diff --git a/src/agora/utils/indexing_new.py b/src/agora/utils/indexing_new.py index 7c749088c3276d95ccb9c54b8d6b7b99ea6b4fc9..f2ddf0b89d1c778b24b153ffddc4c96c715b1ca0 100644 --- a/src/agora/utils/indexing_new.py +++ b/src/agora/utils/indexing_new.py @@ -11,7 +11,33 @@ import numpy as np import typing as t -def validate_association_new( +def validate_lineage( + association: np.ndarray, + indices: np.ndarray, + match_column: t.Optional[int] = None, +) -> t.Tuple[np.ndarray, np.ndarray]: + if association.ndim == 2: + # reshape into 3D array for broadcasting + # for each trap, [trap, mother, daughter] becomes + # [[trap, mother], [trap, daughter]] + association = _assoc_indices_to_3d(association) + + dtype = {"names": ["trap_id", "cell_id"], "formats": [int, int]} + intersections = np.intersect1d( + association.view(dtype), indices.view(dtype) + ) + inter_array = intersections.view(int).reshape(-1, 2) + valid_association = np.isin(association.view(dtype), intersections).all( + axis=1 + ) + valid_indices = np.intersect1d( + association[valid_association.flatten(), ...].view(dtype), + indices.view(dtype), + ) + return valid_association, valid_indices + + +def validate_association( association: np.ndarray, indices: np.ndarray, match_column: t.Optional[int] = None, @@ -82,15 +108,19 @@ def validate_association_new( indicesT = indices.T # compare each of [[trap, mother], [trap, daughter]] for all traps # in association with [trap, cell_label] for all traps in indices + # association is no_traps x 2 x 2; indices is no_traps X 2 + # valid_ndassociation is no_traps_association x 2 x 2 x no_traps_indices valid_ndassociation = ( association[..., np.newaxis] == indicesT[np.newaxis, ...] ) # find matches in association ### - # make True comparisons have both trap_ids and cell labels matching + # make True comparisons with both trap_ids and cell labels matching + # compare trap_ids and cell_ids for each pair of traps valid_cell_ids = valid_ndassociation.all(axis=2) if match_column is None: # make True comparisons match at least one row in indices + # at least one cell_id matches va_intermediate = valid_cell_ids.any(axis=2) # make True comparisons have both mother and bud matching rows in indices valid_association = va_intermediate.all(axis=1) @@ -105,7 +135,7 @@ def validate_association_new( valid_cell_ids_va = valid_ndassociation[valid_association].all(axis=2) if match_column is None: # make True comparisons match either a mother or a bud in association - valid_indices = valid_cell_ids_va.any(axis=1)[0] + valid_indices = valid_cell_ids_va.any(axis=(0, 1)) else: valid_indices = valid_cell_ids_va[:, match_column][0] @@ -135,8 +165,12 @@ def validate_association_new( valid_cell_ids_a[:, match_column] & valid_indices ).any(axis=1) - assert valid_association != valid_association_a, "valid_association error" - assert valid_indices != valid_indices_a, "valid_indices error" + assert np.array_equal( + valid_association, valid_association_a + ), "valid_association error" + assert np.array_equal( + valid_indices, valid_indices_a + ), "valid_indices error" return valid_association, valid_indices diff --git a/src/postprocessor/core/reshapers/picker.py b/src/postprocessor/core/reshapers/picker.py index 74f3a4b0f6ab64b4f68e3258e02d35fe92b217f2..2b5414b0529f8298529f2a4851fc2d9679cee9e2 100644 --- a/src/postprocessor/core/reshapers/picker.py +++ b/src/postprocessor/core/reshapers/picker.py @@ -5,7 +5,7 @@ import pandas as pd from agora.abc import ParametersABC from agora.io.cells import Cells -from agora.utils.indexing import validate_association +from agora.utils.indexing_new import validate_association from agora.utils.cast import _str_to_int from agora.utils.kymograph import drop_mother_label from postprocessor.core.lineageprocess import LineageProcess