Skip to content
Snippets Groups Projects
Commit e1c36f24 authored by Alán Muñoz's avatar Alán Muñoz
Browse files

refactor(merge): add indexing merge functions

parent bfddc70d
No related branches found
No related tags found
No related merge requests found
...@@ -9,7 +9,7 @@ import numpy as np ...@@ -9,7 +9,7 @@ import numpy as np
import pandas as pd import pandas as pd
from utils_find_1st import cmp_larger, find_1st from utils_find_1st import cmp_larger, find_1st
from agora.utils.indexing import validate_association from agora.utils.indexing import compare_indices, validate_association
def apply_merges(data: pd.DataFrame, merges: np.ndarray): def apply_merges(data: pd.DataFrame, merges: np.ndarray):
...@@ -39,15 +39,16 @@ def apply_merges(data: pd.DataFrame, merges: np.ndarray): ...@@ -39,15 +39,16 @@ def apply_merges(data: pd.DataFrame, merges: np.ndarray):
merged = data.loc[~indices] merged = data.loc[~indices]
# Implement the merges and drop source rows. # Implement the merges and drop source rows.
# TODO Use matrices to perform merges in batch
# for ecficiency
if valid_merges.any(): if valid_merges.any():
to_merge = data.loc[indices] to_merge = data.loc[indices]
targets, sources = zip(*merges[valid_merges]) targets, sources = zip(*merges[valid_merges])
for source, target in zip(sources, targets): for source, target in zip(sources, targets):
to_merge.loc[target] = copy( target = tuple(target)
join_tracks_pair( to_merge.loc[target] = join_tracks_pair(
to_merge.loc[tuple(target)].values, to_merge.loc[target].values,
to_merge.loc[tuple(source)].values, to_merge.loc[tuple(source)].values,
)
) )
to_merge.drop(map(tuple, sources), inplace=True) to_merge.drop(map(tuple, sources), inplace=True)
...@@ -63,3 +64,79 @@ def join_tracks_pair(target: np.ndarray, source: np.ndarray) -> np.ndarray: ...@@ -63,3 +64,79 @@ def join_tracks_pair(target: np.ndarray, source: np.ndarray) -> np.ndarray:
end = find_1st(target_copy[::-1], 0, cmp_larger) end = find_1st(target_copy[::-1], 0, cmp_larger)
target_copy[-end:] = source[-end:] target_copy[-end:] = source[-end:]
return target_copy return target_copy
def group_merges(merges: np.ndarray) -> t.List[t.Tuple]:
# Return a list where the cell is present as source and target
# (multimerges)
sources_targets = compare_indices(merges[:, 0, :], merges[:, 1, :])
is_multimerge = sources_targets.any(axis=0) | sources_targets.any(axis=1)
is_monomerge = ~is_multimerge
multimerge_subsets = union_find(zip(*np.where(sources_targets)))
merge_groups = [merges[np.array(tuple(x))] for x in multimerge_subsets]
sorted_merges = list(map(sort_association, merge_groups))
# Ensure that source and target are at the edges
return [
*sorted_merges,
*[[event] for event in merges[is_monomerge]],
]
def union_find(lsts):
sets = [set(lst) for lst in lsts if lst]
merged = True
while merged:
merged = False
results = []
while sets:
common, rest = sets[0], sets[1:]
sets = []
for x in rest:
if x.isdisjoint(common):
sets.append(x)
else:
merged = True
common |= x
results.append(common)
sets = results
return sets
def sort_association(array: np.ndarray):
# Sort the internal associations
order = np.where(
(array[:, 0, ..., None] == array[:, 1].T[None, ...]).all(axis=1)
)
res = []
[res.append(x) for x in np.flip(order).flatten() if x not in res]
sorted_array = array[np.array(res)]
return sorted_array
def merge_association(
association: np.ndarray, merges: np.ndarray
) -> np.ndarray:
grouped_merges = group_merges(merges)
flat_indices = association.reshape(-1, 2)
comparison_mat = compare_indices(merges[:, 0], flat_indices)
valid_indices = comparison_mat.any(axis=0)
replacement_d = {}
for dataset in grouped_merges:
for k in dataset:
replacement_d[tuple(k[0])] = dataset[-1][1]
flat_indices[valid_indices] = [
replacement_d[tuple(i)] for i in flat_indices[valid_indices]
]
merged_indices = flat_indices.reshape(-1, 2, 2)
return merged_indices
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment