diff --git a/src/agora/io/cells.py b/src/agora/io/cells.py
index e72d32f4c13d88272a4962394b408b0a76b47bec..182590b7ee2637a8516c902f79f6c763c47454dd 100644
--- a/src/agora/io/cells.py
+++ b/src/agora/io/cells.py
@@ -165,7 +165,7 @@ class Cells:
     def at_times(self, timepoints: t.Iterable[int], kind="mask"):
         return [
             [
-                np.dstack(tile_masks) if len(tile_masks) else []
+                np.stack(tile_masks) if len(tile_masks) else []
                 for tile_masks in self.at_time(tp, kind=kind).values()
             ]
             for tp in timepoints
diff --git a/src/aliby/tile/traps.py b/src/aliby/tile/traps.py
index 4eddeb7e45a0f39ea0de28c865b79b685061da5b..1777289258daacdb919cb2be44077fae4b426d67 100644
--- a/src/aliby/tile/traps.py
+++ b/src/aliby/tile/traps.py
@@ -120,7 +120,7 @@ def segment_traps(
         for x, y in centroids
     ]
     # make a mean template from all the found traps
-    mean_template = np.dstack(candidate_templates).astype(int).mean(axis=-1)
+    mean_template = np.stack(candidate_templates).astype(int).mean(axis=0)
 
     # find traps using the mean trap template
     traps = identify_trap_locations(
@@ -243,4 +243,4 @@ def stretch_image(image):
     maxval = np.percentile(image, 98)
     image = np.clip(image, minval, maxval)
     image = (image - minval) / (maxval - minval)
-    return image
\ No newline at end of file
+    return image
diff --git a/src/aliby/utils/imageViewer.py b/src/aliby/utils/imageViewer.py
index 54392a0904dab9e925864df4b43fd0233ee20d0d..b08ebb3389b7c02bd3f1fa6d3357777332a23b89 100644
--- a/src/aliby/utils/imageViewer.py
+++ b/src/aliby/utils/imageViewer.py
@@ -264,8 +264,8 @@ class remoteImageViewer(BaseImageViewer):
         ]
         lbls = [self.cells.labels_at_time(tp).get(trap_id, []) for tp in tps]
         lbld_outlines = [
-            np.dstack([mask * lbl for mask, lbl in zip(maskset, lblset)]).max(
-                axis=2
+            np.stack([mask * lbl for mask, lbl in zip(maskset, lblset)]).max(
+                axis=0
             )
             if len(lblset)
             else np.zeros_like(imgs_list[0]).astype(bool)
diff --git a/src/extraction/core/extractor.py b/src/extraction/core/extractor.py
index eee3185a613eb19af8faa5b88566d76d1f197d53..61f21e5db023e1c56ccacf8999564fe4a84ee976 100644
--- a/src/extraction/core/extractor.py
+++ b/src/extraction/core/extractor.py
@@ -479,7 +479,7 @@ class Extractor(StepABC):
             masks = {trap_id: [] for trap_id in range(cells.ntraps)}
             for trap_id, cells in raw_masks.items():
                 if len(cells):
-                    masks[trap_id] = np.dstack(np.array(cells)).astype(bool)
+                    masks[trap_id] = np.stack(np.array(cells)).astype(bool)
         # convert to a list of masks
         masks = [np.array(v) for v in masks.values()]
 
@@ -491,7 +491,7 @@ class Extractor(StepABC):
         bgs = []
         if self.params.sub_bg:
             bgs = [
-                ~np.sum(m, axis=2).astype(bool)
+                ~np.sum(m, axis=0).astype(bool)
                 if np.any(m)
                 else np.zeros((tile_size, tile_size))
                 for m in masks