From 148d92c32ed250ff70def7430a03e94d7f9c8338 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <alan.munoz@ed.ac.uk>
Date: Wed, 5 Oct 2022 17:40:48 +0100
Subject: [PATCH] refactor(picker): use shared function for indices

---
 src/postprocessor/core/reshapers/picker.py | 41 ++++++++++++++--------
 1 file changed, 27 insertions(+), 14 deletions(-)

diff --git a/src/postprocessor/core/reshapers/picker.py b/src/postprocessor/core/reshapers/picker.py
index 117b5d4b..4d8a6b2d 100644
--- a/src/postprocessor/core/reshapers/picker.py
+++ b/src/postprocessor/core/reshapers/picker.py
@@ -1,19 +1,23 @@
-from abc import ABC, abstractmethod
+# from abc import ABC, abstractmethod
 
 # from copy import copy
-from itertools import groupby
-from typing import List, Tuple, Union
+# from itertools import groupby
+# from typing import List, Tuple, Union
+import typing as t
+from typing import Union
 
-import igraph as ig
+# import igraph as ig
 import numpy as np
 import pandas as pd
+
 from agora.abc import ParametersABC
 from agora.io.cells import Cells
-from utils_find_1st import cmp_equal, find_1st
 
-from postprocessor.core.lineageprocess import LineageProcess
-from postprocessor.core.functions.tracks import max_nonstop_ntps, max_ntps
+# from postprocessor.core.functions.tracks import max_nonstop_ntps, max_ntps
 from agora.utils.association import validate_association
+from postprocessor.core.lineageprocess import LineageProcess
+
+# from utils_find_1st import cmp_equal, find_1st
 
 
 class pickerParameters(ParametersABC):
@@ -47,11 +51,18 @@ class picker(LineageProcess):
 
         self.cells = cells
 
-    def pick_by_lineage(self, signal, how):
+    def pick_by_lineage(
+        self,
+        signal: pd.DataFrame,
+        how: str,
+        mothers_daughters: t.Optional[np.ndarray] = None,
+    ):
         self.orig_signals = signal
 
         idx = np.array(signal.index.to_list())
-        mothers_daughters = self.cells.mothers_daughters
+
+        if mothers_daughters is None:
+            mothers_daughters = self.cells.mothers_daughters
         valid_indices, valid_lineage = [slice(None)] * 2
 
         if how == "mothers":
@@ -60,11 +71,11 @@ class picker(LineageProcess):
             )
         elif how == "daughters":
             valid_lineage, valid_indices = validate_association(
-                mothers_daughters, idx, match_column=0
+                mothers_daughters, idx, match_column=1
             )
         elif how == "families":  # Mothers and daughters that are still present
             valid_lineage, valid_indices = validate_association(
-                mothers_daughters, idx, match_column=0
+                mothers_daughters, idx
             )
 
         idx = idx[valid_indices]
@@ -72,9 +83,11 @@ class picker(LineageProcess):
 
         return mothers_daughters, idx
 
-    def loc_lineage(self, signals: pd.DataFrame, how: str):
-        _, valid_indices = self.pick_by_lineage(signals, how)
-        return signals.loc[valid_indices]
+    def loc_lineage(self, kymo: pd.DataFrame, how: str, lineage=None):
+        _, valid_indices = self.pick_by_lineage(
+            kymo, how, mothers_daughters=lineage
+        )
+        return kymo.loc[[tuple(x) for x in valid_indices]]
 
     def pick_by_condition(self, signals, condition, thresh):
         idx = self.switch_case(signals, condition, thresh)
-- 
GitLab