From be41600875889fa7cf712e703a78e238a383e1fb Mon Sep 17 00:00:00 2001
From: Peter Swain <peter.swain@ed.ac.uk>
Date: Fri, 10 Feb 2023 16:09:02 +0000
Subject: [PATCH] Minor tweaks to documentation in extraction

---
 src/aliby/tile/tiler.py                     |   2 +-
 src/extraction/core/extractor.py            | 129 +++++++++-----------
 src/extraction/core/functions/cell.py       |  57 +++++----
 src/extraction/core/functions/defaults.py   |  40 +++---
 src/extraction/core/functions/loaders.py    |  32 ++---
 src/extraction/core/functions/math_utils.py |   1 -
 src/extraction/core/functions/trap.py       |  10 +-
 7 files changed, 130 insertions(+), 141 deletions(-)

diff --git a/src/aliby/tile/tiler.py b/src/aliby/tile/tiler.py
index c91f412d..9c1c08ce 100644
--- a/src/aliby/tile/tiler.py
+++ b/src/aliby/tile/tiler.py
@@ -539,7 +539,7 @@ class Tiler(StepABC):
         Returns
         -------
         res: array
-            Data arranged as (tiles, channels, timepoints, X, Y, Z)
+            Data arranged as (tiles, channels, time points, X, Y, Z)
         """
         # FIXME add support for sub-tiling a tile
         # FIXME can we ignore z
diff --git a/src/extraction/core/extractor.py b/src/extraction/core/extractor.py
index eee3185a..f22898a8 100644
--- a/src/extraction/core/extractor.py
+++ b/src/extraction/core/extractor.py
@@ -18,7 +18,7 @@ from extraction.core.functions.loaders import (
     load_redfuns,
 )
 
-# Define types
+# define types
 reduction_method = t.Union[t.Callable, str, None]
 extraction_tree = t.Dict[
     str, t.Dict[reduction_method, t.Dict[str, t.Collection]]
@@ -27,7 +27,7 @@ extraction_result = t.Dict[
     str, t.Dict[reduction_method, t.Dict[str, t.Dict[str, pd.Series]]]
 ]
 
-# Global parameters used to load functions that either analyse cells or their background. These global parameters both allow the functions to be stored in a dictionary for access only on demand and to be defined simply in extraction/core/functions.
+# Global variables used to load functions that either analyse cells or their background. These global variables both allow the functions to be stored in a dictionary for access only on demand and to be defined simply in extraction/core/functions.
 CELL_FUNS, TRAPFUNS, FUNS = load_funs()
 CUSTOM_FUNS, CUSTOM_ARGS = load_custom_args()
 RED_FUNS = load_redfuns()
@@ -37,9 +37,7 @@ RED_FUNS = load_redfuns()
 
 
 class ExtractorParameters(ParametersABC):
-    """
-    Base class to define parameters for extraction.
-    """
+    """Base class to define parameters for extraction."""
 
     def __init__(
         self,
@@ -48,12 +46,14 @@ class ExtractorParameters(ParametersABC):
         multichannel_ops: t.Dict = {},
     ):
         """
+        Initialise.
+
         Parameters
         ----------
         tree: dict
             Nested dictionary indicating channels, reduction functions and
             metrics to be used.
-            str channel -> U(function,None) reduction -> str metric
+            str channel -> U(function, None) reduction -> str metric
             If not of depth three, tree will be filled with None.
         sub_bg: set
         multichannel_ops: dict
@@ -65,7 +65,7 @@ class ExtractorParameters(ParametersABC):
     @staticmethod
     def guess_from_meta(store_name: str, suffix="fast"):
         """
-        Find the microscope used from the h5 metadata.
+        Find the microscope name from the h5 metadata.
 
         Parameters
         ----------
@@ -98,17 +98,7 @@ class Extractor(StepABC):
 
     Usually the metric is applied to only a tile's masked area, but some metrics depend on the whole tile.
 
-    Extraction follows a three-level tree structure. Channels, such as GFP, are the root level; the reduction algorithm, such as maximum projection, is the second level; the specific metric, or operation, to apply to the masks is the third level.
-
-    Parameters
-    ----------
-    parameters: core.extractor Parameters
-        Parameters that include the channels, and reduction and
-        extraction functions.
-    store: str
-        Path to the h5 file, which must contain the cell masks.
-    tiler: pipeline-core.core.segmentation tiler
-        Class that contains or fetches the images used for segmentation.
+    Extraction follows a three-level tree structure. Channels, such as GFP, are the root level; the reduction algorithm, such as maximum projection, is the second level; the specific metric, or operation, to apply to the masks, such as mean, is the third level.
     """
 
     # Alan: should this data be stored here or all such data in a separate file
@@ -129,10 +119,13 @@ class Extractor(StepABC):
 
         Parameters
         ----------
-        parameters: ExtractorParameters object
+        parameters: core.extractor Parameters
+            Parameters that include the channels, reduction and
+            extraction functions.
         store: str
-            Name of h5 file
-        tiler: Tiler object
+            Path to the h5 file containing the cell masks.
+        tiler: pipeline-core.core.segmentation tiler
+            Class that contains or fetches the images used for segmentation.
         """
         self.params = parameters
         if store:
@@ -179,14 +172,18 @@ class Extractor(StepABC):
 
     @property
     def group(self):
-        # returns path within h5 file
+        """Return path within the h5 file."""
         if not hasattr(self, "_out_path"):
             self._group = "/extraction/"
         return self._group
 
     def load_custom_funs(self):
         """
-        Define any custom functions to be functions of cell_masks and trap_image only.
+        Incorporate the extra arguments of custom functions into their definitions.
+
+        Normal functions only have cell_masks and trap_image as their
+        arguments, and here custom functions are made the same by
+        setting the values of their extra arguments.
 
         Any other parameters are taken from the experiment's metadata and automatically applied. These parameters therefore must be loaded within an Extractor instance.
         """
@@ -206,12 +203,13 @@ class Extractor(StepABC):
             k: {k2: self.get_meta(k2) for k2 in v}
             for k, v in CUSTOM_ARGS.items()
         }
-        # define custom functions - those with extra arguments other than cell_masks and trap_image - as functions of two variables
+        # define custom functions
         self._custom_funs = {}
         for k, f in CUSTOM_FUNS.items():
 
             def tmp(f):
                 # pass extra arguments to custom function
+                # return a function of cell_masks and trap_image
                 return lambda cell_masks, trap_image: trap_apply(
                     f,
                     cell_masks,
@@ -222,6 +220,7 @@ class Extractor(StepABC):
             self._custom_funs[k] = tmp(f)
 
     def load_funs(self):
+        """Define all functions, including custum ones."""
         self.load_custom_funs()
         self._all_cell_funs = set(self._custom_funs.keys()).union(CELL_FUNS)
         # merge the two dicts
@@ -239,20 +238,18 @@ class Extractor(StepABC):
         **kwargs,
     ) -> t.Optional[np.ndarray]:
         """
-        Find tiles for a given time point and given channels and z-stacks.
-
-        Returns None if no tiles are found.
+        Find tiles for a given time point, channels, and z-stacks.
 
         Any additional keyword arguments are passed to tiler.get_tiles_timepoint
 
         Parameters
         ----------
         tp: int
-            Time point of interest
+            Time point of interest.
         channels: list of strings (optional)
-            Channels of interest
+            Channels of interest.
         z: list of integers (optional)
-            Indices for the z-stacks of interest
+            Indices for the z-stacks of interest.
         """
         if channels is None:
             # find channels from tiler
@@ -265,16 +262,16 @@ class Extractor(StepABC):
             channel_ids = None
         if z is None:
             # gets the tiles data via tiler
-            z: t.List[int] = list(range(self.tiler.shape[-3]))
-        tiles = (
+            z = list(range(self.tiler.shape[-3]))
+        res = (
             self.tiler.get_tiles_timepoint(
                 tp, channels=channel_ids, z=z, **kwargs
             )
             if channel_ids
             else None
         )
-        # data arranged as (traps, channels, timepoints, X, Y, Z)
-        return tiles
+        # data arranged as (tiles, channels, time points, X, Y, Z)
+        return res
 
     def extract_traps(
         self,
@@ -302,11 +299,10 @@ class Extractor(StepABC):
         Returns
         -------
         res_idx: a tuple of tuples
-            A two-tuple of a tuple of results and a tuple with the corresponding trap_id and cell labels
+            A two-tuple comprising a tuple of results and a tuple of the tile_id and cell labels
         """
         if labels is None:
             self._log("No labels given. Sorting cells using index.")
-
         cell_fun = True if metric in self._all_cell_funs else False
         idx = []
         results = []
@@ -337,7 +333,9 @@ class Extractor(StepABC):
         **kwargs,
     ) -> t.Dict[str, pd.Series]:
         """
-        Returns dict with metrics as key and metrics applied to data as values for data from one timepoint.
+        Return dict with metrics as key and metrics applied to data as values.
+
+        Data from one time point is used.
         """
         d = {
             metric: self.extract_traps(
@@ -349,7 +347,7 @@ class Extractor(StepABC):
 
     def reduce_extract(
         self,
-        traps: np.ndarray,
+        tiles_data: np.ndarray,
         masks: t.List[np.ndarray],
         red_metrics: t.Dict[reduction_method, t.Collection[str]],
         **kwargs,
@@ -359,8 +357,8 @@ class Extractor(StepABC):
 
         Parameters
         ----------
-        traps: array
-            An array of image data arranged as (traps, X, Y, Z)
+        tiles_data: array
+            An array of image data arranged as (tiles, X, Y, Z)
         masks: list of arrays
             An array of masks for each trap: one per cell at the trap
         red_metrics: dict
@@ -371,20 +369,20 @@ class Extractor(StepABC):
 
         Returns
         ------
-        Dictionary of dataframes with the corresponding reductions and metrics nested.
+        Dict of dataframes with the corresponding reductions and metrics nested.
         """
         # create dict with keys naming the reduction in the z-direction and the reduced data as values
-        reduced_traps = {}
-        if traps is not None:
+        reduced_tiles_data = {}
+        if tiles_data is not None:
             for red_fun in red_metrics.keys():
-                reduced_traps[red_fun] = [
-                    self.reduce_dims(trap, method=RED_FUNS[red_fun])
-                    for trap in traps
+                reduced_tiles_data[red_fun] = [
+                    self.reduce_dims(tile_data, method=RED_FUNS[red_fun])
+                    for tile_data in tiles_data
                 ]
         d = {
             red_fun: self.extract_funs(
                 metrics=metrics,
-                traps=reduced_traps.get(red_fun, [None for _ in masks]),
+                traps=reduced_tiles_data.get(red_fun, [None for _ in masks]),
                 masks=masks,
                 **kwargs,
             )
@@ -403,9 +401,9 @@ class Extractor(StepABC):
         Parameters
         ----------
         img: array
-            An array of the image data arranged as (X, Y, Z)
+            An array of the image data arranged as (X, Y, Z).
         method: function
-            The reduction function
+            The reduction function.
         """
         reduced = img
         if method is not None:
@@ -422,7 +420,7 @@ class Extractor(StepABC):
         **kwargs,
     ) -> t.Dict[str, t.Dict[str, t.Dict[str, tuple]]]:
         """
-        Extract for an individual time-point.
+        Extract for an individual time point.
 
         Parameters
         ----------
@@ -452,7 +450,6 @@ class Extractor(StepABC):
             The first tuple is the result of applying the metrics to a particular cell or trap; the second tuple is either (trap_id, cell_label) for a metric applied to a cell or a trap_id for a metric applied to a trap.
 
             An example is d["GFP"]["np_max"]["mean"][0], which gives a tuple of the calculated mean GFP fluorescence for all cells.
-
         """
         # TODO Can we split the different extraction types into sub-methods to make this easier to read?
         if tree is None:
@@ -464,7 +461,6 @@ class Extractor(StepABC):
         tree_chs = (*ch_tree,)
         # create a Cells object to extract information from the h5 file
         cells = Cells(self.local)
-
         # find the cell labels and store as dict with trap_ids as keys
         if labels is None:
             raw_labels = cells.labels_at_time(tp)
@@ -472,7 +468,6 @@ class Extractor(StepABC):
                 trap_id: raw_labels.get(trap_id, [])
                 for trap_id in range(cells.ntraps)
             }
-
         # find the cell masks for a given trap as a dict with trap_ids as keys
         if masks is None:
             raw_masks = cells.at_time(tp, kind="mask")
@@ -482,11 +477,9 @@ class Extractor(StepABC):
                     masks[trap_id] = np.dstack(np.array(cells)).astype(bool)
         # convert to a list of masks
         masks = [np.array(v) for v in masks.values()]
-
         # find image data at the time point
-        # stored as an array arranged as (traps, channels, timepoints, X, Y, Z)
+        # stored as an array arranged as (traps, channels, time points, X, Y, Z)
         tiles = self.get_tiles(tp, tile_shape=tile_size, channels=tree_chs)
-
         # generate boolean masks for background as a list with one mask per trap
         bgs = []
         if self.params.sub_bg:
@@ -496,7 +489,6 @@ class Extractor(StepABC):
                 else np.zeros((tile_size, tile_size))
                 for m in masks
             ]
-
         # perform extraction by applying metrics
         d = {}
         self.img_bgsub = {}
@@ -510,9 +502,9 @@ class Extractor(StepABC):
                 img = None
             # apply metrics to image data
             d[ch] = self.reduce_extract(
-                red_metrics=red_metrics,
-                traps=img,
+                tiles_data=img,
                 masks=masks,
+                red_metrics=red_metrics,
                 labels=labels,
                 **kwargs,
             )
@@ -537,8 +529,7 @@ class Extractor(StepABC):
                     labels=labels,
                     **kwargs,
                 )
-
-        # apply any metrics that use multiple channels (eg pH calculations)
+        # apply any metrics using multiple channels, such as pH calculations
         for name, (
             chs,
             merge_fun,
@@ -560,10 +551,9 @@ class Extractor(StepABC):
                     labels=labels,
                     **kwargs,
                 )
-
         return d
 
-    def get_imgs(self, channel: t.Optional[str], traps, channels=None):
+    def get_imgs(self, channel: t.Optional[str], tiles, channels=None):
         """
         Return image from a correct source, either raw or bgsub.
 
@@ -571,20 +561,20 @@ class Extractor(StepABC):
         ----------
         channel: str
             Name of channel to get.
-        traps: ndarray
-            An array of the image data having dimensions of (trap_id, channel, tp, tile_size, tile_size, n_zstacks).
+        tiles: ndarray
+            An array of the image data having dimensions of (tile_id, channel, tp, tile_size, tile_size, n_zstacks).
         channels: list of str (optional)
             List of available channels.
 
         Returns
         -------
         img: ndarray
-            An array of image data with dimensions (no traps, X, Y, no Z channels)
+            An array of image data with dimensions (no tiles, X, Y, no Z channels)
         """
         if channels is None:
             channels = (*self.params.tree,)
         if channel in channels:  # TODO start here to fetch channel using regex
-            return traps[:, channels.index(channel), 0]
+            return tiles[:, channels.index(channel), 0]
         elif channel in self.img_bgsub:
             return self.img_bgsub[channel]
 
@@ -622,7 +612,6 @@ class Extractor(StepABC):
             tps = list(range(self.meta["time_settings/ntimepoints"][0]))
         elif isinstance(tps, int):
             tps = [tps]
-
         # store results in dict
         d = {}
         for tp in tps:
@@ -669,7 +658,7 @@ class Extractor(StepABC):
         self.writer.id_cache.clear()
 
     def get_meta(self, flds: t.Union[str, t.Collection]):
-        # Obtain metadata for one or multiple fields
+        """Obtain metadata for one or multiple fields."""
         if isinstance(flds, str):
             flds = [flds]
         meta_short = {k.split("/")[-1]: v for k, v in self.meta.items()}
@@ -692,7 +681,7 @@ def flatten_nesteddict(
     to: str (optional)
         Specifies the format of the output, either pd.Series (default) or a list
     tp: int
-        Timepoint used to name the pd.Series
+        Time point used to name the pd.Series
 
     Returns
     -------
diff --git a/src/extraction/core/functions/cell.py b/src/extraction/core/functions/cell.py
index c3d99d27..0e7b9fe8 100644
--- a/src/extraction/core/functions/cell.py
+++ b/src/extraction/core/functions/cell.py
@@ -1,13 +1,15 @@
 """
-Base functions to extract information from a single cell
+Base functions to extract information from a single cell.
 
-These functions are automatically read by extractor.py, and so can only have the cell_mask and trap_image as inputs and must return only one value.
+These functions are automatically read by extractor.py, and
+so can only have the cell_mask and trap_image as inputs. They
+must return only one value.
 
 They assume that there are no NaNs in the image.
- We use bottleneck when it performs faster than numpy:
-- Median
-- values containing NaNs (We make sure this does not happen)
 
+We use the module bottleneck when it performs faster than numpy:
+- Median
+- values containing NaNs (but we make sure this does not happen)
 """
 import math
 import typing as t
@@ -19,24 +21,24 @@ from scipy import ndimage
 
 def area(cell_mask) -> int:
     """
-    Find the area of a cell mask
+    Find the area of a cell mask.
 
     Parameters
     ----------
     cell_mask: 2d array
-        Segmentation mask for the cell
+        Segmentation mask for the cell.
     """
     return np.sum(cell_mask)
 
 
 def eccentricity(cell_mask) -> float:
     """
-    Find the eccentricity using the approximate major and minor axes
+    Find the eccentricity using the approximate major and minor axes.
 
     Parameters
     ----------
     cell_mask: 2d array
-        Segmentation mask for the cell
+        Segmentation mask for the cell.
     """
     min_ax, maj_ax = min_maj_approximation(cell_mask)
     return np.sqrt(maj_ax**2 - min_ax**2) / maj_ax
@@ -44,12 +46,12 @@ def eccentricity(cell_mask) -> float:
 
 def mean(cell_mask, trap_image) -> float:
     """
-    Finds the mean of the pixels in the cell.
+    Find the mean of the pixels in the cell.
 
     Parameters
     ----------
     cell_mask: 2d array
-        Segmentation mask for the cell
+        Segmentation mask for the cell.
     trap_image: 2d array
     """
     return np.mean(trap_image[cell_mask])
@@ -57,12 +59,12 @@ def mean(cell_mask, trap_image) -> float:
 
 def median(cell_mask, trap_image) -> int:
     """
-    Finds the median of the pixels in the cell.
+    Find the median of the pixels in the cell.
 
     Parameters
     ----------
     cell_mask: 2d array
-        Segmentation mask for the cell
+         Segmentation mask for the cell.
     trap_image: 2d array
     """
     return bn.median(trap_image[cell_mask])
@@ -70,12 +72,12 @@ def median(cell_mask, trap_image) -> int:
 
 def max2p5pc(cell_mask, trap_image) -> float:
     """
-    Finds the mean of the brightest 2.5% of pixels in the cell.
+    Find the mean of the brightest 2.5% of pixels in the cell.
 
     Parameters
     ----------
     cell_mask: 2d array
-        Segmentation mask for the cell
+        Segmentation mask for the cell.
     trap_image: 2d array
     """
     # number of pixels in mask
@@ -84,19 +86,18 @@ def max2p5pc(cell_mask, trap_image) -> float:
     # sort pixels in cell and find highest 2.5%
     pixels = trap_image[cell_mask]
     top_values = bn.partition(pixels, len(pixels) - n_top)[-n_top:]
-
     # find mean of these highest pixels
     return np.mean(top_values)
 
 
 def max5px(cell_mask, trap_image) -> float:
     """
-    Finds the mean of the five brightest pixels in the cell.
+    Find the mean of the five brightest pixels in the cell.
 
     Parameters
     ----------
     cell_mask: 2d array
-        Segmentation mask for the cell
+        Segmentation mask for the cell.
     trap_image: 2d array
     """
     # sort pixels in cell
@@ -109,12 +110,12 @@ def max5px(cell_mask, trap_image) -> float:
 
 def std(cell_mask, trap_image):
     """
-    Finds the standard deviation of the values of the pixels in the cell.
+    Find the standard deviation of the values of the pixels in the cell.
 
     Parameters
     ----------
     cell_mask: 2d array
-        Segmentation mask for the cell
+        Segmentation mask for the cell.
     trap_image: 2d array
     """
     return np.std(trap_image[cell_mask])
@@ -122,12 +123,15 @@ def std(cell_mask, trap_image):
 
 def volume(cell_mask) -> float:
     """
-    Estimates the volume of the cell assuming it is an ellipsoid with the mask providing a cross-section through the median plane of the ellipsoid.
+    Estimate the volume of the cell.
+
+    Assumes the cell is an ellipsoid with the mask providing
+    a cross-section through its median plane.
 
     Parameters
     ----------
     cell_mask: 2d array
-        Segmentation mask for the cell
+        Segmentation mask for the cell.
     """
     min_ax, maj_ax = min_maj_approximation(cell_mask)
     return (4 * np.pi * min_ax**2 * maj_ax) / 3
@@ -135,7 +139,7 @@ def volume(cell_mask) -> float:
 
 def conical_volume(cell_mask):
     """
-    Estimates the volume of the cell
+    Estimate the volume of the cell.
 
     Parameters
     ----------
@@ -151,7 +155,10 @@ def conical_volume(cell_mask):
 
 def spherical_volume(cell_mask):
     """
-    Estimates the volume of the cell assuming it is a sphere with the mask providing a cross-section through the median plane of the sphere.
+    Estimate the volume of the cell.
+
+    Assumes the cell is a sphere with the mask providing
+    a cross-section through its median plane.
 
     Parameters
     ----------
@@ -165,7 +172,7 @@ def spherical_volume(cell_mask):
 
 def min_maj_approximation(cell_mask) -> t.Tuple[int]:
     """
-    Finds the lengths of the minor and major axes of an ellipse from a cell mask.
+    Find the lengths of the minor and major axes of an ellipse from a cell mask.
 
     Parameters
     ----------
diff --git a/src/extraction/core/functions/defaults.py b/src/extraction/core/functions/defaults.py
index 84798680..4fcb4094 100644
--- a/src/extraction/core/functions/defaults.py
+++ b/src/extraction/core/functions/defaults.py
@@ -2,23 +2,26 @@
 import re
 import typing as t
 from pathlib import PosixPath
-
 import h5py
 
+# should we move these functions here?
+from aliby.tile.tiler import find_channel_name
+
 
 def exparams_from_meta(
     meta: t.Union[dict, PosixPath, str], extras: t.Collection[str] = ["ph"]
 ):
     """
-    Obtain parameters from metadata of hdf5 file.
-    It compares a list of candidate channels using case-inspecific REGEX to identify valid channels.
+    Obtain parameters from metadata of the h5 file.
+
+    Compares a list of candidate channels using case-insensitive
+    REGEX to identify valid channels.
     """
-    meta = meta if isinstance(meta, dict) else load_attributes(meta)
+    meta = meta if isinstance(meta, dict) else load_metadata(meta)
     base = {
         "tree": {"general": {"None": ["area", "volume", "eccentricity"]}},
         "multichannel_ops": {},
     }
-
     candidate_channels = {
         "Citrine",
         "GFP",
@@ -30,7 +33,6 @@ def exparams_from_meta(
         "Cy5",
         "mKO2",
     }
-
     default_reductions = {"max"}
     default_metrics = {
         "mean",
@@ -40,33 +42,26 @@ def exparams_from_meta(
         "max5px",
         # "nuc_est_conv",
     }
-
-    # Defined ratiometric combinations that can be used as ratio
-    # key is numerator and value is denominator; add more to support additional channel names
+    # define ratiometric combinations
+    # key is numerator and value is denominator
+    # add more to support additional channel names
     ratiometric_combinations = {"phluorin405": ("phluorin488", "gfpfast")}
-
     default_reduction_metrics = {
         r: default_metrics for r in default_reductions
     }
     # default_rm["None"] = ["nuc_conv_3d"] # Uncomment this to add nuc_conv_3d (slow)
-
-    from aliby.tile.tiler import find_channel_name
-
     extant_fluorescence_ch = []
     for av_channel in candidate_channels:
-        # Find channels in metadata whose names match
+        # find matching channels in metadata
         found_channel = find_channel_name(meta.get("channels", []), av_channel)
         if found_channel is not None:
             extant_fluorescence_ch.append(found_channel)
-
     for ch in extant_fluorescence_ch:
         base["tree"][ch] = default_reduction_metrics
-
     base["sub_bg"] = extant_fluorescence_ch
-
-    # Additional extraction defaults when channels available
+    # additional extraction defaults if the channels are available
     if "ph" in extras:
-        # SWAINLAB-specific names
+        # SWAINLAB specific names
         # find first valid combination of ratiometric fluorescence channels
         numerator_channel, denominator_channel = (None, None)
         for ch1, chs2 in ratiometric_combinations.items():
@@ -80,8 +75,7 @@ def exparams_from_meta(
                     if found_channel2:
                         denominator_channel = found_channel2
                         break
-
-        # If two compatible ratiometric channels are available
+        # if two compatible ratiometric channels are available
         if numerator_channel is not None and denominator_channel is not None:
             sets = {
                 b + a: (x, y)
@@ -102,11 +96,11 @@ def exparams_from_meta(
                     *v,
                     default_reduction_metrics,
                 ]
-
     return base
 
 
-def load_attributes(file: t.Union[str, PosixPath], group="/"):
+def load_metadata(file: t.Union[str, PosixPath], group="/"):
+    """Get meta data from an h5 file."""
     with h5py.File(file, "r") as f:
         meta = dict(f[group].attrs.items())
     return meta
diff --git a/src/extraction/core/functions/loaders.py b/src/extraction/core/functions/loaders.py
index 9d2e9c47..ff83b20c 100644
--- a/src/extraction/core/functions/loaders.py
+++ b/src/extraction/core/functions/loaders.py
@@ -11,14 +11,13 @@ from extraction.core.functions.math_utils import div0
 
 """
 Load functions for analysing cells and their background.
-Note that inspect.getmembers returns a list of function names and functions, and inspect.getfullargspec returns a function's arguments.
+Note that inspect.getmembers returns a list of function names and functions,
+and inspect.getfullargspec returns a function's arguments.
 """
 
 
 def load_cellfuns_core():
-    """
-    Load functions from the cell module and return as a dict.
-    """
+    """Load functions from the cell module and return as a dict."""
     return {
         f[0]: f[1]
         for f in getmembers(cell)
@@ -31,7 +30,10 @@ def load_custom_args() -> t.Tuple[
     (t.Dict[str, t.Callable], t.Dict[str, t.List[str]])
 ]:
     """
-    Load custom functions from the localisation module and return the functions and any additional arguments, other than cell_mask and trap_image, as dictionaries.
+    Load custom functions from the localisation module.
+
+    Return the functions and any additional arguments other
+    than cell_mask and trap_image as dictionaries.
     """
     # load functions from module
     funs = {
@@ -57,7 +59,8 @@ def load_custom_args() -> t.Tuple[
 
 def load_cellfuns():
     """
-    Creates a dict of core functions that can be used on an array of cell_masks.
+    Create a dict of core functions for use on cell_masks.
+
     The core functions only work on a single mask.
     """
     # create dict of the core functions from cell.py - these functions apply to a single mask
@@ -81,9 +84,7 @@ def load_cellfuns():
 
 
 def load_trapfuns():
-    """
-    Load functions that are applied to an entire trap or tile or subsection of an image rather than to single cells.
-    """
+    """Load functions that are applied to an entire tile."""
     TRAPFUNS = {
         f[0]: f[1]
         for f in getmembers(trap)
@@ -94,9 +95,7 @@ def load_trapfuns():
 
 
 def load_funs():
-    """
-    Combine all automatically loaded functions
-    """
+    """Combine all automatically loaded functions."""
     CELLFUNS = load_cellfuns()
     TRAPFUNS = load_trapfuns()
     # return dict of cell funs, dict of trap funs, and dict of both
@@ -111,7 +110,10 @@ def load_redfuns(
     """
     Load functions to reduce a multidimensional image by one dimension.
 
-    It can take custom functions as arguments.
+    Parameters
+    ----------
+    additional_reducers: function or a dict of functions (optional)
+        Functions to perform the reduction.
     """
     RED_FUNS = {
         "max": bn.nanmax,
@@ -121,12 +123,10 @@ def load_redfuns(
         "add": bn.nansum,
         "None": None,
     }
-
     if additional_reducers is not None:
         if isinstance(additional_reducers, FunctionType):
             additional_reducers = [
                 (additional_reducers.__name__, additional_reducers)
             ]
-        RED_FUNS.update(name, fun)
-
+        RED_FUNS.update(additional_reducers)
     return RED_FUNS
diff --git a/src/extraction/core/functions/math_utils.py b/src/extraction/core/functions/math_utils.py
index eeae8e0c..a6216ea9 100644
--- a/src/extraction/core/functions/math_utils.py
+++ b/src/extraction/core/functions/math_utils.py
@@ -20,7 +20,6 @@ def div0(array, fill=0, axis=-1):
     slices_0, slices_1 = [[slice(None)] * len(array.shape)] * 2
     slices_0[axis] = 0
     slices_1[axis] = 1
-
     with np.errstate(divide="ignore", invalid="ignore"):
         c = np.true_divide(
             array[tuple(slices_0)],
diff --git a/src/extraction/core/functions/trap.py b/src/extraction/core/functions/trap.py
index b3cd7d13..f1f491e0 100644
--- a/src/extraction/core/functions/trap.py
+++ b/src/extraction/core/functions/trap.py
@@ -5,14 +5,14 @@ import numpy as np
 
 def imBackground(cell_masks, trap_image):
     """
-    Finds the median background (pixels not comprising cells) from trap_image
+    Find the median background (pixels not comprising cells) from trap_image.
 
     Parameters
     ----------
     cell_masks: 3d array
        Segmentation masks for cells
     trap_image:
-        The image (all channels) for the tile containing the cell
+        The image (all channels) for the tile containing the cell.
     """
     if not len(cell_masks):
         # create cell_masks if none are given
@@ -25,14 +25,14 @@ def imBackground(cell_masks, trap_image):
 
 def background_max5(cell_masks, trap_image):
     """
-    Finds the mean of the maximum five pixels of the background (pixels not comprising cells) from trap_image
+    Finds the mean of the maximum five pixels of the background.
 
     Parameters
     ----------
     cell_masks: 3d array
-        Segmentation masks for cells
+        Segmentation masks for cells.
     trap_image:
-        The image (all channels) for the tile containing the cell
+        The image (all channels) for the tile containing the cell.
     """
     if not len(cell_masks):
         # create cell_masks if none are given
-- 
GitLab