diff --git a/src/agora/io/bridge.py b/src/agora/io/bridge.py
index 4a2f6093e7d8aaf4ed8e0e0f8abfcfaffb796449..478408f37d1f0724a232876f0f5ce0da5b02b585 100644
--- a/src/agora/io/bridge.py
+++ b/src/agora/io/bridge.py
@@ -1,5 +1,5 @@
 """
-Tools to interact with hdf5 files and handle data consistently.
+Tools to interact with h5 files and handle data consistently.
 """
 import collections
 from itertools import chain, groupby, product
@@ -13,26 +13,28 @@ import yaml
 
 class BridgeH5:
     """
-    Base class to interact with h5 data stores.
-    It also contains functions useful to predict how long should segmentation take.
+    Base class to interact with h5 files.
+
+    It includes functions that predict how long segmentation will take.
     """
 
     def __init__(self, filename, flag="r"):
+        """Initialise with the name of the h5 file."""
         self.filename = filename
         if flag is not None:
             self._hdf = h5py.File(filename, flag)
-
             self._filecheck
 
     def _filecheck(self):
         assert "cell_info" in self._hdf, "Invalid file. No 'cell_info' found."
 
     def close(self):
+        """Close the h5 file."""
         self._hdf.close()
 
     @property
     def meta_h5(self) -> t.Dict[str, t.Any]:
-        # Return metadata as indicated in h5 file
+        """Return metadata, defining it if necessary."""
         if not hasattr(self, "_meta_h5"):
             with h5py.File(self.filename, "r") as f:
                 self._meta_h5 = dict(f.attrs)
@@ -44,24 +46,24 @@ class BridgeH5:
 
     @staticmethod
     def get_consecutives(tree, nstepsback):
-        # Receives a sorted tree and returns the keys of consecutive elements
-        vals = {k: np.array(list(v)) for k, v in tree.items()}  # get tp level
+        """Receives a sorted tree and returns the keys of consecutive elements."""
+        # get tp level
+        vals = {k: np.array(list(v)) for k, v in tree.items()}
+        # get indices of consecutive elements
         where_consec = [
             {
                 k: np.where(np.subtract(v[n + 1 :], v[: -n - 1]) == n + 1)[0]
                 for k, v in vals.items()
             }
             for n in range(nstepsback)
-        ]  # get indices of consecutive elements
+        ]
         return where_consec
 
     def get_npairs(self, nstepsback=2, tree=None):
         if tree is None:
             tree = self.cell_tree
-
         consecutive = self.get_consecutives(tree, nstepsback=nstepsback)
         flat_tree = flatten(tree)
-
         n_predictions = 0
         for i, d in enumerate(consecutive, 1):
             flat = list(chain(*[product([k], list(v)) for k, v in d.items()]))
@@ -70,55 +72,49 @@ class BridgeH5:
                 n_predictions += len(flat_tree.get(p[0], [])) * len(
                     flat_tree.get(p[1], [])
                 )
-
         return n_predictions
 
     def get_npairs_over_time(self, nstepsback=2):
         tree = self.cell_tree
         npairs = []
-        for t in self._hdf["cell_info"]["processed_timepoints"][()]:
+        for tp in self._hdf["cell_info"]["processed_timepoints"][()]:
             tmp_tree = {
-                k: {k2: v2 for k2, v2 in v.items() if k2 <= t}
+                k: {k2: v2 for k2, v2 in v.items() if k2 <= tp}
                 for k, v in tree.items()
             }
             npairs.append(self.get_npairs(tree=tmp_tree))
-
         return np.diff(npairs)
 
     def get_info_tree(
         self, fields: Union[tuple, list] = ("trap", "timepoint", "cell_label")
     ):
         """
-        Returns traps, time points and labels for this position in form of a tree
-        in the hierarchy determined by the argument fields. Note that it is
-        compressed to non-empty elements and timepoints.
+        Return traps, time points and labels for this position in the form of a tree in the hierarchy determined by the argument fields.
+
+        Note that it is compressed to non-empty elements and timepoints.
 
         Default hierarchy is:
         - trap
         - time point
         - cell label
 
-        This function currently produces trees of depth 3, but it can easily be
-        extended for deeper trees if needed (e.g. considering groups,
-        chambers and/or positions).
+        This function currently produces trees of depth 3, but it can easily be extended for deeper trees if needed (e.g. considering groups, chambers and/or positions).
 
         Parameters
         ----------
-            fields: Fields to fetch from 'cell_info' inside the hdf5 storage
+        fields: list of strs
+            Fields to fetch from 'cell_info' inside the h5 file.
 
         Returns
         ----------
-            Nested dictionary where keys (or branches) are the upper levels
-             and the leaves are the last element of :fields:.
+        Nested dictionary where keys (or branches) are the upper levels and the leaves are the last element of :fields:.
         """
         zipped_info = (*zip(*[self._hdf["cell_info"][f][()] for f in fields]),)
-
         return recursive_groupsort(zipped_info)
 
 
 def groupsort(iterable: Union[tuple, list]):
-    # Sorts iterable and returns a dictionary where the values are grouped by the first element.
-
+    """Sorts iterable and returns a dictionary where the values are grouped by the first element."""
     iterable = sorted(iterable, key=lambda x: x[0])
     grouped = {
         k: [x[1:] for x in v] for k, v in groupby(iterable, lambda x: x[0])
@@ -127,17 +123,18 @@ def groupsort(iterable: Union[tuple, list]):
 
 
 def recursive_groupsort(iterable):
-    # Recursive extension of groupsort
+    """Recursive extension of groupsort."""
     if len(iterable[0]) > 1:
         return {
             k: recursive_groupsort(v) for k, v in groupsort(iterable).items()
         }
-    else:  # Only two elements in list
+    else:
+        # only two elements in list
         return [x[0] for x in iterable]
 
 
 def flatten(d, parent_key="", sep="_"):
-    """Flatten nested dict. Adapted from https://stackoverflow.com/a/6027615"""
+    """Flatten nested dict. Adapted from https://stackoverflow.com/a/6027615."""
     items = []
     for k, v in d.items():
         new_key = parent_key + (k,) if parent_key else (k,)
@@ -149,18 +146,19 @@ def flatten(d, parent_key="", sep="_"):
 
 
 def attrs_from_h5(fpath: str):
-    """Return attributes as dict from h5 file"""
+    """Return attributes as dict from an h5 file."""
     with h5py.File(fpath, "r") as f:
         return dict(f.attrs)
 
 
 def parameters_from_h5(fpath: str):
+    """Return parameters from an h5 file."""
     attrs = attrs_from_h5(fpath)
     return yaml.safe_load(attrs["parameters"])
 
 
 def image_creds_from_h5(fpath: str):
-    """Return image id and server credentials from h5"""
+    """Return image id and server credentials from an h5."""
     attrs = attrs_from_h5(fpath)
     return (
         attrs["image_id"],
diff --git a/src/agora/io/writer.py b/src/agora/io/writer.py
index c448f59f64f59c9bc77ace34fc931e014eecd04a..f11827fa9adb6818e354c0574f91a419770f8f9a 100644
--- a/src/agora/io/writer.py
+++ b/src/agora/io/writer.py
@@ -18,35 +18,78 @@ from agora.io.utils import timed
 
 
 def load_attributes(file: str, group="/"):
+    """
+    Load the metadata from an h5 file and convert to a dictionary, including the "parameters" field which is stored as YAML.
+
+    Parameters
+    ----------
+    file: str
+        Name of the h5 file
+    group: str, optional
+        The group in the h5 file from which to read the data
+    """
+    # load the metadata, stored as attributes, from the h5 file and return as a dictionary
     with h5py.File(file, "r") as f:
         meta = dict(f[group].attrs.items())
     if "parameters" in meta:
+        # convert from yaml format into dict
         meta["parameters"] = yaml.safe_load(meta["parameters"])
     return meta
 
 
 class DynamicWriter:
+    """Provides a parent class for all writers."""
+
+    # a dict giving for each dataset a tuple, comprising the dataset's maximum size, as a 2D tuple, and its type
     data_types = {}
+    # the group in the h5 file to write to
     group = ""
+    # compression info
     compression = "gzip"
     compression_opts = 9
     metadata = None
 
     def __init__(self, file: str):
         self.file = file
+        # the metadata is stored as attributes in the h5 file
         if Path(file).exists():
             self.metadata = load_attributes(file)
 
     def _append(self, data, key, hgroup):
-        """Append data to existing dataset."""
+        """
+        Append data to existing dataset in the h5 file otherwise create a new one.
+
+        Parameters
+        ----------
+        data
+            Data to be written, typically a numpy array
+        key: str
+            Name of dataset
+        hgroup: str
+            Destination group in the h5 file
+        """
         try:
             n = len(data)
         except Exception as e:
             logging.debug(
-                "DynamicWriter:Attributes have no length: {}".format(e)
+                "DynamicWriter: Attributes have no length: {}".format(e)
             )
             n = 1
-        if key not in hgroup:
+        if key in hgroup:
+            # append to existing dataset
+            try:
+                # FIXME This is broken by bugged mother-bud assignment
+                dset = hgroup[key]
+                dset.resize(dset.shape[0] + n, axis=0)
+                dset[-n:] = data
+            except Exception as e:
+                logging.debug(
+                    "DynamicWriter: Inconsistency between dataset shape and new empty data: {}".format(
+                        e
+                    )
+                )
+        else:
+            # create new dataset
             # TODO Include sparsity check
             max_shape, dtype = self.datatypes[key]
             shape = (n,) + max_shape[1:]
@@ -60,66 +103,87 @@ class DynamicWriter:
                 if self.compression is not None
                 else None,
             )
+            # write all data, signified by the empty tuple
             hgroup[key][()] = data
-        else:
-            # The dataset already exists, expand it
-
-            try:  # FIXME This is broken by bugged mother-bud assignment
-                dset = hgroup[key]
-                dset.resize(dset.shape[0] + n, axis=0)
-                dset[-n:] = data
-            except Exception as e:
-                logging.debug(
-                    "DynamicWriter:Inconsistency between dataset shape and new empty data: {}".format(
-                        e
-                    )
-                )
-        return
 
     def _overwrite(self, data, key, hgroup):
-        """Overwrite existing dataset with new data"""
+        """
+        Delete and then replace existing dataset in h5 file.
+
+        Parameters
+        ----------
+        data
+            Data to be written, typically a numpy array
+        key: str
+            Name of dataset
+        hgroup: str
+            Destination group in the h5 file
+        """
         # We do not append to mother_assign; raise error if already saved
         data_shape = np.shape(data)
         max_shape, dtype = self.datatypes[key]
+        # delete existing data
         if key in hgroup:
             del hgroup[key]
+        # write new data
         hgroup.require_dataset(
-            key, shape=data_shape, dtype=dtype, compression=self.compression
+            key,
+            shape=data_shape,
+            dtype=dtype,
+            compression=self.compression,
         )
+        # write all data, signified by the empty tuple
         hgroup[key][()] = data
 
-    def _check_key(self, key):
-        if key not in self.datatypes:
-            raise KeyError(f"No defined data type for key {key}")
+    # def _check_key(self, key):
+    #     if key not in self.datatypes:
+    #         raise KeyError(f"No defined data type for key {key}")
+
+    def write(self, data: dict, overwrite: list, meta: dict = {}):
+        """
+        Write data and metadata to h5 file.
 
-    def write(self, data, overwrite: list, meta={}):
-        # Data is a dictionary, if not, make it one
-        # Overwrite data is a list
+         Parameters
+        ----------
+        data: dict
+            A dict of datasets and data
+        overwrite: list of str
+            A list of datasets to overwrite
+        meta: dict, optional
+            Metadata to be written as attributes of the h5 file
+        """
         with h5py.File(self.file, "a") as store:
+            # open group, creating if necessary
             hgroup = store.require_group(self.group)
-
+            # write data
             for key, value in data.items():
-                # We're only saving data that has a pre-defined data-type
-                self._check_key(key)
-                try:
-                    if key.startswith("attrs/"):  # metadata
-                        key = key.split("/")[1]  # First thing after attrs
-                        hgroup.attrs[key] = value
-                    elif key in overwrite:
-                        self._overwrite(value, key, hgroup)
-                    else:
-                        self._append(value, key, hgroup)
-                except Exception as e:
-                    print(key, value)
-                    raise (e)
+                # only save data with a pre-defined data-type
+                if key not in self.datatypes:
+                    raise KeyError(f"No defined data type for key {key}")
+                else:
+                    try:
+                        if key.startswith("attrs/"):
+                            # metadata
+                            key = key.split("/")[1]
+                            hgroup.attrs[key] = value
+                        elif key in overwrite:
+                            # delete and replace existing dataset
+                            self._overwrite(value, key, hgroup)
+                        else:
+                            # append or create new dataset
+                            self._append(value, key, hgroup)
+                    except Exception as e:
+                        print(key, value)
+                        raise (e)
+            # write metadata
             for key, value in meta.items():
                 hgroup.attrs[key] = value
 
-        return
-
 
 ##################### Special instances #####################
 class TilerWriter(DynamicWriter):
+    """Write data stored in a Tiler instance to h5 files."""
+
     datatypes = {
         "trap_locations": ((None, 2), np.uint16),
         "drifts": ((None, 2), np.float32),
@@ -128,30 +192,40 @@ class TilerWriter(DynamicWriter):
     }
     group = "trap_info"
 
-    def write(self, data, overwrite: list, tp: int, meta={}):
-        """
-        Skips writing data if it were to overwrite it,using drift as a marker
+    def write(self, data: dict, overwrite: list, tp: int, meta: dict = {}):
         """
+        Write data for time points that have none.
 
+        Parameters
+        ----------
+        data: dict
+            A dict of datasets and data
+        overwrite: list of str
+            A list of datasets to overwrite
+        tp: int
+            The time point of interest
+        meta: dict, optional
+            Metadata to be written as attributes of the h5 file
+        """
         skip = False
+        # append to h5 file
         with h5py.File(self.file, "a") as store:
+            # open group, creating if necessary
             hgroup = store.require_group(self.group)
-
+            # find xy drift for each time point as proof that it has already been processed
             nprev = hgroup.get("drifts", None)
             if nprev and tp < nprev.shape[0]:
+                # data already exists
                 print(f"Tiler: Skipping timepoint {tp}")
                 skip = True
-
         if not skip:
             super().write(data=data, overwrite=overwrite, meta=meta)
 
 
-tile_size = 117
-
-
+# Alan: we use complex numbers because...
 @timed()
 def save_complex(array, dataset):
-    # Dataset needs to be 2D
+    # append array, an 1D array of complex numbers, onto dataset, a 2D array of real numbers
     n = len(array)
     if n > 0:
         dataset.resize(dataset.shape[0] + n, axis=0)
@@ -161,22 +235,33 @@ def save_complex(array, dataset):
 
 @timed()
 def load_complex(dataset):
+    # convert 2D dataset into a 1D array of complex numbers
     array = dataset[:, 0] + 1j * dataset[:, 1]
     return array
 
 
 class BabyWriter(DynamicWriter):
+    """
+    Write data stored in a Baby instance to h5 files.
+
+    Assumes the edgemasks are of form ((max_ncells, max_tps, tile_size, tile_size), bool).
+    """
+
     compression = "gzip"
-    max_ncells = 2e5  # Could just make this None
+    max_ncells = 2e5  # Alan: Could just make this None
     max_tps = 1e3  # Could just make this None
-    chunk_cells = 25  # The number of cells in a chunk for edge masks
+    # the number of cells in a chunk for edge masks
+    chunk_cells = 25
     default_tile_size = 117
     datatypes = {
         "centres": ((None, 2), np.uint16),
         "position": ((None,), np.uint16),
         "angles": ((None,), h5py.vlen_dtype(np.float32)),
         "radii": ((None,), h5py.vlen_dtype(np.float32)),
-        "edgemasks": ((max_ncells, max_tps, tile_size, tile_size), bool),
+        "edgemasks": (
+            (max_ncells, max_tps, default_tile_size, default_tile_size),
+            bool,
+        ),
         "ellipse_dims": ((None, 2), np.float32),
         "cell_label": ((None,), np.uint16),
         "trap": ((None,), np.uint16),
@@ -189,11 +274,10 @@ class BabyWriter(DynamicWriter):
 
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
-        # Get max_tps and trap info
         self._traps_initialised = False
 
     def __init_trap_info(self):
-        # Should only be run after the traps have been initialised
+        # requires traps to have been initialised
         trap_metadata = load_attributes(self.file, "trap_info")
         tile_size = trap_metadata.get("tile_size", self.default_tile_size)
         max_tps = self.metadata["time_settings/ntimepoints"][0]
@@ -204,9 +288,8 @@ class BabyWriter(DynamicWriter):
         self._traps_initialised = True
 
     def __init_edgemasks(self, hgroup, edgemasks, current_indices, n_cells):
-        # Create values dataset
-        # This holds the edge masks directly and
-        # Is of shape (n_tps, n_cells, tile_size, tile_size)
+        # create the values dataset in the h5 file
+        # holds the edge masks and has shape (n_tps, n_cells, tile_size, tile_size)
         key = "edgemasks"
         max_shape, dtype = self.datatypes[key]
         shape = (n_cells, 1) + max_shape[2:]
@@ -220,9 +303,8 @@ class BabyWriter(DynamicWriter):
             compression=self.compression,
         )
         val_dset[:, 0] = edgemasks
-        # Create index dataset
-        # Holds the (trap, cell_id) description used to index into the
-        # values and is of shape (n_cells, 2)
+        # create index dataset in the h5 file:
+        # the (trap, cell_id) description used to index into the values and has shape (n_cells, 2)
         ix_max_shape = (max_shape[0], 2)
         ix_shape = (0, 2)
         ix_dtype = np.uint16
@@ -236,25 +318,24 @@ class BabyWriter(DynamicWriter):
         save_complex(current_indices, ix_dset)
 
     def __append_edgemasks(self, hgroup, edgemasks, current_indices):
-        # key = "edgemasks"
         val_dset = hgroup["values"]
         ix_dset = hgroup["indices"]
         existing_indices = load_complex(ix_dset)
-        # Check if there are any new labels
+        # check if there are any new labels
         available = np.in1d(current_indices, existing_indices)
         missing = current_indices[~available]
         all_indices = np.concatenate([existing_indices, missing])
-        # Resizing
-        t = perf_counter()
+        # resizing
+        debug_t = perf_counter()  # for timing code for debugging
         n_tps = val_dset.shape[1] + 1
         n_add_cells = len(missing)
-        # RESIZE DATASET FOR TIME and Cells
+        # resize dataset for Time and Cells
         new_shape = (val_dset.shape[0] + n_add_cells, n_tps) + val_dset.shape[
             2:
         ]
         val_dset.resize(new_shape)
-        logging.debug(f"Timing:resizing:{perf_counter() - t}")
-        # Writing data
+        logging.debug(f"Timing:resizing:{perf_counter() - debug_t}")
+        # write data
         cell_indices = np.where(np.in1d(all_indices, current_indices))[0]
         for ix, mask in zip(cell_indices, edgemasks):
             try:
@@ -265,75 +346,112 @@ class BabyWriter(DynamicWriter):
                         e, ix, n_tps, val_dset.shape
                     )
                 )
-        # Save the index values
+        # save the index values
         save_complex(missing, ix_dset)
 
     def write_edgemasks(self, data, keys, hgroup):
+        """
+        Write edgemasks to h5 file.
+
+        Parameters
+        ----------
+        data: list of arrays
+            Data to be written, in the form (trap_ids, cell_labels, edgemasks)
+        keys: list of str
+            Names corresponding to the elements of data.
+            For example: ["trap", "cell_label", "edgemasks"]
+        hgroup: group object
+            Group to write to in h5 file.
+        """
         if not self._traps_initialised:
             self.__init_trap_info()
-        # DATA is TRAP_IDS, CELL_LABELS, EDGEMASKS in a structured array
         key = "edgemasks"
         val_key = "values"
-        # idx_key = "indices"
-        # Length of edgemasks
         traps, cell_labels, edgemasks = data
         n_cells = len(cell_labels)
         hgroup = hgroup.require_group(key)
+        # create complex indices with traps as real part and cell_labels as imaginary part
         current_indices = np.array(traps) + 1j * np.array(cell_labels)
         if val_key not in hgroup:
             self.__init_edgemasks(hgroup, edgemasks, current_indices, n_cells)
         else:
             self.__append_edgemasks(hgroup, edgemasks, current_indices)
 
-    def write(self, data, overwrite: list, tp: int = None, meta={}):
+    def write(
+        self, data: dict, overwrite: list, tp: int = None, meta: dict = {}
+    ):
+        """
+        Write data from a Baby instance, including edgemasks.
+
+        Parameters
+        ----------
+        data: dict
+            A dict of datasets and data
+        overwrite: list of str
+            A list of datasets to overwrite
+        tp: int
+            The time point of interest
+        meta: dict, optional
+            Metadata to be written as attributes of the h5 file
+        """
         with h5py.File(self.file, "a") as store:
             hgroup = store.require_group(self.group)
-
+            # write data
             for key, value in data.items():
-                # We're only saving data that has a pre-defined data-type
-                self._check_key(key)
-                try:
-                    if key.startswith("attrs/"):  # metadata
-                        key = key.split("/")[1]  # First thing after attrs
-                        hgroup.attrs[key] = value
-                    elif key in overwrite:
-                        self._overwrite(value, key, hgroup)
-                    elif key == "edgemasks":
-                        keys = ["trap", "cell_label", "edgemasks"]
-                        value = [data[x] for x in keys]
-
-                        edgemask_dset = hgroup.get(key + "/values", None)
-                        if (
-                            # tp > 0
-                            edgemask_dset
-                            and tp < edgemask_dset[()].shape[1]
-                        ):
-                            print(f"BabyWriter: Skipping edgemasks in tp {tp}")
+                if key not in self.datatypes:
+                    raise KeyError(
+                        f"BabyWriter: No defined data type for key {key}"
+                    )
+                else:
+                    try:
+                        if key.startswith("attrs/"):
+                            # metadata
+                            key = key.split("/")[1]
+                            hgroup.attrs[key] = value
+                        elif key in overwrite:
+                            # delete and replace existing dataset
+                            self._overwrite(value, key, hgroup)
+                        elif key == "edgemasks":
+                            keys = ["trap", "cell_label", "edgemasks"]
+                            value = [data[x] for x in keys]
+                            edgemask_dset = hgroup.get(key + "/values", None)
+                            if (
+                                edgemask_dset
+                                and tp < edgemask_dset[()].shape[1]
+                            ):
+                                # data already exists
+                                print(
+                                    f"BabyWriter: Skipping edgemasks in tp {tp}"
+                                )
+                            else:
+                                self.write_edgemasks(value, keys, hgroup)
                         else:
-                            # print(f"BabyWriter: Writing edgemasks in tp {tp}")
-                            self.write_edgemasks(value, keys, hgroup)
-                    else:
-                        self._append(value, key, hgroup)
-                except Exception as e:
-                    print(key, value)
-                    raise (e)
-
-        # Meta
+                            # append or create new dataset
+                            self._append(value, key, hgroup)
+                    except Exception as e:
+                        print(key, value)
+                        raise (e)
+        # write metadata
         for key, value in meta.items():
             hgroup.attrs[key] = value
 
-        return
-
 
 class LinearBabyWriter(DynamicWriter):
-    # TODO make this YAML
+    """
+    Write data stored in a Baby instance to h5 files.
+
+    Assumes the edgemasks are of form ((None, tile_size, tile_size), bool).
+    """
+
+    # TODO make this YAML: Alan: why?
     compression = "gzip"
+    _default_tile_size = 117
     datatypes = {
         "centres": ((None, 2), np.uint16),
         "position": ((None,), np.uint16),
         "angles": ((None,), h5py.vlen_dtype(np.float32)),
         "radii": ((None,), h5py.vlen_dtype(np.float32)),
-        "edgemasks": ((None, tile_size, tile_size), bool),
+        "edgemasks": ((None, _default_tile_size, _default_tile_size), bool),
         "ellipse_dims": ((None, 2), np.float32),
         "cell_label": ((None,), np.uint16),
         "trap": ((None,), np.uint16),
@@ -344,32 +462,63 @@ class LinearBabyWriter(DynamicWriter):
     }
     group = "cell_info"
 
-    def write(self, data, overwrite: list, tp=None, meta={}):
-        # Data is a dictionary, if not, make it one
-        # Overwrite data is a list
+    def write(
+        self, data: dict, overwrite: list, tp: int = None, meta: dict = {}
+    ):
+        """
+        Check data does not exist before writing.
 
+        Parameters
+        ----------
+        data: dict
+            A dict of datasets and data
+        overwrite: list of str
+            A list of datasets to overwrite
+        tp: int
+            The time point of interest
+        meta: dict, optional
+            Metadata to be written as attributes of the h5 file
+        """
         with h5py.File(self.file, "a") as store:
             hgroup = store.require_group(self.group)
             available_tps = hgroup.get("timepoint", None)
+            # write data
             if not available_tps or tp not in np.unique(available_tps[()]):
                 super().write(data, overwrite)
             else:
+                # data already exists
                 print(f"BabyWriter: Skipping tp {tp}")
-
+            # write metadata
             for key, value in meta.items():
                 hgroup.attrs[key] = value
 
 
 class StateWriter(DynamicWriter):
+    """
+    Write information summarising the current state of the pipeline - the 'last_state' dataset in the h5 file.
+
+    MOVEDatatypes are specified with the first variable specifying the number of traps and the other specifying the shape of the object.
+
+    """
+
     datatypes = {
+        # the highest cell label assigned for each time point
         "max_lbl": ((None, 1), np.uint16),
+        # how far back we go for tracking
         "tp_back": ((None, 1), np.uint16),
+        # trap labels
         "trap": ((None, 1), np.int16),
+        # cell labels
         "cell_lbls": ((None, 1), np.uint16),
+        # previous cell features for tracking
         "prev_feats": ((None, None), np.float32),
+        # number of images for which a cell has been present
         "lifetime": ((None, 2), np.uint16),
+        # probability of a mother-bud relationship given a bud
         "p_was_bud": ((None, 2), np.float32),
+        # probability of a mother-bud relationship given a mother
         "p_is_mother": ((None, 2), np.float32),
+        # cumulative matrix, over time, of bud assignments
         "ba_cum": ((None, None), np.float32),
     }
     group = "last_state"
@@ -377,29 +526,33 @@ class StateWriter(DynamicWriter):
 
     @staticmethod
     def format_field(states: list, field: str):
-        # Flatten a field in the states list to save as an hdf5 dataset
+        """Flatten a field in the states list to save as an h5 dataset."""
         fields = [pos_state[field] for pos_state in states]
         return fields
 
     @staticmethod
     def format_values_tpback(states: list, val_name: str):
+        """Unpacks a dict of state data into tp_back, trap, value."""
+        # initialise as empty lists
+        # Alan: is this initialisation necessary?
         tp_back, trap, value = [
             [[] for _ in states[0][val_name]] for _ in range(3)
         ]
-
+        # store results as a list of tuples
         lbl_tuples = [
             (tp_back, trap, cell_label)
             for trap, state in enumerate(states)
             for tp_back, value in enumerate(state[val_name])
             for cell_label in value
         ]
+        # unpack list of tuples to define variables
         if len(lbl_tuples):
             tp_back, trap, value = zip(*lbl_tuples)
-
         return tp_back, trap, value
 
     @staticmethod
     def format_values_traps(states: list, val_name: str):
+        """Format either lifetime, p_was_bud, or p_is_mother variables as a list."""
         formatted = np.array(
             [
                 (trap, clabel_val)
@@ -411,113 +564,106 @@ class StateWriter(DynamicWriter):
 
     @staticmethod
     def pad_if_needed(array: np.ndarray, pad_size: int):
+        """Pad a 2D array with zeros for large indices."""
         padded = np.zeros((pad_size, pad_size)).astype(float)
         length = len(array)
         padded[:length, :length] = array
-
         return padded
 
     def format_states(self, states: list):
+        """Re-format state data into a dict of lists, with one element per per list per state."""
         formatted_state = {"max_lbl": [state["max_lbl"] for state in states]}
         tp_back, trap, cell_label = self.format_values_tpback(
             states, "cell_lbls"
         )
         _, _, prev_feats = self.format_values_tpback(states, "prev_feats")
-
-        # Heterogeneous datasets
+        # store lists in a dict
         formatted_state["tp_back"] = tp_back
         formatted_state["trap"] = trap
         formatted_state["cell_lbls"] = cell_label
         formatted_state["prev_feats"] = np.array(prev_feats)
-
-        # One entry per cell label - tp_back independent
+        # one entry per cell label - tp_back independent
         for val_name in ("lifetime", "p_was_bud", "p_is_mother"):
             formatted_state[val_name] = self.format_values_traps(
                 states, val_name
             )
-
         bacum_max = max([len(state["ba_cum"]) for state in states])
-
         formatted_state["ba_cum"] = np.array(
             [
                 self.pad_if_needed(state["ba_cum"], bacum_max)
                 for state in states
             ]
         )
-
         return formatted_state
 
-    def write(self, data, overwrite: Iterable, tp: int = None):
-        # formatted_data = self.format_states(data)
-        # super().write(data=formatted_data, overwrite=overwrite)
+    def write(self, data: dict, overwrite: list, tp: int = 0):
+        """Write the current state of the pipeline."""
         if len(data):
             last_tp = 0
-            if tp is None:
-                tp = 0
-
             try:
                 with h5py.File(self.file, "r") as f:
                     gr = f.get(self.group, None)
                     if gr:
                         last_tp = gr.attrs.get("tp", 0)
-
-                # print(f"{ self.file } - tp: {tp}, last_tp: {last_tp}")
                 if tp == 0 or tp > last_tp:
-                    # print(f"Writing timepoint {tp}")
+                    # write
                     formatted_data = self.format_states(data)
                     super().write(data=formatted_data, overwrite=overwrite)
                     with h5py.File(self.file, "a") as f:
-                        # print(f"Writing tp {tp}")
+                        # record that data for the timepoint has been written
                         f[self.group].attrs["tp"] = tp
                 elif tp > 0 and tp <= last_tp:
-                    print(f"BabyWriter: Skipping timepoint {tp}")
+                    # data already present
+                    print(f"StateWriter: Skipping timepoint {tp}")
             except Exception as e:
                 raise (e)
         else:
-            print("Skipping overwriting empty state")
+            print("Skipping overwriting: no data")
 
 
 #################### Extraction version ###############################
 class Writer(BridgeH5):
-    """
-    Class in charge of transforming data into compatible formats
+    """Class to transform data into compatible structures."""
 
-    Decoupling interface from implementation!
+    # Alan: when is this used?
 
-    Parameters
-    ----------
-        filename: str Name of file to write into
+    def __init__(self, filename, flag=None, compression="gzip"):
+        """
+        Initialise write.
+
+        Parameters
+        ----------
+        filename: str
+            Name of file to write into
         flag: str, default=None
             Flag to pass to the default file reader. If None the file remains closed.
-        compression: str, default=None
-            Compression method passed on to h5py writing functions (only used for
-        dataframes and other array-like data.)
-    """
-
-    def __init__(self, filename, compression=None):
-        super().__init__(filename, flag=None)
-
-        if compression is None:
-            self.compression = "gzip"
+        compression: str, default="gzip"
+            Compression method passed on to h5py writing functions (only used for dataframes and other array-like data).
+        """
+        super().__init__(filename, flag=flag)
+        self.compression = compression
 
     def write(
         self,
         path: str,
         data: Iterable = None,
-        meta: Dict = {},
+        meta: dict = {},
         overwrite: str = None,
     ):
         """
+        Write data and metadata to a particular path in the h5 file.
+
         Parameters
         ----------
         path : str
-            Path inside h5 file to write into.
-        data : Iterable, default = None
-        meta : Dict, default = {}
-
+            Path inside h5 file into which to write.
+        data : Iterable, optional
+        meta : dict, optional
+        overwrite: str, optional
         """
         self.id_cache = {}
         with h5py.File(self.filename, "a") as f:
+            # Alan, haven't we already opened the h5 file through BridgeH5's init?
             if overwrite == "overwrite":  # TODO refactor overwriting
                 if path in f:
                     del f[path]
@@ -529,49 +675,57 @@ class Writer(BridgeH5):
             # elif overwrite == "skip":
             #     if path in f:
             #         logging.debug("Skipping dataset {}".format(path))
-
             logging.debug(
                 "{} {} to {} and {} metadata fields".format(
                     overwrite, type(data), path, len(meta)
                 )
             )
+            # write data
             if data is not None:
                 self.write_dset(f, path, data)
+            # write metadata
             if meta:
                 for attr, metadata in meta.items():
                     self.write_meta(f, path, attr, data=metadata)
 
     def write_dset(self, f: h5py.File, path: str, data: Iterable):
+        """Write data in different ways depending on its type to an open h5 file."""
+        # data is a datafram
         if isinstance(data, pd.DataFrame):
             self.write_pd(f, path, data, compression=self.compression)
+        # data is a multi-index dataframe
         elif isinstance(data, pd.MultiIndex):
+            # Alan: should we still not compress here?
             self.write_index(f, path, data)  # , compression=self.compression)
+        # data is a dictionary of dataframes
         elif isinstance(data, Dict) and np.all(
             [isinstance(x, pd.DataFrame) for x in data.values]
         ):
             for k, df in data.items():
                 self.write_dset(f, path + f"/{k}", df)
+        # data is an iterable
         elif isinstance(data, Iterable):
             self.write_arraylike(f, path, data)
+        # data is a float or integer
         else:
             self.write_atomic(data, f, path)
 
     def write_meta(self, f: h5py.File, path: str, attr: str, data: Iterable):
+        """Write metadata to an open h5 file."""
         obj = f.require_group(path)
-
         obj.attrs[attr] = data
 
     @staticmethod
     def write_arraylike(f: h5py.File, path: str, data: Iterable, **kwargs):
+        """Write an iterable."""
         if path in f:
             del f[path]
-
         narray = np.array(data)
-
-        chunks = None
         if narray.any():
             chunks = (1, *narray.shape[1:])
-
+        else:
+            chunks = None
+        # create dset
         dset = f.create_dataset(
             path,
             shape=narray.shape,
@@ -579,10 +733,12 @@ class Writer(BridgeH5):
             dtype="int",
             compression=kwargs.get("compression", None),
         )
+        # add data to dset
         dset[()] = narray
 
     @staticmethod
     def write_index(f, path, pd_index, **kwargs):
+        """Write a multi-index dataframe."""
         f.require_group(path)  # TODO check if we can remove this
         for i, name in enumerate(pd_index.names):
             ids = pd_index.get_level_values(i)
@@ -597,12 +753,14 @@ class Writer(BridgeH5):
             indices[()] = ids
 
     def write_pd(self, f, path, df, **kwargs):
+        """Write a dataframe."""
         values_path = (
             path + "values" if path.endswith("/") else path + "/values"
         )
         if path not in f:
-            max_ncells = 2e5
 
+            # create dataset and write data
+            max_ncells = 2e5
             max_tps = 1e3
             f.create_dataset(
                 name=values_path,
@@ -616,6 +774,7 @@ class Writer(BridgeH5):
             dset = f[values_path]
             dset[()] = df.values
 
+            # create dateset and write indices
             if not len(df):  # Only write more if not empty
                 return None
 
@@ -630,7 +789,7 @@ class Writer(BridgeH5):
                 )
                 dset = f[indices_path]
                 dset[()] = df.index.get_level_values(level=name).tolist()
-
+            # create dataset and write columns
             if (
                 df.columns.dtype == np.int
                 or df.columns.dtype == np.dtype("uint")
@@ -648,9 +807,11 @@ class Writer(BridgeH5):
             else:
                 f[path].attrs["columns"] = df.columns.tolist()
         else:
+
+            # path exists
             dset = f[values_path]
 
-            # Filter out repeated timepoints
+            # filter out repeated timepoints
             new_tps = set(df.columns)
             if path + "/timepoint" in f:
                 new_tps = new_tps.difference(f[path + "/timepoint"][()])
@@ -659,16 +820,18 @@ class Writer(BridgeH5):
             if (
                 not hasattr(self, "id_cache")
                 or df.index.nlevels not in self.id_cache
-            ):  # Use cache dict to store previously-obtained indices
+            ):
+                # use cache dict to store previously obtained indices
                 self.id_cache[df.index.nlevels] = {}
                 existing_ids = self.get_existing_ids(
                     f, [path + "/" + x for x in df.index.names]
                 )
-                # Split indices in existing and additional
+                # split indices in existing and additional
                 new = df.index.tolist()
                 if (
                     df.index.nlevels == 1
-                ):  # Cover for cases with a single index
+                ):
+                    # cover cases with a single index
                     new = [(x,) for x in df.index.tolist()]
                 (
                     found_multis,
@@ -681,7 +844,7 @@ class Writer(BridgeH5):
                     locate_indices(existing_ids, found_multis)
                 )
 
-                # We must sort our indices for h5py indexing
+                # sort indices for h5 indexing
                 incremental_existing = np.argsort(found_indices)
                 self.id_cache[df.index.nlevels][
                     "found_indices"
@@ -706,7 +869,7 @@ class Writer(BridgeH5):
             ].values
             ncells, ntps = f[values_path].shape
 
-            # Add found cells
+            # add found cells
             dset.resize(dset.shape[1] + df.shape[1], axis=1)
             dset[:, ntps:] = np.nan
 
@@ -715,13 +878,13 @@ class Writer(BridgeH5):
                 "found_indices"
             ]
 
-            # Cover for case when all labels are new
+            # case when all labels are new
             if found_indices_sorted.any():
                 # h5py does not allow bidimensional indexing,
                 # so we have to iterate over the columns
                 for i, tp in enumerate(df.columns):
                     dset[found_indices_sorted, tp] = existing_values[:, i]
-            # Add new cells
+            # add new cells
             n_newcells = len(
                 self.id_cache[df.index.nlevels]["additional_multis"]
             )
@@ -756,24 +919,20 @@ class Writer(BridgeH5):
 
     @staticmethod
     def get_existing_ids(f, paths):
-        # Fetch indices and convert them to a (nentries, nlevels) ndarray
+        """Fetch indices and convert them to a (nentries, nlevels) ndarray."""
         return np.array([f[path][()] for path in paths]).T
 
     @staticmethod
     def find_ids(existing, new):
-        # Compare two tuple sets and return the intersection and difference
-        # (elements in the 'new' set not in 'existing')
+        """Compare two tuple sets and return the intersection and difference (elements in the 'new' set not in 'existing')."""
         set_existing = set([tuple(*x) for x in zip(existing.tolist())])
         existing_cells = np.array(list(set_existing.intersection(new)))
         new_cells = np.array(list(set(new).difference(set_existing)))
+        return existing_cells, new_cells
+
 
-        return (
-            existing_cells,
-            new_cells,
-        )
 
 
-# @staticmethod
 def locate_indices(existing, new):
     if new.any():
         if new.shape[1] > 1:
@@ -794,7 +953,7 @@ def locate_indices(existing, new):
 
 
 def _tuple_or_int(x):
-    # Convert tuple to int if it only contains one value
+    """Convert tuple to int if it only contains one value."""
     if len(x) == 1:
         return x[0]
     else: