From 914eae9713e626ca6f774a1bef9f058c11dd20a5 Mon Sep 17 00:00:00 2001 From: Peter Swain <peter.swain@ed.ac.uk> Date: Fri, 28 Apr 2023 12:30:30 +0100 Subject: [PATCH] streamlined indexing simplifying broadcasting --- src/agora/utils/indexing.py | 78 +++++++++++++++++++------------------ 1 file changed, 40 insertions(+), 38 deletions(-) diff --git a/src/agora/utils/indexing.py b/src/agora/utils/indexing.py index 24fa004e..0e50be02 100644 --- a/src/agora/utils/indexing.py +++ b/src/agora/utils/indexing.py @@ -17,9 +17,9 @@ def validate_association( match_column: t.Optional[int] = None, ) -> t.Tuple[np.ndarray, np.ndarray]: """ - Identify matches between two arrays by matching rows. + Identify matches between two arrays by comparing rows. - We use broadcasting for fast multi-indexing, generalising for lineage dynamics. + We use broadcasting for speed. Parameters ---------- @@ -30,68 +30,72 @@ def validate_association( indices : np.ndarray a 2D array where each column is a different level, such as (trap_id, cell_label). This should not include mother_label. match_column: int - int indicating a specific column that is required to match (i.e. 0-1 for target-source when trying to merge tracklets or mother-bud for lineage) - must be present in indices. - If None, one match suffices for the resultant indices - vector to be True. + If 0, match mothers; if 1 match daughters. + If None, match both mothers and daughters. Returns ------- valid_association: boolean np.ndarray - 1D array indicating valid elements in association. + 1D array indicating elements in association with matches. valid_indices: boolean np.ndarray - 1D array indicating valid elements in indices. + 1D array indicating elements in indices with matches. Examples -------- >>> import numpy as np >>> from agora.utils.indexing import validate_association - >>> association = np.array([ [[0, 1], [0, 3]], [[0, 1], [0, 4]] ]) + >>> association = np.array([ [[0, 1], [0, 3]], [[0, 1], [0, 4]], [[0, 1], [0, 6]], [[0, 4], [0, 7]] ]) >>> indices = np.array([ [0, 1], [0, 2], [0, 3]]) >>> print(indices.T) - >>> valid_associations, valid_indices = validate_association(merges, indices) + >>> valid_associations, valid_indices = validate_association(association, indices) >>> print(valid_associations) - array([ True, False]) + array([ True, False, False, False]) >>> print(valid_indices) array([ True, False, True]) """ if association.ndim == 2: # reshape into 3D array for broadcasting - # [trap, mother, daughter] becomes [[trap, mother], [trap, daughter]] for each trap + # for each trap, [trap, mother, daughter] becomes + # [[trap, mother], [trap, daughter]] association = _assoc_indices_to_3d(association) - # compare existing association with available indices - # swap trap and cell_label axes for the indices array to correctly broadcast - # compare [[trap, mother], [trap, daughter]] with [trap, cell_label] for all traps in association and for all [trap, cell_label] pairs in indices + # use broadcasting to compare association with indices + # swap trap and cell_label axes for correct broadcasting + indicesT = indices.T + # compare each of [[trap, mother], [trap, daughter]] for all traps + # in association with [trap, cell_label] for all traps in indices valid_ndassociation = ( - association[..., np.newaxis] == indices.T[np.newaxis, ...] + association[..., np.newaxis] == indicesT[np.newaxis, ...] ) - # broadcasting is confusing (but efficient): - # first, we check the dimension across trap and cell id to ensure both match - # 1. find only those comparisons with both trap_ids and cell labels matching - they are now marked as True + # make True comparisons have both trap_ids and cell labels matching valid_cell_ids = valid_ndassociation.all(axis=2) if match_column is None: - # then, we check the merge tuples to check which have both target and source - # 2. keep only those comparisons that match at least one row in indices - # 3. keep those that have a match for both mother and daughter in association - valid_association = valid_cell_ids.any(axis=2).all(axis=1) - # finally, we check the dimension that crosses all indices to ensure the pair - # is present in a valid merge event - valid_indices = ( - valid_ndassociation[valid_association].all(axis=2).any(axis=(0, 1)) - ) - myversion = valid_cell_ids.any(axis=1).any(axis=0) - 1 / 0 + # 1. find matches in association + # make True comparisons match at least one row in indices + va_intermediate = valid_cell_ids.any(axis=2) + # make True comparisons have both mother and daughter matching rows in indices + valid_association = va_intermediate.all(axis=1) + # 2. find matches in indices + # make True comparisons match for at least one mother or daughter in association + ind_intermediate = valid_cell_ids.any(axis=1) + # make True comparisons match for at least one row in association + valid_indices = ind_intermediate.any(axis=0) + # OLD + # valid_indices = ( + # valid_ndassociation[valid_association].all(axis=2).any(axis=(0, 1)) + # ) else: - # we fetch specific indices if we aim for the ones with one present + # match_column selects mothers if 0 and daughters if 1 + # make True match at least one row in indices + valid_association = valid_cell_ids[:, match_column].any(axis=1) + # make True match at least one row in association valid_indices = valid_cell_ids[:, match_column].any(axis=0) - # calid association then becomes a boolean array: True means that there is a - # match (match_column) between that cell and the index - valid_association = ( - valid_cell_ids[:, match_column] & valid_indices - ).any(axis=1) + # OLD + # valid_association = ( + # valid_cell_ids[:, match_column] & valid_indices + # ).any(axis=1) return valid_association, valid_indices @@ -105,8 +109,6 @@ def _assoc_indices_to_3d(ndarray: np.ndarray): [ [0, 1, 3], [0, 1, 4] ] becomes [ [[0, 1], [0, 3]], [[0, 1], [0, 4]] ] - - This is used to convert a signal MultiIndex before comparing association. """ result = ndarray if len(ndarray) and ndarray.ndim > 1: -- GitLab