Skip to content
Snippets Groups Projects
Commit 914eae97 authored by pswain's avatar pswain
Browse files

streamlined indexing simplifying broadcasting

parent ee850164
No related branches found
No related tags found
No related merge requests found
...@@ -17,9 +17,9 @@ def validate_association( ...@@ -17,9 +17,9 @@ def validate_association(
match_column: t.Optional[int] = None, match_column: t.Optional[int] = None,
) -> t.Tuple[np.ndarray, np.ndarray]: ) -> t.Tuple[np.ndarray, np.ndarray]:
""" """
Identify matches between two arrays by matching rows. Identify matches between two arrays by comparing rows.
We use broadcasting for fast multi-indexing, generalising for lineage dynamics. We use broadcasting for speed.
Parameters Parameters
---------- ----------
...@@ -30,68 +30,72 @@ def validate_association( ...@@ -30,68 +30,72 @@ def validate_association(
indices : np.ndarray indices : np.ndarray
a 2D array where each column is a different level, such as (trap_id, cell_label). This should not include mother_label. a 2D array where each column is a different level, such as (trap_id, cell_label). This should not include mother_label.
match_column: int match_column: int
int indicating a specific column that is required to match (i.e. 0-1 for target-source when trying to merge tracklets or mother-bud for lineage) If 0, match mothers; if 1 match daughters.
must be present in indices. If None, match both mothers and daughters.
If None, one match suffices for the resultant indices
vector to be True.
Returns Returns
------- -------
valid_association: boolean np.ndarray valid_association: boolean np.ndarray
1D array indicating valid elements in association. 1D array indicating elements in association with matches.
valid_indices: boolean np.ndarray valid_indices: boolean np.ndarray
1D array indicating valid elements in indices. 1D array indicating elements in indices with matches.
Examples Examples
-------- --------
>>> import numpy as np >>> import numpy as np
>>> from agora.utils.indexing import validate_association >>> from agora.utils.indexing import validate_association
>>> association = np.array([ [[0, 1], [0, 3]], [[0, 1], [0, 4]] ]) >>> association = np.array([ [[0, 1], [0, 3]], [[0, 1], [0, 4]], [[0, 1], [0, 6]], [[0, 4], [0, 7]] ])
>>> indices = np.array([ [0, 1], [0, 2], [0, 3]]) >>> indices = np.array([ [0, 1], [0, 2], [0, 3]])
>>> print(indices.T) >>> print(indices.T)
>>> valid_associations, valid_indices = validate_association(merges, indices) >>> valid_associations, valid_indices = validate_association(association, indices)
>>> print(valid_associations) >>> print(valid_associations)
array([ True, False]) array([ True, False, False, False])
>>> print(valid_indices) >>> print(valid_indices)
array([ True, False, True]) array([ True, False, True])
""" """
if association.ndim == 2: if association.ndim == 2:
# reshape into 3D array for broadcasting # reshape into 3D array for broadcasting
# [trap, mother, daughter] becomes [[trap, mother], [trap, daughter]] for each trap # for each trap, [trap, mother, daughter] becomes
# [[trap, mother], [trap, daughter]]
association = _assoc_indices_to_3d(association) association = _assoc_indices_to_3d(association)
# compare existing association with available indices # use broadcasting to compare association with indices
# swap trap and cell_label axes for the indices array to correctly broadcast # swap trap and cell_label axes for correct broadcasting
# compare [[trap, mother], [trap, daughter]] with [trap, cell_label] for all traps in association and for all [trap, cell_label] pairs in indices indicesT = indices.T
# compare each of [[trap, mother], [trap, daughter]] for all traps
# in association with [trap, cell_label] for all traps in indices
valid_ndassociation = ( valid_ndassociation = (
association[..., np.newaxis] == indices.T[np.newaxis, ...] association[..., np.newaxis] == indicesT[np.newaxis, ...]
) )
# broadcasting is confusing (but efficient): # make True comparisons have both trap_ids and cell labels matching
# first, we check the dimension across trap and cell id to ensure both match
# 1. find only those comparisons with both trap_ids and cell labels matching - they are now marked as True
valid_cell_ids = valid_ndassociation.all(axis=2) valid_cell_ids = valid_ndassociation.all(axis=2)
if match_column is None: if match_column is None:
# then, we check the merge tuples to check which have both target and source # 1. find matches in association
# 2. keep only those comparisons that match at least one row in indices # make True comparisons match at least one row in indices
# 3. keep those that have a match for both mother and daughter in association va_intermediate = valid_cell_ids.any(axis=2)
valid_association = valid_cell_ids.any(axis=2).all(axis=1) # make True comparisons have both mother and daughter matching rows in indices
# finally, we check the dimension that crosses all indices to ensure the pair valid_association = va_intermediate.all(axis=1)
# is present in a valid merge event # 2. find matches in indices
valid_indices = ( # make True comparisons match for at least one mother or daughter in association
valid_ndassociation[valid_association].all(axis=2).any(axis=(0, 1)) ind_intermediate = valid_cell_ids.any(axis=1)
) # make True comparisons match for at least one row in association
myversion = valid_cell_ids.any(axis=1).any(axis=0) valid_indices = ind_intermediate.any(axis=0)
1 / 0 # OLD
# valid_indices = (
# valid_ndassociation[valid_association].all(axis=2).any(axis=(0, 1))
# )
else: else:
# we fetch specific indices if we aim for the ones with one present # match_column selects mothers if 0 and daughters if 1
# make True match at least one row in indices
valid_association = valid_cell_ids[:, match_column].any(axis=1)
# make True match at least one row in association
valid_indices = valid_cell_ids[:, match_column].any(axis=0) valid_indices = valid_cell_ids[:, match_column].any(axis=0)
# calid association then becomes a boolean array: True means that there is a # OLD
# match (match_column) between that cell and the index # valid_association = (
valid_association = ( # valid_cell_ids[:, match_column] & valid_indices
valid_cell_ids[:, match_column] & valid_indices # ).any(axis=1)
).any(axis=1)
return valid_association, valid_indices return valid_association, valid_indices
...@@ -105,8 +109,6 @@ def _assoc_indices_to_3d(ndarray: np.ndarray): ...@@ -105,8 +109,6 @@ def _assoc_indices_to_3d(ndarray: np.ndarray):
[ [0, 1, 3], [0, 1, 4] ] [ [0, 1, 3], [0, 1, 4] ]
becomes becomes
[ [[0, 1], [0, 3]], [[0, 1], [0, 4]] ] [ [[0, 1], [0, 3]], [[0, 1], [0, 4]] ]
This is used to convert a signal MultiIndex before comparing association.
""" """
result = ndarray result = ndarray
if len(ndarray) and ndarray.ndim > 1: if len(ndarray) and ndarray.ndim > 1:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment