From e4e31e1df6e6543c0487fd3cd03ee420d59eabe7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <alan.munoz@ed.ac.uk> Date: Tue, 4 Oct 2022 20:06:15 +0100 Subject: [PATCH] refactor(merge): isolate validate_association --- src/agora/utils/merge.py | 59 ++++------------------------------------ 1 file changed, 5 insertions(+), 54 deletions(-) diff --git a/src/agora/utils/merge.py b/src/agora/utils/merge.py index 9a28fd59..4e4abdb8 100644 --- a/src/agora/utils/merge.py +++ b/src/agora/utils/merge.py @@ -9,6 +9,8 @@ import numpy as np import pandas as pd from utils_find_1st import cmp_larger, find_1st +from agora.utils.association import validate_association + def apply_merges(data: pd.DataFrame, merges: np.ndarray): """Split data in two, one subset for rows relevant for merging and one @@ -29,7 +31,9 @@ def apply_merges(data: pd.DataFrame, merges: np.ndarray): """ - valid_merges, indices = validate_merges(merges, np.array(list(data.index))) + valid_merges, indices = validate_association( + merges, np.array(list(data.index)) + ) # Assign non-merged merged = data.loc[~indices] @@ -49,59 +53,6 @@ def apply_merges(data: pd.DataFrame, merges: np.ndarray): return merged -def validate_merges( - merges: np.ndarray, indices: np.ndarray -) -> t.Tuple[np.ndarray, np.ndarray]: - - """Select rows from the first array that are present in both. - We use casting for fast multiindexing. - - - - - Parameters - ---------- - merges : np.ndarray - 2-D array where columns are (trap, mother, daughter) or 3-D array where - dimensions are (X, (trap,mother), (trap,daughter)) - indices : np.ndarray - 2-D array where each column is a different level. - - Returns - ------- - np.ndarray - 1-D boolean array indicating valid merge events. - np.ndarray - 1-D boolean array indicating indices involved in merging. - - Examples - -------- - FIXME: Add docs. - - """ - if merges.ndim < 3: - # Reshape into 3-D array for broadcasting if neded - merges = np.stack((merges[:, [0, 1]], merges[:, [0, 2]]), axis=1) - - # Compare existing merges with available indices - # Swap trap and label axes for the merges array to correctly cast - # valid_ndmerges = merges.swapaxes(1, 2)[..., None] == indices.T[:, None, :] - valid_ndmerges = merges[..., None] == indices.T[None, ...] - - # Broadcasting is confusing (but efficient): - # First we check the dimension across trap and cell id, to ensure both match - valid_cell_ids = valid_ndmerges.all(axis=2) - - # Then we check the merge tuples to check which cases have both target and source - valid_merges = valid_cell_ids.any(axis=2).all(axis=1) - - # Finalle we check the dimension that crosses all indices, to ensure the pair - # is present in a valid merge event. - valid_indices = valid_ndmerges[valid_merges].all(axis=2).any(axis=(0, 1)) - - return valid_merges, valid_indices - - def join_tracks_pair(target: np.ndarray, source: np.ndarray) -> np.ndarray: """ Join two tracks and return the new value of the target. -- GitLab