From 792fb895ba99f36b6ae54372c9125e7ac06c5cf9 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <alan.munoz@ed.ac.uk>
Date: Wed, 28 Sep 2022 00:47:52 +0100
Subject: [PATCH] refactor(lineageprocess): add lineage requirements

---
 .../core/processes/lineageprocess.py          | 41 ++++++++++++++++---
 1 file changed, 35 insertions(+), 6 deletions(-)

diff --git a/src/postprocessor/core/processes/lineageprocess.py b/src/postprocessor/core/processes/lineageprocess.py
index 2c0cc6a0..38565e18 100644
--- a/src/postprocessor/core/processes/lineageprocess.py
+++ b/src/postprocessor/core/processes/lineageprocess.py
@@ -1,7 +1,11 @@
+import typing as t
+from abc import abstractmethod
+
 import numpy as np
 import pandas as pd
-from agora.abc import ParametersABC
 
+from agora.abc import ParametersABC
+from agora.utils.lineage import group_matrix
 from postprocessor.core.abc import PostProcessABC
 
 
@@ -21,11 +25,6 @@ class LineageProcess(PostProcessABC):
     def __init__(self, parameters: LineageProcessParameters):
         super().__init__(parameters)
 
-    def run(
-        self,
-    ):
-        pass
-
     def filter_signal_cells(self, signal: pd.DataFrame):
         """
         Use casting to filter cell ids in signal and lineage
@@ -44,3 +43,33 @@ class LineageProcess(PostProcessABC):
         )
 
         return self.lineage[mo_av & da_av]
+
+    @abstractmethod
+    def run(
+        self,
+        data: pd.DataFrame,
+        mother_bud_ids: t.Dict[t.Tuple[int], t.Collection[int]],
+        *args,
+    ):
+        pass
+
+    @classmethod
+    def as_function(
+        cls,
+        data: pd.DataFrame,
+        lineage: t.Union[t.Dict[t.Tuple[int], t.List[int]]],
+        *extra_data,
+        **kwargs,
+    ):
+        """
+        Overrides PostProcess.as_function classmethod.
+        Lineage functions require lineage information to be passed if run as function.
+        """
+        if isinstance(lineage, np.ndarray):
+            lineage = group_matrix(lineage, n_keys=2)
+
+        parameters = cls.default_parameters(**kwargs)
+        return cls(parameters=parameters).run(
+            data, mother_bud_ids=lineage, *extra_data
+        )
+        # super().as_function(data, *extra_data, lineage=lineage, **kwargs)
-- 
GitLab