From ab99d9fb803982e36c4072a054dcd31c3e21332a Mon Sep 17 00:00:00 2001 From: Swainlab <peter.swain@ed.ac.uk> Date: Wed, 26 Jul 2023 18:14:33 +0100 Subject: [PATCH] unfinished docs on tracks --- src/postprocessor/core/functions/tracks.py | 421 ++++++++++----------- 1 file changed, 197 insertions(+), 224 deletions(-) diff --git a/src/postprocessor/core/functions/tracks.py b/src/postprocessor/core/functions/tracks.py index d49a4cfb..8deecb39 100644 --- a/src/postprocessor/core/functions/tracks.py +++ b/src/postprocessor/core/functions/tracks.py @@ -1,8 +1,13 @@ """ -Functions to process, filter and merge tracks. -""" +Functions to process, filter, and merge tracks. + +We call two tracks contiguous if they are adjacent in time: the +maximal time point of one is one time point less than the +minimal time point of the other. -# from collections import Counter +A right track can have multiple potential left tracks. We must +pick the best. +""" import typing as t from copy import copy @@ -17,6 +22,76 @@ from utils_find_1st import cmp_larger, find_1st from postprocessor.core.processes.savgol import non_uniform_savgol +def get_joinable(tracks, smooth=False, tol=0.1, window=5, degree=3) -> dict: + """ + Get the pair of track (without repeats) that have a smaller error than the + tolerance. If there is a track that can be assigned to two or more other + ones, choose the one with lowest error. + + Parameters + ---------- + tracks: (m x n) Signal + A Signal, usually area, dataframe where rows are cell tracks and + columns are time points. + tol: float or int + threshold of average (prediction error/std) necessary + to consider two tracks the same. If float is fraction of first track, + if int it is absolute units. + window: int + value of window used for savgol_filter + degree: int + value of polynomial degree passed to savgol_filter + """ + # only consider time series with more than two non-NaN data points + tracks = tracks.loc[tracks.notna().sum(axis=1) > 2] + # get contiguous tracks + if smooth: + # specialise to tracks with growing cells and of long duration + clean = clean_tracks(tracks, min_duration=window + 1, min_gr=0.9) + contigs = clean.groupby(["trap"]).apply(get_contiguous_pairs) + else: + contigs = tracks.groupby(["trap"]).apply(get_contiguous_pairs) + # remove traps with no contiguous tracks + contigs = contigs.loc[contigs.apply(len) > 0] + # flatten to (trap, cell_id) pairs + flat = set([k for v in contigs.values for i in v for j in i for k in j]) + # make a data frame of contiguous tracks with the tracks as arrays + if smooth: + smoothed_tracks = clean.loc[flat].apply( + lambda x: non_uniform_savgol(x.index, x.values, window, degree), + axis=1, + ) + else: + smoothed_tracks = tracks.loc[flat].apply( + lambda x: np.array(x.values), axis=1 + ) + # get the Signal values for neighbouring end points of contiguous tracks + actual_edges = contigs.apply(lambda x: get_edge_values(x, smoothed_tracks)) + # get the predicted values + predicted_edges = contigs.apply( + lambda x: get_predicted_edge_values(x, smoothed_tracks, window) + ) + # Prediction of pre and mean of post + prediction_costs = predicted_edges.apply(get_dMetric_wrap, tol=tol) + solutions = [ + solve_matrices_wrap(cost, edges, tol=tol) + for (trap_id, cost), edges in zip( + prediction_costs.items(), actual_edges + ) + ] + breakpoint() + closest_pairs = pd.Series( + solutions, + index=edges_dMetric_pred.index, + ) + # match local with global ids + joinable_ids = [ + localid_to_idx(closest_pairs.loc[i], contigs.loc[i]) + for i in closest_pairs.index + ] + return [pair for pairset in joinable_ids for pair in pairset] + + def load_test_dset(): """Load development dataset to test functions.""" return pd.DataFrame( @@ -45,46 +120,21 @@ def max_nonstop_ntps(track: pd.Series) -> int: return max(consecutive_nonas_grouped) -def get_tracks_ntps(tracks: pd.DataFrame) -> pd.Series: - return tracks.apply(max_ntps, axis=1) - - -def get_avg_gr(track: pd.Series) -> int: - """ - Get average growth rate for a track. - - :param tracks: Series with volume and timepoints as indices - """ +def get_avg_gr(track: pd.Series) -> float: + """Get average growth rate for a track.""" ntps = max_ntps(track) vals = track.dropna().values gr = (vals[-1] - vals[0]) / ntps return gr -def get_avg_grs(tracks: pd.DataFrame) -> pd.DataFrame: - """ - Get average growth rate for a group of tracks - - :param tracks: (m x n) dataframe where rows are cell tracks and - columns are timepoints - """ - return tracks.apply(get_avg_gr, axis=1) - - def clean_tracks( - tracks, min_len: int = 15, min_gr: float = 1.0 + tracks, min_duration: int = 15, min_gr: float = 1.0 ) -> pd.DataFrame: - """ - Clean small non-growing tracks and return the reduced dataframe - - :param tracks: (m x n) dataframe where rows are cell tracks and - columns are timepoints - :param min_len: int number of timepoints cells must have not to be removed - :param min_gr: float Minimum mean growth rate to assume an outline is growing - """ - ntps = get_tracks_ntps(tracks) - grs = get_avg_grs(tracks) - growing_long_tracks = tracks.loc[(ntps >= min_len) & (grs > min_gr)] + """Remove small non-growing tracks and return the reduced data frame.""" + ntps = tracks.apply(max_ntps, axis=1) + grs = tracks.apply(get_avg_gr, axis=1) + growing_long_tracks = tracks.loc[(ntps >= min_duration) & (grs > min_gr)] return growing_long_tracks @@ -191,125 +241,78 @@ def join_track_pair(target, source): return tgt_copy -def get_joinable(tracks, smooth=False, tol=0.1, window=5, degree=3) -> dict: - """ - Get the pair of track (without repeats) that have a smaller error than the - tolerance. If there is a track that can be assigned to two or more other - ones, choose the one with lowest error. +def get_edge_values(contigs_ids, smoothed_tracks): + """Get Signal values for adjacent end points for each contiguous track.""" + values = [ + ( + [get_value(smoothed_tracks.loc[pre_id], -1) for pre_id in pre_ids], + [ + get_value(smoothed_tracks.loc[post_id], 0) + for post_id in post_ids + ], + ) + for pre_ids, post_ids in contigs_ids + ] + return values - Parameters - ---------- - tracks: (m x n) Signal - A Signal, usually area, dataframe where rows are cell tracks and - columns are time points. - tol: float or int - threshold of average (prediction error/std) necessary - to consider two tracks the same. If float is fraction of first track, - if int it is absolute units. - window: int - value of window used for savgol_filter - degree: int - value of polynomial degree passed to savgol_filter - """ - # only consider time series with more than two non-NaN data points - tracks = tracks.loc[tracks.notna().sum(axis=1) > 2] - # smooth all relevant tracks - if smooth: - # Apply savgol filter TODO fix nans affecting edge placing - clean = clean_tracks( - tracks, min_len=window + 1, min_gr=0.9 - ) # get useful tracks - - def savgol_on_srs(x): - return non_uniform_savgol(x.index, x.values, window, degree) - - contig = clean.groupby(["trap"]).apply(get_contiguous_pairs) - contig = contig.loc[contig.apply(len) > 0] - flat = set([k for v in contig.values for i in v for j in i for k in j]) - smoothed_tracks = clean.loc[flat].apply(savgol_on_srs, 1) - else: - contig = tracks.groupby(["trap"]).apply(get_contiguous_pairs) - contig = contig.loc[contig.apply(len) > 0] - flat = set([k for v in contig.values for i in v for j in i for k in j]) - smoothed_tracks = tracks.loc[flat].apply( - lambda x: np.array(x.values), axis=1 - ) +def get_predicted_edge_values(contigs_ids, smoothed_tracks, window): + """ + Find neighbouring values of two contiguous tracks. - # fetch edges from ids TODO (IF necessary, here we can compare growth rates) - def idx_to_edge(preposts): - return [ - ( - [get_val(smoothed_tracks.loc[pre], -1) for pre in pres], - [get_val(smoothed_tracks.loc[post], 0) for post in posts], + Predict the next value for the leftmost track using window values + and find the mean of the initial window values of the rightmost + track. + """ + result = [] + for pre_ids, post_ids in contigs_ids: + pre_res = [] + # left contiguous tracks + for pre_id in pre_ids: + # get last window values of a track + y = get_values_i(smoothed_tracks.loc[pre_id], -window) + # predict next value + pre_res.append( + np.poly1d(np.polyfit(range(len(y)), y, 1))(len(y) + 1), ) - for pres, posts in preposts + # right contiguous tracks + pos_res = [ + # mean value of initial window values of a track + get_mean_value_i(smoothed_tracks.loc[post_id], window) + for post_id in post_ids ] + result.append([pre_res, pos_res]) + return result - def idx_to_pred(preposts): - result = [] - for pres, posts in preposts: - pre_res = [] - for pre in pres: - y = get_last_i(smoothed_tracks.loc[pre], -window) - pre_res.append( - np.poly1d(np.polyfit(range(len(y)), y, 1))(len(y) + 1), - ) - pos_res = [ - get_means(smoothed_tracks.loc[post], window) for post in posts - ] - result.append([pre_res, pos_res]) - - return result - - edges = contig.apply(idx_to_edge) # Raw edges - # edges_mean = contig.apply(idx_to_means) # Mean of both - pre_pred = contig.apply(idx_to_pred) # Prediction of pre and mean of post - - # edges_dMetric = edges.apply(get_dMetric_wrap, tol=tol) - # edges_dMetric_mean = edges_mean.apply(get_dMetric_wrap, tol=tol) - edges_dMetric_pred = pre_pred.apply(get_dMetric_wrap, tol=tol) - solutions = [] - # for (i, dMetrics), edgeset in zip(combined_dMetric.items(), edges): - for (i, dMetrics), edgeset in zip(edges_dMetric_pred.items(), edges): - solutions.append(solve_matrices_wrap(dMetrics, edgeset, tol=tol)) - - closest_pairs = pd.Series( - solutions, - index=edges_dMetric_pred.index, - ) - - # match local with global ids - joinable_ids = [ - localid_to_idx(closest_pairs.loc[i], contig.loc[i]) - for i in closest_pairs.index - ] - - return [pair for pairset in joinable_ids for pair in pairset] +def get_value(x, n): + """Get value from an array ignoring NaN.""" + val = x[~np.isnan(x)][n] if len(x[~np.isnan(x)]) else np.nan + return val -def get_val(x, n): - return x[~np.isnan(x)][n] if len(x[~np.isnan(x)]) else np.nan - -def get_means(x, i): +def get_mean_value_i(x, i): + """Get track's mean Signal value from values either from or up to an index.""" if not len(x[~np.isnan(x)]): return np.nan - if i > 0: - v = x[~np.isnan(x)][:i] else: - v = x[~np.isnan(x)][i:] - return np.nanmean(v) + if i > 0: + v = x[~np.isnan(x)][:i] + else: + v = x[~np.isnan(x)][i:] + return np.nanmean(v) -def get_last_i(x, i): +def get_values_i(x, i): + """Get track's Signal values either from or up to an index.""" if not len(x[~np.isnan(x)]): return np.nan - if i > 0: - v = x[~np.isnan(x)][:i] else: - v = x[~np.isnan(x)][i:] - return v + if i > 0: + v = x[~np.isnan(x)][:i] + else: + v = x[~np.isnan(x)][i:] + return v def localid_to_idx(local_ids, contig_trap): @@ -338,57 +341,55 @@ def get_vec_closest_pairs(lst: List, **kwargs): def get_dMetric_wrap(lst: List, **kwargs): + """Calculate dMetric on a list.""" return [get_dMetric(*sublist, **kwargs) for sublist in lst] def solve_matrices_wrap(dMetric: List, edges: List, **kwargs): + """Calculate solve_matrices on a list.""" return [ solve_matrices(mat, edgeset, **kwargs) for mat, edgeset in zip(dMetric, edges) ] -def get_dMetric( - pre: List[float], post: List[float], tol: Union[float, int] = 1 -): - """Calculate a cost matrix - - input - :param pre: list of floats with edges on left - :param post: list of floats with edges on right - :param tol: int or float if int metrics of tolerance, if float fraction - - returns - :: list of indices corresponding to the best solutions for matrices +def get_dMetric(pre: List[float], post: List[float], tol): + """ + Calculate a cost matrix based on the difference between two Signal + values. + Parameters + ---------- + pre: list of floats + Values of the Signal for left contiguous tracks. + post: list of floats + Values of the Signal for right contiguous tracks. """ if len(pre) > len(post): dMetric = np.abs(np.subtract.outer(post, pre)) else: dMetric = np.abs(np.subtract.outer(pre, post)) - dMetric[np.isnan(dMetric)] = ( - tol + 1 + np.nanmax(dMetric) - ) # nans will be filtered + # replace NaNs with maximal cost values + dMetric[np.isnan(dMetric)] = tol + 1 + np.nanmax(dMetric) return dMetric -def solve_matrices( - dMetric: np.ndarray, prepost: List, tol: Union[float, int] = 1 -): +def solve_matrices(cost: np.ndarray, edges: List, tol: Union[float, int] = 1): """ Solve the distance matrices obtained in get_dMetric and/or merged from independent dMetric matrices. """ - ids = solve_matrix(dMetric) - if not len(ids[0]): + ids = solve_matrix(cost) + if len(ids[0]): + pre, post = edges + norm = ( + np.array(pre)[ids[len(pre) > len(post)]] if tol < 1 else 1 + ) # relative or absolute tol + result = dMetric[ids] / norm + ids = ids if len(pre) < len(post) else ids[::-1] + return [idx for idx, res in zip(zip(*ids), result) if res <= tol] + else: return [] - pre, post = prepost - norm = ( - np.array(pre)[ids[len(pre) > len(post)]] if tol < 1 else 1 - ) # relative or absolute tol - result = dMetric[ids] / norm - ids = ids if len(pre) < len(post) else ids[::-1] - return [idx for idx, res in zip(zip(*ids), result) if res <= tol] def get_closest_pairs( @@ -412,37 +413,31 @@ def get_closest_pairs( def solve_matrix(dMetric): - """ - Solve cost matrix focusing on getting the smallest cost at each iteration. - - input - :param dMetric: np.array cost matrix - - returns - tuple of np.arrays indicating picks with lowest individual value - """ + """Arrange indices to the cost matrix in order of increasing cost.""" glob_is = [] glob_js = [] if (~np.isnan(dMetric)).any(): - tmp = copy(dMetric) - std = sorted(tmp[~np.isnan(tmp)]) - while (~np.isnan(std)).any(): - v = std[0] - i_s, j_s = np.where(tmp == v) + lMetric = copy(dMetric) + sortedMetric = sorted(lMetric[~np.isnan(lMetric)]) + while (~np.isnan(sortedMetric)).any(): + # indices of point with minimal cost + i_s, j_s = np.where(lMetric == sortedMetric[0]) i = i_s[0] j = j_s[0] - tmp[i, :] += np.nan - tmp[:, j] += np.nan + # store this point glob_is.append(i) glob_js.append(j) - std = sorted(tmp[~np.isnan(tmp)]) - return (np.array(glob_is), np.array(glob_js)) + # remove from lMetric + lMetric[i, :] += np.nan + lMetric[:, j] += np.nan + sortedMetric = sorted(lMetric[~np.isnan(lMetric)]) + indices = (np.array(glob_is), np.array(glob_js)) + breakpoint() + return indices def plot_joinable(tracks, joinable_pairs): - """ - Convenience plotting function for debugging and data vis - """ + """Convenience plotting function for debugging.""" nx = 8 ny = 8 _, axes = plt.subplots(nx, ny) @@ -467,53 +462,31 @@ def get_contiguous_pairs(tracks: pd.DataFrame) -> list: """ Get all pair of contiguous track ids from a tracks data frame. + For two tracks to be contiguous, they must be exactly adjacent. + Parameters ---------- tracks: pd.Dataframe - A dataframe for one trap where rows are cell tracks and columns - are time points. + A dataframe where rows are cell tracks and columns are time + points. """ - # time points bounding a tracklet of non-NaN values + # TODO add support for skipping time points + # find time points bounding tracks of non-NaN values mins, maxs = [ tracks.notna().apply(np.where, axis=1).apply(fn) for fn in (np.min, np.max) ] + # mins.name = "min_tpt" + # maxs.name = "max_tpt" + # df = pd.merge(mins, maxs, right_index=True, left_index=True) + # df["duration"] = df.max_tpt - df.min_tpt + # + # flip so that time points become the index mins_d = mins.groupby(mins).apply(lambda x: x.index.tolist()) - mins_d.index = mins_d.index - 1 # make indices equal - # TODO add support for skipping time points maxs_d = maxs.groupby(maxs).apply(lambda x: x.index.tolist()) + # reduce minimal time point to make a right track overlap with a left track + mins_d.index = mins_d.index - 1 + # find common end points common = sorted(set(mins_d.index).intersection(maxs_d.index), reverse=True) - return [(maxs_d[t], mins_d[t]) for t in common] - - -# def fit_track(track: pd.Series, obj=None): -# if obj is None: -# obj = objective - -# x = track.dropna().index -# y = track.dropna().values -# popt, _ = curve_fit(obj, x, y) - -# return popt - -# def interpolate(track, xs) -> list: -# ''' -# Interpolate next timepoint from a track - -# :param track: pd.Series of volume growth over a time period -# :param t: int timepoint to interpolate -# ''' -# popt = fit_track(track) -# # perr = np.sqrt(np.diag(pcov)) -# return objective(np.array(xs), *popt) - - -# def objective(x,a,b,c,d) -> float: -# # return (a)/(1+b*np.exp(c*x))+d -# return (((x+d)*a)/((x+d)+b))+c - -# def cand_pairs_to_dict(candidates): -# d={x:[] for x,_ in candidates} -# for x,y in candidates: -# d[x].append(y) -# return d + contigs = [(maxs_d[t], mins_d[t]) for t in common] + return contigs -- GitLab