From a533b16b355c8bc025ad25182c9aceb2b067d11e Mon Sep 17 00:00:00 2001
From: Swainlab <peter.swain@ed.ac.uk>
Date: Fri, 28 Jul 2023 13:14:20 +0100
Subject: [PATCH] finished docs on tracks; rewrote for transparency

---
 src/postprocessor/core/functions/tracks.py | 194 ++++++++++++++++-----
 1 file changed, 147 insertions(+), 47 deletions(-)

diff --git a/src/postprocessor/core/functions/tracks.py b/src/postprocessor/core/functions/tracks.py
index 8deecb39..fc7f445b 100644
--- a/src/postprocessor/core/functions/tracks.py
+++ b/src/postprocessor/core/functions/tracks.py
@@ -5,8 +5,7 @@ We call two tracks contiguous if they are adjacent in time: the
 maximal time point of one is one time point less than the
 minimal time point of the other.
 
-A right track can have multiple potential left tracks. We must
-pick the best.
+A right track can have multiple potential left tracks. We pick the best.
 """
 
 import typing as t
@@ -22,7 +21,97 @@ from utils_find_1st import cmp_larger, find_1st
 from postprocessor.core.processes.savgol import non_uniform_savgol
 
 
-def get_joinable(tracks, smooth=False, tol=0.1, window=5, degree=3) -> dict:
+def get_joinable(tracks, smooth=False, tol=0.2, window=5, degree=3) -> dict:
+    """
+    Find all pairs of tracks that should be joined.
+
+    Each track is defined by (trap_id, cell_id).
+
+    If there are multiple choices of which, say, left tracks to join to a
+    right track, pick the best using the Signal values to do so.
+
+    To score two tracks, we predict the future value of a left track and
+    compare with the mean initial values of a right track.
+
+    Parameters
+    ----------
+    tracks: pd.DataFrame
+        A Signal, usually area, where rows are cell tracks and columns are
+        time points.
+    smooth: boolean
+        If True, smooth tracks with a savgol_filter.
+    tol: float < 1 or int
+        If int, compare the absolute distance between predicted values
+        for the left and right end points of two contiguous tracks.
+        If float, compare the distance relative to the magnitude of the
+        end point of the left track.
+    window: int
+        Length of window used for predictions and for any savgol_filter.
+    degree: int
+        The degree of the polynomial used by the savgol_filter.
+    """
+    # only consider time series with more than two non-NaN data points
+    tracks = tracks.loc[tracks.notna().sum(axis=1) > 2]
+    # get contiguous tracks
+    if smooth:
+        # specialise to tracks with growing cells and of long duration
+        clean = clean_tracks(tracks, min_duration=window + 1, min_gr=0.9)
+        contigs = clean.groupby(["trap"]).apply(get_contiguous_pairs)
+    else:
+        contigs = tracks.groupby(["trap"]).apply(get_contiguous_pairs)
+    # remove traps with no contiguous tracks
+    contigs = contigs.loc[contigs.apply(len) > 0]
+    # flatten to (trap, cell_id) pairs
+    flat = set([k for v in contigs.values for i in v for j in i for k in j])
+    # make a data frame of contiguous tracks with the tracks as arrays
+    if smooth:
+        smoothed_tracks = clean.loc[flat].apply(
+            lambda x: non_uniform_savgol(x.index, x.values, window, degree),
+            axis=1,
+        )
+    else:
+        smoothed_tracks = tracks.loc[flat].apply(
+            lambda x: np.array(x.values), axis=1
+        )
+    # get the Signal values for neighbouring end points of contiguous tracks
+    actual_edge_values = contigs.apply(
+        lambda x: get_edge_values(x, smoothed_tracks)
+    )
+    # get the predicted values
+    predicted_edge_values = contigs.apply(
+        lambda x: get_predicted_edge_values(x, smoothed_tracks, window)
+    )
+    # score predicted edge values: low values are best
+    prediction_scores = predicted_edge_values.apply(get_dMetric_wrap)
+    # find contiguous tracks to join for each trap
+    trap_contigs_to_join = []
+    for idx in contigs.index:
+        local_contigs = contigs.loc[idx]
+        # find indices of best left and right tracks to join
+        best_indices = find_best_from_scores_wrap(
+            prediction_scores.loc[idx], actual_edge_values.loc[idx], tol=tol
+        )
+        # find tracks from the indices
+        trap_contigs_to_join.append(
+            [
+                (contig[0][left], contig[1][right])
+                for best_index, contig in zip(best_indices, local_contigs)
+                for (left, right) in best_index
+                if best_index
+            ]
+        )
+    # return only the pairs of contiguous tracks
+    contigs_to_join = [
+        contigs
+        for trap_tracks in trap_contigs_to_join
+        for contigs in trap_tracks
+    ]
+    return contigs_to_join
+
+
+def get_joinable_original(
+    tracks, smooth=False, tol=0.1, window=5, degree=3
+) -> dict:
     """
     Get the pair of track (without repeats) that have a smaller error than the
     tolerance. If there is a track that can be assigned to two or more other
@@ -66,34 +155,41 @@ def get_joinable(tracks, smooth=False, tol=0.1, window=5, degree=3) -> dict:
             lambda x: np.array(x.values), axis=1
         )
     # get the Signal values for neighbouring end points of contiguous tracks
-    actual_edges = contigs.apply(lambda x: get_edge_values(x, smoothed_tracks))
+    actual_edge_values = contigs.apply(
+        lambda x: get_edge_values(x, smoothed_tracks)
+    )
     # get the predicted values
-    predicted_edges = contigs.apply(
+    predicted_edge_values = contigs.apply(
         lambda x: get_predicted_edge_values(x, smoothed_tracks, window)
     )
-    # Prediction of pre and mean of post
-    prediction_costs = predicted_edges.apply(get_dMetric_wrap, tol=tol)
+    # score predicted edge values
+    prediction_scores = predicted_edge_values.apply(get_dMetric_wrap, tol=tol)
+
     solutions = [
-        solve_matrices_wrap(cost, edges, tol=tol)
-        for (trap_id, cost), edges in zip(
-            prediction_costs.items(), actual_edges
+        # for all sets of contigs at a trap
+        find_best_from_scores_wrap(cost, edge_values, tol=tol)
+        for (trap_id, cost), edge_values in zip(
+            prediction_scores.items(), actual_edge_values
         )
     ]
-    breakpoint()
     closest_pairs = pd.Series(
         solutions,
-        index=edges_dMetric_pred.index,
+        index=prediction_scores.index,
     )
     # match local with global ids
     joinable_ids = [
         localid_to_idx(closest_pairs.loc[i], contigs.loc[i])
         for i in closest_pairs.index
     ]
-    return [pair for pairset in joinable_ids for pair in pairset]
+
+    contigs_to_join = [
+        contigs for trap_tracks in joinable_ids for contigs in trap_tracks
+    ]
+    return contigs_to_join
 
 
 def load_test_dset():
-    """Load development dataset to test functions."""
+    """Load test data set."""
     return pd.DataFrame(
         {
             ("a", 1, 1): [2, 5, np.nan, 6, 8] + [np.nan] * 5,
@@ -345,49 +441,59 @@ def get_dMetric_wrap(lst: List, **kwargs):
     return [get_dMetric(*sublist, **kwargs) for sublist in lst]
 
 
-def solve_matrices_wrap(dMetric: List, edges: List, **kwargs):
+def find_best_from_scores_wrap(dMetric: List, edges: List, **kwargs):
     """Calculate solve_matrices on a list."""
     return [
-        solve_matrices(mat, edgeset, **kwargs)
+        find_best_from_scores(mat, edgeset, **kwargs)
         for mat, edgeset in zip(dMetric, edges)
     ]
 
 
-def get_dMetric(pre: List[float], post: List[float], tol):
+def get_dMetric(pre_values: List[float], post_values: List[float]):
     """
-    Calculate a cost matrix based on the difference between two Signal
+    Calculate a scoring matrix based on the difference between two Signal
     values.
 
+    We generate one score per pair of contiguous tracks.
+
+    Lower scores are better.
+
     Parameters
     ----------
-    pre: list of floats
+    pre_values: list of floats
         Values of the Signal for left contiguous tracks.
-    post: list of floats
+    post_values: list of floats
         Values of the Signal for right contiguous tracks.
     """
-    if len(pre) > len(post):
-        dMetric = np.abs(np.subtract.outer(post, pre))
+    if len(pre_values) > len(post_values):
+        dMetric = np.abs(np.subtract.outer(post_values, pre_values))
     else:
-        dMetric = np.abs(np.subtract.outer(pre, post))
-    # replace NaNs with maximal cost values
-    dMetric[np.isnan(dMetric)] = tol + 1 + np.nanmax(dMetric)
+        dMetric = np.abs(np.subtract.outer(pre_values, post_values))
+    # replace NaNs with maximal values
+    dMetric[np.isnan(dMetric)] = 1 + np.nanmax(dMetric)
     return dMetric
 
 
-def solve_matrices(cost: np.ndarray, edges: List, tol: Union[float, int] = 1):
-    """
-    Solve the distance matrices obtained in get_dMetric and/or merged from
-    independent dMetric matrices.
-    """
-    ids = solve_matrix(cost)
+def find_best_from_scores(
+    scores: np.ndarray, actual_edge_values: List, tol: Union[float, int] = 1
+):
+    """Find indices for left and right contiguous tracks with scores below a tolerance."""
+    ids = find_best_indices(scores)
     if len(ids[0]):
-        pre, post = edges
+        pre_value, post_value = actual_edge_values
+        # score with relative or absolute distance
         norm = (
-            np.array(pre)[ids[len(pre) > len(post)]] if tol < 1 else 1
-        )  # relative or absolute tol
-        result = dMetric[ids] / norm
-        ids = ids if len(pre) < len(post) else ids[::-1]
-        return [idx for idx, res in zip(zip(*ids), result) if res <= tol]
+            np.array(pre_value)[ids[len(pre_value) > len(post_value)]]
+            if tol < 1
+            else 1
+        )
+        best_scores = scores[ids] / norm
+        ids = ids if len(pre_value) < len(post_value) else ids[::-1]
+        # keep only indices with best_score less than the tolerance
+        indices = [
+            idx for idx, score in zip(zip(*ids), best_scores) if score <= tol
+        ]
+        return indices
     else:
         return []
 
@@ -409,18 +515,18 @@ def get_closest_pairs(
 
     """
     dMetric = get_dMetric(pre, post, tol)
-    return solve_matrices(dMetric, pre, post, tol)
+    return find_best_from_scores(dMetric, pre, post, tol)
 
 
-def solve_matrix(dMetric):
-    """Arrange indices to the cost matrix in order of increasing cost."""
+def find_best_indices(dMetric):
+    """Find indices for left and right contiguous tracks with minimal scores."""
     glob_is = []
     glob_js = []
     if (~np.isnan(dMetric)).any():
         lMetric = copy(dMetric)
         sortedMetric = sorted(lMetric[~np.isnan(lMetric)])
         while (~np.isnan(sortedMetric)).any():
-            # indices of point with minimal cost
+            # indices of point with the lowest score
             i_s, j_s = np.where(lMetric == sortedMetric[0])
             i = i_s[0]
             j = j_s[0]
@@ -432,7 +538,6 @@ def solve_matrix(dMetric):
             lMetric[:, j] += np.nan
             sortedMetric = sorted(lMetric[~np.isnan(lMetric)])
     indices = (np.array(glob_is), np.array(glob_js))
-    breakpoint()
     return indices
 
 
@@ -476,11 +581,6 @@ def get_contiguous_pairs(tracks: pd.DataFrame) -> list:
         tracks.notna().apply(np.where, axis=1).apply(fn)
         for fn in (np.min, np.max)
     ]
-    # mins.name = "min_tpt"
-    # maxs.name = "max_tpt"
-    # df = pd.merge(mins, maxs, right_index=True, left_index=True)
-    # df["duration"] = df.max_tpt - df.min_tpt
-    #
     # flip so that time points become the index
     mins_d = mins.groupby(mins).apply(lambda x: x.index.tolist())
     maxs_d = maxs.groupby(maxs).apply(lambda x: x.index.tolist())
-- 
GitLab