diff --git a/src/postprocessor/grouper.py b/src/postprocessor/grouper.py index 5c062d7a9dc2c0675f3cb5f965dbcde66dc41363..4be02a8defed3fe0ddbfaae9e8ea224319dbd3b4 100644 --- a/src/postprocessor/grouper.py +++ b/src/postprocessor/grouper.py @@ -15,7 +15,11 @@ import pandas as pd import seaborn as sns from pathos.multiprocessing import Pool -from agora.utils.kymograph import drop_level, intersection_matrix +from agora.utils.kymograph import ( + drop_level, + get_mother_ilocs_from_daughters, + intersection_matrix, +) from postprocessor.chainer import Chainer @@ -403,13 +407,7 @@ def concat_signal_ind( elif mode == "families": combined = chainer.get_raw(path, **kwargs) - daughter_ids = combined.index[ - combined.index.get_level_values("mother_label") > 0 - ] - mother_id_mask = intersection_matrix( - daughter_ids.droplevel("cell_label"), - drop_level(combined, "mother_label", as_list=False), - ).any(axis=0) + mother_id_mask = get_mother_ilocs_from_daughters(df) combined = combined.loc[ combined.index[mother_id_mask].union(daughter_ids) ]