From e4e31e1df6e6543c0487fd3cd03ee420d59eabe7 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <alan.munoz@ed.ac.uk>
Date: Tue, 4 Oct 2022 20:06:15 +0100
Subject: [PATCH] refactor(merge): isolate validate_association

---
 src/agora/utils/merge.py | 59 ++++------------------------------------
 1 file changed, 5 insertions(+), 54 deletions(-)

diff --git a/src/agora/utils/merge.py b/src/agora/utils/merge.py
index 9a28fd59..4e4abdb8 100644
--- a/src/agora/utils/merge.py
+++ b/src/agora/utils/merge.py
@@ -9,6 +9,8 @@ import numpy as np
 import pandas as pd
 from utils_find_1st import cmp_larger, find_1st
 
+from agora.utils.association import validate_association
+
 
 def apply_merges(data: pd.DataFrame, merges: np.ndarray):
     """Split data in two, one subset for rows relevant for merging and one
@@ -29,7 +31,9 @@ def apply_merges(data: pd.DataFrame, merges: np.ndarray):
 
     """
 
-    valid_merges, indices = validate_merges(merges, np.array(list(data.index)))
+    valid_merges, indices = validate_association(
+        merges, np.array(list(data.index))
+    )
 
     # Assign non-merged
     merged = data.loc[~indices]
@@ -49,59 +53,6 @@ def apply_merges(data: pd.DataFrame, merges: np.ndarray):
     return merged
 
 
-def validate_merges(
-    merges: np.ndarray, indices: np.ndarray
-) -> t.Tuple[np.ndarray, np.ndarray]:
-
-    """Select rows from the first array that are present in both.
-    We use casting for fast multiindexing.
-
-
-
-
-    Parameters
-    ----------
-    merges : np.ndarray
-        2-D array where columns are (trap, mother, daughter) or 3-D array where
-        dimensions are (X, (trap,mother), (trap,daughter))
-    indices : np.ndarray
-        2-D array where each column is a different level.
-
-    Returns
-    -------
-    np.ndarray
-        1-D boolean array indicating valid merge events.
-    np.ndarray
-        1-D boolean array indicating indices involved in merging.
-
-    Examples
-    --------
-    FIXME: Add docs.
-
-    """
-    if merges.ndim < 3:
-        # Reshape into 3-D array for broadcasting if neded
-        merges = np.stack((merges[:, [0, 1]], merges[:, [0, 2]]), axis=1)
-
-    # Compare existing merges with available indices
-    # Swap trap and label axes for the merges array to correctly cast
-    # valid_ndmerges = merges.swapaxes(1, 2)[..., None] == indices.T[:, None, :]
-    valid_ndmerges = merges[..., None] == indices.T[None, ...]
-
-    # Broadcasting is confusing (but efficient):
-    # First we check the dimension across trap and cell id, to ensure both match
-    valid_cell_ids = valid_ndmerges.all(axis=2)
-
-    # Then we check the merge tuples to check which cases have both target and source
-    valid_merges = valid_cell_ids.any(axis=2).all(axis=1)
-
-    # Finalle we check the dimension that crosses all indices, to ensure the pair
-    # is present in a valid merge event.
-    valid_indices = valid_ndmerges[valid_merges].all(axis=2).any(axis=(0, 1))
-
-    return valid_merges, valid_indices
-
-
 def join_tracks_pair(target: np.ndarray, source: np.ndarray) -> np.ndarray:
     """
     Join two tracks and return the new value of the target.
-- 
GitLab