diff --git a/src/agora/utils/merge.py b/src/agora/utils/merge.py index 8f3aee4eaf798b0f2f53b8920c29f35363e9fb19..47ec082100b3d6ae200d688e1c64cdbd83a67085 100644 --- a/src/agora/utils/merge.py +++ b/src/agora/utils/merge.py @@ -41,13 +41,15 @@ def apply_merges(data: pd.DataFrame, merges: np.ndarray): # Implement the merges and drop source rows. if valid_merges.any(): to_merge = data.loc[indices] - for target, source in merges[valid_merges]: - target, source = tuple(target), tuple(source) - to_merge.loc[target] = join_tracks_pair( - to_merge.loc[target].values, - to_merge.loc[source].values, + targets, sources = zip(*merges[valid_merges]) + for source, target in zip(sources, targets): + to_merge.loc[target] = copy( + join_tracks_pair( + to_merge.loc[tuple(target)].values, + to_merge.loc[tuple(source)].values, + ) ) - to_merge.drop(source, inplace=True) + to_merge.drop(map(tuple, sources), inplace=True) merged = pd.concat((merged, to_merge), names=data.index.names) return merged @@ -57,7 +59,7 @@ def join_tracks_pair(target: np.ndarray, source: np.ndarray) -> np.ndarray: """ Join two tracks and return the new value of the target. """ - target_copy = copy(target) + target_copy = target end = find_1st(target_copy[::-1], 0, cmp_larger) target_copy[-end:] = source[-end:] return target_copy