From eb00e2b2548b3551a5d081f882e0d4687fcec35e Mon Sep 17 00:00:00 2001
From: pswain <peter.swain@ed.ac.uk>
Date: Thu, 21 Dec 2023 10:33:44 +0000
Subject: [PATCH] feature(signal): make check on lineage assignment by
 validate_lineage optional

---
 src/agora/io/signal.py      | 16 ++++++++-----
 src/agora/utils/indexing.py | 45 +++++++++++++++++++++----------------
 2 files changed, 36 insertions(+), 25 deletions(-)

diff --git a/src/agora/io/signal.py b/src/agora/io/signal.py
index 83907dc..e9b0dd1 100644
--- a/src/agora/io/signal.py
+++ b/src/agora/io/signal.py
@@ -142,15 +142,15 @@ class Signal(BridgeH5):
         with h5py.File(self.filename, "r") as f:
             if lineage_location not in f:
                 lineage_location = "postprocessing/lineage"
-            tile_mo_da = f[lineage_location]
-            if isinstance(tile_mo_da, h5py.Dataset):
-                lineage = tile_mo_da[()]
+            traps_mothers_daughters = f[lineage_location]
+            if isinstance(traps_mothers_daughters, h5py.Dataset):
+                lineage = traps_mothers_daughters[()]
             else:
                 lineage = np.array(
                     (
-                        tile_mo_da["trap"],
-                        tile_mo_da["mother_label"],
-                        tile_mo_da["daughter_label"],
+                        traps_mothers_daughters["trap"],
+                        traps_mothers_daughters["mother_label"],
+                        traps_mothers_daughters["daughter_label"],
                     )
                 ).T
         return lineage
@@ -249,6 +249,7 @@ class Signal(BridgeH5):
         dataset: str or t.List[str],
         in_minutes: bool = True,
         lineage: bool = False,
+        run_lineage_check: bool = True,
         **kwargs,
     ) -> pd.DataFrame or t.List[pd.DataFrame]:
         """
@@ -262,6 +263,8 @@ class Signal(BridgeH5):
             If True, convert column headings to times in minutes.
         lineage: boolean
             If True, add mother_label to index.
+        run_lineage_check: boolean
+            If True, raise exception if a likely error in the lineage assignment.
         """
         try:
             if isinstance(dataset, str):
@@ -279,6 +282,7 @@ class Signal(BridgeH5):
                                 lineage,
                                 indices=np.array(df.index.to_list()),
                                 how="daughters",
+                                run_lineage_check=run_lineage_check,
                             )
                             mother_label[valid_indices] = lineage[
                                 valid_lineage, 1
diff --git a/src/agora/utils/indexing.py b/src/agora/utils/indexing.py
index faa8e25..1b20a71 100644
--- a/src/agora/utils/indexing.py
+++ b/src/agora/utils/indexing.py
@@ -5,7 +5,10 @@ i_dtype = {"names": ["trap_id", "cell_id"], "formats": [np.int64, np.int64]}
 
 
 def validate_lineage(
-    lineage: np.ndarray, indices: np.ndarray, how: str = "families"
+    lineage: np.ndarray,
+    indices: np.ndarray,
+    how: str = "families",
+    run_lineage_check: bool = True,
 ):
     """
     Identify mother-bud pairs both in lineage and a Signal's indices.
@@ -28,6 +31,9 @@ def validate_lineage(
         If "mothers", matches indicate mothers from mother-bud pairs;
         If "daughters", matches indicate daughters from mother-bud pairs;
         If "families", matches indicate mothers and daughters in mother-bud pairs.
+    run_lineage_check: bool
+        If True, check for errors in the lineage assignment such as a daughter
+        being assigned two mothers.
 
     Returns
     -------
@@ -85,24 +91,25 @@ def validate_lineage(
         valid_indices = index_isin(indices, selected_lineages[:, c_index, :])
     flat_valid_indices = valid_indices.flatten()
     # test for mismatch
-    if how == "families":
-        test_mismatch = (
-            indices[flat_valid_indices, :].size
-            != np.unique(
-                lineage[flat_valid_lineage, :].reshape(-1, 2), axis=0
-            ).size
-        )
-    else:
-        test_mismatch = (
-            indices[flat_valid_indices, :].size
-            != lineage[flat_valid_lineage, c_index, :].size
-        )
-    if test_mismatch:
-        # all unique indices in valid_lineages should be in valid_indices
-        raise Exception(
-            "Error in validate_lineage: "
-            "lineage information is likely not unique."
-        )
+    if run_lineage_check:
+        if how == "families":
+            test_mismatch = (
+                indices[flat_valid_indices, :].size
+                != np.unique(
+                    lineage[flat_valid_lineage, :].reshape(-1, 2), axis=0
+                ).size
+            )
+        else:
+            test_mismatch = (
+                indices[flat_valid_indices, :].size
+                != lineage[flat_valid_lineage, c_index, :].size
+            )
+        if test_mismatch:
+            # all unique indices in valid_lineages should be in valid_indices
+            raise Exception(
+                "Error in validate_lineage: "
+                "lineage information is likely not unique."
+            )
     return flat_valid_lineage, flat_valid_indices
 
 
-- 
GitLab