diff --git a/aliby/utils/imageViewer.py b/aliby/utils/imageViewer.py
index cfdb25421683318859e496b31f6fa8f90934a942..f50f0185768da50eda435f61d2b30b7335011706 100644
--- a/aliby/utils/imageViewer.py
+++ b/aliby/utils/imageViewer.py
@@ -101,9 +101,12 @@ class remoteImageViewer:
 
         return channels
 
-    def get_trap_timepoints(self, trap_id, tps, channels=None, server_info=None):
+    def get_trap_timepoints(
+        self, trap_id, tps, channels=None, z=None, server_info=None
+    ):
         server_info = server_info or self.server_info
         channels = self.find_channels(channels)
+        z = z or self.tiler.ref_z
 
         ch_tps = set([(channels[0], tp) for tp in tps])
         with OImage(self.image_id, **server_info) as image:
@@ -111,9 +114,14 @@ class remoteImageViewer:
             if ch_tps.difference(self.full.keys()):
                 tps = set(tps).difference(self.full.keys())
                 for ch, tp in ch_tps:
-                    self.full[(ch, tp)] = self.tiler.get_traps_timepoint(
-                        tp, channels=[ch], z=[0]
-                    )[:, 0, 0, ..., 0]
+                    if z is 0:
+                        self.full[(ch, tp)] = self.tiler.get_traps_timepoint(
+                            tp, channels=[ch], z=[z]
+                        )[:, 0, 0, ..., 0]
+                    else:
+                        self.full[(ch, tp)] = self.tiler.get_traps_timepoint(
+                            tp, channels=[ch], z=[z]
+                        )[:, 0, 0, ..., z]
             requested_trap = {tp: self.full[(ch, tp)] for ch, tp in ch_tps}
 
             return requested_trap
@@ -135,12 +143,25 @@ class remoteImageViewer:
         img_concat = np.concatenate(imgs_list, axis=1)
         return outline_concat, img_concat
 
-    def plot_labeled_channelrows(self, trap_id, channels, trange, **kwargs):
+    def get_images(self, trap_id, trange, channels, **kwargs):
+        """
+        Wrapper to fetch images
+        """
         imgs = {}
+
         for ch in self.find_channels(channels):
             out, imgs[ch] = self.get_labeled_trap(
                 trap_id, trange, channels=[ch], **kwargs
             )
+        return out, imgs
+
+    def plot_labeled_zstacks(self, trap_id, channels, trange, z=None, **kwargs):
+        # if z is None:
+        #     z =
+        out, images = self.get_imgs(trap_id, trange, channels, z=z, **kwargs)
+
+    def plot_labeled_channelrows(self, trap_id, channels, trange, **kwargs):
+        out, images = self.get_imgs(trap_id, trange, channels, **kwargs)
 
         # dilation makes outlines easier to see
         out = dilation(out).astype(float)
@@ -171,7 +192,16 @@ class remoteImageViewer:
         )
         plt.show()
 
-    def plot_labeled_traps(self, trap_id, trange, ncols, **kwargs):
+    def plot_labeled_traps(
+        self,
+        trap_id,
+        trange,
+        ncols,
+        remove_axis=False,
+        savefile=False,
+        skip_outlines=False,
+        **kwargs
+    ):
         """
         Wrapper to plot a single trap over time
 
@@ -202,18 +232,31 @@ class remoteImageViewer:
             )
 
         plt.imshow(
-            self.concat_pad(img),
+            concat_pad(img),
             interpolation=None,
             cmap="Greys_r",
         )
-        plt.imshow(
-            self.concat_pad(out),
-            # concat_pad(mask),
-            cmap="Set1",
-            interpolation=None,
-        )
-        plt.yticks(
-            ticks=[self.tiler.tile_size * (i + 0.5) for i in range(nrows)],
-            labels=[trange[0] + ncols * i for i in range(nrows)],
-        )
-        plt.show()
+        if not skip_outlines:
+            plt.imshow(
+                concat_pad(out),
+                # concat_pad(mask),
+                cmap="Set1",
+                interpolation=None,
+            )
+
+        bbox_inches = None
+        if remove_axis:
+            plt.axis("off")
+            bbox_inches = "tight"
+
+        else:
+            plt.yticks(
+                ticks=[self.tiler.tile_size * (i + 0.5) for i in range(nrows)],
+                labels=[trange[0] + ncols * i for i in range(nrows)],
+            )
+
+        if not savefile:
+            plt.show()
+        else:
+            if np.any(out):
+                plt.savefig(savefile, bbox_inches=bbox_inches)