diff --git a/src/agora/utils/indexing.py b/src/agora/utils/indexing.py
index 9a07a6d52faa9e5c5a577b1e2776456a39174e7a..e89d2b7b9c83e4c3e3bea504876a65dbeb22f54a 100644
--- a/src/agora/utils/indexing.py
+++ b/src/agora/utils/indexing.py
@@ -9,6 +9,9 @@ This can be:
 import numpy as np
 import typing as t
 
+# data type to link together trap and cell ids
+i_dtype = {"names": ["trap_id", "cell_id"], "formats": [np.int64, np.int64]}
+
 
 def validate_lineage(
     lineage: np.ndarray, indices: np.ndarray, how: str = "families"
@@ -75,13 +78,8 @@ def validate_lineage(
         c_index = 0
     elif how == "daughters":
         c_index = 1
-    # data type to link together trap and cell ids
-    dtype = {"names": ["trap_id", "cell_id"], "formats": [np.int64, np.int64]}
-    lineage = np.ascontiguousarray(lineage, dtype=np.int64)
-    # find (trap, cell_ids) in intersection
-    inboth = np.intersect1d(lineage.view(dtype), indices.view(dtype))
     # find valid lineage
-    valid_lineages = np.isin(lineage.view(dtype), inboth)
+    valid_lineages = index_isin(lineage, indices)
     if how == "families":
         # both mother and bud must be in indices
         valid_lineage = valid_lineages.all(axis=1)
@@ -92,11 +90,12 @@ def validate_lineage(
     if how == "families":
         # select only pairs of mother and bud indices
         valid_indices = np.isin(
-            indices.view(dtype), selected_lineages.view(dtype)
+            indices.view(i_dtype), selected_lineages.view(i_dtype)
         )
     else:
         valid_indices = np.isin(
-            indices.view(dtype), selected_lineages.view(dtype)[:, c_index, :]
+            indices.view(i_dtype),
+            selected_lineages.view(i_dtype)[:, c_index, :],
         )
     if valid_indices[valid_indices].size != valid_lineage[valid_lineage].size:
         raise Exception(
@@ -244,3 +243,18 @@ def compare_indices(x: np.ndarray, y: np.ndarray) -> np.ndarray:
     where a True value links two cells where all cells are the same
     """
     return (x[..., None] == y.T[None, ...]).all(axis=1)
+
+
+def index_isin(x: np.ndarray, y: np.ndarray) -> np.ndarray:
+    """
+    Find those elements of x that are in y.
+
+    Both arrays must be arrays of integer indices,
+    such as (trap_id, cell_id).
+    """
+    x = np.ascontiguousarray(x, dtype=np.int64)
+    y = np.ascontiguousarray(y, dtype=np.int64)
+    xv = x.view(i_dtype)
+    inboth = np.intersect1d(xv, y.view(i_dtype))
+    x_bool = np.isin(xv, inboth)
+    return x_bool
diff --git a/src/agora/utils/merge.py b/src/agora/utils/merge.py
index b588871be583e8ce8478c581e9b24b1456677f33..1e48feb8608d413371d0f03d7e05c7f69340cde7 100644
--- a/src/agora/utils/merge.py
+++ b/src/agora/utils/merge.py
@@ -9,7 +9,11 @@ import numpy as np
 import pandas as pd
 from utils_find_1st import cmp_larger, find_1st
 
-from agora.utils.indexing import compare_indices, validate_association
+from agora.utils.indexing import (
+    index_isin,
+    compare_indices,
+    validate_association,
+)
 
 
 def apply_merges(data: pd.DataFrame, merges: np.ndarray):
@@ -73,23 +77,41 @@ def join_tracks_pair(target: np.ndarray, source: np.ndarray) -> np.ndarray:
 
 
 def group_merges(merges: np.ndarray) -> t.List[t.Tuple]:
-    # Return a list where the cell is present as source and target
-    # (multimerges)
-
-    sources_targets = compare_indices(merges[:, 0, :], merges[:, 1, :])
-    is_multimerge = sources_targets.any(axis=0) | sources_targets.any(axis=1)
-    is_monomerge = ~is_multimerge
-
-    multimerge_subsets = union_find(zip(*np.where(sources_targets)))
-    merge_groups = [merges[np.array(tuple(x))] for x in multimerge_subsets]
-
-    sorted_merges = list(map(sort_association, merge_groups))
-
-    # Ensure that source and target are at the edges
-    return [
-        *sorted_merges,
-        *[[event] for event in merges[is_monomerge]],
+    """
+    Convert merges into a list of merges for traps requiring multiple
+    merges and then for traps requiring single merges.
+    """
+    left_track = merges[:, 0]
+    right_track = merges[:, 1]
+    # find traps requiring multiple merges
+    linr = merges[index_isin(left_track, right_track).flatten(), :]
+    rinl = merges[index_isin(right_track, left_track).flatten(), :]
+    # make unique and order merges for each trap
+    multi_merge = np.unique(np.concatenate((linr, rinl)), axis=0)
+    # find traps requiring a singe merge
+    single_merge = merges[
+        ~index_isin(merges, multi_merge).all(axis=1).flatten(), :
     ]
+    # convert to lists of arrays
+    single_merge_list = [[sm] for sm in single_merge]
+    multi_merge_list = [
+        multi_merge[multi_merge[:, 0, 0] == trap_id, ...]
+        for trap_id in np.unique(multi_merge[:, 0, 0])
+    ]
+    res = [*multi_merge_list, *single_merge_list]
+    # #
+    # sources_targets = compare_indices(merges[:, 0, :], merges[:, 1, :])
+    # is_multimerge = sources_targets.any(axis=0) | sources_targets.any(axis=1)
+    # is_monomerge = ~is_multimerge
+    # multimerge_subsets = union_find(zip(*np.where(sources_targets)))
+    # merge_groups = [merges[np.array(tuple(x))] for x in multimerge_subsets]
+    # sorted_merges = list(map(sort_association, merge_groups))
+    # res = [
+    #     *sorted_merges,
+    #     *[[event] for event in merges[is_monomerge]],
+    # ]
+    # #
+    return res
 
 
 def union_find(lsts):
@@ -125,25 +147,39 @@ def sort_association(array: np.ndarray):
     return sorted_array
 
 
-def merge_association(
-    association: np.ndarray, merges: np.ndarray
-) -> np.ndarray:
+def merge_association(lineage: np.ndarray, merges: np.ndarray) -> np.ndarray:
+    """Use merges to update lineage information."""
+    flat_lineage = lineage.reshape(-1, 2)
+    left_track = merges[:, 0]
+    # comparison_mat = compare_indices(left_track, flat_lineage)
+    # valid_indices = comparison_mat.any(axis=0)
+    valid_lineages = index_isin(flat_lineage, left_track).flatten()
+    # group into multi- and single merges
     grouped_merges = group_merges(merges)
-
-    flat_indices = association.reshape(-1, 2)
-    comparison_mat = compare_indices(merges[:, 0], flat_indices)
-
-    valid_indices = comparison_mat.any(axis=0)
-
-    if valid_indices.any():  # Where valid, perform transformation
-        replacement_d = {}
-        for dataset in grouped_merges:
-            for k in dataset:
-                replacement_d[tuple(k[0])] = dataset[-1][1]
-
-        flat_indices[valid_indices] = [
-            replacement_d[tuple(i)] for i in flat_indices[valid_indices]
+    # perform merges
+    if valid_lineages.any():
+        # indices of each left track -> indices of rightmost track
+        replacement_dict = {
+            tuple(contig_pair[0]): merge[-1][1]
+            for merge in grouped_merges
+            for contig_pair in merge
+        }
+        # correct lineage information
+        # replace mother or bud index with index of rightmost track
+        flat_lineage[valid_lineages] = [
+            replacement_dict[tuple(i)] for i in flat_lineage[valid_lineages]
         ]
-
-    merged_indices = flat_indices.reshape(-1, 2, 2)
-    return merged_indices
+    # reverse flattening
+    new_lineage = flat_lineage.reshape(-1, 2, 2)
+    # remove any duplicates
+    new_lineage = np.unique(new_lineage, axis=0)
+    # buds should have only one mother
+    buds = new_lineage[:, 1]
+    ubuds, counts = np.unique(buds, axis=0, return_counts=True)
+    duplicate_buds = ubuds[counts > 1, :]
+    # duplicates
+    new_lineage[index_isin(buds, duplicate_buds).flatten(), ...]
+    # original
+    lineage[index_isin(lineage[:, 1], duplicate_buds).flatten(), ...]
+    breakpoint()
+    return new_lineage