From 01adfb23806779a95796f2aaad4a239295a50811 Mon Sep 17 00:00:00 2001
From: Swainlab <peter.swain@ed.ac.uk>
Date: Fri, 21 Jul 2023 12:56:54 +0100
Subject: [PATCH] added and verified validate_lineage, now used by picker

---
 src/agora/utils/indexing_new.py            | 89 ++++++++++++++++++----
 src/postprocessor/core/reshapers/picker.py | 21 ++++-
 2 files changed, 94 insertions(+), 16 deletions(-)

diff --git a/src/agora/utils/indexing_new.py b/src/agora/utils/indexing_new.py
index 0f7b4a26..aca8667c 100644
--- a/src/agora/utils/indexing_new.py
+++ b/src/agora/utils/indexing_new.py
@@ -12,31 +12,90 @@ import typing as t
 
 
 def validate_lineage(
-    lineage: np.ndarray,
-    indices: np.ndarray,
-    match_column: t.Optional[int] = None,
-) -> t.Tuple[np.ndarray, np.ndarray]:
+    lineage: np.ndarray, indices: np.ndarray, how: str = "families"
+):
     """
     Identify mother-bud pairs that exist both in lineage and a Signal's indices.
+
+    Parameters
+    ----------
+    lineage : np.ndarray
+        2D array of lineage associations where columns are
+        (trap, mother, daughter)
+        or
+        a 3D array, which is an array of 2 X 2 arrays comprising
+        [[trap_id, mother_label], [trap_id, daughter_label]].
+    indices : np.ndarray
+        A 2D array of cell indices from a Signal, (trap_id, cell_label).
+        This array should not include mother_label.
+    how: str
+        If "mothers", matches indicate mothers from mother-bud pairs;
+        If "daughters", matches indicate daughters from mother-bud pairs;
+        If "families", matches indicate mothers and daughters in mother-bud pairs.
+
+    Returns
+    -------
+    valid_lineage: boolean np.ndarray
+        1D array indicating matched elements in lineage.
+    valid_indices: boolean np.ndarray
+        1D array indicating matched elements in indices.
+
+    Examples
+    --------
+    >>> import numpy as np
+    >>> from agora.utils.indexing import validate_lineage
+
+    >>> lineage = np.array([ [[0, 1], [0, 3]], [[0, 1], [0, 4]], [[0, 1], [0, 6]], [[0, 4], [0, 7]] ])
+    >>> indices = np.array([ [0, 1], [0, 2], [0, 3]])
+
+    >>> valid_lineage, valid_indices = validate_lineage(lineage, indices)
+
+    >>> print(valid_lineage)
+     array([ True, False, False, False])
+    >>> print(valid_indices)
+     array([ True, False, True])
+
+    and
+
+    >>> lineage = np.array([[[0,3], [0,1]], [[0,2], [0,4]]])
+    >>> indices = np.array([[0,1], [0,2], [0,3]])
+    >>> valid_lineage, valid_indices = validate_lineage(lineage, indices)
+    >>> print(valid_lineage)
+     array([ True, False])
+    >>> print(valid_indices)
+     array([ True, False, True])
     """
     if lineage.ndim == 2:
         # [trap, mother, daughter] becomes [[trap, mother], [trap, daughter]]
         lineage = _assoc_indices_to_3d(lineage)
+    if how == "mothers":
+        c_index = 0
+    elif how == "daughters":
+        c_index = 1
     # dtype links together trap and cell ids
-    dtype = {
-        "names": ["trap_id", "cell_id"],
-        "formats": [np.int64, np.int64],
-    }
+    dtype = {"names": ["trap_id", "cell_id"], "formats": [np.int64, np.int64]}
     lineage = np.ascontiguousarray(lineage, dtype=np.int64)
     # find (trap, cell_ids) in intersection
     inboth = np.intersect1d(lineage.view(dtype), indices.view(dtype))
-    # both mother and bud must be in indices
-    valid_lineage = np.isin(lineage.view(dtype), inboth).all(axis=1).flatten()
-    # select only pairs of mother and bud indices
-    valid_indices = np.isin(
-        indices.view(dtype), lineage[valid_lineage.flatten(), ...].view(dtype)
-    ).flatten()
-    return valid_lineage, valid_indices
+    # find valid lineage
+    valid_lineage = np.isin(lineage.view(dtype), inboth)
+    if how == "families":
+        # both mother and bud must be in indices
+        valid_lineage = valid_lineage.all(axis=1)
+    else:
+        valid_lineage = valid_lineage[:, c_index, :]
+    # find valid indices
+    possible_indices = lineage[valid_lineage.flatten(), ...]
+    if how == "families":
+        # select only pairs of mother and bud indices
+        valid_indices = np.isin(
+            indices.view(dtype), possible_indices.view(dtype)
+        )
+    else:
+        valid_indices = np.isin(
+            indices.view(dtype), possible_indices.view(dtype)[:, c_index, :]
+        )
+    return valid_lineage.flatten(), valid_indices.flatten()
 
 
 def validate_association(
diff --git a/src/postprocessor/core/reshapers/picker.py b/src/postprocessor/core/reshapers/picker.py
index 2b5414b0..c05d4279 100644
--- a/src/postprocessor/core/reshapers/picker.py
+++ b/src/postprocessor/core/reshapers/picker.py
@@ -5,7 +5,8 @@ import pandas as pd
 
 from agora.abc import ParametersABC
 from agora.io.cells import Cells
-from agora.utils.indexing_new import validate_association
+from agora.utils.indexing import validate_association
+from agora.utils.indexing_new import validate_lineage
 from agora.utils.cast import _str_to_int
 from agora.utils.kymograph import drop_mother_label
 from postprocessor.core.lineageprocess import LineageProcess
@@ -52,6 +53,23 @@ class Picker(LineageProcess):
         signal: pd.DataFrame,
         how: str,
         mothers_daughters: t.Optional[np.ndarray] = None,
+    ) -> pd.MultiIndex:
+        """
+        Return rows of a signal corresponding to either mothers, daughters,
+        or mother-daughter pairs using lineage information.
+        """
+        cells_present = drop_mother_label(signal.index)
+        mothers_daughters = self.get_lineage_information(signal)
+        _, valid_indices = validate_lineage(
+            mothers_daughters, cells_present, how
+        )
+        return signal.index[valid_indices]
+
+    def pick_by_lineage_original(
+        self,
+        signal: pd.DataFrame,
+        how: str,
+        mothers_daughters: t.Optional[np.ndarray] = None,
     ) -> pd.MultiIndex:
         """
         Return rows of a signal corresponding to either mothers, daughters,
@@ -60,6 +78,7 @@ class Picker(LineageProcess):
         cells_present = drop_mother_label(signal.index)
         mothers_daughters = self.get_lineage_information(signal)
         #: might be better if match_column defined as a string to make everything one line
+        breakpoint()
         if how == "mothers":
             _, valid_indices = validate_association(
                 mothers_daughters, cells_present, match_column=0
-- 
GitLab