From 8ef3726cc4f980d8ac3c2ab3b67a10f905551245 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <alan.munoz@ed.ac.uk> Date: Mon, 3 Apr 2023 16:38:20 +0100 Subject: [PATCH] fix(pp): refresh bud_metric and merging --- src/agora/utils/merge.py | 20 ++++++---- src/postprocessor/core/processor.py | 4 +- .../core/reshapers/bud_metric.py | 37 +++++++++++++------ 3 files changed, 40 insertions(+), 21 deletions(-) diff --git a/src/agora/utils/merge.py b/src/agora/utils/merge.py index e74a9a03..93669912 100644 --- a/src/agora/utils/merge.py +++ b/src/agora/utils/merge.py @@ -31,8 +31,11 @@ def apply_merges(data: pd.DataFrame, merges: np.ndarray): """ + indices = data.index + if "mother_label" in indices.names: + indices = indices.droplevel("mother_label") valid_merges, indices = validate_association( - merges, np.array(list(data.index)) + merges, np.array(list(indices)) ) # Assign non-merged @@ -129,14 +132,15 @@ def merge_association( valid_indices = comparison_mat.any(axis=0) - replacement_d = {} - for dataset in grouped_merges: - for k in dataset: - replacement_d[tuple(k[0])] = dataset[-1][1] + if valid_indices.any(): # Where valid, perform transformation + replacement_d = {} + for dataset in grouped_merges: + for k in dataset: + replacement_d[tuple(k[0])] = dataset[-1][1] - flat_indices[valid_indices] = [ - replacement_d[tuple(i)] for i in flat_indices[valid_indices] - ] + flat_indices[valid_indices] = [ + replacement_d[tuple(i)] for i in flat_indices[valid_indices] + ] merged_indices = flat_indices.reshape(-1, 2, 2) return merged_indices diff --git a/src/postprocessor/core/processor.py b/src/postprocessor/core/processor.py index e0adac14..72bbf3f7 100644 --- a/src/postprocessor/core/processor.py +++ b/src/postprocessor/core/processor.py @@ -204,7 +204,9 @@ class PostProcessor(ProcessABC): self._writer.write( "modifiers/picks", data=pd.MultiIndex.from_arrays( - picked_indices, names=["trap", "cell_label"] + picked_indices.T, + # names=["trap", "cell_label", "mother_label"], + names=["trap", "cell_label"], ), overwrite="overwrite", ) diff --git a/src/postprocessor/core/reshapers/bud_metric.py b/src/postprocessor/core/reshapers/bud_metric.py index d23ea115..bf813739 100644 --- a/src/postprocessor/core/reshapers/bud_metric.py +++ b/src/postprocessor/core/reshapers/bud_metric.py @@ -65,23 +65,36 @@ class BudMetric(LineageProcess): mothers_mat = np.zeros((len(md), signal.shape[1])) cells_were_dropped = 0 # Flag determines if mothers (1), daughters (2) or both were missing (3) - md_index = signal.index.droplevel("mother_label") - if md is not None: # Get intersection of lineage and current index + md_index = signal.index + if ( + "mother_label" not in md_index.names + ): # Generate mother label from md dict if unavailable + d = {v: k for k, values in md.items() for v in values} + signal["mother_label"] = list( + map(lambda x: d.get(x, [0])[-1], signal.index) + ) + signal.set_index("mother_label", append=True, inplace=True) related_items = set( [*md.keys(), *[y for x in md.values() for y in x]] ) - intersection = md_index.intersection(related_items) - if len(intersection) < len(signal): - print("Dropped cells before bud_metric") # TODO log - - signal = ( - signal.reset_index("mother_label") - .loc(axis=0)[intersection] - .set_index("mother_label", append=True) - ) + md_index = md_index.intersection(related_items) + elif "mother_label" in md_index.names: + md_index = md_index.droplevel("mother_label") + else: + raise ("Unavailable relationship information") + + if len(md_index) < len(signal): + print("Dropped cells before bud_metric") # TODO log + + signal = ( + signal.reset_index("mother_label") + .loc(axis=0)[md_index] + .set_index("mother_label", append=True) + ) names = list(signal.index.names) del names[-2] + output_df = ( signal.loc[signal.index.get_level_values("mother_label") > 0] .groupby(names) @@ -101,7 +114,7 @@ def _combine_daughter_tracks(tracks: t.Collection[pd.Series]): prioritising the most recent entity. """ sorted_da_ids = tracks.sort_index(level="cell_label") - tp_fvt = sorted_da_ids.apply(lambda x: x.last_valid_index(), axis=0) + tp_fvt = sorted_da_ids.apply(lambda x: x.first_valid_index(), axis=1) tp_fvt = sorted_da_ids.index.get_indexer(tp_fvt) tp_fvt[tp_fvt < 0] = sorted_da_ids.shape[0] - 1 -- GitLab