Skip to content
Snippets Groups Projects
Commit a43b512d authored by pswain's avatar pswain
Browse files

fixed merging causing one bud to have multiple mothers

parent ad973bf1
No related branches found
No related tags found
No related merge requests found
...@@ -85,8 +85,11 @@ def validate_lineage( ...@@ -85,8 +85,11 @@ def validate_lineage(
valid_lineage = valid_lineages.all(axis=1) valid_lineage = valid_lineages.all(axis=1)
else: else:
valid_lineage = valid_lineages[:, c_index, :] valid_lineage = valid_lineages[:, c_index, :]
flat_valid_lineage = valid_lineage.flatten()
# find valid indices # find valid indices
selected_lineages = lineage[valid_lineage.flatten(), ...] selected_lineages = np.ascontiguousarray(
lineage[flat_valid_lineage, ...], dtype=np.int64
)
if how == "families": if how == "families":
# select only pairs of mother and bud indices # select only pairs of mother and bud indices
valid_indices = np.isin( valid_indices = np.isin(
...@@ -97,12 +100,19 @@ def validate_lineage( ...@@ -97,12 +100,19 @@ def validate_lineage(
indices.view(i_dtype), indices.view(i_dtype),
selected_lineages.view(i_dtype)[:, c_index, :], selected_lineages.view(i_dtype)[:, c_index, :],
) )
if valid_indices[valid_indices].size != valid_lineage[valid_lineage].size: flat_valid_indices = valid_indices.flatten()
if (
indices[flat_valid_indices, :].size
!= np.unique(
lineage[flat_valid_lineage, :].reshape(-1, 2), axis=0
).size
):
# all unique indices in valid_lineages should be in valid_indices
raise Exception( raise Exception(
"Error in validate_lineage: " "Error in validate_lineage: "
"lineage information is likely not unique." "lineage information is likely not unique."
) )
return valid_lineage.flatten(), valid_indices.flatten() return flat_valid_lineage, flat_valid_indices
def validate_association( def validate_association(
......
...@@ -81,11 +81,11 @@ def group_merges(merges: np.ndarray) -> t.List[t.Tuple]: ...@@ -81,11 +81,11 @@ def group_merges(merges: np.ndarray) -> t.List[t.Tuple]:
Convert merges into a list of merges for traps requiring multiple Convert merges into a list of merges for traps requiring multiple
merges and then for traps requiring single merges. merges and then for traps requiring single merges.
""" """
left_track = merges[:, 0] left_tracks = merges[:, 0]
right_track = merges[:, 1] right_tracks = merges[:, 1]
# find traps requiring multiple merges # find traps requiring multiple merges
linr = merges[index_isin(left_track, right_track).flatten(), :] linr = merges[index_isin(left_tracks, right_tracks).flatten(), :]
rinl = merges[index_isin(right_track, left_track).flatten(), :] rinl = merges[index_isin(right_tracks, left_tracks).flatten(), :]
# make unique and order merges for each trap # make unique and order merges for each trap
multi_merge = np.unique(np.concatenate((linr, rinl)), axis=0) multi_merge = np.unique(np.concatenate((linr, rinl)), axis=0)
# find traps requiring a singe merge # find traps requiring a singe merge
...@@ -99,18 +99,6 @@ def group_merges(merges: np.ndarray) -> t.List[t.Tuple]: ...@@ -99,18 +99,6 @@ def group_merges(merges: np.ndarray) -> t.List[t.Tuple]:
for trap_id in np.unique(multi_merge[:, 0, 0]) for trap_id in np.unique(multi_merge[:, 0, 0])
] ]
res = [*multi_merge_list, *single_merge_list] res = [*multi_merge_list, *single_merge_list]
# #
# sources_targets = compare_indices(merges[:, 0, :], merges[:, 1, :])
# is_multimerge = sources_targets.any(axis=0) | sources_targets.any(axis=1)
# is_monomerge = ~is_multimerge
# multimerge_subsets = union_find(zip(*np.where(sources_targets)))
# merge_groups = [merges[np.array(tuple(x))] for x in multimerge_subsets]
# sorted_merges = list(map(sort_association, merge_groups))
# res = [
# *sorted_merges,
# *[[event] for event in merges[is_monomerge]],
# ]
# #
return res return res
...@@ -150,36 +138,45 @@ def sort_association(array: np.ndarray): ...@@ -150,36 +138,45 @@ def sort_association(array: np.ndarray):
def merge_association(lineage: np.ndarray, merges: np.ndarray) -> np.ndarray: def merge_association(lineage: np.ndarray, merges: np.ndarray) -> np.ndarray:
"""Use merges to update lineage information.""" """Use merges to update lineage information."""
flat_lineage = lineage.reshape(-1, 2) flat_lineage = lineage.reshape(-1, 2)
left_track = merges[:, 0] bud_mother_dict = {
# comparison_mat = compare_indices(left_track, flat_lineage) tuple(bud): mother for bud, mother in zip(lineage[:, 1], lineage[:, 0])
# valid_indices = comparison_mat.any(axis=0) }
valid_lineages = index_isin(flat_lineage, left_track).flatten() left_tracks = merges[:, 0]
# group into multi- and single merges # find left tracks that are in lineages
valid_lineages = index_isin(flat_lineage, left_tracks).flatten()
# group into multi- and then single merges
grouped_merges = group_merges(merges) grouped_merges = group_merges(merges)
# perform merges # perform merges
if valid_lineages.any(): if valid_lineages.any():
# indices of each left track -> indices of rightmost track # indices of each left track -> indices of rightmost right track
replacement_dict = { replacement_dict = {
tuple(contig_pair[0]): merge[-1][1] tuple(contig_pair[0]): merge[-1][1]
for merge in grouped_merges for merge in grouped_merges
for contig_pair in merge for contig_pair in merge
} }
# if both key and value are buds, they must have the same mother
buds = lineage[:, 1]
incorrect_merges = [
key
for key in replacement_dict
if np.any(index_isin(buds, replacement_dict[key]).flatten())
and np.any(index_isin(buds, key).flatten())
and not np.array_equal(
bud_mother_dict[key],
bud_mother_dict[tuple(replacement_dict[key])],
)
]
# reassign incorrect merges so that they have no affect
for key in incorrect_merges:
replacement_dict[key] = key
# correct lineage information # correct lineage information
# replace mother or bud index with index of rightmost track # replace mother or bud index with index of rightmost track
flat_lineage[valid_lineages] = [ flat_lineage[valid_lineages] = [
replacement_dict[tuple(i)] for i in flat_lineage[valid_lineages] replacement_dict[tuple(index)]
for index in flat_lineage[valid_lineages]
] ]
# reverse flattening # reverse flattening
new_lineage = flat_lineage.reshape(-1, 2, 2) new_lineage = flat_lineage.reshape(-1, 2, 2)
# remove any duplicates # remove any duplicates
new_lineage = np.unique(new_lineage, axis=0) new_lineage = np.unique(new_lineage, axis=0)
# buds should have only one mother
buds = new_lineage[:, 1]
ubuds, counts = np.unique(buds, axis=0, return_counts=True)
duplicate_buds = ubuds[counts > 1, :]
# duplicates
new_lineage[index_isin(buds, duplicate_buds).flatten(), ...]
# original
lineage[index_isin(lineage[:, 1], duplicate_buds).flatten(), ...]
breakpoint()
return new_lineage return new_lineage
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment