Skip to content
Snippets Groups Projects
Commit 2766f814 authored by pswain's avatar pswain
Browse files

new attempt at validate_association

parent cc84c443
No related branches found
No related tags found
No related merge requests found
......@@ -11,7 +11,33 @@ import numpy as np
import typing as t
def validate_association_new(
def validate_lineage(
association: np.ndarray,
indices: np.ndarray,
match_column: t.Optional[int] = None,
) -> t.Tuple[np.ndarray, np.ndarray]:
if association.ndim == 2:
# reshape into 3D array for broadcasting
# for each trap, [trap, mother, daughter] becomes
# [[trap, mother], [trap, daughter]]
association = _assoc_indices_to_3d(association)
dtype = {"names": ["trap_id", "cell_id"], "formats": [int, int]}
intersections = np.intersect1d(
association.view(dtype), indices.view(dtype)
)
inter_array = intersections.view(int).reshape(-1, 2)
valid_association = np.isin(association.view(dtype), intersections).all(
axis=1
)
valid_indices = np.intersect1d(
association[valid_association.flatten(), ...].view(dtype),
indices.view(dtype),
)
return valid_association, valid_indices
def validate_association(
association: np.ndarray,
indices: np.ndarray,
match_column: t.Optional[int] = None,
......@@ -82,15 +108,19 @@ def validate_association_new(
indicesT = indices.T
# compare each of [[trap, mother], [trap, daughter]] for all traps
# in association with [trap, cell_label] for all traps in indices
# association is no_traps x 2 x 2; indices is no_traps X 2
# valid_ndassociation is no_traps_association x 2 x 2 x no_traps_indices
valid_ndassociation = (
association[..., np.newaxis] == indicesT[np.newaxis, ...]
)
# find matches in association
###
# make True comparisons have both trap_ids and cell labels matching
# make True comparisons with both trap_ids and cell labels matching
# compare trap_ids and cell_ids for each pair of traps
valid_cell_ids = valid_ndassociation.all(axis=2)
if match_column is None:
# make True comparisons match at least one row in indices
# at least one cell_id matches
va_intermediate = valid_cell_ids.any(axis=2)
# make True comparisons have both mother and bud matching rows in indices
valid_association = va_intermediate.all(axis=1)
......@@ -105,7 +135,7 @@ def validate_association_new(
valid_cell_ids_va = valid_ndassociation[valid_association].all(axis=2)
if match_column is None:
# make True comparisons match either a mother or a bud in association
valid_indices = valid_cell_ids_va.any(axis=1)[0]
valid_indices = valid_cell_ids_va.any(axis=(0, 1))
else:
valid_indices = valid_cell_ids_va[:, match_column][0]
......@@ -135,8 +165,12 @@ def validate_association_new(
valid_cell_ids_a[:, match_column] & valid_indices
).any(axis=1)
assert valid_association != valid_association_a, "valid_association error"
assert valid_indices != valid_indices_a, "valid_indices error"
assert np.array_equal(
valid_association, valid_association_a
), "valid_association error"
assert np.array_equal(
valid_indices, valid_indices_a
), "valid_indices error"
return valid_association, valid_indices
......
......@@ -5,7 +5,7 @@ import pandas as pd
from agora.abc import ParametersABC
from agora.io.cells import Cells
from agora.utils.indexing import validate_association
from agora.utils.indexing_new import validate_association
from agora.utils.cast import _str_to_int
from agora.utils.kymograph import drop_mother_label
from postprocessor.core.lineageprocess import LineageProcess
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment