From f328b45fcf3fef85a15c1b2b820fd17362abd248 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <alan.munoz@ed.ac.uk>
Date: Wed, 15 Mar 2023 14:41:53 +0000
Subject: [PATCH] refactor(lineage): move functions to parent/utils

---
 src/agora/utils/kymograph.py               | 12 +++++
 src/postprocessor/core/lineageprocess.py   | 14 +++---
 src/postprocessor/core/processor.py        |  3 +-
 src/postprocessor/core/reshapers/picker.py | 51 ++++++++++------------
 4 files changed, 45 insertions(+), 35 deletions(-)

diff --git a/src/agora/utils/kymograph.py b/src/agora/utils/kymograph.py
index 71411e1d..f33c1c1d 100644
--- a/src/agora/utils/kymograph.py
+++ b/src/agora/utils/kymograph.py
@@ -163,3 +163,15 @@ def slices_from_spans(spans: t.Tuple[int], df: pd.DataFrame) -> t.List[slice]:
         for start, end in zip(cumsum[:-1], cumsum[1:])
     ]
     return slices
+
+
+def drop_mother_label(index: pd.MultiIndex) -> np.ndarray:
+    no_mother_label = index
+    if "mother_label" in index.names:
+        no_mother_label = index.droplevel("mother_label")
+    return np.array(no_mother_label.tolist())
+
+
+def get_index_as_np(signal: pd.DataFrame):
+    # Get mother labels from multiindex dataframe
+    return np.array(signal.index.to_list())
diff --git a/src/postprocessor/core/lineageprocess.py b/src/postprocessor/core/lineageprocess.py
index f10d5b3e..1c875020 100644
--- a/src/postprocessor/core/lineageprocess.py
+++ b/src/postprocessor/core/lineageprocess.py
@@ -51,9 +51,11 @@ class LineageProcess(PostProcessABC):
             data, lineage=lineage, *extra_data
         )
 
-    def load_lineage(self, lineage):
-        """
-        Reshape the lineage information if needed
-        """
-        # TODO does this need to be a function?
-        self.lineage = lineage
+    def get_lineage_information(self, signal):
+        if "mother_label" in signal.index.names:
+            lineage = get_index_as_np(signal)
+        elif self.cells is not None:
+            lineage = self.cells.mothers_daughters
+        else:
+            raise Exception("No linage information found")
+        return lineage
diff --git a/src/postprocessor/core/processor.py b/src/postprocessor/core/processor.py
index ec3b74ab..b010efd1 100644
--- a/src/postprocessor/core/processor.py
+++ b/src/postprocessor/core/processor.py
@@ -286,7 +286,8 @@ class PostProcessor(ProcessABC):
                     # self.parameters.lineage_location
                 )
                 loaded_process = self.classfun[process](parameters)
-                loaded_process.load_lineage(lineage)
+                loaded_process.lineage = lineage
+
             else:
                 loaded_process = self.classfun[process](parameters)
 
diff --git a/src/postprocessor/core/reshapers/picker.py b/src/postprocessor/core/reshapers/picker.py
index c331c842..4132cf8d 100644
--- a/src/postprocessor/core/reshapers/picker.py
+++ b/src/postprocessor/core/reshapers/picker.py
@@ -6,7 +6,8 @@ import pandas as pd
 from agora.abc import ParametersABC
 from agora.io.cells import Cells
 
-from agora.utils.association import validate_association, last_col_as_rows
+from agora.utils.association import validate_association
+from agora.utils.kymograph import drop_mother_label, get_index_as_np
 from postprocessor.core.lineageprocess import LineageProcess
 
 
@@ -24,14 +25,14 @@ class Picker(LineageProcess):
     :cells: Cell object passed to the constructor
     :condition: Tuple with condition and associated parameter(s), conditions can be
     "present", "nonstoply_present" or "quantile".
-    Determines the thersholds or fractions of signals/signals to use.
+    Determines the thersholds or fractions of signals to use.
     :lineage: str {"mothers", "daughters", "families" (mothers AND daughters), "orphans"}. Mothers/daughters picks cells with those tags, families pick the union of both and orphans the difference between the total and families.
     """
 
     def __init__(
         self,
         parameters: PickerParameters,
-        cells: Cells,
+        cells: Cells or None = None,
     ):
         super().__init__(parameters=parameters)
 
@@ -46,8 +47,7 @@ class Picker(LineageProcess):
 
         cells_present = drop_mother_label(signal.index)
 
-        if mothers_daughters is None:
-            mothers_daughters = self.cells.mothers_daughters
+        mothers_daughters = self.get_lineage_information(signal)
 
         valid_indices = slice(None)
 
@@ -66,15 +66,17 @@ class Picker(LineageProcess):
 
         return signal.index[valid_indices]
 
-    def pick_by_condition(self, signals, condition, thresh):
-        idx = self.switch_case(signals, condition, thresh)
+    def pick_by_condition(self, signal, condition, thresh):
+        idx = self.switch_case(signal, condition, thresh)
         return idx
 
-    def run(self, signals):
-        self.orig_signals = signals
-        indices = set(signals.index)
-        lineage = self.cells.mothers_daughters
-        if lineage.any():
+    def run(self, signal):
+        self.orig_signal = signal
+        indices = set(signal.index)
+
+        lineage = self.get_lineage_information(signal)
+
+        if len(lineage):
             self.mothers = lineage[:, :2]
             self.daughters = lineage[:, [0, 2]]
 
@@ -84,12 +86,12 @@ class Picker(LineageProcess):
                     if alg == "lineage":
                         param1 = params[0]
                         new_indices = getattr(self, "pick_by_" + alg)(
-                            signals.loc[list(indices)], param1
+                            signal.loc[list(indices)], param1
                         )
                     else:
                         param1, *param2 = params
                         new_indices = getattr(self, "pick_by_" + alg)(
-                            signals.loc[list(indices)], param1, param2
+                            signal.loc[list(indices)], param1, param2
                         )
                         new_indices = [tuple(x) for x in new_indices]
 
@@ -102,12 +104,12 @@ class Picker(LineageProcess):
 
     def switch_case(
         self,
-        signals: pd.DataFrame,
+        signal: pd.DataFrame,
         condition: str,
         threshold: t.Union[float, int, list],
     ):
         if len(threshold) == 1:
-            threshold = [_as_int(*threshold, signals.shape[1])]
+            threshold = [_as_int(*threshold, signal.shape[1])]
         case_mgr = {
             "any_present": lambda s, thresh: any_present(s, thresh),
             "present": lambda s, thresh: s.notna().sum(axis=1) > thresh,
@@ -115,7 +117,7 @@ class Picker(LineageProcess):
             > thresh,
             "growing": lambda s, thresh: s.diff(axis=1).sum(axis=1) > thresh,
         }
-        return set(signals.index[case_mgr[condition](signals, *threshold)])
+        return set(signal.index[case_mgr[condition](signal, *threshold)])
 
 
 def _as_int(threshold: t.Union[float, int], ntps: int):
@@ -124,28 +126,21 @@ def _as_int(threshold: t.Union[float, int], ntps: int):
     return threshold
 
 
-def any_present(signals, threshold):
+def any_present(signal, threshold):
     """
     Returns a mask for cells, True if there is a cell in that trap that was present for more than :threshold: timepoints.
     """
     any_present = pd.Series(
         np.sum(
             [
-                np.isin([x[0] for x in signals.index], i) & v
-                for i, v in (signals.notna().sum(axis=1) > threshold)
+                np.isin([x[0] for x in signal.index], i) & v
+                for i, v in (signal.notna().sum(axis=1) > threshold)
                 .groupby("trap")
                 .any()
                 .items()
             ],
             axis=0,
         ).astype(bool),
-        index=signals.index,
+        index=signal.index,
     )
     return any_present
-
-
-def drop_mother_label(index: pd.MultiIndex) -> np.ndarray:
-    no_mother_label = index
-    if "mother_label" in index.names:
-        no_mother_label = index.droplevel("mother_label")
-    return np.array(no_mother_label.tolist())
-- 
GitLab