diff --git a/src/agora/utils/indexing.py b/src/agora/utils/indexing.py index 9a07a6d52faa9e5c5a577b1e2776456a39174e7a..e89d2b7b9c83e4c3e3bea504876a65dbeb22f54a 100644 --- a/src/agora/utils/indexing.py +++ b/src/agora/utils/indexing.py @@ -9,6 +9,9 @@ This can be: import numpy as np import typing as t +# data type to link together trap and cell ids +i_dtype = {"names": ["trap_id", "cell_id"], "formats": [np.int64, np.int64]} + def validate_lineage( lineage: np.ndarray, indices: np.ndarray, how: str = "families" @@ -75,13 +78,8 @@ def validate_lineage( c_index = 0 elif how == "daughters": c_index = 1 - # data type to link together trap and cell ids - dtype = {"names": ["trap_id", "cell_id"], "formats": [np.int64, np.int64]} - lineage = np.ascontiguousarray(lineage, dtype=np.int64) - # find (trap, cell_ids) in intersection - inboth = np.intersect1d(lineage.view(dtype), indices.view(dtype)) # find valid lineage - valid_lineages = np.isin(lineage.view(dtype), inboth) + valid_lineages = index_isin(lineage, indices) if how == "families": # both mother and bud must be in indices valid_lineage = valid_lineages.all(axis=1) @@ -92,11 +90,12 @@ def validate_lineage( if how == "families": # select only pairs of mother and bud indices valid_indices = np.isin( - indices.view(dtype), selected_lineages.view(dtype) + indices.view(i_dtype), selected_lineages.view(i_dtype) ) else: valid_indices = np.isin( - indices.view(dtype), selected_lineages.view(dtype)[:, c_index, :] + indices.view(i_dtype), + selected_lineages.view(i_dtype)[:, c_index, :], ) if valid_indices[valid_indices].size != valid_lineage[valid_lineage].size: raise Exception( @@ -244,3 +243,18 @@ def compare_indices(x: np.ndarray, y: np.ndarray) -> np.ndarray: where a True value links two cells where all cells are the same """ return (x[..., None] == y.T[None, ...]).all(axis=1) + + +def index_isin(x: np.ndarray, y: np.ndarray) -> np.ndarray: + """ + Find those elements of x that are in y. + + Both arrays must be arrays of integer indices, + such as (trap_id, cell_id). + """ + x = np.ascontiguousarray(x, dtype=np.int64) + y = np.ascontiguousarray(y, dtype=np.int64) + xv = x.view(i_dtype) + inboth = np.intersect1d(xv, y.view(i_dtype)) + x_bool = np.isin(xv, inboth) + return x_bool diff --git a/src/agora/utils/merge.py b/src/agora/utils/merge.py index b588871be583e8ce8478c581e9b24b1456677f33..1e48feb8608d413371d0f03d7e05c7f69340cde7 100644 --- a/src/agora/utils/merge.py +++ b/src/agora/utils/merge.py @@ -9,7 +9,11 @@ import numpy as np import pandas as pd from utils_find_1st import cmp_larger, find_1st -from agora.utils.indexing import compare_indices, validate_association +from agora.utils.indexing import ( + index_isin, + compare_indices, + validate_association, +) def apply_merges(data: pd.DataFrame, merges: np.ndarray): @@ -73,23 +77,41 @@ def join_tracks_pair(target: np.ndarray, source: np.ndarray) -> np.ndarray: def group_merges(merges: np.ndarray) -> t.List[t.Tuple]: - # Return a list where the cell is present as source and target - # (multimerges) - - sources_targets = compare_indices(merges[:, 0, :], merges[:, 1, :]) - is_multimerge = sources_targets.any(axis=0) | sources_targets.any(axis=1) - is_monomerge = ~is_multimerge - - multimerge_subsets = union_find(zip(*np.where(sources_targets))) - merge_groups = [merges[np.array(tuple(x))] for x in multimerge_subsets] - - sorted_merges = list(map(sort_association, merge_groups)) - - # Ensure that source and target are at the edges - return [ - *sorted_merges, - *[[event] for event in merges[is_monomerge]], + """ + Convert merges into a list of merges for traps requiring multiple + merges and then for traps requiring single merges. + """ + left_track = merges[:, 0] + right_track = merges[:, 1] + # find traps requiring multiple merges + linr = merges[index_isin(left_track, right_track).flatten(), :] + rinl = merges[index_isin(right_track, left_track).flatten(), :] + # make unique and order merges for each trap + multi_merge = np.unique(np.concatenate((linr, rinl)), axis=0) + # find traps requiring a singe merge + single_merge = merges[ + ~index_isin(merges, multi_merge).all(axis=1).flatten(), : ] + # convert to lists of arrays + single_merge_list = [[sm] for sm in single_merge] + multi_merge_list = [ + multi_merge[multi_merge[:, 0, 0] == trap_id, ...] + for trap_id in np.unique(multi_merge[:, 0, 0]) + ] + res = [*multi_merge_list, *single_merge_list] + # # + # sources_targets = compare_indices(merges[:, 0, :], merges[:, 1, :]) + # is_multimerge = sources_targets.any(axis=0) | sources_targets.any(axis=1) + # is_monomerge = ~is_multimerge + # multimerge_subsets = union_find(zip(*np.where(sources_targets))) + # merge_groups = [merges[np.array(tuple(x))] for x in multimerge_subsets] + # sorted_merges = list(map(sort_association, merge_groups)) + # res = [ + # *sorted_merges, + # *[[event] for event in merges[is_monomerge]], + # ] + # # + return res def union_find(lsts): @@ -125,25 +147,39 @@ def sort_association(array: np.ndarray): return sorted_array -def merge_association( - association: np.ndarray, merges: np.ndarray -) -> np.ndarray: +def merge_association(lineage: np.ndarray, merges: np.ndarray) -> np.ndarray: + """Use merges to update lineage information.""" + flat_lineage = lineage.reshape(-1, 2) + left_track = merges[:, 0] + # comparison_mat = compare_indices(left_track, flat_lineage) + # valid_indices = comparison_mat.any(axis=0) + valid_lineages = index_isin(flat_lineage, left_track).flatten() + # group into multi- and single merges grouped_merges = group_merges(merges) - - flat_indices = association.reshape(-1, 2) - comparison_mat = compare_indices(merges[:, 0], flat_indices) - - valid_indices = comparison_mat.any(axis=0) - - if valid_indices.any(): # Where valid, perform transformation - replacement_d = {} - for dataset in grouped_merges: - for k in dataset: - replacement_d[tuple(k[0])] = dataset[-1][1] - - flat_indices[valid_indices] = [ - replacement_d[tuple(i)] for i in flat_indices[valid_indices] + # perform merges + if valid_lineages.any(): + # indices of each left track -> indices of rightmost track + replacement_dict = { + tuple(contig_pair[0]): merge[-1][1] + for merge in grouped_merges + for contig_pair in merge + } + # correct lineage information + # replace mother or bud index with index of rightmost track + flat_lineage[valid_lineages] = [ + replacement_dict[tuple(i)] for i in flat_lineage[valid_lineages] ] - - merged_indices = flat_indices.reshape(-1, 2, 2) - return merged_indices + # reverse flattening + new_lineage = flat_lineage.reshape(-1, 2, 2) + # remove any duplicates + new_lineage = np.unique(new_lineage, axis=0) + # buds should have only one mother + buds = new_lineage[:, 1] + ubuds, counts = np.unique(buds, axis=0, return_counts=True) + duplicate_buds = ubuds[counts > 1, :] + # duplicates + new_lineage[index_isin(buds, duplicate_buds).flatten(), ...] + # original + lineage[index_isin(lineage[:, 1], duplicate_buds).flatten(), ...] + breakpoint() + return new_lineage