diff --git a/src/aliby/pipeline.py b/src/aliby/pipeline.py
index b58302c3e69d64b3b84c0fd0a389425eb5c5dbf4..e67c6e6f659fee169c8325a125d0f23a6a11b460 100644
--- a/src/aliby/pipeline.py
+++ b/src/aliby/pipeline.py
@@ -141,7 +141,7 @@ class PipelineParameters(ParametersABC):
         # define defaults and update with any inputs
         defaults["tiler"] = TilerParameters.default(**tiler).to_dict()
 
-        # Generate a backup channel, for when logfile meta is available
+        # generate a backup channel, for when logfile meta is available
         # but not image metadata.
         backup_ref_channel = None
         if "channels" in meta_d and isinstance(
@@ -384,7 +384,7 @@ class Pipeline(ProcessABC):
         session = None
         filename = None
         #
-        run_kwargs = {"extraction": {"labels": None, "masks": None}}
+        run_kwargs = {"extraction": {"cell_labels": None, "masks": None}}
         try:
             (
                 filename,
@@ -507,7 +507,7 @@ class Pipeline(ProcessABC):
                                         )
                                     elif step == "extraction":
                                         # remove masks and labels after extraction
-                                        for k in ["masks", "labels"]:
+                                        for k in ["masks", "cell_labels"]:
                                             run_kwargs[step][k] = None
                             # check and report clogging
                             frac_clogged_traps = self.check_earlystop(
diff --git a/src/aliby/tile/tiler.py b/src/aliby/tile/tiler.py
index 27c1e814d4682228149fc0343d9a1a564cc7ad14..be01900ee28c521e4f9f421dedf4d1474e39594c 100644
--- a/src/aliby/tile/tiler.py
+++ b/src/aliby/tile/tiler.py
@@ -344,7 +344,7 @@ class Tiler(StepABC):
         return tiler
 
     @lru_cache(maxsize=2)
-    def get_tc(self, t: int, c: int) -> np.ndarray:
+    def get_tc(self, tp: int, c: int) -> np.ndarray:
         """
         Load image using dask.
 
@@ -357,7 +357,7 @@ class Tiler(StepABC):
 
         Parameters
         ----------
-        t: integer
+        tp: integer
             An index for a time point
         c: integer
             An index for a channel
@@ -366,7 +366,7 @@ class Tiler(StepABC):
         -------
         full: an array of images
         """
-        full = self.image[t, c]
+        full = self.image[tp, c]
         if hasattr(full, "compute"):
             # if using dask fetch images
             full = full.compute(scheduler="synchronous")
@@ -570,9 +570,8 @@ class Tiler(StepABC):
         Returns
         -------
         res: array
-            Data arranged as (tiles, channels, time points, X, Y, Z)
+            Data arranged as (tiles, channels, Z, X, Y)
         """
-        # FIXME add support for sub-tiling a tile
         # FIXME can we ignore z
         if channels is None:
             channels = [0]
@@ -583,8 +582,7 @@ class Tiler(StepABC):
         for c in channels:
             # only return requested z
             val = self.get_tp_data(tp, c)[:, z]
-            # starts with the order: tiles, z, y, x
-            # returns the order: tiles, C, T, Z, X, Y
+            # starts with the order: tiles, Z, Y, X
             val = np.expand_dims(val, axis=1)
             res.append(val)
         if tile_shape is not None:
@@ -596,7 +594,10 @@ class Tiler(StepABC):
                     for tile_size, ax in zip(tile_shape, res[0].shape[-3:-2])
                 ]
             )
-        return np.stack(res, axis=1)
+        # convert to array with channels as first column
+        # final has dimensions (tiles, channels, 1, Z, X, Y)
+        final = np.stack(res, axis=1)
+        return final
 
     @property
     def ref_channel_index(self):
diff --git a/src/extraction/core/extractor.py b/src/extraction/core/extractor.py
index cebb0b2a1d9229174c18f1ea9a9f5cb0a80383cf..ad42f9b9620e5d711d234c298d718a504d9206ad 100644
--- a/src/extraction/core/extractor.py
+++ b/src/extraction/core/extractor.py
@@ -26,14 +26,14 @@ extraction_result = t.Dict[
     str, t.Dict[reduction_method, t.Dict[str, t.Dict[str, pd.Series]]]
 ]
 
-# Global variables used to load functions that either analyse cells or their background. These global variables both allow the functions to be stored in a dictionary for access only on demand and to be defined simply in extraction/core/functions.
+# Global variables used to load functions that either analyse cells
+# or their background. These global variables both allow the functions
+# to be stored in a dictionary for access only on demand and to be
+# defined simply in extraction/core/functions.
 CELL_FUNS, TRAPFUNS, FUNS = load_funs()
 CUSTOM_FUNS, CUSTOM_ARGS = load_custom_args()
 RED_FUNS = load_redfuns()
 
-# Assign datatype depending on the metric used
-# m2type = {"mean": np.float32, "median": np.ubyte, "imBackground": np.ubyte}
-
 
 class ExtractorParameters(ParametersABC):
     """Base class to define parameters for extraction."""
@@ -74,13 +74,19 @@ class Extractor(StepABC):
     """
     Apply a metric to cells identified in the tiles.
 
-    Using the cell masks, the Extractor applies a metric, such as area or median, to cells identified in the image tiles.
+    Using the cell masks, the Extractor applies a metric, such as
+    area or median, to cells identified in the image tiles.
 
     Its methods require both tile images and masks.
 
-    Usually the metric is applied to only a tile's masked area, but some metrics depend on the whole tile.
+    Usually the metric is applied to only a tile's masked area, but
+    some metrics depend on the whole tile.
 
-    Extraction follows a three-level tree structure. Channels, such as GFP, are the root level; the reduction algorithm, such as maximum projection, is the second level; the specific metric, or operation, to apply to the masks, such as mean, is the third or leaf level.
+    Extraction follows a three-level tree structure. Channels, such
+    as GFP, are the root level; the reduction algorithm, such as
+    maximum projection, is the second level; the specific metric,
+    or operation, to apply to the masks, such as mean, is the third
+    or leaf level.
     """
 
     # TODO Alan: Move this to a location with the SwainLab defaults
@@ -107,7 +113,8 @@ class Extractor(StepABC):
         store: str
             Path to the h5 file containing the cell masks.
         tiler: pipeline-core.core.segmentation tiler
-            Class that contains or fetches the images used for segmentation.
+            Class that contains or fetches the images used for
+            segmentation.
         """
         self.params = parameters
         if store:
@@ -161,13 +168,16 @@ class Extractor(StepABC):
 
     def load_custom_funs(self):
         """
-        Incorporate the extra arguments of custom functions into their definitions.
+        Incorporate the extra arguments of custom functions into their
+        definitions.
 
         Normal functions only have cell_masks and trap_image as their
         arguments, and here custom functions are made the same by
         setting the values of their extra arguments.
 
-        Any other parameters are taken from the experiment's metadata and automatically applied. These parameters therefore must be loaded within an Extractor instance.
+        Any other parameters are taken from the experiment's metadata
+        and automatically applied. These parameters therefore must be
+        loaded within an Extractor instance.
         """
         # find functions specified in params.tree
         funs = set(
@@ -222,7 +232,8 @@ class Extractor(StepABC):
         """
         Find tiles for a given time point, channels, and z-stacks.
 
-        Any additional keyword arguments are passed to tiler.get_tiles_timepoint
+        Any additional keyword arguments are passed to
+        tiler.get_tiles_timepoint
 
         Parameters
         ----------
@@ -243,8 +254,9 @@ class Extractor(StepABC):
             # a list of the indices of the z stacks
             channel_ids = None
         if z is None:
-            # gets the tiles data via tiler
+            # include all Z channels
             z = list(range(self.tiler.shape[-3]))
+        # get the image data via tiler
         res = (
             self.tiler.get_tiles_timepoint(
                 tp, channels=channel_ids, z=z, **kwargs
@@ -252,7 +264,7 @@ class Extractor(StepABC):
             if channel_ids
             else None
         )
-        # data arranged as (tiles, channels, time points, X, Y, Z)
+        # res has dimensions (tiles, channels, 1, Z, X, Y)
         return res
 
     def extract_traps(
@@ -260,7 +272,7 @@ class Extractor(StepABC):
         traps: t.List[np.ndarray],
         masks: t.List[np.ndarray],
         metric: str,
-        labels: t.Dict[int, t.List[int]],
+        cell_labels: t.Dict[int, t.List[int]],
     ) -> t.Tuple[t.Union[t.Tuple[float], t.Tuple[t.Tuple[int]]]]:
         """
         Apply a function to a whole position.
@@ -273,23 +285,25 @@ class Extractor(StepABC):
             t.List of masks.
         metric: str
             Metric to extract.
-        labels: dict
-            A dict of cell labels with trap_ids as keys and a list of cell labels as values.
+        cell_labels: dict
+            A dict of cell labels with trap_ids as keys and a list
+            of cell labels as values.
         pos_info: bool
             Whether to add the position as an index or not.
 
         Returns
         -------
         res_idx: a tuple of tuples
-            A two-tuple comprising a tuple of results and a tuple of the tile_id and cell labels
+            A two-tuple comprising a tuple of results and a tuple of
+            the tile_id and cell labels
         """
-        if labels is None:
-            self._log("No labels given. Sorting cells using index.")
+        if cell_labels is None:
+            self._log("No cell labels given. Sorting cells using index.")
         cell_fun = True if metric in self._all_cell_funs else False
         idx = []
         results = []
         for trap_id, (mask_set, trap, lbl_set) in enumerate(
-            zip(masks, traps, labels.values())
+            zip(masks, traps, cell_labels.values())
         ):
             # ignore empty traps
             if len(mask_set):
@@ -344,7 +358,9 @@ class Extractor(StepABC):
         masks: list of arrays
             An array of masks for each trap: one per cell at the trap
         red_metrics: dict
-            dict for which keys are reduction functions and values are either a list or a set of strings giving the metric functions.
+            dict for which keys are reduction functions and values are
+            either a list or a set of strings giving the metric
+            functions.
             For example: {'np_max': {'max5px', 'mean', 'median'}}
         **kwargs: dict
             All other arguments passed to Extractor.extract_funs.
@@ -353,7 +369,8 @@ class Extractor(StepABC):
         ------
         Dict of dataframes with the corresponding reductions and metrics nested.
         """
-        # create dict with keys naming the reduction in the z-direction and the reduced data as values
+        # create dict with keys naming the reduction in the z-direction
+        # and the reduced data as values
         reduced_tiles_data = {}
         if traps is not None:
             for red_fun in red_metrics.keys():
@@ -392,64 +409,24 @@ class Extractor(StepABC):
             reduced = reduce_z(img, method)
         return reduced
 
-    def extract_tp(
-        self,
-        tp: int,
-        tree: t.Optional[extraction_tree] = None,
-        tile_size: int = 117,
-        masks: t.Optional[t.List[np.ndarray]] = None,
-        labels: t.Optional[t.List[int]] = None,
-        **kwargs,
-    ) -> t.Dict[str, t.Dict[str, t.Dict[str, tuple]]]:
-        """
-        Extract for an individual time point.
-
-        Parameters
-        ----------
-        tp : int
-            Time point being analysed.
-        tree : dict
-            Nested dictionary indicating channels, reduction functions and
-            metrics to be used.
-            For example: {'general': {'None': ['area', 'volume', 'eccentricity']}}
-        tile_size : int
-            Size of the tile to be extracted.
-        masks : list of arrays
-            A list of masks per trap with each mask having dimensions (ncells, tile_size,
-            tile_size).
-        labels : dict
-            A dictionary with trap_ids as keys and cell_labels as values.
-        **kwargs : keyword arguments
-            Passed to extractor.reduce_extract.
-
-        Returns
-        -------
-        d: dict
-            Dictionary of the results with three levels of dictionaries.
-            The first level has channels as keys.
-            The second level has reduction metrics as keys.
-            The third level has cell or background metrics as keys and a two-tuple as values.
-            The first tuple is the result of applying the metrics to a particular cell or trap; the second tuple is either (trap_id, cell_label) for a metric applied to a cell or a trap_id for a metric applied to a trap.
-
-            An example is d["GFP"]["np_max"]["mean"][0], which gives a tuple of the calculated mean GFP fluorescence for all cells.
-        """
-        # TODO Can we split the different extraction types into sub-methods to make this easier to read?
+    def make_tree_bits(self, tree):
+        """Put extraction tree and information for the channels into a dict."""
         if tree is None:
             # use default
             tree: extraction_tree = self.params.tree
-        # dictionary with channel: {reduction algorithm : metric}
-        ch_tree = {ch: v for ch, v in tree.items() if ch != "general"}
-        # tuple of the channels
-        tree_chs = (*ch_tree,)
-        # create a Cells object to extract information from the h5 file
-        cells = Cells(self.local)
-        # find the cell labels and store as dict with trap_ids as keys
-        if labels is None:
-            raw_labels = cells.labels_at_time(tp)
-            labels = {
-                trap_id: raw_labels.get(trap_id, [])
-                for trap_id in range(cells.ntraps)
-            }
+        tree_bits = {
+            "tree": tree,
+            # dictionary with channel: {reduction algorithm : metric}
+            "channel_tree": {
+                ch: v for ch, v in tree.items() if ch != "general"
+            },
+        }
+        # tuple of the fluorescence channels
+        tree_bits["tree_channels"] = (*tree_bits["channel_tree"],)
+        return tree_bits
+
+    def get_masks(self, tp, masks, cells):
+        """Get the masks as a list with an array of masks for each trap."""
         # find the cell masks for a given trap as a dict with trap_ids as keys
         if masks is None:
             raw_masks = cells.at_time(tp, kind="mask")
@@ -458,16 +435,31 @@ class Extractor(StepABC):
                 if len(cells):
                     masks[trap_id] = np.stack(np.array(cells)).astype(bool)
         # convert to a list of masks
+        # one array of size (no cells, tile_size, tile_size) per trap
         masks = [np.array(v) for v in masks.values()]
-        # find image data at the time point
-        # stored as an array arranged as (traps, channels, time points, X, Y, Z)
-        tiles = self.get_tiles(tp, tile_shape=tile_size, channels=tree_chs)
-        # generate boolean masks for background as a list with one mask per trap
-        bgs = np.array([])
+        return masks
+
+    def get_cell_labels(self, tp, cell_labels, cells):
+        """Get the cell labels per trap as a dict with trap_ids as keys."""
+        if cell_labels is None:
+            raw_cell_labels = cells.labels_at_time(tp)
+            cell_labels = {
+                trap_id: raw_cell_labels.get(trap_id, [])
+                for trap_id in range(cells.ntraps)
+            }
+        return cell_labels
+
+    def get_background_masks(self, masks, tile_size):
+        """
+        Generate boolean background masks.
+
+        Combine masks per trap and then take the logical inverse.
+        """
         if self.params.sub_bg:
             bgs = ~np.array(
                 list(
                     map(
+                        # sum over masks for each cell
                         lambda x: np.sum(x, axis=0)
                         if np.any(x)
                         else np.zeros((tile_size, tile_size)),
@@ -475,15 +467,29 @@ class Extractor(StepABC):
                     )
                 )
             ).astype(bool)
-        # perform extraction by applying metrics
+        else:
+            bgs = np.array([])
+        return bgs
+
+    def extract_one_channel(
+        self, tree_bits, cell_labels, tiles, masks, bgs, **kwargs
+    ):
+        """
+        Extract using all metrics requiring a single channel.
+
+        Apply first without and then with background subtraction.
+
+        Return the extraction results and a dict of background corrected images.
+        """
         d = {}
-        self.img_bgsub = {}
-        for ch, red_metrics in tree.items():
+        img_bgsub = {}
+        for ch, red_metrics in tree_bits["tree"].items():
             # NB ch != is necessary for threading
             if ch != "general" and tiles is not None and len(tiles):
-                # image data for all traps and z sections for a particular channel
-                # as an array arranged as (tiles, Z, X, Y, )
-                img = tiles[:, tree_chs.index(ch), 0]
+                # image data for all traps for a particular channel and time point
+                # arranged as (traps, Z, X, Y)
+                # we use 0 here to access the single time point available
+                img = tiles[:, tree_bits["tree_channels"].index(ch), 0]
             else:
                 img = None
             # apply metrics to image data
@@ -491,37 +497,43 @@ class Extractor(StepABC):
                 traps=img,
                 masks=masks,
                 red_metrics=red_metrics,
-                labels=labels,
+                cell_labels=cell_labels,
                 **kwargs,
             )
             # apply metrics to image data with the background subtracted
             if bgs.any() and ch in self.params.sub_bg and img is not None:
-                # calculate metrics with subtracted bg
+                # calculate metrics with background subtracted
                 ch_bs = ch + "_bgsub"
                 # subtract median background
-                self.img_bgsub[ch_bs] = np.moveaxis(
-                    np.stack(
-                        list(
-                            map(
-                                lambda tile, mask: np.moveaxis(tile, 0, -1)
-                                - bn.median(tile[:, mask], axis=1),
-                                img,
-                                bgs,
-                            )
-                        )
-                    ),
-                    -1,
-                    1,
-                )  # End with tiles, z, y, x
+                bgsub_mapping = map(
+                    # move time to last column to allow subtraction
+                    lambda img, bgs: np.moveaxis(img, 0, -1)
+                    # median of background over all pixels for each time point
+                    - bn.median(img[:, bgs], axis=1),
+                    img,
+                    bgs,
+                )
+                # apply map and convert to array
+                mapping_result = np.stack(list(bgsub_mapping))
+                # move time axis back to the second column
+                img_bgsub[ch_bs] = np.moveaxis(mapping_result, -1, 1)
                 # apply metrics to background-corrected data
                 d[ch_bs] = self.reduce_extract(
-                    red_metrics=ch_tree[ch],
-                    traps=self.img_bgsub[ch_bs],
+                    red_metrics=tree_bits["channel_tree"][ch],
+                    traps=img_bgsub[ch_bs],
                     masks=masks,
-                    labels=labels,
+                    cell_labels=cell_labels,
                     **kwargs,
                 )
-        # apply any metrics using multiple channels, such as pH calculations
+        return d, img_bgsub
+
+    def extract_multiple_channels(
+        self, tree_bits, cell_labels, tiles, masks, **kwargs
+    ):
+        """
+        Extract using all metrics requiring multiple channels.
+        """
+        d = {}
         for name, (
             chs,
             merge_fun,
@@ -529,22 +541,99 @@ class Extractor(StepABC):
         ) in self.params.multichannel_ops.items():
             if len(
                 set(chs).intersection(
-                    set(self.img_bgsub.keys()).union(tree_chs)
+                    set(self.img_bgsub.keys()).union(
+                        tree_bits["tree_channels"]
+                    )
                 )
             ) == len(chs):
                 channels_stack = np.stack(
-                    [self.get_imgs(ch, tiles, tree_chs) for ch in chs], axis=-1
+                    [
+                        self.get_imgs(ch, tiles, tree_bits["tree_channels"])
+                        for ch in chs
+                    ],
+                    axis=-1,
                 )
                 merged = RED_FUNS[merge_fun](channels_stack, axis=-1)
                 d[name] = self.reduce_extract(
                     red_metrics=red_metrics,
                     traps=merged,
                     masks=masks,
-                    labels=labels,
+                    cell_labels=cell_labels,
                     **kwargs,
                 )
         return d
 
+    def extract_tp(
+        self,
+        tp: int,
+        tree: t.Optional[extraction_tree] = None,
+        tile_size: int = 117,
+        masks: t.Optional[t.List[np.ndarray]] = None,
+        cell_labels: t.Optional[t.List[int]] = None,
+        **kwargs,
+    ) -> t.Dict[str, t.Dict[str, t.Dict[str, tuple]]]:
+        """
+        Extract for an individual time point.
+
+        Parameters
+        ----------
+        tp : int
+            Time point being analysed.
+        tree : dict
+            Nested dictionary indicating channels, reduction functions
+            and metrics to be used.
+            For example: {'general': {'None': ['area', 'volume', 'eccentricity']}}
+        tile_size : int
+            Size of the tile to be extracted.
+        masks : list of arrays
+            A list of masks per trap with each mask having dimensions
+            (ncells, tile_size, tile_size) and with one mask per cell.
+        cell_labels : dict
+            A dictionary with trap_ids as keys and cell_labels as values.
+        **kwargs : keyword arguments
+            Passed to extractor.reduce_extract.
+
+        Returns
+        -------
+        d: dict
+            Dictionary of the results with three levels of dictionaries.
+            The first level has channels as keys.
+            The second level has reduction metrics as keys.
+            The third level has cell or background metrics as keys and a
+            two-tuple as values.
+            The first tuple is the result of applying the metrics to a
+            particular cell or trap; the second tuple is either
+            (trap_id, cell_label) for a metric applied to a cell or a
+            trap_id for a metric applied to a trap.
+
+            An example is d["GFP"]["np_max"]["mean"][0], which gives a tuple
+            of the calculated mean GFP fluorescence for all cells.
+        """
+        # dict of information from extraction tree
+        tree_bits = self.make_tree_bits(tree)
+        # create a Cells object to extract information from the h5 file
+        cells = Cells(self.local)
+        # find the cell labels as dict with trap_ids as keys
+        cell_labels = self.get_cell_labels(tp, cell_labels, cells)
+        # get masks one per cell per trap
+        masks = self.get_masks(tp, masks, cells)
+        # find image data at the time point
+        # stored as an array arranged as (traps, channels, 1, Z, X, Y)
+        tiles = self.get_tiles(
+            tp, tile_shape=tile_size, channels=tree_bits["tree_channels"]
+        )
+        # generate boolean masks for background for each trap
+        bgs = self.get_background_masks(masks, tile_size)
+        # perform extraction
+        res_one, self.img_bgsub = self.extract_one_channel(
+            tree_bits, cell_labels, tiles, masks, bgs, **kwargs
+        )
+        res_two = self.extract_multiple_channels(
+            tree_bits, cell_labels, tiles, masks, **kwargs
+        )
+        res = {**res_one, **res_two}
+        return res
+
     def get_imgs(self, channel: t.Optional[str], tiles, channels=None):
         """
         Return image from a correct source, either raw or bgsub.
@@ -554,14 +643,16 @@ class Extractor(StepABC):
         channel: str
             Name of channel to get.
         tiles: ndarray
-            An array of the image data having dimensions of (tile_id, channel, tp, tile_size, tile_size, n_zstacks).
+            An array of the image data having dimensions of
+            (tile_id, channel, tp, tile_size, tile_size, n_zstacks).
         channels: list of str (optional)
             t.List of available channels.
 
         Returns
         -------
         img: ndarray
-            An array of image data with dimensions (no tiles, X, Y, no Z channels)
+            An array of image data with dimensions
+            (no tiles, X, Y, no Z channels)
         """
         if channels is None:
             channels = (*self.params.tree,)
@@ -598,7 +689,9 @@ class Extractor(StepABC):
         Returns
         -------
         d: dict
-            A dict of the extracted data for one position with a concatenated string of channel, reduction metric, and cell metric as keys and pd.DataFrame of the extracted data for all time points as values.
+            A dict of the extracted data for one position with a concatenated
+            string of channel, reduction metric, and cell metric as keys and
+            pd.DataFrame of the extracted data for all time points as values.
         """
         if tree is None:
             tree = self.params.tree
@@ -673,14 +766,17 @@ def flatten_nesteddict(
     nest: dict of dicts
         Contains the nested results of extraction.
     to: str (optional)
-        Specifies the format of the output, either pd.Series (default) or a list
+        Specifies the format of the output, either pd.Series (default)
+        or a list
     tp: int
         Time point used to name the pd.Series
 
     Returns
     -------
     d: dict
-        A dict with a concatenated string of channel, reduction metric, and cell metric as keys and either a pd.Series or a list of the corresponding extracted data as values.
+        A dict with a concatenated string of channel, reduction metric,
+        and cell metric as keys and either a pd.Series or a list of the
+        corresponding extracted data as values.
     """
     d = {}
     for k0, v0 in nest.items():
@@ -690,14 +786,3 @@ def flatten_nesteddict(
                     pd.Series(*v2, name=tp) if to == "series" else v2
                 )
     return d
-
-
-class hollowExtractor(Extractor):
-    """
-    Extractor that only cares about receiving images and masks.
-
-    Used for testing.
-    """
-
-    def __init__(self, parameters):
-        self.params = parameters
diff --git a/src/extraction/core/functions/distributors.py b/src/extraction/core/functions/distributors.py
index e9b5265f55373af6acd409d4a018d9b6341dbd7b..90838e61b21ef505f87df784fbe82da6781efce1 100644
--- a/src/extraction/core/functions/distributors.py
+++ b/src/extraction/core/functions/distributors.py
@@ -44,5 +44,6 @@ def reduce_z(trap_image: np.ndarray, fun: t.Callable, axis: int = 0):
     elif isinstance(fun, np.ufunc):
         # optimise the reduction function if possible
         return fun.reduce(trap_image, axis=axis)
-    else:  # WARNING: Very slow, only use when no alternatives exist
+    else:
+        # WARNING: Very slow, only use when no alternatives exist
         return np.apply_along_axis(fun, axis, trap_image)