From b1919402ef89c0fb15ba041da5af0749a64d2c3f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <amuoz@ed.ac.uk>
Date: Fri, 8 Oct 2021 19:48:31 +0100
Subject: [PATCH] add lineage management

Former-commit-id: 8aaaf056b023a2362888d5c635950a9bd53ff705
---
 core/processes/picker.py | 94 +++++++++++++++++++++++++++++++---------
 core/processor.py        | 63 ++++++++++++++++++++++++---
 2 files changed, 130 insertions(+), 27 deletions(-)

diff --git a/core/processes/picker.py b/core/processes/picker.py
index 3668e922..54c658bc 100644
--- a/core/processes/picker.py
+++ b/core/processes/picker.py
@@ -1,6 +1,8 @@
 from typing import Tuple, Union, List
 from abc import ABC, abstractmethod
 
+from itertools import groupby
+from utils_find_1st import find_1st, cmp_equal
 import numpy as np
 import pandas as pd
 
@@ -15,10 +17,12 @@ class pickerParameters(ParametersABC):
         self,
         condition: Tuple[str, Union[float, int]] = None,
         lineage: str = None,
+        lineage_conditional: str = None,
         sequence: List[str] = ["lineage", "condition"],
     ):
         self.condition = condition
         self.lineage = lineage
+        self.lineage_conditional = lineage_conditional
         self.sequence = sequence
 
     @classmethod
@@ -26,8 +30,9 @@ class pickerParameters(ParametersABC):
         return cls.from_dict(
             {
                 "condition": ["present", 0.8],
-                "lineage": None,
-                "sequence": ["lineage", "condition"],
+                "lineage": "families",
+                "lineage_conditional": "include",
+                "sequence": ["condition", "lineage"],
             }
         )
 
@@ -59,44 +64,93 @@ class picker(ProcessABC):
         for cells in ma:
             for d, m in enumerate(cells):
                 if m:
-                    mb_matrix[c + d, c + m] = True
+                    mb_matrix[c + d, c + m - 1] = True
 
             c += len(cells)
 
         return mb_matrix
 
+    @staticmethod
+    def mother_assign_from_dynamic(ma, label, trap, ntraps: int):
+        """
+        Interpolate the list of lists containing the associated mothers from the mother_assign_dynamic feature
+        """
+        idlist = list(zip(trap, label))
+        cell_gid = np.unique(idlist, axis=0)
+
+        last_lin_preds = [
+            find_1st(((label[::-1] == lbl) & (trap[::-1] == tr)), True, cmp_equal)
+            for tr, lbl in cell_gid
+        ]
+        mother_assign_sorted = ma[last_lin_preds]
+
+        traps = cell_gid[:, 0]
+        iterator = groupby(zip(traps, mother_assign_sorted), lambda x: x[0])
+        d = {key: [x[1] for x in group] for key, group in iterator}
+        nested_massign = [d.get(i, []) for i in range(ntraps)]
+
+        return nested_massign
+
     def pick_by_lineage(self, signals):
         idx = signals.index
 
         if self.lineage:
-            ma = self._cells["mother_assign"]
-            mother_bud_mat = self.mother_assign_to_mb_matrix(ma)
-            daughters, mothers = np.where(mother_bud_mat)
+            ma = self._cells["mother_assign_dynamic"]
+            trap = self._cells["trap"]
+            label = self._cells["cell_label"]
+            nested_massign = self.mother_assign_from_dynamic(
+                ma, label, trap, self._cells.ntraps
+            )
+            # mother_bud_mat = self.mother_assign_to_mb_matrix(nested_massign)
+
+            idx = set(
+                [
+                    (tid, i + 1)
+                    for tid, x in enumerate(nested_massign)
+                    for i in range(len(x))
+                ]
+            )
+            mothers, daughters = zip(
+                *[
+                    ((tid, m), (tid, d))
+                    for tid, trapcells in enumerate(nested_massign)
+                    for d, m in enumerate(trapcells, 1)
+                    if m
+                ]
+            )
+            self.mothers = mothers
+            self.daughters = daughters
+
+            mothers = set(mothers)
+            daughters = set(daughters)
+            # daughters, mothers = np.where(mother_bud_mat)
             if self.lineage == "mothers":
-                idx = idx[mothers]
+                idx = mothers
             elif self.lineage == "daughters":
-                idx = idx[daughters]
+                idx = daughters
             elif self.lineage == "families" or self.lineage == "orphans":
-                families = list(set(np.append(daughters, mothers)))
+                families = mothers.union(daughters)
                 if self.lineage == "families":
-                    idx = idx[families]
-                else:  # orphans
-                    idx = idx[list(set(range(len(idx))).difference(families))]
+                    idx = families
+                elif self.lineage == "orphans":  # orphans
+                    idx = idx.diference(families)
 
-            idx = list(set(idx).intersection(signals.index))
+            idx = idx.intersection(signals.index)
 
-        return signals.loc[idx]
+        return idx
 
     def pick_by_condition(self, signals):
         idx = self.switch_case(self.condition[0], signals, self.condition[1])
-        return signals.loc[idx]
+        return idx
 
     def run(self, signals):
+        indices = set(signals.index)
+        daughters, mothers = (None, None)
         for alg in self.sequence:
-            if alg == "condition":
-                pass
-            self.signals = getattr(self, "pick_by_" + alg)(signals)
-        return self.signals
+            indices = getattr(self, "pick_by_" + alg)(signals)
+
+        daughters, mothers = self.daughters, self.mothers
+        return np.array(daughters), np.array(mothers), np.array(list(indices))
 
     @staticmethod
     def switch_case(
@@ -111,7 +165,7 @@ class picker(ProcessABC):
             > threshold_asint,
             "quantile": [np.quantile(signals.values[signals.notna()], threshold)],
         }
-        return case_mgr[condition]
+        return set(case_mgr[condition].index)
 
 
 def _as_int(threshold: Union[float, int], ntps: int):
diff --git a/core/processor.py b/core/processor.py
index 239edd99..09caa6a1 100644
--- a/core/processor.py
+++ b/core/processor.py
@@ -109,6 +109,7 @@ class PostProcessor:
 
     def run_prepost(self):
         """Important processes run before normal post-processing ones"""
+
         merge_events = self.merger.run(self._signal[self.targets["prepost"]["merger"]])
 
         with h5py.File(self._filename, "r") as f:
@@ -117,8 +118,54 @@ class PostProcessor:
         changes_history = list(prev_idchanges) + [np.array(x) for x in merge_events]
         self._writer.write("modifiers/merges", data=changes_history)
 
-        picks = self.picker.run(self._signal[self.targets["prepost"]["picker"][0]])
-        self._writer.write("modifiers/picks", data=picks)
+        with h5py.File(self._filename, "a") as f:  # TODO Remove this once done tweaking
+            if "modifiers/picks" in f:
+                del f["modifiers/picks"]
+
+        mothers, daughters, indices = self.picker.run(
+            self._signal[self.targets["prepost"]["picker"][0]]
+        )
+        self._writer.write(
+            "postprocessing/lineage",
+            data=pd.MultiIndex.from_arrays(
+                np.append(mothers, daughters[:, 1].reshape(-1, 1), axis=1).T,
+                names=["trap", "mother_label", "daughter_label"],
+            ),
+            overwrite="overwrite",
+        )
+
+        # apply merge to mother-daughter
+        moset = set([tuple(x) for x in mothers])
+        daset = set([tuple(x) for x in daughters])
+        picked_set = set([tuple(x) for x in indices])
+        with h5py.File(self._filename, "a") as f:
+            merge_events = f["modifiers/merges"][()]
+        merged_moda = set([tuple(x) for x in merge_events[:, 0, :]]).intersection(
+            set([*moset, *daset, *picked_set])
+        )
+        for source, target in merge_events:
+            if tuple(source) in merged_moda:
+                mothers[np.isin(mothers, source).all(axis=1)] = target
+                daughters[np.isin(daughters, source).all(axis=1)] = target
+                indices[np.isin(indices, source).all(axis=1)] = target
+
+        self._writer.write(
+            "postprocessing/lineage_merged",
+            data=pd.MultiIndex.from_arrays(
+                np.append(mothers, daughters[:, 1].reshape(-1, 1), axis=1).T,
+                names=["trap", "mother_label", "daughter_label"],
+            ),
+            overwrite="overwrite",
+        )
+
+        self._writer.write(
+            "modifiers/picks",
+            data=pd.MultiIndex.from_arrays(
+                indices.T,
+                names=["trap", "cell_label"],
+            ),
+            overwrite="overwrite",
+        )
 
     def run(self):
         self.run_prepost()
@@ -165,16 +212,18 @@ class PostProcessor:
                 else:
                     raise ("Outpath not defined", type(dataset))
 
-                if isinstance(result, dict): # Multiple Signals as output
+                if isinstance(result, dict):  # Multiple Signals as output
                     for k, v in result:
                         self.write_result(
-                            "/postprocessing/" + process + "/" + outpath +
-                            f'/{k}',
-                            v, metadata={}
+                            "/postprocessing/" + process + "/" + outpath + f"/{k}",
+                            v,
+                            metadata={},
                         )
                 else:
                     self.write_result(
-                        "/postprocessing/" + process + "/" + outpath, result, metadata={}
+                        "/postprocessing/" + process + "/" + outpath,
+                        result,
+                        metadata={},
                     )
 
     def write_result(
-- 
GitLab