From a533b16b355c8bc025ad25182c9aceb2b067d11e Mon Sep 17 00:00:00 2001 From: Swainlab <peter.swain@ed.ac.uk> Date: Fri, 28 Jul 2023 13:14:20 +0100 Subject: [PATCH] finished docs on tracks; rewrote for transparency --- src/postprocessor/core/functions/tracks.py | 194 ++++++++++++++++----- 1 file changed, 147 insertions(+), 47 deletions(-) diff --git a/src/postprocessor/core/functions/tracks.py b/src/postprocessor/core/functions/tracks.py index 8deecb39..fc7f445b 100644 --- a/src/postprocessor/core/functions/tracks.py +++ b/src/postprocessor/core/functions/tracks.py @@ -5,8 +5,7 @@ We call two tracks contiguous if they are adjacent in time: the maximal time point of one is one time point less than the minimal time point of the other. -A right track can have multiple potential left tracks. We must -pick the best. +A right track can have multiple potential left tracks. We pick the best. """ import typing as t @@ -22,7 +21,97 @@ from utils_find_1st import cmp_larger, find_1st from postprocessor.core.processes.savgol import non_uniform_savgol -def get_joinable(tracks, smooth=False, tol=0.1, window=5, degree=3) -> dict: +def get_joinable(tracks, smooth=False, tol=0.2, window=5, degree=3) -> dict: + """ + Find all pairs of tracks that should be joined. + + Each track is defined by (trap_id, cell_id). + + If there are multiple choices of which, say, left tracks to join to a + right track, pick the best using the Signal values to do so. + + To score two tracks, we predict the future value of a left track and + compare with the mean initial values of a right track. + + Parameters + ---------- + tracks: pd.DataFrame + A Signal, usually area, where rows are cell tracks and columns are + time points. + smooth: boolean + If True, smooth tracks with a savgol_filter. + tol: float < 1 or int + If int, compare the absolute distance between predicted values + for the left and right end points of two contiguous tracks. + If float, compare the distance relative to the magnitude of the + end point of the left track. + window: int + Length of window used for predictions and for any savgol_filter. + degree: int + The degree of the polynomial used by the savgol_filter. + """ + # only consider time series with more than two non-NaN data points + tracks = tracks.loc[tracks.notna().sum(axis=1) > 2] + # get contiguous tracks + if smooth: + # specialise to tracks with growing cells and of long duration + clean = clean_tracks(tracks, min_duration=window + 1, min_gr=0.9) + contigs = clean.groupby(["trap"]).apply(get_contiguous_pairs) + else: + contigs = tracks.groupby(["trap"]).apply(get_contiguous_pairs) + # remove traps with no contiguous tracks + contigs = contigs.loc[contigs.apply(len) > 0] + # flatten to (trap, cell_id) pairs + flat = set([k for v in contigs.values for i in v for j in i for k in j]) + # make a data frame of contiguous tracks with the tracks as arrays + if smooth: + smoothed_tracks = clean.loc[flat].apply( + lambda x: non_uniform_savgol(x.index, x.values, window, degree), + axis=1, + ) + else: + smoothed_tracks = tracks.loc[flat].apply( + lambda x: np.array(x.values), axis=1 + ) + # get the Signal values for neighbouring end points of contiguous tracks + actual_edge_values = contigs.apply( + lambda x: get_edge_values(x, smoothed_tracks) + ) + # get the predicted values + predicted_edge_values = contigs.apply( + lambda x: get_predicted_edge_values(x, smoothed_tracks, window) + ) + # score predicted edge values: low values are best + prediction_scores = predicted_edge_values.apply(get_dMetric_wrap) + # find contiguous tracks to join for each trap + trap_contigs_to_join = [] + for idx in contigs.index: + local_contigs = contigs.loc[idx] + # find indices of best left and right tracks to join + best_indices = find_best_from_scores_wrap( + prediction_scores.loc[idx], actual_edge_values.loc[idx], tol=tol + ) + # find tracks from the indices + trap_contigs_to_join.append( + [ + (contig[0][left], contig[1][right]) + for best_index, contig in zip(best_indices, local_contigs) + for (left, right) in best_index + if best_index + ] + ) + # return only the pairs of contiguous tracks + contigs_to_join = [ + contigs + for trap_tracks in trap_contigs_to_join + for contigs in trap_tracks + ] + return contigs_to_join + + +def get_joinable_original( + tracks, smooth=False, tol=0.1, window=5, degree=3 +) -> dict: """ Get the pair of track (without repeats) that have a smaller error than the tolerance. If there is a track that can be assigned to two or more other @@ -66,34 +155,41 @@ def get_joinable(tracks, smooth=False, tol=0.1, window=5, degree=3) -> dict: lambda x: np.array(x.values), axis=1 ) # get the Signal values for neighbouring end points of contiguous tracks - actual_edges = contigs.apply(lambda x: get_edge_values(x, smoothed_tracks)) + actual_edge_values = contigs.apply( + lambda x: get_edge_values(x, smoothed_tracks) + ) # get the predicted values - predicted_edges = contigs.apply( + predicted_edge_values = contigs.apply( lambda x: get_predicted_edge_values(x, smoothed_tracks, window) ) - # Prediction of pre and mean of post - prediction_costs = predicted_edges.apply(get_dMetric_wrap, tol=tol) + # score predicted edge values + prediction_scores = predicted_edge_values.apply(get_dMetric_wrap, tol=tol) + solutions = [ - solve_matrices_wrap(cost, edges, tol=tol) - for (trap_id, cost), edges in zip( - prediction_costs.items(), actual_edges + # for all sets of contigs at a trap + find_best_from_scores_wrap(cost, edge_values, tol=tol) + for (trap_id, cost), edge_values in zip( + prediction_scores.items(), actual_edge_values ) ] - breakpoint() closest_pairs = pd.Series( solutions, - index=edges_dMetric_pred.index, + index=prediction_scores.index, ) # match local with global ids joinable_ids = [ localid_to_idx(closest_pairs.loc[i], contigs.loc[i]) for i in closest_pairs.index ] - return [pair for pairset in joinable_ids for pair in pairset] + + contigs_to_join = [ + contigs for trap_tracks in joinable_ids for contigs in trap_tracks + ] + return contigs_to_join def load_test_dset(): - """Load development dataset to test functions.""" + """Load test data set.""" return pd.DataFrame( { ("a", 1, 1): [2, 5, np.nan, 6, 8] + [np.nan] * 5, @@ -345,49 +441,59 @@ def get_dMetric_wrap(lst: List, **kwargs): return [get_dMetric(*sublist, **kwargs) for sublist in lst] -def solve_matrices_wrap(dMetric: List, edges: List, **kwargs): +def find_best_from_scores_wrap(dMetric: List, edges: List, **kwargs): """Calculate solve_matrices on a list.""" return [ - solve_matrices(mat, edgeset, **kwargs) + find_best_from_scores(mat, edgeset, **kwargs) for mat, edgeset in zip(dMetric, edges) ] -def get_dMetric(pre: List[float], post: List[float], tol): +def get_dMetric(pre_values: List[float], post_values: List[float]): """ - Calculate a cost matrix based on the difference between two Signal + Calculate a scoring matrix based on the difference between two Signal values. + We generate one score per pair of contiguous tracks. + + Lower scores are better. + Parameters ---------- - pre: list of floats + pre_values: list of floats Values of the Signal for left contiguous tracks. - post: list of floats + post_values: list of floats Values of the Signal for right contiguous tracks. """ - if len(pre) > len(post): - dMetric = np.abs(np.subtract.outer(post, pre)) + if len(pre_values) > len(post_values): + dMetric = np.abs(np.subtract.outer(post_values, pre_values)) else: - dMetric = np.abs(np.subtract.outer(pre, post)) - # replace NaNs with maximal cost values - dMetric[np.isnan(dMetric)] = tol + 1 + np.nanmax(dMetric) + dMetric = np.abs(np.subtract.outer(pre_values, post_values)) + # replace NaNs with maximal values + dMetric[np.isnan(dMetric)] = 1 + np.nanmax(dMetric) return dMetric -def solve_matrices(cost: np.ndarray, edges: List, tol: Union[float, int] = 1): - """ - Solve the distance matrices obtained in get_dMetric and/or merged from - independent dMetric matrices. - """ - ids = solve_matrix(cost) +def find_best_from_scores( + scores: np.ndarray, actual_edge_values: List, tol: Union[float, int] = 1 +): + """Find indices for left and right contiguous tracks with scores below a tolerance.""" + ids = find_best_indices(scores) if len(ids[0]): - pre, post = edges + pre_value, post_value = actual_edge_values + # score with relative or absolute distance norm = ( - np.array(pre)[ids[len(pre) > len(post)]] if tol < 1 else 1 - ) # relative or absolute tol - result = dMetric[ids] / norm - ids = ids if len(pre) < len(post) else ids[::-1] - return [idx for idx, res in zip(zip(*ids), result) if res <= tol] + np.array(pre_value)[ids[len(pre_value) > len(post_value)]] + if tol < 1 + else 1 + ) + best_scores = scores[ids] / norm + ids = ids if len(pre_value) < len(post_value) else ids[::-1] + # keep only indices with best_score less than the tolerance + indices = [ + idx for idx, score in zip(zip(*ids), best_scores) if score <= tol + ] + return indices else: return [] @@ -409,18 +515,18 @@ def get_closest_pairs( """ dMetric = get_dMetric(pre, post, tol) - return solve_matrices(dMetric, pre, post, tol) + return find_best_from_scores(dMetric, pre, post, tol) -def solve_matrix(dMetric): - """Arrange indices to the cost matrix in order of increasing cost.""" +def find_best_indices(dMetric): + """Find indices for left and right contiguous tracks with minimal scores.""" glob_is = [] glob_js = [] if (~np.isnan(dMetric)).any(): lMetric = copy(dMetric) sortedMetric = sorted(lMetric[~np.isnan(lMetric)]) while (~np.isnan(sortedMetric)).any(): - # indices of point with minimal cost + # indices of point with the lowest score i_s, j_s = np.where(lMetric == sortedMetric[0]) i = i_s[0] j = j_s[0] @@ -432,7 +538,6 @@ def solve_matrix(dMetric): lMetric[:, j] += np.nan sortedMetric = sorted(lMetric[~np.isnan(lMetric)]) indices = (np.array(glob_is), np.array(glob_js)) - breakpoint() return indices @@ -476,11 +581,6 @@ def get_contiguous_pairs(tracks: pd.DataFrame) -> list: tracks.notna().apply(np.where, axis=1).apply(fn) for fn in (np.min, np.max) ] - # mins.name = "min_tpt" - # maxs.name = "max_tpt" - # df = pd.merge(mins, maxs, right_index=True, left_index=True) - # df["duration"] = df.max_tpt - df.min_tpt - # # flip so that time points become the index mins_d = mins.groupby(mins).apply(lambda x: x.index.tolist()) maxs_d = maxs.groupby(maxs).apply(lambda x: x.index.tolist()) -- GitLab