diff --git a/src/agora/io/bridge.py b/src/agora/io/bridge.py
index 4a2f6093e7d8aaf4ed8e0e0f8abfcfaffb796449..478408f37d1f0724a232876f0f5ce0da5b02b585 100644
--- a/src/agora/io/bridge.py
+++ b/src/agora/io/bridge.py
@@ -1,5 +1,5 @@
 """
-Tools to interact with hdf5 files and handle data consistently.
+Tools to interact with h5 files and handle data consistently.
 """
 import collections
 from itertools import chain, groupby, product
@@ -13,26 +13,28 @@ import yaml
 
 class BridgeH5:
     """
-    Base class to interact with h5 data stores.
-    It also contains functions useful to predict how long should segmentation take.
+    Base class to interact with h5 files.
+
+    It includes functions that predict how long segmentation will take.
     """
 
     def __init__(self, filename, flag="r"):
+        """Initialise with the name of the h5 file."""
         self.filename = filename
         if flag is not None:
             self._hdf = h5py.File(filename, flag)
-
             self._filecheck
 
     def _filecheck(self):
         assert "cell_info" in self._hdf, "Invalid file. No 'cell_info' found."
 
     def close(self):
+        """Close the h5 file."""
         self._hdf.close()
 
     @property
     def meta_h5(self) -> t.Dict[str, t.Any]:
-        # Return metadata as indicated in h5 file
+        """Return metadata, defining it if necessary."""
         if not hasattr(self, "_meta_h5"):
             with h5py.File(self.filename, "r") as f:
                 self._meta_h5 = dict(f.attrs)
@@ -44,24 +46,24 @@ class BridgeH5:
 
     @staticmethod
     def get_consecutives(tree, nstepsback):
-        # Receives a sorted tree and returns the keys of consecutive elements
-        vals = {k: np.array(list(v)) for k, v in tree.items()}  # get tp level
+        """Receives a sorted tree and returns the keys of consecutive elements."""
+        # get tp level
+        vals = {k: np.array(list(v)) for k, v in tree.items()}
+        # get indices of consecutive elements
         where_consec = [
             {
                 k: np.where(np.subtract(v[n + 1 :], v[: -n - 1]) == n + 1)[0]
                 for k, v in vals.items()
             }
             for n in range(nstepsback)
-        ]  # get indices of consecutive elements
+        ]
         return where_consec
 
     def get_npairs(self, nstepsback=2, tree=None):
         if tree is None:
             tree = self.cell_tree
-
         consecutive = self.get_consecutives(tree, nstepsback=nstepsback)
         flat_tree = flatten(tree)
-
         n_predictions = 0
         for i, d in enumerate(consecutive, 1):
             flat = list(chain(*[product([k], list(v)) for k, v in d.items()]))
@@ -70,55 +72,49 @@ class BridgeH5:
                 n_predictions += len(flat_tree.get(p[0], [])) * len(
                     flat_tree.get(p[1], [])
                 )
-
         return n_predictions
 
     def get_npairs_over_time(self, nstepsback=2):
         tree = self.cell_tree
         npairs = []
-        for t in self._hdf["cell_info"]["processed_timepoints"][()]:
+        for tp in self._hdf["cell_info"]["processed_timepoints"][()]:
             tmp_tree = {
-                k: {k2: v2 for k2, v2 in v.items() if k2 <= t}
+                k: {k2: v2 for k2, v2 in v.items() if k2 <= tp}
                 for k, v in tree.items()
             }
             npairs.append(self.get_npairs(tree=tmp_tree))
-
         return np.diff(npairs)
 
     def get_info_tree(
         self, fields: Union[tuple, list] = ("trap", "timepoint", "cell_label")
     ):
         """
-        Returns traps, time points and labels for this position in form of a tree
-        in the hierarchy determined by the argument fields. Note that it is
-        compressed to non-empty elements and timepoints.
+        Return traps, time points and labels for this position in the form of a tree in the hierarchy determined by the argument fields.
+
+        Note that it is compressed to non-empty elements and timepoints.
 
         Default hierarchy is:
         - trap
         - time point
         - cell label
 
-        This function currently produces trees of depth 3, but it can easily be
-        extended for deeper trees if needed (e.g. considering groups,
-        chambers and/or positions).
+        This function currently produces trees of depth 3, but it can easily be extended for deeper trees if needed (e.g. considering groups, chambers and/or positions).
 
         Parameters
         ----------
-            fields: Fields to fetch from 'cell_info' inside the hdf5 storage
+        fields: list of strs
+            Fields to fetch from 'cell_info' inside the h5 file.
 
         Returns
         ----------
-            Nested dictionary where keys (or branches) are the upper levels
-             and the leaves are the last element of :fields:.
+        Nested dictionary where keys (or branches) are the upper levels and the leaves are the last element of :fields:.
         """
         zipped_info = (*zip(*[self._hdf["cell_info"][f][()] for f in fields]),)
-
         return recursive_groupsort(zipped_info)
 
 
 def groupsort(iterable: Union[tuple, list]):
-    # Sorts iterable and returns a dictionary where the values are grouped by the first element.
-
+    """Sorts iterable and returns a dictionary where the values are grouped by the first element."""
     iterable = sorted(iterable, key=lambda x: x[0])
     grouped = {
         k: [x[1:] for x in v] for k, v in groupby(iterable, lambda x: x[0])
@@ -127,17 +123,18 @@ def groupsort(iterable: Union[tuple, list]):
 
 
 def recursive_groupsort(iterable):
-    # Recursive extension of groupsort
+    """Recursive extension of groupsort."""
     if len(iterable[0]) > 1:
         return {
             k: recursive_groupsort(v) for k, v in groupsort(iterable).items()
         }
-    else:  # Only two elements in list
+    else:
+        # only two elements in list
         return [x[0] for x in iterable]
 
 
 def flatten(d, parent_key="", sep="_"):
-    """Flatten nested dict. Adapted from https://stackoverflow.com/a/6027615"""
+    """Flatten nested dict. Adapted from https://stackoverflow.com/a/6027615."""
     items = []
     for k, v in d.items():
         new_key = parent_key + (k,) if parent_key else (k,)
@@ -149,18 +146,19 @@ def flatten(d, parent_key="", sep="_"):
 
 
 def attrs_from_h5(fpath: str):
-    """Return attributes as dict from h5 file"""
+    """Return attributes as dict from an h5 file."""
     with h5py.File(fpath, "r") as f:
         return dict(f.attrs)
 
 
 def parameters_from_h5(fpath: str):
+    """Return parameters from an h5 file."""
     attrs = attrs_from_h5(fpath)
     return yaml.safe_load(attrs["parameters"])
 
 
 def image_creds_from_h5(fpath: str):
-    """Return image id and server credentials from h5"""
+    """Return image id and server credentials from an h5."""
     attrs = attrs_from_h5(fpath)
     return (
         attrs["image_id"],
diff --git a/src/agora/io/writer.py b/src/agora/io/writer.py
index 66fec1f40d69c08412e246be4e21cd20d6fdb410..6c07e40609f188006fe703debb605dd14a2426a3 100644
--- a/src/agora/io/writer.py
+++ b/src/agora/io/writer.py
@@ -620,22 +620,23 @@ class StateWriter(DynamicWriter):
 
 #################### Extraction version ###############################
 class Writer(BridgeH5):
-    """
-    Class in charge of transforming data into compatible structures.
+    """Class to transform data into compatible structures."""
 
-    Decoupling interface from implementation!
-
-    Parameters
-    ----------
-    filename: str
-        Name of file to write into
-    flag: str, default=None
-        Flag to pass to the default file reader. If None the file remains closed.
-    compression: str, default="gzip"
-        Compression method passed on to h5py writing functions (only used for dataframes and other array-like data).
-    """
+    # Alan: when is this used?
 
     def __init__(self, filename, flag=None, compression="gzip"):
+        """
+        Initialise write.
+
+        Parameters
+        ----------
+        filename: str
+            Name of file to write into
+        flag: str, default=None
+            Flag to pass to the default file reader. If None the file remains closed.
+        compression: str, default="gzip"
+            Compression method passed on to h5py writing functions (only used for dataframes and other array-like data).
+        """
         super().__init__(filename, flag=flag)
         self.compression = compression
 
@@ -647,16 +648,19 @@ class Writer(BridgeH5):
         overwrite: str = None,
     ):
         """
+        Write data and metadata to a particular path in the h5 file.
+
         Parameters
         ----------
         path : str
-            Path inside h5 file to write into.
-        data : Iterable, default = None
-        meta : dict, default = {}
-
+            Path inside h5 file into which to write.
+        data : Iterable, optional
+        meta : dict, optional
+        overwrite: str, optional
         """
         self.id_cache = {}
         with h5py.File(self.filename, "a") as f:
+            # Alan, haven't we already opened the h5 file through BridgeH5's init?
             if overwrite == "overwrite":  # TODO refactor overwriting
                 if path in f:
                     del f[path]
@@ -668,49 +672,57 @@ class Writer(BridgeH5):
             # elif overwrite == "skip":
             #     if path in f:
             #         logging.debug("Skipping dataset {}".format(path))
-
             logging.debug(
                 "{} {} to {} and {} metadata fields".format(
                     overwrite, type(data), path, len(meta)
                 )
             )
+            # write data
             if data is not None:
                 self.write_dset(f, path, data)
+            # write metadata
             if meta:
                 for attr, metadata in meta.items():
                     self.write_meta(f, path, attr, data=metadata)
 
     def write_dset(self, f: h5py.File, path: str, data: Iterable):
+        """Write data in different ways depending on its type to an open h5 file."""
+        # data is a datafram
         if isinstance(data, pd.DataFrame):
             self.write_pd(f, path, data, compression=self.compression)
+        # data is a multi-index dataframe
         elif isinstance(data, pd.MultiIndex):
+            # Alan: should we still not compress here?
             self.write_index(f, path, data)  # , compression=self.compression)
+        # data is a dictionary of dataframes
         elif isinstance(data, Dict) and np.all(
             [isinstance(x, pd.DataFrame) for x in data.values]
         ):
             for k, df in data.items():
                 self.write_dset(f, path + f"/{k}", df)
+        # data is an iterable
         elif isinstance(data, Iterable):
             self.write_arraylike(f, path, data)
+        # data is a float or integer
         else:
             self.write_atomic(data, f, path)
 
     def write_meta(self, f: h5py.File, path: str, attr: str, data: Iterable):
+        """Write metadata to an open h5 file."""
         obj = f.require_group(path)
-
         obj.attrs[attr] = data
 
     @staticmethod
     def write_arraylike(f: h5py.File, path: str, data: Iterable, **kwargs):
+        """Write an iterable."""
         if path in f:
             del f[path]
-
         narray = np.array(data)
-
-        chunks = None
         if narray.any():
             chunks = (1, *narray.shape[1:])
-
+        else:
+            chunks = None
+        # create dset
         dset = f.create_dataset(
             path,
             shape=narray.shape,
@@ -718,10 +730,12 @@ class Writer(BridgeH5):
             dtype="int",
             compression=kwargs.get("compression", None),
         )
+        # add data to dset
         dset[()] = narray
 
     @staticmethod
     def write_index(f, path, pd_index, **kwargs):
+        """Write a multi-index dataframe."""
         f.require_group(path)  # TODO check if we can remove this
         for i, name in enumerate(pd_index.names):
             ids = pd_index.get_level_values(i)
@@ -736,12 +750,14 @@ class Writer(BridgeH5):
             indices[()] = ids
 
     def write_pd(self, f, path, df, **kwargs):
+        """Write a dataframe."""
         values_path = (
             path + "values" if path.endswith("/") else path + "/values"
         )
         if path not in f:
-            max_ncells = 2e5
 
+            # create dataset and write data
+            max_ncells = 2e5
             max_tps = 1e3
             f.create_dataset(
                 name=values_path,
@@ -754,7 +770,7 @@ class Writer(BridgeH5):
             )
             dset = f[values_path]
             dset[()] = df.values
-
+            # create dateset and write indices
             for name in df.index.names:
                 indices_path = "/".join((path, name))
                 f.create_dataset(
@@ -766,7 +782,7 @@ class Writer(BridgeH5):
                 )
                 dset = f[indices_path]
                 dset[()] = df.index.get_level_values(level=name).tolist()
-
+            # create dataset and write columns
             if (
                 df.columns.dtype == np.int
                 or df.columns.dtype == np.dtype("uint")
@@ -784,9 +800,11 @@ class Writer(BridgeH5):
             else:
                 f[path].attrs["columns"] = df.columns.tolist()
         else:
+
+            # path exists
             dset = f[values_path]
 
-            # Filter out repeated timepoints
+            # filter out repeated timepoints
             new_tps = set(df.columns)
             if path + "/timepoint" in f:
                 new_tps = new_tps.difference(f[path + "/timepoint"][()])
@@ -795,16 +813,18 @@ class Writer(BridgeH5):
             if (
                 not hasattr(self, "id_cache")
                 or df.index.nlevels not in self.id_cache
-            ):  # Use cache dict to store previously-obtained indices
+            ):
+                # use cache dict to store previously obtained indices
                 self.id_cache[df.index.nlevels] = {}
                 existing_ids = self.get_existing_ids(
                     f, [path + "/" + x for x in df.index.names]
                 )
-                # Split indices in existing and additional
+                # split indices in existing and additional
                 new = df.index.tolist()
                 if (
                     df.index.nlevels == 1
-                ):  # Cover for cases with a single index
+                ):
+                    # cover cases with a single index
                     new = [(x,) for x in df.index.tolist()]
                 (
                     found_multis,
@@ -817,7 +837,7 @@ class Writer(BridgeH5):
                     locate_indices(existing_ids, found_multis)
                 )
 
-                # We must sort our indices for h5py indexing
+                # sort indices for h5 indexing
                 incremental_existing = np.argsort(found_indices)
                 self.id_cache[df.index.nlevels][
                     "found_indices"
@@ -842,7 +862,7 @@ class Writer(BridgeH5):
             ].values
             ncells, ntps = f[values_path].shape
 
-            # Add found cells
+            # add found cells
             dset.resize(dset.shape[1] + df.shape[1], axis=1)
             dset[:, ntps:] = np.nan
 
@@ -851,13 +871,13 @@ class Writer(BridgeH5):
                 "found_indices"
             ]
 
-            # Cover for case when all labels are new
+            # case when all labels are new
             if found_indices_sorted.any():
                 # h5py does not allow bidimensional indexing,
                 # so we have to iterate over the columns
                 for i, tp in enumerate(df.columns):
                     dset[found_indices_sorted, tp] = existing_values[:, i]
-            # Add new cells
+            # add new cells
             n_newcells = len(
                 self.id_cache[df.index.nlevels]["additional_multis"]
             )
@@ -892,24 +912,20 @@ class Writer(BridgeH5):
 
     @staticmethod
     def get_existing_ids(f, paths):
-        # Fetch indices and convert them to a (nentries, nlevels) ndarray
+        """Fetch indices and convert them to a (nentries, nlevels) ndarray."""
         return np.array([f[path][()] for path in paths]).T
 
     @staticmethod
     def find_ids(existing, new):
-        # Compare two tuple sets and return the intersection and difference
-        # (elements in the 'new' set not in 'existing')
+        """Compare two tuple sets and return the intersection and difference (elements in the 'new' set not in 'existing')."""
         set_existing = set([tuple(*x) for x in zip(existing.tolist())])
         existing_cells = np.array(list(set_existing.intersection(new)))
         new_cells = np.array(list(set(new).difference(set_existing)))
+        return existing_cells, new_cells
+
 
-        return (
-            existing_cells,
-            new_cells,
-        )
 
 
-# @staticmethod
 def locate_indices(existing, new):
     if new.any():
         if new.shape[1] > 1:
@@ -930,7 +946,7 @@ def locate_indices(existing, new):
 
 
 def _tuple_or_int(x):
-    # Convert tuple to int if it only contains one value
+    """Convert tuple to int if it only contains one value."""
     if len(x) == 1:
         return x[0]
     else: