From 914eae9713e626ca6f774a1bef9f058c11dd20a5 Mon Sep 17 00:00:00 2001
From: Peter Swain <peter.swain@ed.ac.uk>
Date: Fri, 28 Apr 2023 12:30:30 +0100
Subject: [PATCH] streamlined indexing simplifying broadcasting

---
 src/agora/utils/indexing.py | 78 +++++++++++++++++++------------------
 1 file changed, 40 insertions(+), 38 deletions(-)

diff --git a/src/agora/utils/indexing.py b/src/agora/utils/indexing.py
index 24fa004e..0e50be02 100644
--- a/src/agora/utils/indexing.py
+++ b/src/agora/utils/indexing.py
@@ -17,9 +17,9 @@ def validate_association(
     match_column: t.Optional[int] = None,
 ) -> t.Tuple[np.ndarray, np.ndarray]:
     """
-    Identify matches between two arrays by matching rows.
+    Identify matches between two arrays by comparing rows.
 
-    We use broadcasting for fast multi-indexing, generalising for lineage dynamics.
+    We use broadcasting for speed.
 
     Parameters
     ----------
@@ -30,68 +30,72 @@ def validate_association(
     indices : np.ndarray
         a 2D array where each column is a different level, such as (trap_id, cell_label). This should not include mother_label.
     match_column: int
-        int indicating a specific column that is required to match (i.e. 0-1 for target-source when trying to merge tracklets or mother-bud for lineage)
-        must be present in indices.
-        If None, one match suffices for the resultant indices
-        vector to be True.
+        If 0, match mothers; if 1 match daughters.
+        If None, match both mothers and daughters.
 
     Returns
     -------
     valid_association: boolean np.ndarray
-        1D array indicating valid elements in association.
+        1D array indicating elements in association with matches.
     valid_indices: boolean np.ndarray
-        1D array indicating valid elements in indices.
+        1D array indicating elements in indices with matches.
 
     Examples
     --------
     >>> import numpy as np
     >>> from agora.utils.indexing import validate_association
 
-    >>> association = np.array([ [[0, 1], [0, 3]], [[0, 1], [0, 4]] ])
+    >>> association = np.array([ [[0, 1], [0, 3]], [[0, 1], [0, 4]], [[0, 1], [0, 6]], [[0, 4], [0, 7]] ])
     >>> indices = np.array([ [0, 1], [0, 2], [0, 3]])
     >>> print(indices.T)
 
-    >>> valid_associations, valid_indices = validate_association(merges, indices)
+    >>> valid_associations, valid_indices = validate_association(association, indices)
 
     >>> print(valid_associations)
-    array([ True, False])
+    array([ True, False, False, False])
     >>> print(valid_indices)
     array([ True, False, True])
     """
     if association.ndim == 2:
         # reshape into 3D array for broadcasting
-        # [trap, mother, daughter] becomes [[trap, mother], [trap, daughter]] for each trap
+        # for each trap, [trap, mother, daughter] becomes
+        #  [[trap, mother], [trap, daughter]]
         association = _assoc_indices_to_3d(association)
-    # compare existing association with available indices
-    # swap trap and cell_label axes for the indices array to correctly broadcast
-    # compare [[trap, mother], [trap, daughter]] with [trap, cell_label] for all traps in association and for all [trap, cell_label] pairs in indices
+    # use broadcasting to compare association with indices
+    # swap trap and cell_label axes for correct broadcasting
+    indicesT = indices.T
+    # compare each of [[trap, mother], [trap, daughter]] for all traps
+    # in association with [trap, cell_label] for all traps in indices
     valid_ndassociation = (
-        association[..., np.newaxis] == indices.T[np.newaxis, ...]
+        association[..., np.newaxis] == indicesT[np.newaxis, ...]
     )
-    # broadcasting is confusing (but efficient):
-    # first, we check the dimension across trap and cell id to ensure both match
-    # 1. find only those comparisons with both trap_ids and cell labels matching - they are now marked as True
+    # make True comparisons have both trap_ids and cell labels matching
     valid_cell_ids = valid_ndassociation.all(axis=2)
     if match_column is None:
-        # then, we check the merge tuples to check which have both target and source
-        # 2. keep only those comparisons that match at least one row in indices
-        # 3. keep those that have a match for both mother and daughter in association
-        valid_association = valid_cell_ids.any(axis=2).all(axis=1)
-        # finally, we check the dimension that crosses all indices to ensure the pair
-        # is present in a valid merge event
-        valid_indices = (
-            valid_ndassociation[valid_association].all(axis=2).any(axis=(0, 1))
-        )
-        myversion = valid_cell_ids.any(axis=1).any(axis=0)
-        1 / 0
+        # 1. find matches in association
+        # make True comparisons match at least one row in indices
+        va_intermediate = valid_cell_ids.any(axis=2)
+        # make True comparisons have both mother and daughter matching rows in indices
+        valid_association = va_intermediate.all(axis=1)
+        # 2. find matches in indices
+        # make True comparisons match for at least one mother or daughter in association
+        ind_intermediate = valid_cell_ids.any(axis=1)
+        # make True comparisons match for at least one row in association
+        valid_indices = ind_intermediate.any(axis=0)
+        # OLD
+        # valid_indices = (
+        #     valid_ndassociation[valid_association].all(axis=2).any(axis=(0, 1))
+        # )
     else:
-        # we fetch specific indices if we aim for the ones with one present
+        # match_column selects mothers if 0 and daughters if 1
+        # make True match at least one row in indices
+        valid_association = valid_cell_ids[:, match_column].any(axis=1)
+        # make True match at least one row in association
         valid_indices = valid_cell_ids[:, match_column].any(axis=0)
-        # calid association then becomes a boolean array: True means that there is a
-        # match (match_column) between that cell and the index
-        valid_association = (
-            valid_cell_ids[:, match_column] & valid_indices
-        ).any(axis=1)
+        # OLD
+        # valid_association = (
+        #     valid_cell_ids[:, match_column] & valid_indices
+        # ).any(axis=1)
     return valid_association, valid_indices
 
 
@@ -105,8 +109,6 @@ def _assoc_indices_to_3d(ndarray: np.ndarray):
     [ [0, 1, 3], [0, 1, 4] ]
     becomes
     [ [[0, 1], [0, 3]], [[0, 1], [0, 4]] ]
-
-    This is used to convert a signal MultiIndex before comparing association.
     """
     result = ndarray
     if len(ndarray) and ndarray.ndim > 1:
-- 
GitLab