diff --git a/src/aliby/utils/imageViewer.py b/src/aliby/utils/imageViewer.py
index 52a15ee763f776aba369276e34748b1becff76c0..f3331983dbe73755f8ed7f17f7fcd69e1f87a3fa 100644
--- a/src/aliby/utils/imageViewer.py
+++ b/src/aliby/utils/imageViewer.py
@@ -1,24 +1,7 @@
-"""
-ImageViewer class, used to look at individual or multiple traps over time.
-
-
-Example of usage:
-
-fpath = "/home/alan/data/16543/URA8_young018.h5"
-
-tile_id = 9
-trange = list(range(0, 10))
-ncols = 8
-
-riv = remoteImageViewer(fpath)
-riv.plot_labelled_trap(tile_id, trange, [0], ncols=ncols)
-
-"""
-
 import re
 import typing as t
 from abc import ABC
-
+from pathlib import Path
 import h5py
 import matplotlib.pyplot as plt
 import numpy as np
@@ -27,13 +10,9 @@ from skimage.morphology import dilation
 
 from agora.io.cells import Cells
 from agora.io.metadata import dispatch_metadata_parser
-from aliby.io.image import ImageDir, ImageZarr, dispatch_image
+from aliby.io.image import dispatch_image
 
-try:
-    from aliby.io.omero import UnsafeImage as OImage
-except ModuleNotFoundError:
-    print("Viewing available only for local files.")
-from aliby.tile.tiler import Tiler, TilerParameters
+from aliby.tile.tiler import Tiler
 from aliby.utils.plot import stretch_clip
 
 default_colours = {
@@ -59,18 +38,23 @@ def custom_imshow(a, norm=None, cmap=None, *args, **kwargs):
 
 
 class BaseImageViewer(ABC):
-    def __init__(self, fpath):
-        self._fpath = fpath
-        self.attrs = dispatch_metadata_parser(fpath.parent)
-        self._logfiles_meta = {}
-        self.image_id = self.attrs.get("image_id")
+    """Base class with routines common to all ImageViewers."""
+
+    def __init__(self, h5file_path):
+        """Initialise from a Path to a h5 file."""
+        self.h5file_path = h5file_path
+        self.logfiles_meta = dispatch_metadata_parser(h5file_path.parent)
+        self.image_id = self.logfiles_meta.get("image_id")
         if self.image_id is None:
-            with h5py.File(fpath, "r") as f:
+            with h5py.File(h5file_path, "r") as f:
                 self.image_id = f.attrs.get("image_id")
-        assert self.image_id is not None, "No valid image_id found in metadata"
+        if self.image_id is None:
+            raise ("No valid image_id found in metadata.")
+        self.full = {}
 
     @property
     def shape(self):
+        """Return shape of image array."""
         return self.tiler.image.shape
 
     @property
@@ -87,193 +71,112 @@ class BaseImageViewer(ABC):
         """Find cell label at a given time point."""
         return self.cells.labels_at_time(tp)
 
-
-class LocalImageViewer(BaseImageViewer):
-    """
-    Generate figures from local files.
-
-    File are either zarr or organised in directories.
-    TODO move common functionality from RemoteImageViewer to BaseImageViewer
-    """
-
-    def __init__(self, h5file_path: str, data_path: str):
-        super().__init__(h5file_path)
-        self._image_class = (
-            ImageZarr if str(data_path).endswith(".zarr") else ImageDir
-        )
-        image = ImageZarr(data_path)
-        self.tiler = Tiler(
-            image.data,
-            self._meta if hasattr(self, "_meta") else self._logfiles_meta,
-            TilerParameters.default(),
-        )
-        self.cells = Cells.from_source(h5file_path)
-
-
-class RemoteImageViewer(BaseImageViewer):
-    """
-    This ImageViewer combines fetching remote images with tiling and outline display.
-    """
-
-    _credentials = ("host", "username", "password")
-
-    def __init__(
-        self,
-        results_path: str,
-        server_info: t.Dict[str, str],
+    def find_channel_indices(
+        self, channels: t.Union[str, t.Collection[str]], guess=True
     ):
-        super().__init__(results_path)
-        self._server_info = server_info or {
-            k: self.attrs["parameters"]["general"][k]
-            for k in self._credentials
-        }
-        self._image_instance = OImage(self.image_id, **self._server_info)
-        self.tiler = Tiler.from_h5(self._image_instance, results_path)
-        self.cells = Cells.from_source(results_path)
-
-    def random_valid_trap_tp(
-        self,
-        min_ncells: int = None,
-        min_consecutive_tps: int = None,
-        label_modulo: int = None,
-    ):
-        # Call Cells convenience function to pick a random trap and tp
-        # containing cells for x cells for y
-        return self.cells.random_valid_trap_tp(
-            min_ncells=min_ncells,
-            min_consecutive_tps=min_consecutive_tps,
-        )
-
-    def get_entire_position(self):
-        raise (NotImplementedError)
-
-    def get_position_timelapse(self):
-        raise (NotImplementedError)
-
-    @property
-    def full(self):
-        if not hasattr(self, "_full"):
-            self._full = {}
-        return self._full
-
-    def get_tc(self, tp, channel=None, server_info=None):
-        server_info = server_info or self._server_info
-        channel = channel or self.tiler.ref_channel
-
-        with self._image_class(self.image_id, **server_info) as image:
-            self.tiler.image = image.data
-            return self.tiler.load_image(tp, channel)
-
-    def _find_channels(self, channels: str, guess: bool = True):
+        """Find index for particular channels."""
         channels = channels or self.tiler.ref_channel
         if isinstance(channels, (int, str)):
             channels = [channels]
         if isinstance(channels[0], str):
             if guess:
-                channels = [self.tiler.channels.index(ch) for ch in channels]
+                indices = [self.tiler.channels.index(ch) for ch in channels]
             else:
-                channels = [
+                indices = [
                     re.search(ch, tiler_channels)
                     for ch in channels
                     for tiler_channels in self.tiler.channels
                 ]
+            return indices
+        else:
+            return channels
+
+    def get_outlines_tiles_dict(self, tile_id, trange, channels):
+        """Get outlines and dict of tiles with channel indices as keys."""
+        outlines = None
+        tile_dict = {}
+        for ch in self.find_channel_indices(channels):
+            outlines, tile_dict[ch] = self.get_outlines_tiles(
+                tile_id, trange, channels=[ch]
+            )
+        return outlines, tile_dict
+
+    def get_outlines_tiles(
+        self,
+        tile_id: int,
+        tps: t.Union[range, t.Collection[int]],
+        channels=None,
+        concatenate=True,
+        **kwargs,
+    ) -> t.Tuple[np.array]:
+        """
+        Get masks uniquely labelled for each cell with the corresponding tiles.
 
-        return channels
+        Returns a list of masks, each an array with distinct masks for each cell,
+        and an array of tiles for the given channel.
+        """
+        tile_dict = self.get_tiles(tps, channels=channels, **kwargs)
+        # get tiles of interest
+        tiles = [x[tile_id] for x in tile_dict.values()]
+        # get outlines for each time point
+        outlines = [
+            self.cells.at_time(tp, kind="edgemask").get(tile_id, []) for tp in tps
+        ]
+        # get cell labels for each time point
+        cell_labels = [self.cells.labels_at_time(tp).get(tile_id, []) for tp in tps]
+        # generate one image with all cell outlines uniquely labelled per tile
+        labelled_outlines = [
+            np.stack(
+                [outline * label for outline, label in zip(outlines_tp, labels_tp)]
+            ).max(axis=0)
+            if len(labels_tp)
+            else np.zeros_like(tiles[0]).astype(bool)
+            for outlines_tp, labels_tp in zip(outlines, cell_labels)
+        ]
+        if concatenate:
+            # concatenate to allow potential image processing
+            labelled_outlines = np.concatenate(labelled_outlines, axis=1)
+            tiles = np.concatenate(tiles, axis=1)
+        return labelled_outlines, tiles
 
-    def get_pos_timepoints(
+    def get_tiles(
         self,
         tps: t.Union[int, t.Collection[int]],
-        channels: t.Union[str, t.Collection[str]] = None,
+        channels: None,
         z: int = None,
-        server_info=None,
     ):
+        """Get dict with time points as keys and all available tiles as values."""
         if tps and not isinstance(tps, t.Collection):
             tps = range(tps)
-
-        # TODO add support for multiple channels or refactor
-        if channels and not isinstance(channels, t.Collection):
-            channels = [channels]
-
         if z is None:
             z = 0
-
-        server_info = server_info or self._server_info
-        channels = 0 or self._find_channels(channels)
         z = z or self.tiler.ref_z
-
-        ch_tps = [(channels[0], tp) for tp in tps]
-
-        image = self._image_instance
-        self.tiler.image = image.data
+        channel_indices = self.find_channel_indices(channels)
+        ch_tps = [(channel_indices[0], tp) for tp in tps]
         for ch, tp in ch_tps:
             if (ch, tp) not in self.full:
                 self.full[(ch, tp)] = self.tiler.get_tiles_timepoint(
                     tp, channels=[ch], z=[z]
                 )[:, 0, 0, z, ...]
-        requested_trap = {tp: self.full[(ch, tp)] for ch, tp in ch_tps}
-
-        return requested_trap
-
-    def get_labelled_trap(
-        self,
-        tile_id: int,
-        tps: t.Union[range, t.Collection[int]],
-        channels=None,
-        concatenate=True,
-        **kwargs,
-    ) -> t.Tuple[np.array]:
-        """
-        Core method to fetch traps and labels together
-        """
-        imgs = self.get_pos_timepoints(tps, channels=channels, **kwargs)
-        imgs_list = [x[tile_id] for x in imgs.values()]
-        outlines = [
-            self.cells.at_time(tp, kind="edgemask").get(tile_id, [])
-            for tp in tps
-        ]
-        lbls = [self.cells.labels_at_time(tp).get(tile_id, []) for tp in tps]
-        lbld_outlines = [
-            np.stack([mask * lbl for mask, lbl in zip(maskset, lblset)]).max(
-                axis=0
-            )
-            if len(lblset)
-            else np.zeros_like(imgs_list[0]).astype(bool)
-            for maskset, lblset in zip(outlines, lbls)
-        ]
-        if concatenate:
-            lbld_outlines = np.concatenate(lbld_outlines, axis=1)
-            imgs_list = np.concatenate(imgs_list, axis=1)
-        return lbld_outlines, imgs_list
-
-    def get_images(self, tile_id, trange, channels, **kwargs):
-        """
-        Wrapper to fetch images
-        """
-        out = None
-        imgs = {}
-
-        for ch in self._find_channels(channels):
-            out, imgs[ch] = self.get_labelled_trap(
-                tile_id, trange, channels=[ch], **kwargs
-            )
-        return out, imgs
+        tile_dict = {tp: self.full[(ch, tp)] for ch, tp in ch_tps}
+        return tile_dict
 
     def plot_labelled_trap(
         self,
-        tile_id: int,
+        tile_id,
         channels,
         trange: t.Union[range, t.Collection[int]],
         remove_axis: bool = False,
-        savefile: str = None,
         skip_outlines: bool = False,
-        norm: str = None,
+        norm=True,
         ncols: int = None,
         local_colours: bool = True,
         img_plot_kwargs: dict = {},
         lbl_plot_kwargs: dict = {"alpha": 0.8},
         **kwargs,
     ):
-        """Wrapper to plot time-lapses of individual traps
+        """
+        Plot time-lapses of individual tiles.
 
         Use Cells and Tiler to generate images of cells with their resulting
         outlines.
@@ -282,15 +185,13 @@ class RemoteImageViewer(BaseImageViewer):
         ----------
         tile_id : int
             Identifier of trap
-        channels : Union[str, int]
+        channel : Union[str, int]
             Channels to use
         trange : t.Union[range, t.Collection[int]]
             Range or collection indicating the time-points to use.
         remove_axis : bool
             None, "off", or "x". Determines whether to remove the x-axis, both
             axes or none.
-        savefile : str
-            Saves file to a location.
         skip_outlines : bool
             Do not add overlay with outlines
         norm : str
@@ -304,68 +205,40 @@ class RemoteImageViewer(BaseImageViewer):
             Arguments to pass to plt.imshow used for images.
         lbl_plot_kwargs : dict
             Keyword arguments to pass to label plots.
-        **kwargs : dict
-            Additional keyword arguments passed to ImageViewer.get_images.
-
-        Examples
-        --------
-        FIXME: Add docs.
-
         """
+        # set up for plotting
         if ncols is None:
             ncols = len(trange)
         nrows = int(np.ceil(len(trange) / ncols))
         width = self.tiler.tile_size * ncols
-
-        out, images = self.get_images(tile_id, trange, channels, **kwargs)
-
-        # dilation makes outlines easier to see
-        out = dilation(out).astype(float)
-        out[out == 0] = np.nan
-
+        outlines, tiles_dict = self.get_outlines_tiles_dict(tile_id, trange, channels)
         channel_labels = [
-            self.tiler.channels[ch] if isinstance(ch, int) else ch
-            for ch in channels
+            self.tiler.channels[ch] if isinstance(ch, int) else ch for ch in channels
+        ]
+        # dilate to make outlines easier to see
+        outlines = dilation(outlines).astype(float)
+        outlines[outlines == 0] = np.nan
+        # split concatenated tiles into one tile per time point in a row
+        tiles = [
+            into_image_time_series(tile, width, nrows) for tile in tiles_dict.values()
         ]
-
-        assert not norm or norm in (
-            "l1",
-            "l2",
-            "max",
-        ), "Invalid norm argument."
-
-        if norm and norm in ("l1", "l2", "max"):
-            images = {k: stretch_clip(v) for k, v in images.items()}
-
-        images = [concat_pad(img, width, nrows) for img in images.values()]
         # TODO convert to RGB to draw fluorescence with colour
-        tiled_imgs = {}
-        tiled_imgs["img"] = np.concatenate(images, axis=0)
-        tiled_imgs["cell_labels"] = np.concatenate(
-            [concat_pad(out, width, nrows) for _ in images], axis=0
-        )
-
-        custom_imshow(
-            tiled_imgs["img"],
-            **img_plot_kwargs,
+        res = {}
+        # concatenate different channels vertically for display
+        res["tiles"] = np.concatenate(tiles, axis=0)
+        res["cell_labels"] = np.concatenate(
+            [into_image_time_series(outlines, width, nrows) for _ in tiles], axis=0
         )
+        custom_imshow(res["tiles"], **img_plot_kwargs)
         custom_imshow(
-            tiled_imgs["cell_labels"],
-            cmap=sns.color_palette("Paired", as_cmap=True),
-            **lbl_plot_kwargs,
+            res["cell_labels"], cmap=default_colours["cell_label"], **lbl_plot_kwargs
         )
-
         if remove_axis is True:
             plt.axis("off")
         elif remove_axis == "x":
             plt.tick_params(
-                axis="x",
-                which="both",
-                bottom=False,
-                top=False,
-                labelbottom=False,
+                axis="x", which="both", bottom=False, top=False, labelbottom=False
             )
-
         if remove_axis != "True":
             plt.yticks(
                 ticks=[
@@ -375,39 +248,63 @@ class RemoteImageViewer(BaseImageViewer):
                 ],
                 labels=channel_labels,
             )
-
         if not remove_axis:
             xlabels = (
-                ["+ {} ".format(i) for i in range(ncols)]
-                if nrows > 1
-                else list(trange)
+                ["+ {} ".format(i) for i in range(ncols)] if nrows > 1 else list(trange)
             )
             plt.xlabel("Time-point")
-
             plt.xticks(
                 ticks=[self.tiler.tile_size * (i + 0.5) for i in range(ncols)],
                 labels=xlabels,
             )
+        if not np.any(outlines):
+            print("ImageViewer:Warning: No cell outlines found.")
+        plt.tight_layout()
+        plt.show(block=False)
 
-        if not np.any(out):
-            print("ImageViewer:Warning:No cell outlines found")
-
-        if savefile:
-            plt.savefig(savefile, bbox_inches="tight", dpi=300)
-            plt.close()
-        else:
-            plt.show()
 
-
-def concat_pad(a: np.array, width, nrows):
+class LocalImageViewer(BaseImageViewer):
     """
-    Melt an array into having multiple blocks as rows
+    View images from local files.
+
+    File are either zarr or organised in directories.
     """
+
+    def __init__(self, h5file: str, image_direc: str):
+        """Initialise using a h5file and a local directory of images."""
+        h5file_path = Path(h5file)
+        image_direc_path = Path(image_direc)
+        super().__init__(h5file_path)
+        with dispatch_image(image_direc_path)(image_direc_path) as image:
+            self.tiler = Tiler.from_h5(image, h5file_path)
+        self.cells = Cells.from_source(h5file_path)
+
+
+class RemoteImageViewer(BaseImageViewer):
+    """Fetching remote images with tiling and outline display."""
+
+    credentials = ("host", "username", "password")
+
+    def __init__(self, h5file: str, server_info: t.Dict[str, str]):
+        """Initialise using a h5file and importing aliby.io.omero."""
+        from aliby.io.omero import UnsafeImage as OImage
+
+        h5file_path = Path(h5file)
+        super().__init__(h5file_path)
+        self.server_info = server_info or {
+            k: self.attrs["parameters"]["general"][k] for k in self.credentials
+        }
+        image = OImage(self.image_id, **self._server_info)
+        self.tiler = Tiler.from_h5(image, h5file_path)
+        self.cells = Cells.from_source(h5file_path)
+
+
+def into_image_time_series(a: np.array, width, nrows):
+    """Split into sub-arrays and then concatenate into one."""
     return np.concatenate(
         np.array_split(
             np.pad(
                 a,
-                # ((0, 0), (0, width - (a.shape[1] % width))),
                 ((0, 0), (0, a.shape[1] % width)),
                 constant_values=np.nan,
             ),