diff --git a/src/aliby/utils/vis_tools.py b/src/aliby/utils/vis_tools.py
index 3d4caef4e1c7c2fb460df42fb64423e6200d34b3..7d688211cf5630a945ee2a2617f2a0cfcfd034bb 100644
--- a/src/aliby/utils/vis_tools.py
+++ b/src/aliby/utils/vis_tools.py
@@ -63,7 +63,8 @@ def get_cellmasks_at_times(results_path: str, timepoints: t.List[int] = [0]):
 
 
 def concatenate_dims(ndarray, axis1: int, axis2: int):
-    return np.concatenate(np.moveaxis(ndarray, axis1, 0), axis=axis2)
+    axis2 = len(ndarray.shape) + axis2 if axis2 < 0 else axis2
+    return np.concatenate(np.moveaxis(ndarray, axis1, 0), axis=axis2 - 1)
 
 
 def get_tile_mask_pairs(