diff --git a/src/aliby/pipeline.py b/src/aliby/pipeline.py
index 6c70c575721d2ed0a00705599bbeafed07cd2605..2edcc52c55e5ed72b3bd8d9bc9b740847ed64397 100644
--- a/src/aliby/pipeline.py
+++ b/src/aliby/pipeline.py
@@ -474,7 +474,7 @@ class Pipeline(ProcessABC):
                                         and i == min_process_from
                                     ):
                                         logging.getLogger("aliby").info(
-                                            f"Found {steps['tiler'].n_traps} traps in {image.name}"
+                                            f"Found {steps['tiler'].n_tiles} traps in {image.name}"
                                         )
                                     elif step == "baby":
                                         # write state and pass info to ext (Alan: what's ext?)
diff --git a/src/aliby/tile/tiler.py b/src/aliby/tile/tiler.py
index 34aa89d1c2af76976f134a0fd38d58919cedc9bb..c91f412daa5bb51f23190e1a75fc2b9e3e3c3448 100644
--- a/src/aliby/tile/tiler.py
+++ b/src/aliby/tile/tiler.py
@@ -1,19 +1,15 @@
 """
-Tiler: Tiles and tracks traps.
+Tiler: Divides images into smaller tiles.
 
-The tasks of the Tiler are selecting regions of interest, or tiles, of an image - with one tile per trap, tracking and correcting for the drift of the microscope stage over time, and handling errors and bridging between the image data and ALIBY’s image-processing steps.
+The tasks of the Tiler are selecting regions of interest, or tiles, of images - with one trap per tile, correcting for the drift of the microscope stage over time, and handling errors and bridging between the image data and Aliby’s image-processing steps.
 
 Tiler subclasses deal with either network connections or local files.
 
-To find traps, we use a two-step process: we analyse the bright-field image to produce the template of a trap, and we fit this template to the image to find the traps' centres.
+To find tiles, we use a two-step process: we analyse the bright-field image to produce the template of a trap, and we fit this template to the image to find the tiles' centres.
 
 We use texture-based segmentation (entropy) to split the image into foreground -- cells and traps -- and background, which we then identify with an Otsu filter. Two methods are used to produce a template trap from these regions: pick the trap with the smallest minor axis length and average over all validated traps.
 
-A peak-identifying algorithm recovers the x and y-axis location of traps in the original image, and we choose the templating approach that identifies the most traps
-
-One key method is Tiler.run.
-
-The image-processing is performed by traps/segment_traps.
+A peak-identifying algorithm recovers the x and y-axis location of traps in the original image, and we choose the approach to template that identifies the most tiles.
 
 The experiment is stored as an array with a standard indexing order of (Time, Channels, Z-stack, X, Y).
 """
@@ -34,11 +30,12 @@ from aliby.io.image import ImageLocalOME, ImageDir, ImageDummy
 from aliby.tile.traps import segment_traps
 
 
-class Trap:
+class Tile:
     """
-    Stores a trap's location and size.
-    Allows checks to see if the trap should be padded.
-    Can export the trap either in OMERO or numpy formats.
+    Store a tile's location and size.
+
+    Checks to see if the tile should be padded.
+    Can export the tile either in OMERO or numpy formats.
     """
 
     def __init__(self, centre, parent, size, max_size):
@@ -50,31 +47,28 @@ class Trap:
 
     def at_time(self, tp: int) -> t.List[int]:
         """
-        Return trap centre at time tp by applying drifts
+        Return tile's centre by applying drifts.
 
         Parameters
         ----------
         tp: integer
-            Index for a time point
-
-        Returns
-        -------
-        trap_centre:
+            Index for the time point of interest.
         """
         drifts = self.parent.drifts
-        trap_centre = self.centre - np.sum(drifts[: tp + 1], axis=0)
-        return list(trap_centre.astype(int))
+        tile_centre = self.centre - np.sum(drifts[: tp + 1], axis=0)
+        return list(tile_centre.astype(int))
 
-    def as_tile(self, tp):
+    def as_tile(self, tp: int):
         """
-        Return trap in the OMERO tile format of x, y, w, h
-        where x, y are at the bottom left corner of the tile
+        Return tile in the OMERO tile format of x, y, w, h.
+
+        Here x, y are at the bottom left corner of the tile
         and w and h are the tile width and height.
 
         Parameters
         ----------
         tp: integer
-            Index for a time point
+            Index for the time point of interest.
 
         Returns
         -------
@@ -93,10 +87,10 @@ class Trap:
         y = int(y - self.half_size)
         return x, y, self.size, self.size
 
-    def as_range(self, tp):
+    def as_range(self, tp: int):
         """
-        Return trap in a range format: two slice objects that can
-        be used in arrays
+        Return tile in a range format: two slice objects that can
+        be used in arrays.
 
         Parameters
         ----------
@@ -112,11 +106,8 @@ class Trap:
         return slice(x, x + w), slice(y, y + h)
 
 
-class TrapLocations:
-    """
-    Stores each trap as an instance of Trap.
-    Traps can be iterated.
-    """
+class TileLocations:
+    """Store each tile as an instance of Tile."""
 
     def __init__(
         self,
@@ -130,29 +121,27 @@ class TrapLocations:
         self.tile_size = tile_size
         self.max_size = max_size
         self.initial_location = initial_location
-        self.traps = [
-            Trap(centre, self, tile_size or max_size, max_size)
+        self.tiles = [
+            Tile(centre, self, tile_size or max_size, max_size)
             for centre in initial_location
         ]
         self.drifts = drifts
 
     def __len__(self):
-        return len(self.traps)
+        return len(self.tiles)
 
     def __iter__(self):
-        yield from self.traps
+        yield from self.tiles
 
     @property
     def shape(self):
-        """
-        Returns no of traps and no of drifts
-        """
-        return len(self.traps), len(self.drifts)
+        """Return numbers of tiles and drifts."""
+        return len(self.tiles), len(self.drifts)
 
-    def to_dict(self, tp):
+    def to_dict(self, tp: int):
         """
-        Export inital locations, tile_size, max_size, and drifts
-        as a dictionary
+        Export initial locations, tile_size, max_size, and drifts
+        as a dictionary.
 
         Parameters
         ----------
@@ -168,47 +157,49 @@ class TrapLocations:
         return res
 
     def at_time(self, tp: int) -> np.ndarray:
-        # Returns ( ntraps, 2 ) ndarray with the trap centres as individual rows
-        return np.array([trap.at_time(tp) for trap in self.traps])
+        """Return an array of tile centres (x- and y-coords)."""
+        return np.array([tile.at_time(tp) for tile in self.tiles])
 
     @classmethod
     def from_tiler_init(
         cls, initial_location, tile_size: int = None, max_size: int = 1200
     ):
-        """
-        Instantiate class from an instance of the Tiler class
-        """
+        """Instantiate from a Tiler."""
         return cls(initial_location, tile_size, max_size, drifts=[])
 
     @classmethod
     def read_hdf5(cls, file):
-        """
-        Instantiate class from a hdf5 file
-        """
+        """Instantiate from a h5 file."""
         with h5py.File(file, "r") as hfile:
-            trap_info = hfile["trap_info"]
-            initial_locations = trap_info["trap_locations"][()]
-            drifts = trap_info["drifts"][()].tolist()
-            max_size = trap_info.attrs["max_size"]
-            tile_size = trap_info.attrs["tile_size"]
-        trap_locs = cls(initial_locations, tile_size, max_size=max_size)
-        trap_locs.drifts = drifts
-        return trap_locs
+            tile_info = hfile["trap_info"]
+            initial_locations = tile_info["trap_locations"][()]
+            drifts = tile_info["drifts"][()].tolist()
+            max_size = tile_info.attrs["max_size"]
+            tile_size = tile_info.attrs["tile_size"]
+        tile_loc_cls = cls(initial_locations, tile_size, max_size=max_size)
+        tile_loc_cls.drifts = drifts
+        return tile_loc_cls
 
 
 class TilerParameters(ParametersABC):
-    _defaults = {"tile_size": 117, "ref_channel": "Brightfield", "ref_z": 0}
+    """Set default parameters for Tiler."""
+
+    _defaults = {
+        "tile_size": 117,
+        "ref_channel": "Brightfield",
+        "ref_z": 0,
+    }
 
 
 class Tiler(StepABC):
     """
-    Remote Timelapse Tiler.
+    Divide images into smaller tiles for faster processing.
 
-    Finds traps and re-registers images if there is any drifting.
-    Fetches images from a server.
+    Finds tiles and re-registers images if they drift.
+    Fetch images from an OMERO server if necessary.
 
-    Uses an Image instance, which lazily provides the data on pixels, and, as
-    an independent argument, metadata.
+    Uses an Image instance, which lazily provides the data on pixels,
+    and, as an independent argument, metadata.
     """
 
     def __init__(
@@ -216,17 +207,17 @@ class Tiler(StepABC):
         image: da.core.Array,
         metadata: dict,
         parameters: TilerParameters,
-        trap_locs=None,
+        tile_locs=None,
     ):
         """
-        Initialise Tiler
+        Initialise.
 
         Parameters
         ----------
         image: an instance of Image
         metadata: dictionary
-        parameters: an instance of TilerPameters
-        trap_locs: (optional)
+        parameters: an instance of TilerParameters
+        tile_locs: (optional)
         """
         super().__init__(parameters)
         self.image = image
@@ -235,8 +226,7 @@ class Tiler(StepABC):
             "channels", list(range(metadata["size_c"]))
         )
         self.ref_channel = self.get_channel_index(parameters.ref_channel)
-
-        self.trap_locs = trap_locs
+        self.tile_locs = tile_locs
         try:
             self.z_perchannel = {
                 ch: zsect
@@ -244,24 +234,24 @@ class Tiler(StepABC):
             }
         except Exception as e:
             self._log(f"No z_perchannel data: {e}")
-
         self.tile_size = self.tile_size or min(self.image.shape[-2:])
 
     @classmethod
     def dummy(cls, parameters: dict):
         """
-        Instantiate dummy Tiler from dummy image
+        Instantiate dummy Tiler from dummy image.
 
         If image.dimorder exists dimensions are saved in that order.
         Otherwise default to "tczyx".
 
         Parameters
         ----------
-        parameters: dictionary output of an instance of TilerParameters
+        parameters: dict
+            An instance of TilerParameters converted to a dict.
         """
         imgdmy_obj = ImageDummy(parameters)
         dummy_image = imgdmy_obj.get_data_lazy()
-        # Default to "tczyx" if image.dimorder is None
+        # default to "tczyx" if image.dimorder is None
         dummy_omero_metadata = {
             f"size_{dim}": dim_size
             for dim, dim_size in zip(
@@ -277,7 +267,6 @@ class Tiler(StepABC):
                 "name": "",
             }
         )
-
         return cls(
             imgdmy_obj.data,
             dummy_omero_metadata,
@@ -287,7 +276,7 @@ class Tiler(StepABC):
     @classmethod
     def from_image(cls, image, parameters: TilerParameters):
         """
-        Instantiate Tiler from an Image instance
+        Instantiate from an Image instance.
 
         Parameters
         ----------
@@ -306,7 +295,7 @@ class Tiler(StepABC):
         parameters: TilerParameters = None,
     ):
         """
-        Instantiate Tiler from hdf5 files
+        Instantiate from h5 files.
 
         Parameters
         ----------
@@ -315,7 +304,7 @@ class Tiler(StepABC):
             Path to a directory of h5 files
         parameters: an instance of TileParameters (optional)
         """
-        trap_locs = TrapLocations.read_hdf5(filepath)
+        tile_locs = TileLocations.read_hdf5(filepath)
         metadata = BridgeH5(filepath).meta_h5
         metadata["channels"] = image.metadata["channels"]
         if parameters is None:
@@ -324,16 +313,17 @@ class Tiler(StepABC):
             image.data,
             metadata,
             parameters,
-            trap_locs=trap_locs,
+            tile_locs=tile_locs,
         )
-        if hasattr(trap_locs, "drifts"):
-            tiler.n_processed = len(trap_locs.drifts)
+        if hasattr(tile_locs, "drifts"):
+            tiler.n_processed = len(tile_locs.drifts)
         return tiler
 
     @lru_cache(maxsize=2)
-    def get_tc(self, t, c):
+    def get_tc(self, t: int, c: int):
         """
         Load image using dask.
+
         Assumes the image is arranged as
             no of time points
             no of channels
@@ -348,7 +338,7 @@ class Tiler(StepABC):
         c: integer
             An index for a channel
 
-        Retruns
+        Returns
         -------
         full: an array of images
         """
@@ -358,16 +348,13 @@ class Tiler(StepABC):
     @property
     def shape(self):
         """
-        Returns properties of the time-lapse as shown by self.image.shape
-
+        Return properties of the time-lapse as shown by self.image.shape
         """
         return self.image.shape
 
     @property
     def n_processed(self):
-        """
-        Returns the number of images that have been processed
-        """
+        """Return the number of processed images."""
         if not hasattr(self, "_n_processed"):
             self._n_processed = 0
         return self._n_processed
@@ -377,22 +364,21 @@ class Tiler(StepABC):
         self._n_processed = value
 
     @property
-    def n_traps(self):
-        """
-        Returns number of traps
-        """
-        return len(self.trap_locs)
+    def n_tiles(self):
+        """Return number of tiles."""
+        return len(self.tile_locs)
 
-    def initialise_traps(self, tile_size: int = None):
+    def initialise_tiles(self, tile_size: int = None):
         """
-        Find initial trap positions if they have not been initialised.
-        Removes all those that are too close to the edge so no padding
-        is necessary.
+        Find initial positions of tiles.
+
+        Remove tiles that are too close to the edge of the image
+        so no padding is necessary.
 
         Parameters
         ----------
         tile_size: integer
-            The size of a tile
+            The size of a tile.
         """
         initial_image = self.image[0, self.ref_channel, self.ref_z]
         if tile_size:
@@ -400,27 +386,27 @@ class Tiler(StepABC):
             # max_size is the minimal number of x or y pixels
             max_size = min(self.image.shape[-2:])
             # first time point, reference channel, reference z-position
-            # find the traps
-            trap_locs = segment_traps(initial_image, tile_size)
-            # keep only traps that are not near an edge
-            trap_locs = [
+            # find the tiles
+            tile_locs = segment_traps(initial_image, tile_size)
+            # keep only tiles that are not near an edge
+            tile_locs = [
                 [x, y]
-                for x, y in trap_locs
+                for x, y in tile_locs
                 if half_tile < x < max_size - half_tile
                 and half_tile < y < max_size - half_tile
             ]
-            # store traps in an instance of TrapLocations
-            self.trap_locs = TrapLocations.from_tiler_init(
-                trap_locs, tile_size
+            # store tiles in an instance of TileLocations
+            self.tile_locs = TileLocations.from_tiler_init(
+                tile_locs, tile_size
             )
         else:
             yx_shape = self.image.shape[-2:]
-            trap_locs = [[x // 2 for x in yx_shape]]
-            self.trap_locs = TrapLocations.from_tiler_init(
-                trap_locs, max_size=min(yx_shape)
+            tile_locs = [[x // 2 for x in yx_shape]]
+            self.tile_locs = TileLocations.from_tiler_init(
+                tile_locs, max_size=min(yx_shape)
             )
 
-    def find_drift(self, tp):
+    def find_drift(self, tp: int):
         """
         Find any translational drift between two images at consecutive
         time points using cross correlation.
@@ -428,7 +414,7 @@ class Tiler(StepABC):
         Arguments
         ---------
         tp: integer
-            Index for a time point
+            Index for a time point.
         """
         prev_tp = max(0, tp - 1)
         # cross-correlate
@@ -437,14 +423,14 @@ class Tiler(StepABC):
             self.image[tp, self.ref_channel, self.ref_z],
         )
         # store drift
-        if 0 < tp < len(self.trap_locs.drifts):
-            self.trap_locs.drifts[tp] = drift.tolist()
+        if 0 < tp < len(self.tile_locs.drifts):
+            self.tile_locs.drifts[tp] = drift.tolist()
         else:
-            self.trap_locs.drifts.append(drift.tolist())
+            self.tile_locs.drifts.append(drift.tolist())
 
     def get_tp_data(self, tp, c):
         """
-        Returns all traps corrected for drift.
+        Return all tiles corrected for drift.
 
         Parameters
         ----------
@@ -453,41 +439,42 @@ class Tiler(StepABC):
         c: integer
             An index for a channel
         """
-        traps = []
+        tiles = []
         # get OMERO image
         full = self.get_tc(tp, c)
-        for trap in self.trap_locs:
-            # pad trap if necessary
-            ndtrap = self.ifoob_pad(full, trap.as_range(tp))
-            traps.append(ndtrap)
-        return np.stack(traps)
+        for tile in self.tile_locs:
+            # pad tile if necessary
+            ndtile = self.ifoob_pad(full, tile.as_range(tp))
+            tiles.append(ndtile)
+        return np.stack(tiles)
 
-    def get_trap_data(self, trap_id, tp, c):
+    def get_tile_data(self, tile_id: int, tp: int, c: int):
         """
-        Returns a particular trap corrected for drift and padding
+        Return a particular tile corrected for drift and padding.
 
         Parameters
         ----------
-        trap_id: integer
-            Number of trap
+        tile_id: integer
+            Number of tile.
         tp: integer
-            Index of time points
+            Index of time points.
         c: integer
-            Index of channel
+            Index of channel.
 
         Returns
         -------
-        ndtrap: array
+        ndtile: array
             An array of (x, y) arrays, one for each z stack
         """
         full = self.get_tc(tp, c)
-        trap = self.trap_locs.traps[trap_id]
-        ndtrap = self.ifoob_pad(full, trap.as_range(tp))
-        return ndtrap
+        tile = self.tile_locs.tiles[tile_id]
+        ndtile = self.ifoob_pad(full, tile.as_range(tp))
+        return ndtile
 
-    def _run_tp(self, tp):
+    def _run_tp(self, tp: int):
         """
-        Find traps if they have not yet been found.
+        Find tiles if they have not yet been found.
+
         Determine any translational drift of the current image from the
         previous one.
 
@@ -498,10 +485,10 @@ class Tiler(StepABC):
         """
         # assert tp >= self.n_processed, "Time point already processed"
         # TODO check contiguity?
-        if self.n_processed == 0 or not hasattr(self.trap_locs, "drifts"):
-            self.initialise_traps(self.tile_size)
-        if hasattr(self.trap_locs, "drifts"):
-            drift_len = len(self.trap_locs.drifts)
+        if self.n_processed == 0 or not hasattr(self.tile_locs, "drifts"):
+            self.initialise_tiles(self.tile_size)
+        if hasattr(self.tile_locs, "drifts"):
+            drift_len = len(self.tile_locs.drifts)
             if self.n_processed != drift_len:
                 warnings.warn("Tiler:n_processed and ndrifts don't match")
                 self.n_processed = drift_len
@@ -510,7 +497,7 @@ class Tiler(StepABC):
         # update n_processed
         self.n_processed = tp + 1
         # return result for writer
-        return self.trap_locs.to_dict(tp)
+        return self.tile_locs.to_dict(tp)
 
     def run(self, time_dim=None):
         """
@@ -524,14 +511,13 @@ class Tiler(StepABC):
 
     def get_traps_timepoint(self, *args, **kwargs):
         self._log(
-            "get_trap_timepoints is deprecated; get_tiles_timepoint instead."
+            "get_traps_timepoint is deprecated; get_tiles_timepoint instead."
         )
-
         return self.get_tiles_timepoint(*args, **kwargs)
 
     # The next set of functions are necessary for the extraction object
     def get_tiles_timepoint(
-        self, tp, tile_shape=None, channels=None, z: int = 0
+        self, tp: int, tile_shape=None, channels=None, z: int = 0
     ) -> np.ndarray:
         """
         Get a multidimensional array with all tiles for a set of channels
@@ -553,10 +539,10 @@ class Tiler(StepABC):
         Returns
         -------
         res: array
-            Data arranged as (traps, channels, timepoints, X, Y, Z)
+            Data arranged as (tiles, channels, timepoints, X, Y, Z)
         """
-        # FIXME add support for subtiling trap
-        # FIXME can we ignore z(always  give)
+        # FIXME add support for sub-tiling a tile
+        # FIXME can we ignore z
         if channels is None:
             channels = [0]
         elif isinstance(channels, str):
@@ -566,8 +552,8 @@ class Tiler(StepABC):
         for c in channels:
             # only return requested z
             val = self.get_tp_data(tp, c)[:, z]
-            # starts with the order: traps, z, y, x
-            # returns the order: trap, C, T, X, Y, Z
+            # starts with the order: tiles, z, y, x
+            # returns the order: tiles, C, T, X, Y, Z
             val = val.swapaxes(1, 3).swapaxes(1, 2)
             val = np.expand_dims(val, axis=1)
             res.append(val)
@@ -584,16 +570,19 @@ class Tiler(StepABC):
 
     @property
     def ref_channel_index(self):
+        """Return index of reference channel."""
         return self.get_channel_index(self.parameters.ref_channel)
 
     def get_channel_index(self, channel: str or int):
         """
-        Find index for channel using regex. Returns the first matched string.
+        Find index for channel using regex.
+
+        Returns the first matched string.
 
         Parameters
         ----------
         channel: string or int
-            The channel or index to be used
+            The channel or index to be used.
         """
         if isinstance(channel, str):
             channel = find_channel_index(self.channels, channel)
@@ -606,7 +595,7 @@ class Tiler(StepABC):
     @staticmethod
     def ifoob_pad(full, slices):
         """
-        Returns the slices padded if it is out of bounds.
+        Return the slices padded if out of bounds.
 
         Parameters
         ----------
@@ -614,11 +603,11 @@ class Tiler(StepABC):
             Slice of OMERO image (zstacks, x, y) - the entire position
             with zstacks as first axis
         slices: tuple of two slices
-            Delineates indiceds for the x- and y- ranges of the tile.
+            Delineates indices for the x- and y- ranges of the tile.
 
         Returns
         -------
-        trap: array
+        tile: array
             A tile with all z stacks for the given slices.
             If some padding is needed, the median of the image is used.
             If much padding is needed, a tile of NaN is returned.
@@ -628,7 +617,7 @@ class Tiler(StepABC):
         # ignore parts of the tile outside of the image
         y, x = [slice(max(0, s.start), min(max_size, s.stop)) for s in slices]
         # get the tile including all z stacks
-        trap = full[:, y, x]
+        tile = full[:, y, x]
         # find extent of padding needed in x and y
         padding = np.array(
             [(-min(0, s.start), -min(0, max_size - s.stop)) for s in slices]
@@ -638,13 +627,15 @@ class Tiler(StepABC):
             if (padding > tile_size / 4).any():
                 # too much of the tile is outside of the image
                 # fill with NaN
-                trap = np.full((full.shape[0], tile_size, tile_size), np.nan)
+                tile = np.full((full.shape[0], tile_size, tile_size), np.nan)
             else:
-                # pad tile with median value of trap image
-                trap = np.pad(trap, [[0, 0]] + padding.tolist(), "median")
-        return trap
+                # pad tile with median value of the tile
+                tile = np.pad(tile, [[0, 0]] + padding.tolist(), "median")
+        return tile
 
 
+# Alan: do we need these as well as get_channel_index and get_channel_name?
+# self._log below is not defined
 def find_channel_index(image_channels: t.List[str], channel: str):
     """
     Access
@@ -659,7 +650,14 @@ def find_channel_index(image_channels: t.List[str], channel: str):
 
 def find_channel_name(image_channels: t.List[str], channel: str):
     """
-    Find the name of the channel according to a given channel regex.
+    Find the name of the channel using regex.
+
+    Parameters
+    ----------
+    image_channels: list of str
+        Channels.
+    channel: str
+        A regular expression.
     """
     index = find_channel_index(image_channels, channel)
     if index is not None:
diff --git a/src/aliby/tile/traps.py b/src/aliby/tile/traps.py
index 4eddeb7e45a0f39ea0de28c865b79b685061da5b..65b21b9ce744d59bd87b2232b12b6971a05a8395 100644
--- a/src/aliby/tile/traps.py
+++ b/src/aliby/tile/traps.py
@@ -1,7 +1,4 @@
-"""
-A set of utilities for dealing with ALCATRAS traps
-"""
-
+"""Functions for identifying and dealing with ALCATRAS traps."""
 
 import numpy as np
 from skimage import feature, transform
@@ -31,10 +28,10 @@ def segment_traps(
     **identify_traps_kwargs,
 ):
     """
-    Uses an entropy filter and Otsu thresholding to find a trap template,
+    Use an entropy filter and Otsu thresholding to find a trap template,
     which is then passed to identify_trap_locations.
 
-    To obtain candidate traps it the major axis length of a tile must be smaller than tilesize.
+    To obtain candidate traps the major axis length of a tile must be smaller than tilesize.
 
     The hyperparameters have not been optimised.
 
@@ -60,7 +57,7 @@ def segment_traps(
     Returns
     -------
     traps: an array of pairs of integers
-        The coordinates of the centroids of the traps
+        The coordinates of the centroids of the traps.
     """
     # keep a memory of image in case need to re-run
     img = image
@@ -144,17 +141,18 @@ def identify_trap_locations(
     image, trap_template, optimize_scale=True, downscale=0.35, trap_size=None
 ):
     """
-    Identify the traps in a single image based on a trap template,
-    which requires the trap template to be similar to the image
-    (same camera, same magification - ideally the same experiment).
+    Identify the traps in a single image based on a trap template.
+
+    Requires the trap template to be similar to the image
+    (same camera, same magnification - ideally the same experiment).
 
-    Uses normalised correlation in scikit-image's match_template.
+    Use normalised correlation in scikit-image's to match_template.
 
-    The search is speeded up by downscaling both the image and
+    The search is sped up by down-scaling both the image and
     the trap template before running the template matching.
 
     The trap template is rotated and re-scaled to improve matching.
-    The parameters of the rotation and rescaling are optimised, although
+    The parameters of the rotation and re-scaling are optimised, although
     over restricted ranges.
 
     Parameters
@@ -243,4 +241,4 @@ def stretch_image(image):
     maxval = np.percentile(image, 98)
     image = np.clip(image, minval, maxval)
     image = (image - minval) / (maxval - minval)
-    return image
\ No newline at end of file
+    return image