Skip to content
Snippets Groups Projects
Commit 08877db3 authored by Alán Muñoz's avatar Alán Muñoz
Browse files

refactor(association): Move format handling to end

parent 29b39345
No related branches found
No related tags found
No related merge requests found
......@@ -24,9 +24,10 @@ def validate_association(
----------
association : np.ndarray
2-D array where columns are (trap, mother, daughter) or 3-D array where
dimensions are (X, (trap,mother), (trap,daughter))
dimensions are (X,trap,2), containing tuples ((trap,mother), (trap,daughter))
across the 3rd dimension.
indices : np.ndarray
2-D array where each column is a different level.
2-D array where each column is a different level. This should not include mother_label.
match_column: int
int indicating a specific column is required to match (i.e.
0-1 for target-source when trying to merge tracklets or mother-bud for lineage)
......@@ -68,11 +69,12 @@ def validate_association(
[ True False False] [ True True False]
"""
if association.ndim < 3:
if association.ndim == 2:
# Reshape into 3-D array for broadcasting if neded
association = np.stack(
(association[:, [0, 1]], association[:, [0, 2]]), axis=1
)
# association = np.stack(
# (association[:, [0, 1]], association[:, [0, 2]]), axis=1
# )
association = last_col_as_rows(association)
# Compare existing association with available indices
# Swap trap and label axes for the association array to correctly cast
......@@ -100,3 +102,20 @@ def validate_association(
).any(axis=1)
return valid_association, valid_indices
def last_col_as_rows(ndarray: np.ndarray):
"""
Convert the last column to a new row while repeating all previous indices.
This is useful when converting a signal multiindex before comparing association.
"""
columns = np.arange(ndarray.shape[1])
return np.stack(
(
ndarray[:, np.delete(columns, -1)],
ndarray[:, np.delete(columns, -2)],
),
axis=1,
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment