diff --git a/extraction/core/tracks.py b/extraction/core/tracks.py index 71002b72bf096332b29507e40dd39f97b951149e..62e7bff4df9b3d00013347870631ac9e813f16d7 100644 --- a/extraction/core/tracks.py +++ b/extraction/core/tracks.py @@ -1,6 +1,6 @@ -''' +""" Functions to process, filter and merge tracks. -''' +""" # from collections import Counter @@ -10,69 +10,78 @@ from typing import Union, List import numpy as np import pandas as pd -from scipy.signal import savgol_filter +from scipy.signal import savgol_filter + # from scipy.optimize import linear_sum_assignment # from scipy.optimize import curve_fit from matplotlib import pyplot as plt + def load_test_dset(): # Load development dataset to test functions - return pd.DataFrame({('a',1,1):[2, 5, np.nan, 6,8] + [np.nan] * 5, - ('a',1,2):list(range(2,12)), - ('a',1,3):[np.nan] * 8 + [6,7], - ('a',1,4):[np.nan] * 5 + [9,12,10,14,18]}, - index=range(1,11)).T - -def get_ntps(track:pd.Series) -> int: + return pd.DataFrame( + { + ("a", 1, 1): [2, 5, np.nan, 6, 8] + [np.nan] * 5, + ("a", 1, 2): list(range(2, 12)), + ("a", 1, 3): [np.nan] * 8 + [6, 7], + ("a", 1, 4): [np.nan] * 5 + [9, 12, 10, 14, 18], + }, + index=range(1, 11), + ).T + + +def get_ntps(track: pd.Series) -> int: # Get number of timepoints indices = np.where(track.notna()) return np.max(indices) - np.min(indices) -def get_tracks_ntps(tracks:pd.DataFrame) -> pd.Series: +def get_tracks_ntps(tracks: pd.DataFrame) -> pd.Series: return tracks.apply(get_ntps, axis=1) -def get_avg_gr(track:pd.Series) -> int: - ''' + +def get_avg_gr(track: pd.Series) -> int: + """ Get average growth rate for a track. :param tracks: Series with volume and timepoints as indices - ''' + """ ntps = get_ntps(track) vals = track.dropna().values - gr = (vals[-1] - vals[0] )/ ntps + gr = (vals[-1] - vals[0]) / ntps return gr -def get_avg_grs(tracks:pd.DataFrame) -> pd.DataFrame: - ''' +def get_avg_grs(tracks: pd.DataFrame) -> pd.DataFrame: + """ Get average growth rate for a group of tracks :param tracks: (m x n) dataframe where rows are cell tracks and columns are timepoints - ''' + """ return tracks.apply(get_avg_gr, axis=1) -def clean_tracks(tracks, min_len:int=6, min_gr:float=0.5) -> pd.DataFrame: - ''' +def clean_tracks(tracks, min_len: int = 6, min_gr: float = 0.5) -> pd.DataFrame: + """ Clean small non-growing tracks and return the reduced dataframe :param tracks: (m x n) dataframe where rows are cell tracks and columns are timepoints :param min_len: int number of timepoints cells must have not to be removed :param min_gr: float Minimum mean growth rate to assume an outline is growing - ''' + """ ntps = get_tracks_ntps(tracks) grs = get_avg_grs(tracks) growing_long_tracks = tracks.loc[(ntps >= min_len) & (grs > min_gr)] - return (growing_long_tracks) + return growing_long_tracks + def merge_tracks(tracks, drop=False, **kwargs) -> pd.DataFrame: - ''' + """ Join tracks that are contiguous and within a volume threshold of each other :param tracks: (m x n) dataframe where rows are cell tracks and @@ -84,7 +93,7 @@ def merge_tracks(tracks, drop=False, **kwargs) -> pd.DataFrame: :joint_tracks: (m x n) Dataframe where rows are cell tracks and columns are timepoints. Merged tracks are still present but filled with np.nans. - ''' + """ # calculate tracks that can be merged until no more traps can be merged joinable_pairs = get_joinable(tracks, **kwargs) @@ -96,7 +105,7 @@ def merge_tracks(tracks, drop=False, **kwargs) -> pd.DataFrame: def get_joint_ids(merging_seqs) -> dict: - ''' + """ Convert a series of merges into a dictionary where the key is the cell_id of destination and the value a list of the other track ids that were merged into the key @@ -119,7 +128,7 @@ def get_joint_ids(merging_seqs) -> dict: output {a:a, b:a, c:a, d:a} - ''' + """ targets, origins = list(zip(*merging_seqs)) static_tracks = set(targets).difference(origins) @@ -127,13 +136,17 @@ def get_joint_ids(merging_seqs) -> dict: for target, origin in merging_seqs: joint[origin] = target - moved_target = [k for k,v in joint.items() \ - if joint[v]!=v and v in joint.values()] + moved_target = [ + k for k, v in joint.items() if joint[v] != v and v in joint.values() + ] for orig in moved_target: joint[orig] = rec_bottom(joint, orig) - return {k:v for k,v in joint.items() if k!=v} # remove ids that point to themselves + return { + k: v for k, v in joint.items() if k != v + } # remove ids that point to themselves + def rec_bottom(d, k): if d[k] == k: @@ -141,8 +154,9 @@ def rec_bottom(d, k): else: return rec_bottom(d, d[k]) + def join_tracks(tracks, joinable_pairs, drop=False) -> pd.DataFrame: - ''' + """ Join pairs of tracks from later tps towards the start. :param tracks: (m x n) dataframe where rows are cell tracks and @@ -155,8 +169,7 @@ def join_tracks(tracks, joinable_pairs, drop=False) -> pd.DataFrame: with np.nans. :param drop: bool indicating whether or not to drop moved rows - ''' - + """ tmp = copy(tracks) for target, source in joinable_pairs: @@ -165,7 +178,8 @@ def join_tracks(tracks, joinable_pairs, drop=False) -> pd.DataFrame: if drop: tmp = tmp.drop(source) - return (tmp) + return tmp + def join_track_pairs(track1, track2): tmp = copy(track1) @@ -173,8 +187,9 @@ def join_track_pairs(track1, track2): return tmp + def get_joinable(tracks, smooth=False, tol=0.1, window=5, degree=3) -> dict: - ''' + """ Get the pair of track (without repeats) that have a smaller error than the tolerance. If there is a track that can be assigned to two or more other ones, it chooses the one with a lowest error. @@ -187,39 +202,48 @@ def get_joinable(tracks, smooth=False, tol=0.1, window=5, degree=3) -> dict: :param window: int value of window used for savgol_filter :param degree: int value of polynomial degree passed to savgol_filter - ''' + """ - tracks.index.names = ['pos', 'trap', 'cell'] #TODO remove this once it is integrated in the tracker + tracks.index.names = [ + "pos", + "trap", + "cell", + ] # TODO remove this once it is integrated in the tracker # contig=tracks.groupby(['pos','trap']).apply(tracks2contig) - clean = clean_tracks(tracks, min_len=window+1, min_gr = 0.9) # get useful tracks - contig=clean.groupby(['pos','trap']).apply(get_contiguous_pairs) + clean = clean_tracks(tracks, min_len=window + 1, min_gr=0.9) # get useful tracks + contig = clean.groupby(["pos", "trap"]).apply(get_contiguous_pairs) contig = contig.loc[contig.apply(len) > 0] # candict = {k:v for d in contig.values for k,v in d.items()} # smooth all relevant tracks - linear=set([k for v in contig.values for i in v for j in i for k in j]) - if smooth: # Apply savgol filter TODO fix nans affecting edge placing - savgol_on_srs = lambda x: non_uniform_savgol(x.index, x.values, - window, degree) - smoothed_tracks = clean.loc[linear].apply(savgol_on_srs,1) + linear = set([k for v in contig.values for i in v for j in i for k in j]) + if smooth: # Apply savgol filter TODO fix nans affecting edge placing + savgol_on_srs = lambda x: non_uniform_savgol(x.index, x.values, window, degree) + smoothed_tracks = clean.loc[linear].apply(savgol_on_srs, 1) else: smoothed_tracks = clean.loc[linear].apply(lambda x: np.array(x.values), axis=1) # fetch edges from ids TODO (IF necessary, here we can compare growth rates) - idx_to_edge = lambda preposts: [([get_val(smoothed_tracks.loc[pre],-1) for pre in pres], - [get_val(smoothed_tracks.loc[post],0) for post in posts]) - for pres, posts in preposts] + idx_to_edge = lambda preposts: [ + ( + [get_val(smoothed_tracks.loc[pre], -1) for pre in pres], + [get_val(smoothed_tracks.loc[post], 0) for post in posts], + ) + for pres, posts in preposts + ] edges = contig.apply(idx_to_edge) closest_pairs = edges.apply(get_vec_closest_pairs, tol=tol) - #match local with global ids - joinable_ids = [localid_to_idx(closest_pairs.loc[i], contig.loc[i])\ - for i in closest_pairs.index] + # match local with global ids + joinable_ids = [ + localid_to_idx(closest_pairs.loc[i], contig.loc[i]) for i in closest_pairs.index + ] return [pair for pairset in joinable_ids for pair in pairset] + get_val = lambda x, n: x[~np.isnan(x)][n] if len(x[~np.isnan(x)]) else np.nan @@ -234,17 +258,18 @@ def localid_to_idx(local_ids, contig_trap): list of pairs with (experiment-level) ids to be joint """ lin_pairs = [] - for i,pairs in enumerate(local_ids): + for i, pairs in enumerate(local_ids): if len(pairs): - for left,right in pairs: - lin_pairs.append((contig_trap[i][0][left], - contig_trap[i][1][right])) + for left, right in pairs: + lin_pairs.append((contig_trap[i][0][left], contig_trap[i][1][right])) return lin_pairs -def get_vec_closest_pairs(lst:List, **kwags): + +def get_vec_closest_pairs(lst: List, **kwags): return [get_closest_pairs(*l, **kwags) for l in lst] -def get_closest_pairs(pre:List[float], post:List[float], tol:Union[float,int]=1): + +def get_closest_pairs(pre: List[float], post: List[float], tol: Union[float, int] = 1): """Calculate a cost matrix the Hungarian algorithm to pick the best set of options @@ -258,22 +283,25 @@ def get_closest_pairs(pre:List[float], post:List[float], tol:Union[float,int]=1) """ if len(pre) > len(post): - dMetric = np.abs(np.subtract.outer(post,pre)) + dMetric = np.abs(np.subtract.outer(post, pre)) else: - dMetric = np.abs(np.subtract.outer(pre,post)) + dMetric = np.abs(np.subtract.outer(pre, post)) # dMetric[np.isnan(dMetric)] = tol + 1 + np.nanmax(dMetric) # nans will be filtered # ids = linear_sum_assignment(dMetric) - dMetric[np.isnan(dMetric)] = tol + 1 + np.nanmax(dMetric) # nans will be filtered + dMetric[np.isnan(dMetric)] = tol + 1 + np.nanmax(dMetric) # nans will be filtered ids = solve_matrix(dMetric) if not len(ids[0]): return [] - norm = np.array(pre)[ids[len(pre)>len(post)]] if tol<1 else 1 # relative or absolute tol - result = dMetric[ids]/norm - ids = ids if len(pre)<len(post) else ids[::-1] + norm = ( + np.array(pre)[ids[len(pre) > len(post)]] if tol < 1 else 1 + ) # relative or absolute tol + result = dMetric[ids] / norm + ids = ids if len(pre) < len(post) else ids[::-1] + + return [idx for idx, res in zip(zip(*ids), result) if res <= tol] - return [idx for idx,res in zip(zip(*ids), result) if res < tol] def solve_matrix(dMetric): """ @@ -288,70 +316,72 @@ def solve_matrix(dMetric): glob_is = [] glob_js = [] if (~np.isnan(dMetric)).any(): - tmp = copy(dMetric ) + tmp = copy(dMetric) std = sorted(tmp[~np.isnan(tmp)]) while (~np.isnan(std)).any(): v = std[0] - i_s,j_s = np.where(tmp==v) + i_s, j_s = np.where(tmp == v) i = i_s[0] j = j_s[0] - tmp[i,:]+= np.nan - tmp[:,j]+= np.nan + tmp[i, :] += np.nan + tmp[:, j] += np.nan glob_is.append(i) glob_js.append(j) std = sorted(tmp[~np.isnan(tmp)]) + return (np.array(glob_is), np.array(glob_js)) - return (np.array( glob_is ), np.array( glob_js )) def plot_joinable(tracks, joinable_pairs, max=64): """ Convenience plotting function for debugging and data vis """ - nx=8 - ny=8 - _, axes = plt.subplots(nx,ny) + nx = 8 + ny = 8 + _, axes = plt.subplots(nx, ny) for i in range(nx): for j in range(ny): - if i*ny+j < len(joinable_pairs): + if i * ny + j < len(joinable_pairs): ax = axes[i, j] pre, post = joinable_pairs[i * ny + j] pre_srs = tracks.loc[pre].dropna() post_srs = tracks.loc[post].dropna() - ax.plot(pre_srs.index, pre_srs.values , 'b') + ax.plot(pre_srs.index, pre_srs.values, "b") # try: # totrange = np.arange(pre_srs.index[0],post_srs.index[-1]) # ax.plot(totrange, interpolate(pre_srs, totrange), 'r-') # except: # pass - ax.plot(post_srs.index, post_srs.values, 'g') + ax.plot(post_srs.index, post_srs.values, "g") plt.show() + def get_contiguous_pairs(tracks: pd.DataFrame) -> list: - ''' + """ Get all pair of contiguous track ids from a tracks dataframe. :param tracks: (m x n) dataframe where rows are cell tracks and columns are timepoints :param min_dgr: float minimum difference in growth rate from the interpolation - ''' + """ # indices = np.where(tracks.notna()) - - mins, maxes = [tracks.notna().apply(np.where, axis=1).apply(fn) - for fn in (np.min, np.max)] + mins, maxes = [ + tracks.notna().apply(np.where, axis=1).apply(fn) for fn in (np.min, np.max) + ] mins_d = mins.groupby(mins).apply(lambda x: x.index.tolist()) - mins_d.index = mins_d.index - 1 # make indices equal + mins_d.index = mins_d.index - 1 # make indices equal maxes_d = maxes.groupby(maxes).apply(lambda x: x.index.tolist()) common = sorted(set(mins_d.index).intersection(maxes_d.index), reverse=True) return [(maxes_d[t], mins_d[t]) for t in common] + # def fit_track(track: pd.Series, obj=None): # if obj is None: # obj = objective @@ -416,7 +446,7 @@ def non_uniform_savgol(x, y, window, polynom): raise ValueError('"x" and "y" must be of the same size') if len(x) < window: - raise ValueError('The data size must be larger than the window size') + raise ValueError("The data size must be larger than the window size") if type(window) is not int: raise TypeError('"window" must be an integer') @@ -434,9 +464,9 @@ def non_uniform_savgol(x, y, window, polynom): polynom += 1 # Initialize variables - A = np.empty((window, polynom)) # Matrix - tA = np.empty((polynom, window)) # Transposed matrix - t = np.empty(window) # Local x variables + A = np.empty((window, polynom)) # Matrix + tA = np.empty((polynom, window)) # Transposed matrix + t = np.empty(window) # Local x variables y_smoothed = np.full(len(y), np.nan) # Start smoothing diff --git a/tests/extraction/__pycache__/log_test.cpython-37-pytest-6.2.5.pyc b/tests/extraction/__pycache__/log_test.cpython-37-pytest-6.2.5.pyc deleted file mode 100644 index 6d598779a3e3d03ea6335aa26a3c1b73d713e53e..0000000000000000000000000000000000000000 Binary files a/tests/extraction/__pycache__/log_test.cpython-37-pytest-6.2.5.pyc and /dev/null differ diff --git a/tests/extraction/test_tracks.py b/tests/extraction/test_tracks.py index 40976fcca875972cf01d5018b0ce64205ce13c46..c51fd9e2394a376f3b4530848f7124284d1948d4 100644 --- a/tests/extraction/test_tracks.py +++ b/tests/extraction/test_tracks.py @@ -1,6 +1,6 @@ - from extraction.core.tracks import load_test_dset, clean_tracks, merge_tracks + def test_clean_tracks(): tracks = load_test_dset() clean = clean_tracks(tracks, min_len=3) @@ -8,24 +8,26 @@ def test_clean_tracks(): assert len(clean) < len(tracks) pass + def test_merge_tracks_drop(): tracks = load_test_dset() - joint_tracks,joint_ids = merge_tracks(tracks,window=3, degree=2, drop=True) + joint_tracks = merge_tracks(tracks, window=3, degree=2, drop=True, tol=1) - assert len(joint_tracks)<len(tracks), 'Error when merging' - - assert len(joint_ids), 'No joint ids found' + assert len(joint_tracks) < len(tracks), "Error when merging" pass + def test_merge_tracks_nodrop(): tracks = load_test_dset() - joint_tracks,joint_ids = merge_tracks(tracks,window=3, degree=2, drop=False) + joint_tracks, joint_ids = merge_tracks( + tracks, window=3, degree=2, drop=False, tol=1 + ) - assert len(joint_tracks)==len(tracks), 'Error when merging' + assert len(joint_tracks) == len(tracks), "Error when merging" - assert len(joint_ids), 'No joint ids found' + assert len(joint_ids), "No joint ids found" pass