diff --git a/src/agora/utils/association.py b/src/agora/utils/association.py index 656aa787fbf4029f964bfc4ba971b80b2f1177e4..d523e427c5fed741204061295317371d533b70df 100644 --- a/src/agora/utils/association.py +++ b/src/agora/utils/association.py @@ -24,9 +24,10 @@ def validate_association( ---------- association : np.ndarray 2-D array where columns are (trap, mother, daughter) or 3-D array where - dimensions are (X, (trap,mother), (trap,daughter)) + dimensions are (X,trap,2), containing tuples ((trap,mother), (trap,daughter)) + across the 3rd dimension. indices : np.ndarray - 2-D array where each column is a different level. + 2-D array where each column is a different level. This should not include mother_label. match_column: int int indicating a specific column is required to match (i.e. 0-1 for target-source when trying to merge tracklets or mother-bud for lineage) @@ -68,11 +69,12 @@ def validate_association( [ True False False] [ True True False] """ - if association.ndim < 3: + if association.ndim == 2: # Reshape into 3-D array for broadcasting if neded - association = np.stack( - (association[:, [0, 1]], association[:, [0, 2]]), axis=1 - ) + # association = np.stack( + # (association[:, [0, 1]], association[:, [0, 2]]), axis=1 + # ) + association = last_col_as_rows(association) # Compare existing association with available indices # Swap trap and label axes for the association array to correctly cast @@ -100,3 +102,20 @@ def validate_association( ).any(axis=1) return valid_association, valid_indices + + +def last_col_as_rows(ndarray: np.ndarray): + """ + Convert the last column to a new row while repeating all previous indices. + + This is useful when converting a signal multiindex before comparing association. + """ + columns = np.arange(ndarray.shape[1]) + + return np.stack( + ( + ndarray[:, np.delete(columns, -1)], + ndarray[:, np.delete(columns, -2)], + ), + axis=1, + )