diff --git a/src/postprocessor/core/processor.py b/src/postprocessor/core/processor.py
index 6f92799ec38b3928daa27e6f2ebfde394fa7f2b2..d4cd0367a7b1d0a9d29c65a1c3ea90263b1af0a1 100644
--- a/src/postprocessor/core/processor.py
+++ b/src/postprocessor/core/processor.py
@@ -14,8 +14,8 @@ from agora.io.writer import Writer
 from agora.utils.indexing import (
     _3d_index_to_2d,
     _assoc_indices_to_3d,
-    compare_indices,
 )
+from agora.utils.merge import merge_association
 from agora.utils.kymograph import get_index_as_np
 from postprocessor.core.abc import get_parameters, get_process
 from postprocessor.core.lineageprocess import (
@@ -168,24 +168,9 @@ class PostProcessor(ProcessABC):
 
         if merges.any():  # Update lineages after merge events
 
-            grouped_merges = group_merges(merges)
-
-            flat_indices = lineage.reshape(-1, 2)
-            comparison_mat = compare_indices(merges[:, 0], flat_indices)
-
-            valid_indices = comparison_mat.any(axis=0)
-
-            replacement_d = {}
-            for dataset in grouped_merges:
-                for k in dataset:
-                    replacement_d[tuple(k[0])] = dataset[-1][1]
-
-            flat_indices[valid_indices] = [
-                replacement_d[tuple(i)] for i in flat_indices[valid_indices]
-            ]
-
+            merged_indices = merge_association(lineage, merges)
             # Remove repeated labels post-merging
-            lineage_merged = np.unique(flat_indices.reshape(-1, 2, 2), axis=0)
+            lineage_merged = np.unique(merged_indices, axis=0)
 
         self.lineage = _3d_index_to_2d(
             lineage_merged if len(lineage_merged) else lineage
@@ -307,56 +292,3 @@ class PostProcessor(ProcessABC):
         if not result.any().any():
             logging.getLogger("aliby").warning(f"Record {path} is empty")
         self._writer.write(path, result, meta=metadata, overwrite="overwrite")
-
-
-def union_find(lsts):
-    sets = [set(lst) for lst in lsts if lst]
-    merged = True
-    while merged:
-        merged = False
-        results = []
-        while sets:
-            common, rest = sets[0], sets[1:]
-            sets = []
-            for x in rest:
-                if x.isdisjoint(common):
-                    sets.append(x)
-                else:
-                    merged = True
-                    common |= x
-            results.append(common)
-        sets = results
-    return sets
-
-
-def group_merges(merges: np.ndarray) -> t.List[t.Tuple]:
-    # Return a list where the cell is present as source and target
-    # (multimerges)
-
-    sources_targets = compare_indices(merges[:, 0, :], merges[:, 1, :])
-    is_multimerge = sources_targets.any(axis=0) | sources_targets.any(axis=1)
-    is_monomerge = ~is_multimerge
-
-    multimerge_subsets = union_find(zip(*np.where(sources_targets)))
-    merge_groups = [merges[np.array(tuple(x))] for x in multimerge_subsets]
-
-    sorted_merges = list(map(sort_association, merge_groups))
-
-    # Ensure that source and target are at the edges
-    return [
-        *sorted_merges,
-        *[[event] for event in merges[is_monomerge]],
-    ]
-
-
-def sort_association(array: np.ndarray):
-    # Sort the internal associations
-
-    order = np.where(
-        (array[:, 0, ..., None] == array[:, 1].T[None, ...]).all(axis=1)
-    )
-
-    res = []
-    [res.append(x) for x in np.flip(order).flatten() if x not in res]
-    sorted_array = array[np.array(res)]
-    return sorted_array