diff --git a/src/agora/io/signal.py b/src/agora/io/signal.py index 5c52f2d34fc3d732021646e7f26545920e1ca599..7f8af04818c5b8dd29647f806ddadf0935e5c46d 100644 --- a/src/agora/io/signal.py +++ b/src/agora/io/signal.py @@ -435,6 +435,10 @@ def validate_merges(merges: np.ndarray, indices: np.ndarray) -> np.ndarray: # valid_ndmerges = merges.swapaxes(1, 2)[..., None] == indices.T[:, None, :] valid_ndmerges = merges[..., None] == indices.T[None, ...] - valid_merges = merges[valid_ndmerges.all(axis=2).any(axis=2).any(axis=1)] + # Casting is confusing (but efficient): + # - First we check the dimension across trap and cell id, to ensure both match + # - Then we check the dimension that crosses all indices, to ensure the pair is present there + # - Finally we check the merge tuples to check which cases have both target and source + valid_merges = merges[valid_ndmerges.all(axis=2).any(axis=2).all(axis=1)] # valid_merges = merges[allnan.any(axis=1)] return valid_merges