Skip to content
Snippets Groups Projects
Commit a43b512d authored by pswain's avatar pswain
Browse files

fixed merging causing one bud to have multiple mothers

parent ad973bf1
No related branches found
No related tags found
No related merge requests found
......@@ -85,8 +85,11 @@ def validate_lineage(
valid_lineage = valid_lineages.all(axis=1)
else:
valid_lineage = valid_lineages[:, c_index, :]
flat_valid_lineage = valid_lineage.flatten()
# find valid indices
selected_lineages = lineage[valid_lineage.flatten(), ...]
selected_lineages = np.ascontiguousarray(
lineage[flat_valid_lineage, ...], dtype=np.int64
)
if how == "families":
# select only pairs of mother and bud indices
valid_indices = np.isin(
......@@ -97,12 +100,19 @@ def validate_lineage(
indices.view(i_dtype),
selected_lineages.view(i_dtype)[:, c_index, :],
)
if valid_indices[valid_indices].size != valid_lineage[valid_lineage].size:
flat_valid_indices = valid_indices.flatten()
if (
indices[flat_valid_indices, :].size
!= np.unique(
lineage[flat_valid_lineage, :].reshape(-1, 2), axis=0
).size
):
# all unique indices in valid_lineages should be in valid_indices
raise Exception(
"Error in validate_lineage: "
"lineage information is likely not unique."
)
return valid_lineage.flatten(), valid_indices.flatten()
return flat_valid_lineage, flat_valid_indices
def validate_association(
......
......@@ -81,11 +81,11 @@ def group_merges(merges: np.ndarray) -> t.List[t.Tuple]:
Convert merges into a list of merges for traps requiring multiple
merges and then for traps requiring single merges.
"""
left_track = merges[:, 0]
right_track = merges[:, 1]
left_tracks = merges[:, 0]
right_tracks = merges[:, 1]
# find traps requiring multiple merges
linr = merges[index_isin(left_track, right_track).flatten(), :]
rinl = merges[index_isin(right_track, left_track).flatten(), :]
linr = merges[index_isin(left_tracks, right_tracks).flatten(), :]
rinl = merges[index_isin(right_tracks, left_tracks).flatten(), :]
# make unique and order merges for each trap
multi_merge = np.unique(np.concatenate((linr, rinl)), axis=0)
# find traps requiring a singe merge
......@@ -99,18 +99,6 @@ def group_merges(merges: np.ndarray) -> t.List[t.Tuple]:
for trap_id in np.unique(multi_merge[:, 0, 0])
]
res = [*multi_merge_list, *single_merge_list]
# #
# sources_targets = compare_indices(merges[:, 0, :], merges[:, 1, :])
# is_multimerge = sources_targets.any(axis=0) | sources_targets.any(axis=1)
# is_monomerge = ~is_multimerge
# multimerge_subsets = union_find(zip(*np.where(sources_targets)))
# merge_groups = [merges[np.array(tuple(x))] for x in multimerge_subsets]
# sorted_merges = list(map(sort_association, merge_groups))
# res = [
# *sorted_merges,
# *[[event] for event in merges[is_monomerge]],
# ]
# #
return res
......@@ -150,36 +138,45 @@ def sort_association(array: np.ndarray):
def merge_association(lineage: np.ndarray, merges: np.ndarray) -> np.ndarray:
"""Use merges to update lineage information."""
flat_lineage = lineage.reshape(-1, 2)
left_track = merges[:, 0]
# comparison_mat = compare_indices(left_track, flat_lineage)
# valid_indices = comparison_mat.any(axis=0)
valid_lineages = index_isin(flat_lineage, left_track).flatten()
# group into multi- and single merges
bud_mother_dict = {
tuple(bud): mother for bud, mother in zip(lineage[:, 1], lineage[:, 0])
}
left_tracks = merges[:, 0]
# find left tracks that are in lineages
valid_lineages = index_isin(flat_lineage, left_tracks).flatten()
# group into multi- and then single merges
grouped_merges = group_merges(merges)
# perform merges
if valid_lineages.any():
# indices of each left track -> indices of rightmost track
# indices of each left track -> indices of rightmost right track
replacement_dict = {
tuple(contig_pair[0]): merge[-1][1]
for merge in grouped_merges
for contig_pair in merge
}
# if both key and value are buds, they must have the same mother
buds = lineage[:, 1]
incorrect_merges = [
key
for key in replacement_dict
if np.any(index_isin(buds, replacement_dict[key]).flatten())
and np.any(index_isin(buds, key).flatten())
and not np.array_equal(
bud_mother_dict[key],
bud_mother_dict[tuple(replacement_dict[key])],
)
]
# reassign incorrect merges so that they have no affect
for key in incorrect_merges:
replacement_dict[key] = key
# correct lineage information
# replace mother or bud index with index of rightmost track
flat_lineage[valid_lineages] = [
replacement_dict[tuple(i)] for i in flat_lineage[valid_lineages]
replacement_dict[tuple(index)]
for index in flat_lineage[valid_lineages]
]
# reverse flattening
new_lineage = flat_lineage.reshape(-1, 2, 2)
# remove any duplicates
new_lineage = np.unique(new_lineage, axis=0)
# buds should have only one mother
buds = new_lineage[:, 1]
ubuds, counts = np.unique(buds, axis=0, return_counts=True)
duplicate_buds = ubuds[counts > 1, :]
# duplicates
new_lineage[index_isin(buds, duplicate_buds).flatten(), ...]
# original
lineage[index_isin(lineage[:, 1], duplicate_buds).flatten(), ...]
breakpoint()
return new_lineage
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment