diff --git a/src/agora/io/signal.py b/src/agora/io/signal.py
index 5c52f2d34fc3d732021646e7f26545920e1ca599..7f8af04818c5b8dd29647f806ddadf0935e5c46d 100644
--- a/src/agora/io/signal.py
+++ b/src/agora/io/signal.py
@@ -435,6 +435,10 @@ def validate_merges(merges: np.ndarray, indices: np.ndarray) -> np.ndarray:
     # valid_ndmerges = merges.swapaxes(1, 2)[..., None] == indices.T[:, None, :]
     valid_ndmerges = merges[..., None] == indices.T[None, ...]
 
-    valid_merges = merges[valid_ndmerges.all(axis=2).any(axis=2).any(axis=1)]
+    # Casting is confusing (but efficient):
+    # - First we check the dimension across trap and cell id, to ensure both match
+    # - Then we check the dimension that crosses all indices, to ensure the pair is present there
+    # - Finally we check the merge tuples to check which cases have both target and source
+    valid_merges = merges[valid_ndmerges.all(axis=2).any(axis=2).all(axis=1)]
     # valid_merges = merges[allnan.any(axis=1)]
     return valid_merges