diff --git a/src/agora/utils/association.py b/src/agora/utils/association.py
index 656aa787fbf4029f964bfc4ba971b80b2f1177e4..d523e427c5fed741204061295317371d533b70df 100644
--- a/src/agora/utils/association.py
+++ b/src/agora/utils/association.py
@@ -24,9 +24,10 @@ def validate_association(
         ----------
         association : np.ndarray
             2-D array where columns are (trap, mother, daughter) or 3-D array where
-            dimensions are (X, (trap,mother), (trap,daughter))
+            dimensions are (X,trap,2), containing tuples ((trap,mother), (trap,daughter))
+            across the 3rd dimension.
         indices : np.ndarray
-            2-D array where each column is a different level.
+            2-D array where each column is a different level. This should not include mother_label.
         match_column: int
             int indicating a specific column is required to match (i.e.
             0-1 for target-source when trying to merge tracklets or mother-bud for lineage)
@@ -68,11 +69,12 @@ def validate_association(
     [ True False False] [ True  True False]
 
     """
-    if association.ndim < 3:
+    if association.ndim == 2:
         # Reshape into 3-D array for broadcasting if neded
-        association = np.stack(
-            (association[:, [0, 1]], association[:, [0, 2]]), axis=1
-        )
+        # association = np.stack(
+        #     (association[:, [0, 1]], association[:, [0, 2]]), axis=1
+        # )
+        association = last_col_as_rows(association)
 
     # Compare existing association with available indices
     # Swap trap and label axes for the association array to correctly cast
@@ -100,3 +102,20 @@ def validate_association(
         ).any(axis=1)
 
     return valid_association, valid_indices
+
+
+def last_col_as_rows(ndarray: np.ndarray):
+    """
+    Convert the last column to a new row while repeating all previous indices.
+
+    This is useful when converting a signal multiindex before comparing association.
+    """
+    columns = np.arange(ndarray.shape[1])
+
+    return np.stack(
+        (
+            ndarray[:, np.delete(columns, -1)],
+            ndarray[:, np.delete(columns, -2)],
+        ),
+        axis=1,
+    )