diff --git a/src/wela/imageviewer.py b/src/wela/imageviewer.py
index a2bfd5f436b58e8e11d24dc6f1f9b0d2c4812401..bed51f6c63b394774a2b62a7330d3078e01c4647 100644
--- a/src/wela/imageviewer.py
+++ b/src/wela/imageviewer.py
@@ -9,7 +9,10 @@ except ModuleNotFoundError:
         "Napari cannot be imported.\nRun",
         ' python -m pip install "napari[all]"',
     )
+from typing import Any, Dict, List
+
 import numpy as np
+import numpy.typing as npt
 from agora.io.cells import Cells
 from aliby.io.image import dispatch_image
 from aliby.io.omero import Dataset
@@ -47,7 +50,7 @@ class ImageViewer:
                 print(f" Trouble loading {image_file}.")
 
     @classmethod
-    def remote(cls, h5file: str, server_info: dict, omero_id: int):
+    def remote(cls, h5file: str, server_info: Dict, omero_id: int):
         """View images from OMERO."""
         iv = cls(h5file)
         with h5py.File(iv.h5file_path, "r") as f:
@@ -69,7 +72,9 @@ class ImageViewer:
             iv.cells = Cells.from_source(iv.h5file_path)
             return iv
 
-    def get_all_traps_with_cells(self, tpt_end, tpt_start=0, display=True):
+    def get_all_traps_with_cells(
+        self, tpt_end: int, tpt_start: int = 0, display: bool = True
+    ):
         """List traps with cells."""
         cells = self.cells
         tpts = range(tpt_start, tpt_end)
@@ -84,7 +89,9 @@ class ImageViewer:
             print(f"Traps with cells {list(traps_with_cells)}")
         return traps_with_cells
 
-    def sample_traps_with_cells(self, no_cells, tpt_end, tpt_start=0):
+    def sample_traps_with_cells(
+        self, no_cells: int, tpt_end: int, tpt_start: int = 0
+    ):
         """Sample some traps that have cells."""
         traps_with_cells = self.get_all_traps_with_cells(
             tpt_end, tpt_start, display=False
@@ -96,7 +103,12 @@ class ImageViewer:
         )
         return samples
 
-    def get_tiles(self, trap_id, tps, channels_to_skip=None, cell_only=True):
+    def get_tiles(
+        self,
+        trap_id: int,
+        tps: List[int],
+        channels_to_skip: List[str] = None,
+    ):
         """Get dict of tiles with channel indices as keys."""
         tiles_dict = {}
         if channels_to_skip is None:
@@ -132,7 +144,7 @@ class ImageViewer:
                 tiles_dict[ch] = new_tiles
         return tiles_dict
 
-    def get_outlines(self, trap_id, tps):
+    def get_outlines(self, trap_id: int, tps: List[int]):
         """Get uniquely labelled outlines for each cell time point."""
         # get outlines for each time point
         outlines = [
@@ -161,14 +173,14 @@ class ImageViewer:
 
     def get_all_tiles(
         self,
-        tps,
-        channel_index,
-        z=0,
+        tps: List[int],
+        channel_index: str,
+        z: int = 0,
     ):
         """
         Get dict with time points as keys and all available tiles as values.
 
-        We assume only a single channel.
+        Assume only a single channel.
         """
         z = z or self.tiler.ref_z
         ch_tps = [(channel_index, tp) for tp in tps]
@@ -181,7 +193,9 @@ class ImageViewer:
         tile_dict = {tp: self.full[(ch, tp)] for ch, tp in ch_tps}
         return tile_dict
 
-    def get_data_for_viewing(self, trap_id, tps, channels_to_skip):
+    def get_data_for_viewing(
+        self, trap_id: int, tps: List[int], channels_to_skip: List[str]
+    ):
         """
         Get images and outlines as multidimensional arrays for Napari.
 
@@ -217,11 +231,11 @@ class ImageViewer:
             ts_labels[tp_index, 0, ...] = outlines[tp_index]
         return ts_images, ts_labels, channels
 
-    def concat(self, arrangement, image_dict, axis):
+    def concat(self, arrangement: List[int], image_dict: Dict, axis: int):
         """
         Concat dict of images into one image array.
 
-        Following the vertical layout in arrangment.
+        Follow the vertical layout in arrangment.
         """
         # concatenate vertically into a list
         images_v = [
@@ -239,23 +253,22 @@ class ImageViewer:
             images = images_v[0]
         return images
 
-    def combine_tiles(self, ts_images_dict, ts_labels_dict, no_vertical_tiles):
-        """Combine tiles into one image first vertically then horizontally."""
+    def combine_tiles(
+        self, ts_images_dict: Dict, ts_labels_dict: Dict, no_rows: int
+    ):
+        """Combine tiles into one image first into rows then columns."""
         no_tiles = len(ts_images_dict)
         trap_ids = list(ts_images_dict.keys())
         # find how tiles will be arranged in the concatenated image
-        if no_tiles < no_vertical_tiles:
+        if no_tiles < no_rows:
             arrangement = [trap_ids]
         else:
             arrangement = [
-                trap_ids[i : min(i + no_vertical_tiles, no_tiles)]
+                trap_ids[i : min(i + no_rows, no_tiles)]
                 for i in range(
                     0,
-                    int(
-                        np.floor(no_tiles / no_vertical_tiles)
-                        * no_vertical_tiles
-                    ),
-                    no_vertical_tiles,
+                    int(np.floor(no_tiles / no_rows) * no_rows),
+                    no_rows,
                 )
             ]
         if no_tiles > np.array(arrangement).size:
@@ -271,26 +284,30 @@ class ImageViewer:
 
     def view(
         self,
-        trap_ids,
-        tpt_end=10,
-        tpt_start=0,
-        channels_to_skip=None,
-        no_vertical_tiles=3,
+        trap_ids: List[int],
+        tpt_start: int = 0,
+        tpt_end: int = 10,
+        channels_to_skip: List[str] = None,
+        no_rows: int = 2,
     ):
         """
-        Use Napari to view all channels and outlines for a particular trap.
+        Use Napari to view all channels and outlines for particular traps.
 
         Fluorescence channels will not be immediately visible.
+        Concatenating traps into one image can become slow for multiple traps.
 
         Parameters
         ----------
-        trap_id: int
-            The trap to be viewed.
-        tps: int or array of ints
-            Either the last time point to be viewed or a rage of time points
-            to view.
-            If None, all time points will be viewed, but gathering the images
-            will be slow.
+        trap_ids: list of int
+            The traps to be viewed.
+        tpt_start: int
+            The index for the initial time point to view.
+        tpt_end: int
+            The index for the final time point.
+        channels_to_skip: list of str
+            Channels to ignore, such as "cy5".
+        no_rows: int
+            The number of rows of traps in the final concatenated image.
         """
         tps = np.arange(tpt_start, tpt_end + 1)
         if isinstance(trap_ids, int):
@@ -302,13 +319,18 @@ class ImageViewer:
             )
         # combine tiles
         ts_images, ts_labels = self.combine_tiles(
-            ts_images_dict, ts_labels_dict, no_vertical_tiles
+            ts_images_dict, ts_labels_dict, no_rows
         )
         # launch napari
         self.launch_napari(ts_images, ts_labels, channels)
 
-    def launch_napari(self, ts_images, ts_labels, channels):
-        """Use Napari to see the images and outlines."""
+    def launch_napari(
+        self,
+        ts_images: npt.NDArray[Any],
+        ts_labels: npt.NDArray[Any],
+        channels: List[str],
+    ):
+        """Call Napari viewer."""
         viewer = napari.Viewer()
         viewer.add_image(
             ts_images[:, channels.index("Brightfield"), ...],
@@ -328,7 +350,9 @@ class ImageViewer:
 
 
 ####
-def colormap(channel):
+
+
+def colormap(channel: str):
     """Find default colormap."""
     if "GFP" in channel:
         colormap = "green"