From 001e309aea9cfc30e9654adf9fa3ce0e2a9595df Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <alan.munoz@ed.ac.uk>
Date: Sun, 23 Oct 2022 19:29:57 +0100
Subject: [PATCH] refactor(grouper): update imports for filtering

---
 src/postprocessor/grouper.py | 14 ++++++--------
 1 file changed, 6 insertions(+), 8 deletions(-)

diff --git a/src/postprocessor/grouper.py b/src/postprocessor/grouper.py
index 5c062d7a..4be02a8d 100644
--- a/src/postprocessor/grouper.py
+++ b/src/postprocessor/grouper.py
@@ -15,7 +15,11 @@ import pandas as pd
 import seaborn as sns
 from pathos.multiprocessing import Pool
 
-from agora.utils.kymograph import drop_level, intersection_matrix
+from agora.utils.kymograph import (
+    drop_level,
+    get_mother_ilocs_from_daughters,
+    intersection_matrix,
+)
 from postprocessor.chainer import Chainer
 
 
@@ -403,13 +407,7 @@ def concat_signal_ind(
     elif mode == "families":
         combined = chainer.get_raw(path, **kwargs)
 
-        daughter_ids = combined.index[
-            combined.index.get_level_values("mother_label") > 0
-        ]
-        mother_id_mask = intersection_matrix(
-            daughter_ids.droplevel("cell_label"),
-            drop_level(combined, "mother_label", as_list=False),
-        ).any(axis=0)
+        mother_id_mask = get_mother_ilocs_from_daughters(df)
         combined = combined.loc[
             combined.index[mother_id_mask].union(daughter_ids)
         ]
-- 
GitLab