diff --git a/src/agora/io/signal.py b/src/agora/io/signal.py index 8a3957c35fcf81744abdceaab91c7ad556a7d4b9..876424d12e038707156454381eff7808c1cacc66 100644 --- a/src/agora/io/signal.py +++ b/src/agora/io/signal.py @@ -11,7 +11,7 @@ import pandas as pd from agora.io.bridge import BridgeH5 from agora.io.decorators import _first_arg_str_to_raw_df -from agora.utils.indexing import validate_association +from agora.utils.indexing import validate_lineage from agora.utils.kymograph import add_index_levels from agora.utils.merge import apply_merges @@ -295,13 +295,13 @@ class Signal(BridgeH5): # assume that df is sorted mother_label = np.zeros(len(df), dtype=int) lineage = self.lineage() - valid_association, valid_indices = validate_association( + # information on buds + valid_lineage, valid_indices = validate_lineage( lineage, np.array(df.index.to_list()), - #: are mothers not match_column=0? - match_column=1, + "daughters", ) - mother_label[valid_indices] = lineage[valid_association, 1] + mother_label[valid_indices] = lineage[valid_lineage, 1] df = add_index_levels(df, {"mother_label": mother_label}) return df except Exception as e: diff --git a/src/agora/utils/indexing.py b/src/agora/utils/indexing.py index a36f7300bdd71971e3ca201b8cf59e85f579fa7f..9a07a6d52faa9e5c5a577b1e2776456a39174e7a 100644 --- a/src/agora/utils/indexing.py +++ b/src/agora/utils/indexing.py @@ -14,7 +14,11 @@ def validate_lineage( lineage: np.ndarray, indices: np.ndarray, how: str = "families" ): """ - Identify mother-bud pairs that exist both in lineage and a Signal's indices. + Identify mother-bud pairs that exist both in lineage and a Signal's + indices. + + We expect the lineage information to be unique: a bud should not have + two mothers. Parameters ---------- @@ -71,28 +75,33 @@ def validate_lineage( c_index = 0 elif how == "daughters": c_index = 1 - # dtype links together trap and cell ids + # data type to link together trap and cell ids dtype = {"names": ["trap_id", "cell_id"], "formats": [np.int64, np.int64]} lineage = np.ascontiguousarray(lineage, dtype=np.int64) # find (trap, cell_ids) in intersection inboth = np.intersect1d(lineage.view(dtype), indices.view(dtype)) # find valid lineage - valid_lineage = np.isin(lineage.view(dtype), inboth) + valid_lineages = np.isin(lineage.view(dtype), inboth) if how == "families": # both mother and bud must be in indices - valid_lineage = valid_lineage.all(axis=1) + valid_lineage = valid_lineages.all(axis=1) else: - valid_lineage = valid_lineage[:, c_index, :] + valid_lineage = valid_lineages[:, c_index, :] # find valid indices - possible_indices = lineage[valid_lineage.flatten(), ...] + selected_lineages = lineage[valid_lineage.flatten(), ...] if how == "families": # select only pairs of mother and bud indices valid_indices = np.isin( - indices.view(dtype), possible_indices.view(dtype) + indices.view(dtype), selected_lineages.view(dtype) ) else: valid_indices = np.isin( - indices.view(dtype), possible_indices.view(dtype)[:, c_index, :] + indices.view(dtype), selected_lineages.view(dtype)[:, c_index, :] + ) + if valid_indices[valid_indices].size != valid_lineage[valid_lineage].size: + raise Exception( + "Error in validate_lineage: " + "lineage information is likely not unique." ) return valid_lineage.flatten(), valid_indices.flatten() diff --git a/src/postprocessor/core/functions/tracks.py b/src/postprocessor/core/functions/tracks.py index 5fa04a1506996909f1cccb18ddb3e9c80f6eb0c3..d49a4cfbbfea06f1554aaae290f54bf358478484 100644 --- a/src/postprocessor/core/functions/tracks.py +++ b/src/postprocessor/core/functions/tracks.py @@ -197,23 +197,26 @@ def get_joinable(tracks, smooth=False, tol=0.1, window=5, degree=3) -> dict: tolerance. If there is a track that can be assigned to two or more other ones, choose the one with lowest error. - :param tracks: (m x n) dataframe where rows are cell tracks and - columns are timepoints - :param tol: float or int threshold of average (prediction error/std) necessary + Parameters + ---------- + tracks: (m x n) Signal + A Signal, usually area, dataframe where rows are cell tracks and + columns are time points. + tol: float or int + threshold of average (prediction error/std) necessary to consider two tracks the same. If float is fraction of first track, if int it is absolute units. - :param window: int value of window used for savgol_filter - :param degree: int value of polynomial degree passed to savgol_filter - + window: int + value of window used for savgol_filter + degree: int + value of polynomial degree passed to savgol_filter """ + # only consider time series with more than two non-NaN data points tracks = tracks.loc[tracks.notna().sum(axis=1) > 2] - # Commented because we are not smoothing in this step yet - # candict = {k:v for d in contig.values for k,v in d.items()} - # smooth all relevant tracks - - if smooth: # Apply savgol filter TODO fix nans affecting edge placing + if smooth: + # Apply savgol filter TODO fix nans affecting edge placing clean = clean_tracks( tracks, min_len=window + 1, min_gr=0.9 ) # get useful tracks @@ -243,14 +246,6 @@ def get_joinable(tracks, smooth=False, tol=0.1, window=5, degree=3) -> dict: for pres, posts in preposts ] - # idx_to_means = lambda preposts: [ - # ( - # [get_means(smoothed_tracks.loc[pre], -window) for pre in pres], - # [get_means(smoothed_tracks.loc[post], window) for post in posts], - # ) - # for pres, posts in preposts - # ] - def idx_to_pred(preposts): result = [] for pres, posts in preposts: @@ -274,15 +269,6 @@ def get_joinable(tracks, smooth=False, tol=0.1, window=5, degree=3) -> dict: # edges_dMetric = edges.apply(get_dMetric_wrap, tol=tol) # edges_dMetric_mean = edges_mean.apply(get_dMetric_wrap, tol=tol) edges_dMetric_pred = pre_pred.apply(get_dMetric_wrap, tol=tol) - - # combined_dMetric = pd.Series( - # [ - # [np.nanmin((a, b), axis=0) for a, b in zip(x, y)] - # for x, y in zip(edges_dMetric, edges_dMetric_mean) - # ], - # index=edges_dMetric.index, - # ) - # closest_pairs = combined_dMetric.apply(get_vec_closest_pairs, tol=tol) solutions = [] # for (i, dMetrics), edgeset in zip(combined_dMetric.items(), edges): for (i, dMetrics), edgeset in zip(edges_dMetric_pred.items(), edges): @@ -479,25 +465,25 @@ def plot_joinable(tracks, joinable_pairs): def get_contiguous_pairs(tracks: pd.DataFrame) -> list: """ - Get all pair of contiguous track ids from a tracks dataframe. + Get all pair of contiguous track ids from a tracks data frame. - :param tracks: (m x n) dataframe where rows are cell tracks and - columns are timepoints - :param min_dgr: float minimum difference in growth rate from - the interpolation + Parameters + ---------- + tracks: pd.Dataframe + A dataframe for one trap where rows are cell tracks and columns + are time points. """ - mins, maxes = [ + # time points bounding a tracklet of non-NaN values + mins, maxs = [ tracks.notna().apply(np.where, axis=1).apply(fn) for fn in (np.min, np.max) ] mins_d = mins.groupby(mins).apply(lambda x: x.index.tolist()) mins_d.index = mins_d.index - 1 # make indices equal # TODO add support for skipping time points - maxes_d = maxes.groupby(maxes).apply(lambda x: x.index.tolist()) - common = sorted( - set(mins_d.index).intersection(maxes_d.index), reverse=True - ) - return [(maxes_d[t], mins_d[t]) for t in common] + maxs_d = maxs.groupby(maxs).apply(lambda x: x.index.tolist()) + common = sorted(set(mins_d.index).intersection(maxs_d.index), reverse=True) + return [(maxs_d[t], mins_d[t]) for t in common] # def fit_track(track: pd.Series, obj=None): diff --git a/src/postprocessor/core/reshapers/merger.py b/src/postprocessor/core/reshapers/merger.py index b5115c4df9b8b08a0c85c828c65359773eec5748..5c035040d4f7c559e4b7ecf996f665a3b80370da 100644 --- a/src/postprocessor/core/reshapers/merger.py +++ b/src/postprocessor/core/reshapers/merger.py @@ -14,7 +14,7 @@ class MergerParameters(ParametersABC): Whether or not to smooth with a savgol_filter. tol: float or int The threshold of average prediction error/std necessary to - consider two tracks the same. + consider two tracks to be the same. If float, the threshold is the fraction of the first track; if int, the threshold is in absolute units. window: int