diff --git a/src/agora/utils/indexing.py b/src/agora/utils/indexing.py index e89d2b7b9c83e4c3e3bea504876a65dbeb22f54a..3147d5df8d869009a62e31b1970cae42d063d044 100644 --- a/src/agora/utils/indexing.py +++ b/src/agora/utils/indexing.py @@ -85,8 +85,11 @@ def validate_lineage( valid_lineage = valid_lineages.all(axis=1) else: valid_lineage = valid_lineages[:, c_index, :] + flat_valid_lineage = valid_lineage.flatten() # find valid indices - selected_lineages = lineage[valid_lineage.flatten(), ...] + selected_lineages = np.ascontiguousarray( + lineage[flat_valid_lineage, ...], dtype=np.int64 + ) if how == "families": # select only pairs of mother and bud indices valid_indices = np.isin( @@ -97,12 +100,19 @@ def validate_lineage( indices.view(i_dtype), selected_lineages.view(i_dtype)[:, c_index, :], ) - if valid_indices[valid_indices].size != valid_lineage[valid_lineage].size: + flat_valid_indices = valid_indices.flatten() + if ( + indices[flat_valid_indices, :].size + != np.unique( + lineage[flat_valid_lineage, :].reshape(-1, 2), axis=0 + ).size + ): + # all unique indices in valid_lineages should be in valid_indices raise Exception( "Error in validate_lineage: " "lineage information is likely not unique." ) - return valid_lineage.flatten(), valid_indices.flatten() + return flat_valid_lineage, flat_valid_indices def validate_association( diff --git a/src/agora/utils/merge.py b/src/agora/utils/merge.py index 1e48feb8608d413371d0f03d7e05c7f69340cde7..70769d9233867c37fbba90a0fd726ed1446edc3a 100644 --- a/src/agora/utils/merge.py +++ b/src/agora/utils/merge.py @@ -81,11 +81,11 @@ def group_merges(merges: np.ndarray) -> t.List[t.Tuple]: Convert merges into a list of merges for traps requiring multiple merges and then for traps requiring single merges. """ - left_track = merges[:, 0] - right_track = merges[:, 1] + left_tracks = merges[:, 0] + right_tracks = merges[:, 1] # find traps requiring multiple merges - linr = merges[index_isin(left_track, right_track).flatten(), :] - rinl = merges[index_isin(right_track, left_track).flatten(), :] + linr = merges[index_isin(left_tracks, right_tracks).flatten(), :] + rinl = merges[index_isin(right_tracks, left_tracks).flatten(), :] # make unique and order merges for each trap multi_merge = np.unique(np.concatenate((linr, rinl)), axis=0) # find traps requiring a singe merge @@ -99,18 +99,6 @@ def group_merges(merges: np.ndarray) -> t.List[t.Tuple]: for trap_id in np.unique(multi_merge[:, 0, 0]) ] res = [*multi_merge_list, *single_merge_list] - # # - # sources_targets = compare_indices(merges[:, 0, :], merges[:, 1, :]) - # is_multimerge = sources_targets.any(axis=0) | sources_targets.any(axis=1) - # is_monomerge = ~is_multimerge - # multimerge_subsets = union_find(zip(*np.where(sources_targets))) - # merge_groups = [merges[np.array(tuple(x))] for x in multimerge_subsets] - # sorted_merges = list(map(sort_association, merge_groups)) - # res = [ - # *sorted_merges, - # *[[event] for event in merges[is_monomerge]], - # ] - # # return res @@ -150,36 +138,45 @@ def sort_association(array: np.ndarray): def merge_association(lineage: np.ndarray, merges: np.ndarray) -> np.ndarray: """Use merges to update lineage information.""" flat_lineage = lineage.reshape(-1, 2) - left_track = merges[:, 0] - # comparison_mat = compare_indices(left_track, flat_lineage) - # valid_indices = comparison_mat.any(axis=0) - valid_lineages = index_isin(flat_lineage, left_track).flatten() - # group into multi- and single merges + bud_mother_dict = { + tuple(bud): mother for bud, mother in zip(lineage[:, 1], lineage[:, 0]) + } + left_tracks = merges[:, 0] + # find left tracks that are in lineages + valid_lineages = index_isin(flat_lineage, left_tracks).flatten() + # group into multi- and then single merges grouped_merges = group_merges(merges) # perform merges if valid_lineages.any(): - # indices of each left track -> indices of rightmost track + # indices of each left track -> indices of rightmost right track replacement_dict = { tuple(contig_pair[0]): merge[-1][1] for merge in grouped_merges for contig_pair in merge } + # if both key and value are buds, they must have the same mother + buds = lineage[:, 1] + incorrect_merges = [ + key + for key in replacement_dict + if np.any(index_isin(buds, replacement_dict[key]).flatten()) + and np.any(index_isin(buds, key).flatten()) + and not np.array_equal( + bud_mother_dict[key], + bud_mother_dict[tuple(replacement_dict[key])], + ) + ] + # reassign incorrect merges so that they have no affect + for key in incorrect_merges: + replacement_dict[key] = key # correct lineage information # replace mother or bud index with index of rightmost track flat_lineage[valid_lineages] = [ - replacement_dict[tuple(i)] for i in flat_lineage[valid_lineages] + replacement_dict[tuple(index)] + for index in flat_lineage[valid_lineages] ] # reverse flattening new_lineage = flat_lineage.reshape(-1, 2, 2) # remove any duplicates new_lineage = np.unique(new_lineage, axis=0) - # buds should have only one mother - buds = new_lineage[:, 1] - ubuds, counts = np.unique(buds, axis=0, return_counts=True) - duplicate_buds = ubuds[counts > 1, :] - # duplicates - new_lineage[index_isin(buds, duplicate_buds).flatten(), ...] - # original - lineage[index_isin(lineage[:, 1], duplicate_buds).flatten(), ...] - breakpoint() return new_lineage