diff --git a/.gitignore b/.gitignore
index fb11995ecf7a23417b7105faac4fd539370d9dbb..ef97f288f045c275e9d39282062406480412353a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,4 +1,5 @@
 *.pyc
+*.asv
 .*.sw?
 *~
 .DS_Store
@@ -13,3 +14,5 @@ baby-server-test.png
 *.egg-info/
 **/venv/
 /tf1_venv/
+/dist
+/tests/test-modelset-cache/
diff --git a/LICENSE.txt b/LICENSE.txt
index df6ab2d7a11345bfd94c81ba9e31d37c44a48791..e95c479f323c327c150ff2f9276739747314ad6e 100644
--- a/LICENSE.txt
+++ b/LICENSE.txt
@@ -1,12 +1,14 @@
 If you publish results that make use of this software or the Birth Annotator
 for Budding Yeast algorithm, please cite:
-Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S Swain,
-2021, Birth Annotator for Budding Yeast (in preparation).
+Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+and Swain, P.S. (2023). Determining growth rates from bright-field images of
+budding cells through identifying overlaps. eLife. 12:e79812.
+https://doi.org/10.7554/eLife.79812
 
 
 The MIT License (MIT)
 
-Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021
+Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
 
 Permission is hereby granted, free of charge, to any person obtaining a copy
 of this software and associated documentation files (the "Software"), to deal
diff --git a/MANIFEST.in b/MANIFEST.in
deleted file mode 100644
index bb4ac0c66c1510fa03ad47f7727055fcdcfbd778..0000000000000000000000000000000000000000
--- a/MANIFEST.in
+++ /dev/null
@@ -1,4 +0,0 @@
-graft python/baby
-include README.md
-include LICENSE.txt
-include AUTHORS.txt
diff --git a/README.md b/README.md
index c59fbc1ceef00ec54d49b71d2b4a7fbd64f01008..a77c33d75f38ab6a0bb883d767a105e8e04bdcf1 100644
--- a/README.md
+++ b/README.md
@@ -12,92 +12,97 @@ budding cells from bright-field stacks. The Birth Annotator for Budding Yeast
 
 The algorithm is described in:
 
-Julian M J Pietsch, Alán F Muñoz, Diane-Yayra A Adjavon, Ivan B N Clark, Peter
-S Swain, 2022, A label-free method to track individuals and lineages of
-budding cells (in submission).
+[Julian MJ Pietsch, Alán F Muñoz, Diane-Yayra A Adjavon, Iseabail Farquhar,
+Ivan BN Clark, Peter S Swain. (2023). Determining growth rates from
+bright-field images of budding cells through identifying overlaps. eLife.
+12:e79812.](https://doi.org/10.7554/eLife.79812)
+
 
 ## Installation
 
-BABY can be used with Python versions 3.6-3.8 (see below for details). If you
-wish to use the latest compatible versions of all packages, BABY can simply be
-installed by first obtaining this repository (e.g., `git clone
-https://git.ecdf.ed.ac.uk/jpietsch/baby.git`), and then running pip on the 
-repository directory:
+We recommend installing BABY and all its requirements in a virtual environment
+(e.g., `conda create` if you are using Anaconda, or `python3 -m venv` otherwise).
+
+If you do not have a GPU and simply wish to use the latest compatible versions
+of all packages, BABY can be installed by first obtaining this repository
+(e.g., `git clone https://git.ecdf.ed.ac.uk/jpietsch/baby.git`), and then
+using pip:
 
 ```bash
-> cd baby
-> pip install .
+> pip install baby/
 ```
 
-NB: The '.' is important!
-
-If you pull new changes, you need to update by running: `pip install -U .` from
-within the repository directory.
+NB: You can update by running: `pip install -U baby/`.
 
 *Developers:* You may prefer to install an editable version:
 
 ```bash
-> pip install -e .
+> pip install -e baby/
 ```
 
-This avoids the need to run the update command.
-
-**Requirements for Python and TensorFlow**
-
-BABY requires Python 3 and [TensorFlow](https://www.tensorflow.org). The
-models were trained in TensorFlow 1.14.0, but are compatible with versions of
-TensorFlow up to 2.3.4. The required version of Python depends on the version
-of TensorFlow you choose. We recommend either:
+### Python and TensorFlow version
 
-- Python 3.6 and TensorFlow 1.14,
-- Python 3.7 and TensorFlow 1.15, or
-- Python 3.8 and TensorFlow 2.3.  
+BABY requires Python 3 and [TensorFlow](https://www.tensorflow.org).
+Different versions of TensorFlow have different Python version requirements.
+You can find a table of matching versions
+[here](https://www.tensorflow.org/install/source#tested_build_configurations).
 
-In any case, it is recommended that you install TensorFlow and all other
-required packages into a virtual environment (i.e., `conda create` if you are
-using Anaconda, or `python3 -m venv` otherwise).
+Our models were trained with TensorFlow version 2.8, but have been tested up
+to version 2.14.
 
-By default, BABY will trigger installation of the highest compatible version of
-TensorFlow. If you want to use an earlier version as suggested above, then
-first install that version in your virtual environment by running:
+By default, BABY will trigger installation of the highest compatible version
+of TensorFlow. If you want to use an earlier version, then first install that
+version in your virtual environment by running:
 
 ```bash
-> pip install tensorflow==1.14
+> pip install tensorflow==2.8
 ```
 
 and then follow the instructions for installing BABY as above.
 
-**NB:** To make use of a GPU you should also follow the other [set up
-instructions](https://www.tensorflow.org/install/gpu).
+### Running with GPU
 
-**NB:** For `tensorflow==1.14`, you will also need to downgrade the default
-version of `h5py`: 
+To make use of a GPU you should follow the [TensorFlow set up
+instructions](https://www.tensorflow.org/install/gpu) before installing BABY.
 
-```bash
-> pip uninstall h5py
-> pip install h5py==2.9.0
-```
+BABY can make use of Metal on M1/M2 Macs by following the instructions
+[here](https://developer.apple.com/metal/tensorflow-plugin/).
+
+
+## Quickstart using the Python API
 
-## Run using the Python API
+The BABY algorithm makes use of several machine learning models that are
+defined as a model set. Various model sets are available, and each has been
+optimised for a particular species, microfluidics device, pixel size, channel
+and number of input Z sections.
 
-Create a new `BabyBrain` with one of the model sets. The `brain` contains
-all the models and parameters for segmenting and tracking cells.
+You can get a list of available model sets and the types of input they were
+trained for using the `meta` function in the `modelsets` module:
 
 ```python
->>> from baby import BabyBrain, BabyCrawler, modelsets
->>> modelset = modelsets()['evolve_brightfield_60x_5z']
->>> brain = BabyBrain(**modelset)
+>>> from baby import modelsets
+>>> modelsets.meta()
+```
+
+You then load your favourite model set as a `BabyBrain` object, which
+coordinates all the models and parameters in the set to produce tracked and
+segmented outlines from input images. You can get a `BabyBrain` for a given
+model set using the `get` function in the `modelsets` module:
+
+```python
+>>> brain = modelsets.get('yeast-alcatras-brightfield-EMCCD-60x-5z')
 ```
 
 For each time course you want to process, instantiate a new `BabyCrawler`. The
 crawler keeps track of cells between time steps.
 
 ```python
+>>> from baby import BabyCrawler
 >>> crawler = BabyCrawler(brain)
 ```
 
-Load an image time series (from the `tests` subdirectory in this example). The
-image should have shape (x, y, z).
+Load an image time series (from the `tests` subdirectory in this repository).
+The image should have shape (x, y, z).
 
 ```python
 >>> from baby.io import load_tiled_image
@@ -150,12 +155,6 @@ requests using:
 > baby-phone
 ```
 
-or on windows:
-
-```
-> baby-phone.exe
-```
-
 Server runs by default on [http://0.0.0.0:5101](). HTTP requests need to be
 sent to the correct URL endpoint, but the HTTP API is currently undocumented.
 The primary client implementation is in Matlab.
diff --git a/hatch.toml b/hatch.toml
new file mode 100644
index 0000000000000000000000000000000000000000..10caa8e586007468886f88c811e30e50f17e84da
--- /dev/null
+++ b/hatch.toml
@@ -0,0 +1,12 @@
+[envs.test]
+python = "3.10"
+dependencies = [
+  "pytest",
+]
+[envs.test.env-vars]
+BABY_MODELSETS_PATH = "tests/test-modelset-cache"
+
+[envs.test_cache]
+template = "test"
+[envs.test_cache.env-vars]
+BABY_MODELSETS_PATH = ""
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..b336e92a4ee32960b0a9ac0de3820477adbf0646
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,68 @@
+[build-system]
+requires = ["hatchling"]
+build-backend = "hatchling.build"
+
+[project]
+name = "baby-seg"
+dynamic = ["version"]
+description = "Birth Annotator for Budding Yeast"
+readme = "README.md"
+license = "MIT"
+license-files = { paths = ["LICENSE.txt"] }
+authors = [
+    {name = "Julian Pietsch", email = "julian.pietsch@synmikro.mpi-marburg.mpg.de"}
+]
+requires-python = ">=3.6"
+dependencies = [
+    "aiohttp",
+    "gaussianprocessderivatives",
+    "imageio",
+    "matplotlib",
+    "numpy",
+    "pandas",
+    "pillow<9",
+    "requests",
+    "scikit-image",
+    "scikit-learn<1.3",
+    "scipy",
+    "tensorflow>=1.14",
+    "tensorflow-metal; platform_system == 'Darwin' and platform_machine == 'arm64'",
+    "tqdm",
+]
+classifiers = [
+    "Programming Language :: Python :: 3",
+    "License :: OSI Approved :: MIT License",
+    "Operating System :: OS Independent",
+    "Intended Audience :: Science/Research",
+    "Topic :: Scientific/Engineering :: Artificial Intelligence",
+    "Topic :: Scientific/Engineering :: Image Processing",
+]
+
+[project.optional-dependencies]
+dev = [
+    "elasticdeform",
+    "keras_tuner",
+    "pytest",
+]
+
+[project.scripts]
+baby-fit-grs = "baby.postprocessing:main"
+baby-phone = "baby.server:main"
+baby-race = "baby.speed_tests:main"
+
+[project.urls]
+Homepage = "https://git.ecdf.ed.ac.uk/swain-lab/baby"
+
+[tool.hatch.version]
+path = "python/baby/__init__.py"
+
+[tool.hatch.build.targets.wheel]
+packages = [
+    "python/baby",
+]
+
+[tool.hatch.build.targets.sdist]
+include = [
+    "/python",
+    "/AUTHORS.txt",
+]
diff --git a/python/baby/__init__.py b/python/baby/__init__.py
index 77f8b476878a21c7c3e1829542ed6f1e6bec6f49..002f91abbba22e8591ac5f7041682d87eddebe09 100644
--- a/python/baby/__init__.py
+++ b/python/baby/__init__.py
@@ -1,12 +1,14 @@
 # If you publish results that make use of this software or the Birth Annotator
 # for Budding Yeast algorithm, please cite:
-# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S
-# Swain, 2021, Birth Annotator for Budding Yeast (in preparation).
+# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+# and Swain, P.S. (2023). Determining growth rates from bright-field images of
+# budding cells through identifying overlaps. eLife. 12:e79812.
+# https://doi.org/10.7554/eLife.79812
 # 
 # 
 # The MIT License (MIT)
 # 
-# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021
+# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
 # 
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to
@@ -26,19 +28,7 @@
 # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
 # IN THE SOFTWARE.
 """Mostly used to access models and model-sets"""
-from pathlib import Path
-import json
 from .brain import BabyBrain
 from .crawler import BabyCrawler
 
-BASE_DIR = Path(__file__).parent
-MODEL_DIR = BASE_DIR / 'models'
-
-def modelsets():
-    with open(BASE_DIR / 'modelsets.json', 'r') as fd:
-        msets = json.load(fd)
-    return msets
-
-# Todo: should probably be removed, but used in Tests
-def model_path():
-    return MODEL_DIR
+__version__ = 'v0.30.0'
diff --git a/python/baby/augmentation.py b/python/baby/augmentation.py
index 38ee7a195a7a36e4b64bb8614ce663cb67ba6aff..23b959f8d621255172e27666545f9a5650eceac3 100644
--- a/python/baby/augmentation.py
+++ b/python/baby/augmentation.py
@@ -1,12 +1,14 @@
 # If you publish results that make use of this software or the Birth Annotator
 # for Budding Yeast algorithm, please cite:
-# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S
-# Swain, 2021, Birth Annotator for Budding Yeast (in preparation).
+# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+# and Swain, P.S. (2023). Determining growth rates from bright-field images of
+# budding cells through identifying overlaps. eLife. 12:e79812.
+# https://doi.org/10.7554/eLife.79812
 # 
 # 
 # The MIT License (MIT)
 # 
-# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021
+# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
 # 
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to
@@ -30,7 +32,7 @@ A set of augmentation functions for on-the-fly preprocessing of images before
 input into the Neural Network. All augmentation functions are meant to work
 with unstructured arrays (in particular, elastic_deform, turn, and the shifts
 probably will return errors if used on structured arrays. They all take exactly
-two arrays as input, and will perform identitcal transformations on both arrays.
+two arrays as input, and will perform identical transformations on both arrays.
 """
 from __future__ import absolute_import, division, print_function
 import json
@@ -38,21 +40,28 @@ import numpy as np
 from itertools import permutations, repeat
 from scipy.ndimage import map_coordinates, gaussian_filter, shift
 from scipy.ndimage.morphology import binary_fill_holes
+from scipy import interpolate
 from skimage import transform
 from skimage.feature import canny
 from skimage.filters import gaussian
 from skimage.draw import rectangle_perimeter
+import elasticdeform
+# import cv2
 
-from .preprocessing import segoutline_flattening
-from .errors import BadParam
+from .preprocessing import segoutline_flattening, connect_pixel_gaps
+from .errors import BadParam, BadFile
 
-AUGMENTATION_ORDER = ('substacks', 'rotate', 'vshift', 'hshift', 'downscale',
-                      'crop', 'vflip', 'hflip', 'movestacks', 'noise')
+AUGMENTATION_ORDER = ('substacks', 'vshift', 'hshift', 'rough_crop',
+                      'downscale', 'shadows', 'blur', 'elastic_deform',
+                      'rotate', 'crop', 'vflip', 'hflip',
+                      'movestacks', 'noise', 'pad', 'gamma',
+                      'hist_deform', 'hist_scale')
 
 
 class Augmenter(object):
 
-    def __init__(self, xy_out=80, probs={}, p_noop=0.05, substacks=None):
+    def __init__(self, xy_out=80, probs={}, p_noop=0.05, substacks=None,
+            only_basic_augs=True):
         """
         Random data augmentation of img and lbl.
 
@@ -61,35 +70,62 @@ class Augmenter(object):
         xy_out : int or pair of ints as a tuple
             The intended width and height of the final augmented image.
 
-        probs : dict of floats in [0, 1]
+        probs : dict of floats in [0, 1] or None
             Specify non-default probabilities for the named augmentations.
-            Augmentations with zero probability are omitted.
+            Augmentations with zero probability are omitted.  To get all
+            augmentations with equal probability, simply specify an empty
+            dict.
 
         p_noop : float in [0, 1]
             Adjusts the probability for no augmentation operation in the
             default case. If set to 1, then all operations will be omitted
             by default (i.e., unless a non-zero probability is specified
             for that operation in probs).
+
+        only_basic_augs: bool
+            If True, then elastic_deform and img intensity augmentations are
+            omitted (i.e., "shadows", "elastic_deform", "gamma",
+            "hist_deform", "hist_scale"). This was the default for the BABY
+            paper. Specify False to include all augmentations. NB: this has
+            lower priority than the `probs` arg, e.g., if `shadows=0.2` is
+            specified in probs, then the shadows aug will be active.
         """
 
         if type(xy_out) is int:
             self.xy_out = (xy_out, xy_out)
         elif len(xy_out) == 2:
-            self.xy_out = xy_out
+            self.xy_out = tuple(xy_out)
         else:
             raise Exception('"xy_out" must be an int or pair of ints')
 
         self.xy_in = None
+        self.pad_value = None
+        self.aug_log = []
+        self._vshift = None
+        self._hshift = None
+
+        self.preserve_bool_lbl = True
+
+        # Interpolation order for label operations
+        # This should be 0 for bitmask labels
+        self.lbl_order = 0
+
+        if only_basic_augs:
+            custom_probs = probs
+            probs = dict(shadows=0, elastic_deform=0, gamma=0, hist_deform=0,
+                    hist_scale=0, blur=0)
+            probs.update(**custom_probs)
 
+        # Treat 'crop', 'pad' and 'substacks' specially to have p = 1
+        guaranteed_augs = ('substacks', 'rough_crop', 'pad', 'crop')
         self.aug_order = [
             a for a in AUGMENTATION_ORDER
-            if probs.get(a, 1) > 0 or a == 'crop'
+            if probs.get(a, 1) > 0 or a in guaranteed_augs
         ]
-
-        # Treat 'crop' and 'substacks' specially to have p = 1
-        guaranteed_augs = ('crop', 'substacks')
-        n_augs = len(self.aug_order) - len(guaranteed_augs)
-        p_default = (1. - p_noop)**n_augs / n_augs
+        n_augs = len([1 for a in self.aug_order
+                      if probs.get(a, 0) < 1
+                      and a not in guaranteed_augs])
+        p_default = 1. - p_noop**(1/n_augs)
         self.probs = np.array([
             1 if a in guaranteed_augs else probs.get(a, p_default)
             for a in self.aug_order
@@ -126,22 +162,27 @@ class Augmenter(object):
         assert img.shape[:2] == lbl.shape[:2], \
             'xy dimensions of img and lbl are mismatched'
 
-        lbl_is_bool = lbl.dtype == 'bool'
+        lbl_is_bool = lbl.dtype == 'bool' and self.preserve_bool_lbl
 
         self.xy_in = img.shape[:2]
+        self.pad_value = None
+        self.aug_log = []
 
         for a, p in zip(self.aug_order, self.probs):
             if p == 0.0:
                 continue
             elif p == 1.0:
+                self.aug_log.append(a)
                 img, lbl = getattr(self, a)(img, lbl)
             else:
                 if np.random.uniform() < p:
+                    self.aug_log.append(a)
                     img, lbl = getattr(self, a)(img, lbl)
             if lbl_is_bool and lbl.dtype != 'bool':
                 lbl = lbl > 0.5  # ensure label stays boolean
 
         self.xy_in = None
+        self.pad_value = None
 
         # Ensure that x and y dimensions match the intended output size
         assert img.shape[:2] == self.xy_out and lbl.shape[:2] == self.xy_out, \
@@ -200,6 +241,40 @@ class Augmenter(object):
             img = img[:, :, ss]
             self.refslice = int(np.median(np.where(ss))) + 1
 
+        self.pad_value = np.median(img)
+        return img, lbl
+
+    def pad(self, img, lbl, update_size=True):
+        """Ensure image/label size are large enough
+
+        A guaranteed augmentation that pads image to ensure that 
+        as required Guaranteed augmentation simply run to ensure image/label sizes 
+        """
+
+        # We store a pad value to ensure consistent padding if this routine is
+        # called more than once in the augmentation chain
+        if self.pad_value is None:
+            self.pad_value = np.median(img)
+
+        # Ensure that img and lbl xy dimensions are at least as large as the
+        # requested output
+        Xin, Yin = img.shape[:2]
+        Xout, Yout = self.xy_out
+        Xpad = Xout - Xin if Xin < Xout else 0
+        Ypad = Yout - Yin if Yin < Yout else 0
+        if Xpad > 0 or Ypad > 0:
+            XpadL = Xpad // 2
+            Xpad = (XpadL, Xpad - XpadL)
+            YpadL = Ypad // 2
+            Ypad = (YpadL, Ypad - YpadL)
+            img = np.pad(img, (Xpad, Ypad, (0, 0)),
+                mode='constant', constant_values=self.pad_value)
+            lbl = np.pad(lbl, (Xpad, Ypad, (0, 0)), mode='constant')
+
+        # Update input size by default
+        if update_size:
+            self.xy_in = img.shape[:2]
+
         return img, lbl
 
     def rotate(self, img, lbl):
@@ -236,9 +311,8 @@ class Augmenter(object):
                 inshape = self.xy_in[0]
             maxpix = np.max([0, (inshape - self.xy_out[0]) // 2])
 
-        pix = np.random.choice(np.arange(-maxpix, maxpix + 1, dtype='int'))
-        return (shift(img, [pix, 0, 0], mode='reflect', order=0),
-                shift(lbl, [pix, 0, 0], mode='reflect', order=0))
+        self._vshift = np.random.choice(np.arange(-maxpix, maxpix + 1, dtype='int'))
+        return img, lbl
 
     def hshift(self, img, lbl, maxpix=None):
         """Shift along width, max of 10px by default
@@ -257,9 +331,22 @@ class Augmenter(object):
                 inshape = self.xy_in[1]
             maxpix = np.max([0, (inshape - self.xy_out[1]) // 2])
 
-        pix = np.random.choice(np.arange(-maxpix, maxpix + 1, dtype='int'))
-        return (shift(img, [0, pix, 0], mode='reflect', order=0),
-                shift(lbl, [0, pix, 0], mode='reflect', order=0))
+        self._hshift = np.random.choice(np.arange(-maxpix, maxpix + 1, dtype='int'))
+        return img, lbl
+
+    def rough_crop(self, img, lbl, xysize=None, nonempty_slices=None):
+        if xysize is None:
+            xysize = np.array(self.xy_out, dtype=float) * 1.7
+            xysize = np.ceil(xysize).astype(int)
+        if self.pad_value is None:
+            self.pad_value = np.median(img)
+        vshift = 0 if self._vshift is None else self._vshift
+        hshift = 0 if self._hshift is None else self._hshift
+        img = _shift_and_crop(img, vshift, hshift, xysize, self.pad_value)
+        lbl = _shift_and_crop(lbl, vshift, hshift, xysize, 0,
+                              nonempty_slices=nonempty_slices)
+        self._vshift, self._hshift = None, None
+        return img, lbl
 
     def crop(self, img, lbl, xysize=None):
         if xysize is None:
@@ -327,48 +414,155 @@ class Augmenter(object):
                                 size=img.shape)
         return img, lbl
 
-    def elastic_deform(self, img, lbl, params={}):
-        """Slight deformation
+    def elastic_deform(self, img, lbl):
+        """Random deformation based on elasticdeform package from PyPI
 
-        Elastic deformation of images as described in
-        Simard, Steinkraus and Platt, "Best Practices for
-        Convolutional Neural Networks applied to Visual Document Analysis", in
-        Proc. of the International Conference on Document Analysis and
-        Recognition, 2003.
-        Adapted from:
-        https://gist.github.com/chsasank/4d8f68caf01f041a6453e67fb30f8f5a
+        We aim for relatively weak deformation, so assign grid points
+        approximately every 32 pixels and keep the sigma low.
 
-        Example image:
+        The interpolation order for the lbl is set from the `lbl_order`
+        property of this class. It should be 0 if the lbl is a bitmask and no
+        smoothing has been applied, but can be the default 3 if smoothing of
+        a bitmask image has been applied.
+        """
+
+        # Want only light deformation, so approximately one grid point every
+        # 32 pixels:
+        npoints = np.maximum(np.round(np.array(lbl.shape[:2])/32), 1)
+        npoints = npoints.astype('int').tolist()
+        return elasticdeform.deform_random_grid(
+            [img, lbl], sigma=2, points=npoints,
+            order=[3, self.lbl_order], mode='reflect', axis=[(0, 1), (0, 1)])
+
+    def shadows(self, img, lbl):
+        """Location-dependent intensity deformation
 
-        .. image:: ../report/figures/augmentations/elastic_deform.*
+        Introduce local changes in intensity by fitting a bivariate spline to
+        a grid of random intensities. 
         """
-        alpha = 0.75 * img.shape[1]
-        sigma = 0.08 * img.shape[1]
-        x_y = _elastic_deform(np.dstack([img, lbl]), alpha=alpha, sigma=sigma)
-        return np.split(x_y, [img.shape[2]], axis=2)
 
-    def identity(self, img, lbl):
-        """Do nothing
+        # Want only slow changes in intensity, so approximately one grid point
+        # every 64 pixels:
+        npoints = np.maximum(np.round(np.array(img.shape[:2])/64), 1)
+        npoints = npoints.astype('int').tolist()
+
+        # Set up grid with random intensities and interpolate
+        x = np.linspace(0, img.shape[0] - 1, np.maximum(npoints[0], 2))
+        y = np.linspace(0, img.shape[1] - 1, np.maximum(npoints[1], 2))
+        I = np.random.normal(size=(x.size, y.size))
+        mapping = interpolate.RectBivariateSpline(
+                x, y, I, kx=np.minimum(x.size, 4) - 1,
+                ky=np.minimum(y.size, 4) - 1, s=(x.size * y.size) * 0.25)
+        # Scaled so that 90% of distribution sits within 0.5- to 2-fold
+        I_scaler = np.exp(0.4 * mapping(np.arange(img.shape[0]),
+                np.arange(img.shape[1]), grid=True))
+        # Truncate to a two-fold change in intensity
+        I_scaler = np.minimum(np.maximum(I_scaler, 0.5), 2)
+        # Ensure that the mean intensity of the entire image stays
+        # approximately constant
+        I_scaler /= np.median(I_scaler)
+        # Apply scaling consistently across z-sections
+        I_scaler = I_scaler[..., None]
+        img_min = img.min()
+        img_aug = img_min + (img - img_min) * I_scaler
+        return img_aug, lbl
+
+    def gamma(self, img, lbl):
+        """Adjust image gamma
+
+        Picks a random gamma transform to apply to the image
+        """
 
-        Example image:
+        # Pick a random gamma
+        g = np.random.uniform(0.3, 2.0)
+
+        # Attempt to robustly maintain the intensity range
+        iQ = np.quantile(img, (0, 0.02, 0.5, 0.98, 1))
+        imin, iptp = iQ[0], iQ[4] - iQ[0]
+        iQn = (iQ - imin) / iptp
+        # aim for consistent difference between 2/98 quantiles
+        sc = np.diff(iQ[[1, 3]]) / np.diff(iQn[[1, 3]] ** g)
+        # aim for consistent median
+        off = iQ[2] - iQn[2] ** g
+
+        return off + sc * ((img - imin) / iptp) ** g, lbl
+
+    def hist_deform(self, img, lbl):
+        """Random deformation of intensity histogram
+
+        Maps equally-spaced intensity ranges to new ranges sampled from the
+        Dirichlet distribution (of equal average lengths), but uses
+        interpolation to smooth the mapping.
 
-        .. image:: ../report/figures/augmentations/identity.*
+        Inspired by the `histogram_voodoo` implementation provided by Daniel
+        Eaton from the Paulsson lab for the Delta segmentation package.
         """
-        return img, lbl
+        npoints = 5  # includes end points
+        control_points = np.linspace(0, 1, num=npoints)
+        # split into `npoints - 1` segments with roughly equal lengths
+        # i.e., want relatively high alpha parameter:
+        mapped_points = np.random.dirichlet((2 * npoints,) * (npoints - 1))
+        mapped_points = np.cumsum(np.concatenate([[0], mapped_points]))
+        mapping = interpolate.PchipInterpolator(control_points, mapped_points)
+
+        # Attempt to robustly maintain the intensity range
+        iQ = np.quantile(img, (0, 0.02, 0.5, 0.98, 1))
+        imin, iptp = iQ[0], iQ[4] - iQ[0]
+        iQn = (iQ - imin) / iptp
+        # aim for consistent difference between 2/98 quantiles
+        sc = np.diff(iQ[[1, 3]]) / np.diff(mapping(iQn[[1, 3]]))
+        # aim for consistent median
+        off = iQ[2] - mapping(iQn[2])
+
+        return off + sc * mapping((img - imin) / iptp), lbl
+
+    def hist_scale(self, img, lbl):
+        """Randomly offset and scale the intensity histogram
+        """
+
+        # Sample an offset within 25% of the range of image intensities
+        imgrange = np.diff(np.quantile(img, [0.02, 0.98]))
+        off = np.random.uniform(-0.25*imgrange, 0.25*imgrange)
+
+        # Fairly sample a scaling factor between roughly 0.5 and 2
+        sc = np.random.lognormal(sigma=0.4)
+        # Truncate to a minimum of 0.5 and maximum of 2
+        sc = np.minimum(np.maximum(sc, 0.5), 2)
+
+        return (img + off) * sc, lbl
+
+    def blur(self, img, lbl):
+        """Apply gaussian filter to img to simulate loss of focus
+
+        Draw the sigma from an uniform distribution between 1 and 4.
+        """
+        return gaussian(img, np.random.uniform(1, 4)), lbl
 
 
 class SmoothingSigmaModel(object):
     """Model for picking a smoothing sigma for gaussian filtering
     
-    a, b and c should be obtained by fitting data to the following model:
+    There are two model types, 'exponential' and 'constant'
+
+    For the exponential model, `a`, `b` and `c` should be obtained by fitting
+    data to the following model:
         nedge = c + a * exp(b * sigma)
+
+    For the constant model, set `a` such that:
+        sigma = a
+    i.e., sigma is independent of nedge
     """
 
-    def __init__(self, a=None, b=None, c=None):
+    def __init__(self, a=None, b=None, c=None, model='exponential'):
         self._a = a
         self._b = b
         self._c = c
-        self._formula = 'sigma = log((nedge - c) / a) / b'
+        if model == 'exponential':
+            self._formula = 'sigma = log((nedge - c) / a) / b'
+        elif model == 'constant':
+            self._formula = 'sigma = a'
+        else:
+            raise BadParam('Unrecognised model type')
 
     def save(self, filename):
         with open(filename, 'wt') as f:
@@ -383,24 +577,41 @@ class SmoothingSigmaModel(object):
     def load(self, filename):
         with open(filename, 'rt') as f:
             model = json.load(f)
-        if model.get('formula') == 'sigma = m*log(nedge) + c':
+        formula = model.get('formula')
+        if formula == 'sigma = m*log(nedge) + c':
             m, c = (model.get(k, 0) for k in ('m', 'c'))
             self._a = np.exp(-c * m)
             self._b = 1 / m
             self._c = 0
-        elif self._formula != model.get('formula'):
-            raise BadFile('Model formula does not match SmoothingSigmaModel')
-        else:
+            self._formula = 'sigma = log((nedge - c) / a) / b'
+        elif formula == 'sigma = a':
+            self._a = model.get('a', 0)
+            self._b, self._c = None, None
+            self._formula = formula
+        elif formula == 'sigma = log((nedge - c) / a) / b':
             self._a, self._b, self._c = (
                 model.get(k, 0) for k in ('a', 'b', 'c'))
+            self._formula = formula
+        else:
+            raise BadFile('Model with unrecognised formula encountered')
 
     def __repr__(self):
-        return 'SmoothingSigmaModel: {}; a = {:.2f}, b = {:.2f}, c = {:.2f}'.format(
-            self._formula, self._a, self._b, self._c)
-
-    def __call__(self, s):
-        return np.log(np.clip(
-            (np.sum(s) - self._c) / self._a, 1, None)) / self._b
+        keys = ('a', 'b', 'c')
+        coefs = (getattr(self, '_' + k) for k in keys)
+        coefs = [
+            f'{k} = {v:.2f}' for k, v in zip(keys, coefs) if v is not None]
+        return 'SmoothingSigmaModel: {}; {}'.format(
+            self._formula, ', '.join(coefs))
+
+    def __call__(self, s, scaling=1.):
+        if self._formula == 'sigma = a':
+            return self._a
+        elif self._formula == 'sigma = log((nedge - c) / a) / b':
+            nedge = np.sum(s) * scaling
+            return np.log(np.clip(
+                (nedge - self._c) / self._a, 1, None)) / self._b
+        else:
+            raise BadParam('Unrecognised formula encountered')
 
 
 class SmoothedLabelAugmenter(Augmenter):
@@ -408,19 +619,24 @@ class SmoothedLabelAugmenter(Augmenter):
     def __init__(self,
                  sigmafunc,
                  targetgenfunc=segoutline_flattening,
+                 canny_padding=2,
                  **kwargs):
         super(SmoothedLabelAugmenter, self).__init__(**kwargs)
         self.sigmafunc = sigmafunc
         self.targetgenfunc = targetgenfunc
+        self.canny_padding = canny_padding
+        # Since label smoothing is applied, interpolation order for label
+        # transformations can be increased
+        self.lbl_order = 3
 
     def __call__(self, img, lbl_info):
         """This augmenter needs to be used in combination with a label
         preprocessing function that returns both images and info.
         """
 
-        lbl, info = lbl_info
+        lbl, self.current_info = lbl_info
 
-        # Smooth filled label for to avoid anti-aliasing artefacts
+        # Smooth filled label to avoid anti-aliasing artefacts
         lbl = lbl.astype('float')
         for l in range(lbl.shape[2]):
             o = lbl[..., l]  # slice a single outline
@@ -428,13 +644,44 @@ class SmoothedLabelAugmenter(Augmenter):
                 lbl[..., l] = gaussian(binary_fill_holes(o),
                                        self.sigmafunc(o))
 
-        img, lbl = super(SmoothedLabelAugmenter, self).__call__(img, lbl)
+        return super(SmoothedLabelAugmenter, self).__call__(img, lbl)
+
+    def rotate(self, img, lbl):
+        """Random rotation
+
+        Example image:
+
+        .. image:: ../report/figures/augmentations/turn.*
+        """
+        angle = np.random.choice(360)
+        return (transform.rotate(img,
+                                 angle=angle,
+                                 order=3,
+                                 mode='reflect',
+                                 resize=True),
+                transform.rotate(lbl,
+                                 angle=angle,
+                                 order=3,
+                                 mode='reflect',
+                                 resize=True))
+
+    def crop(self, img, lbl, xysize=None):
+        # Overload the crop function to restore filled cell masks and generate
+        # the flattened targets. Performing these operations before applying
+        # cropping helps to avoid boundary effects and misclassification of
+        # size group.
+        #
+        if xysize is None:
+            xysize = np.array(self.xy_out)
 
-        # NB: to limit open shapes, the crop operation has been overloaded to
-        # find edges before cropping
+        # Find edges from blurred images and fill
+        for s in range(lbl.shape[2]):
+            lbl[:, :, s] = _filled_canny(lbl[:, :, s], self.canny_padding)
 
-        # Finally generate flattened targets from segmentation outlines
+        info = self.current_info
         if 'focusStack' in info:
+            # Modify info dict that gets passed to flattener in order to
+            # specify focus relative to reference slice formed from substack
             info = info.copy()
             cellFocus = info['focusStack']
             if type(cellFocus) != list:
@@ -443,18 +690,9 @@ class SmoothedLabelAugmenter(Augmenter):
             info['focusStack'] = [f - self.refslice for f in cellFocus]
             # print('new focus = {}'.format(', '.join([str(f) for f in info['focusStack']])))
 
+        # Generate flattened targets from segmentation outlines
         lbl = self.targetgenfunc(lbl, info)
 
-        return img, lbl
-
-    def crop(self, img, lbl, xysize=None):
-        if xysize is None:
-            xysize = np.array(self.xy_out)
-
-        # Find edges and fill cells before cropping
-        for s in range(lbl.shape[2]):
-            lbl[:, :, s] = _filled_canny(lbl[:, :, s])
-
         return _apply_crop(img, xysize), _apply_crop(lbl, xysize)
 
     def downscale(self, img, lbl, maxpix=None):
@@ -520,17 +758,16 @@ class ScalingAugmenter(SmoothedLabelAugmenter):
         scale_index = self.aug_order.index('downscale')
         self.scale_prob = self.probs[scale_index]
         self.probs[scale_index] = 1
+        self.preserve_bool_lbl = False
 
     def __call__(self, img, lbl_info):
         lbl, info = lbl_info
+        self.current_info = info
         self._input_pix_size = info.get('pixel_size', self.target_pixel_size)
         self._scaling = self._input_pix_size / self.target_pixel_size
         self._outshape = np.round(np.array(self.xy_out) / self._scaling)
-        img, lbl = super(ScalingAugmenter, self).__call__(img, lbl_info)
-       # self._input_pix_size = None
-       # iself.scaling = None
-       # self._outshape = None
-        return img, lbl
+
+        return super(SmoothedLabelAugmenter, self).__call__(img, lbl)
 
     def vshift(self, img, lbl, maxpix=None):
         if maxpix is None:
@@ -550,99 +787,144 @@ class ScalingAugmenter(SmoothedLabelAugmenter):
             maxpix = np.max([0, (inshape - self._outshape[1]) // 2])
         return super(ScalingAugmenter, self).hshift(img, lbl, maxpix=maxpix)
 
+    def rough_crop(self, img, lbl):
+        xysize = np.array(self.xy_out, dtype=float) * 1.7 / self._scaling
+        xysize = np.ceil(xysize).astype(int)
+
+        nonempty = np.ones(lbl.shape[2], dtype=bool)
+        n_total = nonempty.size if lbl.any() else 0
+        img, lbl = super(ScalingAugmenter, self).rough_crop(
+            img, lbl, xysize=xysize, nonempty_slices=nonempty)
+
+        # Remove empty sections of lbl and sync with info
+        if lbl.shape[2] == 0:
+            lbl = np.zeros(lbl.shape[:2] + (1,), dtype=bool)
+        info = self.current_info.copy()
+        for k in ['cellLabels', 'focusStack', 'buds']:
+            if k in info and info[k] is not None:
+                l = info[k]
+                if type(l) != list:
+                    l = [l]
+                if len(l) != n_total:
+                    raise Exception(f'"{k}" info does not match label image shape ({str(info)})')
+                info[k] = [x for x, m in zip(l, nonempty) if m]
+        self.current_info = info
+
+        # Smooth filled label to avoid anti-aliasing artefacts
+        lbl = lbl.astype('float')
+        for l in range(lbl.shape[2]):
+            o = lbl[..., l]  # slice a single outline
+            sigma = max(self.sigmafunc(o), 1)
+            if o.sum() > 0:
+                lbl[..., l] = gaussian(_bordered_fill_holes(o), sigma)
+
+        return img, lbl
+
     def downscale(self, img, lbl, maxpix=None):
         # Scale image and label to target pixel size
         inshape = img.shape[:2]
         scaling = self._scaling
 
         # Apply random scaling according to probability for this op
+        # also need to correct the aug_log
+        last_aug = self.aug_log.pop()
+        assert last_aug == 'downscale'
         p = self.scale_prob
         if p == 1.0 or (p > 0 and np.random.uniform() < p):
+            self.aug_log.append('downscale')
             scaling += scaling * self.scale_frac * np.random.uniform(-1, 1)
 
-        outshape = np.round(np.array(img.shape[:2]) * scaling)
-        outshape = np.maximum(outshape, self.xy_out)
-
-        return (transform.resize(img, outshape),
-                transform.resize(lbl, outshape, anti_aliasing=False))
+        outshape = np.round(np.array(img.shape[:2]) * scaling).astype(int)
+        # OpenCV is much faster but much less accurate...
+        # img = cv2.resize(img, outshape, interpolation=cv2.INTER_LINEAR)
+        # if img.ndim == 2:
+        #     img = img[:, :, None]
+        # lbl = cv2.resize(lbl, outshape, interpolation=cv2.INTER_LINEAR)
+        # if lbl.ndim == 2:
+        #     lbl = lbl[:, :, None]
+        img = transform.resize(img, outshape, mode='edge', anti_aliasing=False)
+        lbl = transform.resize(lbl, outshape, order=3, mode='edge', anti_aliasing=False)
+        return img, lbl
 
 
 # =============== UTILITY FUNCTIONS ====================== #
 
 
+def _shift_and_crop(stack, vshift, hshift, xysize, pad_value,
+                    nonempty_slices=None):
+    r_in, c_in = stack.shape[:2]
+    r_out, c_out = xysize
+    rb_in = (r_in - r_out) // 2 - vshift
+    cb_in = (c_in - c_out) // 2 - hshift
+    re_in = rb_in + r_out
+    ce_in = cb_in + c_out
+    rb_out = max(0, -rb_in)
+    re_out = min(r_out, r_out - re_in + r_in)
+    cb_out = max(0, -cb_in)
+    ce_out = min(c_out, c_out - ce_in + c_in)
+    rb_in = max(0, rb_in)
+    re_in = min(r_in, re_in)
+    cb_in = max(0, cb_in)
+    ce_in = min(c_in, ce_in)
+    slicedims = stack.shape[2:]
+    if nonempty_slices is not None:
+        nonempty_slices[()] = stack[rb_in:re_in, cb_in:ce_in].any(axis=(0,1))
+        slicedims = (int(nonempty_slices.sum()),) + slicedims[1:]
+    out = np.full_like(stack, pad_value, shape=tuple(xysize) + slicedims)
+    if nonempty_slices is None:
+        out[rb_out:re_out, cb_out:ce_out] = stack[rb_in:re_in, cb_in:ce_in]
+    else:
+        out[rb_out:re_out, cb_out:ce_out, :] = stack[
+            rb_in:re_in, cb_in:ce_in, nonempty_slices]
+    return out
+
 def _apply_crop(stack, xysize):
     cropy, cropx = xysize
     starty, startx = stack.shape[:2]
-    startx = (startx - cropx) // 2
-    starty = (starty - cropy) // 2
-    return stack[starty:(starty + cropy), startx:(startx + cropx), ...]
+    if startx > cropx:
+        startx = (startx - cropx) // 2
+        stack = stack[:, startx:(startx + cropx), ...]
+    if starty > cropy:
+        starty = (starty - cropy) // 2
+        stack = stack[starty:(starty + cropy), :, ...]
+    return stack
 
 
-def _elastic_deform(image, alpha, sigma, random_state=None):
-    """
-    Elastic deformation of images as described in [Simard2003]_.
-    [Simard2003] Simard, Steinkraus and Platt, "Best Practices for
-    Convolutional Neural Networks applied to Visual Document Analysis", in
-    Proc. of the International Conference on Document Analysis and
-    Recognition, 2003.
-    Adapted from:
-    https://gist.github.com/chsasank/4d8f68caf01f041a6453e67fb30f8f5a
+def _bordered_fill_holes(edgemask, pad_width=1, prepadded=False):
+    """Fill holes in a mask treating the boundary as an edge
+
+    NB: this function assumes that no cells intersect with opposing borders.
     """
-    if random_state is None:
-        random_state = np.random.RandomState(None)
-
-    shape = image[:, :, 0].shape
-    dx = gaussian_filter(
-        (random_state.rand(*shape) * 2 - 1), sigma, mode="constant",
-        cval=0) * alpha
-    dy = gaussian_filter(
-        (random_state.rand(*shape) * 2 - 1), sigma, mode="constant",
-        cval=0) * alpha
-
-    _x, _y = np.meshgrid(np.arange(shape[0]),
-                         np.arange(shape[1]),
-                         indexing='ij')
-    indices = np.reshape(_x + dx, (-1, 1)), np.reshape(_y + dy, (-1, 1))
-    if len(image.shape) == 3:
-        result = np.empty_like(image)
-        for d in range(image.shape[2]):
-            # iterate over depth
-            cval = np.median(image[:, :, d])
-            result[:, :, d] = map_coordinates(image[:, :, d],
-                                              indices,
-                                              order=1,
-                                              cval=cval).reshape(shape)
-        result
-    else:
-        cval = np.median(image)
-        result = map_coordinates(image, indices, order=1,
-                                 cval=cval).reshape(shape)
-    return result
 
+    bp = pad_width
+    if not prepadded:
+        edgemask = np.pad(edgemask, bp)
+
+    mask = np.zeros(edgemask.shape, dtype='bool')
+    mask[bp:-bp, bp:-bp] = edgemask[bp:-bp, bp:-bp]
 
-def _filled_canny(segblur, bp=2):
+    # The following assumes that the cell does not intersect with opposing
+    # borders, filling bordering with two U-shaped border edges:
+    mask[:bp, :] = True
+    mask[-bp:, :] = True
+    mask[:, :bp] = True
+    mask = binary_fill_holes(mask)
+
+    mask[:, :bp] = False
+    mask[:, -bp:] = True
+    mask = binary_fill_holes(mask)
+
+    return mask[bp:-bp, bp:-bp]
+
+
+def _filled_canny(segblur, pad_width=2):
     """Use canny to find edge and fill object
 
     Handles intersections with border by assuming that the object cannot
     intersect all borders at once.
 
     segblur:  segmentation image that has been gaussian blurred
-    bp:       border padding
+    pad_width:       border padding
     """
-
-    se = canny(np.pad(segblur, bp, 'edge'), sigma=0)
-    sf = np.zeros(se.shape, dtype='bool')
-    sf[bp:-bp, bp:-bp] = se[bp:-bp, bp:-bp]
-
-    # The following assumes that the cell does not intersect with opposing
-    # borders, filling bordering with two U-shaped border edges:
-    sf[:bp, :] = True
-    sf[-bp:, :] = True
-    sf[:, :bp] = True
-    sf = binary_fill_holes(sf)
-
-    sf[:, :bp] = False
-    sf[:, -bp:] = True
-    sf = binary_fill_holes(sf)
-
-    return sf[bp:-bp, bp:-bp]
+    se = canny(np.pad(segblur, pad_width, 'edge'), sigma=0)
+    return _bordered_fill_holes(se, pad_width=pad_width, prepadded=True)
diff --git a/python/baby/brain.py b/python/baby/brain.py
index 50e9f4023e6278b8329ccec938d99e530182bf21..0f9cb0b529b0342cd7f20053e30792b786b8d413 100644
--- a/python/baby/brain.py
+++ b/python/baby/brain.py
@@ -1,23 +1,25 @@
 # If you publish results that make use of this software or the Birth Annotator
 # for Budding Yeast algorithm, please cite:
-# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S
-# Swain, 2021, Birth Annotator for Budding Yeast (in preparation).
-#
-#
+# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+# and Swain, P.S. (2023). Determining growth rates from bright-field images of
+# budding cells through identifying overlaps. eLife. 12:e79812.
+# https://doi.org/10.7554/eLife.79812
+# 
+# 
 # The MIT License (MIT)
-#
-# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021
-#
+# 
+# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
+# 
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to
 # deal in the Software without restriction, including without limitation the
 # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
 # sell copies of the Software, and to permit persons to whom the Software is
 # furnished to do so, subject to the following conditions:
-#
+# 
 # The above copyright notice and this permission notice shall be included in
 # all copies or substantial portions of the Software.
-#
+# 
 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
@@ -27,7 +29,8 @@
 # IN THE SOFTWARE.
 from __future__ import (absolute_import, division, print_function,
                         unicode_literals)
-from os.path import dirname, join, isfile, isdir
+from pathlib import Path
+import json
 from itertools import repeat, chain
 
 import numpy as np
@@ -37,30 +40,42 @@ import tensorflow as tf
 from tensorflow.keras import models, layers
 from tensorflow.keras import backend as K
 
-from .losses import bce_dice_loss, dice_loss, dice_coeff
-from .preprocessing import robust_norm, SegmentationFlattening
-from .morph_thresh_seg import MorphSegGrouped
-from .tracker.core import MasterTracker
-from .utils import batch_iterator, split_batch_pred
-from .brain_util import _segment, _track, _segment_and_track_parallel
+try:
+    from tensorflow.keras.optimizers import AdamW
+except ImportError:
+    try:
+        from tensorflow.keras.optimizers.experimental import AdamW
+    except ImportError:
+        try:
+            from tensorflow_addons.optimizers import AdamW
+        except ImportError:
+            raise ImportError('You need to pip install tensorflow-addons with this version of tensorflow')
+
+from skimage import transform
 
-models_path = join(dirname(__file__), 'models')
+from . import modelsets
+from .losses import bce_dice_loss, dice_loss, dice_coeff
+from .preprocessing import SegmentationFlattening
+from .morph_thresh_seg import MorphSegGrouped, SegmentationParameters
+from .tracker.core import MasterTracker, MMTracker
+from .utils import batch_iterator, split_batch_pred, as_python_object
+from .brain_util import (_segment, _track, _segment_and_track_parallel,
+                         _apply_preprocessing, _rescale_output,
+                         _tile_generator, _stitch_tiles, _batch_generator)
+from .errors import BadFile, BadParam
 
 tf_version = [int(v) for v in tf.__version__.split('.')]
 
-# Default to optimal segmentation parameters found in Jupyter notebook
-# segmentation-190906.ipynb:
-default_params = {
-    'interior_threshold': (0.7, 0.5, 0.5),
-    'nclosing': (1, 0, 0),
-    'nopening': (1, 0, 0),
-    'connectivity': (2, 2, 1),
-    'pedge_thresh': 0.001,
-    'fit_radial': True,
-    'edge_sub_dilations': 1,
-    'use_group_thresh': True,
-    'group_thresh_expansion': 0.1
-}
+
+DEFAULT_MODELSET = 'yeast-alcatras-brightfield-EMCCD-60x-5z'
+DEFAULT_MODELSET_PARAMS = None
+
+
+def _default_params():
+    global DEFAULT_MODELSET_PARAMS
+    if DEFAULT_MODELSET_PARAMS is None:
+        DEFAULT_MODELSET_PARAMS = modelsets.get_params(DEFAULT_MODELSET)
+    return DEFAULT_MODELSET_PARAMS
 
 
 class BabyBrain(object):
@@ -70,37 +85,58 @@ class BabyBrain(object):
     file or the name of the file in the default "models" dir shipped with this
     package. If any are left unspecified, then default models will be loaded.
 
-    :param morph_model_file: neural network model taking a stack of images and
-        outputting predictions according to the paired flattener model.
-    :param flattener_file: a saved `SegmentationFlattening` model specifying
-        the trained target types for each output layer of the neural network.
-    :param celltrack_model_file:
-    :param budassign_model_file:
-    :param default_image_size:
-    :param params: dict of keyword parameters to be passed to the
-        `morph_seg_grouped` function when segmenting.
-    :param session: optionally specify the Tensorflow session to load the
-        neural network model into (useful only for Tensorflow versions <2)
-    :param graph: optionally specify the Tensorflow graph to load the neural
-        network model into (useful only for Tensorflow versions <2)
-    :param suppress_errors: whether or not to catch Exceptions raised during
-        segmentation or tracking. If True, then any Exceptions will be logged
-        using standard Python logging.
-    :param error_dump_dir: optionally specify a directory in which to dump
-        input parameters when an error is caught.
+    Args:
+        morph_model_file: neural network model taking a stack of images and
+            outputting predictions according to the paired flattener model.
+        flattener_file: a saved `SegmentationFlattening` model specifying
+            the trained target types for each output layer of the neural network.
+        celltrack_model_file: file name of the saved machine learning model
+            for default tracking predictions.
+        celltrack_backup_model_file: file name of the saved machine learning
+            model for backup tracking predictions.
+        budassign_model_file: file name of the saved machine learning model
+            for predicting bud assignments
+        pixel_size (float): Target pixel size for inputs to the trained CNN.
+        default_image_size (None or Tuple[int] or int): Optionally specify an
+            alternative to the input size of the trained CNN as a ``(W, H)``
+            tuple giving the ``W`` and ``H`` of the image. If just an ``int``
+            is specified, then ``W = H`` is assumed.
+        params (str or Path or SegmentationParameters): Segmentation
+            parameters to use with :py:class:`MorphSegGrouped`. May be
+            specified as :py:class:`SegmentationParameters`, or the path to a
+            saved :py:class:`SegmentationParameters`.
+        modelset_path (str or Path): path to a folder containing the files
+            specified by the other arguments. The shared folder of the local
+            model set cache is always checked if file is not found at
+            `modelset_path`. See :py:func:`modelsets.resolve` for details on
+            how paths to model files are resolved.
+        session: optionally specify the Tensorflow session to load the
+            neural network model into (useful only for Tensorflow versions <2)
+        graph: optionally specify the Tensorflow graph to load the neural
+            network model into (useful only for Tensorflow versions <2)
+        suppress_errors: whether or not to catch Exceptions raised during
+            segmentation or tracking. If True, then any Exceptions will be logged
+            using standard Python logging.
+        error_dump_dir: optionally specify a directory in which to dump
+            input parameters when an error is caught.
     '''
 
     def __init__(self,
                  morph_model_file=None,
                  flattener_file=None,
                  celltrack_model_file=None,
+                 celltrack_backup_model_file=None,
                  budassign_model_file=None,
+                 mmtracking=False,
                  pixel_size=0.263,
                  default_image_size=None,
-                 params=default_params,
+                 params=SegmentationParameters(),
+                 nstepsback=None,
                  clogging_thresh=0.75,
                  min_bud_tps=3,
                  isbud_thresh=0.5,
+                 input_norm_dw=False,
+                 modelset_path=DEFAULT_MODELSET,
                  session=None,
                  graph=None,
                  print_info=False,
@@ -110,25 +146,28 @@ class BabyBrain(object):
         self.reshaped_models = {}
 
         if morph_model_file is None:
-            morph_model_file = join(models_path, 'I5_msd_d80_20190916.hdf5')
-        elif not isfile(morph_model_file):
-            morph_model_file = join(models_path, morph_model_file)
+            morph_model_file = _default_params()['morph_model_file']
+        morph_model_file = modelsets.resolve(morph_model_file, modelset_path)
 
         if flattener_file is None:
-            flattener_file = join(models_path, 'flattener_v2_20190905.json')
-        elif not isfile(flattener_file):
-            flattener_file = join(models_path, flattener_file)
+            flattener_file = _default_params()['flattener_file']
+        flattener_file = modelsets.resolve(flattener_file, modelset_path)
+
+        if type(params) == dict:
+            params = SegmentationParameters(**params)
+        if type(params) != SegmentationParameters:
+            param_file = modelsets.resolve(params, modelset_path)
+            with open(param_file, 'rt') as f:
+                params = json.load(f, object_hook=as_python_object)
+            if type(params) != SegmentationParameters:
+                raise BadFile('Specified file does not contain a valid '
+                              '`SegmentationParameters` object')
+
+        if not params.fit_radial:
+            raise BadParam('`BabyBrain` currently only works for '
+                           '`SegmentationParameters` with `fit_radial=True`')
 
-        if celltrack_model_file is None:
-            celltrack_model_file = join(models_path, 'ct_svc_20201106_6.pkl')
-        elif not isfile(celltrack_model_file):
-            celltrack_model_file = join(models_path, celltrack_model_file)
-
-        if budassign_model_file is None:
-            budassign_model_file = join(models_path,
-                                        'baby_randomforest_20190906.pkl')
-        elif not isfile(budassign_model_file):
-            budassign_model_file = join(models_path, budassign_model_file)
+        self.params = params
 
         self.session = None
         self.graph = None
@@ -154,7 +193,8 @@ class BabyBrain(object):
                     custom_objects={
                         'bce_dice_loss': bce_dice_loss,
                         'dice_loss': dice_loss,
-                        'dice_coeff': dice_coeff
+                        'dice_coeff': dice_coeff,
+                        'AdamW': AdamW
                     })
         else:
             self.morph_model = models.load_model(
@@ -164,35 +204,60 @@ class BabyBrain(object):
                     'dice_loss': dice_loss,
                     'dice_coeff': dice_coeff
                 })
+            if tf_version[0] == 2 and tf_version[1] > 3:
+                # TF 2.4 no longer supports reshaping using the functional API
+                # on a new Input since it allows multiple inputs to a layer
+                # Recommendation is to define generic Input layers with None
+                # for variable dimensions, or to disable input checking with
+                # the following code.
+                # See release notes for version 2.4.0 at
+                # https://github.com/tensorflow/tensorflow/blob/master/RELEASE.md
+                self.morph_model.input_spec = None
 
         self.suppress_errors = suppress_errors
         self.error_dump_dir = error_dump_dir
 
+        self.input_norm_dw = input_norm_dw
         self.flattener = SegmentationFlattening(flattener_file)
-        self.params = params
 
-        if ('use_group_thresh' not in self.params and
-                'group_thresh_expansion' in self.params):
-            self.params['use_group_thresh'] = True
         self.morph_segmenter = MorphSegGrouped(self.flattener,
-                                               fit_radial=True,
+                                               params=params,
                                                return_masks=True,
-                                               return_coords=True,
-                                               **self.params)
+                                               return_coords=True)
 
         self.pixel_size = pixel_size
         self.clogging_thresh = clogging_thresh
 
-        # Load tracker models and initialise Tracker
-        with open(celltrack_model_file, 'rb') as f:
-            celltrack_model = pickle.load(f)
-        with open(budassign_model_file, 'rb') as f:
-            budassign_model = pickle.load(f)
-        self.tracker = MasterTracker(ctrack_args={'model': celltrack_model},
-                                     btrack_args={'model': budassign_model},
-                                     min_bud_tps=min_bud_tps,
-                                     isbud_thresh=isbud_thresh,
-                                     px_size=pixel_size)
+        if mmtracking:
+            # Set the tracker engine to the model-free Mother Machine variant
+            self.tracker = MMTracker(px_size=pixel_size)
+        else:
+            # Load tracker models and initialise Tracker
+            if celltrack_model_file is None:
+                celltrack_model_file = _default_params()['celltrack_model_file']
+            celltrack_model_file = modelsets.resolve(celltrack_model_file, modelset_path)
+
+            if celltrack_backup_model_file is None:
+                celltrack_backup_model_file = _default_params()[
+                    'celltrack_backup_model_file']
+            celltrack_backup_model_file = modelsets.resolve(
+                celltrack_backup_model_file, modelset_path)
+
+            if budassign_model_file is None:
+                budassign_model_file = _default_params()['budassign_model_file']
+            budassign_model_file = modelsets.resolve(budassign_model_file, modelset_path)
+
+            ctrack_args ={
+                'model': celltrack_model_file,
+                'bak_model': celltrack_backup_model_file,
+                'nstepsback': nstepsback
+            }
+            btrack_args = {'model': budassign_model_file}
+            self.tracker = MasterTracker(ctrack_args=ctrack_args,
+                                         btrack_args=btrack_args,
+                                         min_bud_tps=min_bud_tps,
+                                         isbud_thresh=isbud_thresh,
+                                         px_size=pixel_size)
 
         # Run prediction on mock image to load model for prediction
         _, x, y, z = self.morph_model.input.shape
@@ -210,43 +275,78 @@ class BabyBrain(object):
     def depth(self):
         return self.morph_model.input.shape[3]
 
-    def morph_predict(self, X, needs_context=True):
+    def _predict(self, X, needs_context=True):
         if tf_version[0] == 1 and needs_context:
             with self.graph.as_default():
                 K.set_session(self.session)
-                return self.morph_predict(X, needs_context=False)
+                return self._predict(X, needs_context=False)
 
         imdims = X.shape[1:3]
-        # Current MSD model requires shape to be divisible by 8
-        nndims = tuple([int(np.ceil(float(d) / 8.)) * 8 for d in imdims])
+        # Current MSD model requires shape to be divisible by 8. Standard 5
+        # layer U-Net requires shape to be divisible by 16.
+        # CNN model for FRET is very sensitive to padding: we need images to
+        # be at least 64px wide, the padding needs to be centred and the pad
+        # value needs to be the image median. This basically makes the
+        # following match the padding protocol in the Augmenter
+        nndims = tuple(max(64, int(np.ceil(float(d) / 16.)) * 16)
+                       for d in imdims)
+        xpadoff, ypadoff = 0, 0
         if not all([n == i for n, i in zip(nndims, imdims)]):
             xpad, ypad = tuple(n - i for n, i in zip(nndims, imdims))
-            X = np.pad(X, ((0, 0), (0, xpad), (0, ypad), (0, 0)), 'edge')
+            xpadoff, ypadoff = xpad // 2, ypad // 2
+            xpad = (xpadoff, xpad - xpadoff)
+            ypad = (ypadoff, ypad - ypadoff)
+            X = np.array([np.pad(Xi, (xpad, ypad, (0, 0)),
+                                 mode='constant',
+                                 constant_values=np.median(Xi))
+                          for Xi in X])
 
         if nndims not in self.reshaped_models:
-            base_input_shape = self.morph_model.input.shape[1:3]
-            if all([n == m for n, m in zip(nndims, base_input_shape)]):
-                self.reshaped_models[nndims] = self.morph_model
-            else:
-                i = layers.Input(shape=X.shape[1:])
-                self.reshaped_models[nndims] = models.Model(
-                    i, self.morph_model(i))
+            i = layers.Input(shape=X.shape[1:])
+            self.reshaped_models[nndims] = models.Model(
+                i, self.morph_model(i))
 
         if tf_version[0] == 1 and self.print_info:
             print('Running prediction in session "{}"...'.format(
                 K.get_session()))
 
-        pred = self.reshaped_models[nndims].predict(X)
+        pred = self.reshaped_models[nndims].predict(X, verbose=0)
 
-        return [p[:, :imdims[0], :imdims[1], :] for p in pred]
+        return [p[:, xpadoff:xpadoff+imdims[0], ypadoff:ypadoff+imdims[1], :]
+                for p in pred]
+
+    def morph_predict(self, X, pixel_size=None, overlap_size=48,
+                      yield_rescaling=False, keep_bb_pixel_size=False):
+        # First preprocess each brightfield image in batch
+        X, rescaling, inshape = _apply_preprocessing(
+            X, self.input_norm_dw, pixel_size, self.pixel_size)
+
+        if yield_rescaling:
+            yield rescaling, inshape
+
+        tilegen = _tile_generator(X, overlap_size=overlap_size)
+        tiling_strategy = next(tilegen)
+        predgen = chain(*map(lambda x: split_batch_pred(self._predict(x)),
+                             _batch_generator(tilegen, 8)))
+        
+        for pred in _stitch_tiles(predgen, *tiling_strategy):
+            if keep_bb_pixel_size:
+                yield pred
+            else:
+                yield transform.resize(
+                    pred, (len(pred),) + inshape, order=1)
 
     def segment(self,
                 bf_img_batch,
+                pixel_size=None,
+                overlap_size=48,
                 yield_edgemasks=False,
                 yield_masks=False,
                 yield_preds=False,
                 yield_volumes=False,
-                refine_outlines=False):
+                refine_outlines=False,
+                yield_rescaling=False,
+                keep_bb_pixel_size=False):
         '''Generator yielding segmented output for a batch of input images
 
         :param bf_img_batch: a list of ndarray with shape (X, Y, Z), or
@@ -272,19 +372,26 @@ class BabyBrain(object):
             - volumes: (optional) list of floats corresponding, for each cell,
               to the conical section method for cell volume estimation 
         '''
-        # First preprocess each brightfield image in batch
-        bf_img_batch = np.stack(
-            [robust_norm(img, {}) for img in bf_img_batch])
 
-        for batch in batch_iterator(bf_img_batch):
-            morph_preds = split_batch_pred(self.morph_predict(batch))
-
-            for cnn_output in morph_preds:
-                yield _segment(self.morph_segmenter, cnn_output,
-                               refine_outlines, yield_volumes, yield_masks,
-                               yield_preds, yield_edgemasks,
-                               self.clogging_thresh, self.error_dump_dir,
-                               self.suppress_errors)
+        predgen = self.morph_predict(bf_img_batch,
+                                     pixel_size=pixel_size,
+                                     overlap_size=overlap_size,
+                                     yield_rescaling=True,
+                                     keep_bb_pixel_size=True)
+        rescaling, inshape = next(predgen)
+        if yield_rescaling:
+            yield rescaling, inshape
+
+        for cnn_output in predgen:
+            segout = _segment(self.morph_segmenter, cnn_output,
+                              refine_outlines, yield_volumes, yield_masks,
+                              yield_preds, yield_edgemasks,
+                              self.clogging_thresh, self.error_dump_dir,
+                              self.suppress_errors)
+            if not keep_bb_pixel_size:
+                _rescale_output(segout, rescaling, inshape,
+                                self.morph_segmenter.params.cartesian_spline)
+            yield segout
 
     def run(self, bf_img_batch):
         '''Implementation of legacy runner function...
@@ -314,6 +421,7 @@ class BabyBrain(object):
     def segment_and_track(self,
                           bf_img_batch,
                           tracker_states=None,
+                          pixel_size=None,
                           yield_next=False,
                           yield_edgemasks=False,
                           assign_mothers=False,
@@ -365,26 +473,46 @@ class BabyBrain(object):
             tracker_states = repeat(None)
 
         tnames = self.flattener.names()
-        i_budneck = tnames.index('bud_neck')
         bud_target = 'sml_fill' if 'sml_fill' in tnames else 'sml_inte'
-        i_bud = tnames.index(bud_target)
+        assign_buds = 'bud_neck' in tnames and bud_target in tnames
+        if assign_buds:
+            i_budneck = tnames.index('bud_neck')
+            i_bud = tnames.index(bud_target)
+        else:
+            i_budneck = None
+            i_bud = None
 
         segment_gen = self.segment(bf_img_batch,
+                                   pixel_size=pixel_size,
                                    yield_masks=True,
                                    yield_edgemasks=True,
                                    yield_preds=True,
                                    yield_volumes=yield_volumes,
-                                   refine_outlines=refine_outlines)
+                                   refine_outlines=refine_outlines,
+                                   yield_rescaling=True,
+                                   keep_bb_pixel_size=True)
+
+        rescaling, inshape = next(segment_gen)
 
         for seg, state in zip(segment_gen, tracker_states):
-            yield _track(self.tracker, seg, state, i_budneck, i_bud,
+            trackout = _track(self.tracker, seg, state, i_budneck, i_bud,
                          assign_mothers, return_baprobs, yield_edgemasks,
                          yield_next, self.error_dump_dir,
                          self.suppress_errors)
+            if yield_next:
+                segout, state = trackout
+                _rescale_output(segout, rescaling, inshape,
+                                self.morph_segmenter.params.cartesian_spline)
+                yield segout, state
+            else:
+                _rescale_output(trackout, rescaling, inshape,
+                                self.morph_segmenter.params.cartesian_spline)
+                yield trackout
 
     def segment_and_track_parallel(self,
                                    bf_img_batch,
                                    tracker_states=None,
+                                   pixel_size=None,
                                    yield_next=False,
                                    yield_edgemasks=False,
                                    assign_mothers=False,
@@ -432,21 +560,21 @@ class BabyBrain(object):
             tracker states for this time point as a tuple
         '''
 
-        # First preprocess each brightfield image in batch
-        bf_img_batch = np.stack(
-            [robust_norm(img, {}) for img in bf_img_batch])
-
         # Do not run the CNN in parallel
-        morph_preds = list(
-            chain(*(split_batch_pred(self.morph_predict(batch))
-                    for batch in batch_iterator(bf_img_batch))))
+        predgen = self.morph_predict(bf_img_batch,
+                                     pixel_size=pixel_size,
+                                     yield_rescaling=True,
+                                     keep_bb_pixel_size=True)
+        rescaling, inshape = next(predgen)
+        preds = list(predgen)
 
         if tracker_states is None:
             tracker_states = repeat(None)
 
         trackout = _segment_and_track_parallel(
-            self.morph_segmenter, self.tracker, self.flattener, morph_preds,
-            tracker_states, refine_outlines, yield_volumes, yield_edgemasks,
-            self.clogging_thresh, assign_mothers, return_baprobs, yield_next,
-            njobs, self.error_dump_dir, self.suppress_errors)
+            self.morph_segmenter, self.tracker, self.flattener, preds,
+            tracker_states, rescaling, inshape, refine_outlines,
+            yield_volumes, yield_edgemasks, self.clogging_thresh,
+            assign_mothers, return_baprobs, yield_next, njobs,
+            self.error_dump_dir, self.suppress_errors)
         return trackout
diff --git a/python/baby/brain_util.py b/python/baby/brain_util.py
index 10a7cd15a4703ae811d802f9315a61e551f40095..ef6e49845925fd7bafccad2b05dd8caaf5166af7 100644
--- a/python/baby/brain_util.py
+++ b/python/baby/brain_util.py
@@ -1,23 +1,25 @@
 # If you publish results that make use of this software or the Birth Annotator
 # for Budding Yeast algorithm, please cite:
-# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S
-# Swain, 2021, Birth Annotator for Budding Yeast (in preparation).
-#
-#
+# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+# and Swain, P.S. (2023). Determining growth rates from bright-field images of
+# budding cells through identifying overlaps. eLife. 12:e79812.
+# https://doi.org/10.7554/eLife.79812
+# 
+# 
 # The MIT License (MIT)
-#
-# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021
-#
+# 
+# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
+# 
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to
 # deal in the Software without restriction, including without limitation the
 # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
 # sell copies of the Software, and to permit persons to whom the Software is
 # furnished to do so, subject to the following conditions:
-#
+# 
 # The above copyright notice and this permission notice shall be included in
 # all copies or substantial portions of the Software.
-#
+# 
 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
@@ -31,13 +33,18 @@ from os.path import join, isdir
 from time import strftime
 from uuid import uuid1
 import logging
+from itertools import islice
 import numpy as np
 import pickle
-from scipy.ndimage import binary_dilation
+from scipy.ndimage import binary_dilation, binary_fill_holes
+from skimage import transform
+from skimage.morphology import diamond
 
 from .morph_thresh_seg import SegmentationOutput
+from .segmentation import draw_radial
 from .io import save_tiled_image
-from .errors import Clogging
+from .errors import Clogging, BadParam
+from .preprocessing import robust_norm, robust_norm_dw
 
 _logger = None
 
@@ -62,6 +69,170 @@ def _generate_error_dump_id():
     return strftime('%Y-%m-%d_%H-%M-%S_') + str(uuid1())
 
 
+def _batch_generator(iterable, n=8):
+    """Yields batches from an iterator
+    Based on
+    https://docs.python.org/3/library/itertools.html#itertools-recipes"""
+    it = iter(iterable)
+    while True:
+        batch = list(islice(it, n))
+        if not batch:
+            return
+        yield np.array(batch)
+
+
+def _tile_generator(imgarray,
+                    overlap_size=48,
+                    tile_sizes=np.arange(64, 257, 8)):
+    """Yields optimal tiles from an image."""
+
+    imH, imW = imgarray.shape[1:3]
+
+    if imH * imW <= np.square(np.max(tile_sizes)):
+        # If area is less than that of max tile size we will simply return
+        # images without modification
+        nH, nW = 1, 1
+        tileH, tileW = imH, imW
+    else:
+        tile_sizes = tile_sizes[tile_sizes > overlap_size + 32]
+        if len(tile_sizes) == 0:
+            raise BadParam('`overlap_size` is too large for the ' +
+                           'specified tile_sizes')
+
+        # For each candidate tile size, calculate the required number of tiles
+        # per image (rounding up), including allowance for overlaps
+        nH = np.ceil((imH - overlap_size)/(tile_sizes - overlap_size)).astype(int)
+        nW = np.ceil((imW - overlap_size)/(tile_sizes - overlap_size)).astype(int)
+
+        # Calculate the new total size given the rounding above
+        paddedH = nH * (tile_sizes - overlap_size) + overlap_size
+        paddedW = nW * (tile_sizes - overlap_size) + overlap_size
+
+        # Choose the tile size that minimises the fractional padding
+        padH = paddedH - imH
+        i_best =  np.argmin(padH / tile_sizes)
+        nH, tileH = nH[i_best], tile_sizes[i_best]
+        padH = padH[i_best]
+        padW = paddedW - imW
+        i_best =  np.argmin(padW / tile_sizes)
+        nW, tileW = nW[i_best], tile_sizes[i_best]
+        padW = padW[i_best]
+
+        # Pad image to tiled size
+        imgarray = np.pad(
+            imgarray, ((0, 0), (0, padH), (0, padW), (0, 0)))
+
+    # First yield tiling details for future stitching
+    yield nH, nW, imH, imW, overlap_size
+
+    # Split the images up into tiles
+    for img in imgarray:
+        for j in range(nH):
+            for i in range(nW):
+                Hl = j * (tileH - overlap_size)
+                Hu = Hl + tileH
+                Wl = i * (tileW - overlap_size)
+                Wu = Wl + tileW
+                yield img[Hl:Hu, Wl:Wu, ...]
+
+
+def _stitch_tiles(tilestream, nH, nW, imH, imW, nOvlap):
+    """Stitches tiles back together"""
+
+    # Regions where tiles overlap use linear decay for weighted averaging
+    lindecay = np.linspace(1, 0, nOvlap + 2)[1:-1]
+    Hdecay = np.c_[lindecay][..., None]
+    Wdecay = np.c_[lindecay].T[..., None]
+
+    # Batch tiles into each complete image
+    for tiles in zip(*[iter(tilestream)] * (nW * nH)):
+        # Transpose tiles so that we can stitch with h/vstack
+        tiles = [tile.transpose((1, 2, 0)) for tile in tiles]
+
+        img = None
+        # Batch tiles into rows at a time
+        for row in zip(*[iter(tiles)]*nW):
+            # Stitch tiles into a row
+            rowimg = row[0]
+            for t in row[1:]:
+                rowimg[:, -nOvlap:] = (rowimg[:, -nOvlap:] * Wdecay
+                                       + t[:, :nOvlap] * Wdecay[:, ::-1])
+                rowimg = np.hstack((rowimg, t[:, nOvlap:]))
+
+            # Stitch rows together
+            if img is None:
+                img = rowimg
+            else:
+                img[-nOvlap:, :] = (img[-nOvlap:, :] * Hdecay +
+                                    rowimg[:nOvlap, :] * Hdecay[::-1, :])
+                img = np.vstack((img, rowimg[nOvlap:, :]))
+
+        yield img[:imH, :imW, :].transpose((2, 0, 1))
+
+
+def _apply_preprocessing(bf_imgs, input_norm_dw, pxsize_in, pxsize_out):
+    input_norm = robust_norm_dw if input_norm_dw else robust_norm
+    bf_imgs = np.stack([input_norm(img, {}) for img in bf_imgs])
+    inshape = bf_imgs.shape
+    rescaling = None
+    if pxsize_in is not None and pxsize_in != pxsize_out:
+        rescaling = pxsize_in / pxsize_out
+    if rescaling is not None:
+        bf_imgs = list(bf_imgs)
+        for i, img in enumerate(bf_imgs):
+            bf_imgs[i] = transform.rescale(img,
+                                           rescaling,
+                                           order=1,
+                                           channel_axis=2)
+        bf_imgs = np.stack(bf_imgs)
+        rescaling = 1. / rescaling
+    return bf_imgs, rescaling, inshape[1:3]
+
+
+def _rescale_output(output, rescaling, outshape, cartesian_spline):
+    if rescaling is None:
+        return
+
+    if 'centres' in output:
+        output['centres'] = [[x * rescaling for x in cell]
+                             for cell in output['centres']]
+    if 'radii' in output:
+        output['radii'] = [[r * rescaling for r in cell]
+                           for cell in output['radii']]
+    has_coords = {'centres', 'angles', 'radii'}.issubset(output.keys())
+    if has_coords and ('edgemasks' in output or 'masks' in output):
+        edgemasks = np.stack([
+            draw_radial(np.array(r), np.array(a), c, outshape,
+                        cartesian_spline=cartesian_spline)
+            for r, a, c in zip(output['radii'], output['angles'],
+                               output['centres'])
+        ])
+    _0xy = (0,) + outshape
+    if 'edgemasks' in output:
+        if output['edgemasks'].shape[0] == 0:
+            output['edgemasks'] = np.zeros(_0xy, dtype=bool)
+        elif has_coords:
+            output['edgemasks'] = edgemasks
+        else:
+            output['edgemasks'] = transform.resize(
+                output['edgemasks'], (len(output['edgemasks']),) + outshape, order=0)
+    if 'masks' in output:
+        if output['masks'].shape[0] == 0:
+            output['masks'] = np.zeros(_0xy, dtype=bool)
+        elif has_coords:
+            output['masks'] = binary_fill_holes(
+                edgemasks, diamond(1)[None, ...])
+        else:
+            output['masks'] = transform.resize(
+                output['masks'], (len(output['masks']),) + outshape, order=0)
+    if 'preds' in output:
+        output['preds'] = transform.resize(
+            output['preds'], (len(output['preds']),) + outshape, order=1)
+    if 'volumes' in output:
+        vscaling = rescaling ** 3
+        output['volumes'] = [v * vscaling for v in output['volumes']]
+
+
 def _segment(segmenter,
              cnn_output,
              refine_outlines,
@@ -166,10 +337,18 @@ def _track(tracker,
     if logger is None:
         logger = _get_logger()
 
+    if i_budneck is None:
+        p_budneck = np.zeros(seg['preds'].shape[1:])
+    else:
+        p_budneck = seg['preds'][i_budneck]
+    if i_bud is None:
+        p_bud = np.zeros(seg['preds'].shape[1:])
+    else:
+        p_bud = seg['preds'][i_bud]
     try:
         tracking = tracker.step_trackers(seg['masks'],
-                                         seg['preds'][i_budneck],
-                                         seg['preds'][i_bud],
+                                         p_budneck,
+                                         p_bud,
                                          state=state,
                                          assign_mothers=assign_mothers,
                                          return_baprobs=return_baprobs)
@@ -184,10 +363,10 @@ def _track(tracker,
                     seg['masks'].transpose((1, 2, 0)).astype('uint8'),
                     fprefix + '_masks.png')
             save_tiled_image(
-                np.uint16((2**16 - 1) * seg['preds'][i_budneck, :, :, None]),
+                np.uint16((2**16 - 1) * p_budneck[..., None]),
                 fprefix + '_budneck_pred.png')
             save_tiled_image(
-                np.uint16((2**16 - 1) * seg['preds'][i_bud, :, :, None]),
+                np.uint16((2**16 - 1) * p_bud[..., None]),
                 fprefix + '_bud_pred.png')
             with open(fprefix + '_state.pkl', 'wb') as f:
                 pickle.dump(state, f)
@@ -250,6 +429,8 @@ def _segment_and_track(segmenter,
                        state,
                        i_budneck,
                        i_bud,
+                       rescaling,
+                       outshape,
                        refine_outlines,
                        yield_volumes,
                        yield_edgemasks,
@@ -274,40 +455,53 @@ def _segment_and_track(segmenter,
                       error_dump_dir,
                       suppress_errors,
                       logger=logger)
-    return _track(tracker,
-                  segout,
-                  state,
-                  i_budneck,
-                  i_bud,
-                  assign_mothers,
-                  return_baprobs,
-                  yield_edgemasks,
-                  yield_next,
-                  error_dump_dir,
-                  suppress_errors,
-                  logger=logger)
+    trackout = _track(tracker,
+                      segout,
+                      state,
+                      i_budneck,
+                      i_bud,
+                      assign_mothers,
+                      return_baprobs,
+                      yield_edgemasks,
+                      yield_next,
+                      error_dump_dir,
+                      suppress_errors,
+                      logger=logger)
 
+    if yield_next:
+        segout, state = trackout
+        _rescale_output(segout, rescaling, outshape,
+                        segmenter.params.cartesian_spline)
+        return segout, state
+    else:
+        _rescale_output(trackout, rescaling, outshape,
+                        segmenter.params.cartesian_spline)
+        return trackout
 
-def _segment_and_track_parallel(segmenter, tracker, flattener, morph_preds,
-                                tracker_states, refine_outlines,
-                                yield_volumes, yield_edgemasks,
-                                clogging_thresh, assign_mothers,
-                                return_baprobs, yield_next, njobs,
-                                error_dump_dir, suppress_errors):
 
-    # logger = _get_logger()
+def _segment_and_track_parallel(segmenter, tracker, flattener, morph_preds,
+                                tracker_states, rescaling, outshape,
+                                refine_outlines, yield_volumes,
+                                yield_edgemasks, clogging_thresh,
+                                assign_mothers, return_baprobs, yield_next,
+                                njobs, error_dump_dir, suppress_errors):
 
     tnames = flattener.names()
-    i_budneck = tnames.index('bud_neck')
     bud_target = 'sml_fill' if 'sml_fill' in tnames else 'sml_inte'
-    i_bud = tnames.index(bud_target)
+    assign_buds = 'bud_neck' in tnames and bud_target in tnames
+    if assign_buds:
+        i_budneck = tnames.index('bud_neck')
+        i_bud = tnames.index(bud_target)
+    else:
+        i_budneck = None
+        i_bud = None
 
     # Run segmentation and tracking in parallel
     from joblib import Parallel, delayed
     return Parallel(n_jobs=njobs, mmap_mode='c')(
         delayed(_segment_and_track)
         (segmenter, tracker, cnn_output, state, i_budneck, i_bud,
-         refine_outlines, yield_volumes, yield_edgemasks, clogging_thresh,
-         assign_mothers, return_baprobs, yield_next, error_dump_dir,
-         suppress_errors)
+         rescaling, outshape, refine_outlines, yield_volumes, yield_edgemasks,
+         clogging_thresh, assign_mothers, return_baprobs, yield_next,
+         error_dump_dir, suppress_errors)
         for cnn_output, state in zip(morph_preds, tracker_states))
diff --git a/python/baby/crawler.py b/python/baby/crawler.py
index 0929f140d1285c8595f147d1242b11562c53dec5..328e9fe3ffbc917fbd071349d9d3afbdc35b3409 100644
--- a/python/baby/crawler.py
+++ b/python/baby/crawler.py
@@ -1,12 +1,14 @@
 # If you publish results that make use of this software or the Birth Annotator
 # for Budding Yeast algorithm, please cite:
-# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S
-# Swain, 2021, Birth Annotator for Budding Yeast (in preparation).
+# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+# and Swain, P.S. (2023). Determining growth rates from bright-field images of
+# budding cells through identifying overlaps. eLife. 12:e79812.
+# https://doi.org/10.7554/eLife.79812
 # 
 # 
 # The MIT License (MIT)
 # 
-# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021
+# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
 # 
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to
@@ -47,6 +49,7 @@ class BabyCrawler(object):
 
     def step(self,
              bf_img_batch,
+             pixel_size=None,
              with_edgemasks=False,
              assign_mothers=False,
              return_baprobs=False,
@@ -92,20 +95,24 @@ class BabyCrawler(object):
         output = []
 
         if parallel:
+            kwargs = {k: v for k, v in kwargs.items() if k in {'njobs'}}
             seg_trk_gen = self.baby_brain.segment_and_track_parallel(
                 bf_img_batch,
                 tracker_states=self.tracker_states,
                 yield_next=True,
+                pixel_size=pixel_size,
                 yield_edgemasks=with_edgemasks,
                 yield_volumes=with_volumes,
                 assign_mothers=assign_mothers,
                 return_baprobs=return_baprobs,
-                refine_outlines=refine_outlines)
+                refine_outlines=refine_outlines,
+                **kwargs)
         else:
             seg_trk_gen = self.baby_brain.segment_and_track(
                 bf_img_batch,
                 tracker_states=self.tracker_states,
                 yield_next=True,
+                pixel_size=pixel_size,
                 yield_edgemasks=with_edgemasks,
                 yield_volumes=with_volumes,
                 assign_mothers=assign_mothers,
diff --git a/python/baby/errors.py b/python/baby/errors.py
index c53d3d55b58777d7cee435bc0edd2113b8ae6efc..ce052a9181a054a5836767d351b1dd7e95fb062e 100644
--- a/python/baby/errors.py
+++ b/python/baby/errors.py
@@ -1,12 +1,14 @@
 # If you publish results that make use of this software or the Birth Annotator
 # for Budding Yeast algorithm, please cite:
-# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S
-# Swain, 2021, Birth Annotator for Budding Yeast (in preparation).
+# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+# and Swain, P.S. (2023). Determining growth rates from bright-field images of
+# budding cells through identifying overlaps. eLife. 12:e79812.
+# https://doi.org/10.7554/eLife.79812
 # 
 # 
 # The MIT License (MIT)
 # 
-# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021
+# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
 # 
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to
@@ -56,3 +58,6 @@ class BadOutput(Exception):
 class Clogging(Exception):
     pass
 
+
+class BadModel(Exception):
+    pass
diff --git a/python/baby/generator.py b/python/baby/generator.py
index 5e6fa8c48aa7535fa0ac667af57ef7bec182e8a5..f3e2e975dc36e83ec765b3f5c3dc8c25a6cb4a7f 100644
--- a/python/baby/generator.py
+++ b/python/baby/generator.py
@@ -1,12 +1,14 @@
 # If you publish results that make use of this software or the Birth Annotator
 # for Budding Yeast algorithm, please cite:
-# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S
-# Swain, 2021, Birth Annotator for Budding Yeast (in preparation).
+# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+# and Swain, P.S. (2023). Determining growth rates from bright-field images of
+# budding cells through identifying overlaps. eLife. 12:e79812.
+# https://doi.org/10.7554/eLife.79812
 # 
 # 
 # The MIT License (MIT)
 # 
-# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021
+# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
 # 
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to
@@ -25,15 +27,21 @@
 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
 # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
 # IN THE SOFTWARE.
+import json
 from pathlib import Path
+from contextlib import contextmanager
 import numpy as np
 from functools import namedtuple
-from tensorflow.python.keras.utils.data_utils import Sequence
+from itertools import repeat
+from tensorflow.keras.utils import Sequence
 from tqdm import tqdm
+from matplotlib import pyplot as plt
+from PIL import Image
 
 from .io import load_tiled_image
 from .augmentation import Augmenter
 from .preprocessing import standard_norm
+from .visualise import colour_segstack
 
 # Following are used by ImgLblInMem and should be deprecated
 from .preprocessing import robust_norm as preprocess_brightfield
@@ -44,7 +52,8 @@ ImageLabelShapes = namedtuple('ImageLabelShapes', ('input', 'output'))
 
 class ImageLabel(Sequence):
     def __init__(self, paths, batch_size, aug, preprocess=None,
-                 in_memory=False):
+                 in_memory=False, balanced_sampling=False,
+                 use_sample_weights=False, n_jobs=4):
         """Generator for training image-label pairs.
 
         Arguments:
@@ -55,12 +64,15 @@ class ImageLabel(Sequence):
             preprocess: a single callable or tuple of callables (one for each
                 file of pair); None specifies the default `standard_norm`
             in_memory: whether or not to load all images into memory
+            balanced_sampling: whether to increase sampling of images based on
+                their shape relative to the output crop
         """
 
         self.batch_size = batch_size
 
         assert callable(aug), '"aug" must be a callable'
         self.aug = aug
+        self.n_jobs = n_jobs
 
         # Check that all required images exist
         self.paths = [(Path(img), Path(lbl)) for img, lbl in paths]
@@ -84,29 +96,98 @@ class ImageLabel(Sequence):
                 [ppf(*load_tiled_image(img)) for ppf, img
                  in zip(self.preprocess, imgs)] for imgs in tqdm(self.paths)
             ]
+        
+        self.balanced_sampling = balanced_sampling
+        self.use_sample_weights = use_sample_weights
 
         # Initialise ordering
         self.on_epoch_end()
 
     def __len__(self):
-        return int(np.ceil(len(self.paths) / float(self.batch_size)))
+        return int(np.ceil(sum(self.nsamples) / float(self.batch_size)))
 
     @property
     def shapes(self):
-        if not hasattr(self, '_shapes') or not self._shapes:
-            if len(self.paths) == 0:
-                return ImageLabelShapes(tuple(), tuple())
+        if len(self.paths) == 0:
+            return ImageLabelShapes(tuple(), tuple())
+
+        img, lbl = self.get_by_index(0)
+        if type(lbl) == tuple and len(lbl) == 2:
+            lbl, _ = lbl
+        Nbatch = (self.batch_size,)
+        return ImageLabelShapes(Nbatch + img.shape, Nbatch + lbl.shape)
+
+    def _collect_size_info(self):
+        self._rawshapes = []
+        self._pixelsizes = []
+        self._weights = []
+        for imgfile, lblfile in self.paths:
+            img = Image.open(imgfile)
+            img_info = json.loads(img.info.get('Description', '{}'))
+            if 'tilesize' in img_info and 'ntiles' in img_info:
+                img_shape = tuple(img_info['tilesize']) + (img_info['ntiles'],)
+            else:
+                img_shape = img.size + (1,)
+            self._rawshapes.append(img_shape)
+
+            lbl = Image.open(lblfile)
+            lbl_info = json.loads(lbl.info.get('Description', '{}'))
+            self._pixelsizes.append(lbl_info.get('pixel_size'))
+            self._weights.append(lbl_info.get('weight',1))
+        self._weights = np.array(self._weights)
 
-            img, lbl = self.get_by_index(0)
-            Nbatch = (self.batch_size,)
-            self._shapes = ImageLabelShapes(Nbatch + img.shape,
-                                          Nbatch + lbl.shape)
-        return self._shapes
+    @property
+    def rawshapes(self):
+        if not hasattr(self, '_rawshapes') or self._rawshapes is None:
+            self._collect_size_info()
+        return self._rawshapes
+
+    @property
+    def pixelsizes(self):
+        if not hasattr(self, '_pixelsizes') or self._pixelsizes is None:
+            self._collect_size_info()
+        return self._pixelsizes
+
+    @property
+    def weights(self):
+        if not hasattr(self, '_weights') or self._weights is None:
+            self._collect_size_info()
+        return self._weights
+
+    @property
+    def target_pixel_size(self):
+        val = None
+        if hasattr(self.aug, 'target_pixel_size'):
+            val = self.aug.target_pixel_size
+        return val
+
+    @property
+    def nsamples(self):
+        if ~self.balanced_sampling:
+            return np.ones(len(self.rawshapes), dtype='int').tolist()
+
+        aug_shape = self.shapes.input[1:]
+        usepxsz = self.target_pixel_size is not None
+        if usepxsz:
+            aug_size = np.array(aug_shape, dtype=float)
+            aug_size[:2] *= self.target_pixel_size
+        nsamples = []
+        for in_shape, pxsz in zip(self.rawshapes, self.pixelsizes):
+            if pxsz and usepxsz:
+                in_size = np.array(in_shape, dtype=float)
+                in_size[:2] *= pxsz
+                szratio = np.floor_divide(in_size, aug_size)
+            else:
+                szratio = np.floor_divide(in_shape, aug_shape)
+            szratio = szratio.astype(int)
+            nsamples.append(np.prod(np.maximum(szratio, 1)))
+        return nsamples
 
     def on_epoch_end(self):
         # Shuffle samples for next epoch
-        Nsamples = len(self.paths)
-        self.ordering = np.random.choice(Nsamples, Nsamples, replace=False)
+        self.ordering = np.repeat(np.arange(len(self.nsamples)),
+                                  self.nsamples)
+        np.random.shuffle(self.ordering)
 
     @property
     def n_pairs(self):
@@ -124,24 +205,166 @@ class ImageLabel(Sequence):
         else:
             return aug(img, lbl)
 
+    def parallel_get_indices(self, inds, n_jobs=None):
+        if n_jobs is None:
+            n_jobs = self.n_jobs
+        passthrough = lambda img, lbl: (img, lbl)
+        img_lbl_pairs = [self.get_by_index(i, aug=passthrough) for i in inds]
+        from joblib import Parallel, delayed
+        return Parallel(n_jobs=n_jobs)(
+            delayed(self.aug)(img, lbl) for img, lbl in img_lbl_pairs)
+
     def __getitem__(self, idx):
         Nbatch = self.batch_size
         current_batch = self.ordering[idx * Nbatch:(idx + 1) * Nbatch]
 
-        img_batch = []
-        lbl_batch = []
-
-        for i in current_batch:
-            img, lbl = self.get_by_index(i)
-            lbl = np.dsplit(lbl, lbl.shape[2])
+        if self.n_jobs > 1:
+            img_batch, lbl_batch = zip(*self.parallel_get_indices(current_batch))
+        else:
+            img_batch, lbl_batch = zip(*[self.get_by_index(i) for i in
+                                         current_batch])
 
-            img_batch.append(img)
-            lbl_batch.append(lbl)
+        lbl_batch = [np.dsplit(lbl, lbl.shape[2]) for lbl in lbl_batch]
 
         img_batch = np.array(img_batch)
         lbl_batch = [np.array(lw) for lw in zip(*lbl_batch)]
 
-        return img_batch, lbl_batch
+        if self.use_sample_weights:
+            return img_batch, lbl_batch, self.weights[current_batch]
+        else:
+            return img_batch, lbl_batch
+
+    def plot_sample(self, i=0, figsize=3):
+        """Plot a sample batch from the generator
+
+        This function assumes that the assigned Augmenter produces label
+        images that can be concatenated along a new axis.
+        """
+        img_batch, lbl_batch = self[i][:2]
+        lbl_batch = np.concatenate(lbl_batch, axis=3)
+
+        n_sections = img_batch.shape[3]
+        n_targets = lbl_batch.shape[3]
+
+        target_names = repeat(None, n_targets)
+        edge_inds = None
+        if hasattr(self.aug, 'targetgenfunc'):
+            if hasattr(self.aug.targetgenfunc, 'names'):
+                target_names = self.aug.targetgenfunc.names()
+            if hasattr(self.aug.targetgenfunc, 'targets'):
+                edge_inds = np.flatnonzero(
+                    [t.prop == 'edge' for t in self.aug.targetgenfunc.targets])
+
+        ncol = len(img_batch)
+        nrow =  n_sections + n_targets
+        fig, axs = plt.subplots(nrow, ncol,
+                                figsize=(figsize * ncol, figsize * nrow))
+
+        # Plot img sections first...
+        for axrow, section in zip(axs[:n_sections],
+                                  np.split(img_batch, n_sections, axis=3)):
+            for ax, img, lbl in zip(axrow, section, lbl_batch):
+                ax.imshow(img, cmap='gray')
+                if edge_inds is not None:
+                    ax.imshow(colour_segstack(lbl[..., edge_inds], dw=True))
+                ax.grid(False)
+                ax.set(xticks=[], yticks=[])
+
+        # ...then plot targets
+        for axrow, target, name in zip(axs[n_sections:],
+                                       np.split(lbl_batch, n_targets, axis=3),
+                                       target_names):
+            for ax, lbl in zip(axrow, target):
+                ax.imshow(lbl, cmap='gray')
+                ax.grid(False)
+                ax.set(xticks=[], yticks=[])
+                if name is not None:
+                    ax.set_title(name)
+
+        return fig, axs
+
+
+
+@contextmanager
+def augmented_generator(gen: ImageLabel, aug: Augmenter):
+    # Save the previous augmenter if any
+    saved_aug = gen.aug
+    gen.aug = aug
+    try:
+        yield gen
+    # Todo: add except otherwise there might be an issue of there is an error?
+    finally:
+        gen.aug = saved_aug
+
+
+class AugmentedGenerator(Sequence):
+    """Wraps a generator with an alternative augmenter.
+
+    Args:
+        gen (ImageLabel): Generator to wrap.
+        aug (augmentation.Augmenter): Augmenter to use.
+    """
+    def __init__(self, gen, aug):
+        self._gen = gen
+        self._aug = aug
+        self.on_epoch_end()
+
+    def __len__(self):
+        with augmented_generator(self._gen, self._aug) as g:
+            return len(g)
+
+    @property
+    def batch_size(self):
+        return self._gen.batch_size
+
+    @property
+    def shapes(self):
+        with augmented_generator(self._gen, self._aug) as g:
+            return g.shapes
+
+    @property
+    def nsamples(self):
+        with augmented_generator(self._gen, self._aug) as g:
+            return g.nsamples
+
+    @property
+    def rawshapes(self):
+        return self._gen.rawshapes
+
+    @property
+    def pixelsizes(self):
+        return self._gen.pixelsizes
+
+    def on_epoch_end(self):
+        with augmented_generator(self._gen, self._aug) as g:
+            g.on_epoch_end()
+
+    @property
+    def ordering(self):
+        return self._gen.ordering
+
+    @property
+    def n_pairs(self):
+        return self._gen.n_pairs
+
+    def get_by_index(self, i, aug=None):
+        if aug is None:
+            with augmented_generator(self._gen, self._aug) as g:
+                return g.get_by_index(i)
+        else:
+            return self._gen.get_by_index(i, aug=aug)
+
+    def parallel_get_indices(self, inds, **kwargs):
+        with augmented_generator(self._gen, self._aug) as g:
+            return g.parallel_get_indices(inds, **kwargs)
+
+    def __getitem__(self, idx):
+        with augmented_generator(self._gen, self._aug) as g:
+            return g[idx]
+
+    def plot_sample(self, *args, **kwargs):
+        with augmented_generator(self._gen, self._aug) as g:
+            return g.plot_sample(*args, **kwargs)
 
 
 class ImgLblInMem(Sequence):
diff --git a/python/baby/io.py b/python/baby/io.py
index f468bf451443099fa2f6262e51ba160b3492e773..858977ca056f084f4caa2ee32dd1ec72ec6f4571 100644
--- a/python/baby/io.py
+++ b/python/baby/io.py
@@ -1,12 +1,14 @@
 # If you publish results that make use of this software or the Birth Annotator
 # for Budding Yeast algorithm, please cite:
-# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S
-# Swain, 2021, Birth Annotator for Budding Yeast (in preparation).
+# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+# and Swain, P.S. (2023). Determining growth rates from bright-field images of
+# budding cells through identifying overlaps. eLife. 12:e79812.
+# https://doi.org/10.7554/eLife.79812
 # 
 # 
 # The MIT License (MIT)
 # 
-# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021
+# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
 # 
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to
@@ -29,7 +31,7 @@ import json
 import re
 
 from pathlib import Path
-from typing import Union
+from typing import Union, Tuple
 from fnmatch import translate as glob_to_re
 from os import walk
 from itertools import groupby, chain, repeat
@@ -37,17 +39,38 @@ from collections import namedtuple, Counter
 import numpy as np
 import random
 import pandas as pd
-from PIL.PngImagePlugin import PngInfo
-from imageio import imread, imwrite
 from sklearn.model_selection import train_test_split
 
+from PIL.PngImagePlugin import PngInfo
+LEGACY_IIO = False
+try:
+    from imageio import v3 as iio
+except ImportError:
+    from imageio import imread, imwrite
+    LEGACY_IIO = True
+
 from .errors import LayoutError, UnpairedImagesError
 from .utils import PathEncoder
 
 
 def load_tiled_image(filename):
-    tImg = imread(filename)
-    info = json.loads(tImg.meta.get('Description', '{}'))
+    if LEGACY_IIO:
+        tImg = imread(filename)
+        info = tImg.meta
+    else:
+        tImg = iio.imread(filename)
+        info = iio.immeta(filename)
+        # Due to a limitation in Pillow, 16-bit (uint16) PNG images are
+        # currently loaded as int32. This will apparently change in a future
+        # version of Pillow, but for the moment, manually force conversion to
+        # uint16. It is still possible to write 16-bit PNG images, and in my
+        # tests, images of dtype int32 are saved as uint16 (as are images of
+        # dtype uint32). Other dtypes (e.g., bool and uint8) load and write as
+        # expected.
+        if str(filename).lower().endswith('.png') and tImg.dtype == 'int32':
+            tImg = tImg.astype('uint16')
+
+    info = json.loads(info.get('Description', '{}'))
     tw, th = info.get('tilesize', tImg.shape[0:2])
     nt = info.get('ntiles', 1)
     nr, nc = info.get('layout', (1, 1))
@@ -61,7 +84,7 @@ def load_tiled_image(filename):
     return img, info
 
 
-def save_tiled_image(img, filename, info={}, layout=None):
+def save_tiled_image(img, filename, info={}, layout=None, fill=0):
     if layout is not None and len(layout) != 2:
         raise LayoutError('"layout" must a 2-tuple')
 
@@ -82,7 +105,7 @@ def save_tiled_image(img, filename, info={}, layout=None):
     info['layout'] = (nr, nc)
 
     nc_final_row = np.mod(nt, nc)
-    tImg = np.zeros((tw * nr, th * nc), dtype=img.dtype)
+    tImg = np.full((tw * nr, th * nc), fill, dtype=img.dtype)
     for i in range(nr):
         i_nc = nc_final_row if i + 1 == nr and nc_final_row > 0 else nc
         for j in range(i_nc):
@@ -91,8 +114,11 @@ def save_tiled_image(img, filename, info={}, layout=None):
 
     meta = PngInfo()
     meta.add_text('Description', json.dumps(info))
-    imwrite(filename, tImg, format='png', pnginfo=meta,
-            prefer_uint8=tImg.dtype != 'uint16')
+    if LEGACY_IIO:
+        imwrite(filename, tImg, format='png', pnginfo=meta,
+                prefer_uint8=False)
+    else:
+        iio.imwrite(filename, tImg, extension='.png', pnginfo=meta)
 
 
 def load_paired_images(filenames, typeA='Brightfield', typeB='segoutlines'):
@@ -274,7 +300,7 @@ class TrainValPairs(object):
         # case insensitive manner
         re_img = re.compile(r'^(.*)' + img_suffix + r'$', re.IGNORECASE)
         re_lbl = re.compile(r'^(.*)' + lbl_suffix + r'$', re.IGNORECASE)
-        png_files = sorted(Path(base_dir).rglob('*.png'))
+        png_files = sorted(Path(base_dir).resolve().rglob('*.png'))
         matches = [(re_img.search(f.stem), re_lbl.search(f.stem), f)
                    for f in png_files]
         matches = [('img', im, f) if im else ('lbl', lm, f)
@@ -329,6 +355,14 @@ class TrainValPairs(object):
             len(self.training), len(self.validation))
 
 
+IMAGE_INFO_GROUP_BY_MAP = {
+    'experimentID': str,
+    'position': int,
+    'trap': int,
+    'tp': int
+}
+
+
 class TrainValTestPairs(object):
 
     @property
@@ -502,6 +536,42 @@ class TrainValTestPairs(object):
                  val_size=0.2,
                  test_size=0.2,
                  group_by=('experimentID', 'position', 'trap')):
+        """Search a directory for image/label pairs to add
+
+        Images are assumed to be in PNG format (i.e., they must have extension
+        .png). The images can be annotated in the 'Description' meta data slot
+        using a JSON-encoded dictionary (see :py:function:`save_tiled_image`
+        and :py:function:`load_tiled_image`).
+
+        To ensure that highly similar image/label pairs (e.g., from
+        consecutive time points) are not unfairly separated into training and
+        either the validation or test sets, pairs are allocated according to
+        their 'group', which is defined by the ``group_by`` parameter. This
+        requires the images to include meta data annotations as described
+        above (valid keys are currently limited to 'experimentID', 'position',
+        'trap' and 'tp').
+
+        Args:
+            base_dir: base directory in which to search for images. All
+                subfolders are recursively searched.
+            img_suffix (str): the suffix (before .png extension) of all image
+                files.
+            lbl_suffix (str): the suffix (before .png extension) of all label
+                files.
+            val_size (float): fraction of files to split into validation set.
+            test_size (float): fraction of files to split into test set.
+            group_by (Union[str, Tuple[str]]): perform the split according to
+                groups defined by these tokens as found in the meta data of
+                each image.
+        """
+
+        # Check that all group_by values are valid
+        if type(group_by) == str:
+            group_by = (group_by,)
+        if not all((t in IMAGE_INFO_GROUP_BY_MAP for t in group_by)):
+            raise BadParam('group_by must be one of {}'.format(
+                ', '.join(f'"{k}"' for k in IMAGE_INFO_GROUP_BY_MAP.keys())))
+
         only_outlines = False
         if img_suffix is None:
             img_suffix='segoutlines'
@@ -513,7 +583,7 @@ class TrainValTestPairs(object):
         # case insensitive manner
         re_img = re.compile(r'^(.*)' + img_suffix + r'$', re.IGNORECASE)
         re_lbl = re.compile(r'^(.*)' + lbl_suffix + r'$', re.IGNORECASE)
-        png_files = sorted(Path(base_dir).rglob('*.png'))
+        png_files = sorted(Path(base_dir).resolve().rglob('*.png'))
         matches = [(re_img.search(f.stem), re_lbl.search(f.stem), f)
                    for f in png_files]
         matches = [('img', im, f) if im else ('lbl', lm, f)
@@ -541,14 +611,18 @@ class TrainValTestPairs(object):
         if len(pairs) == 0:
             return
 
-        # Choose a split that ensures separation by group keys and avoids,
-        # e.g., splitting same cell but different time points
+        # Choose a split that ensures separation by group keys and can avoid,
+        # e.g., splitting same cell but different time points.
         info = [
             json.loads(imread(l).meta.get('Description', '{}'))
             for _, l in pairs
         ]
+        # Collect grouping variables to split pairs up according to their
+        # group. NB: we force everything to be a string in case variables are
+        # loaded inconsistently from the JSON meta data. Any missing or False
+        # values are annotated with a unique label taken from enumeration:
         pair_groups = [
-            tuple(i.get(f, 'missing_' + str(e))
+            tuple(str(i.get(f) or 'missing_' + str(e))
                   for f in group_by)
             for e, i in enumerate(info)
         ]
@@ -563,11 +637,16 @@ class TrainValTestPairs(object):
             unique_groups,
             [int(npairs*train_size), int(npairs * (1-test_size))])
 
-        reformat = lambda exp, pos, trap : (exp, int(pos), int(trap))
+        # BELOW CODE REMOVED SINCE I IT DOES NOT ACHIEVE WHAT I THINK IS THE
+        # INTENDED PURPOSE OF ORDERING THE OUTPUT
+        # reformat = lambda exp, pos, trap : (exp, int(pos), int(trap))
+        # train_groups = set([reformat(*t) for t in train_groups])
+        # val_groups = set([reformat(*t) for t in val_groups])
+        # test_groups = set([reformat(*t) for t in test_groups])
 
-        train_groups = set([reformat(*t) for t in train_groups])
-        val_groups = set([reformat(*t) for t in val_groups])
-        test_groups = set([reformat(*t) for t in test_groups])
+        train_groups = {tuple(x) for x in train_groups}
+        val_groups = {tuple(x) for x in val_groups}
+        test_groups = {tuple(x) for x in test_groups}
 
         # Add new pairs to the existing train-val split
         self.training += [p for p, g in zip(pairs, pair_groups) if
diff --git a/python/baby/layers.py b/python/baby/layers.py
index cbf71a0f36fbbcb6afc4aa5f7491f63d367dffea..3d9c84de6074dadee9c816b41610915ede1040ff 100644
--- a/python/baby/layers.py
+++ b/python/baby/layers.py
@@ -1,12 +1,14 @@
 # If you publish results that make use of this software or the Birth Annotator
 # for Budding Yeast algorithm, please cite:
-# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S
-# Swain, 2021, Birth Annotator for Budding Yeast (in preparation).
+# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+# and Swain, P.S. (2023). Determining growth rates from bright-field images of
+# budding cells through identifying overlaps. eLife. 12:e79812.
+# https://doi.org/10.7554/eLife.79812
 # 
 # 
 # The MIT License (MIT)
 # 
-# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021
+# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
 # 
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to
@@ -25,97 +27,207 @@
 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
 # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
 # IN THE SOFTWARE.
-from tensorflow.python.keras.layers import (
+import numpy as np
+from itertools import repeat
+from string import ascii_lowercase as alphabet
+from tensorflow.keras.layers import (
     Conv2D, BatchNormalization, Activation, MaxPooling2D,
-    Conv2DTranspose, concatenate, Dropout
+    Conv2DTranspose, concatenate, Dropout, LayerNormalization,
+    add, DepthwiseConv2D
 )
 
 
 ### U-NET layers ###
 
 
-def conv_block(input_tensor, num_filters, stage, batchnorm=True, dropout=0.):
-    postfix = '_{}a'.format(stage + 1)
-    encoder = Conv2D(num_filters, (3, 3), padding='same',
-                     name='enc_conv' + postfix)(input_tensor)
+def res_block(input_tensor, nfilters, kernel=3, expand_ratio=4, drop=0.,
+        activation='swish', prefix='', suffix='', use_bias=False,
+        init='glorot_uniform', pre_activate=False, post_activate=False,
+        convnext=False, **kwargs):
+    """A flexible residual block
+
+    Primarily inspired from EfficientNet block, but without the squeeze and
+    excite block since that is unlikely to be important for the highly
+    application-specific models we typically train.
+
+    If expand_ratio < 1 and post_activate = True, it is a special case of the
+    ResNeXt block where the number of groups equals the number of features.
+
+    If convnext=True, then expansion occurs after depth-wise convolution,
+    batch normalisation is omitted in favour of a single layer normalisation
+    after depth-wise convolution and there is only a single activation in the
+    expanded space.
+    """
+
+    tmplt = '_'.join([e for e in (prefix, '{}', suffix) if len(e) > 0])
+    lbl = lambda l: tmplt.format(l)
+
+    x = input_tensor
+
+    if not convnext:
+        if pre_activate:
+            x = BatchNormalization(name=lbl('pre_bn'))(x)
+            x = Activation(activation, name=lbl('pre_act'))(x)
+
+        # Expand from bottleneck by a factor of expand_ratio
+        x = Conv2D(x.shape[-1] * expand_ratio, 1, padding='same',
+                use_bias=use_bias,
+                kernel_initializer=init, name=lbl('expand'))(x)
+        x = BatchNormalization(name=lbl('expand_bn'))(x)
+        x = Activation(activation, name=lbl('expand_act'))(x)
+
+    # Mix in spatial dimension
+    x = DepthwiseConv2D(kernel, padding='same', use_bias=use_bias,
+            depthwise_initializer=init, name=lbl('mixXY'))(x)
+    if convnext:
+        x = LayerNormalization(epsilon=1e-6, name=lbl('mixXY_ln'))(x)
+        # Expand from bottleneck
+        x = Conv2D(x.shape[-1] * expand_ratio, 1, padding='same',
+                use_bias=use_bias, kernel_initializer=init,
+                name=lbl('expand'))(x)
+        x = Activation(activation, name=lbl('expand_act'))(x)
+    else:
+        # Do standard norm/activation
+        x = BatchNormalization(name=lbl('mixXY_bn'))(x)
+        x = Activation(activation, name=lbl('mixXY_act'))(x)
+
+    # Mix across features
+    x = Conv2D(nfilters, 1, padding='same', use_bias=use_bias,
+            kernel_initializer=init, name=lbl('proj'))(x)
+    if not pre_activate and not convnext:
+        x = BatchNormalization(name=lbl('proj_bn'))(x)
+
+    # Stochastically drop this entire block
+    if drop > 0:
+        x = Dropout(drop, noise_shape=(None, 1, 1, 1), name=lbl('drop'))(x)
+
+    # Sum at bottleneck (though, will not be bottleneck if expand_ratio > 1)
+    if input_tensor.shape[-1] != nfilters:
+        input_tensor = Conv2D(nfilters, 1, padding='same', use_bias=use_bias,
+                kernel_initializer=init, name=lbl('input_proj'))(input_tensor)
+    x = add([input_tensor, x], name=lbl('add'))
+
+    # Post activation to match ResNeXt models
+    if post_activate and not pre_activate and not convnext:
+        x = Activation(activation, name=lbl('add_act'))(x)
+
+    return x
+
+
+def conv_block(input_tensor, nfilters, kernel=3, prefix='', suffix='',
+        batchnorm=True, dropout=0., activation='relu', init='glorot_uniform',
+        **kwargs):
+    """Standard convolution with batch norm and activation
+    """
+
+    tmplt = '_'.join([e for e in (prefix, '{}', suffix) if len(e) > 0])
+    lbl = lambda l: tmplt.format(l)
+
+    x = input_tensor
+    x = Conv2D(nfilters, kernel, padding='same', kernel_initializer=init,
+            name=lbl('conv'))(x)
     if batchnorm:
-        encoder = BatchNormalization(name='enc_bn' + postfix)(encoder)
-    encoder = Activation('relu', name='enc_act' + postfix)(encoder)
+        x = BatchNormalization(name=lbl('conv_bn'))(x)
+    x = Activation(activation, name=lbl('conv_act'))(x)
     if dropout > 0:
-        encoder = Dropout(dropout, name='enc_dropout' + postfix)(encoder)
+        x = Dropout(dropout, name=lbl('conv_dropout'))(x)
+
+    return x
 
-    postfix = '_{}b'.format(stage + 1)
-    encoder = Conv2D(num_filters, (3, 3), padding='same',
-                     name='enc_conv' + postfix)(encoder)
-    if batchnorm:
-        encoder = BatchNormalization(name='enc_bn' + postfix)(encoder)
-    encoder = Activation('relu', name='enc_act' + postfix)(encoder)
 
-    return encoder
+def encoder_block(input_tensor, nfilters, stage, repeats=2, block=conv_block,
+        dropout=0., drop=repeat(0.), batchnorm=True, conv_pool=False,
+        init='glorot_uniform', **kwargs):
 
+    x = input_tensor
+    for i in range(repeats - 1):
+        x = block(x, nfilters, dropout=dropout, batchnorm=batchnorm,
+                prefix='enc', init=init, suffix=f'{stage + 1}{alphabet[i]}',
+                **kwargs)
+    encoder = block(x, nfilters, batchnorm=batchnorm, init=init, prefix='enc',
+            suffix=f'{stage + 1}{alphabet[repeats - 1]}', **kwargs)
+    if conv_pool:
+        encoder_pool = DepthwiseConv2D(2, strides=2, padding='same', 
+                kernel_initializer=init, name=f'down_{stage + 1}')(encoder)
+        if batchnorm:
+            encoder_pool = BatchNormalization(
+                    name=f'down_bn_{stage + 1}')(encoder_pool)
+    else:
+        encoder_pool = MaxPooling2D(
+            (2, 2), strides=(2, 2), name=f'down_{stage + 1}')(encoder)
 
-def encoder_block(input_tensor, num_filters, stage, batchnorm=True,
-                  dropout=0.):
-    encoder = conv_block(input_tensor, num_filters, stage,
-                         batchnorm=batchnorm,
-                         dropout=dropout)
-    encoder_pool = MaxPooling2D(
-        (2, 2), strides=(2, 2), name='down_{}'.format(stage + 1))(encoder)
     return encoder_pool, encoder
 
 
-def decoder_block(input_tensor, concat_tensor, num_filters, stage,
-                  batchnorm=True, prename='', dropout=0.):
-    postfix = '_{}'.format(stage + 1)
-    decoder = Conv2DTranspose(num_filters, (2, 2), strides=(2, 2), padding='same',
-                              name=prename + 'up' + postfix)(input_tensor)
-    decoder = concatenate([concat_tensor, decoder], axis=-1,
-                          name=prename + 'skip' + postfix)
-    if batchnorm:
-        decoder = BatchNormalization(name=prename + 'up_bn' + postfix)(decoder)
-    decoder = Activation('relu', name=prename + 'up_act' + postfix)(decoder)
-    if dropout > 0:
-        decoder = Dropout(dropout, name='dec_dropout' + postfix)(decoder)
+def decoder_block(input_tensor, skip_tensor, nfilters, stage, repeats=2,
+        block=conv_block, dropout=0., drop=repeat(0.), batchnorm=True,
+        prename='', init='glorot_uniform', residual_skip=False,
+        up_activate=True, activation='relu', **kwargs):
 
-    postfix = '_{}a'.format(stage + 1)
-    decoder = Conv2D(num_filters, (3, 3), padding='same',
-                     name=prename + 'dec_conv' + postfix)(decoder)
+    x = input_tensor
+    x = Conv2DTranspose(nfilters, 2, strides=2, padding='same',
+            kernel_initializer=init, name=prename + f'up_{stage + 1}')(x)
     if batchnorm:
-        decoder = BatchNormalization(name=prename + 'dec_bn' + postfix)(decoder)
-    decoder = Activation('relu', name=prename + 'dec_act' + postfix)(decoder)
-    if dropout > 0:
-        decoder = Dropout(dropout, name='dec_dropout' + postfix)(decoder)
+        x = BatchNormalization(name=prename + f'up_bn_{stage + 1}')(x)
+    if up_activate:
+        x = Activation(activation, name=prename + f'up_act_{stage + 1}')(x)
+    if residual_skip:
+        x = add([skip_tensor, x], name=prename + f'up_skip_{stage + 1}')
+    else:
+        x = concatenate([skip_tensor, x], axis=-1,
+                name=prename + f'up_skip_{stage + 1}')
+
+    drop = iter(drop)
+    for i in range(repeats - 1):
+        x = block(x, nfilters, dropout=dropout, drop=next(drop),
+                batchnorm=batchnorm, init=init, prefix='dec',
+                suffix=f'{stage + 1}{alphabet[i]}', **kwargs)
+    decoder = block(x, nfilters, drop=next(drop), batchnorm=batchnorm,
+            init=init, prefix='dec',
+            suffix=f'{stage + 1}{alphabet[repeats - 1]}', **kwargs)
 
-    postfix = '_{}b'.format(stage + 1)
-    decoder = Conv2D(num_filters, (3, 3), padding='same',
-                     name=prename + 'dec_conv' + postfix)(decoder)
-    if batchnorm:
-        decoder = BatchNormalization(name=prename + 'dec_bn' + postfix)(decoder)
-    decoder = Activation('relu', name=prename + 'dec_act' + postfix)(decoder)
-    if dropout > 0:
-        decoder = Dropout(dropout, name='dec_dropout' + postfix)(decoder)
     return decoder
 
 
-def unet_block(input_tensor, layer_sizes, batchnorm=True, dropout=0.):
+def unet_block(input_tensor, layer_sizes, enc_repeats=2, dec_repeats=2,
+        drop=0., dropout=0., block=conv_block, stem=False, **kwargs): 
+
+    nlayers = len(layer_sizes)
+
+    # Rate of stochastic block dropping increases linearly with depth
+    drop = iter(np.linspace(0, drop, nlayers * enc_repeats +  (nlayers - 1) *
+        dec_repeats))
+
+    if stem:
+        x = input_tensor
+        x = Conv2D(layer_sizes[0], 3, padding='same', name='stem')(x)
+        if not kwargs.get('pre_activate', False):
+            x = BatchNormalization(name='stem_bn')(x)
+            x = Activation(kwargs.get('activation', 'relu'), name='stem_act')(x)
+        input_tensor = x
+
     # Encoding
     upper_layer = input_tensor
     encoding_layers = []
-    for i, num_filters in enumerate(layer_sizes[:-1]):
-        upper_layer, encoder = encoder_block(upper_layer, num_filters, i,
-                                             batchnorm=batchnorm,
-                                             dropout=dropout)
+    for i, nfilters in enumerate(layer_sizes[:-1]):
+        upper_layer, encoder = encoder_block(upper_layer, nfilters, i,
+                repeats=enc_repeats, dropout=dropout, drop=drop, block=block,
+                **kwargs)
         encoding_layers.append(encoder)
 
     # Centre
-    lower_layer = conv_block(upper_layer, layer_sizes[-1], len(layer_sizes) - 1,
-                             batchnorm=batchnorm, dropout=dropout)
+    x = upper_layer
+    for i in range(enc_repeats - 1):
+        x = block(x, layer_sizes[-1], dropout=dropout, drop=next(drop),
+                prefix='enc', suffix=f'{nlayers}{alphabet[i]}', **kwargs)
+    lower_layer = block(x, layer_sizes[-1], drop=next(drop), prefix='enc',
+            suffix=f'{nlayers}{alphabet[enc_repeats - 1]}', **kwargs)
 
     # Decoding
-    for i, num_filters in reversed(list(enumerate(layer_sizes[:-1]))):
-        lower_layer = decoder_block(lower_layer, encoding_layers[i],
-                                    num_filters, i, batchnorm=batchnorm,
-                                    dropout=dropout)
+    for i, nfilters in reversed(list(enumerate(layer_sizes[:-1]))):
+        lower_layer = decoder_block(lower_layer, encoding_layers[i], nfilters,
+                i, repeats=dec_repeats, dropout=dropout, drop=drop,
+                block=block, **kwargs)
 
     return lower_layer
 
diff --git a/python/baby/losses.py b/python/baby/losses.py
index 18231b088fe1409ea046c411b75419a508c984af..cc8079e917e94862454f712bb9f884b9ccf5d11c 100644
--- a/python/baby/losses.py
+++ b/python/baby/losses.py
@@ -1,12 +1,14 @@
 # If you publish results that make use of this software or the Birth Annotator
 # for Budding Yeast algorithm, please cite:
-# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S
-# Swain, 2021, Birth Annotator for Budding Yeast (in preparation).
+# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+# and Swain, P.S. (2023). Determining growth rates from bright-field images of
+# budding cells through identifying overlaps. eLife. 12:e79812.
+# https://doi.org/10.7554/eLife.79812
 # 
 # 
 # The MIT License (MIT)
 # 
-# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021
+# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
 # 
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to
@@ -26,7 +28,7 @@
 # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
 # IN THE SOFTWARE.
 import tensorflow as tf
-from tensorflow.python.keras.losses import binary_crossentropy
+from tensorflow.keras.losses import binary_crossentropy
 
 
 ### LOSS FUNCTIONS ###
@@ -37,9 +39,10 @@ def dice_coeff(y_true, y_pred):
     # Flatten
     y_true_f = tf.reshape(y_true, [-1])
     y_pred_f = tf.reshape(y_pred, [-1])
-    intersection = tf.reduce_sum(y_true_f * y_pred_f)
+    intersection = tf.reduce_sum(tf.boolean_mask(y_pred_f, y_true_f))
     score = ((2. * intersection + smooth) /
-             (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth))
+             (tf.math.count_nonzero(y_true_f, dtype=tf.dtypes.float32)
+              + tf.reduce_sum(y_pred_f) + smooth))
     return score
 
 
diff --git a/python/baby/models.py b/python/baby/models.py
index ebf47ccd21e9a940ade15deb319bf1daa90c6365..9d3d87705ad81625b825cce063b3216796e2641b 100644
--- a/python/baby/models.py
+++ b/python/baby/models.py
@@ -1,12 +1,14 @@
 # If you publish results that make use of this software or the Birth Annotator
 # for Budding Yeast algorithm, please cite:
-# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S
-# Swain, 2021, Birth Annotator for Budding Yeast (in preparation).
+# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+# and Swain, P.S. (2023). Determining growth rates from bright-field images of
+# budding cells through identifying overlaps. eLife. 12:e79812.
+# https://doi.org/10.7554/eLife.79812
 # 
 # 
 # The MIT License (MIT)
 # 
-# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021
+# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
 # 
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to
@@ -25,12 +27,13 @@
 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
 # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
 # IN THE SOFTWARE.
-from tensorflow.python.keras.models import Model
-from tensorflow.python.keras.optimizers import Adam
-from tensorflow.python.keras.layers import Input
+from tensorflow.keras.models import Model
+from tensorflow.keras.optimizers import Adam
+from tensorflow.keras.layers import Input
+from tensorflow.keras.initializers import VarianceScaling
 
 from .utils import named_obj
-from .layers import msd_block, unet_block, make_outputs
+from .layers import msd_block, unet_block, conv_block, res_block, make_outputs
 from .losses import bce_dice_loss, dice_coeff
 
 
@@ -39,13 +42,18 @@ def named_model_fn(name):
     def wrap(f):
 
         @named_obj(name)
-        def model_fn(generator, flattener, weights={}, **kwargs):
+        def model_fn(generator, flattener, weights={}, use_adamw=False, **kwargs):
             weights = {n: weights.get(n, 1) for n in flattener.names()}
             inputs = Input(shape=generator.shapes.input[1:])
             model = Model(inputs=[inputs],
                           outputs=make_outputs(f(inputs, **kwargs),
                                                flattener.names()))
-            model.compile(optimizer=Adam(amsgrad=False),
+            if use_adamw:
+                from tensorflow_addons.optimizers import AdamW
+                optimizer = AdamW(weight_decay=0.00025)
+            else:
+                optimizer = Adam(amsgrad=False)
+            model.compile(optimizer=optimizer,
                           metrics=[dice_coeff],
                           loss=bce_dice_loss,
                           loss_weights=weights)
@@ -75,6 +83,46 @@ def unet(inputs, depth=4, layer_size=8, batchnorm=True, dropout=0.):
                       dropout=dropout, batchnorm=batchnorm)
 
 
+@named_model_fn('unet3')
+def unet3(inputs, depth=3, layer_size=8, batchnorm=True, dropout=0.):
+    layer_sizes = [layer_size*(2**i) for i in range(depth)]
+    return unet_block(inputs, layer_sizes,
+                      dropout=dropout, batchnorm=batchnorm)
+
+
+@named_model_fn('unet_hyper')
+def unet_hyper(inputs, width=64, depth=4, initializer='glorot_uniform',
+        block_type='conv', **kwargs):
+    layer_sizes = [width*(2**i) for i in range(depth)]
+    if initializer == 'variance_scaling':
+        initializer = VarianceScaling(2., mode='fan_out')
+    block_args = {
+            'conv': dict(block=conv_block, stem=False),
+            'effnet': dict(block=res_block, stem=True),
+            'effnet-preact': dict(block=res_block, pre_activate=True, stem=True),
+            'convnext': dict(block=res_block, stem=True, convnext=True)
+            }[block_type]
+    block_args.update(kwargs)
+    return unet_block(inputs, layer_sizes, init=initializer, **block_args)
+
+
+@named_obj('unet_convnext')
+def unet_convnext(generator, flattener, weights={}, width=8, depth=4,
+        kernel=7, enc_repeats=3, dec_repeats=2, expand_ratio=4,
+        activation='swish', initializer='variance_scaling', **kwargs):
+    """U-net model with ConvNeXt blocks and defaults
+
+    With default parameters, produces a model of similar size to the default
+    unet model here (depth 4, width 8).
+    """
+    return unet_hyper(generator, flattener, weights=weights, width=width,
+            depth=depth, kernel=kernel, enc_repeats=enc_repeats,
+            dec_repeats=dec_repeats, block_type='convnext',
+            expand_ratio=expand_ratio, activation=activation,
+            initializer=initializer, use_adamw=True, residual_skip=True,
+            conv_pool=True, up_activate=False, **kwargs)
+
+
 @named_model_fn('msd')
 def msd(inputs, depth=80, width=1, n_dilations=4, dilation=1, batchnorm=True):
     dilations = [dilation * (2 ** i) for i in range(n_dilations)]
diff --git a/python/baby/models/I1_evolve_unet_4s_20200630.hdf5 b/python/baby/models/I1_evolve_unet_4s_20200630.hdf5
deleted file mode 100644
index cf4b862e5aa88bf1534edebf7fdc06eda316727f..0000000000000000000000000000000000000000
Binary files a/python/baby/models/I1_evolve_unet_4s_20200630.hdf5 and /dev/null differ
diff --git a/python/baby/models/I1_evolve_unet_4s_20210302.hdf5 b/python/baby/models/I1_evolve_unet_4s_20210302.hdf5
deleted file mode 100644
index 6f5d000647bd1e235672b6fea4105bcce9289ee4..0000000000000000000000000000000000000000
Binary files a/python/baby/models/I1_evolve_unet_4s_20210302.hdf5 and /dev/null differ
diff --git a/python/baby/models/I1_prime_unet_4s_20200630.hdf5 b/python/baby/models/I1_prime_unet_4s_20200630.hdf5
deleted file mode 100644
index f5f890a31e82e211dea7b4ae890e769f707c48dc..0000000000000000000000000000000000000000
Binary files a/python/baby/models/I1_prime_unet_4s_20200630.hdf5 and /dev/null differ
diff --git a/python/baby/models/I1_prime_unet_4s_20210302.hdf5 b/python/baby/models/I1_prime_unet_4s_20210302.hdf5
deleted file mode 100644
index 5093dba134babbcb94c35df94f022305dbcbd24c..0000000000000000000000000000000000000000
Binary files a/python/baby/models/I1_prime_unet_4s_20210302.hdf5 and /dev/null differ
diff --git a/python/baby/models/I3_evolve_unet_4s_20200903.hdf5 b/python/baby/models/I3_evolve_unet_4s_20200903.hdf5
deleted file mode 100644
index 785ba76d57a1ed9272f0d0a3a81dce63b0cd081e..0000000000000000000000000000000000000000
Binary files a/python/baby/models/I3_evolve_unet_4s_20200903.hdf5 and /dev/null differ
diff --git a/python/baby/models/I3_evolve_unet_4s_20210215.hdf5 b/python/baby/models/I3_evolve_unet_4s_20210215.hdf5
deleted file mode 100644
index fede45d6909cb21be94c2940cf28803eeb7f7d85..0000000000000000000000000000000000000000
Binary files a/python/baby/models/I3_evolve_unet_4s_20210215.hdf5 and /dev/null differ
diff --git a/python/baby/models/I3_prime_unet_4s_20200903.hdf5 b/python/baby/models/I3_prime_unet_4s_20200903.hdf5
deleted file mode 100644
index f4a1dba19ca8be6bf8351ae593226ac65bdf7d19..0000000000000000000000000000000000000000
Binary files a/python/baby/models/I3_prime_unet_4s_20200903.hdf5 and /dev/null differ
diff --git a/python/baby/models/I3_prime_unet_4s_20210215.hdf5 b/python/baby/models/I3_prime_unet_4s_20210215.hdf5
deleted file mode 100644
index be7e266f30a0f03ba0c502655fc58ab809d483d4..0000000000000000000000000000000000000000
Binary files a/python/baby/models/I3_prime_unet_4s_20210215.hdf5 and /dev/null differ
diff --git a/python/baby/models/I5_evolve_unet_4s_20210120.hdf5 b/python/baby/models/I5_evolve_unet_4s_20210120.hdf5
deleted file mode 100644
index f7b133276b73e3c30fd2b3bab81309e0e694addd..0000000000000000000000000000000000000000
Binary files a/python/baby/models/I5_evolve_unet_4s_20210120.hdf5 and /dev/null differ
diff --git a/python/baby/models/I5_prime_unet_4s_20210120.hdf5 b/python/baby/models/I5_prime_unet_4s_20210120.hdf5
deleted file mode 100644
index 89afdb85d621332913a9fd1816a7d03c00a24277..0000000000000000000000000000000000000000
Binary files a/python/baby/models/I5_prime_unet_4s_20210120.hdf5 and /dev/null differ
diff --git a/python/baby/models/ct_rf_20210125_9.pkl b/python/baby/models/ct_rf_20210125_9.pkl
deleted file mode 100644
index a86a15b55c53f4a92a71e91c732ef32355ab6164..0000000000000000000000000000000000000000
Binary files a/python/baby/models/ct_rf_20210125_9.pkl and /dev/null differ
diff --git a/python/baby/models/ct_rf_20210201_12.pkl b/python/baby/models/ct_rf_20210201_12.pkl
deleted file mode 100644
index aff4247d0a4973186315b36f19ac6f1caf4c2057..0000000000000000000000000000000000000000
Binary files a/python/baby/models/ct_rf_20210201_12.pkl and /dev/null differ
diff --git a/python/baby/models/ct_rf_20210201_4.pkl b/python/baby/models/ct_rf_20210201_4.pkl
deleted file mode 100644
index cc5cd715a297ed113be2d8503c96e9d4d91a67f1..0000000000000000000000000000000000000000
Binary files a/python/baby/models/ct_rf_20210201_4.pkl and /dev/null differ
diff --git a/python/baby/models/flattener_60x_1z_evolve_20210302.json b/python/baby/models/flattener_60x_1z_evolve_20210302.json
deleted file mode 100644
index b1333bbc17f6b11554750809dda4bf8a9cf0fe5c..0000000000000000000000000000000000000000
--- a/python/baby/models/flattener_60x_1z_evolve_20210302.json
+++ /dev/null
@@ -1 +0,0 @@
-{"groupdef": {"small": {"_python_NamedTuple": {"lower": 1, "upper": 134, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "medium": {"_python_NamedTuple": {"lower": 101, "upper": 177, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "large": {"_python_NamedTuple": {"lower": 143, "upper": Infinity, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "buds": {"_python_NamedTuple": {"lower": 1, "upper": 200, "budonly": true, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}}, "groupprops": {"small": {"_python_set": ["filled", "edge"]}, "medium": {"_python_set": ["edge", "overlap", "interior"]}, "large": {"_python_set": ["edge", "overlap", "interior"]}, "buds": {"_python_set": ["budneck"]}}, "targets": [{"_python_NamedTuple": {"name": "lge_inte", "group": "large", "prop": "interior", "nerode": 4, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "lge_edge", "group": "large", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "mid_inte", "group": "medium", "prop": "interior", "nerode": 3, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "mid_edge", "group": "medium", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "sml_inte", "group": "small", "prop": "filled", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "sml_edge", "group": "small", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "bud_neck", "group": "buds", "prop": "budneck", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}]}
\ No newline at end of file
diff --git a/python/baby/models/flattener_60x_1z_prime_20210302.json b/python/baby/models/flattener_60x_1z_prime_20210302.json
deleted file mode 100644
index 4bd3545c288f66539370e1db9da2bd69a61291bf..0000000000000000000000000000000000000000
--- a/python/baby/models/flattener_60x_1z_prime_20210302.json
+++ /dev/null
@@ -1 +0,0 @@
-{"groupdef": {"small": {"_python_NamedTuple": {"lower": 1, "upper": 258, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "medium": {"_python_NamedTuple": {"lower": 192, "upper": 352, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "large": {"_python_NamedTuple": {"lower": 286, "upper": Infinity, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "buds": {"_python_NamedTuple": {"lower": 1, "upper": 300, "budonly": true, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}}, "groupprops": {"small": {"_python_set": ["filled", "edge"]}, "medium": {"_python_set": ["overlap", "edge", "interior"]}, "large": {"_python_set": ["overlap", "edge", "interior"]}, "buds": {"_python_set": ["budneck"]}}, "targets": [{"_python_NamedTuple": {"name": "lge_inte", "group": "large", "prop": "interior", "nerode": 5, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "lge_edge", "group": "large", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "mid_inte", "group": "medium", "prop": "interior", "nerode": 5, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "mid_edge", "group": "medium", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "sml_inte", "group": "small", "prop": "filled", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "sml_edge", "group": "small", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "bud_neck", "group": "buds", "prop": "budneck", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}]}
\ No newline at end of file
diff --git a/python/baby/models/flattener_60x_3z_evolve_20210215.json b/python/baby/models/flattener_60x_3z_evolve_20210215.json
deleted file mode 100644
index 1f75baa878fad6c46cdfc539b888972fa9fdcc82..0000000000000000000000000000000000000000
--- a/python/baby/models/flattener_60x_3z_evolve_20210215.json
+++ /dev/null
@@ -1 +0,0 @@
-{"groupdef": {"small": {"_python_NamedTuple": {"lower": 1, "upper": 133, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "medium": {"_python_NamedTuple": {"lower": 100, "upper": 175, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "large": {"_python_NamedTuple": {"lower": 142, "upper": Infinity, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "buds": {"_python_NamedTuple": {"lower": 1, "upper": 200, "budonly": true, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}}, "groupprops": {"small": {"_python_set": ["filled", "edge"]}, "medium": {"_python_set": ["overlap", "edge", "interior"]}, "large": {"_python_set": ["overlap", "edge", "interior"]}, "buds": {"_python_set": ["budneck"]}}, "targets": [{"_python_NamedTuple": {"name": "lge_inte", "group": "large", "prop": "interior", "nerode": 4, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "lge_edge", "group": "large", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "mid_inte", "group": "medium", "prop": "interior", "nerode": 3, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "mid_edge", "group": "medium", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "sml_inte", "group": "small", "prop": "filled", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "sml_edge", "group": "small", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "bud_neck", "group": "buds", "prop": "budneck", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}]}
\ No newline at end of file
diff --git a/python/baby/models/flattener_60x_3z_prime_20210215.json b/python/baby/models/flattener_60x_3z_prime_20210215.json
deleted file mode 100644
index a35852ef72d755894d4397f0337373f8bf2b567b..0000000000000000000000000000000000000000
--- a/python/baby/models/flattener_60x_3z_prime_20210215.json
+++ /dev/null
@@ -1 +0,0 @@
-{"groupdef": {"small": {"_python_NamedTuple": {"lower": 1, "upper": 248, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "medium": {"_python_NamedTuple": {"lower": 181, "upper": 319, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "large": {"_python_NamedTuple": {"lower": 252, "upper": Infinity, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "buds": {"_python_NamedTuple": {"lower": 1, "upper": 300, "budonly": true, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}}, "groupprops": {"small": {"_python_set": ["filled", "edge"]}, "medium": {"_python_set": ["overlap", "edge", "interior"]}, "large": {"_python_set": ["overlap", "edge", "interior"]}, "buds": {"_python_set": ["budneck"]}}, "targets": [{"_python_NamedTuple": {"name": "lge_inte", "group": "large", "prop": "interior", "nerode": 5, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "lge_edge", "group": "large", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "mid_inte", "group": "medium", "prop": "interior", "nerode": 3, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "mid_edge", "group": "medium", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "sml_inte", "group": "small", "prop": "filled", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "sml_edge", "group": "small", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "bud_neck", "group": "buds", "prop": "budneck", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}]}
\ No newline at end of file
diff --git a/python/baby/models/flattener_60x_evolve_20200630.json b/python/baby/models/flattener_60x_evolve_20200630.json
deleted file mode 100644
index bf693db44d61dac6be5ea26a3932cddf5de6ce9e..0000000000000000000000000000000000000000
--- a/python/baby/models/flattener_60x_evolve_20200630.json
+++ /dev/null
@@ -1 +0,0 @@
-{"groupdef": {"small": {"_python_NamedTuple": {"lower": 1, "upper": 104, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "medium": {"_python_NamedTuple": {"lower": 76, "upper": 153, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "large": {"_python_NamedTuple": {"lower": 124, "upper": Infinity, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "buds": {"_python_NamedTuple": {"lower": 1, "upper": 200, "budonly": true, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}}, "groupprops": {"small": {"_python_set": ["filled", "edge"]}, "medium": {"_python_set": ["interior", "overlap", "edge"]}, "large": {"_python_set": ["interior", "overlap", "edge"]}, "buds": {"_python_set": ["budneck"]}}, "targets": [{"_python_NamedTuple": {"name": "lge_inte", "group": "large", "prop": "interior", "nerode": 3, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "lge_edge", "group": "large", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "mid_inte", "group": "medium", "prop": "interior", "nerode": 2, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "mid_edge", "group": "medium", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "sml_inte", "group": "small", "prop": "filled", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "sml_edge", "group": "small", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "bud_neck", "group": "buds", "prop": "budneck", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}]}
\ No newline at end of file
diff --git a/python/baby/models/flattener_60x_evolve_20210120.json b/python/baby/models/flattener_60x_evolve_20210120.json
deleted file mode 100644
index 7c06d8316c0e5421083adbfe96437e6d7402e0a2..0000000000000000000000000000000000000000
--- a/python/baby/models/flattener_60x_evolve_20210120.json
+++ /dev/null
@@ -1 +0,0 @@
-{"groupdef": {"small": {"_python_NamedTuple": {"lower": 1, "upper": 133, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "medium": {"_python_NamedTuple": {"lower": 100, "upper": 176, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "large": {"_python_NamedTuple": {"lower": 143, "upper": Infinity, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "buds": {"_python_NamedTuple": {"lower": 1, "upper": 200, "budonly": true, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}}, "groupprops": {"small": {"_python_set": ["edge", "filled"]}, "medium": {"_python_set": ["overlap", "edge", "interior"]}, "large": {"_python_set": ["overlap", "edge", "interior"]}, "buds": {"_python_set": ["budneck"]}}, "targets": [{"_python_NamedTuple": {"name": "lge_inte", "group": "large", "prop": "interior", "nerode": 4, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "lge_edge", "group": "large", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "mid_inte", "group": "medium", "prop": "interior", "nerode": 3, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "mid_edge", "group": "medium", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "sml_inte", "group": "small", "prop": "filled", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "sml_edge", "group": "small", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "bud_neck", "group": "buds", "prop": "budneck", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}]}
\ No newline at end of file
diff --git a/python/baby/models/flattener_60x_prime_20200630.json b/python/baby/models/flattener_60x_prime_20200630.json
deleted file mode 100644
index d156a0d3cd015ca234b6853130ab75a2bea53ae6..0000000000000000000000000000000000000000
--- a/python/baby/models/flattener_60x_prime_20200630.json
+++ /dev/null
@@ -1 +0,0 @@
-{"groupdef": {"small": {"_python_NamedTuple": {"lower": 1, "upper": 285, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "medium": {"_python_NamedTuple": {"lower": 227, "upper": 340, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "large": {"_python_NamedTuple": {"lower": 282, "upper": Infinity, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "buds": {"_python_NamedTuple": {"lower": 1, "upper": 300, "budonly": true, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}}, "groupprops": {"small": {"_python_set": ["filled", "edge"]}, "medium": {"_python_set": ["interior", "overlap", "edge"]}, "large": {"_python_set": ["interior", "overlap", "edge"]}, "buds": {"_python_set": ["budneck"]}}, "targets": [{"_python_NamedTuple": {"name": "lge_inte", "group": "large", "prop": "interior", "nerode": 5, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "lge_edge", "group": "large", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "mid_inte", "group": "medium", "prop": "interior", "nerode": 5, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "mid_edge", "group": "medium", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "sml_inte", "group": "small", "prop": "filled", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "sml_edge", "group": "small", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "bud_neck", "group": "buds", "prop": "budneck", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}]}
\ No newline at end of file
diff --git a/python/baby/models/flattener_60x_prime_20210120.json b/python/baby/models/flattener_60x_prime_20210120.json
deleted file mode 100644
index 03cb6e0dd99aaf62459221cb19e68cadc70c140b..0000000000000000000000000000000000000000
--- a/python/baby/models/flattener_60x_prime_20210120.json
+++ /dev/null
@@ -1 +0,0 @@
-{"groupdef": {"small": {"_python_NamedTuple": {"lower": 1, "upper": 255, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "medium": {"_python_NamedTuple": {"lower": 189, "upper": 319, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "large": {"_python_NamedTuple": {"lower": 252, "upper": Infinity, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "buds": {"_python_NamedTuple": {"lower": 1, "upper": 300, "budonly": true, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}}, "groupprops": {"small": {"_python_set": ["edge", "filled"]}, "medium": {"_python_set": ["edge", "overlap", "interior"]}, "large": {"_python_set": ["edge", "overlap", "interior"]}, "buds": {"_python_set": ["budneck"]}}, "targets": [{"_python_NamedTuple": {"name": "lge_inte", "group": "large", "prop": "interior", "nerode": 5, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "lge_edge", "group": "large", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "mid_inte", "group": "medium", "prop": "interior", "nerode": 4, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "mid_edge", "group": "medium", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "sml_inte", "group": "small", "prop": "filled", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "sml_edge", "group": "small", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "bud_neck", "group": "buds", "prop": "budneck", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}]}
\ No newline at end of file
diff --git a/python/baby/models/mb_model_20201022.pkl b/python/baby/models/mb_model_20201022.pkl
deleted file mode 100644
index e25c52bdf6a7debca807a6de540cf36e673eb0e6..0000000000000000000000000000000000000000
Binary files a/python/baby/models/mb_model_20201022.pkl and /dev/null differ
diff --git a/python/baby/models/mb_model_60x_1z_evolve_210302.pkl b/python/baby/models/mb_model_60x_1z_evolve_210302.pkl
deleted file mode 100644
index bad010d4ea03f38ecc1d90828f2c667792459936..0000000000000000000000000000000000000000
Binary files a/python/baby/models/mb_model_60x_1z_evolve_210302.pkl and /dev/null differ
diff --git a/python/baby/models/mb_model_60x_1z_prime_210302.pkl b/python/baby/models/mb_model_60x_1z_prime_210302.pkl
deleted file mode 100644
index 32d399892f0bfe40d0a1d9426a3ca08e607b32ea..0000000000000000000000000000000000000000
Binary files a/python/baby/models/mb_model_60x_1z_prime_210302.pkl and /dev/null differ
diff --git a/python/baby/models/mb_model_60x_3z_evolve_20210215.pkl b/python/baby/models/mb_model_60x_3z_evolve_20210215.pkl
deleted file mode 100644
index fb2b349dc580a4ef2062ce0d4b2fab5e8abe69b2..0000000000000000000000000000000000000000
Binary files a/python/baby/models/mb_model_60x_3z_evolve_20210215.pkl and /dev/null differ
diff --git a/python/baby/models/mb_model_60x_3z_prime_20210215.pkl b/python/baby/models/mb_model_60x_3z_prime_20210215.pkl
deleted file mode 100644
index 97a6fd1ca0a3ebe0a9093cff56e3537e1f1c4bcb..0000000000000000000000000000000000000000
Binary files a/python/baby/models/mb_model_60x_3z_prime_20210215.pkl and /dev/null differ
diff --git a/python/baby/models/mb_model_60x_evolve_20210120.pkl b/python/baby/models/mb_model_60x_evolve_20210120.pkl
deleted file mode 100644
index 3b86ce1254f8c1205f39004c816d80ff35f7461c..0000000000000000000000000000000000000000
Binary files a/python/baby/models/mb_model_60x_evolve_20210120.pkl and /dev/null differ
diff --git a/python/baby/models/mb_model_60x_prime_20210120.pkl b/python/baby/models/mb_model_60x_prime_20210120.pkl
deleted file mode 100644
index 8c1fb44059d2b71e135d51014320ca0799959fb6..0000000000000000000000000000000000000000
Binary files a/python/baby/models/mb_model_60x_prime_20210120.pkl and /dev/null differ
diff --git a/python/baby/modelsets.json b/python/baby/modelsets.json
deleted file mode 100644
index fcdaf59a60a683ec5a82e558b6c82311cd44ce22..0000000000000000000000000000000000000000
--- a/python/baby/modelsets.json
+++ /dev/null
@@ -1,103 +0,0 @@
-{
-  "prime95b_brightfield_60x_5z": {
-    "morph_model_file": "I5_prime_unet_4s_20210120.hdf5",
-    "flattener_file": "flattener_60x_prime_20210120.json",
-    "celltrack_model_file": "ct_rf_20210201_12.pkl",
-    "budassign_model_file": "mb_model_60x_prime_20210120.pkl",
-    "pixel_size": 0.182,
-    "default_image_size": [117, 117],
-    "params": {
-      "interior_threshold": [0.95, 0.5, 0.35],
-      "nclosing": [0, 0, 0], "nopening": [2, 0, 0],
-      "connectivity": [1, 1, 2],
-      "edge_sub_dilations": [0, 0, 0],
-      "containment_thresh": 0.85, "min_area": 19,
-      "pedge_thresh": [0.0, 0.0028, 0.0012],
-      "group_thresh_expansion": [0.28, 0.06, 0.32]
-    }
-  },
-  "evolve_brightfield_60x_5z": {
-    "morph_model_file": "I5_evolve_unet_4s_20210120.hdf5",
-    "flattener_file": "flattener_60x_evolve_20210120.json",
-    "celltrack_model_file": "ct_rf_20210201_12.pkl",
-    "budassign_model_file": "mb_model_60x_evolve_20210120.pkl",
-    "pixel_size": 0.263,
-    "default_image_size": [81, 81],
-    "params": {
-      "interior_threshold": [0.9, 0.65, 0.8],
-      "nclosing": [0, 1, 0], "nopening": [1, 2, 0],
-      "connectivity": [1, 1, 1],
-      "edge_sub_dilations": [0, 2, 0],
-      "containment_thresh": 0.85, "min_area": 9,
-      "pedge_thresh": [0.0, 0.0, 0.0014],
-      "group_thresh_expansion": [0.14, 0.06, 0.24]
-    },
-    "isbud_thresh": 0.35
-  },
-  "prime95b_brightfield_60x_1z": {
-    "morph_model_file": "I1_prime_unet_4s_20210302.hdf5",
-    "flattener_file": "flattener_60x_1z_prime_20210302.json",
-    "celltrack_model_file": "ct_rf_20210201_12.pkl",
-    "budassign_model_file": "mb_model_60x_1z_prime_210302.pkl",
-    "pixel_size": 0.182,
-    "default_image_size": [117, 117],
-    "params": {
-      "interior_threshold": [0.5, 0.8, 0.4],
-      "nclosing": [0, 0, 0], "nopening": [1, 2, 0],
-      "connectivity": [1, 1, 1],
-      "edge_sub_dilations": [0, 1, 0],
-      "containment_thresh": 0.8, "min_area": 19,
-      "pedge_thresh": [0.0, 0.0, 0.0002],
-      "group_thresh_expansion": [0.22, 0.4, 0.24]
-    }
-  },
-  "evolve_brightfield_60x_1z": {
-    "morph_model_file": "I1_evolve_unet_4s_20210302.hdf5",
-    "flattener_file": "flattener_60x_1z_evolve_20210302.json",
-    "celltrack_model_file": "ct_rf_20210201_12.pkl",
-    "budassign_model_file": "mb_model_60x_1z_evolve_210302.pkl",
-    "pixel_size": 0.263,
-    "default_image_size": [81, 81],
-    "params": {
-      "interior_threshold": [0.35, 0.8, 0.35],
-      "nclosing": [0, 0, 0], "nopening": [0, 0, 1],
-      "connectivity": [1, 1, 1], "edge_sub_dilations": [0, 1, 0],
-      "containment_thresh": 0.7, "min_area": 14,
-      "pedge_thresh": [0.0, 0.0, 0.0009],
-      "group_thresh_expansion": [0.16, 0.4, 0.0]
-    }
-  },
-  "prime95b_brightfield_60x_3z": {
-    "morph_model_file": "I3_prime_unet_4s_20210215.hdf5",
-    "flattener_file": "flattener_60x_3z_prime_20210215.json",
-    "celltrack_model_file": "ct_rf_20210201_12.pkl",
-    "budassign_model_file": "mb_model_60x_3z_prime_20210215.pkl",
-    "pixel_size": 0.182,
-    "default_image_size": [117, 117],
-    "params": {
-      "interior_threshold": [0.65, 0.9, 0.3],
-      "nclosing": [0, 0, 0], "nopening": [1, 2, 1],
-      "connectivity": [1, 1, 1],
-      "edge_sub_dilations": [0, 0, 0],
-      "containment_thresh": 0.6, "min_area": 19,
-      "pedge_thresh": [0.0, 0.0, 0.0012],
-      "group_thresh_expansion": [0.12, 0.4, 0.3]
-    }
-  },
-  "evolve_brightfield_60x_3z": {
-    "morph_model_file": "I3_evolve_unet_4s_20210215.hdf5",
-    "flattener_file": "flattener_60x_3z_evolve_20210215.json",
-    "celltrack_model_file": "ct_rf_20210201_12.pkl",
-    "budassign_model_file": "mb_model_60x_3z_evolve_20210215.pkl",
-    "pixel_size": 0.263,
-    "default_image_size": [81, 81],
-    "params": {
-      "interior_threshold": [0.75, 0.4, 0.45],
-      "nclosing": [0, 0, 0], "nopening": [0, 2, 0],
-      "connectivity": [1, 1, 1], "edge_sub_dilations": [0, 0, 0],
-      "containment_thresh": 0.9, "min_area": 9,
-      "pedge_thresh": [0.0, 0.0015, 0.0],
-      "group_thresh_expansion": [0.28, 0.38, 0.22]
-    }
-  }
-}
diff --git a/python/baby/modelsets.py b/python/baby/modelsets.py
new file mode 100644
index 0000000000000000000000000000000000000000..645112e363c3e51c3ab509ae64ee43cd971c651d
--- /dev/null
+++ b/python/baby/modelsets.py
@@ -0,0 +1,404 @@
+# If you publish results that make use of this software or the Birth Annotator
+# for Budding Yeast algorithm, please cite:
+# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+# and Swain, P.S. (2023). Determining growth rates from bright-field images of
+# budding cells through identifying overlaps. eLife. 12:e79812.
+# https://doi.org/10.7554/eLife.79812
+# 
+# 
+# The MIT License (MIT)
+# 
+# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
+# 
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to
+# deal in the Software without restriction, including without limitation the
+# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+# sell copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+# 
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+# 
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
+# IN THE SOFTWARE.
+# SOFTWARE.
+import os
+from pathlib import Path
+import requests
+from urllib.parse import urljoin, quote
+import json
+
+from .errors import BadParam, BadModel, BadFile
+from .utils import jsonify, as_python_object
+from .morph_thresh_seg import SegmentationParameters
+
+MODELSETS_FILENAME = 'modelsets.json'
+MODELSET_FILENAME = 'modelset.json'
+
+BABY_MODELS_URL = 'https://julianpietsch.github.io/baby-models/'
+MODELSETS_URL = urljoin(BABY_MODELS_URL, MODELSETS_FILENAME)
+
+ENV_VAR_MODELSETS_PATH = 'BABY_MODELSETS_PATH'
+ENV_LOCAL_MODELSETS_PATH = os.environ.get(ENV_VAR_MODELSETS_PATH)
+if (type(ENV_LOCAL_MODELSETS_PATH) == str and
+    ENV_LOCAL_MODELSETS_PATH.startswith('~')):
+    ENV_LOCAL_MODELSETS_PATH = Path(ENV_LOCAL_MODELSETS_PATH).expanduser()
+DEFAULT_LOCAL_MODELSETS_PATH = Path.home() / '.baby_models'
+LOCAL_MODELSETS_PATH = ENV_LOCAL_MODELSETS_PATH or DEFAULT_LOCAL_MODELSETS_PATH
+LOCAL_MODELSETS_PATH = Path(LOCAL_MODELSETS_PATH)
+LOCAL_MODELSETS_CACHE = LOCAL_MODELSETS_PATH / MODELSETS_FILENAME
+
+SHARE_PATH = 'shared'
+
+
+def remote_modelsets():
+    '''Retrieve info on the model sets available for download
+
+    Returns:
+        Nested dict hierarchy decoded from remote JSON file. The root of the
+        hierarchy has keys 'models' and 'shared'.
+
+        The value of the 'models' key is a dict whose keys are the folder
+        names for the model sets and whose values are dicts giving details for
+        each model set. Each model set dict has the following items:
+        - a 'name' key with value being a str naming the model,
+        - a 'files' key with value being a list of str specifying all file
+          names in the model set folder,
+        - a 'meta' key with value being a dict of meta data on the model,
+        - a 'brain_params' key with value being a dict of parameters to
+          instantiate a :py:class:`brain.BabyBrain` object.
+
+        The 'shared' key at the root of the hierarchy points to a list of str
+        giving file names in the 'shared' folder.
+
+    '''
+    msets_text = requests.get(MODELSETS_URL).text
+    return json.loads(msets_text, object_hook=as_python_object)
+
+
+def _ensure_local_path():
+    '''Creates the directory to store model sets if it does not exist'''
+    if not LOCAL_MODELSETS_PATH.is_dir():
+        LOCAL_MODELSETS_PATH.mkdir(parents=True)
+    local_share_path = LOCAL_MODELSETS_PATH / SHARE_PATH
+    if not local_share_path.is_dir():
+        local_share_path.mkdir(parents=True)
+
+
+def update_local_cache():
+    '''Update the locally cached info on model sets available for download'''
+    _ensure_local_path()
+    with open(LOCAL_MODELSETS_CACHE, 'wt') as f:
+        json.dump(jsonify(remote_modelsets()), f)
+
+
+def _ensure_local_cache(update=False):
+    if not LOCAL_MODELSETS_CACHE.exists() or update:
+        update_local_cache()
+
+
+def specifications(update=False, local=False):
+    '''Get full specifications for all model sets
+
+    By default returns specifications for all models, including models that
+    have not yet been downloaded.
+
+    Args:
+        update: whether or not to update the locally-cached list of models
+        local: whether to include only locally-available models.
+
+    Returns:
+        a `dict` mapping model set IDs to dicts specifying details for each
+        model set. Each model set dict has at least the following items:
+        - a 'name' key with value being a str naming the model,
+        - a 'meta' key with value being a dict of meta data on the model,
+        - a 'brain_params' key with value being a dict of parameters to
+          instantiate a :py:class:`brain.BabyBrain` object.
+    '''
+    _ensure_local_cache(update=update)
+    if local:
+        return local_modelsets()['models']
+    with open(LOCAL_MODELSETS_CACHE, 'rt') as f:
+        modelsets_info = json.load(f, object_hook=as_python_object)
+    return modelsets_info['models']
+
+
+def ids(update=False, local=False):
+    '''List the available model sets by ID
+
+    Args:
+        update: whether or not to update the locally-cached list of models
+        local: whether to include only locally-available models.
+
+    Returns:
+        a `list` of `str` giving the ID of each available model set
+    '''
+    return list(specifications(update=update, local=local).keys())
+
+
+def meta(update=False, local=False):
+    '''Obtain descriptive meta information on each model set
+
+    Args:
+        update: whether or not to update the locally-cached list of models
+        local: whether to include only locally-available models.
+
+    Returns:
+        a `dict` mapping model set IDs to dicts of meta information associated
+        with each model set
+    '''
+    return {msId: v['meta'] for msId, v in
+            specifications(update=update, local=local).items()}
+
+
+def local_modelsets():
+    _ensure_local_path()
+    modelsets = {}
+    for modelset_file in LOCAL_MODELSETS_PATH.glob('**/' + MODELSET_FILENAME):
+        modelset_id = modelset_file.parent.name
+        with open(modelset_file, 'rt') as f:
+            modelsets[modelset_id] = json.load(f, object_hook=as_python_object)
+        modelsets[modelset_id]['files'] = [
+            p.name for p in modelset_file.parent.iterdir()]
+
+    local_share_path = LOCAL_MODELSETS_PATH / SHARE_PATH
+    shared_files = [p.name for p in local_share_path.iterdir()]
+
+    return {'models': modelsets, 'shared': shared_files}
+
+
+def resolve(filename, modelset_id):
+    '''Resolve the path to a file specified by a model set
+
+    File names are resolved by first assuming the `modelset_id` argument
+    specifies a directory containing model files (the model set path). If
+    `modelset_id` is not a directory, then it is assumed to be a model ID as
+    per :py:func:`ids`, and then the corresponding model set path in the local
+    cache will be searched first. 
+
+    If the file is not found in the model set path, it will be searched for in
+    the shared directory of the local cache.
+
+    Args:
+        filename (str): file name to resolve, typically as specified by one of
+            the keys obtained from :py:func:`get_params`.
+        modelset_id (str): one of the model IDs as returned by :py:func:`ids`.
+
+    Returns:
+        A `Path` object giving the path to the specified file. If the file
+        cannot be found a BadParam exception is raised.
+    '''
+    if Path(modelset_id).is_dir():
+        modelset_path = Path(modelset_id)
+    else:
+        modelset_path = LOCAL_MODELSETS_PATH / modelset_id
+
+    trial_path = modelset_path / filename
+    if trial_path.is_file():
+        return trial_path
+
+    trial_path = LOCAL_MODELSETS_PATH / SHARE_PATH / filename
+    if trial_path.is_file():
+        return trial_path
+    else:
+        raise BadParam(f'The file {filename} could not be resolved for model set {modelset_id}')
+
+
+def _get_modelset_files(modelset):
+    params = modelset['brain_params']
+    modelset_files = [v for k, v in params.items() if k.endswith('_file')]
+    if type(params['params']) not in {dict, SegmentationParameters}:
+        modelset_files.append(params['params'])
+    return modelset_files
+
+
+def _missing_files(modelset, modelset_id):
+    modelset_files = _get_modelset_files(modelset)
+    missing = []
+    for fname in modelset_files:
+        try:
+            resolve(fname, modelset_id)
+        except BadParam:
+            missing.append(fname)
+    return missing
+
+
+def update(modelset_ids=None, force=True, cleanup=False, verbose=True):
+    '''Updates any outdated model sets that are available remotely
+
+    Args:
+        modelset_ids: a list of str specifying one of the model IDs as
+            returned by :py:func:`ids`. By default, only updates model sets
+            that have already been downloaded locally. Set this to `'all'` to
+            download all available model sets. 
+        force: whether to replace existing files or just obtain missing ones.
+        cleanup: whether to delete model set and shared files that are no
+            longer on the remote
+        verbose: whether to print download status to standard out.
+    '''
+    remote_mset_info = remote_modelsets()
+    remote_msets = remote_mset_info['models']
+    local_mset_info = local_modelsets()  # ensures local path
+    local_msets = local_mset_info['models']
+
+    if modelset_ids is None:
+        remote_ids = set(remote_msets.keys())
+        modelset_ids = list(remote_ids.intersection(local_msets.keys()))
+    elif modelset_ids == 'all':
+        modelset_ids = list(remote_msets.keys())
+
+    invalid_ids = [msId for msId in modelset_ids if msId not in remote_msets]
+    if any(invalid_ids):
+        invalid_ids = ', '.join([f'`{msId}`' for msId in invalid_ids])
+        raise BadParam(f'Requested model set(s) {invalid_ids} not available')
+
+    # Update model sets from remote host
+    for mset_id in modelset_ids:
+        mset_meta = remote_msets[mset_id]
+        local_mset_dir = LOCAL_MODELSETS_PATH / mset_id
+        if not local_mset_dir.exists():
+            local_mset_dir.mkdir(parents=True)
+
+        mset_missing = mset_id not in local_msets
+        mset_changed = mset_missing or local_msets[mset_id] != mset_meta
+        if mset_missing or (mset_changed and not force):
+            new_mset_meta = mset_meta.copy()
+            del new_mset_meta['files']
+            with open(local_mset_dir / MODELSET_FILENAME, 'wt') as f:
+                json.dump(jsonify(new_mset_meta), f)
+
+        if not force and not mset_missing:
+            # If there is already a local model set and we are not forcing an
+            # update, then we will only proceed with file download / cleanup
+            # if the local model has missing files. This allows for the case
+            # where a local version of a model has different files to those
+            # found on the server, and so avoids downloading potentially
+            # outdated extras.
+            nmissing = len(_missing_files(local_msets[mset_id], mset_id))
+            if nmissing == 0:
+                continue
+
+        remote_mset_files = set(mset_meta['files'])
+        local_mset_files = local_msets.get(mset_id, {}).get('files', [])
+        if cleanup:
+            # Clean up any old files that are no longer on the remote
+            for local_file in local_mset_files:
+                if local_file not in remote_mset_files:
+                    (local_mset_dir / local_file).unlink()
+
+        remote_mset_files = remote_mset_files.difference({MODELSET_FILENAME})
+        remote_mset_dir = urljoin(BABY_MODELS_URL, mset_id + '/')
+        if force:
+            files_to_download = remote_mset_files
+        else:
+            files_to_download = remote_mset_files.difference(local_mset_files)
+        if len(files_to_download) > 0 and verbose:
+            print(f'Downloading files for {mset_id}...')
+        for remote_file in files_to_download:
+            r = requests.get(urljoin(remote_mset_dir, remote_file))
+            with open(local_mset_dir / remote_file, 'wb') as f:
+                for chunk in r.iter_content(chunk_size=128):
+                    f.write(chunk)
+
+    # Update shared files from remote host
+    remote_share_dir = urljoin(BABY_MODELS_URL, SHARE_PATH + '/')
+    remote_shared_files = set(remote_mset_info['shared'])
+    local_share_path = LOCAL_MODELSETS_PATH / SHARE_PATH
+    local_shared_files = local_mset_info['shared']
+
+    if cleanup:
+        # Clean up any old files that are no longer on the remote
+        for local_file in local_shared_files:
+            if local_file not in remote_shared_files:
+                (local_share_path / local_file).unlink()
+
+    if force:
+        files_to_download = remote_shared_files
+    else:
+        files_to_download = remote_shared_files.difference(local_shared_files)
+
+    # Download any files that need updating
+    if len(files_to_download) > 0 and verbose:
+        print('Downloading shared files...')
+    for remote_file in files_to_download:
+        r = requests.get(urljoin(remote_share_dir, remote_file))
+        with open(local_share_path / remote_file, 'wb') as f:
+            for chunk in r.iter_content(chunk_size=128):
+                f.write(chunk)
+
+
+def _ensure_modelset(modelset_id):
+    '''Ensure that a model set has been downloaded and is ready to use
+
+    Args:
+        modelset_id: a `str` specifying one of the model IDs as returned by
+        :py:func:`ids`.
+    '''
+    _ensure_local_path()
+    local_path = LOCAL_MODELSETS_PATH / modelset_id
+    share_path = LOCAL_MODELSETS_PATH / SHARE_PATH
+    local_modelset_file = local_path / MODELSET_FILENAME
+    updated = False
+    if not local_modelset_file.exists():
+        update([modelset_id], force=False)
+        updated = True
+    with open(local_modelset_file, 'rt') as f:
+        modelset = json.load(f, object_hook=as_python_object)
+    modelset_files = _get_modelset_files(modelset)
+    for fname in modelset_files:
+        try:
+            resolve(fname, modelset_id)
+        except BadParam:
+            if updated:
+                raise BadModel('Model is corrupt. Contact maintainer.')
+            else:
+                update([modelset_id], force=False)
+                try:
+                    resolve(fname, modelset_id)
+                except BadParam:
+                    raise BadModel('Model is corrupt. Contact maintainer.')
+                updated = True
+
+
+def get_params(modelset_id):
+    '''Get model set parameters
+
+    The parameters are designed to be supplied as the argument to
+    instantiate a :py:class:`brain.BabyBrain` object.
+
+    The model set will be automatically downloaded if it has not yet been.
+
+    Args:
+        modelset_id: a `str` specifying one of the model IDs as returned by
+        :py:func:`ids`.
+    '''
+    _ensure_modelset(modelset_id)
+    local_path = LOCAL_MODELSETS_PATH / modelset_id
+    local_modelset_file = local_path / MODELSET_FILENAME
+    with open(local_modelset_file, 'rt') as f:
+        modelset = json.load(f, object_hook=as_python_object)
+    return modelset['brain_params']
+
+
+def get(modelset_id, **kwargs):
+    '''Get a model set as a BabyBrain object
+
+    The model set will be automatically downloaded if it has not yet been.
+
+    Args:
+        modelset_id: a `str` specifying one of the model IDs as returned by
+        :py:func:`ids`.
+
+    Returns:
+        A :py:class:`brain.BabyBrain` object instantiated with the model set
+        parameters.
+    '''
+    from .brain import BabyBrain
+    return BabyBrain(modelset_path=modelset_id, **get_params(modelset_id),
+                     **kwargs)
diff --git a/python/baby/morph_thresh_seg.py b/python/baby/morph_thresh_seg.py
index f09e10d5484747dc81f3ba24752b0d23b1803d7d..2e9832c62e6f6991fd3234605f918d591035dcda 100644
--- a/python/baby/morph_thresh_seg.py
+++ b/python/baby/morph_thresh_seg.py
@@ -1,12 +1,14 @@
 # If you publish results that make use of this software or the Birth Annotator
 # for Budding Yeast algorithm, please cite:
-# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S
-# Swain, 2021, Birth Annotator for Budding Yeast (in preparation).
+# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+# and Swain, P.S. (2023). Determining growth rates from bright-field images of
+# budding cells through identifying overlaps. eLife. 12:e79812.
+# https://doi.org/10.7554/eLife.79812
 # 
 # 
 # The MIT License (MIT)
 # 
-# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021
+# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
 # 
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to
@@ -30,15 +32,142 @@ import itertools
 
 from scipy import ndimage
 
-from typing import Union, Iterable, Any, Optional, List, NamedTuple
+from typing import (Union, Sequence, Any, Optional, List, NamedTuple, Literal,
+                    Tuple)
 
 from .errors import BadParam
-from .segmentation import mask_containment, iterative_erosion, thresh_seg, \
-    binary_edge, single_region_prop, outline_to_radial, get_edge_scores, \
-    refine_radial_grouped, iterative_dilation, draw_radial
+from baby import segmentation
+from .segmentation import (
+    iterative_erosion, threshold_segmentation, binary_edge,
+    single_region_prop, mask_to_knots, get_edge_scores, refine_radial_grouped,
+    iterative_dilation, draw_radial)
+from .utils import EncodableNamedTuple
 from .volume import volume
 
 
+BROADCASTABLE_PARAMS = {'edge_sub_dilations', 'interior_threshold',
+                        'nclosing', 'nopening', 'connectivity',
+                        'min_area', 'bbox_padding', 'pedge_thresh',
+                        'group_thresh_expansion'}
+
+
+@EncodableNamedTuple
+class SegmentationParameters(NamedTuple):
+    """Defines parameters for segmentation from CNN outputs.
+
+    Designed to be used with :py:class:`MorphSegGrouped` and descriptions of
+    the attributes below should be read in conjunction with that class.
+
+    Certain parameters can be specified group-wise, that is, with a different
+    value for each of the groupings defined by ``cellgroups``. If specified as
+    a singleton, the same parameter value will be replicated for each group. If
+    specified as a list, the length of the list should match the number of
+    cell groupings.
+
+    This class has been decorated with :py:func:`utils.EncodableNamedTuple`,
+    so can be reversibly encoded to JSON using :py:func:`utils.jsonify`.
+
+    Attributes:
+        cellgroups: By default, all groups defined by the ``flattener``
+            (except ``budonly`` groups) will be included in descending size
+            order. Optionally specify a sequence of ``str`` or a sequence of
+            tuples of ``str`` naming specific flattener groups or respectively
+            combinations thereof.
+        edge_sub_dilations: Number of iterations of grayscale morphological
+            dilation to apply to the edge probability image when it is used
+            to dampen the interior probability image before thresholding.
+            Specify ``None`` to disable dampening, or ``0`` to dampen with
+            the (undilated) edge probability image. Can be specified
+            group-wise.
+        interior_threshold: Threshold to apply to interior probability image
+            for locating cell instances (see
+            :py:func:`segmentation.threshold_segmentation`). Can be specified
+            group-wise.
+        connectivity: Connectivity for labelling connected regions (see
+            :py:func:`segmentation.threshold_segmentation`). Can be specified
+            group-wise.
+        nclosing: Number of closing iterations to apply to identified masks
+            (see :py:func:`segmentation.threshold_segmentation`). Can be
+            specified group-wise.
+        nopening: Number of opening iterations to apply to identified masks
+            (see :py:func:`segmentation.threshold_segmentation`). Can be
+            specified group-wise.
+        min_area: Minimum area (in square pixels) below which a mask will be
+            discarded. Has precedence over a smaller minimum specified by
+            any of the ``cellgroups``. Can be specified group-wise.
+        fit_radial: If ``True``, edges are generated by splines with knots
+            defined in radial coordinates; if ``False``, edges are taken
+            directly from the dilated masks.
+        cartesian_spline: If ``True``, splines are interpolated in cartesian
+            coordinates; if ``False``, they are interpolated in radial
+            coordinates.
+        bbox_padding: The number of pixels of padding that should be
+            included when limiting mask processing to a bounding box.
+        nrays_thresh_map: See :py:func:`segmentation.guess_radial_edge`. Only
+            applies when ``cartesian_spline=False`` and ``curvy_knots=False``.
+        nrays_max: See :py:func:`segmentation.guess_radial_edge`. Only applies
+            when ``cartesian_spline=False`` and ``curvy_knots=False``.
+        n_trials: See :py:func:`segmentation.guess_cartesian_edge`. Only
+            applies when ``cartesian_spline=True`` and ``curvy_knots=False``.
+        pxperknot: See :py:func:`segmentation.guess_cartesian_edge`. Only
+            applies when ``cartesian_spline=True`` and ``curvy_knots=False``.
+        squishiness: See :py:func:`segmentation.guess_cartesian_edge`. Only
+            applies when ``cartesian_spline=True`` and ``curvy_knots=False``.
+        alignedness: See :py:func:`segmentation.guess_cartesian_edge`. Only
+            applies when ``cartesian_spline=True`` and ``curvy_knots=False``.
+        curvy_knots: If ``True``, knots for splines are placed at points of
+            estimated high curvature
+            (:py:func:`segmentation.curvy_knots_from_outline`); if ``False``,
+            knots are placed at equally spaced locations in radial space
+            (``cartesian_spline=False``) or approximately equally spaced
+            locations along the perimeter (``cartesian_spline=True``).
+        n_knots_fraction: Determines the number of knots as a fraction of edge
+            pixel count for the case where ``curvy_knots=True``
+            (see:py:func:`segmentation.curvy_knots_from_outline`).
+        pedge_thresh: Threshold on the edge probability scores of guessed
+            outlines, below which outlines will be discarded. Can be specified
+            group-wise. If any are specified as ``None``, then the decision on
+            which contained outline to keep in neighbouring groups based on
+            area rather than edge score.
+        use_group_thresh: If ``True``, discard masks if their areas fall
+            outside the size limits as defined for the group they originated
+            in; if ``False``, keep all generated masks (except those below the
+            ``min_area`` threshold).
+        group_thresh_expansion: Expansion factor for the size limits defined
+            by each group. Can be specified group-wise.
+        containment_func: Specifies the metric to use to decide whether cells
+            in adjacent groups should be considered identical or not.
+            Currently the only valid choices are 'mask_containment' and
+            'mask_iou' (see :py:func:`segmentation.mask_containment` and
+            :py:func:`segmentation.mask_iou`).
+        containment_thresh: Threshold above which to consider cells in
+            adjacent groups identical.
+    """
+    cellgroups: Union[None, List[str], List[Tuple[str, ...]]] = None
+    edge_sub_dilations: Union[None, int, List[Optional[int]]] = 0
+    interior_threshold: Union[float, List[float]] = 0.5
+    nclosing: Union[int, List[int]] = 0
+    nopening: Union[int, List[int]] = 0
+    connectivity: Union[int, List[int]] = 2
+    min_area: Union[int, List[int]] = 10
+    fit_radial: bool = True
+    cartesian_spline: int = False
+    bbox_padding: Union[int, List[int]] = 10
+    nrays_thresh_map: List[Tuple[float, int]] = [(5., 4), (20., 6)]
+    nrays_max: int = 8
+    n_trials: int = 10
+    pxperknot: float = 4.
+    squishiness: float = 0.1
+    alignedness: float = 10.
+    curvy_knots: bool = False
+    n_knots_fraction: float = 0.5
+    pedge_thresh: Union[None, float, List[Optional[float]]] = None
+    use_group_thresh: bool = False
+    group_thresh_expansion: Union[float, List[float]] = 0.
+    containment_func: Literal['mask_containment', 'mask_iou'] = 'mask_containment'
+    containment_thresh: int = 0.8
+
+
 # class ContainmentFunction:
 #     def __init__(self, threshold: float = .8):
 #         self.threshold = threshold
@@ -60,13 +189,11 @@ from .volume import volume
 # a = A()
 
 class Cell:
-    def __init__(self, area, mask, predicted_edge, border_rect,
-                 fit_radial=True):
+    def __init__(self, area, mask, predicted_edge, border_rect, params):
         self.area = area
         self.mask = mask
         self.predicted_edge = predicted_edge
-        self.fit_radial = fit_radial
-        self.border_rect = border_rect
+        self.params = params
 
         self._coords = None
         self._edge = None
@@ -79,15 +206,14 @@ class Cell:
                                                self.predicted_edge)[0]
         return self._edge_score
 
-    def _calculate_properties(self, fit_radial):
-        self._edge = binary_edge(self.mask)
-        if fit_radial:
-            rprop = single_region_prop(self.mask)
-            coords, edge = outline_to_radial(self.edge, rprop,
-                                             return_outline=True)
+    def _calculate_properties(self):
+        if self.params.fit_radial:
+            coords, edge = mask_to_knots(self.mask,
+                                         p_edge=self.predicted_edge,
+                                         **self.params._asdict())
             self.mask = ndimage.binary_fill_holes(edge)
         else:
-            edge = self._edge | (self.border_rect & self.mask)
+            edge = binary_edge(self.mask) | (self.border_rect & self.mask)
             coords = tuple()
         self._coords = coords
         self._edge = edge
@@ -95,13 +221,13 @@ class Cell:
     @property
     def edge(self):
         if self._edge is None:
-            self._calculate_properties(fit_radial=self.fit_radial)
+            self._calculate_properties()
         return self._edge
 
     @property
     def coords(self):
         if self._coords is None:
-            self._calculate_properties(fit_radial=self.fit_radial)
+            self._calculate_properties()
         return self._coords
 
     @property
@@ -172,53 +298,30 @@ class Target:
 
 class Group:
 
-    def __init__(self,
-                 targets,
-                 min_area=10.,
-                 use_thresh=False,
-                 thresh_expansion=0.,
-                 pedge_thresh=None,
-                 interior_threshold=0.5,
-                 n_closing=0,
-                 n_opening=0,
-                 connectivity=2,
-                 edge_sub_dilations=None):
-        # Parameter assignment
-        self.__connectivity = connectivity
-        self.__min_area = min_area
-        self.__use_thresh = use_thresh
-        self.__thresh_expansion = thresh_expansion
-        self.__pedge_thresh = pedge_thresh
-        self.__n_closing = n_closing
-        self.__n_opening = n_opening
-        self.__interior_threshold = interior_threshold
-        self.edge_sub_dilations = edge_sub_dilations
-
-        # Subgroup targets
+    def __init__(self, targets, params):
+
+        self.params = params
         self.targets = targets
+        self.cells = []
 
         # Computed members
         self._n_erode = None
         self._lower = None
         self._upper = None
 
-        # Dunno yet, probably functions
-        self.cells = []
-
     def _calculate_bounds(self):
         # get bounds
-        if self.__use_thresh:
+        if self.params.use_group_thresh:
             lower = min(target.definition.get('lower', 1.) for target in
                         self.targets)
             upper = max(target.definition.get('upper', float('inf')) for
                         target in self.targets)
-            expansion = self.__thresh_expansion * (lower
-                                                   if upper == float('inf')
-                                                   else upper - lower)
-            lower = max(lower - expansion, self.__min_area)
+            expansion = self.params.group_thresh_expansion * (
+                lower if upper == float('inf') else upper - lower)
+            lower = max(lower - expansion, self.params.min_area)
             upper += expansion
         else:
-            lower, upper = self.__min_area, float('inf')
+            lower, upper = self.params.min_area, float('inf')
         self._lower = lower
         self._upper = upper
 
@@ -244,7 +347,7 @@ class Group:
     @property
     def max_n_erode(self):
         max_n_erode = max(self.n_erode)
-        if self.edge_sub_dilations is not None and max_n_erode == 0:
+        if self.params.edge_sub_dilations is not None and max_n_erode == 0:
             return 1
         else:
             return max_n_erode  # Todo: add number of dilations?
@@ -255,7 +358,7 @@ class Group:
 
     @property
     def interior_threshold(self):
-        return self.__interior_threshold or 0.5
+        return self.params.interior_threshold or 0.5
 
     def prediction(self, pred, target_name, erode=False):
         predictions = [target.prediction(pred, target_name)
@@ -272,16 +375,22 @@ class Group:
                        for p, n_erode in zip(predictions, self.n_erode)]
         return predictions
 
-    def segment(self, pred, border_rect, fit_radial=False):
-        """
-        Obtain the cell masks, areas, edges, and coordiantes from the
-        prediction of the interior of the cell group.
-
-
-        :param fit_radial:
-        :param pred: The neural network's prediction
-        :param border_rect: A boolean array delimiting the border of the
-            prediction arrays
+    def segment(self, pred, border_rect):
+        """Find cell instances for this group from interior probability
+
+        Given raw CNN output ``pred``, extract the interior probability image
+        defined for this group, then apply threshold segmentation as per
+        :py:func:`segmentation.threshold_segmentation` and create a
+        :py:class:`Cell` for each identified mask. Depending on the
+        :py:class:`SegmentationParameters`, some filtering by mask area and/or
+        probability of the mask edge on the edge probability image may be
+        applied.
+        
+        Args:
+            pred (ndarray): The neural network's prediction with shape
+                (targets, X, Y, 1).
+            border_rect (ndarray): An image of shape (X, Y) which is False
+                except for a one-pixel border of True values.
         """
 
         # TODO this is the bit where the use of the overlap informations
@@ -289,27 +398,33 @@ class Group:
         pred_interior = self.prediction(pred, 'interior', erode=True)
         pred_edge = self.prediction(pred, 'edge', erode=False)
 
-        if self.edge_sub_dilations is not None:
+        if self.params.edge_sub_dilations is not None:
             pred_edge = iterative_dilation(pred_edge,
-                                           self.edge_sub_dilations)
+                                           self.params.edge_sub_dilations)
             pred_interior *= (1 - pred_edge)
 
-        masks_areas = [(m, a) for m, a in thresh_seg(
-            pred_interior, interior_threshold=self.__interior_threshold or 0.5,
-            nclosing=self.__n_closing or 0,
-            nopening=self.__n_opening or 0,
-            ndilate=self.max_n_erode,
-            return_area=True, connectivity=self.__connectivity)
-                       if self.lower <= a < self.upper]
-        self.cells = [Cell(a, m, pred_edge, border_rect, fit_radial=fit_radial)
-                      for m, a in masks_areas]
+        masks_areas = [
+            (m, a) for m, a in threshold_segmentation(
+                pred_interior,
+                interior_threshold=self.params.interior_threshold or 0.5,
+                nclosing=self.params.nclosing or 0,
+                nopening=self.params.nopening or 0,
+                ndilate=self.max_n_erode,
+                return_area=True,
+                connectivity=self.params.connectivity)
+            if self.lower <= a < self.upper
+        ]
+        self.cells = [
+            Cell(a, m, pred_edge, border_rect, self.params)
+            for m, a in masks_areas
+        ]
         # Remove cells that do not exceed the p_edge threshold
-        if self.__pedge_thresh is not None:
+        if self.params.pedge_thresh is not None:
             self.cells = [cell for cell in self.cells
-                          if cell.edge_score > self.__pedge_thresh]
+                          if cell.edge_score > self.params.pedge_thresh]
 
 
-def broadcast_arg(arg: Union[Iterable, Any],
+def broadcast_arg(arg: Union[Sequence, Any],
                   argname: Optional[str] = None,
                   n_groups: Optional[int] = 3):
     if argname is None:
@@ -323,98 +438,82 @@ def broadcast_arg(arg: Union[Iterable, Any],
         return [arg] * n_groups
 
 
-class MorphSegGrouped:
-    def __init__(self, flattener, cellgroups=None,
-                 interior_threshold=0.5, nclosing=0, nopening=0,
-                 connectivity=2,
-                 min_area=10, pedge_thresh=None, fit_radial=False,
-                 use_group_thresh=False, group_thresh_expansion=0.,
-                 edge_sub_dilations=None,
-                 containment_thresh=0.8, containment_func=mask_containment,
-                 return_masks=False, return_coords=False, return_volume=False):
-        """
 
-        :param flattener:
-        :param cellgroups:
-        :param interior_threshold:
-        :param nclosing:
-        :param nopening:
-        :param connectivity:
-        :param min_area:
-        :param pedge_thresh:
-        :param fit_radial:
-        :param use_group_thresh:
-        :param group_thresh_expansion:
-        :param edge_sub_dilations:
-        :param containment_thresh:
-        :param containment_func:
-        :param return_masks:
-        :param return_coords:
-        """
-        # Todo: assertions about valid options
-        #  (e.g. 0 < interior_threshold < 1)
-        # Assign options and parameters
-        self.pedge_thresh = pedge_thresh
-        self.fit_radial = fit_radial
-        self.containment_thresh = containment_thresh
-        self.containment_func = containment_func
+class MorphSegGrouped:
+    """Provides an instance segmentation method given CNN target definitions.
+
+    Args:
+        flattener (preprocessing.SegmentationFlattening): Target definitions
+            for the CNN predictions that will be used as input to
+            :py:meth:`segment`.
+        params (SegmentationParameters): Parameters used for segmentation. See
+            :py:class:`SegmentationParameters` for more details.
+        return_masks (bool): Whether masks should be returned by
+            :py:meth:`segment`.
+        return_coords (bool): Whether knot coordinates should be returned by
+            :py:meth:`segment`.
+
+    Attributes:
+        groups (List[Group]): Segmentation processing/results for each group.
+        flattener (preprocessing.SegmentationFlattening): As for the
+            ``flattener`` argument.
+        params (SegmentationParameters): Parameters used for segmentation
+            including any broadcasting for the number of cell groups.
+        return_masks: Whether masks should be returned by :py:meth:`segment`.
+        return_coords: Whether knot coordinates should be returned by
+            :py:meth:`segment`.
+    """
+
+    def __init__(self, flattener, params=SegmentationParameters(),
+                 return_masks=False, return_coords=False):
+        self.flattener = flattener
         self.return_masks = return_masks
         self.return_coords = return_coords
-        self.return_volume = return_volume
 
-        self.flattener = flattener
+        # TODO: assertions about valid options
+        #  (e.g. 0 < interior_threshold < 1)
 
-        # Define  group parameters
+        # Define group parameters
+        cellgroups = params.cellgroups
         if cellgroups is None:
-            cellgroups = ['large', 'medium', 'small']
+            cellgroups = [t.group for t in flattener.targets
+                          if t.prop in {'interior', 'filled'}]
         cellgroups = [(g,) if isinstance(g, str) else g for g in cellgroups]
         n_groups = len(cellgroups)
-        interior_threshold = broadcast_arg(interior_threshold,
-                                           'interior_threshold', n_groups)
-        n_closing = broadcast_arg(nclosing, 'nclosing', n_groups)
-        n_opening = broadcast_arg(nopening, 'nopening', n_groups)
-        min_area = broadcast_arg(min_area, 'min_area', n_groups)
-        connectivity = broadcast_arg(connectivity, 'connectivity', n_groups)
-        pedge_thresh = broadcast_arg(pedge_thresh, 'pedge_thresh', n_groups)
-        group_thresh_expansion = broadcast_arg(group_thresh_expansion,
-                                               'group_thresh_expansion',
-                                               n_groups)
-        edge_sub_dilations = broadcast_arg(edge_sub_dilations,
-                                           'edge_substraction_dilations',
-                                           n_groups)
+
+        # Broadcast relevant parameters
+        params = params._replace(**{
+            p: broadcast_arg(getattr(params, p), p, n_groups)
+            for p in BROADCASTABLE_PARAMS
+        })
 
         # Minimum area must be larger than 1 to avoid generating cells with
         # no size:
-        min_area = [np.max([a, 1]) for a in min_area]
+        params = params._replace(
+            min_area=[np.max([a, 1]) for a in params.min_area])
+
+        self.params = params
+
+        self.containment_func = getattr(segmentation, params.containment_func)
+        self.containment_thresh = params.containment_thresh
 
         # Initialize the different groups and their targets
         self.groups = []
         for i, target_names in enumerate(cellgroups):
             targets = [Target(name, flattener) for name in target_names]
-            self.groups.append(Group(targets, min_area=min_area[i],
-                                     use_thresh=use_group_thresh,
-                                     thresh_expansion=group_thresh_expansion[i],
-                                     pedge_thresh=pedge_thresh[i],
-                                     interior_threshold=interior_threshold[i],
-                                     n_closing=n_closing[i],
-                                     n_opening=n_opening[i],
-                                     connectivity=connectivity[i],
-                                     edge_sub_dilations=edge_sub_dilations[i]))
-        self.group_segs = None
+            self.groups.append(Group(targets, params=params._replace(**{
+                p: getattr(params, p)[i] for p in BROADCASTABLE_PARAMS
+            })))
 
     # Todo: This is ideally the form of the input argument
     def contains(self, a, b):
         return self.containment_func(a, b) > self.containment_thresh
 
     def remove_duplicates(self):
+        """Resolve any cells duplicated across adjacent groups.
         """
-        Resolve any cells duplicated across adjacent groups:
 
-        :param group_segs:
-        :return: The group segmentations with duplicates removed
-        """
-
-        if self.pedge_thresh is None:
+        if all([t is None for t in self.params.pedge_thresh]):
             def accessor(cell):
                 return cell.area
         else:
@@ -444,7 +543,7 @@ class MorphSegGrouped:
     def extract_edges(self, pred, shape, refine_outlines, return_volume):
         masks = [[]]
         if refine_outlines:
-            if not self.fit_radial:
+            if not self.params.fit_radial:
                 raise BadParam(
                     '"refine_outlines" requires "fit_radial" to have been specified'
                 )
@@ -456,11 +555,14 @@ class MorphSegGrouped:
 
             if predicted_edges:
                 coords = list(itertools.chain.from_iterable(
-                    refine_radial_grouped(grouped_coords,
-                                          predicted_edges)))
+                    refine_radial_grouped(
+                        grouped_coords, predicted_edges,
+                        cartesian_spline=self.params.cartesian_spline)))
             else:
                 coords = tuple()
-            edges = [draw_radial(radii, angles, centre, shape)
+            edges = [draw_radial(
+                        radii, angles, centre, shape,
+                        cartesian_spline=self.params.cartesian_spline)
                      for centre, radii, angles in coords]
             if self.return_masks:
                 masks = [ndimage.binary_fill_holes(e) for e in edges]
@@ -480,21 +582,23 @@ class MorphSegGrouped:
                            for cell in group.cells]
 
             if len(outputs) > 0:
-                return zip(*outputs)
+                return (list(o) for o in zip(*outputs))
             else:
                 return 4 * [[]] if return_volume else 3 * [[]]
 
     def segment(self, pred, refine_outlines=False, return_volume=False):
-        """
-        Take the output of the neural network and turn it into an instance
-        segmentation output.
-
-        :param pred: list of prediction images (ndarray with shape (x, y))
-        matching `self.flattener.names()`
-        :return: a list of boolean edge images (ndarray shape (x, y)), one for
-        each cell identified. If `return_masks` and/or `return_coords` are
-        true, the output will be a tuple of edge images, filled masks, and/or
-        radial coordinates.
+        """Returns segmented instances based on the output of the CNN.
+
+        Args:
+            pred: A list of prediction images (ndarray with shape (x, y))
+                matching the names of :py:attr:`flattener` (see
+                :py:meth:`preprocessing.SegmentationFlattening.names`).
+
+        Returns:
+            A list of boolean edge images (ndarray shape (x, y)), one for each
+            cell identified. If ``return_masks`` and/or ``return_coords`` are
+            true, the output will be a tuple of edge images, filled masks,
+            and/or radial coordinates.
         """
         if len(pred) != len(self.flattener.names()):
             raise BadParam(
@@ -506,7 +610,7 @@ class MorphSegGrouped:
             constant_values=True)
 
         for group in self.groups:
-            group.segment(pred, border_rect, fit_radial=self.fit_radial)
+            group.segment(pred, border_rect)
 
         # Remove cells that are duplicated in several groups
         self.remove_duplicates()
diff --git a/python/baby/performance.py b/python/baby/performance.py
index 888ed3de8501e102546e7246afc2979a77a2e009..d7a1005ef615e086468786f7bbe1a67ca3a32ef5 100644
--- a/python/baby/performance.py
+++ b/python/baby/performance.py
@@ -1,12 +1,14 @@
 # If you publish results that make use of this software or the Birth Annotator
 # for Budding Yeast algorithm, please cite:
-# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S
-# Swain, 2021, Birth Annotator for Budding Yeast (in preparation).
+# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+# and Swain, P.S. (2023). Determining growth rates from bright-field images of
+# budding cells through identifying overlaps. eLife. 12:e79812.
+# https://doi.org/10.7554/eLife.79812
 # 
 # 
 # The MIT License (MIT)
 # 
-# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021
+# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
 # 
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to
diff --git a/python/baby/postprocessing.py b/python/baby/postprocessing.py
index 65225788f4b1088550825f30782cec6b944d651c..b4d52f96740b927f2a46201a9859ef94946c950c 100644
--- a/python/baby/postprocessing.py
+++ b/python/baby/postprocessing.py
@@ -1,12 +1,14 @@
 # If you publish results that make use of this software or the Birth Annotator
 # for Budding Yeast algorithm, please cite:
-# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S
-# Swain, 2021, Birth Annotator for Budding Yeast (in preparation).
+# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+# and Swain, P.S. (2023). Determining growth rates from bright-field images of
+# budding cells through identifying overlaps. eLife. 12:e79812.
+# https://doi.org/10.7554/eLife.79812
 # 
 # 
 # The MIT License (MIT)
 # 
-# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021
+# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
 # 
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to
@@ -243,8 +245,8 @@ def fit_gr_data(filename, run_moth=True, run_dght=True, split_at_birth=False,
         fvals = ['f', 'df', 'ddf']
         for g in tqdm(data['D']):
             grp = data['D'][g]
-            allT = grp['times']
-            allV = grp['vol']
+            allT = grp['times'][()]
+            allV = grp['vol'][()]
             if log_volume:
                 allV = np.log(allV)
 
diff --git a/python/baby/preprocessing.py b/python/baby/preprocessing.py
index 3626e141463f36848e732d05a99972358f7b1fba..1e3ff1e31748c0244c1393093c4fe3c5c827e60f 100644
--- a/python/baby/preprocessing.py
+++ b/python/baby/preprocessing.py
@@ -1,12 +1,14 @@
 # If you publish results that make use of this software or the Birth Annotator
 # for Budding Yeast algorithm, please cite:
-# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S
-# Swain, 2021, Birth Annotator for Budding Yeast (in preparation).
+# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+# and Swain, P.S. (2023). Determining growth rates from bright-field images of
+# budding cells through identifying overlaps. eLife. 12:e79812.
+# https://doi.org/10.7554/eLife.79812
 # 
 # 
 # The MIT License (MIT)
 # 
-# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021
+# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
 # 
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to
@@ -25,25 +27,25 @@
 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
 # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
 # IN THE SOFTWARE.
+import json
+from itertools import chain
+from typing import NamedTuple, Union
 import numpy as np
 from skimage import img_as_float
 from scipy.ndimage import (
-    generic_filter, minimum_filter, maximum_filter,
+    convolve, minimum_filter, maximum_filter,
     binary_fill_holes, binary_dilation, binary_erosion, binary_closing
 )
 from skimage.measure import regionprops
 from skimage.draw import polygon
 from skimage.morphology import diamond
-from itertools import chain
-from typing import NamedTuple, Union
-import json
 
 from .segmentation import binary_edge
 from .utils import EncodableNamedTuple, jsonify, as_python_object
 
 # Depth-wise structuring elements for square or full connectivity
 dwsquareconn = diamond(1)[..., None]
-dwfullconn = np.ones((3,3,1), dtype='uint8')
+dwfullconn = np.ones((3,3,1), dtype=np.uint8)
 
 def raw_norm(img, info):
     """Keep raw values but scale non-float images to [0, 1]"""
@@ -73,79 +75,38 @@ def robust_norm(img, info, q_outliers=0.04):
     return (img - mid) * (1 - q_outliers) / imrange
 
 
-def seg_norm(img, info):
-    img = img > 0
-    # Connect any 1-pixel gaps:
-    imconn = generic_filter(img.astype('int'), np.sum, footprint=dwfullconn)
-    imconn = binary_erosion(imconn > 1, dwsquareconn) | img
-    return imconn, info
-
-
-def segoutline_flattening(fill_stack, info):
-    """Returns a stack of images in order:
-        edge: edges for all cells flattened into a single layer
-        filled: filled area for all cells flattened into a single layer
-        interiors: filled area excluding the edges
-        overlap: regions occupied by more than one cell
-        budring: edges located between mother and bud centres
-        bud: buds that are are smaller than a fractional threshold of the mother
-    """
-    imsize = fill_stack.shape[0:2]
-    ncells = fill_stack.shape[2]
-    imout = np.zeros(imsize+(6,), dtype='bool')
-
-    edge_stack = binary_edge(fill_stack, dwsquareconn)
-
-    edge_flat = np.any(edge_stack, axis=2)
-    fill_flat = np.any(fill_stack, axis=2)
-    overlap = np.sum(fill_stack, axis=2)>1
-
-    imout[:,:,0] = edge_flat  # edge
-    imout[:,:,1] = fill_flat  # filled
-    imout[:,:,2] = fill_flat & ~edge_flat & ~overlap  # interiors
-    imout[:,:,3] = overlap  # overlap
+def robust_norm_dw(img, info, q_outliers=0.04):
+    """Robust normalisation of intensity to [-1,1] applied depth-wise"""
+    img = img.copy().astype('float')
+    hq = q_outliers / 2
+    for i in range(img.shape[2]):
+        low, mid, high = np.quantile(img[:,:,i], (hq, 0.5, 1-hq))
+        imrange = high - low
+        imrange = 1 if imrange == 0 else imrange
+        img[:,:,i] = (img[:,:,i] - mid) * (1 - q_outliers) / imrange
+    return img
 
-    bud_pairs = [(m, np.nonzero(np.array(info.get('cellLabels', []))==b)[0][0])
-                 for m, b in enumerate(info.get('buds', []) or []) if b>0]
 
-    cell_info = [
-        regionprops(fill_stack[:,:,i].astype('int32'), coordinates='rc')
-        for i in range(fill_stack.shape[2])
-    ]
-    cell_info = [p[0] if len(p)>0 else None for p in cell_info]
+def connect_pixel_gaps(img):
+    """Connects any 1-pixel gaps in an edge image"""
+    # Connect any 1-pixel gaps:
+    imconn = convolve(img.astype(np.uint8), dwfullconn, mode='constant')
+    imconn = binary_erosion(imconn > 1, dwsquareconn) | img
+    return imconn
+    
 
-    for m, b in bud_pairs:
-        if cell_info[m] is None or cell_info[b] is None:
-            # Label possible transformed outside field of view by augmentation
-            continue
-        if m == b:
-            raise Exception('a mother cannot be its own bud')
-        m_centre = np.array(cell_info[m].centroid).T
-        b_centre = np.array(cell_info[b].centroid).T
-        r_width = cell_info[b].minor_axis_length*0.25
-        r_hvec = b_centre-m_centre
-        r_wvec = np.matmul(np.array([[0,-1],[1,0]]), r_hvec)
-        if np.linalg.norm(r_wvec) == 0:
-            raise Exception('mother and bud have coincident centres')
-        r_wvec = r_width*r_wvec/np.linalg.norm(r_wvec)
-        r_points = np.zeros((2,4))
-        r_points[:,0] = m_centre-0.5*r_wvec
-        r_points[:,1] = r_points[:,0] + r_hvec
-        r_points[:,2] = r_points[:,1] + r_wvec
-        r_points[:,3] = r_points[:,2] - r_hvec
-        r_inds, c_inds = polygon(r_points[0,:], r_points[1,:], imsize)
-        r_im = np.zeros(fill_stack.shape[0:2], dtype='bool')
-        r_im[r_inds, c_inds] = 1
+def seg_norm(img, info):
+    img = img > 0
+    return connect_pixel_gaps(img), info
 
-        # Bud junction
-        bj = (edge_stack[:,:,m] | edge_stack[:,:,b]) & r_im
-        imout[:,:,4] |= binary_dilation(binary_closing(bj))
 
-        # Smaller buds
-        if (cell_info[b].area / cell_info[m].area) < 0.7:
-            imout[:,:,5] |= fill_stack[:,:,b]
+def flattener_norm_func(flattener):
+    def norm_func(img, info):
+        img, info = seg_norm(img, info)
+        img = binary_fill_holes(img, dwsquareconn)
+        return flattener(img, info)
 
-    return imout
+    return norm_func
 
 
 @EncodableNamedTuple
@@ -265,6 +226,16 @@ class SegmentationFlattening(object):
     def names(self):
         return tuple(t.name for t in self.targets)
 
+    def group_names(self, exclude_budonly=False):
+        groups = self.groupdef.items()
+        if exclude_budonly:
+            groups = [(k, g) for k, g in groups if not g.budonly]
+        sort_lower = sorted(
+            groups, key=lambda i: i[1].lower, reverse=True)
+        sort_upper = sorted(
+            sort_lower, key=lambda i: i[1].upper, reverse=True)
+        return next(zip(*sort_upper))
+
     def getGroupTargets(self, group, propfilter=None):
         assert group in self.groupdef, \
             '"{}" group does not exist'.format(group)
@@ -438,10 +409,77 @@ class SegmentationFlattening(object):
         return np.dstack(targetims)
 
 
-def flattener_norm_func(flattener):
-    def norm_func(img, info):
-        img, info = seg_norm(img, info)
-        img = binary_fill_holes(img, dwsquareconn)
-        return flattener(img, info)
+##################
+### DEPRECATED ###
+##################
 
-    return norm_func
+
+def segoutline_flattening(fill_stack, info):
+    """Returns a stack of images in order.
+
+     DEPRECATED
+     
+    Args:
+        edge: edges for all cells flattened into a single layer
+        filled: filled area for all cells flattened into a single layer
+        interiors: filled area excluding the edges
+        overlap: regions occupied by more than one cell
+        budring: edges located between mother and bud centres
+        bud: buds that are are smaller than a fractional threshold of the mother
+    """
+    imsize = fill_stack.shape[0:2]
+    ncells = fill_stack.shape[2]
+    imout = np.zeros(imsize+(6,), dtype='bool')
+
+    edge_stack = binary_edge(fill_stack, dwsquareconn)
+
+    edge_flat = np.any(edge_stack, axis=2)
+    fill_flat = np.any(fill_stack, axis=2)
+    overlap = np.sum(fill_stack, axis=2)>1
+
+    imout[:,:,0] = edge_flat  # edge
+    imout[:,:,1] = fill_flat  # filled
+    imout[:,:,2] = fill_flat & ~edge_flat & ~overlap  # interiors
+    imout[:,:,3] = overlap  # overlap
+
+    bud_pairs = [(m, np.nonzero(np.array(info.get('cellLabels', []))==b)[0][0])
+                 for m, b in enumerate(info.get('buds', []) or []) if b>0]
+
+    cell_info = [
+        regionprops(fill_stack[:,:,i].astype('int32'), coordinates='rc')
+        for i in range(fill_stack.shape[2])
+    ]
+    cell_info = [p[0] if len(p)>0 else None for p in cell_info]
+
+    for m, b in bud_pairs:
+        if cell_info[m] is None or cell_info[b] is None:
+            # Label possible transformed outside field of view by augmentation
+            continue
+        if m == b:
+            raise Exception('a mother cannot be its own bud')
+        m_centre = np.array(cell_info[m].centroid).T
+        b_centre = np.array(cell_info[b].centroid).T
+        r_width = cell_info[b].minor_axis_length*0.25
+        r_hvec = b_centre-m_centre
+        r_wvec = np.matmul(np.array([[0,-1],[1,0]]), r_hvec)
+        if np.linalg.norm(r_wvec) == 0:
+            raise Exception('mother and bud have coincident centres')
+        r_wvec = r_width*r_wvec/np.linalg.norm(r_wvec)
+        r_points = np.zeros((2,4))
+        r_points[:,0] = m_centre-0.5*r_wvec
+        r_points[:,1] = r_points[:,0] + r_hvec
+        r_points[:,2] = r_points[:,1] + r_wvec
+        r_points[:,3] = r_points[:,2] - r_hvec
+        r_inds, c_inds = polygon(r_points[0,:], r_points[1,:], imsize)
+        r_im = np.zeros(fill_stack.shape[0:2], dtype='bool')
+        r_im[r_inds, c_inds] = 1
+
+        # Bud junction
+        bj = (edge_stack[:,:,m] | edge_stack[:,:,b]) & r_im
+        imout[:,:,4] |= binary_dilation(binary_closing(bj))
+
+        # Smaller buds
+        if (cell_info[b].area / cell_info[m].area) < 0.7:
+            imout[:,:,5] |= fill_stack[:,:,b]
+
+    return imout
diff --git a/python/baby/profile.py b/python/baby/profile.py
index d59cec1b0b1820618f80280efb8e368590d68799..4de4ea07ebb6636ebd8df407a10478e578b4fbff 100644
--- a/python/baby/profile.py
+++ b/python/baby/profile.py
@@ -1,12 +1,14 @@
 # If you publish results that make use of this software or the Birth Annotator
 # for Budding Yeast algorithm, please cite:
-# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S
-# Swain, 2021, Birth Annotator for Budding Yeast (in preparation).
+# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+# and Swain, P.S. (2023). Determining growth rates from bright-field images of
+# budding cells through identifying overlaps. eLife. 12:e79812.
+# https://doi.org/10.7554/eLife.79812
 # 
 # 
 # The MIT License (MIT)
 # 
-# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021
+# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
 # 
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to
diff --git a/python/baby/seg_trainer.py b/python/baby/seg_trainer.py
deleted file mode 100644
index 91fc368ad5a6bf2a7d3d4bf2496d74f9ef264af6..0000000000000000000000000000000000000000
--- a/python/baby/seg_trainer.py
+++ /dev/null
@@ -1,449 +0,0 @@
-# If you publish results that make use of this software or the Birth Annotator
-# for Budding Yeast algorithm, please cite:
-# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S
-# Swain, 2021, Birth Annotator for Budding Yeast (in preparation).
-# 
-# 
-# The MIT License (MIT)
-# 
-# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021
-# 
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to
-# deal in the Software without restriction, including without limitation the
-# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
-# sell copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-# 
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-# 
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
-# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
-# IN THE SOFTWARE.
-from math import floor, log10
-from itertools import combinations, product, chain, repeat
-from typing import NamedTuple, Union, Tuple, Any
-import numpy as np
-np.seterr(all='ignore')
-import pandas as pd
-
-from baby.segmentation import mask_containment
-from baby.morph_thresh_seg import MorphSegGrouped
-from baby.performance import calc_IoUs, best_IoU
-from baby.brain import default_params
-from baby.errors import BadProcess, BadParam
-
-BASIC_PARAMS = {
-    'interior_threshold', 'nclosing', 'nopening', 'connectivity',
-    'edge_sub_dilations'
-}
-
-round_to_n = lambda x, n: round(x, -int(floor(log10(x))) + (n - 1))
-
-
-class Score(NamedTuple):
-    precision: float
-    recall: float
-    F1: float
-    F0_5: float
-    F2: float
-    meanIoU: float
-
-
-class SegFilterParamOptim:
-    """
-    # TODO What does this class do
-        * What are the parameters and what do they mean
-        * What are the defaults, what are the ranges/admissible options?
-    :param flattener:
-    :param basic_params:
-    :param IoU_thresh:
-    :param scoring:
-    :param nbootstraps:
-    :param bootstrap_frac:
-    """
-    def __init__(self,
-                 flattener,
-                 basic_params={},
-                 IoU_thresh=0.5,
-                 scoring='F0_5',
-                 nbootstraps=10,
-                 bootstrap_frac=0.9):
-
-        self.IoU_thresh = IoU_thresh
-        self.scoring = scoring
-        self.nbootstraps = nbootstraps
-        self.bootstrap_frac = bootstrap_frac
-        self._basic_params = default_params.copy()
-        self._basic_params.update(
-            {k: v for k, v in basic_params.items() if k in BASIC_PARAMS})
-        self._basic_params.update({
-            'fit_radial': True,
-            'min_area': 1,
-            'pedge_thresh': None,
-            'use_group_thresh': False
-        })
-        self.segmenter = MorphSegGrouped(flattener,
-                                         return_masks=True,
-                                         **self._basic_params)
-
-        self.group_info = []
-        for g, group in enumerate(self.segrps):
-            lower = min(
-                target.definition.get('lower', 1.)
-                for target in group.targets)
-            upper = max(
-                target.definition.get('upper', float('inf'))
-                for target in group.targets)
-            grange = lower if upper == float('inf') else upper - lower
-            self.group_info.append((g, lower, upper, grange))
-
-    @property
-    def scoring(self):
-        """ The scoring method used during evaluation of the segmentation.
-        Accepted values are: # TODO define the scoring metrics
-        * precision:
-        * recall:
-        * F1:
-        * F0_5:
-        * F2:
-        * meanIoU:
-        :return: str scoring method
-        """
-        return self._scoring
-
-    @scoring.setter
-    def scoring(self, val):
-        if val not in Score._fields:
-            raise BadParam('Specified scoring metric not available')
-        self._scoring = val
-
-    @property
-    def basic_params(self):
-        return self._basic_params
-
-    @property
-    def segrps(self):
-        return self.segmenter.groups
-
-    @property
-    def stat_table(self):
-        val = getattr(self, '_stat_table', None)
-        if val is None:
-            raise BadProcess('"generate_stat_table" has not been run')
-        return val
-
-    @property
-    def stat_table_bootstraps(self):
-        val = getattr(self, '_stat_table_bootstraps', None)
-        if val is None:
-            raise BadProcess('"generate_stat_table" has not been run')
-        return val
-
-    @property
-    def truth(self):
-        val = getattr(self, '_nPs', None)
-        if val is None:
-            raise BadProcess('"generate_stat_table" has not been run')
-        return val
-
-    @property
-    def truth_bootstraps(self):
-        val = getattr(self, '_nPs_bootstraps', None)
-        if val is None:
-            raise BadProcess('"generate_stat_table" has not been run')
-        return val
-
-    @property
-    def opt_params(self):
-        val = getattr(self, '_opt_params', None)
-        if val is None:
-            raise BadProcess('"fit_filter_params" has not been run')
-        return val
-
-    @property
-    def opt_score(self):
-        val = getattr(self, '_opt_score', None)
-        if val is None:
-            raise BadProcess('"fit_filter_params" has not been run')
-        return val
-
-    def generate_stat_table(self, example_gen):
-        rows_truth = []
-        rows = []
-        for s, seg_ex in enumerate(example_gen):
-            ncells = len(seg_ex.target) if seg_ex.target.any() else 0
-            rows_truth.append((s, ncells))
-
-            # Perform within-group segmentation
-            shape = np.squeeze(seg_ex.pred[0]).shape
-            border_rect = np.pad(np.zeros(tuple(x - 2 for x in shape),
-                                          dtype='bool'),
-                                 pad_width=1,
-                                 mode='constant',
-                                 constant_values=True)
-            masks = []
-            for group in self.segrps:
-                group.segment(seg_ex.pred, border_rect, fit_radial=True)
-                for cell in group.cells:
-                    masks.append(cell.mask)
-
-            # Calculate containment scores across groups
-            contained_cells = {}
-            paired_groups = zip(self.segrps, self.segrps[1:])
-            for g, (lower_group, upper_group) in enumerate(paired_groups):
-                for l, lower in enumerate(lower_group.cells):
-                    for u, upper in enumerate(upper_group.cells):
-                        containment = mask_containment(lower.mask, upper.mask)
-                        if containment > 0:
-                            if lower.edge_score > upper.edge_score:
-                                contained_cells[(g + 1, u)] = containment
-                            else:
-                                contained_cells[(g, l)] = containment
-
-            if ncells > 0:
-                IoUs = calc_IoUs(seg_ex.target, masks, fill_holes=False)
-                max_IoU = IoUs.max(axis=0)
-                assignments = IoUs.argmax(axis=0)
-                _, best_assignments = best_IoU(IoUs.T)
-            else:
-                max_IoU = np.zeros(len(masks))
-                assignments = np.zeros(len(masks), dtype=np.uint16)
-                best_assignments = -np.ones(len(masks), dtype=np.int32)
-            ind = 0
-            for g, group in enumerate(self.segrps):
-                for c, cell in enumerate(group.cells):
-                    rows.append((s, g, c, cell.area, cell.edge_score,
-                                 contained_cells.get((g, c),
-                                                     0.), assignments[ind],
-                                 max_IoU[ind], best_assignments[ind]))
-                    ind += 1
-
-        df_truth = pd.DataFrame(rows_truth, columns=['example', 'ncells'])
-        df_truth = df_truth.set_index('example')
-        self._nPs = df_truth.ncells
-
-        dtypes = [('example', np.uint16),
-                  ('group', np.uint8),
-                  ('cell', np.uint16),
-                  ('area', np.uint16),
-                  ('p_edge', np.float64),
-                  ('containment', np.float64),
-                  ('assignments', np.uint16),
-                  ('max_IoU', np.float64),
-                  ('best_assignments', np.int32)]
-        df = pd.DataFrame(np.array(rows, dtype=dtypes))
-
-        df['is_best'] = ((df.best_assignments >= 0) &
-                         (df.max_IoU >= self.IoU_thresh))
-        df['eid'] = df.example
-        df['uid'] = tuple(zip(df.example, df.assignments))
-
-        # Generate a set of bootstrapping filters over 90% of the examples
-        examples = list(set(df_truth.index.values))
-        nperboot = np.round(self.bootstrap_frac * len(examples)).astype(int)
-        bootstraps = [
-            np.random.choice(examples, nperboot, replace=True)
-            for _ in range(self.nbootstraps)
-        ]
-        self._nPs_bootstraps = [df_truth.loc[b].sum() for b in bootstraps]
-        # Limit bootstrap examples to those present in segmentation output
-        bootstraps = [b[np.isin(b, df.example)] for b in bootstraps]
-        df.set_index('example', drop=False, inplace=True)
-        example_counts = df.example.value_counts()
-        self._stat_table_bootstraps = []
-        for b in bootstraps:
-            df_boot = df.loc[b]
-            # Renumber examples to handle the case of duplicated examples in
-            # the bootstrap:
-            df_boot['eid'] = tuple(
-                chain(*(repeat(i, example_counts.loc[e])
-                        for i, e in enumerate(b))))
-            df_boot['uid'] = tuple(zip(df_boot.eid, df_boot.assignments))
-            df_boot.set_index('uid', inplace=True)
-            self._stat_table_bootstraps.append(df_boot)
-
-        df.set_index('uid', inplace=True)
-        self._stat_table = df
-
-    def filter_trial(self,
-                     pedge_thresh,
-                     group_thresh_expansion,
-                     containment_thresh,
-                     min_area,
-                     bootstrap=True,
-                     return_stderr=False):
-        if bootstrap:
-            dfs = self.stat_table_bootstraps
-            truths = self.truth_bootstraps
-        else:
-            dfs = [self.stat_table]
-            truths = [self.truth.sum()]
-
-        uidcols = ['eid', 'assignments']
-        score_boots = []
-        for df, nT in zip(dfs, truths):
-            rejects = ((df.containment > containment_thresh) |
-                       (df.area < min_area))
-            for t_pe, g_ex, (g, l, u, gr) in zip(pedge_thresh,
-                                                 group_thresh_expansion,
-                                                 self.group_info):
-                g_ex = g_ex * gr
-                l = max(l - g_ex, 1)
-                u = u + g_ex
-                rejects |= (df.group == g) & ((df.p_edge < t_pe) |
-                                              (df.area < l) | (df.area > u))
-            TP_mask = (~rejects) & (df.max_IoU >= self.IoU_thresh)
-
-            # # TODO compare speed of:
-            # TPs_IoU = df.loc[TP_mask].groupby(uidcols).max_IoU.max()
-            # # with speed of:
-            TPs_IoU = []
-            current_eid = 0
-            asgn = {}
-            for m, e, a, iou in zip(TP_mask, df.eid, df.assignments,
-                    df.max_IoU):
-                if e != current_eid:
-                    TPs_IoU.extend(asgn.values())
-                    current_eid = e
-                    asgn = {}
-                if not m:
-                    continue
-                asgn[a] = max(iou, asgn.get(a, 0))
-            TPs_IoU.extend(asgn.values())
-            # # END TODO
-
-            nPs = np.sum(~rejects)
-            nTPs = len(TPs_IoU)
-            nFPs = nPs - nTPs
-            nFNs = nT - nTPs
-            precision = nTPs / (nTPs + nFPs)
-            recall = nTPs / (nTPs + nFNs)
-            # Fbeta = (1 + beta^2) * P * R / (beta^2 * P + R)
-            F1 = 2 * precision * recall / (precision + recall)
-            F0_5 = 1.25 * precision * recall / (0.25 * precision + recall)
-            F2 = 5 * precision * recall / (4 * precision + recall)
-            score_boots.append(
-                Score(precision, recall, F1, F0_5, F2, np.mean(TPs_IoU)))
-
-        score_boots = np.array(score_boots)
-        mean_score = Score(*score_boots.mean(axis=0))
-        if return_stderr:
-            stderr = score_boots.std(axis=0) / np.sqrt(score_boots.shape[0])
-            return (mean_score, Score(*stderr))
-        else:
-            return mean_score
-
-    def fit_filter_params(self, lazy=False, bootstrap=False):
-        # Define parameter grid values, firstly those not specific to a group
-        params = {
-            ('containment_thresh', None): np.linspace(0, 1, 21),
-            ('min_area', None): np.arange(0, 20, 1)
-        }
-
-        # Determine the pedge_threshold range based on the observed p_edge
-        # range for each group
-        if (self.stat_table.is_best.all() or not self.stat_table.is_best.any()):
-            t_pe_upper = self.stat_table.groupby('group').p_edge.mean()
-        else:
-            q_pe = self.stat_table.groupby(['group', 'is_best'])
-            q_pe = q_pe.p_edge.quantile([0.25, 0.95]).unstack((1, 2))
-            t_pe_upper = q_pe.loc[:, [(False, 0.95), (True, 0.25)]].mean(1)
-
-        t_pe_vals = [
-            np.arange(0, u, round_to_n(u / 20, 1)) for u in t_pe_upper
-        ]
-
-        # Set group-specific parameter grid values
-        g_ex_vals = repeat(np.linspace(0, 0.4, 21))
-        for g, (t_pe, g_ex) in enumerate(zip(t_pe_vals, g_ex_vals)):
-            params[('pedge_thresh', g)] = t_pe
-            params[('group_thresh_expansion', g)] = g_ex
-
-        # Default starting point is with thresholds off and no group expansion
-        ngroups = len(self.segrps)
-        dflt_params = {
-            'containment_thresh': 0,
-            'min_area': 0,
-            'pedge_thresh': list(repeat(0, ngroups)),
-            'group_thresh_expansion': list(repeat(0, ngroups))
-        }
-
-        # Search first along each parameter dimension with all others kept at
-        # default:
-        opt_params = {}
-        for k, pvals in params.items():
-            scrs = []
-            for v in pvals:
-                p = _sub_params({k: v}, dflt_params)
-                scr = self.filter_trial(**p, bootstrap=bootstrap)
-                scrs.append(getattr(scr, self.scoring))
-            maxInd = np.argmax(scrs)
-            opt_params[k] = pvals[maxInd]
-
-        # Reset the template parameters to the best along each dimension
-        base_params = _sub_params(opt_params, dflt_params)
-
-        if lazy:
-            # Simply repeat search along each parameter dimension, but now
-            # using the new optimum as a starting point
-            opt_params = {}
-            for k, pvals in params.items():
-                scrs = []
-                for v in pvals:
-                    p = _sub_params({k: v}, base_params)
-                    scr = self.filter_trial(**p, bootstrap=bootstrap)
-                    scrs.append(getattr(scr, self.scoring))
-                maxInd = np.argmax(scrs)
-                opt_params[k] = pvals[maxInd]
-            opt_params = _sub_params(opt_params, base_params)
-            scr = self.filter_trial(**opt_params, bootstrap=bootstrap)
-            self._opt_params = opt_params
-            self._opt_score = getattr(scr, self.scoring)
-            return
-
-        # Next perform a joint search for parameters with optimal pairings
-        opt_param_pairs = {k: {v} for k, v in opt_params.items()}
-        for k1, k2 in combinations(params.keys(), 2):
-            scrs = [(v1, v2,
-                     getattr(
-                         self.filter_trial(**_sub_params({
-                             k1: v1,
-                             k2: v2
-                         }, base_params),
-                                           bootstrap=bootstrap),
-                         self.scoring))
-                    for v1, v2 in product(params[k1], params[k2])]
-            p1opt, p2opt, _ = max(scrs, key=lambda x: x[2])
-            opt_param_pairs[k1].add(p1opt)
-            opt_param_pairs[k2].add(p2opt)
-
-        # Finally search over all combinations of the parameter values found
-        # with optimal pairings
-        scrs = []
-        for pvals in product(*opt_param_pairs.values()):
-            p = {k: v for k, v in zip(opt_param_pairs.keys(), pvals)}
-            p = _sub_params(p, base_params)
-            scrs.append((p,
-                         getattr(self.filter_trial(**p, bootstrap=bootstrap),
-                                 self.scoring)))
-
-        self._opt_params, self._opt_score = max(scrs, key=lambda x: x[1])
-
-
-def _sub_params(sub, param_template):
-    p = {
-        k: v.copy() if type(v) == list else v
-        for k, v in param_template.items()
-    }
-    for (k, g), v in sub.items():
-        if g is None:
-            p[k] = v
-        else:
-            p[k][g] = v
-    return p
diff --git a/python/baby/segmentation.py b/python/baby/segmentation.py
index 4b40d103914ca8899ee7c9bb00810084db33050e..633c3eb76c61b7d25eba68442248ef15748b905b 100644
--- a/python/baby/segmentation.py
+++ b/python/baby/segmentation.py
@@ -1,12 +1,14 @@
 # If you publish results that make use of this software or the Birth Annotator
 # for Budding Yeast algorithm, please cite:
-# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S
-# Swain, 2021, Birth Annotator for Budding Yeast (in preparation).
+# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+# and Swain, P.S. (2023). Determining growth rates from bright-field images of
+# budding cells through identifying overlaps. eLife. 12:e79812.
+# https://doi.org/10.7554/eLife.79812
 # 
 # 
 # The MIT License (MIT)
 # 
-# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021
+# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
 # 
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to
@@ -25,39 +27,1024 @@
 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
 # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
 # IN THE SOFTWARE.
-from collections import Iterable
-from itertools import chain, compress, repeat
+
 import numpy as np
-from numpy import newaxis as nax
-from scipy.ndimage import (minimum_filter, binary_dilation, binary_erosion,
-                           binary_closing, binary_opening, binary_fill_holes)
+from itertools import chain
+from typing import NamedTuple, Tuple
+import inspect
 from scipy import interpolate
 from scipy.optimize import least_squares
+from scipy.ndimage import (minimum_filter, binary_dilation, binary_erosion,
+                           binary_closing, binary_opening, binary_fill_holes)
+from skimage import filters
 from skimage.measure import label, regionprops
-from skimage.segmentation import morphological_geodesic_active_contour
 from skimage.morphology import diamond, erosion, dilation
-from skimage.draw import ellipse_perimeter
-from skimage import filters
 
-from .errors import BadParam
+from .errors import BadParam
+
+
+############################
+###  UTILITY FUNCTIONS   ###
+############################
+
+
+squareconn = diamond(1)  # 3x3 filter for 1-connected patches
+fullconn = np.ones((3, 3), dtype='uint8')
+
+
+def binary_edge(imfill, footprint=fullconn):
+    """Get square-connected edges from filled image:"""
+    return minimum_filter(imfill, footprint=footprint) != imfill
+
+
+def iterative_erosion(img, iterations=1, **kwargs):
+    if iterations < 1:
+        return img
+
+    for i in range(iterations):
+        img = erosion(img, **kwargs)
+    return img
+
+
+def iterative_dilation(img, iterations=1, **kwargs):
+    if iterations is None:
+        return img
+    for _ in range(iterations):
+        img = dilation(img, **kwargs)
+    return img
+
+
+def single_region_prop(mask):
+    return regionprops(mask.astype(np.uint8))[0]
+
+
+def get_edge_scores(outlines, p_edge):
+    """Return probability scores for a list of outlines
+
+    NB: BREAKING CHANGE. From July 2022 this has been corrected to take the
+    mean only over edge pixels rather than the entire image. Segmentation
+    parameters will need to be reoptimised.
+
+    :param outlines: list of outline images (2D bool)
+    :param p_edge: edge probability image (2D float with values in [0, 1])
+
+    :return: list of edge probability scores (each in [0, 1])
+    """
+    return [
+        p_edge[binary_dilation(o, iterations=2)].mean() for o in outlines
+    ]
+
+
+def mask_iou(a, b):
+    """Intersection over union (IoU) between boolean masks"""
+    return np.sum(a & b) / np.sum(a | b)
+
+
+def mask_containment(a, b):
+    """Max of intersection over a or over b"""
+    return np.max(np.sum(a & b) / np.array([np.sum(a), np.sum(b)]))
+
+
+def bbox_overlaps(bboxA, bboxB):
+    """Returns True if the bounding boxes are overlapping.
+
+    Args:
+        bboxA (tuple): A bounding box ``(min_row, min_col, max_row, max_col)``
+            as defined by :py:func:`skimage.measure.regionprops`.
+        bboxB (tuple): A bounding box ``(min_row, min_col, max_row, max_col)``
+            as defined by :py:func:`skimage.measure.regionprops`.
+    """
+    lrA, lcA, urA, ucA = regA.bbox
+    lrB, lcB, urB, ucB = regB.bbox
+    rA = np.array([lrA, urA])
+    cA = np.array([lcA, ucA])
+    return ((not ((rA > urB).all() or (rA < lrB).all())) and
+            (not ((cA > ucB).all() or (cA < lcB).all())))
+
+
+def region_iou(regA, regB):
+    """Efficiently computes the IoU between two RegionProperties."""
+    if bbox_overlaps(regA, regB):
+        bb_lr, bb_lc, _, _ = np.stack((regA.bbox, regB.bbox)).min(axis=0)
+        _, _, bb_ur, bb_uc = np.stack((regA.bbox, regB.bbox)).max(axis=0)
+        bboxA = np.zeros((bb_ur - bb_lr, bb_uc - bb_lc), dtype='bool')
+        bboxB = bboxA.copy()
+        bb = regA.bbox
+        bboxA[bb[0] - bb_lr:bb[2] - bb_lr,
+              bb[1] - bb_lc:bb[3] - bb_lc] = regA.image
+        bb = regB.bbox
+        bboxB[bb[0] - bb_lr:bb[2] - bb_lr,
+              bb[1] - bb_lc:bb[3] - bb_lc] = regB.image
+        return np.sum(bboxA & bboxB) / np.sum(bboxA | bboxB)
+    else:
+        return 0.
+
+
+def limit_to_bbox(imglist, bbox, padding=10):
+    imH, imW = imglist[0].shape[:2]
+    assert all([i.shape[:2] == (imH, imW) for i in imglist])
+    rmin, cmin, rmax, cmax = bbox
+    rmin = np.maximum(rmin - padding, 0)
+    cmin = np.maximum(cmin - padding, 0)
+    rmax = np.minimum(rmax + padding, imH)
+    cmax = np.minimum(cmax + padding, imW)
+    return (i[rmin:rmax, cmin:cmax] for i in imglist), (rmin, cmin, imH, imW)
+
+
+def restore_from_bbox(imglist, bbunmap):
+    bbH, bbW = imglist[0].shape[:2]
+    assert all([i.shape[:2] == (bbH, bbW) for i in imglist])
+    rmin, cmin, imH, imW = bbunmap
+    for img in imglist:
+        restored = np.zeros((imH, imW) + img.shape[2:], dtype=img.dtype)
+        restored[rmin:rmin + bbH, cmin:cmin + bbW] = img
+        yield restored
+
+
+def threshold_segmentation(p_int,
+                           interior_threshold=0.5,
+                           connectivity=None,
+                           nclosing=0,
+                           nopening=0,
+                           ndilate=0,
+                           return_area=False):
+    """Generic threshold-based segmentation routine for eroded interiors.
+
+    Finds cell mask instances from a thresholded probability image of cell
+    interiors (possibly trained on eroded targets). Each candidate mask is
+    then independently cleaned by binary closing and then binary opening.
+    Finally, masks are dilated to restore to an original size (i.e., to undo
+    the level of erosion applied to training target images).
+
+    Args:
+        p_int: An ndarray specifying cell interior probability for each pixel.
+        interior_threshold: Threshold to apply to the probability image.
+        connectivity: Connectivity as defined by `skimage.measure.label`.
+        nclosing: Number of iterations in closing operation; if 0 then binary
+            closing is not applied.
+        nopening: Number of iterations in opening operation; if 0 then binary
+            opening is not applied.
+        ndilate: Number of iterations of binary dilation to apply at end.
+        return_area: If True, then function yields (mask, area) tuples.
+
+    Yields:
+        Arrays with the same shape as `p_int` for each identified cell mask. If
+        `return_area=True`, then function yields (mask, area) tuples where the
+        area is the number of True pixels in the mask.
+    """
+
+    lbl, nmasks = label(p_int > interior_threshold,
+                        return_num=True,
+                        connectivity=connectivity)
+    for l in range(nmasks):
+        mask = lbl == l + 1
+        if nclosing > 0:
+            mask = binary_closing(mask, iterations=nclosing)
+        if nopening > 0:
+            mask = binary_opening(mask, iterations=nopening)
+        if ndilate > 0:
+            mask = binary_dilation(mask, iterations=ndilate)
+
+        if return_area:
+            yield mask, mask.sum()
+        else:
+            yield mask
+
+
+def ordered_edge_points(mask, edge=None, border_rect=None):
+    """Returns edge coordinates ordered around the mask perimeter.
+
+    Uses the tangent to the filled mask image to step pixel-wise around an
+    edge image and terminate when within a pixel length of the first pixel, or
+    when steps would be larger than 2 pixels to continue.  Note that this
+    function does not return all edge points, just those pixels closest to a 1
+    pixel step along the direction of the tangent.
+
+    Args:
+        mask (ndarray): A 2D bitmask for a single cell.
+        edge (None or ndarray): To save recomputation, optionally provide the
+            edge image obtained from :py:func:`binary_edge`.
+        border_rect (None or ndarray): To save recomputation, optionally
+            provide an image of same shape as the mask with ``False`` values
+            except for a 1-pixel border of ``True`` values.
+
+    Returns:
+        An ndarray of shape ``(N_edge_pixels, 2)`` giving row/column
+        coordinates for the ordered edge pixels.
+    """
+    if edge is None:
+        edge = binary_edge(mask)
+
+    if border_rect is None:
+        border_rect = np.pad(
+            np.zeros(tuple(x - 2 for x in mask.shape), dtype='bool'),
+            pad_width=1, mode='constant', constant_values=True)
+    
+    # Need to ensure an edge if mask is adjacent to the border
+    edge = binary_edge(mask) | (border_rect & mask)
+
+    X, Y = np.nonzero(edge)
+    edgepts = np.c_[X,Y]
+
+    # Use Sobel filter on Gaussian-blurred mask to estimate tangent along edge
+    mask_blur = filters.gaussian(mask, 1, mode='constant')
+    hgrad = filters.sobel_h(mask_blur)
+    vgrad = filters.sobel_v(mask_blur)
+    edge_hgrad = hgrad[edge]
+    edge_vgrad = vgrad[edge]
+    gradvec = np.c_[edge_hgrad, edge_vgrad]
+    normgradvec = gradvec / np.sqrt(np.sum(np.square(gradvec), axis=1)[:, None])
+    rot90 = np.array([[0, 1], [-1, 0]])
+    tngvec = np.matmul(normgradvec, rot90)
+
+    # Loop over edge points in direction of tangent to get ordered list
+    unvisited = np.ones(edgepts.shape[0], dtype='bool')
+    i = 0
+    start = edgepts[i]
+    ptorder = []
+    for o in range(unvisited.size):
+        unvisited[i] = False
+        nextpt = edgepts[[i]] + tngvec[[i]]
+        if np.sum(np.square(nextpt - start)) <= 2 and i > 0:
+            # We have returned to the first pixel
+            break
+        nextpt_sqdist = np.sum(np.square(edgepts[unvisited]-nextpt), axis=1)
+        minInd = np.argmin(nextpt_sqdist)
+        if nextpt_sqdist[minInd] > 4:
+            # We are jumping further than just neighbouring pixels
+            break
+        i = np.flatnonzero(unvisited)[minInd]
+        ptorder.append(i)
+        
+    return edgepts[ptorder]
+
+
+############################
+### RADIAL KNOT SPLINES  ###
+############################
+
+
+def rc_to_radial(rr_cc, centre):
+    """Converts row-column coordinates to radial coordinates.
+
+    Can be used to directly convert (row, column) output from the
+    `numpy.nonzero` function applied to a binary edge image
+    ``rc_to_radial(np.nonzero(edge_image), (0, 0))``
+
+    Args:
+        rr_cc: A tuple ``(rr, cc)`` of ndarrays specifying row and column
+            indices to be converted. ``rr`` and ``cc`` should have the same
+            shape.
+        centre: A tuple (rr, cc) specifying the origin of the radial
+            coordinate system.
+
+    Returns:
+        A tuple ``(rho, phi)`` of ndarrays of same shape as ``rr`` and ``cc``
+        giving the radii ``rho`` and angles ``phi`` in the radial coordinate
+        system.
+    """
+    rr, cc = rr_cc
+    rloc, cloc = centre
+    rr = rr - rloc
+    cc = cc - cloc
+    return np.sqrt(np.square(rr) + np.square(cc)), np.arctan2(cc, rr)
+
+
+def eval_radial_spline(x, rho, phi):
+    """Evaluates a radial spline with knots defined in radial coordinates.
+
+    The spline is periodic across the boundary.
+
+    Args:
+        x: An ndarray of angles at which to evaluate the spline.
+        rho: An ndarray of radii in[0, Inf) for each knot of the spline.
+        phi: An ndarray of angles in [-pi, pi) for each knot of the spline.
+
+    Returns:
+        An ndarray with same shape as x giving corresponding interpolated
+        radii.
+    """
+
+    # Angles need to be in increasing order for expected behaviour of phi as a
+    # parametric variable
+    order = np.argsort(phi)
+    rho = rho[order]
+    phi = phi[order]
+    offset = phi[0]
+
+    # Make the boundaries periodic
+    rho = np.r_[rho, rho[0]]
+    phi = np.r_[phi - offset, 2 * np.pi]
+
+    tck = interpolate.splrep(phi, rho, per=True)
+    return interpolate.splev(np.mod(x - offset, 2 * np.pi), tck)
+
+
+def eval_cartesian_spline(x, rho, phi):
+    """Evaluates a cartesian spline with knots defined in radial coordinates.
+
+    The spline interpolates between knots in order of appearance in the array
+    and is periodic over the end-points. Notably, the knots do not necessarily
+    need to be in order of increasing phi (though this will likely be the most
+    common case).
+
+    Args:
+        x: An ndarray of 'angles' at which to evaluate the spline. These are
+            not true radial angles as used for `phi` when defining the knots,
+            but are rather parametric locations defined over [0, 2*pi).
+        rho: An ndarray of radii in[0, Inf) for each knot of the spline.
+        phi: An ndarray of angles in [-pi, pi) for each knot of the spline.
+
+    Returns:
+        A tuple (Sx, Sy) of ndarrays with same shape as x, giving the
+        interpolated x and y coordinates corresponding to each parametric
+        location in x.
+    """
+
+    # Make the boundaries periodic
+    rho = np.r_[rho, rho[0]]
+    phi = np.r_[phi, phi[0]]
+
+    # TODO check whether the spline behaves better if the lengths between each
+    # knot are used to scale relative distance in the parametric variable
+
+    # Define splines according to a parametric variable over [0,2*pi)
+    t = np.linspace(0, 2*np.pi, rho.size)
+    Xtck = interpolate.splrep(t, rho*np.cos(phi), per=True)
+    Ytck = interpolate.splrep(t, rho*np.sin(phi), per=True)
+
+    # Evaluate over a modded x
+    Sx = interpolate.splev(np.mod(x, 2 * np.pi), Xtck)
+    Sy = interpolate.splev(np.mod(x, 2 * np.pi), Ytck)
+    return Sx, Sy
+
+
+def draw_radial(rho, phi, centre, shape, cartesian_spline=False):
+    """Renders a spline defined in radial coordinates as an image.
+
+    By default, interpolates spline in radial space.
+
+    Args:
+        rho: An ndarray of radii defining knot locations.
+        phi: An ndarray of angles defining knot locations.
+        centre: A sequence of length 2 specifying the row and column defining
+            the origin of the radial coordinate system.
+        shape: A sequence of length 2 specifying the height and width of the
+            image to be rendered to.
+        cartesian_spline: If True, then interpolate spline in cartesian space.
+
+    Returns:
+        An ndarray with the specified shape and of dtype bool containing the
+        rendered spline.
+    """
+
+    mr, mc = shape
+    im = np.zeros(shape, dtype='bool')
+
+    # Estimate required sampling density from lengths of piecewise linear segments
+    rho_loop = np.r_[rho, rho[0]]
+    phi_loop = np.r_[phi, phi[0]]
+    xy_loop = np.c_[rho_loop*np.cos(phi_loop), rho_loop*np.sin(phi_loop)]
+    linperim = np.sum(np.sqrt(np.sum(np.square(np.diff(xy_loop, axis=0)), axis=1)))
+    neval = np.round(2.5 * linperim).astype(int)
+
+    if neval > 1:
+        x = np.linspace(0, 2 * np.pi, neval)
+        if cartesian_spline:
+            Sx, Sy = eval_cartesian_spline(x, rho, phi)
+            rr = np.round(centre[0] + Sx).astype(int)
+            cc = np.round(centre[1] + Sy).astype(int)
+        else:
+            R = eval_radial_spline(x, rho, phi)
+            rr = np.round(centre[0] + R * np.cos(x)).astype(int)
+            cc = np.round(centre[1] + R * np.sin(x)).astype(int)
+        rr[rr < 0] = 0
+        rr[rr >= mr] = mr - 1
+        cc[cc < 0] = 0
+        cc[cc >= mc] = mc - 1
+    else:
+        rr = np.round(centre[0]).astype(int)
+        cc = np.round(centre[1]).astype(int)
+    im[rr, cc] = True
+    return im
+
+
+def _radii_from_outline(outline, centre, ray_angles, max_ray_len):
+    # Improve accuracy of edge position by smoothing the outline image and using
+    # weighted averaging of pixel positions below:
+    outline = filters.gaussian(outline, 0.5)
+
+    ray_tmplt = 0.5 * np.arange(np.round(2 * max_ray_len))[:, None]
+    rr_max, cc_max = outline.shape
+
+    radii = []
+    for angle in ray_angles:
+        ray = np.matmul(ray_tmplt, np.array((np.cos(angle), np.sin(angle)))[None, :])
+        ray = np.round(centre + ray).astype('int')
+        rr, cc = ray[:, 0], ray[:, 1]
+        ray = ray[(rr >= 0) & (rr < rr_max) & (cc >= 0) & (cc < cc_max), :]
+
+        edge_pix = np.flatnonzero(
+            np.squeeze(outline[ray[:, 0], ray[:, 1]]) > 0.01)
+
+        if len(edge_pix) == 0:
+            radii.append(np.NaN)
+            continue
+
+        ray = ray[edge_pix, :]
+        edge_pix = np.average(ray,
+                              weights=outline[ray[:, 0], ray[:, 1]],
+                              axis=0)
+        radii.append(np.sqrt(np.sum((edge_pix - centre)**2)))
+
+    return np.array(radii)
+    
+
+def guess_radial_edge(edge, mask=None, rprops=None,
+        nrays_thresh_map=[(5., 4), (20., 6)], nrays_max=8):
+    """Guesses knot placement for a radial spline by casting rays.
+
+    Given an edge image, this function casts out rays from the (filled)
+    mask centroid and finds where they intersect with the edge.
+
+    Args:
+        edge (ndarray): A 2D bitmask of the edge for a single cell.
+        mask (None or ndarray): To save recomputation, optionally provide a
+            mask image that is the filled version of the edge image. Not
+            required if `rprops` is specified.
+        rprops (None or RegionProps): To save recomputation, optionally
+            provide :py:func:`skimage.measure.regionprops` that have been
+            calculated on the mask image.        
+        nrays_thresh_map (List[Tuple[float, int]]): An ordered list of tuples
+            ``(upper_threshold, n_rays)`` that give an upper threshold on
+            major axis length for which the specified number of rays
+            ``n_rays`` will be used in stead of ``nrays_max``. The first
+            satisfied threshold in the list will be used to select the number
+            of rays.
+        nrays_max (int): The number of rays that will be used if the major
+            axis length is larger than all ``upper_threshold`` values
+            specified in ``nrays_thresh_map``.
+
+    Returns:
+        A tuple ``(rho, phi)`` of ndarrays giving knot locations in radial
+        coordinates (radii ``rho`` and angles ``phi``) with origin at mask
+        centroid as determined by :py:func:`skimage.measure.regionprops`.
+    """
+    if mask is None and rprops is None:
+        mask = binary_fill_holes(edge)
+
+    if rprops is None:
+        rprops = regionprops(mask.astype('int'))[0]
+
+    r_maj = rprops.major_axis_length
+    nrays = nrays_max
+    for upper_thresh, n in nrays_thresh_map:
+        if r_maj < upper_thresh:
+            nrays = n
+            break
+
+    RL, CL, RU, CU = rprops.bbox
+    bbH, bbW = RU - RL, CU - CL
+    bbdiag = np.sqrt(bbH * bbH + bbW * bbW)
+
+    astep = 2 * np.pi / nrays
+    angles = (np.mod(rprops.orientation + np.pi, astep) +
+              np.arange(nrays)*astep - np.pi)
+    centre = np.array(rprops.centroid)
+
+    radii = _radii_from_outline(edge, centre, angles, bbdiag)
+
+    # Use linear interpolation for any missing radii (e.g., if region intersects
+    # with image boundary):
+    nanradii = np.isnan(radii)
+    if nanradii.all():
+        radii = 0.1 * np.ones(angles.shape)
+    elif nanradii.any():
+        radii = np.interp(angles,
+                          angles[~nanradii],
+                          radii[~nanradii],
+                          period=2 * np.pi)
+
+    return radii, angles
+
+
+def guess_cartesian_edge(mask, p_edge, n_trials=10, pxperknot=5.,
+                         squishiness=1., alignedness=1., return_probs=False,
+                         edge=None, rprops=None, border_rect=None):
+    """Guesses knot placement for a cartesian spline by ordering edge pixels.
+
+    Uses :py:func:`ordered_edge_points` to obtain an ordered sequence of edge
+    pixels, then selects random subsets of these with roughly even
+    perimeter-wise spacing using a Dirichlet distribution. The random trial
+    with the highest probability as measured by ``p_edge`` is returned. 
+
+    The number of knots is always even and never fewer than four. The first
+    knot is biased to align with the major axis, but a normal distribution
+    allows for variation about this position (see ``alignedness`` argument).
+    The default parameters provide even sampling of knot positions over all
+    possible values. A small ``squishiness`` favours more regular spacing
+    between knots. A large ``alignedness`` favours alignment with the major
+    axis.
+
+    Args:
+        mask (ndarray): A 2D bitmask of a single cell (dtype bool).
+        p_edge (ndarray): A 2D image of edge probabilities.
+        n_trials (int): Number of random subsets to test.
+        pxperknot (float): Intended spacing (in pixels) between knots.
+        squishiness (float): Scaling factor for variance of the Dirichlet
+            distribution.
+        alignedness (float): Scaling factor for standard deviation of position
+            of the first knot.
+        return_probs (bool): Specify ``True`` to include probabilities of the
+            best trial, the edge image (determined by binary_edge) and the
+            image of all ordered edge pixels in the return value.
+        edge (None or ndarray): If a binary edge image has already been
+            calculated from the mask it can optionally be provided here to
+            save recomputation.
+        rprops (None or RegionProps): If
+            :py:func:`skimage.measure.regionprops` have already been
+            calculated for the mask, they can optionally be provided here to
+            save recomputation.
+        border_rect (None or ndarray): To save recomputation, optionally
+            provide an image of same shape as the mask with ``False`` values
+            except for a 1-pixel border of ``True`` values.
+
+    Returns:
+        A tuple ``(centre, rho, phi)`` of ndarrays specifying knot placement in
+        radial coordinates for the best trial. The radial coordinate system is
+        centred at ``centre`` (row/column format) and the knots have radii
+        given by ``rho`` and angles given by ``phi``. If
+        ``return_probs=True``, then a tuple ``(centre, rho, phi, trial_prob,
+        edge_prob, ordered_edge_prob)`` is returned to additionally give the
+        probability of the best trial, the edge image (as determined by
+        :py:func:`binary_edge`), and the image of ordered edge pixels (as
+        determined by :py:func:`ordered_edge_points`).
+    """
+    if edge is None:
+        edge = binary_edge(mask)
+        
+    # Define knots in radial coordinates at mask centroid
+    if rprops is None:
+        rprops = regionprops(mask.astype(np.uint8))[0]
+    ctr = np.r_[rprops.centroid]
+    
+    # Get ordered list of edge points in radial coordinates
+    edgepts = ordered_edge_points(mask, edge=edge, border_rect=border_rect)
+    edgepts_xy = edgepts - ctr[None, :]
+    edgepts_rho = np.sqrt(np.sum(np.square(edgepts_xy), axis=1))
+    # Need to flip X and Y here to match row/column format
+    edgepts_phi = np.arctan2(edgepts_xy[:, 1], edgepts_xy[:, 0])
+
+    # Candidate ellipse origins should differ from ellipsoid orientation by
+    # at most 1 arc pixel at half maximum radius
+    # <arc length> = <angle in radians> * <radius>
+    phi_tol = 2. / np.max(edgepts_rho)  # allowed tolerance
+    ori_dist = np.abs(edgepts_phi - rprops.orientation)
+    cand_ori = np.flatnonzero(ori_dist < phi_tol)
+    if len(cand_ori) > 0:
+        # Candidates are then filtered to those within one pixel of maximum
+        # radius from initial set of candidates
+        cand_ori = cand_ori[np.max(edgepts_rho[cand_ori]) - edgepts_rho[cand_ori] < 1]
+        ellipse_ori = cand_ori[np.argmin(ori_dist[cand_ori])]
+    else:
+        # Just pick closest to ellipse orientation
+        ellipse_ori = np.argmin(ori_dist)
+    
+    # Randomly select edge points as knots over multiple trials
+    # Keep the trial with the highest edge probability
+    N_pts = edgepts_rho.size
+    frac_ori = ellipse_ori / (N_pts - 1.)
+    # Choose an even number of knots, no less than 4
+    N_knots = np.maximum(2 * np.round(N_pts / 2 / pxperknot).astype('int'), 4)
+    cand_rho_phi = []
+    cand_probs = []
+    for _ in range(n_trials):
+        # Split snake into roughly equal segments using Dirichlet distribution
+        # We want the variance to scale as 1 / N_knots
+        # Variance is further scaled by squishiness
+        # Alignment with major ellipse axis is controlled by variance in
+        # additive normal noise
+        # defaults of squishiness=1 and alignedness=1 produce even sampling of
+        # knots over [0, N_pts)
+        alpha = N_knots * np.ones(N_knots) / squishiness
+        sec_len = np.random.dirichlet(alpha)
+        frac_ind = np.cumsum(sec_len)
+        frac_ind += np.random.normal(loc=frac_ori, scale=1 / N_knots / alignedness)
+        frac_ind = np.sort(np.mod(frac_ind, 1.))
+        inds = np.floor(N_pts * frac_ind).astype('int')
+
+        # Get the knot positions in radial coordinates
+        knots_rho = edgepts_rho[inds]
+        knots_phi = edgepts_phi[inds]
+
+        splIm = draw_radial(knots_rho, knots_phi, ctr, p_edge.shape,
+                            cartesian_spline=True)
+        probs = p_edge[splIm].mean()
+        cand_probs.append(probs)
+        cand_rho_phi.append((ctr, knots_rho, knots_phi))
+
+    indMax = np.argmax(cand_probs)
+    if return_probs:
+        raw_edge_prob = p_edge[edge].mean()
+        ord_edge_prob = p_edge[edgepts[:, 0], edgepts[:, 1]].mean()
+        return cand_rho_phi[indMax] + (cand_probs[indMax], raw_edge_prob, ord_edge_prob)
+    else:
+        return cand_rho_phi[indMax]
+
+
+def inds_max_curvature(R, N_pergrp, N_elim):
+    """Helper function for curvy_knots_from_outline.
+
+    Args:
+        R: row vector of radii or matrix with rows respectively giving x and y
+            coordinates.
+        N_pergrp: group knots into groups of this size.
+        N_elim: eliminate this many knots from each group.
+
+    Returns:
+        Indices for the knots that had the highest curvature in each group.
+    """
+    curvature = np.diff(np.diff(R[:, np.r_[-1, 0:R.shape[1], 0]], axis=1), axis=1)
+    curvature = np.sqrt(np.sum(curvature**2, axis=0))
+    knorder = np.argsort(np.reshape(curvature, (-1, N_pergrp)), axis=1)
+    rowinds = np.reshape(np.arange(knorder.size), (-1, N_pergrp))
+    rowinds = rowinds[np.tile(np.arange(knorder.shape[0])[:, None], (1, N_pergrp)), knorder]
+    return np.sort(rowinds[:, N_elim:].flatten())
+
+
+def curvy_knots_from_outline(outline, rprops, cartesian_spline=False,
+                             n_knots_fraction=0.5):
+    """Places knots on an outline image at points of high curvature.
+
+    The method essentially progressively eliminates knots with low curvature
+    from an initial dense equi-angular array of knots located on the outline.
+    Curvature is calculated as the difference of the difference between
+    consecutive locations (radii for a radial spline; the norm of x and y for
+    a cartesian spline). 
+
+    Args:
+        outline (ndarray): A 2D bitmask for a single cell edge.
+        rprops (RegionProperties): The output of
+            :py:func:`skimage.measure.regionprops` for the filled outline.
+        cartesian_spline (bool): Specify ``True`` to model curvature in
+            cartesian rather than radial coordinates.
+        n_knots_fraction (float): A value in ``[0, 1]`` that determines the
+            number of knots as a fraction of the number of edge pixels.
+
+    Returns:
+        A (radii, angles) tuple specifying the locations of the curvy knots in
+        radial coordinates from the `rprops` centroid
+    """
+    Nedgepx = outline.sum()
+    # Initial dense number of rays is one ray for every four edge pixels,
+    # rounded to the nearest power of 2:
+    # Nrays_dense = 2**(np.round(np.log(max(Nedgepx, 16)) / np.log(2)).astype(int) - 2)
+    # Initial dense number of rays is one ray for every eight edge pixels,
+    # rounded to the nearest power of 2:
+    Nrays_dense = 2**(np.round(np.log(max(Nedgepx, 32)) / np.log(2)).astype(int) - 3)
+
+    # The target number of rays rounded to the nearest four:
+    Nrays_final = 4 * max(1, np.round(0.25 * Nedgepx * n_knots_fraction).astype(int))
+
+    # Rays can be no larger than the diagonal of the bounding box
+    RL, CL, RU, CU = rprops.bbox
+    bbdiag = np.sqrt((RU - RL)**2 + (CU - CL)**2)
+
+    # Determine radii for initial dense array
+    astep = 2 * np.pi / Nrays_dense
+    angles = np.mod(rprops.orientation + np.pi, astep) + \
+        np.arange(Nrays_dense)*astep - np.pi
+    # Roughly compensate for elliptical squeezing by converting parameterised
+    # ellipse parameter to true angle. See:
+    # https://math.stackexchange.com/a/436125
+    axmaj, axmin = rprops.major_axis_length, rprops.minor_axis_length
+    angles -= np.arctan((axmaj - axmin) * np.tan(angles) / (axmaj + axmin * np.tan(angles)**2))
+
+    centre = np.array(rprops.centroid)
+    radii = _radii_from_outline(outline, centre, angles, bbdiag)
+
+    # Linearly interpolate any missing radii (e.g., if region intersects with
+    # image boundary):
+    nanradii = np.isnan(radii)
+    if nanradii.all():
+        radii = 0.1 * np.ones(angles.shape)
+    elif nanradii.any():
+        radii = np.interp(angles,
+                          angles[~nanradii],
+                          radii[~nanradii],
+                          period=2 * np.pi)
+
+    if cartesian_spline:
+        R = radii * np.vstack((np.cos(angles), np.sin(angles)))
+    else:
+        R = radii[None, :]
+
+    # Progressively eliminate knots four at a time until the desired number of
+    # knots is reached
+    inds = np.arange(R.shape[1])
+    for N_pergrp in np.r_[Nrays_dense // 4 : Nrays_final // 4 : -1]:
+        inds = inds[inds_max_curvature(R[:, inds], N_pergrp, 1)]
+        # Rotate indices over periodic boundary to mix groups
+        inds = inds[np.r_[-1,0:len(inds)-1]]
+    inds = np.sort(inds)
+
+    return radii[inds], angles[inds]
+
+
+class FakeRegionProperties(NamedTuple):
+    centroid: Tuple[int, int]
+    orientation: float
+    bbox: Tuple[int, int, int, int]
+    major_axis_length: float
+    minor_axis_length: float
+
+
+GUESS_RADIAL_EDGE_PARAMS = {
+    p.name for p in
+    inspect.signature(guess_radial_edge).parameters.values()
+    if p.kind == p.POSITIONAL_OR_KEYWORD and p.default != p.empty
+}
+
+
+GUESS_CARTESIAN_EDGE_PARAMS = {
+    p.name for p in
+    inspect.signature(guess_cartesian_edge).parameters.values()
+    if p.kind == p.POSITIONAL_OR_KEYWORD and p.default != p.empty
+}
+
+
+CURVY_KNOTS_FROM_OUTLINE_PARAMS = {
+    p.name for p in
+    inspect.signature(curvy_knots_from_outline).parameters.values()
+    if p.kind == p.POSITIONAL_OR_KEYWORD and p.default != p.empty
+}
+
+
+def mask_to_knots(mask,
+                  p_edge=None,
+                  return_outline=True,
+                  cartesian_spline=False,
+                  curvy_knots=False,
+                  bbox_padding=10,
+                  **kwargs):
+    """Guess knot positions from a cell bitmask.
+
+    Essentially a unified interface to the different methods for initial knot
+    placement before refinement. By default, it finds knots using
+    :py:func:`guess_radial_edge`. If ``cartesian_spline=True`` and
+    ``curvy_knots=False``, then knots are placed according to
+    :py:func:`guess_cartesian_edge`. If ``curvy_knots=True``, then knots are
+    found using :py:func:`curvy_knots_from_outline`.
+
+    For increased performance, processing is limited to a padded bounding box
+    of the mask image. Any return values are, however, restored to the input
+    size/coordinates.
+
+    Args:
+        mask (ndarray): A 2D bitmask for a single cell.
+        p_edge(ndarray): A 2D edge probability image with same shape as mask.
+            Currently only required when ``cartesian_spline=True`` and
+            ``curvy_knots=False``.
+        return_outline (bool): Specify ``False`` if you only want knot
+            coordinates returned.
+        cartesian_spline (bool): Specify ``True`` to determine knots for
+            splines interpolated in cartesian rather than radial coordinates.
+        curvy_knots (bool): Specify ``True`` to place knots at points of high
+            curvature according to :py:func:`curvy_knots_from_outline`.
+        bbox_padding (int): The number of pixels of padding that should be
+            included when limiting processing to the mask bounding box.
+        **kwargs: Additional arguments to be passed to any of the downstream
+            functions (:py:func:`guess_radial_edge`,
+            :py:func:`guess_cartesian_edge`, or
+            :py:func:`curvy_knots_from_outline`).
+
+    Returns:
+        A tuple ``(knot_coordinates, edge_image)`` of knot coordinates
+        ``(centre, knot_rho, knot_phi)`` defined in a radial coordinate system
+        with origin ``centre`` (a size two ndarray giving row/column
+        location), and ndarrays ``knot_rho`` and ``knot_phi`` giving the
+        radius and polar angle respectively for each knot. If
+        ``return_outline=False``, then only ``knot_coordinates`` are returned.
+    """
+
+    rprops = single_region_prop(mask)
+    if cartesian_spline and not curvy_knots:
+        (mask_bb, p_edge_bb), bbunmap = limit_to_bbox(
+            (mask, p_edge), rprops.bbox)
+    else:
+        (mask_bb,), bbunmap = limit_to_bbox((mask,), rprops.bbox)
+    edge_bb = binary_edge(mask_bb)
+
+    ctr = np.r_[rprops.centroid]
+    bboffset = np.r_[bbunmap[:2]]
+    limited_bb = np.r_[rprops.bbox] - np.tile(bboffset, 2)
+    rprops_bb = FakeRegionProperties(
+        centroid=ctr - bboffset,
+        orientation=rprops.orientation,
+        bbox=tuple(limited_bb.tolist()),
+        major_axis_length=rprops.major_axis_length,
+        minor_axis_length=rprops.minor_axis_length)
+
+    if curvy_knots:
+        params = {
+            k: v for k, v in kwargs.items()
+            if k in CURVY_KNOTS_FROM_OUTLINE_PARAMS
+        }
+        k_rho, k_phi = curvy_knots_from_outline(
+            edge_bb, rprops_bb, cartesian_spline=cartesian_spline, **params)
+    elif cartesian_spline:
+        params = {
+            k: v for k, v in kwargs.items()
+            if k in GUESS_CARTESIAN_EDGE_PARAMS
+        }
+        _, k_rho, k_phi = guess_cartesian_edge(
+            mask_bb, p_edge_bb, edge=edge_bb, rprops=rprops_bb,
+            return_probs=False, **params)
+    else:
+        params = {
+            k: v for k, v in kwargs.items() if k in GUESS_RADIAL_EDGE_PARAMS
+        }
+        k_rho, k_phi = guess_radial_edge(edge_bb, rprops=rprops_bb, **params)
+
+    if return_outline:
+        edge = draw_radial(k_rho, k_phi, ctr, mask.shape,
+                           cartesian_spline=cartesian_spline)
+        return (ctr, k_rho, k_phi), edge
+    else:
+        return ctr, k_rho, k_phi
+
+
+############################
+###  OUTLINE REFINEMENT  ###
+############################
+
+
+def prior_resid_weight(resid, gauss_scale=5, exp_scale=1):
+    """Weights radial residuals to bias towards the initial guess
+
+    Weight decays as a gaussian for positive residuals and exponentially for
+    negative residuals. So assuming `resid = rho_guess - rho_initial`, then
+    larger radii are favoured.
+    """
+    W = np.zeros(resid.shape)
+    W[resid >= 0] = np.exp(-resid[resid >= 0]**2 / gauss_scale)
+    W[resid < 0] = np.exp(resid[resid < 0] / exp_scale)
+    return W
+
+
+def adj_rspline_coords(adj, ref_radii, ref_angles):
+    """Map optimisation-space radial spline params to standard values
+
+    Params in optimisation space are specified relative to reference radii and
+    reference angles. If constrained to [-1, 1], optimisation parameters will
+    allow a 30% change in radius, or change in angle up to 1/4 the distance
+    between consecutive angles.
+    """
+    npoints = len(ref_radii)
+    return (
+        # allow up to 30% change in radius
+        ref_radii * (1 + 0.3 * adj[:npoints]),
+        # allow changes in angle up to 1/4 the distance between points
+        ref_angles + adj[npoints:] * np.pi / (2 * npoints))
+
+
+def adj_rspline_resid(adj, rho, phi, probs, ref_radii, ref_angles):
+    """Weighted residual for radial spline optimisation
+
+    Optimisation params (`adj`) are mapped according to `adj_rspline_coords`.
+    Target points are given in radial coordinates `rho` and `phi` with weights
+    `probs`. Optimisation is defined relative to radial spline params
+    `ref_radii` and `ref_angles`.
+    """
+    radii, angles = adj_rspline_coords(adj, ref_radii, ref_angles)
+    return probs * (rho - eval_radial_spline(phi, radii, angles))
+
+
+def adj_cart_spline_resid(adj, rho, phi, probs, ref_radii, ref_angles):
+    """Weighted residual for cartesian spline optimisation in radial
+    coordinate system
+
+    Optimisation params (`adj`) are mapped according to `adj_cart_spline_coords`.
+    Target points are given in radial coordinates `rho` and `phi` with weights
+    `probs`. Optimisation is defined relative to radial coordinates
+    `ref_radii` and `ref_angles`.
+    """
+    radii, angles = adj_rspline_coords(adj, ref_radii, ref_angles)
+    return probs * (rho - eval_cartesian_spline(phi, radii, angles))
+
+
+def refine_radial_grouped(grouped_coords, grouped_p_edges, cartesian_spline=False):
+    """Refine initial radial spline by optimising to predicted edge
+
+    Neighbouring groups are used to re-weight predicted edges belonging to
+    other cells using the initial guess
+    """
+
+    if cartesian_spline:
+        eval_spline = eval_cartesian_spline
+        adj_resid = adj_cart_spline_resid
+    else:
+        eval_spline = eval_radial_spline
+        adj_resid = adj_rspline_resid
+
+    # Determine edge pixel locations and probabilities from NN prediction
+    p_edge_locs = [np.nonzero(p_edge > 0.2) for p_edge in grouped_p_edges]
+    p_edge_probs = [
+        p_edge[rr, cc]
+        for p_edge, (rr, cc) in zip(grouped_p_edges, p_edge_locs)
+    ]
+
+    p_edge_count = [len(rr) for rr, _ in p_edge_locs]
+
+    opt_coords = []
+    ngroups = len(grouped_coords)
+    for g, g_coords in enumerate(grouped_coords):
+        # If this group has no predicted edges, keep initial and skip
+        if p_edge_count[g] == 0:
+            opt_coords.append(g_coords)
+            continue
+
+        # Compile a list of all cells in this and neighbouring groups
+        nbhd = list(
+            chain.from_iterable([
+                [((gi, ci), coords)
+                 for ci, coords in enumerate(grouped_coords[gi])]
+                for gi in range(max(g - 1, 0), min(g + 2, ngroups))
+                if
+                p_edge_count[gi] > 0  # only keep if there are predicted edges
+            ]))
+        if len(nbhd) > 0:
+            nbhd_ids, nbhd_coords = zip(*nbhd)
+        else:
+            nbhd_ids, nbhd_coords = 2 * [[]]
+
+        # Calculate edge pixels in radial coords for all cells in this and
+        # neighbouring groups:
+        radial_edges = [
+            rc_to_radial(p_edge_locs[g], centre)
+            for centre, _, _ in nbhd_coords
+        ]
+
+        # Calculate initial residuals and prior weights
+        resids = [
+            rho - eval_spline(phi, radii, angles)
+            for (rho, phi), (_, radii,
+                             angles) in zip(radial_edges, nbhd_coords)
+        ]
+        indep_weights = [prior_resid_weight(r) for r in resids]
+
+        probs = p_edge_probs[g]
+
+        g_opt_coords = []
+        for c, (centre, radii, angles) in enumerate(g_coords):
+            ind = nbhd_ids.index((g, c))
+            rho, phi = radial_edges[ind]
+            p_weighted = probs * indep_weights[ind]
+            other_weights = indep_weights[:ind] + indep_weights[ind + 1:]
+            if len(other_weights) > 0:
+                p_weighted *= (1 - np.mean(other_weights, axis=0))
+
+            # Remove insignificant fit data
+            signif = p_weighted > 0.1
+            if signif.sum() < 10:
+                # With insufficient data, skip optimisation
+                g_opt_coords.append((centre, radii, angles))
+                continue
+            p_weighted = p_weighted[signif]
+            phi = phi[signif]
+            rho = rho[signif]
+
+            nparams = len(radii) + len(angles)
+            opt = least_squares(adj_resid,
+                                np.zeros(nparams),
+                                bounds=(-np.ones(nparams), np.ones(nparams)),
+                                args=(rho, phi, p_weighted, radii, angles),
+                                ftol=5e-2)
 
-squareconn = diamond(1)  # 3x3 filter for 1-connected patches
-fullconn = np.ones((3, 3), dtype='uint8')
+            g_opt_coords.append((centre,) +
+                                adj_rspline_coords(opt.x, radii, angles))
 
+        opt_coords.append(g_opt_coords)
 
-def binary_edge(imfill, footprint=fullconn):
-    """Get square-connected edges from filled image:"""
-    return minimum_filter(imfill, footprint=footprint) != imfill
+    return opt_coords
 
 
-def mask_iou(a, b):
-    """Intersection over union (IoU) between boolean masks"""
-    return np.sum(a & b) / np.sum(a | b)
+############################
+### DEPRECATED FUNCTIONS ###
+############################
 
 
-def mask_containment(a, b):
-    """Max of intersection over a or over b"""
-    return np.max(np.sum(a & b) / np.array([np.sum(a), np.sum(b)]))
+def get_regions(p_img, threshold):
+    """Find regions in a probability image sorted by likelihood"""
+    p_thresh = p_img > threshold
+    p_label = label(p_thresh, background=0)
+    rprops = regionprops(p_label, p_img)
+    rprops = [
+        r for r in rprops
+        if r.major_axis_length > 0 and r.minor_axis_length > 0
+    ]
+    rprops.sort(key=lambda x: x.mean_intensity, reverse=True)
+    return rprops
 
 
 def morph_thresh_masks(p_interior,
@@ -115,173 +1102,20 @@ def unique_masks(masks, ref_masks, threshold=0.5, iou_func=mask_iou):
     return umasks
 
 
-def morph_thresh_seg(cnn_outputs,
-                     interior_threshold=0.9,
-                     overlap_threshold=0.9,
-                     bud_threshold=0.9,
-                     bud_dilate=False,
-                     bud_overlap=False,
-                     isbud_threshold=0.5):
-    """Segment cell outlines from morphology output of CNN by thresholding
-
-    Specify `overlap_threshold` or `bud_threshold` as `None` to ignore.
-    """
-
-    _, _, p_interior, p_overlap, _, p_bud = cnn_outputs
-
-    if overlap_threshold is None:
-        p_overlap = None
-
-    if isbud_threshold is None and bud_threshold is not None:
-        p_interior = p_interior * (1 - p_bud)
-
-    masks = morph_thresh_masks(p_interior,
-                               interior_threshold=interior_threshold,
-                               p_overlap=p_overlap,
-                               overlap_threshold=overlap_threshold)
-
-    if bud_threshold is not None:
-        if not bud_overlap:
-            p_overlap = None
-
-        budmasks = morph_thresh_masks(p_bud,
-                                      interior_threshold=bud_threshold,
-                                      dilate=False,
-                                      p_overlap=p_overlap,
-                                      overlap_threshold=overlap_threshold)
-
-        if isbud_threshold is not None:
-            # Omit interior masks if they overlap with bud masks
-            masks = unique_masks(masks,
-                                 budmasks,
-                                 iou_func=mask_containment,
-                                 threshold=isbud_threshold) + budmasks
-
-    # Return only the mask outlines
-    outlines = [minimum_filter(m, footprint=squareconn) != m for m in masks]
-
-    return outlines
-
-
-def get_regions(p_img, threshold):
-    """Find regions in a probability image sorted by likelihood"""
-    p_thresh = p_img > threshold
-    p_label = label(p_thresh, background=0)
-    rprops = regionprops(p_label, p_img)
-    rprops = [
-        r for r in rprops
-        if r.major_axis_length > 0 and r.minor_axis_length > 0
-    ]
-    rprops.sort(key=lambda x: x.mean_intensity, reverse=True)
-    return rprops
-
-
-def bbox_overlaps(regA, regB):
-    """Returns True if the regions have overlapping bounding boxes"""
-    lrA, lcA, urA, ucA = regA.bbox
-    lrB, lcB, urB, ucB = regB.bbox
-    rA = np.array([lrA, urA])
-    cA = np.array([lcA, ucA])
-    return ((not ((rA > urB).all() or (rA < lrB).all())) and
-            (not ((cA > ucB).all() or (cA < lcB).all())))
-
-
-def region_iou(regA, regB):
-    if bbox_overlaps(regA, regB):
-        bb_lr, bb_lc, _, _ = np.stack((regA.bbox, regB.bbox)).min(axis=0)
-        _, _, bb_ur, bb_uc = np.stack((regA.bbox, regB.bbox)).max(axis=0)
-        bboxA = np.zeros((bb_ur - bb_lr, bb_uc - bb_lc), dtype='bool')
-        bboxB = bboxA.copy()
-        bb = regA.bbox
-        bboxA[bb[0] - bb_lr:bb[2] - bb_lr,
-              bb[1] - bb_lc:bb[3] - bb_lc] = regA.image
-        bb = regB.bbox
-        bboxB[bb[0] - bb_lr:bb[2] - bb_lr,
-              bb[1] - bb_lc:bb[3] - bb_lc] = regB.image
-        return np.sum(bboxA & bboxB) / np.sum(bboxA | bboxB)
-    else:
-        return 0.0
-
-
-def morph_ellipse_seg(cnn_outputs,
-                      interior_threshold=0.9,
-                      overlap_threshold=0.9,
-                      bud_threshold=0.9,
-                      bud_dilate=False,
-                      bud_overlap=False,
-                      isbud_threshold=0.5,
-                      scaling=1.0,
-                      offset=0):
-    """Segment cell outlines from morphology output of CNN as region ellipses
-
-    Specify `overlap_threshold` or `bud_threshold` as `None` to ignore.
-    """
-
-    _, _, p_interior, p_overlap, _, p_bud = cnn_outputs
-
-    if overlap_threshold is None:
-        p_overlap = None
-
-    masks = morph_thresh_masks(p_interior,
-                               interior_threshold=interior_threshold,
-                               p_overlap=p_overlap,
-                               overlap_threshold=overlap_threshold)
-
-    if bud_threshold is not None:
-        if not bud_overlap:
-            p_overlap = None
-
-        budmasks = morph_thresh_masks(p_bud,
-                                      interior_threshold=bud_threshold,
-                                      dilate=False,
-                                      p_overlap=p_overlap,
-                                      overlap_threshold=overlap_threshold)
-
-        # Omit interior masks if they overlap with bud masks
-        masks = budmasks + unique_masks(
-            masks, budmasks, threshold=isbud_threshold)
-
-    rprops = [regionprops(m.astype('int'))[0] for m in masks]
-    rprops = [
-        r for r in rprops
-        if r.major_axis_length > 0 and r.minor_axis_length > 0
-    ]
-
-    outlines = []
-    for region in rprops:
-        r, c = np.round(region.centroid).astype('int')
-        r_major = np.round(scaling * region.major_axis_length / 2 +
-                           offset).astype('int')
-        r_minor = np.round(scaling * region.minor_axis_length / 2 +
-                           offset).astype('int')
-        orientation = -region.orientation
-        rr, cc = ellipse_perimeter(r,
-                                   c,
-                                   r_major,
-                                   r_minor,
-                                   orientation=orientation,
-                                   shape=p_interior.shape)
-        outline = np.zeros(p_interior.shape, dtype='bool')
-        outline[rr, cc] = True
-        outlines.append(outline)
-
-    return outlines
-
-
 def get_edge_force(rprop, shape):
     r_major = rprop.major_axis_length / 2
     r_minor = rprop.minor_axis_length / 2
     angle = -rprop.orientation
     nr = shape[0]
     nc = shape[1]
-    xmat = np.matmul(np.arange(0, nr)[:, nax], np.ones((1, nc)))
-    ymat = np.matmul(np.ones((nr, 1)), np.arange(0, nc)[nax, :])
+    xmat = np.matmul(np.arange(0, nr)[:, None], np.ones((1, nc)))
+    ymat = np.matmul(np.ones((nr, 1)), np.arange(0, nc)[None, :])
     xy = np.vstack([np.reshape(xmat, (1, -1)), np.reshape(ymat, (1, -1))])
     rotmat = np.array([[np.cos(angle), -np.sin(angle)],
                        [np.sin(angle), np.cos(angle)]])
     radial_index = np.matmul(
-        rotmat, (xy - np.array(rprop.centroid)[:, nax])) / np.array(
-            [r_major, r_minor])[:, nax]
+        rotmat, (xy - np.array(rprop.centroid)[:, None])) / np.array(
+            [r_major, r_minor])[:, None]
     return np.reshape(1 - np.exp(-np.sum((radial_index)**2, 0)), (nr, nc))
 
 
@@ -337,6 +1171,7 @@ def morph_ac_seg(cnn_outputs,
     edge_thresh = p_edge > ac_edge_threshold
     over_thresh = p_over > ac_overlap_threshold
 
+    from skimage.segmentation import morphological_geodesic_active_contour
     for j, (force, mask) in enumerate(zip(e_forces, masks)):
         ij_edge_im = p_edge.copy()
 
@@ -365,116 +1200,119 @@ def morph_ac_seg(cnn_outputs,
     return outlines
 
 
-def eval_radial_spline(x, rho, phi):
-    """Evaluate a radial spline defined by radii and angles
-    rho: vector of radii for each point defining the spline
-    phi: angles in [-pi,pi) defining points of the spline
-    The spline is periodic across the boundary
+def morph_ellipse_seg(cnn_outputs,
+                      interior_threshold=0.9,
+                      overlap_threshold=0.9,
+                      bud_threshold=0.9,
+                      bud_dilate=False,
+                      bud_overlap=False,
+                      isbud_threshold=0.5,
+                      scaling=1.0,
+                      offset=0):
+    """Segment cell outlines from morphology output of CNN as region ellipses
+
+    Specify `overlap_threshold` or `bud_threshold` as `None` to ignore.
     """
 
-    # Angles need to be in increasing order to correctly loop over boundary
-    order = np.argsort(phi)
-    rho = rho[order]
-    phi = phi[order]
-    offset = phi[0]
+    _, _, p_interior, p_overlap, _, p_bud = cnn_outputs
 
-    # Make the boundaries periodic
-    rho = np.concatenate((rho, rho[0, nax]))
-    phi = np.concatenate((phi - offset, (2 * np.pi,)))
+    if overlap_threshold is None:
+        p_overlap = None
 
-    tck = interpolate.splrep(phi, rho, per=True)
-    try:
-        return interpolate.splev(np.mod(x - offset, 2 * np.pi), tck)
-    except ValueError as err:
-        print('x:')
-        print(x)
-        print('rho:')
-        print(rho)
-        print('phi:')
-        print(phi)
-        raise err
-
-
-def morph_radial_thresh_fit(outline, mask=None, rprops=None):
-    if mask is None and rprops is None:
-        mask = binary_fill_holes(outline)
+    masks = morph_thresh_masks(p_interior,
+                               interior_threshold=interior_threshold,
+                               p_overlap=p_overlap,
+                               overlap_threshold=overlap_threshold)
 
-    if rprops is None:
-        rprops = regionprops(mask.astype('int'))[0]
+    if bud_threshold is not None:
+        if not bud_overlap:
+            p_overlap = None
 
-    r_maj = rprops.major_axis_length
-    nrays = 4 if r_maj < 5 else 6 if r_maj < 20 else 8
+        budmasks = morph_thresh_masks(p_bud,
+                                      interior_threshold=bud_threshold,
+                                      dilate=False,
+                                      p_overlap=p_overlap,
+                                      overlap_threshold=overlap_threshold)
 
-    RL, CL, RU, CU = rprops.bbox
-    bbdiag = np.sqrt((RU - RL)**2 + (CU - CL)**2)
-    rr_max, cc_max = outline.shape
+        # Omit interior masks if they overlap with bud masks
+        masks = budmasks + unique_masks(
+            masks, budmasks, threshold=isbud_threshold)
+
+    rprops = [regionprops(m.astype('int'))[0] for m in masks]
+    rprops = [
+        r for r in rprops
+        if r.major_axis_length > 0 and r.minor_axis_length > 0
+    ]
+
+    outlines = []
+
+    from skimage.draw import ellipse_perimeter
+    for region in rprops:
+        r, c = np.round(region.centroid).astype('int')
+        r_major = np.round(scaling * region.major_axis_length / 2 +
+                           offset).astype('int')
+        r_minor = np.round(scaling * region.minor_axis_length / 2 +
+                           offset).astype('int')
+        orientation = -region.orientation
+        rr, cc = ellipse_perimeter(r,
+                                   c,
+                                   r_major,
+                                   r_minor,
+                                   orientation=orientation,
+                                   shape=p_interior.shape)
+        outline = np.zeros(p_interior.shape, dtype='bool')
+        outline[rr, cc] = True
+        outlines.append(outline)
+
+    return outlines
 
-    astep = 2 * np.pi / nrays
-    angles = np.mod(rprops.orientation + np.pi, astep) + \
-        np.arange(nrays)*astep - np.pi
-    centre = np.array(rprops.centroid)
 
-    # Improve accuracy of edge position by smoothing the outline image and using
-    # weighted averaging of pixel positions below:
-    outline = filters.gaussian(outline, 0.5)
+def morph_thresh_seg(cnn_outputs,
+                     interior_threshold=0.9,
+                     overlap_threshold=0.9,
+                     bud_threshold=0.9,
+                     bud_dilate=False,
+                     bud_overlap=False,
+                     isbud_threshold=0.5):
+    """Segment cell outlines from morphology output of CNN by thresholding
 
-    radii = []
-    for angle in angles:
-        ray = np.matmul(0.5 * np.arange(np.round(2 * bbdiag))[:, nax],
-                        np.array((np.cos(angle), np.sin(angle)))[nax, :])
-        ray = np.round(centre + ray).astype('int')
-        rr, cc = (ray[:, 0], ray[:, 1])
-        ray = ray[(rr >= 0) & (rr < rr_max) & (cc >= 0) & (cc < cc_max), :]
+    Specify `overlap_threshold` or `bud_threshold` as `None` to ignore.
+    """
 
-        edge_pix = np.flatnonzero(
-            np.squeeze(outline[ray[:, 0], ray[:, 1]]) > 0.01)
+    _, _, p_interior, p_overlap, _, p_bud = cnn_outputs
 
-        if len(edge_pix) == 0:
-            radii.append(np.NaN)
-            continue
+    if overlap_threshold is None:
+        p_overlap = None
 
-        ray = ray[edge_pix, :]
-        edge_pix = np.average(ray,
-                              weights=outline[ray[:, 0], ray[:, 1]],
-                              axis=0)
-        radii.append(np.sqrt(np.sum((edge_pix - centre)**2)))
+    if isbud_threshold is None and bud_threshold is not None:
+        p_interior = p_interior * (1 - p_bud)
 
-    radii = np.array(radii)
+    masks = morph_thresh_masks(p_interior,
+                               interior_threshold=interior_threshold,
+                               p_overlap=p_overlap,
+                               overlap_threshold=overlap_threshold)
 
-    # Use linear interpolation for any missing radii (e.g., if region intersects
-    # with image boundary):
-    nanradii = np.isnan(radii)
-    if nanradii.all():
-        radii = 0.1 * np.ones(angles.shape)
-    elif nanradii.any():
-        radii = np.interp(angles,
-                          angles[~nanradii],
-                          radii[~nanradii],
-                          period=2 * np.pi)
+    if bud_threshold is not None:
+        if not bud_overlap:
+            p_overlap = None
 
-    return radii, angles
+        budmasks = morph_thresh_masks(p_bud,
+                                      interior_threshold=bud_threshold,
+                                      dilate=False,
+                                      p_overlap=p_overlap,
+                                      overlap_threshold=overlap_threshold)
 
+        if isbud_threshold is not None:
+            # Omit interior masks if they overlap with bud masks
+            masks = unique_masks(masks,
+                                 budmasks,
+                                 iou_func=mask_containment,
+                                 threshold=isbud_threshold) + budmasks
 
-def draw_radial(radii, angles, centre, shape):
-    mr, mc = shape
-    im = np.zeros(shape, dtype='bool')
-    neval = np.round(4 * np.pi * np.max(radii)).astype('int')
-    if neval > 1:
-        phi = np.linspace(0, 2 * np.pi, neval)
-        rho = eval_radial_spline(phi, radii, angles)
-    else:
-        phi = 0
-        rho = 0
-    rr = np.round(centre[0] + rho * np.cos(phi)).astype('int')
-    cc = np.round(centre[1] + rho * np.sin(phi)).astype('int')
-    rr[rr < 0] = 0
-    rr[rr >= mr] = mr - 1
-    cc[cc < 0] = 0
-    cc[cc >= mc] = mc - 1
-    im[rr, cc] = True
-    # valid = (rr >= 0) & (cc >= 0) & (rr < mr) & (cc < mc)
-    # im[rr[valid], cc[valid]] = True
-    return im
+    # Return only the mask outlines
+    outlines = [minimum_filter(m, footprint=squareconn) != m for m in masks]
+
+    return outlines
 
 
 def morph_radial_thresh_seg(cnn_outputs,
@@ -532,7 +1370,7 @@ def morph_radial_thresh_seg(cnn_outputs,
 
     outlines = []
     for mask, outline, rp in zip(masks, mseg, rprops):
-        radii, angles = morph_radial_thresh_fit(outline, mask, rp)
+        radii, angles = guess_radial_edge(outline, mask, rp)
         if np.any(np.isnan(radii)):
             return mask, outline, rp
         outlines.append(draw_radial(radii, angles, rp.centroid, shape))
@@ -540,72 +1378,6 @@ def morph_radial_thresh_seg(cnn_outputs,
     return outlines
 
 
-def thresh_seg(p_int,
-               interior_threshold=0.5,
-               connectivity=None,
-               nclosing=0,
-               nopening=0,
-               ndilate=0,
-               return_area=False):
-    """Segment cell outlines from morphology output of CNN by fitting radial
-    spline to threshold output
-    """
-
-    lbl, nmasks = label(p_int > interior_threshold,
-                        return_num=True,
-                        connectivity=connectivity)
-    for l in range(nmasks):
-        mask = lbl == l + 1
-        if nclosing > 0:
-            mask = binary_closing(mask, iterations=nclosing)
-        if nopening > 0:
-            mask = binary_opening(mask, iterations=nopening)
-        if ndilate > 0:
-            mask = binary_dilation(mask, iterations=ndilate)
-
-        if return_area:
-            yield mask, mask.sum()
-        else:
-            yield mask
-
-
-def single_region_prop(mask):
-    return regionprops(mask.astype('int'))[0]
-
-
-def outline_to_radial(outline, rprop, return_outline=False):
-    coords = (rprop.centroid,) + morph_radial_thresh_fit(outline, None, rprop)
-    if return_outline:
-        centroid, radii, angles = coords
-        outlines = draw_radial(radii, angles, centroid, outline.shape)
-        return coords, outlines
-    else:
-        return coords
-
-
-def get_edge_scores(outlines, p_edge):
-    return [
-        (p_edge * binary_dilation(o, iterations=2)).mean() for o in outlines
-    ]
-
-
-def iterative_erosion(img, iterations=1, **kwargs):
-    if iterations < 1:
-        return img
-
-    for i in range(iterations):
-        img = erosion(img, **kwargs)
-    return img
-
-
-def iterative_dilation(img, iterations=1, **kwargs):
-    if iterations is None:
-        return img
-    for _ in range(iterations):
-        img = dilation(img, **kwargs)
-    return img
-
-
 def morph_seg_grouped(pred,
                       flattener,
                       cellgroups=['large', 'medium', 'small'],
@@ -626,6 +1398,8 @@ def morph_seg_grouped(pred,
                       return_coords=False):
     """Morphological segmentation for model predictions of flattener targets
 
+    DEPRECATED. Use morph_thres_seg.MorphSegGrouped class.
+
     :param pred: list of prediction images (ndarray with shape (x, y))
         matching `flattener.names()`
     :param flattener: an instance of `SegmentationFlattening` defining the
@@ -737,13 +1511,10 @@ def morph_seg_grouped(pred,
 
         masks_areas = [
             (m, a)
-            for m, a in thresh_seg(p_int,
-                                   interior_threshold=thresh or 0.5,
-                                   nclosing=nc or 0,
-                                   nopening=no or 0,
-                                   ndilate=max_ne,
-                                   return_area=True,
-                                   connectivity=conn)
+            for m, a in threshold_segmentation(
+                p_int, interior_threshold=thresh or 0.5,
+                nclosing=nc or 0, nopening=no or 0,
+                ndilate=max_ne, return_area=True, connectivity=conn)
             if a >= lower and a < upper
         ]
 
@@ -752,17 +1523,15 @@ def morph_seg_grouped(pred,
         else:
             masks, areas = [], []
 
-        edges = [binary_edge(m) for m in masks]
-
         if fit_radial:
-            rprops = [single_region_prop(m) for m in masks]
             coords, edges = list(
                 zip(*[
-                    outline_to_radial(edge, rprop, return_outline=True)
-                    for edge, rprop in zip(edges, rprops)
+                    mask_to_knots(mask)
+                    for mask in masks
                 ])) or ([], [])
             masks = [binary_fill_holes(o) for o in edges]
         else:
+            edges = [binary_edge(m) for m in masks]
             edges = [e | (border_rect & m) for e, m in zip(edges, masks)]
             coords = [tuple()] * len(masks)
 
@@ -847,144 +1616,6 @@ def morph_seg_grouped(pred,
         return edges
 
 
-def rc_to_radial(rr_cc, centre):
-    """Helper function to convert row-column coords to radial coords
-    """
-    rr, cc = rr_cc
-    rloc, cloc = centre
-    rr = rr - rloc
-    cc = cc - cloc
-    return np.sqrt(rr**2 + cc**2), np.arctan2(cc, rr)
-
-
-def prior_resid_weight(resid, gauss_scale=5, exp_scale=1):
-    """Weights radial residuals to bias towards the initial guess
-
-    Weight decays as a gaussian for positive residuals and exponentially for
-    negative residuals. So assuming `resid = rho_guess - rho_initial`, then
-    larger radii are favoured.
-    """
-    return (resid >= 0) * np.exp(-resid**2 / gauss_scale) \
-            + (resid < 0) * np.exp(resid / exp_scale)
-
-
-def adj_rspline_coords(adj, ref_radii, ref_angles):
-    """Map optimisation-space radial spline params to standard values
-
-    Params in optimisation space are specified relative to reference radii and
-    reference angles. If constrained to [-1, 1], optimisation parameters will
-    allow a 30% change in radius, or change in angle up to 1/4 the distance
-    between consecutive angles.
-    """
-    npoints = len(ref_radii)
-    return (
-        # allow up to 30% change in radius
-        ref_radii * (1 + 0.3 * adj[:npoints]),
-        # allow changes in angle up to 1/4 the distance between points
-        ref_angles + adj[npoints:] * np.pi / (2 * npoints))
-
-
-def adj_rspline_resid(adj, rho, phi, probs, ref_radii, ref_angles):
-    """Weighted residual for radial spline optimisation
-
-    Optimisation params (`adj`) are mapped according to `adj_rspline_coords`.
-    Target points are given in radial coordinates `rho` and `phi` with weights
-    `probs`. Optimisation is defined relative to radial spline params
-    `ref_radii` and `ref_angles`.
-    """
-    radii, angles = adj_rspline_coords(adj, ref_radii, ref_angles)
-    return probs * (rho - eval_radial_spline(phi, radii, angles))
-
-
-def refine_radial_grouped(grouped_coords, grouped_p_edges):
-    """Refine initial radial spline by optimising to predicted edge
-
-    Neighbouring groups are used to re-weight predicted edges belonging to
-    other cells using the initial guess
-    """
-
-    # Determine edge pixel locations and probabilities from NN prediction
-    p_edge_locs = [np.where(p_edge > 0.2) for p_edge in grouped_p_edges]
-    p_edge_probs = [
-        p_edge[rr, cc]
-        for p_edge, (rr, cc) in zip(grouped_p_edges, p_edge_locs)
-    ]
-
-    p_edge_count = [len(rr) for rr, _ in p_edge_locs]
-
-    opt_coords = []
-    ngroups = len(grouped_coords)
-    for g, g_coords in enumerate(grouped_coords):
-        # If this group has no predicted edges, keep initial and skip
-        if p_edge_count[g] == 0:
-            opt_coords.append(g_coords)
-            continue
-
-        # Compile a list of all cells in this and neighbouring groups
-        nbhd = list(
-            chain.from_iterable([
-                [((gi, ci), coords)
-                 for ci, coords in enumerate(grouped_coords[gi])]
-                for gi in range(max(g - 1, 0), min(g + 2, ngroups))
-                if
-                p_edge_count[gi] > 0  # only keep if there are predicted edges
-            ]))
-        if len(nbhd) > 0:
-            nbhd_ids, nbhd_coords = zip(*nbhd)
-        else:
-            nbhd_ids, nbhd_coords = 2 * [[]]
-
-        # Calculate edge pixels in radial coords for all cells in this and
-        # neighbouring groups:
-        radial_edges = [
-            rc_to_radial(p_edge_locs[g], centre)
-            for centre, _, _ in nbhd_coords
-        ]
-
-        # Calculate initial residuals and prior weights
-        resids = [
-            rho - eval_radial_spline(phi, radii, angles)
-            for (rho, phi), (_, radii,
-                             angles) in zip(radial_edges, nbhd_coords)
-        ]
-        indep_weights = [prior_resid_weight(r) for r in resids]
-
-        probs = p_edge_probs[g]
-
-        g_opt_coords = []
-        for c, (centre, radii, angles) in enumerate(g_coords):
-            ind = nbhd_ids.index((g, c))
-            rho, phi = radial_edges[ind]
-            p_weighted = probs * indep_weights[ind]
-            other_weights = indep_weights[:ind] + indep_weights[ind + 1:]
-            if len(other_weights) > 0:
-                p_weighted *= (1 - np.mean(other_weights, axis=0))
-
-            # Remove insignificant fit data
-            signif = p_weighted > 0.1
-            if signif.sum() < 10:
-                # With insufficient data, skip optimisation
-                g_opt_coords.append((centre, radii, angles))
-                continue
-            p_weighted = p_weighted[signif]
-            phi = phi[signif]
-            rho = rho[signif]
-
-            nparams = len(radii) + len(angles)
-            opt = least_squares(adj_rspline_resid,
-                                np.zeros(nparams),
-                                bounds=(-np.ones(nparams), np.ones(nparams)),
-                                args=(rho, phi, p_weighted, radii, angles),
-                                ftol=5e-2)
-
-            g_opt_coords.append((centre,) +
-                                adj_rspline_coords(opt.x, radii, angles))
-
-        opt_coords.append(g_opt_coords)
-
-    return opt_coords
-
-
 def morph_radial_edge_seg(cnn_outputs):
     RL, CL, RU, CU = rp.bbox
     Rext = np.ceil(0.25 * (RU - RL)).astype('int')
diff --git a/python/baby/server.py b/python/baby/server.py
index 3b717ecf53db94084a8cccf8e6d37e0b31155c51..740c87bdaf17d946ca016420559f24f42298c521 100644
--- a/python/baby/server.py
+++ b/python/baby/server.py
@@ -2,13 +2,15 @@
 
 # If you publish results that make use of this software or the Birth Annotator
 # for Budding Yeast algorithm, please cite:
-# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S
-# Swain, 2021, Birth Annotator for Budding Yeast (in preparation).
+# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+# and Swain, P.S. (2023). Determining growth rates from bright-field images of
+# budding cells through identifying overlaps. eLife. 12:e79812.
+# https://doi.org/10.7554/eLife.79812
 # 
 # 
 # The MIT License (MIT)
 # 
-# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021
+# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
 # 
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to
@@ -45,6 +47,8 @@ from functools import reduce
 from operator import mul
 import numpy as np
 
+from baby import modelsets
+from baby import __version__
 from baby.brain import BabyBrain
 from baby.crawler import BabyCrawler
 from baby.utils import jsonify
@@ -61,7 +65,7 @@ SERVER_DIR = dirname(__file__)
 MAX_RUNNERS = 3
 MAX_SESSIONS = 20
 SLEEP_TIME = 0.2  # time between threaded checks for data availability
-MAX_ATTEMPTS = 300  # allows for 60s delay before timing out
+MAX_ATTEMPTS = 3000  # allows for 10m delay before timing out
 MAX_IMG_SIZE = 100 * 1024 * 1024  # allows for raw image sizes up to 100 MB
 
 DIMS_ERROR_MSG = '"dims" must be a length 4 integer array: [ntraps, width, height, depth]'
@@ -90,7 +94,7 @@ class PredMissingError(Exception):
 
 
 class TaskMaster(object):
-    def __init__(self):
+    def __init__(self, njobs=-2):
         self._lock = threading.Lock()
 
         self._runner_pool = []
@@ -100,12 +104,21 @@ class TaskMaster(object):
         self.tf_session = None
         self.tf_graph = None
         self.tf_version = (0, 0, 0)
+        self.njobs = njobs
+        self._modelsets_ids = None
+        self._modelsets_meta = None
 
     @property
-    def modelsets(self):
-        with open(join(SERVER_DIR, 'modelsets.json'), 'rt') as f:
-            modelsets = json.load(f)
-        return modelsets
+    def modelset_ids(self):
+        if self._modelsets_ids is None:
+            self._modelsets_ids = modelsets.ids()
+        return self._modelsets_ids
+
+    @property
+    def modelsets_meta(self):
+        if self._modelsets_meta is None:
+            self._modelsets_meta = modelsets.meta()
+        return self._modelsets_meta
 
     def new_session(self, model_name):
         # Clean up old sessions that exceed the maximum allowed number
@@ -125,16 +138,11 @@ class TaskMaster(object):
 
         return sessionid
 
-    def ensure_runner(self, model_name, modelsets=None):
+    def ensure_runner(self, model_name):
         if model_name in self.runners:
             print('Model "{}" already loaded. Skipping...'.format(model_name))
             return
 
-        if modelsets is None:
-            modelsets = self.modelsets
-
-        assert model_name in modelsets
-
         # Clean up old runners that exceed the maximum allowed number
         nrunners = len(self._runner_pool)
         with self._lock:
@@ -166,9 +174,9 @@ class TaskMaster(object):
         # Load BabyRunner
         print('Starting new runner for model "{}"...'.format(model_name))
 
-        baby = BabyBrain(**modelsets[model_name],
-                         session=self.tf_session, graph=self.tf_graph,
-                         suppress_errors=True, error_dump_dir=ERR_DUMP_DIR)
+        baby = modelsets.get(model_name, session=self.tf_session,
+                             graph=self.tf_graph, suppress_errors=True,
+                             error_dump_dir=ERR_DUMP_DIR)
 
         if self.runners.get(model_name) == 'pending':
             with self._lock:
@@ -214,7 +222,7 @@ class TaskMaster(object):
         #     tf.keras.backend.set_session(self.tf_session)
 
         t_start = time.perf_counter()
-        pred = crawler.step(img, parallel=True, **kwargs)
+        pred = crawler.step(img, parallel=True, njobs=self.njobs, **kwargs)
         t_elapsed = time.perf_counter() - t_start
 
         print('...images segmented in {:.3f} seconds.'.format(t_elapsed))
@@ -243,28 +251,33 @@ class TaskMaster(object):
 
 @routes.get('/')
 async def version(request):
-    return web.json_response({'baby': 'v1.0'})
+    return web.json_response({'baby': __version__})
 
 
 @routes.get('/models')
 async def get_modelsets(request):
     taskmstr = request.app['TaskMaster']
-    return web.json_response(list(taskmstr.modelsets.keys()))
+    with_meta = False
+    if 'meta' in request.query:
+        with_meta = request.query['meta'] == 'true'
+    if with_meta:
+        return web.json_response(taskmstr.modelsets_meta)
+    else:
+        return web.json_response(taskmstr.modelset_ids)
 
 
 @routes.get('/session/{model}')
 async def get_session(request):
     model_name = request.match_info['model']
     taskmstr = request.app['TaskMaster']
-    modelsets = taskmstr.modelsets
 
-    if model_name not in modelsets:
+    if model_name not in taskmstr.modelset_ids:
         raise web.HTTPNotFound(text='"{}" model is unknown'.format(model_name))
 
     # Ensure model is loaded in another thread
     loop = asyncio.get_event_loop()
     loop.run_in_executor(request.app['Executor'],
-                         taskmstr.ensure_runner, model_name, modelsets)
+                         taskmstr.ensure_runner, model_name)
     sessionid = taskmstr.new_session(model_name)
 
     print('Creating new session "{}" with model "{}"...'.format(
@@ -383,13 +396,6 @@ async def segment(request):
             val = await field.read(decode=True)
             kwargs[field.name] = json.loads(val)
 
-    if request.query.get('test', False):
-        print('Data received. Writing test image to "baby-server-test.png"...')
-        from imageio import imwrite
-        imwrite(join(SERVER_DIR, 'baby-server-test.png'),
-                np.squeeze(img[0, :, :, 0]))
-        return web.json_response({'status': 'test image written'})
-
     print('Data received. Segmenting {} images...'.format(len(img)))
 
     loop.run_in_executor(executor, taskmstr.segment, sessionid, img, kwargs)
@@ -441,12 +447,21 @@ async def get_segmentation(request):
 
 app = web.Application()
 app.add_routes(routes)
-app['TaskMaster'] = TaskMaster()
-app['Executor'] = ThreadPoolExecutor(2)
 
 def main():
+    from argparse import ArgumentParser
+    parser = ArgumentParser('Start BABY server for receiving segmentation requests')
+    parser.add_argument('-p', '--port', type=int, default=5101,
+                        help='port to bind the server to')
+    parser.add_argument('-n', '--njobs', type=int, default=-2,
+                        help='number of parallel jobs to run when processing')
+    args = parser.parse_args()
+ 
     import tensorflow as tf
 
+    app['TaskMaster'] = TaskMaster(njobs=args.njobs)
+    app['Executor'] = ThreadPoolExecutor(2)
+
     tf_version = tuple(int(v) for v in tf.version.VERSION.split('.'))
     app['TaskMaster'].tf_version = tf_version
 
@@ -486,4 +501,4 @@ def main():
         lfh.setFormatter(lff)
         logging.getLogger().addHandler(lfh)
 
-    web.run_app(app, port=5101)
+    web.run_app(app, port=args.port)
diff --git a/python/baby/speed_tests.py b/python/baby/speed_tests.py
index e4adf6910eeced9d4726a1b5905b0ab9feee1f62..2256e64b078cdb6c5ccf729ca84e57b7f5c15689 100644
--- a/python/baby/speed_tests.py
+++ b/python/baby/speed_tests.py
@@ -2,13 +2,15 @@
 
 # If you publish results that make use of this software or the Birth Annotator
 # for Budding Yeast algorithm, please cite:
-# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S
-# Swain, 2021, Birth Annotator for Budding Yeast (in preparation).
+# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+# and Swain, P.S. (2023). Determining growth rates from bright-field images of
+# budding cells through identifying overlaps. eLife. 12:e79812.
+# https://doi.org/10.7554/eLife.79812
 # 
 # 
 # The MIT License (MIT)
 # 
-# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021
+# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
 # 
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to
diff --git a/python/baby/tracker/benchmark.py b/python/baby/tracker/benchmark.py
index a701e34eac5f0ce6c1fc05717e1c6244da950fcb..c2a6e25df6698c4c0f136e36789747c0e0054b10 100644
--- a/python/baby/tracker/benchmark.py
+++ b/python/baby/tracker/benchmark.py
@@ -1,12 +1,14 @@
 # If you publish results that make use of this software or the Birth Annotator
 # for Budding Yeast algorithm, please cite:
-# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S
-# Swain, 2021, Birth Annotator for Budding Yeast (in preparation).
+# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+# and Swain, P.S. (2023). Determining growth rates from bright-field images of
+# budding cells through identifying overlaps. eLife. 12:e79812.
+# https://doi.org/10.7554/eLife.79812
 # 
 # 
 # The MIT License (MIT)
 # 
-# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021
+# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
 # 
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to
diff --git a/python/baby/tracker/core.py b/python/baby/tracker/core.py
index 50f629982f62e497ea543a254ceb12ea9e25334c..4402057589613e085fe553edbff3520a76623d9b 100644
--- a/python/baby/tracker/core.py
+++ b/python/baby/tracker/core.py
@@ -1,14 +1,14 @@
-#!/usr/bin/env python
-
 # If you publish results that make use of this software or the Birth Annotator
 # for Budding Yeast algorithm, please cite:
-# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S
-# Swain, 2021, Birth Annotator for Budding Yeast (in preparation).
+# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+# and Swain, P.S. (2023). Determining growth rates from bright-field images of
+# budding cells through identifying overlaps. eLife. 12:e79812.
+# https://doi.org/10.7554/eLife.79812
 # 
 # 
 # The MIT License (MIT)
 # 
-# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021
+# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
 # 
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to
@@ -27,6 +27,8 @@
 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
 # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
 # IN THE SOFTWARE.
+# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
+# IN THE SOFTWARE.
 
 '''
 TrackerCoordinator class to coordinate cell tracking and bud assignment.
@@ -43,8 +45,11 @@ from sklearn.ensemble import RandomForestClassifier
 from sklearn.svm import SVC
 from baby.errors import BadOutput
 from baby.tracker.utils import calc_barycentre, pick_baryfun
+from baby import modelsets
+
+
+DEFAULT_MODELSET = 'yeast-alcatras-brightfield-EMCCD-60x-5z'
 
-models_path = join(dirname(__file__), '../models')
 
 class FeatureCalculator:
     '''
@@ -94,12 +99,15 @@ class FeatureCalculator:
         if 'area' in self.outfeats:
             self.aind = self.outfeats.index('area')
 
-    def load_model(self, path, fname):
+    def load_model(self, fname, path=None):
+        if path is not None:
             model_file = join(path, fname)
-            with open(model_file, 'rb') as file_to_load:
-                model = pickle.load(file_to_load)
+        else:
+            model_file = fname
+        with open(model_file, 'rb') as file_to_load:
+            model = pickle.load(file_to_load)
 
-            return model
+        return model
 
     def calc_feats_from_mask(self, masks, feats2use=None, trapfeats=None,
                              norm=True, px_size=None):
@@ -216,6 +224,7 @@ class FeatureCalculator:
 
         return feats
 
+
 class CellTracker(FeatureCalculator):
     '''
     Class used to manage cell tracking. You can call it using an existing model or
@@ -264,24 +273,26 @@ class CellTracker(FeatureCalculator):
         if extra_feats is None:
             extra_feats = ()
 
-        if type(model) is str or type(model) is PosixPath: 
-            with open(Path(model), 'rb') as f:
-                model = pickle.load(f)
+        if isinstance(model, (Path, str)): 
+            model = self.load_model(model)
 
-        if type(bak_model) is str or type(bak_model) is PosixPath: 
-            with open(Path(bak_model), 'rb') as f:
-                bak_model = pickle.load(f)
+        if isinstance(bak_model, (Path, str)):
+            bak_model = self.load_model(bak_model)
 
         if aweights is None:
             self.aweights = None
 
         if feats2use is None: # Ignore this block when training
             if model is None:
-                model = self.load_model( models_path,
-                                         'ct_rf_20210201_12.pkl')
+                default_params = modelsets.get_params(DEFAULT_MODELSET)
+                model_file = default_params['celltrack_model_file']
+                model_file = modelsets.resolve(model_file, DEFAULT_MODELSET)
+                model = self.load_model(model_file)
             if bak_model is None:
-                bak_model = self.load_model( models_path,
-                                         'ct_rf_20210125_9.pkl')
+                default_params = modelsets.get_params(DEFAULT_MODELSET)
+                bak_file = default_params['celltrack_backup_model_file']
+                bak_file = modelsets.resolve(bak_file, DEFAULT_MODELSET)
+                bak_model = self.load_model(bak_file)
             self.model = model
             self.bak_model = bak_model
 
@@ -542,17 +553,20 @@ class CellTracker(FeatureCalculator):
             return ([], [], max_lbl)
         return (new_lbls, new_feats, new_max)
 
+
 class BudTracker(FeatureCalculator):
     def __init__(self,
                  model=None,
                  feats2use=None,
                  **kwargs):
 
-        if model is None:
-            model_file = join(models_path,
-                                      'mb_model_20201022.pkl')
-            with open(model_file, 'rb') as file_to_load:
-                model = pickle.load(file_to_load)
+        if isinstance(model, (Path, str)):
+            model = self.load_model(model)
+        elif model is None:
+            default_params = modelsets.get_params(DEFAULT_MODELSET)
+            model_file = default_params['budassign_model_file']
+            model_file = modelsets.resolve(model_file, DEFAULT_MODELSET)
+            model = self.load_model(model_file)
         self.model = model
 
         if feats2use is None:
@@ -718,6 +732,7 @@ class BudTracker(FeatureCalculator):
 
         return r_points
 
+
 class MasterTracker(FeatureCalculator):
     '''
     Coordinates the data transmission from CellTracker to BudTracker to
@@ -842,6 +857,8 @@ class MasterTracker(FeatureCalculator):
                 p_budneck, p_bud, masks, feats[:, self.bt_idx])
             lblinds = np.array(new_lbls) - 1  # new_lbls are indexed from 1
             lifetime[lblinds] += 1
+            # TODO: the following may lead to values of p_is_mother higher
+            # than one, and hence negative ba_probs below
             p_is_mother[lblinds] = np.maximum(p_is_mother[lblinds],
                                               ba_probs.sum(1))
             p_was_bud[lblinds] = np.maximum(p_was_bud[lblinds],
@@ -875,6 +892,9 @@ class MasterTracker(FeatureCalculator):
         if assign_mothers:
             if max_lbl > 0:
                 # Calculate mother assignments for this time point
+                # TODO: this should proceed more like the IoU assignment
+                # algorithm to guarantee that two buds will not be assigned to
+                # the same mother
                 ma = ba_cum[0:max_lbl, 0:max_lbl].argmax(0) + 1
                 # Cell must have been a bud and been present for at least
                 # min_bud_tps
@@ -895,6 +915,163 @@ class MasterTracker(FeatureCalculator):
         return output
 
 
+class MMTracker(FeatureCalculator):
+    '''
+    Tracker specialised for mother machines. Does not use CellTracker or
+    BudTracker, but a custom algorithm adapted from Sean Murray's group (Robin
+    Koehler's code)
+
+    input
+    :ctrack_args: dict with arguments to pass on to CellTracker constructor
+        if None it passes all the features to use
+    :btrack_args: dict with arguments to pass on to BudTracker constructor
+        if None it passes all the features to use
+    :**kwargs: additional arguments passed to FeatureCalculator constructor
+    '''
+    def __init__(self, growthrate=0.1, tol_imbalance=20, nstepsback=1, **kwargs):
+        self.nstepsback = nstepsback
+        self.growthrate = growthrate
+        self.tol_imbalance = tol_imbalance
+        feats2use = ('centroid', 'major_axis_length', 'minor_axis_length')
+        super().__init__(feats2use, **kwargs)
+
+    def step_trackers(self,
+                      masks,
+                      p_budneck=None,
+                      p_bud=None,
+                      state=None,
+                      assign_mothers=False,
+                      return_baprobs=False,
+                      keep_full_state=False):
+        '''
+        Calculate features and track cells and budassignments
+
+        For compatibility with the existing MasterTracker, we retain some
+        unused input arguments.
+
+        :masks: 3d ndarray (ncells, size_x, size_y) containing cell masks
+        :p_budneck: NOT REQUIRED
+        :p_bud: NOT REQUIRED
+        :state: running state for the tracker, or None for initialisation
+        :assign_mothers: whether to include mother assignments in the returned
+            returns
+        :return_baprobs: whether to include bud assignment probability matrix
+            in the returned output
+
+        returns a dict consisting of
+
+        :cell_label: list of int, the tracked global ID for each cell mask
+        :state: the updated state to be used in a subsequent step
+        :mother_assign: (optional) list of int, specifying the assigned mother
+            for each cell
+        :p_bud_assign: (optional) matrix (list of lists of floats), bud assignment
+            probability matrix from `predict_mother_bud`
+        '''
+
+        gr = np.exp(self.growthrate)
+        tol = self.tol_imbalance
+
+        if state is None:
+            state = {}
+
+        new_lbl = state.get('new_lbl', 1)
+        cell_lbls = state.get('cell_lbls', [[]])
+        prev_feats = state.get('prev_feats', [np.zeros((0, self.ntfeats))])
+        mothers = state.get('mothers', [])
+
+        # Get features for cells at this time point
+        feats = self.calc_feats_from_mask(masks, norm=False)
+
+        N_prev = prev_feats[-1].shape[0]
+        N_this = feats.shape[0]
+
+        prev_lbls = np.array(cell_lbls[-1])
+        this_lbls = -np.ones(N_this, dtype='int64')
+        Ci = self.xind
+        cInd_prev = np.argsort(-prev_feats[-1][:,Ci])
+        cInd_this = np.argsort(-feats[:,Ci])
+        Wi = self.outfeats.index('major_axis_length')
+        Ws_prev = prev_feats[-1][:,Wi][cInd_prev]
+        Ws_this = feats[:,Wi][cInd_this]
+        c_prev, c_this = 0, 0
+        aggW_prev, aggW_this = 0, 0
+        aggC_prev, aggC_this = [], []
+
+        while True:
+            if aggW_prev < aggW_this:
+                if c_prev >= N_prev:
+                    break
+                aggW_prev += Ws_prev[c_prev]
+                aggC_prev.append(cInd_prev[c_prev])
+                c_prev += 1
+            else:
+                if c_this >= N_this:
+                    break
+                aggW_this += Ws_this[c_this]
+                aggC_this.append(cInd_this[c_this])
+                c_this += 1
+
+            balance = aggW_prev * gr - aggW_this
+
+            # if balance is found
+            if balance >= -tol and balance <= tol and aggW_prev != 0 and aggW_this != 0:
+                if len(aggC_prev) == 1 and len(aggC_this) == 1:
+                    this_lbls[aggC_this[0]] = prev_lbls[aggC_prev[0]]
+                elif len(aggC_prev) == 1 and len(aggC_this) == 2:
+                    this_lbls[aggC_this[0]] = prev_lbls[aggC_prev[0]]
+                    this_lbls[aggC_this[1]] = new_lbl
+                    mothers.append((new_lbl, prev_lbls[aggC_prev[0]]))
+                    new_lbl += 1
+                else:
+                    # In any other ambiguous cases, mark cells simply as new tracks
+                    for c in aggC_this:
+                        this_lbls[c] = new_lbl
+                        new_lbl += 1
+
+                aggW_prev, aggW_this = 0, 0
+                aggC_prev, aggC_this = [], []
+        
+        # Mark any remaining unbalanced/extra cells as new tracks
+        for c in aggC_this + cInd_this[list(range(c_this, N_this))].tolist():
+            this_lbls[c] = new_lbl
+            new_lbl += 1
+
+        this_lbls = this_lbls.tolist()
+
+        if not keep_full_state:
+            cell_lbls = cell_lbls[-self.nstepsback:]
+            prev_feats = prev_feats[-self.nstepsback:]
+
+        # Finally update the state
+        state = {
+            'new_lbl': new_lbl,
+            'cell_lbls': cell_lbls + [this_lbls],
+            'prev_feats': prev_feats + [feats],
+            'mothers': mothers
+        }
+
+        output = {
+            'cell_label': this_lbls,
+            'state': state
+        }
+
+        if assign_mothers:
+            ma = np.zeros(new_lbl - 1, dtype='int64')
+            for d, m in mothers:
+                ma[d - 1] = m
+
+            if np.any(ma == np.arange(1, len(ma) + 1)):
+                raise BadOutput('Daughter has been assigned as mother to itself')
+
+            output['mother_assign'] = ma.tolist()
+
+        if return_baprobs:
+            ba_probs = np.zeros((N_this, N_this))
+            output['p_bud_assign'] = ba_probs.tolist()
+
+        return output
+
+
 # Helper functions
 
 def switch_case_nfeats(nfeats):
diff --git a/python/baby/tracker/training.py b/python/baby/tracker/training.py
index 904a2e99a3f540e5147a99222fc59062c5f0de35..ddea123d8b94e4127e270cac3907a5d806e2c3aa 100644
--- a/python/baby/tracker/training.py
+++ b/python/baby/tracker/training.py
@@ -2,13 +2,15 @@
 
 # If you publish results that make use of this software or the Birth Annotator
 # for Budding Yeast algorithm, please cite:
-# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S
-# Swain, 2021, Birth Annotator for Budding Yeast (in preparation).
+# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+# and Swain, P.S. (2023). Determining growth rates from bright-field images of
+# budding cells through identifying overlaps. eLife. 12:e79812.
+# https://doi.org/10.7554/eLife.79812
 # 
 # 
 # The MIT License (MIT)
 # 
-# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021
+# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
 # 
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to
@@ -42,7 +44,7 @@ from baby.io import load_tiled_image
 from baby.tracker.utils import pick_baryfun, calc_barycentre
 from baby.training.utils import TrainValProperty, TrainValTestProperty
 from baby.errors import BadProcess, BadParam
-from .core import CellTracker, BudTracker
+from .core import CellTracker
 
 from .benchmark import CellBenchmarker
 
@@ -55,8 +57,14 @@ from sklearn.svm import SVC#, LinearSVC
 from sklearn.model_selection import GridSearchCV
 from sklearn.metrics import (
     make_scorer, fbeta_score, accuracy_score, balanced_accuracy_score,
-    precision_score, recall_score, f1_score, plot_precision_recall_curve
+    precision_score, recall_score, f1_score
 )
+import sklearn
+if int(sklearn.__version__[0]) > 0:
+    from sklearn.metrics import PrecisionRecallDisplay
+    plot_precision_recall_curve = PrecisionRecallDisplay.from_estimator
+else:
+    from sklearn.metrics import plot_precision_recall_curve
 
 class CellTrainer(CellTracker):
     '''
@@ -252,7 +260,7 @@ class CellTrainer(CellTracker):
     def explore_hyperparams(self, model_type = 'rf'):
         self.model_type = model_type
         truth, data = *zip(*self.train),
-        if self.model_type is 'SVC':
+        if self.model_type == 'SVC':
             model = SVC(probability = True, shrinking=False,
                         verbose=True, random_state=1)
             param_grid = {
@@ -263,7 +271,7 @@ class CellTrainer(CellTracker):
               'gamma': [1, 0.01, 0.0001],
               'kernel': ['rbf', 'sigmoid']
             }
-        elif model_type is 'rf':
+        elif model_type == 'rf':
             model = RandomForestClassifier(n_estimators=15,
                                     criterion='gini',
                                     max_depth=3,
@@ -319,208 +327,6 @@ class CellTrainer(CellTracker):
 
 
 
-class BudTrainer(BudTracker):
-    '''
-    :props_file: File where generated property table will be saved
-    :kwargs: Additional arguments passed onto the parent Tracker; `px_size` is
-        especially useful.
-    '''
-
-    def __init__(self, props_file=None, **kwargs):
-        super().__init__(**kwargs)
-        # NB: we inherit self.feats2use from CellTracker class
-        self.props_file = props_file
-        self.rf_feats = ["p_bud_mat", "size_ratio_mat", "p_budneck_mat",
-                "budneck_ratio_mat", "adjacency_mat"]
-
-    @property
-    def props_file(self):
-        return getattr(self, '_props_file')
-
-    @props_file.setter
-    def props_file(self, filename):
-        if filename is not None:
-            self._props_file = Path(filename)
-
-    @property
-    def props(self):
-        if getattr(self, '_props', None) is None:
-            if self.props_file and self.props_file.is_file():
-                self.props = pd.read_csv(self.props_file)
-            else:
-                raise BadProcess(
-                        'The property table has not yet been generated')
-        return self._props
-
-    @props.setter
-    def props(self, props):
-        props = pd.DataFrame(props)
-        required_cols = self.rf_feats + ['is_mb_pair', 'validation']
-        if not all(c in props for c in required_cols):
-            raise BadParam(
-                '"props" does not have all required columns: {}'.format(
-                    ', '.join(required_cols)))
-        self._props = props
-        if self.props_file:
-            props.to_csv(self.props_file)
-
-    def generate_property_table(self, data, flattener, val_data=None):
-        '''Generates properties table that gets used for training
-
-        :data: List or generator of `baby.training.SegExample` tuples
-        :flattener: Instance of a `baby.preprocessing.SegmentationFlattening`
-            object describing the targets of the CNN in data
-        '''
-        tnames = flattener.names()
-        i_budneck = tnames.index('bud_neck')
-        bud_target = 'sml_fill' if 'sml_fill' in tnames else 'sml_inte'
-        i_bud = tnames.index(bud_target)
-
-        if val_data is not None:
-            data = TrainValProperty(data, val_data)
-        if isinstance(data, (TrainValProperty, TrainValTestProperty)):
-            data = chain(zip(repeat(False), data.train),
-                         zip(repeat(True), data.val))
-        else:
-            data = zip(repeat(None), data)
-
-        p_list = []
-        for is_val, seg_example in data:
-            if len(seg_example.target) < 2:
-                # Skip if no pairs are present
-                continue
-            mb_stats = self.calc_mother_bud_stats(seg_example.pred[i_budneck],
-                    seg_example.pred[i_bud], seg_example.target)
-            p = pd.DataFrame(mb_stats, columns=self.rf_feats)
-            p['validation'] = is_val
-
-            # "cellLabels" specifies the label for each mask
-            cell_labels = seg_example.info.get('cellLabels', []) or []
-            if type(cell_labels) is int:
-                cell_labels = [cell_labels]
-            # "buds" specifies the label of the bud for each mask
-            buds = seg_example.info.get('buds', []) or []
-            if type(buds) is int:
-                buds = [buds]
-
-            # Build a ground truth matrix identifying mother-bud pairs
-            ncells = len(seg_example.target)
-            is_mb_pair = np.zeros((ncells, ncells), dtype=bool)
-            mb_inds = [
-                (i, cell_labels.index(b))
-                for i, b in enumerate(buds)
-                if b > 0 and b in cell_labels
-            ]
-            if len(mb_inds) > 0:
-                mother_inds, bud_inds = zip(*mb_inds)
-                is_mb_pair[mother_inds, bud_inds] = True
-            p['is_mb_pair'] = is_mb_pair.flatten()
-
-            # Ignore any rows containing NaNs
-            nanrows = np.isnan(mb_stats).any(axis=1)
-            if (p['is_mb_pair'] & nanrows).any():
-                id_keys = ('experimentID', 'position', 'trap', 'tp')
-                info = seg_example.info
-                img_id = ' / '.join(
-                        [k + ': ' + str(info[k]) for k in id_keys if k in info])
-                warn('Mother-bud pairs omitted due to feature NaNs')
-                print('Mother-bud pair omitted in "{}"'.format(img_id))
-            p = p.loc[~nanrows, :]
-            p_list.append(p)
-
-        props = pd.concat(p_list, ignore_index=True)
-        # TODO: should search for any None values in validation column and
-        # assign a train-validation split to those rows
-
-        self.props = props # also saves
-
-    def explore_hyperparams(self, hyper_param_target='precision'):
-        # Train bud assignment model on validation data, since this more
-        # closely represents real-world performance of the CNN:
-        data = self.props.loc[self.props['validation'], self.rf_feats]
-        truth = self.props.loc[self.props['validation'], 'is_mb_pair']
-
-        rf = RandomForestClassifier(n_estimators=15,
-                                    criterion='gini',
-                                    max_depth=3,
-                                    class_weight='balanced')
-
-        param_grid = {
-            'n_estimators': [6, 15, 50, 100],
-            'max_features': ['auto', 'sqrt', 'log2'],
-            'max_depth': [2, 3, 4],
-            'class_weight': [None, 'balanced', 'balanced_subsample']
-        }
-
-        def get_balanced_best_index(cv_results_):
-            '''Find a model balancing F1 score and speed'''
-            df = pd.DataFrame(cv_results_)
-            best_score = df.iloc[df.mean_test_f1.idxmax(), :]
-            thresh = best_score.mean_test_f1 - 0.1 * best_score.std_test_f1
-            return df.loc[df.mean_test_f1 > thresh, 'mean_score_time'].idxmin()
-
-        self._rf = GridSearchCV(estimator=rf, param_grid=param_grid, cv=5,
-                scoring=SCORING_METRICS, refit=hyper_param_target)
-        self._rf.fit(data, truth)
-
-        df = pd.DataFrame(self._rf.cv_results_)
-        disp_cols = [c for c in df.columns if c.startswith('mean_')
-                     or c.startswith('param_')]
-        print(df.loc[self._rf.best_index_, disp_cols])
-
-    def performance(self):
-        if not isinstance(getattr(self, '_rf', None), GridSearchCV):
-            raise BadProcess('"explore_hyperparams" has not been run')
-
-        best_rf = self._rf.best_estimator_
-        isval = self.props['validation']
-        data = self.props.loc[~isval, self.rf_feats]
-        truth = self.props.loc[~isval, 'is_mb_pair']
-        valdata = self.props.loc[isval, self.rf_feats]
-        valtruth = self.props.loc[isval, 'is_mb_pair']
-        metrics = tuple(SCORING_METRICS.values())
-        return TrainValProperty(
-                Score(*(m(best_rf, data, truth) for m in metrics)),
-                Score(*(m(best_rf, valdata, valtruth) for m in metrics)))
-
-    def plot_PR(self):
-        best_rf = self._rf.best_estimator_
-        isval = self.props['validation']
-        valdata = self.props.loc[isval, self.rf_feats]
-        valtruth = self.props.loc[isval, 'is_mb_pair']
-        plot_precision_recall_curve(best_rf, valdata, valtruth)
-
-    def save_model(self, filename):
-        f = open(filename, 'wb')
-        pickle.dump(self._rf.best_estimator_, f)
-
-
-SCORING_METRICS = {
-    'accuracy': make_scorer(accuracy_score),
-    'balanced_accuracy': make_scorer(balanced_accuracy_score),
-    'precision': make_scorer(precision_score),
-    'recall': make_scorer(recall_score),
-    'f1': make_scorer(f1_score),
-    'f0_5': make_scorer(fbeta_score, beta=0.5),
-    'f2': make_scorer(fbeta_score, beta=2)
-}
-
-
-class Score(NamedTuple):
-    accuracy: float
-    balanced_accuracy: float
-    precision: float
-    recall: float
-    F1: float
-    F0_5: float
-    F2: float
-
-    def __str__(self):
-        return 'Score({})'.format(', '.join([
-            '{}={:.3f}'.format(k, v) for k, v in self._asdict().items()
-            ]))
-
-
 
 def get_ground_truth(cell_labels, buds):
     ncells = len(cell_labels)
diff --git a/python/baby/tracker/utils.py b/python/baby/tracker/utils.py
index 12c1dfc9c45f67d43c13a1156a00fe5c056f90aa..17b5c06670e9f502e519f5983b2f009b3eb1ca87 100644
--- a/python/baby/tracker/utils.py
+++ b/python/baby/tracker/utils.py
@@ -1,12 +1,14 @@
 # If you publish results that make use of this software or the Birth Annotator
 # for Budding Yeast algorithm, please cite:
-# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S
-# Swain, 2021, Birth Annotator for Budding Yeast (in preparation).
+# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+# and Swain, P.S. (2023). Determining growth rates from bright-field images of
+# budding cells through identifying overlaps. eLife. 12:e79812.
+# https://doi.org/10.7554/eLife.79812
 # 
 # 
 # The MIT License (MIT)
 # 
-# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021
+# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
 # 
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to
diff --git a/python/baby/training/__init__.py b/python/baby/training/__init__.py
index 09abf58649a4f6bfc38987e512e62f4332cfb07d..83aa435af93bfe80874eca5ae47c9566eff5063b 100644
--- a/python/baby/training/__init__.py
+++ b/python/baby/training/__init__.py
@@ -1,12 +1,14 @@
 # If you publish results that make use of this software or the Birth Annotator
 # for Budding Yeast algorithm, please cite:
-# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S
-# Swain, 2021, Birth Annotator for Budding Yeast (in preparation).
+# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+# and Swain, P.S. (2023). Determining growth rates from bright-field images of
+# budding cells through identifying overlaps. eLife. 12:e79812.
+# https://doi.org/10.7554/eLife.79812
 # 
 # 
 # The MIT License (MIT)
 # 
-# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021
+# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
 # 
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to
@@ -35,7 +37,7 @@ It includes the following trainers
 * `HyperParameterTrainer`: CNN hyper-parameters
 * `CNNTrainer`: CNN using gradient descent to optimize for a given loss
 * `SegmentationTrainer`: hyper-parameters for post-processing of CNN Output
-into cell instances and attributes
+    into cell instances and attributes
 
 Given the appropriate inputs, each of these can be trained separately.
 This is useful for fine-tuning or re-training parts separately.
@@ -43,8 +45,10 @@ This is useful for fine-tuning or re-training parts separately.
 For training the entire framework at once, it is recommended to use the
 `BabyTrainer` class, which is also aliased as `Nursery`.
 """
+from .utils import fix_tf_rtx_gpu_bug
 from .smoothing_model_trainer import SmoothingModelTrainer
 from .flattener_trainer import FlattenerTrainer
+from .segmentation_trainer import SegmentationTrainer
 import tensorflow as tf
 if tf.__version__.startswith('1'):
     from .v1_hyper_parameter_trainer import HyperParamV1 \
@@ -52,7 +56,4 @@ if tf.__version__.startswith('1'):
 else:
     from .hyper_parameter_trainer import HyperParameterTrainer
 from .cnn_trainer import CNNTrainer
-
-from .training import *
-from .utils import fix_tf_rtx_gpu_bug
-
+from baby.training.training import BabyTrainer, Nursery
diff --git a/python/baby/training/bud_trainer.py b/python/baby/training/bud_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0a60d6952ffa6122346e9caa34cc48882ca2824
--- /dev/null
+++ b/python/baby/training/bud_trainer.py
@@ -0,0 +1,332 @@
+# If you publish results that make use of this software or the Birth Annotator
+# for Budding Yeast algorithm, please cite:
+# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+# and Swain, P.S. (2023). Determining growth rates from bright-field images of
+# budding cells through identifying overlaps. eLife. 12:e79812.
+# https://doi.org/10.7554/eLife.79812
+# 
+# 
+# The MIT License (MIT)
+# 
+# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
+# 
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to
+# deal in the Software without restriction, including without limitation the
+# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+# sell copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+# 
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+# 
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
+# IN THE SOFTWARE.
+
+from pathlib import Path
+from itertools import repeat, chain
+from warnings import warn
+import pickle
+import numpy as np
+import pandas as pd
+from typing import NamedTuple
+
+from sklearn.ensemble import RandomForestClassifier
+from sklearn.model_selection import GridSearchCV
+from sklearn.metrics import (
+    make_scorer, fbeta_score, accuracy_score, balanced_accuracy_score,
+    precision_score, recall_score, f1_score
+)
+import sklearn
+if int(sklearn.__version__[0]) > 0:
+    from sklearn.metrics import PrecisionRecallDisplay
+    plot_precision_recall_curve = PrecisionRecallDisplay.from_estimator
+else:
+    from sklearn.metrics import plot_precision_recall_curve
+
+from baby.errors import BadProcess, BadParam
+from baby.tracker.core import BudTracker
+
+from .utils import (SharedParameterContainer, SharedDataContainer,
+                    TrainValTestProperty, TrainValProperty,
+                    standard_augmenter)
+from .segmentation_trainer import SegmentationTrainer
+
+
+SCORING_METRICS = {
+    'accuracy': make_scorer(accuracy_score),
+    'balanced_accuracy': make_scorer(balanced_accuracy_score),
+    'precision': make_scorer(precision_score),
+    'recall': make_scorer(recall_score),
+    'f1': make_scorer(f1_score),
+    'f0_5': make_scorer(fbeta_score, beta=0.5),
+    'f2': make_scorer(fbeta_score, beta=2)
+}
+
+
+class Score(NamedTuple):
+    accuracy: float
+    balanced_accuracy: float
+    precision: float
+    recall: float
+    F1: float
+    F0_5: float
+    F2: float
+
+    def __str__(self):
+        return 'Score({})'.format(', '.join([
+            '{}={:.3f}'.format(k, v) for k, v in self._asdict().items()
+            ]))
+
+
+class BudTrainer(BudTracker):
+    """Coordinates training for the mother-bud assignment model
+
+    Args:
+        kwargs: Additional arguments passed onto the parent Tracker.
+        shared_params: Training and segmentation parameters as provided by
+            :py:class:`utils.SharedParameterContainer`.
+        shared_data: Training data as provided by
+            :py:class:`utils.SharedDataContainer`.
+        cnn_trainer: Trainer with optimised CNN.
+    """
+
+    def __init__(self,
+                 shared_params: SharedParameterContainer,
+                 shared_data: SharedDataContainer,
+                 seg_trainer: SegmentationTrainer,
+                 **kwargs):
+
+        kwargs.setdefault('px_size',
+                          shared_params.parameters.target_pixel_size)
+        kwargs.setdefault('model', False)
+        self._shared_params = shared_params
+        super().__init__(**kwargs)
+
+        self._shared_data = shared_data
+        self._seg_trainer = seg_trainer
+
+        # NB: we inherit self.feats2use from CellTracker class
+        self.rf_feats = ["p_bud_mat", "size_ratio_mat", "p_budneck_mat",
+                "budneck_ratio_mat", "adjacency_mat"]
+
+        self._model = None
+
+    @property
+    def px_size(self):
+        return self._shared_params.parameters.target_pixel_size
+
+    @px_size.setter
+    def px_size(self, val):
+        if val != self._shared_params.parameters.target_pixel_size:
+            raise BadParam('px_size should be set via the '
+                           '`target_pixel_size` parameter in the '
+                           '`SharedParameterContainer`')
+
+    @property
+    def save_dir(self):
+        """Base directory in which to save trained models"""
+        return self._shared_params.save_dir
+
+    @property
+    def props_file(self):
+        return (self.save_dir /
+                self._shared_params.parameters.mother_bud_props_file)
+
+    @property
+    def props(self):
+        if getattr(self, '_props', None) is None:
+            if self.props_file and self.props_file.is_file():
+                self.props = pd.read_csv(self.props_file)
+            else:
+                raise BadProcess(
+                        'The property table has not yet been generated')
+        return self._props
+
+    @props.setter
+    def props(self, props):
+        props = pd.DataFrame(props)
+        required_cols = self.rf_feats + ['is_mb_pair', 'validation']
+        if not all(c in props for c in required_cols):
+            raise BadParam(
+                '"props" does not have all required columns: {}'.format(
+                    ', '.join(required_cols)))
+        self._props = props
+        props.to_csv(self.props_file)
+
+    @property
+    def model_save_file(self):
+        return (self.save_dir /
+                self._shared_params.parameters.mother_bud_model_file)
+
+    @property
+    def model(self):
+        if self._model is None:
+            if isinstance(getattr(self, '_rf', None), GridSearchCV):
+                self._model = self._rf.best_estimator_
+            elif self.model_save_file.isfile():
+                with open(filename, 'rb') as f:
+                    self._model = pickle.load(f)
+            else:
+                raise BadProcess('"explore_hyperparams" has not been run')
+        return self._model
+
+    @model.setter
+    def model(self, val):
+        if val:
+            if not isinstance(val, RandomForestClassifier):
+                raise BadParam('model must be a RandomForestClassifier')
+            self._model = val
+        else:
+            self._model = None
+
+    @property
+    def feature_importance(self):
+        return dict(zip(self.rf_feats, self.model.feature_importances_))
+
+    def generate_property_table(self):
+        """Generate table of properties to be used for training
+        """
+        segtrainer = self._seg_trainer
+        flattener = segtrainer.flattener
+        data = chain(zip(repeat('train'), segtrainer.examples.train),
+                     zip(repeat('val'), segtrainer.examples.val),
+                     zip(repeat('test'), segtrainer.examples.test))
+
+        tnames = flattener.names()
+        i_budneck = tnames.index('bud_neck')
+        bud_target = 'sml_fill' if 'sml_fill' in tnames else 'sml_inte'
+        i_bud = tnames.index(bud_target)
+
+        p_list = []
+        for train_split, seg_example in data:
+            if len(seg_example.target) < 2:
+                # Skip if no pairs are present
+                continue
+            mb_stats = self.calc_mother_bud_stats(seg_example.pred[i_budneck],
+                    seg_example.pred[i_bud], seg_example.target)
+            p = pd.DataFrame(mb_stats, columns=self.rf_feats)
+            p['validation'] = train_split == 'val'
+            p['testing'] = train_split == 'test'
+
+            # "cellLabels" specifies the label for each mask
+            cell_labels = seg_example.info.get('cellLabels', []) or []
+            if type(cell_labels) is int:
+                cell_labels = [cell_labels]
+            # "buds" specifies the label of the bud for each mask
+            buds = seg_example.info.get('buds', []) or []
+            if type(buds) is int:
+                buds = [buds]
+
+            # Build a ground truth matrix identifying mother-bud pairs
+            ncells = len(seg_example.target)
+            is_mb_pair = np.zeros((ncells, ncells), dtype=bool)
+            mb_inds = [
+                (i, cell_labels.index(b))
+                for i, b in enumerate(buds)
+                if b > 0 and b in cell_labels
+            ]
+            if len(mb_inds) > 0:
+                mother_inds, bud_inds = zip(*mb_inds)
+                is_mb_pair[mother_inds, bud_inds] = True
+            p['is_mb_pair'] = is_mb_pair.flatten()
+
+            # Ignore any rows containing NaNs
+            nanrows = np.isnan(mb_stats).any(axis=1)
+            if (p['is_mb_pair'] & nanrows).any():
+                id_keys = ('experimentID', 'position', 'trap', 'tp')
+                info = seg_example.info
+                img_id = ' / '.join(
+                        [k + ': ' + str(info[k]) for k in id_keys if k in info])
+                warn('Mother-bud pairs omitted due to feature NaNs')
+                print('Mother-bud pair omitted in "{}"'.format(img_id))
+            p = p.loc[~nanrows, :]
+            p_list.append(p)
+
+        props = pd.concat(p_list, ignore_index=True)
+        # TODO: should search for any None values in validation column and
+        # assign a train-validation split to those rows
+
+        self.props = props # also saves
+
+    def explore_hyperparams(self, hyper_param_target='precision'):
+        # Train bud assignment model on validation data, since this more
+        # closely represents real-world performance of the CNN:
+        data = self.props.loc[self.props['validation'], self.rf_feats]
+        truth = self.props.loc[self.props['validation'], 'is_mb_pair']
+
+        rf = RandomForestClassifier(n_estimators=15,
+                                    criterion='gini',
+                                    max_depth=3,
+                                    class_weight='balanced')
+
+        param_grid = {
+            'n_estimators': [6, 15, 50, 100],
+            'max_features': ['auto', 'sqrt', 'log2'],
+            'max_depth': [2, 3, 4],
+            'class_weight': [None, 'balanced', 'balanced_subsample']
+        }
+
+        def get_balanced_best_index(cv_results_):
+            """Find a model balancing F1 score and speed"""
+            df = pd.DataFrame(cv_results_)
+            best_score = df.iloc[df.mean_test_f1.idxmax(), :]
+            thresh = best_score.mean_test_f1 - 0.1 * best_score.std_test_f1
+            return df.loc[df.mean_test_f1 > thresh, 'mean_score_time'].idxmin()
+
+        self._model = None
+        self._rf = GridSearchCV(estimator=rf, param_grid=param_grid, cv=5,
+                scoring=SCORING_METRICS, refit=hyper_param_target)
+        self._rf.fit(data, truth)
+
+    def save_model(self, filename=None):
+        if filename is None:
+            filename = self.model_save_file
+        with open(filename, 'wb') as f:
+            pickle.dump(self.model, f)
+
+    def fit(self, **kwargs):
+        try:
+            self.props
+        except BadProcess:
+            self.generate_property_table()
+        self.explore_hyperparams(**kwargs)
+        self.save_model()
+
+    def performance(self):
+        isval = self.props['validation']
+        istest = self.props['testing']
+        data = self.props.loc[(~isval) & (~istest), self.rf_feats]
+        truth = self.props.loc[(~isval) & (~istest), 'is_mb_pair']
+        valdata = self.props.loc[isval, self.rf_feats]
+        valtruth = self.props.loc[isval, 'is_mb_pair']
+        testdata = self.props.loc[istest, self.rf_feats]
+        testtruth = self.props.loc[istest, 'is_mb_pair']
+        metrics = tuple(SCORING_METRICS.values())
+        return TrainValTestProperty(
+                Score(*(m(self.model, data, truth) for m in metrics)),
+                Score(*(m(self.model, valdata, valtruth) for m in metrics)),
+                Score(*(m(self.model, testdata, testtruth) for m in metrics)))
+
+    def grid_search_summary(self):
+        if not isinstance(getattr(self, '_rf', None), GridSearchCV):
+            raise BadProcess('"explore_hyperparams" has not been run')
+        df = pd.DataFrame(self._rf.cv_results_)
+        disp_cols = [c for c in df.columns if c.startswith('mean_')
+                     or c.startswith('param_')]
+        return df.loc[self._rf.best_index_, disp_cols]
+
+    def plot_PR(self):
+        if not isinstance(getattr(self, '_rf', None), GridSearchCV):
+            raise BadProcess('"explore_hyperparams" has not been run')
+        best_rf = self._rf.best_estimator_
+        isval = self.props['validation']
+        valdata = self.props.loc[isval, self.rf_feats]
+        valtruth = self.props.loc[isval, 'is_mb_pair']
+        plot_precision_recall_curve(best_rf, valdata, valtruth)
+
diff --git a/python/baby/training/cnn_trainer.py b/python/baby/training/cnn_trainer.py
index 2807b69be04b387dfea9fbad932e6c5bacccc80a..91ba33e80427505a4c86d2f743af856f63010b1d 100644
--- a/python/baby/training/cnn_trainer.py
+++ b/python/baby/training/cnn_trainer.py
@@ -1,12 +1,14 @@
 # If you publish results that make use of this software or the Birth Annotator
 # for Budding Yeast algorithm, please cite:
-# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S
-# Swain, 2021, Birth Annotator for Budding Yeast (in preparation).
+# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+# and Swain, P.S. (2023). Determining growth rates from bright-field images of
+# budding cells through identifying overlaps. eLife. 12:e79812.
+# https://doi.org/10.7554/eLife.79812
 # 
 # 
 # The MIT License (MIT)
 # 
-# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021
+# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
 # 
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to
@@ -25,26 +27,28 @@
 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
 # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
 # IN THE SOFTWARE.
+from typing import List, Tuple, Union
+from types import MappingProxyType
 import json
 import pathlib
-from typing import List, Tuple
-
-import matplotlib.pyplot as plt
 import pickle
+import matplotlib.pyplot as plt
+from matplotlib.colors import to_rgba
+from scipy.signal import savgol_filter
 import tensorflow as tf
+from tensorflow.keras.callbacks import (ModelCheckpoint, TensorBoard,
+                                        LearningRateScheduler, CSVLogger)
+from tensorflow.keras.models import load_model
+
 from baby.augmentation import Augmenter
 from baby.generator import ImageLabel
 from baby.preprocessing import SegmentationFlattening
-from matplotlib.colors import to_rgba
-from scipy.signal import savgol_filter
-from tensorflow.python.keras.callbacks import ModelCheckpoint, TensorBoard, \
-    LearningRateScheduler
-from tensorflow.python.keras.models import load_model
-
 from baby import models
 from baby.errors import BadType, BadProcess
 from baby.losses import bce_dice_loss, dice_coeff
 from baby.utils import get_name, schedule_steps
+from .utils import SharedParameterContainer
+from .flattener_trainer import FlattenerTrainer
 
 custom_objects = {'bce_dice_loss': bce_dice_loss, 'dice_coeff': dice_coeff}
 
@@ -52,45 +56,73 @@ OPT_WEIGHTS_FILE = 'weights.h5'
 INIT_WEIGHTS_FILE = 'init_weights.h5'
 FINAL_WEIGHTS_FILE = 'final_weights.h5'
 HISTORY_FILE = 'history.pkl'
+CSV_HISTORY_FILE = 'history.csv'
 LOG_DIR = 'logs'
+HYPER_PARAMS_FILE = 'hyperparameters.json'
 
 
 class CNNTrainer:
+    """Methods for optimising CNN weights using gradient descent.
+
+    After training, the optimised weights are be stored in a file named
+    :py:data:`OPT_WEIGHTS_FILE` within the :py:attr:`cnn_dir` directory. The
+    directory is specific to each of the architectures listed in ``cnn_set``.
+
+    Args:
+        shared_params: Training and segmentation parameters as provided by
+            :py:class:`utils.SharedParameterContainer`.
+        flattener_trainer: Trainer that defines optimised targets for the CNN.
+        max_cnns: The maximum number of CNNs to keep in memory, default is 3.
+        cnn_fn: The CNN architecture to start with. Defaults to the
+            :py:attr:`utils.BabyTrainingParameters.cnn_fn` architecture
+            as found in ``shared_params.parameters``.
     """
-    Methods for optimising the weights of the CNNs using gradient descent.
-    """
-    def __init__(self, save_dir: pathlib.Path,
-                 cnn_set: tuple, gen: ImageLabel, aug: Augmenter,
-                 flattener: SegmentationFlattening,
+    def __init__(self,
+                 shared_params: SharedParameterContainer,
+                 flattener_trainer: FlattenerTrainer,
                  max_cnns: int = 3,
                  cnn_fn: str = None):
-        """
-        Methods for optimising the weights of the CNNs using gradient descent.
-
-        :param save_dir: base directory in which to save weights and outputs
-        :param cnn_set: the names of CNN architectures to be trained
-        :param gen: the data generator
-        :param aug: the data augmentor
-        :param flattener: the data flattener
-        :param max_cnns: the maximum number of CNNs to train/keep, default is 3
-        :param cnn_fn: the CNN architecture to start with, defaults to None
-        """
-        self.flattener = flattener
-        self.aug = aug
-        self.gen = gen
+        self._shared_params = shared_params
+        self._flattener_trainer = flattener_trainer
+
         # TODO no longer needed with hyperparmeter optim
         self._max_cnns = max_cnns
-        self.save_dir = save_dir
-        self.cnn_set = cnn_set
         self._cnn_fn = cnn_fn
         self._cnns = dict()
         self._opt_cnn = None
 
+    @property
+    def save_dir(self):
+        """Base directory in which to save trained models"""
+        return self._shared_params.save_dir
+
+    @property
+    def cnn_set(self):
+        return self._shared_params.parameters.cnn_set
+
+    @property
+    def gen(self):
+        """Data generators used for training models.
+
+        A :py:class:`utils.TrainValTestProperty` of training, validation and
+        test data generators with augmentations obtained from
+        :py:attr:`FlattenerTrainer.default_gen`.
+        """
+        return self._flattener_trainer.default_gen
+
+    @property
+    def flattener(self):
+        """Target definitions for models trained by this class."""
+        return self._flattener_trainer.flattener
+
     @property
     def cnn_fn(self):
         """The current CNN architecture function."""
         if self._cnn_fn is None:
-            self.cnn_fn = self.cnn_set[0]
+            if self._shared_params.parameters.cnn_fn is None:
+                self.cnn_fn = self.cnn_set[0]
+            else:
+                self.cnn_fn = self._shared_params.parameters.cnn_fn
         return getattr(models, self._cnn_fn)
 
     @cnn_fn.setter
@@ -100,6 +132,7 @@ class CNNTrainer:
         if not hasattr(models, fn):
             raise BadType('That is not a recognised model')
         self._cnn_fn = fn
+        self._hyperparameters = None
 
     @property
     def cnn_dir(self):
@@ -114,39 +147,70 @@ class CNNTrainer:
         """The name of the currently active CNN architecture."""
         return get_name(self.cnn_fn)
 
+    @property
+    def hyperparameters(self):
+        """Custom hyperparameters defined for the active CNN architecture.
+
+        A ``dict`` specifying keyword arguments to be passed when building the
+        active CNN architecture. If left unset, any defaults as given in
+        :py:mod:`baby.models` will be used.
+
+        NB: The property is returned as a :py:class:`types.MappingProxyType`,
+        so ``dict`` items cannot be modified. To change hyperparameters, the
+        whole ``dict`` needs to be replaced.
+        """
+        if not getattr(self, '_hyperparameters', None):
+            # Load hyperparameters
+            hyper_param_file = self.cnn_dir / "hyperparameters.json"
+            if hyper_param_file.exists():
+                with open(hyper_param_file, 'r') as fd:
+                    self._hyperparameters = json.load(fd)
+            else:
+                # Use defaults specified in `models`
+                self._hyperparameters = {}
+        return MappingProxyType(self._hyperparameters)
+
+    @hyperparameters.setter
+    def hyperparameters(self, params):
+        if not type(params) == dict or type(params) == MappingProxyType:
+            raise BadType('Hyperparameters must be specified as a `dict`')
+        hyper_param_file = self.cnn_dir / "hyperparameters.json"
+        with open(hyper_param_file, 'w') as f:
+            json.dump(dict(params), f)
+        self._hyperparameters = dict(params)
+        # If the hyperparameters have changed, we need to regenerate the model
+        # and initial weights
+        if self.cnn_name in self._cnns:
+            del self._cnns[self.cnn_name]
+        # Delete initial weights if they have already been saved
+        init_weights_file = self.cnn_dir / INIT_WEIGHTS_FILE
+        init_weights_file.unlink(missing_ok=True)
+
     @property
     def cnn(self):
         """The keras Model for the active CNN."""
         if self.cnn_name not in self._cnns:
-            if len(self._cnns) > self._max_cnns:
+            n_loaded = getattr(self, '_n_cnns_loaded', 0)
+            if n_loaded > self._max_cnns:
                 # To avoid over-consuming memory reset graph
                 # TODO: ensure TF1/TF2 compat and check RTX bug
                 tf.keras.backend.clear_session()
                 # Reset any potentially loaded models
                 self._cnns = dict()
                 self._opt_cnn = None
-            # Todo: separate generator from trainer
-            #   Make model accept an input shape and a set of outputs
-            self.gen.train.aug = self.aug.train
+                n_loaded = 0
+
             print('Loading "{}" CNN...'.format(self.cnn_name))
-            # Load hyperparameters
-            hyper_param_file = self.cnn_dir / "hyperparameters.json"
-            if not hyper_param_file.exists():
-                # Todo: just use defaults
-                raise FileNotFoundError("Hyperparameter file {} for {} not "
-                                        "found.".format(hyper_param_file,
-                                                        self.cnn_name))
-            with open(hyper_param_file, 'r') as fd:
-                hyperparameters = json.load(fd)
             model = self.cnn_fn(self.gen.train, self.flattener,
-                                **hyperparameters)
+                                **self.hyperparameters)
             self._cnns[self.cnn_name] = model
+            self._n_cnns_loaded = n_loaded + 1
 
             # Save initial weights if they haven't already been saved
-            filename = self.cnn_dir / INIT_WEIGHTS_FILE
-            if not filename.exists():
+            init_weights_file = self.cnn_dir / INIT_WEIGHTS_FILE
+            if not init_weights_file.exists():
                 print('Saving initial weights...')
-                model.save_weights(str(filename))
+                model.save_weights(str(init_weights_file))
         return self._cnns[self.cnn_name]
 
     @property
@@ -195,7 +259,7 @@ class CNNTrainer:
         return self._opt_cnn
 
     def fit(self, epochs: int = 400,
-            schedule: List[Tuple[int, float]] = None,
+            schedule: Union[str, List[Tuple[int, float]]] = None,
             replace: bool = False,
             extend: bool = False):
         """Fit the active CNN to minimise loss on the (augmented) generator.
@@ -214,6 +278,11 @@ class CNNTrainer:
         if schedule is None:
             schedule = [(1e-3, epochs)]
 
+        if callable(schedule):
+            schedulefn = schedule
+        else:
+            schedulefn = lambda epoch: schedule_steps(epoch, schedule)
+
         finalfile = self.cnn_dir / FINAL_WEIGHTS_FILE
         if extend:
             self.cnn.load_weights(str(finalfile))
@@ -225,22 +294,28 @@ class CNNTrainer:
         if not replace and optfile.is_file():
             raise BadProcess('Optimised weights already exist')
 
+        csv_history_file = self.cnn_dir / CSV_HISTORY_FILE
         logdir = self.cnn_dir / LOG_DIR
         callbacks = [
             ModelCheckpoint(filepath=str(optfile),
                             monitor='val_loss',
                             save_best_only=True,
                             verbose=1),
+            CSVLogger(filename=str(csv_history_file), append=extend),
             TensorBoard(log_dir=str(logdir)),
-            LearningRateScheduler(
-                lambda epoch: schedule_steps(epoch, schedule))
+            LearningRateScheduler(schedulefn)
         ]
-        self.gen.train.aug = self.aug.train
-        self.gen.val.aug = self.aug.val
-        history = self.cnn.fit_generator(generator=self.gen.train,
-                                         validation_data=self.gen.val,
-                                         epochs=epochs,
-                                         callbacks=callbacks)
+
+        if tf.version.VERSION.startswith('1'):
+            history = self.cnn.fit_generator(generator=self.gen.train,
+                                             validation_data=self.gen.val,
+                                             epochs=epochs,
+                                             callbacks=callbacks)
+        else:
+            history = self.cnn.fit(self.gen.train,
+                                   validation_data=self.gen.val,
+                                   epochs=epochs,
+                                   callbacks=callbacks)
 
         # Save history
         with open(self.cnn_dir / HISTORY_FILE, 'wb') as f:
diff --git a/python/baby/training/flattener_trainer.py b/python/baby/training/flattener_trainer.py
index 36e47ff108e620e745721f884d32fd01b606638b..63dcf16d1451c5b760a863c05ef412767701fb1b 100644
--- a/python/baby/training/flattener_trainer.py
+++ b/python/baby/training/flattener_trainer.py
@@ -1,23 +1,25 @@
 # If you publish results that make use of this software or the Birth Annotator
 # for Budding Yeast algorithm, please cite:
-# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S
-# Swain, 2021, Birth Annotator for Budding Yeast (in preparation).
-#
-#
+# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+# and Swain, P.S. (2023). Determining growth rates from bright-field images of
+# budding cells through identifying overlaps. eLife. 12:e79812.
+# https://doi.org/10.7554/eLife.79812
+# 
+# 
 # The MIT License (MIT)
-#
-# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021
-#
+# 
+# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
+# 
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to
 # deal in the Software without restriction, including without limitation the
 # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
 # sell copies of the Software, and to permit persons to whom the Software is
 # furnished to do so, subject to the following conditions:
-#
+# 
 # The above copyright notice and this permission notice shall be included in
 # all copies or substantial portions of the Software.
-#
+# 
 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
@@ -28,22 +30,67 @@
 """Optimising the hyper-parameters of the `SegmentationFlattener`"""
 import json
 import pathlib
+from itertools import chain
 
 import matplotlib.pyplot as plt
 import numpy as np
+from scipy.ndimage import binary_erosion
+
 from baby.augmentation import Augmenter
 from baby.errors import BadProcess, BadType, BadParam
 from baby.generator import ImageLabel
 from baby.preprocessing import dwsquareconn, SegmentationFlattening
 from baby.utils import find_file
-from scipy.ndimage import binary_erosion
+from baby.visualise import colour_segstack
+
+from .utils import TrainValProperty, standard_augmenter
+from .smoothing_model_trainer import SmoothingModelTrainer
+
 
-from .utils import augmented_generator, TrainValProperty
+def _generate_flattener_stats_inner(segs, nerode, keep_zero):
+    segs = segs > 0
+    # First remove any sections that may have been augmented away
+    # NB: the ScalingAugmenter does remove most of these at the rough_crop
+    # step, but there will still be many excluded after the second crop
+    segs = segs[..., segs.any(axis=(0,1))]
+    s_sizes = segs.sum(axis=(0,1)).tolist()
+    nsegs = segs.shape[2]
+    esizes = [[] for s in range(nsegs)]
+    overlap_sizes = [[] for _ in nerode]
+    for e in nerode:
+        for s0 in range(nsegs):
+            seg0 = segs[..., s0]
+            n0 = int(seg0.sum())
+            esizes[s0].append(n0)
+            if n0 == 0:
+                continue
+            for s1 in range(s0 + 1, nsegs):
+                seg1 = segs[..., s1]
+                # Calculate number of overlapping pixels
+                nO = np.sum(seg0 & seg1)
+                # Calculate fraction overlap
+                fO = float(nO / np.sum(seg0 | seg1))
+                if fO > 0 or keep_zero:
+                    sizes = tuple(sorted([s_sizes[s0], s_sizes[s1]]))
+                    if keep_zero:
+                        overlap_sizes[e].append(sizes + (fO, nO))
+                    else:
+                        overlap_sizes[e].append(sizes + (fO,))
+        segs = binary_erosion(segs, dwsquareconn)
+    return esizes, overlap_sizes
+
+
+def _batch_parallel_generator(gen, sample_inds, batch_size=8, n_jobs=4):
+    for i in range(np.ceil(len(sample_inds) / batch_size).astype('int')):
+        yield gen.parallel_get_indices(
+            sample_inds[batch_size*i:batch_size*(i+1)], n_jobs=n_jobs)
 
 
 def _generate_flattener_stats(gen: ImageLabel,
                               max_erode: int,
-                              keep_zero=False) -> dict:
+                              keep_zero=False,
+                              batch_size=16,
+                              n_jobs=4) -> dict:
     """ Generates flattener statistics of the data output by the generator.
 
     This function measures the size (in pixels) of segmentation mask
@@ -61,36 +108,24 @@ def _generate_flattener_stats(gen: ImageLabel,
     size after applying the indexed number of erosions.
     """
     nerode = list(range(max_erode + 1))
-    overlap_sizes = [[] for _ in nerode]
-    erosion_sizes = []
 
-    for t in range(len(gen.paths)):
-        _, segs = gen.get_by_index(t)
-        nsegs = segs.shape[2]
-        segs = segs > 0
-        s_sizes = [int(segs[..., s].sum()) for s in range(nsegs)]
-        esizes = [[] for s in range(nsegs)]
-        for e in nerode:
-            for s0 in range(nsegs):
-                seg0 = segs[..., s0]
-                n0 = int(seg0.sum())
-                esizes[s0].append(n0)
-                if n0 == 0:
-                    continue
-                for s1 in range(s0 + 1, nsegs):
-                    seg1 = segs[..., s1]
-                    # Calculate number of overlapping pixels
-                    nO = np.sum(seg0 & seg1)
-                    # Calculate fraction overlap
-                    fO = float(nO / np.sum(seg0 | seg1))
-                    if fO > 0 or keep_zero:
-                        sizes = tuple(sorted([s_sizes[s0], s_sizes[s1]]))
-                        if keep_zero:
-                            overlap_sizes[e].append(sizes + (fO, nO))
-                        else:
-                            overlap_sizes[e].append(sizes + (fO,))
-            segs = binary_erosion(segs, dwsquareconn)
-        erosion_sizes.extend(esizes)
+    sample_inds = np.repeat(np.arange(len(gen.nsamples)), gen.nsamples)
+    sample_gen = _batch_parallel_generator(gen, sample_inds, n_jobs=n_jobs)
+    n_batches = np.ceil(len(sample_inds) / batch_size).astype('int')
+    erosion_sizes = []
+    overlap_sizes = [[] for _ in range(len(nerode))]
+    from joblib import Parallel, delayed
+    from tqdm import trange
+    for i in trange(n_batches):
+        gen_batch = gen.parallel_get_indices(
+            sample_inds[batch_size*i:batch_size*(i+1)], n_jobs=n_jobs)
+        e_sizes, o_sizes = zip(*Parallel(n_jobs=n_jobs)(
+            delayed(_generate_flattener_stats_inner)(segs, nerode, keep_zero)
+            for _, segs in gen_batch))
+        erosion_sizes.extend(chain(*e_sizes))
+        o_sizes = [chain(*e) for e in zip(*o_sizes)]
+        for e, o in zip(o_sizes, overlap_sizes):
+            o.extend(e)
 
     return {'overlap_sizes': overlap_sizes, 'erosion_sizes': erosion_sizes}
 
@@ -128,7 +163,7 @@ def _group_overlapping(os, thresh, pad=0):
        ])
 
 
-def _best_overlapping(overlapping, erosion_sizes, min_size):
+def _best_overlapping(overlapping, erosion_sizes, min_size, min_size_frac):
     """Return overlap stats for highest level of erosion without losing cells
 
     Binary erosion is valid if it does not reduce the area of any cells below
@@ -145,8 +180,12 @@ def _best_overlapping(overlapping, erosion_sizes, min_size):
     """
     # Rearrange `erosion_sizes` by number of applied erosions
     sz_erosions = list(zip(*erosion_sizes))
-    # Erosions are invalid if any cells drop below the minimum allowed size
-    e_invalid = [any([c < min_size for c in e]) for e in sz_erosions[:0:-1]]
+    min_median_sz = np.median(sz_erosions[0]) * min_size_frac
+    # Erosions are invalid if any cells drop below the minimum allowed size,
+    # or if the median size drops below a fraction of the original median
+    e_invalid = [np.any(np.array(e) < min_size)
+                 or np.median(e) < min_median_sz
+                 for e in sz_erosions[:0:-1]]
     # Return only overlap stats for valid numbers of erosions
     o_valid = [o for o, e in zip(overlapping[:0:-1], e_invalid) if not e]
     o_valid += [overlapping[0]]
@@ -199,54 +238,74 @@ def _best_nerode(szg, min_size):
 
 
 class FlattenerTrainer:
+    """Optimises the hyper-parameters for the `SegmentationFlattener`.
 
-    def __init__(self, save_dir: pathlib.Path, stats_file: str,
-                 flattener_file: str):
-        """Optimises the hyper-parameters for the `SegmentationFlattener`.
-
+    #TODO describe method for optimisation
 
-
-        #TODO describe method for optimisation
-        :param save_dir: the base directory in which to save outputs
-        :param stats_file: the name of the file in which the stats are saved
-        :param flattener_file: the name of the file defining the flattener
-        """
-        self.save_dir = save_dir
-        self.stats_file = self.save_dir / stats_file
-        self.flattener_file = self.save_dir / flattener_file
+    Args:
+        shared_params (utils.SharedParameterContainer): Training and
+            segmentation parameters as provided by
+            :py:class:`utils.SharedParameterContainer`.
+        shared_data (utils.SharedDataContainer): Access to training data.
+        ssm_trainer (SmoothingModelTrainer): Trainer from which to obtain a
+            smoothing sigma model.
+    """
+    def __init__(self, shared_params, shared_data, ssm_trainer):
+        self._shared_params = shared_params
+        self._shared_data = shared_data
+        self._ssm_trainer = ssm_trainer
         self._flattener = None
         self._stats = None
 
-    def generate_flattener_stats(self,
-                                 train_gen: ImageLabel,
-                                 val_gen: ImageLabel,
-                                 train_aug: Augmenter,
-                                 val_aug: Augmenter,
-                                 max_erode: int = 5):
+    @property
+    def save_dir(self):
+        return self._shared_params.save_dir
+
+    @property
+    def stats_file(self):
+        """File in which to save derived data for training flattener"""
+        return (self.save_dir /
+                self._shared_params.parameters.flattener_stats_file)
+
+    @property
+    def flattener_file(self):
+        return (self.save_dir /
+                self._shared_params.parameters.flattener_file)
+
+    def generate_flattener_stats(self, max_erode: int = 5, n_jobs=None):
         """Generate overlap and erosion statistics for augmented data in input.
 
-        :param train_gen: the generator of training images and their labels
-        :param val_gen: the generator of validation images and their labels
-        :param train_aug: the augmenter to use for training images
-        :param val_aug: the augmenter to use for validation images
-        :param max_erode: the maximum allowed number of erosions used to
+        Saves results to file specified in :py:attr:`stats_file`.
+
+        Args:
+            max_erode: the maximum allowed number of erosions used to
         generate erosion values
         :return: None, saves results to `self.stats_file`
         """
-        with augmented_generator(train_gen, train_aug) as gen:
-            fs_train = _generate_flattener_stats(gen, max_erode)
-        with augmented_generator(val_gen, val_aug) as gen:
-            fs_val = _generate_flattener_stats(gen, max_erode)
+        dummy_flattener = lambda x, y: x
+        ssm = self._ssm_trainer.model
+        params = self._shared_params.parameters
+        # NB: use isval=True als for training aug since we do not need extra
+        # augmentations for calibrating the flattener
+        aug = standard_augmenter(ssm, dummy_flattener, params, isval=True)
+        train_gen, val_gen, _ = self._shared_data.gen_with_aug(aug)
+
+        if n_jobs is None:
+            n_jobs = params.n_jobs
+        fs_train = _generate_flattener_stats(train_gen, max_erode,
+                                             n_jobs=n_jobs)
+        fs_val = _generate_flattener_stats(val_gen, max_erode,
+                                           n_jobs=n_jobs)
         with open(self.stats_file, 'wt') as f:
             json.dump({'train': fs_train, 'val': fs_val}, f)
         self._stats = None  # trigger reload of property
 
     @property
     def stats(self) -> TrainValProperty:
-        """The last statistics computed, loaded from `self.stats_file`
+        """The last statistics computed, loaded from :py:attr:`stats_file`.
 
-        :return: The last training and validation statistics computed.
-        :raises: BadProcess error if the file does not exist.
+        Raises:
+            BadProcess: If the file does not exist.
         """
         if self._stats is None:
             if not self.stats_file.exists():
@@ -258,6 +317,42 @@ class FlattenerTrainer:
         return TrainValProperty(self._stats.get('train', {}),
                                 self._stats.get('val', {}))
 
+    def plot_stats(self, nbins=30, nrows=1, sqrt_area=False):
+        """Plot a histogram of cell overlap statistics of the training set.
+
+        # TODO describe what the plot means
+        # TODO add an image as an example
+
+        :param nbins: binning of data, passed to `matplotlib.pyplot.hist2d`
+        :return: None, saves the resulting figure under `self.save_dir /
+        "flattener_stats.png"`
+        """
+        overlapping = self.stats.train.get('overlap_sizes', []).copy()
+        if sqrt_area:
+            # Transform all areas by sqrt
+            for i, o in enumerate(overlapping):
+                x, y, f = zip(*o)
+                overlapping[i] = list(zip(np.sqrt(x), np.sqrt(y), f))
+        max_erode = len(overlapping)
+        ncols = int(np.ceil(max_erode / nrows))
+        fig, axs = plt.subplots(nrows, ncols, figsize=(4 * ncols, 4 * nrows))
+        axs = axs.flatten()
+        x, y, _ = zip(*overlapping[0])
+        max_size = max(x + y)
+        for ax, (e, os) in zip(axs, enumerate(overlapping)):
+            if len(os) > 0:
+                x, y, w = zip(*os)
+            else:
+                x, y, w = 3 * [[]]
+            ax.hist2d(x,
+                      y,
+                      bins=nbins,
+                      weights=w,
+                      range=[[0, max_size], [0, max_size]])
+            ax.plot((0, max_size), (0, max_size), 'r')
+            ax.set_title('nerosions = {:d}'.format(e))
+        fig.savefig(self.save_dir / 'flattener_stats.png')
+
     @property
     def flattener(self) -> SegmentationFlattening:
         """The last flattener saved to file.
@@ -286,7 +381,8 @@ class FlattenerTrainer:
         f.save(self.flattener_file)
         self._flattener = f
 
-    def fit(self, nbins=30, min_size=10, pad_frac=0.03, bud_max=200):
+    def fit(self, nbins=30, min_size=10, min_size_frac=0., pad_frac=0.03,
+            bud_max=200, sqrt_area=False, overlaps=None):
         """Optimise the parameters of the `SegmentationFlattener` based on
         previously computed statistics.
 
@@ -305,13 +401,21 @@ class FlattenerTrainer:
             raise BadParam('"pad_frac" must be between 0 and 0.2')
 
         # Load generated stats for training data
-        overlapping = self.stats.train.get('overlap_sizes', [])
-        erosion_sizes = self.stats.train.get('erosion_sizes', [])
+        overlapping = self.stats.train.get('overlap_sizes', []).copy()
+        erosion_sizes = self.stats.train.get('erosion_sizes', []).copy()
         if len(overlapping) == 0 or len(erosion_sizes) == 0 or \
                 len(list(zip(*erosion_sizes))) != len(overlapping):
             raise BadProcess(
                 '"flattener_stats.json" file appears to be corrupted')
 
+        if sqrt_area:
+            # Transform all areas by sqrt
+            min_size = np.sqrt(min_size)
+            erosion_sizes = np.sqrt(erosion_sizes).tolist()
+            for i, o in enumerate(overlapping):
+                x, y, f = zip(*o)
+                overlapping[i] = list(zip(np.sqrt(x), np.sqrt(y), f))
+
         # Find the best single split point by brute force iteration over a
         # binned version of the training data
 
@@ -323,7 +427,8 @@ class FlattenerTrainer:
         edges = np.linspace(pad, max_size - pad, nbins)[1:-1]
 
         # Use overlap stats at maximum valid level of erosion
-        o_maxerode = _best_overlapping(overlapping, erosion_sizes, min_size)
+        o_maxerode = _best_overlapping(overlapping, erosion_sizes, min_size,
+                                       min_size_frac)
         # Then iterate over the thresholds (edges) to find which split
         # minimises the overlap fraction
         split0, w0 = _find_best_fgroup_split(o_maxerode, edges, pad=pad)
@@ -335,8 +440,8 @@ class FlattenerTrainer:
         szgL, szgH = _group_sizes(erosion_sizes, split0, pad=pad)
 
         # And again use the overlap stats at maximum valid level of erosion
-        ogL = _best_overlapping(ogL, szgL, min_size)
-        ogH = _best_overlapping(ogH, szgH, min_size)
+        ogL = _best_overlapping(ogL, szgL, min_size, min_size_frac)
+        ogH = _best_overlapping(ogH, szgH, min_size, min_size_frac)
 
         w_ogL = sum([w for _, _, w in ogL])
         w_ogH = sum([w for _, _, w in ogH])
@@ -361,14 +466,21 @@ class FlattenerTrainer:
         ne1 = _best_nerode(szg1, min_size)
         ne2 = _best_nerode(szg2, min_size)
 
+        if sqrt_area:
+            untransform = lambda x: int(np.round(np.square(x)))
+        else:
+            untransform = lambda x: int(np.round(x))
+
         flattener = SegmentationFlattening()
 
-        flattener.addGroup('small', upper=int(np.round(splits[0] + pad)))
+        flattener.addGroup('small', upper=untransform(splits[0] + pad))
         flattener.addGroup('medium',
-                           lower=int(np.round(splits[0] - pad)),
-                           upper=int(np.round(splits[1] + pad)))
-        flattener.addGroup('large', lower=int(np.round(splits[1] - pad)))
+                           lower=untransform(splits[0] - pad),
+                           upper=untransform(splits[1] + pad))
+        flattener.addGroup('large', lower=untransform(splits[1] - pad))
         flattener.addGroup('buds', upper=bud_max, budonly=True)
+        if overlaps == 'all':
+            flattener.addGroup('all')
 
         flattener.addTarget('lge_inte', 'large', 'interior', nerode=ne2)
         flattener.addTarget('lge_edge', 'large', 'edge')
@@ -377,35 +489,48 @@ class FlattenerTrainer:
         flattener.addTarget('sml_inte', 'small', 'filled', nerode=ne0)
         flattener.addTarget('sml_edge', 'small', 'edge')
         flattener.addTarget('bud_neck', 'buds', 'budneck')
+        if overlaps == 'all':
+            flattener.addTarget('all_ovlp', 'all', 'overlap')
 
-        flattener.save(self.flattener_file)
-        self._flattener = None
+        self.flattener = flattener
 
-    def plot_stats(self, nbins=30):
-        """Plot a histogram of cell overlap statistics of the training set.
-
-        # TODO describe what the plot means
-        # TODO add an image as an example
+    @property
+    def default_gen(self):
+        """Get default data generators using the current flattener."""
+        ssm = self._ssm_trainer.model
+        params = self._shared_params.parameters
+        t_aug = standard_augmenter(ssm, self.flattener, params, isval=False)
+        v_aug = standard_augmenter(ssm, self.flattener, params, isval=True)
+        return self._shared_data.gen_with_aug((t_aug, v_aug, v_aug))
+
+    def plot_default_gen_sample(self, i=0, figsize=3, validation=False):
+        g = self.default_gen.val if validation else self.default_gen.train
+        img_batch, lbl_batch = g[i]
+        lbl_batch = np.concatenate(lbl_batch, axis=3)
+
+        f = self.flattener
+        target_names = f.names()
+        edge_inds = np.flatnonzero([t.prop == 'edge' for t in f.targets])
+
+        ncol = len(img_batch)
+        nrow = len(target_names) + 1
+        fig = plt.figure(figsize=(figsize * ncol, figsize * nrow))
+        for b, (bf, seg) in enumerate(zip(img_batch, lbl_batch)):
+            plt.subplot(nrow, ncol, b + 0 * ncol + 1)
+            plt.imshow(bf[:, :, 0], cmap='gray')
+            plt.imshow(colour_segstack(seg[:, :, edge_inds], dw=True))
+            plt.grid(False)
+            plt.xticks([])
+            plt.yticks([])
+
+            for i, name in enumerate(target_names):
+                plt.subplot(nrow, ncol, b + (i + 1) * ncol + 1)
+                plt.imshow(seg[:, :, i], cmap='gray')
+                plt.grid(False)
+                plt.xticks([])
+                plt.yticks([])
+                plt.title(name)
+
+        fig.savefig(self.save_dir / '{}_generator_sample.png'.format(
+            'validation' if validation else 'training'))
 
-        :param nbins: binning of data, passed to `matplotlib.pyplot.hist2d`
-        :return: None, saves the resulting figure under `self.save_dir /
-        "flattener_stats.png"`
-        """
-        overlapping = self.stats.train.get('overlap_sizes', [])
-        max_erode = len(overlapping)
-        fig, axs = plt.subplots(1, max_erode, figsize=(16, 16 / max_erode))
-        x, y, _ = zip(*overlapping[0])
-        max_size = max(x + y)
-        for ax, (e, os) in zip(axs, enumerate(overlapping)):
-            if len(os) > 0:
-                x, y, w = zip(*os)
-            else:
-                x, y, w = 3 * [[]]
-            ax.hist2d(x,
-                      y,
-                      bins=nbins,
-                      weights=w,
-                      range=[[0, max_size], [0, max_size]])
-            ax.plot((0, max_size), (0, max_size), 'r')
-            ax.set_title('nerosions = {:d}'.format(e))
-        fig.savefig(self.save_dir / 'flattener_stats.png')
diff --git a/python/baby/training/hyper_parameter_trainer.py b/python/baby/training/hyper_parameter_trainer.py
index cedbcaf5176534268cea7df7197332c91445e5ad..b52dcaaa419c26583293563fbf181f22d848eac1 100644
--- a/python/baby/training/hyper_parameter_trainer.py
+++ b/python/baby/training/hyper_parameter_trainer.py
@@ -1,12 +1,14 @@
 # If you publish results that make use of this software or the Birth Annotator
 # for Budding Yeast algorithm, please cite:
-# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S
-# Swain, 2021, Birth Annotator for Budding Yeast (in preparation).
+# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+# and Swain, P.S. (2023). Determining growth rates from bright-field images of
+# budding cells through identifying overlaps. eLife. 12:e79812.
+# https://doi.org/10.7554/eLife.79812
 # 
 # 
 # The MIT License (MIT)
 # 
-# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021
+# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
 # 
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to
@@ -29,12 +31,11 @@ import json
 from pathlib import Path
 from typing import Union
 
+from baby.generator import augmented_generator
 from baby.training.hypermodels import get_hypermodel
 from kerastuner import RandomSearch, Hyperband, BayesianOptimization, Tuner, \
     HyperModel
 
-from .utils import augmented_generator
-
 
 def instantiate_tuner(model, method='random', **kwargs):
     method = method.lower()
diff --git a/python/baby/training/hypermodels.py b/python/baby/training/hypermodels.py
index 3ebd0a340d7bca196be2c76906ef6688ddab4a77..01a6aa7356e7e452ad4b0a9a45ef01b94526454d 100644
--- a/python/baby/training/hypermodels.py
+++ b/python/baby/training/hypermodels.py
@@ -1,12 +1,14 @@
 # If you publish results that make use of this software or the Birth Annotator
 # for Budding Yeast algorithm, please cite:
-# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S
-# Swain, 2021, Birth Annotator for Budding Yeast (in preparation).
+# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+# and Swain, P.S. (2023). Determining growth rates from bright-field images of
+# budding cells through identifying overlaps. eLife. 12:e79812.
+# https://doi.org/10.7554/eLife.79812
 # 
 # 
 # The MIT License (MIT)
 # 
-# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021
+# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
 # 
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to
@@ -25,11 +27,13 @@
 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
 # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
 # IN THE SOFTWARE.
-from baby.layers import msd_block, make_outputs, unet_block
+from baby.layers import msd_block, make_outputs, unet_block, res_block, conv_block
 from baby.losses import dice_coeff, bce_dice_loss
-from kerastuner import HyperModel
-from tensorflow.python.keras import Input, Model
-from tensorflow.python.keras.optimizers import Adam
+from keras_tuner import HyperModel
+from tensorflow.keras import Input, Model
+from tensorflow.keras.initializers import VarianceScaling
+from tensorflow.keras.optimizers import Adam
+from tensorflow_addons.optimizers import AdamW
 
 
 
@@ -48,18 +52,44 @@ class UNet(HyperModel):
         return dict(depth=4, layer_size=8, batchnorm=True, dropout=0.0)
 
     def build(self, hp):
+        # Universal variants
+        width = hp.Choice('width', [8, 16, 32, 64])
+        depth = hp.Int('depth', min_value=2, max_value=5)
+        kernel = hp.Choice('kernel', [3, 5, 7])
+        init = hp.Choice('initializer', ['glorot_uniform', 'variance_scaling'])
+        if init == 'variance_scaling':
+            init = VarianceScaling(2., mode='fan_out')
+        activation = hp.Choice('activation', ['relu', 'gelu', 'swish'])
+        conv_pool = hp.Boolean('conv_pool', default=True)
+        up_activate = hp.Boolean('up_activate', default=True)
+        residual_skip = hp.Boolean('residual_skip')
+        enc_repeats = hp.Int('enc_repeats', min_value=2, max_value=4)
+        dec_repeats = hp.Int('dec_repeats', min_value=2, max_value=4)
+        block_type = hp.Choice('block_type',
+                ['effnet', 'effnet-preact', 'convnext', 'conv'])
+        block_args = {
+                'conv': dict(block=conv_block, stem=False),
+                'effnet': dict(block=res_block, stem=True),
+                'effnet-preact': dict(block=res_block, pre_activate=True, stem=True),
+                'convnext': dict(block=res_block, stem=True, convnext=True)
+                }[block_type]
+        expand_ratio, drop = 1, 0.
+        if block_type != 'conv':
+            expand_ratio = hp.Choice('expand_ratio', [0.5, 1., 2., 4.])
+            width = width / expand_ratio
+            drop = hp.Choice('block_drop', [0., 0.2])
+        
         inputs = Input(shape=self.input_shape)
-        depth = hp.Int('depth', min_value=2, max_value=4, step=1)
-        layer_size = hp.Choice('layer_size', values=[8, 16, 32])
-        layer_sizes = [layer_size*(2**i) for i in range(depth)]
-        batchnorm = hp.Boolean('batchnorm', default=True)
-        dropout = hp.Float('dropout', min_value=0., max_value=0.7, step=0.1)
-        unet = unet_block(inputs, layer_sizes, batchnorm=batchnorm,
-                          dropout=dropout)
+        layer_sizes = [width*(2**i) for i in range(depth)]
+        unet = unet_block(inputs, layer_sizes, kernel=kernel, init=init,
+                activation=activation, conv_pool=conv_pool,
+                up_activate=up_activate, residual_skip=residual_skip,
+                enc_repeats=enc_repeats, dec_repeats=dec_repeats,
+                expand_ratio=expand_ratio, drop=drop, **block_args)
         model = Model(inputs=[inputs],
                       outputs=make_outputs(unet, self.outputs))
         # Todo: tuning optimizer
-        model.compile(optimizer=Adam(amsgrad=False),
+        model.compile(optimizer=AdamW(weight_decay=0.00025),
                       metrics=[dice_coeff],
                       loss=bce_dice_loss,
                       loss_weights=self.weights)
diff --git a/python/baby/training/segmentation_trainer.py b/python/baby/training/segmentation_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..68fceece662e4d4964eef6f8601b535cd149bbd3
--- /dev/null
+++ b/python/baby/training/segmentation_trainer.py
@@ -0,0 +1,982 @@
+# If you publish results that make use of this software or the Birth Annotator
+# for Budding Yeast algorithm, please cite:
+# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+# and Swain, P.S. (2023). Determining growth rates from bright-field images of
+# budding cells through identifying overlaps. eLife. 12:e79812.
+# https://doi.org/10.7554/eLife.79812
+# 
+# 
+# The MIT License (MIT)
+# 
+# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
+# 
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to
+# deal in the Software without restriction, including without limitation the
+# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+# sell copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+# 
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+# 
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
+# IN THE SOFTWARE.
+from math import floor, log10
+from itertools import combinations, product, chain, repeat, islice
+from typing import NamedTuple, Union, Tuple, Any
+import numpy as np
+import pandas as pd
+from matplotlib import pyplot as plt
+from tqdm import tqdm
+# from numba import njit
+
+from baby import segmentation
+from baby.morph_thresh_seg import (MorphSegGrouped, Group,
+                                   SegmentationParameters,
+                                   BROADCASTABLE_PARAMS)
+from baby.utils import batch_iterator, split_batch_pred
+from baby.augmentation import ScalingAugmenter
+from baby.performance import calc_IoUs, best_IoU, calc_AP
+from baby.errors import BadProcess, BadParam, BadType
+from .utils import (SharedParameterContainer, SharedDataContainer,
+                    TrainValTestProperty, standard_augmenter)
+from .smoothing_model_trainer import SmoothingModelTrainer
+from .cnn_trainer import CNNTrainer
+
+
+# Todo add default parameters to SegTrainer
+#       Search space !
+DEFAULT_SEG_PARAM_COORDS = {
+    'nclosing': [0, 1, 2],
+    'nopening': [0, 1, 2],
+    'interior_threshold': np.arange(0.3, 1.0, 0.05).tolist(),
+    'connectivity': [1, 2],
+    'edge_sub_dilations': [0, 1, 2]
+}
+
+
+class SegExample(NamedTuple):
+    """CNN output paired with target segmented outlines and info
+
+    Used for optimising segmentation hyperparameters and for training bud
+    assigment models
+    """
+    pred: np.ndarray
+    target: np.ndarray
+    info: dict
+    img: np.ndarray
+
+
+class Score(NamedTuple):
+    """Scoring metrics for segmentation performance.
+
+    Metrics are defined in terms of number of true positives ``TP``, number of
+    false positives ``FP`` and number of false negatives ``FN``.
+
+    Attributes:
+        precision: ``TP / (TP + FP)``.
+        recall: ``TP / (TP + FN)``.
+        F1: Balanced F-score ``2 * precision * recall / (precision +
+            recall)``.
+        F0_5: F-score biased for recall ``1.25 * precision * recall / (0.25 *
+            precision + recall)``.
+        F2: F-score biased for precision ``5 * precision * recall / (4 *
+            precision + recall)``.
+        meanIoU: Mean intersection over union between segmented and ground
+            truth cell masks.
+    """
+    precision: float
+    recall: float
+    F1: float
+    F0_5: float
+    F2: float
+    meanIoU: float
+
+
+def _example_generator(cnn, dgen):
+    # b_iter = batch_iterator(sorted(dgen.ordering), batch_size=dgen.batch_size)
+    # Sorting is counterproductive to method of subsample_frac
+    b_iter = batch_iterator(dgen.ordering, batch_size=dgen.batch_size)
+    with tqdm(total=len(dgen.ordering)) as pbar:
+        for b_inds in b_iter:
+            batch = dgen.parallel_get_indices(b_inds)
+
+            imgs = [img for img, _ in batch]
+            preds = split_batch_pred(
+                cnn.predict(np.stack(imgs), verbose=0))
+            for pred, (img, (lbl, info)) in zip(preds, batch):
+                pbar.update()
+                lbl = lbl.transpose(2, 0, 1)
+                # Filter out examples that have been augmented away
+                valid = lbl.sum(axis=(1, 2)) > 0
+                lbl = lbl[valid]
+                info = info.copy() # ensure we do not modify original
+                clab = info.get('cellLabels', []) or []
+                if type(clab) is int:
+                    clab = [clab]
+                clab = [l for l, v in zip(clab, valid) if v]
+                info['cellLabels'] = clab
+                buds = info.get('buds', []) or []
+                if type(buds) is int:
+                    buds = [buds]
+                buds = [b for b, v in zip(buds, valid) if v]
+                info['buds'] = buds
+                yield SegExample(pred, lbl, info, img)
+
+
+def _sub_params(sub, param_template):
+    """Helper function for ``SegFilterParamOptim.fit_filter_params``.
+
+    Substitutes parameters in a ``dict``, possibly group-specific.
+    """
+    p = {
+        k: v.copy() if type(v) == list else v
+        for k, v in param_template.items()
+    }
+    for (k, g), v in sub.items():
+        if g is None:
+            p[k] = v
+        else:
+            p[k][g] = v
+    return p
+
+
+# @njit
+def _filter_trial(containment, containment_thresh, area, min_area,
+                  pedge_thresh, group_thresh_expansion, 
+                  gIds, gLs, gUs, gRs, group,
+                  p_edge, max_IoU, IoU_thresh, assignId, nT):
+    """Numba optimised filter trial implementation.
+
+    TODO: currently the numba version kills the kernel... This implementation
+    is nonetheless orders of magnitude faster than the previous.
+    """
+
+    rejects = (containment > containment_thresh) | (area < min_area)
+    for t_pe, g_ex, g, l, u, gr in zip(pedge_thresh,
+                                       group_thresh_expansion,
+                                       gIds, gLs, gUs, gRs):
+        g_ex = g_ex * gr
+        l = max(l - g_ex, 1)
+        u = u + g_ex
+        rejects |= (group == g) & ((p_edge < t_pe) | (area < l) | (area > u))
+
+    TP_mask = (~rejects) & (max_IoU >= IoU_thresh)
+
+    # Find the maximum IoU in each eid-assignment group following the
+    # suggestion on stackoverflow:
+    # https://stackoverflow.com/questions/8623047/group-by-max-or-min-in-a-numpy-array
+    # NB: the grouping variables and IoUs are presorted when
+    # _filter_trial_cache is generated (see below).
+    TPs_IoU = max_IoU[TP_mask]
+    if TPs_IoU.size > 0:
+        grouping = assignId[TP_mask]
+        index = np.empty(TPs_IoU.size, 'bool')
+        index[-1] = True
+        last_in_group = np.any(grouping[1:] != grouping[:-1], axis=1)
+        index[:-1] = last_in_group
+        TPs_IoU = TPs_IoU[index]
+
+    nPs = np.sum(~rejects)
+    nTPs = TPs_IoU.size
+    nFPs = nPs - nTPs
+    nFNs = nT - nTPs
+    precision = nTPs / (nTPs + nFPs)
+    recall = nTPs / (nTPs + nFNs)
+    # Fbeta = (1 + beta^2) * P * R / (beta^2 * P + R)
+    F1 = 2 * precision * recall / (precision + recall)
+    F0_5 = 1.25 * precision * recall / (0.25 * precision + recall)
+    F2 = 5 * precision * recall / (4 * precision + recall)
+    return precision, recall, F1, F0_5, F2, np.mean(TPs_IoU)
+
+
+def _filter_trial_bootstraps(filter_trial_cache, pedge_thresh,
+                             group_thresh_expansion, containment_thresh,
+                             min_area, IoU_thresh, return_stderr=False):
+
+    score_boots = []
+    for df in filter_trial_cache:
+        scr = _filter_trial(df['containment'], containment_thresh,
+                            df['area'], min_area,
+                            np.array(pedge_thresh, dtype='float'),
+                            np.array(group_thresh_expansion,
+                                     dtype='float'),
+                            df['gIds'], df['gLs'], df['gUs'], df['gRs'],
+                            df['group'], df['p_edge'], df['max_IoU'],
+                            IoU_thresh, df['assignId'], df['nT'])
+        score_boots.append(Score(*scr))
+
+    score_boots = np.array(score_boots)
+    mean_score = Score(*score_boots.mean(axis=0))
+    if return_stderr:
+        stderr = score_boots.std(axis=0) / np.sqrt(score_boots.shape[0])
+        return (mean_score, Score(*stderr))
+    else:
+        return mean_score
+
+
+def _generate_stat_table(s, seg_ex, segrps, containment_func):
+    """Parallelisation helper for ``SegFilterParamOptim.generate_stat_table``."""
+
+    ncells = len(seg_ex.target) if seg_ex.target.any() else 0
+
+    # Perform within-group segmentation
+    shape = np.squeeze(seg_ex.pred[0]).shape
+    border_rect = np.pad(np.zeros(tuple(x - 2 for x in shape),
+                                  dtype='bool'),
+                         pad_width=1,
+                         mode='constant',
+                         constant_values=True)
+
+    # Generate local versions of the Group segmenters to avoid potential race
+    # conditions in parallel usage:
+    segrps = [Group(grp.targets, grp.params) for grp in segrps]
+    masks = []
+    for grp in segrps:
+        grp.segment(seg_ex.pred, border_rect)
+        for cell in grp.cells:
+            masks.append(cell.mask)
+
+    # Calculate containment scores across groups
+    contained_cells = {}
+    paired_groups = zip(segrps, segrps[1:])
+    for g, (lower_group, upper_group) in enumerate(paired_groups):
+        for l, lower in enumerate(lower_group.cells):
+            for u, upper in enumerate(upper_group.cells):
+                containment = containment_func(lower.mask, upper.mask)
+                if containment > 0:
+                    if lower.edge_score > upper.edge_score:
+                        contained_cells[(g + 1, u)] = containment
+                    else:
+                        contained_cells[(g, l)] = containment
+
+    if ncells > 0:
+        IoUs = calc_IoUs(seg_ex.target, masks, fill_holes=False)
+        max_IoU = IoUs.max(axis=0)
+        assignments = IoUs.argmax(axis=0)
+        _, best_assignments = best_IoU(IoUs.T)
+    else:
+        max_IoU = np.zeros(len(masks))
+        assignments = np.zeros(len(masks), dtype=np.uint16)
+        best_assignments = -np.ones(len(masks), dtype=np.int32)
+
+    ind = 0
+    rows = []
+    for g, grp in enumerate(segrps):
+        for c, cell in enumerate(grp.cells):
+            rows.append((s, g, c, cell.area, cell.edge_score,
+                         contained_cells.get((g, c), 0.),
+                         assignments[ind], max_IoU[ind],
+                         best_assignments[ind]))
+            ind += 1
+
+    return (s, ncells), rows
+
+def _seg_filter_optim(
+        group_id,
+        new_values,
+        param_names,
+        flattener,
+        seg_gen,
+        base_params=SegmentationParameters(),
+        scoring='F0_5',
+        n_jobs=4):
+    """Parallelisation helper for ``SegmentationTrainer.fit_seg_params``."""
+
+    # Replace default params with specified ones
+    new_params = {}
+    for k, v in zip(param_names, new_values):
+        if k in BROADCASTABLE_PARAMS:
+            if k not in new_params:
+                gval = getattr(base_params, k)
+                if type(gval) != list:
+                    raise BadType('Segmentation parameters should '
+                                  'have been broadcast.')
+                new_params[k] = gval.copy()
+            new_params[k][group_id] = v
+        else:
+            new_params[k] = v
+    new_params = base_params._replace(**new_params)
+
+    # Optimise filtering parameters and return score
+    sfpo = SegFilterParamOptim(flattener, base_params=new_params,
+                               scoring=scoring)
+    with np.errstate(all='ignore'):
+        sfpo.generate_stat_table(seg_gen, n_jobs=n_jobs)
+        sfpo.fit_filter_params(lazy=True, bootstrap=False, n_jobs=n_jobs)
+    return {
+        'group': group_id,
+        'base': {k: getattr(new_params, k) for k in param_names},
+        'filter': sfpo.opt_params,
+        'score': sfpo.opt_score,
+        'scores': sfpo.filter_trial(**sfpo.opt_params, bootstrap=False)
+    }
+
+
+round_to_n = lambda x, n: round(x, -int(floor(log10(x))) + (n - 1))
+
+
+class SegFilterParamOptim:
+    """
+    # TODO What does this class do
+        * What are the parameters and what do they mean
+        * What are the defaults, what are the ranges/admissible options?
+    :param flattener:
+    :param base_params:
+    :param IoU_thresh:
+    :param scoring:
+    :param nbootstraps:
+    :param bootstrap_frac:
+    """
+    def __init__(self,
+                 flattener,
+                 base_params=SegmentationParameters(),
+                 IoU_thresh=0.5,
+                 scoring='F0_5',
+                 nbootstraps=10,
+                 bootstrap_frac=0.9):
+
+        self.IoU_thresh = IoU_thresh
+        self.scoring = scoring
+        self.nbootstraps = nbootstraps
+        self.bootstrap_frac = bootstrap_frac
+
+        # Force values for parameters optimised by this class
+        self._base_params = base_params._replace(
+            fit_radial=True, min_area=1, pedge_thresh=None,
+            use_group_thresh=False, group_thresh_expansion=0,
+            containment_thresh=0.8)
+
+        self.segmenter = MorphSegGrouped(flattener,
+                                         params=self._base_params,
+                                         return_masks=True)
+
+        self.group_info = []
+        for g, group in enumerate(self.segrps):
+            lower = min(
+                target.definition.get('lower', 1.)
+                for target in group.targets)
+            upper = max(
+                target.definition.get('upper', float('inf'))
+                for target in group.targets)
+            grange = lower if upper == float('inf') else upper - lower
+            self.group_info.append((g, lower, upper, grange))
+
+    @property
+    def scoring(self):
+        """ The scoring method used during evaluation of the segmentation.
+
+        Accepted values are any of the named attributes of :py:class:`Score`
+        specified as a ``str``.
+        """
+        return self._scoring
+
+    @scoring.setter
+    def scoring(self, val):
+        if val not in Score._fields:
+            raise BadParam('Specified scoring metric not available')
+        self._scoring = val
+
+    @property
+    def base_params(self):
+        return self._base_params
+
+    @property
+    def segrps(self):
+        return self.segmenter.groups
+
+    @property
+    def stat_table(self):
+        val = getattr(self, '_stat_table', None)
+        if val is None:
+            raise BadProcess('"generate_stat_table" has not been run')
+        return val
+
+    @property
+    def stat_table_bootstraps(self):
+        val = getattr(self, '_stat_table_bootstraps', None)
+        if val is None:
+            raise BadProcess('"generate_stat_table" has not been run')
+        return val
+
+    @property
+    def truth(self):
+        val = getattr(self, '_nPs', None)
+        if val is None:
+            raise BadProcess('"generate_stat_table" has not been run')
+        return val
+
+    @property
+    def truth_bootstraps(self):
+        val = getattr(self, '_nPs_bootstraps', None)
+        if val is None:
+            raise BadProcess('"generate_stat_table" has not been run')
+        return val
+
+    @property
+    def opt_params(self):
+        val = getattr(self, '_opt_params', None)
+        if val is None:
+            raise BadProcess('"fit_filter_params" has not been run')
+        return val
+
+    @property
+    def opt_score(self):
+        val = getattr(self, '_opt_score', None)
+        if val is None:
+            raise BadProcess('"fit_filter_params" has not been run')
+        return val
+
+    def generate_stat_table(self, example_gen, n_jobs=4):
+        """Generates unfiltered segmentation results organised as a table.
+
+        The generated output can be accessed as a ``pandas.DataFrame`` from
+        :py:attr:`stat_table` and as a list of bootstrap samples from
+        :py:attr:`stat_table_bootstraps`.
+
+        Note:
+            This function is called from
+            :py:meth:`SegmentationTrainer.fit_seg_params` via
+            :py:func:`_seg_filter_optim` and also from
+            :py:meth:`SegmentationTrainer.refit_filter_seg_params`.
+        """
+        containment_func = getattr(segmentation,
+                                   self.base_params.containment_func)
+        rows_truth = []
+        rows = []
+
+        from joblib import Parallel, delayed
+        rows = Parallel(n_jobs=n_jobs)(
+            delayed(_generate_stat_table)(s, seg_ex, self.segrps,
+                                          containment_func)
+            for s, seg_ex in enumerate(example_gen))
+
+        rows_truth, rows = zip(*rows)
+        rows = list(chain(*rows))
+
+        df_truth = pd.DataFrame(rows_truth, columns=['example', 'ncells'])
+        df_truth = df_truth.set_index('example')
+        self._nPs = df_truth.ncells.to_numpy()
+
+        dtypes = [('example', np.uint16),
+                  ('group', np.uint8),
+                  ('cell', np.uint16),
+                  ('area', np.uint16),
+                  ('p_edge', np.float64),
+                  ('containment', np.float64),
+                  ('assignments', np.uint16),
+                  ('max_IoU', np.float64),
+                  ('best_assignments', np.int32)]
+        df = pd.DataFrame(np.array(rows, dtype=dtypes))
+
+        df['is_best'] = ((df.best_assignments >= 0) &
+                         (df.max_IoU >= self.IoU_thresh))
+        df['eid'] = df.example
+        df['uid'] = tuple(zip(df.example, df.assignments))
+
+        # Generate a set of bootstrapping filters over 90% of the examples
+        examples = list(set(df_truth.index.values))
+        nperboot = np.round(self.bootstrap_frac * len(examples)).astype(int)
+        bootstraps = [
+            np.random.choice(examples, nperboot, replace=True)
+            for _ in range(self.nbootstraps)
+        ]
+        self._nPs_bootstraps = [df_truth.loc[b].sum() for b in bootstraps]
+        # Limit bootstrap examples to those present in segmentation output
+        bootstraps = [b[np.isin(b, df.example)] for b in bootstraps]
+        df.set_index('example', drop=False, inplace=True)
+        example_counts = df.example.value_counts()
+        self._stat_table_bootstraps = []
+        for b in bootstraps:
+            df_boot = df.loc[b]
+            # Renumber examples to handle the case of duplicated examples in
+            # the bootstrap:
+            df_boot['eid'] = tuple(
+                chain(*(repeat(i, example_counts.loc[e])
+                        for i, e in enumerate(b))))
+            df_boot['uid'] = tuple(zip(df_boot.eid, df_boot.assignments))
+            df_boot.set_index('uid', inplace=True)
+            self._stat_table_bootstraps.append(df_boot)
+
+        df.set_index('uid', inplace=True)
+        self._stat_table = df
+
+    def _ensure_filter_trial_cache(self, bootstrap):
+        if getattr(self, '_filter_trial_cache') is None:
+            if bootstrap:
+                dfs = self.stat_table_bootstraps
+                truths = self.truth_bootstraps
+            else:
+                dfs = [self.stat_table]
+                truths = [self.truth.sum()]
+            self._filter_trial_cache = []
+            for df, nT in zip(dfs, truths):
+                # From stackoverflow:
+                # https://stackoverflow.com/questions/8623047/group-by-max-or-min-in-a-numpy-array
+                # We need to pre-sort the DataFrame by a combined grouping
+                # metric according to eid and assignments and IoU
+                df = df.sort_values(['eid','assignments','max_IoU'])
+                # Also need to reformat the group_info
+                gIds, gLs, gUs, gRs = zip(*self.group_info)
+                cache = dict(nT=nT, gIds=gIds,
+                             gLs=np.array(gLs, dtype='float'),
+                             gUs=np.array(gUs, dtype='float'),
+                             gRs=np.array(gRs, dtype='float'),
+                             containment=df.containment.to_numpy(),
+                             area=df.area.to_numpy(),
+                             p_edge=df.p_edge.to_numpy(),
+                             group=df.group.to_numpy(),
+                             max_IoU=df.max_IoU.to_numpy(),
+                             assignId=df[['eid','assignments']].to_numpy())
+                self._filter_trial_cache.append(cache)
+
+    def filter_trial(self,
+                     pedge_thresh,
+                     group_thresh_expansion,
+                     containment_thresh,
+                     min_area,
+                     bootstrap=True,
+                     return_stderr=False,
+                     use_cache=False):
+
+        if not use_cache:
+            self._filter_trial_cache = None
+        self._ensure_filter_trial_cache(bootstrap)
+
+        return _filter_trial_bootstraps(self._filter_trial_cache,
+                                        pedge_thresh, group_thresh_expansion,
+                                        containment_thresh, min_area,
+                                        self.IoU_thresh,
+                                        return_stderr=return_stderr)
+
+    def parallel_filter_trials(self, params_list, score=None, bootstrap=True,
+                               return_stderr=False, use_cache=False, n_jobs=4):
+        if not use_cache:
+            self._filter_trial_cache = None
+        self._ensure_filter_trial_cache(bootstrap)
+        cache = self._filter_trial_cache
+
+        from joblib import Parallel, delayed
+        mean_scores = Parallel(n_jobs=n_jobs)(
+            delayed(_filter_trial_bootstraps)(cache, **params,
+                                              IoU_thresh=self.IoU_thresh,
+                                              return_stderr=return_stderr)
+            for params in params_list)
+
+        if return_stderr:
+            mean_scores, stderrs = zip(*mean_scores)
+            if score is not None:
+                mean_scores = [getattr(s, score) for s in mean_scores]
+                stderrs = [getattr(s, score) for s in stderrs]
+            return mean_scores, stderrs
+        else:
+            if score is not None:
+                mean_scores = [getattr(s, score) for s in mean_scores]
+            return mean_scores
+
+    def fit_filter_params(self, lazy=False, bootstrap=False, n_jobs=4):
+        # Define parameter grid values, firstly those not specific to a group
+        params = {
+            ('containment_thresh', None): np.linspace(0, 1, 21),
+            ('min_area', None): np.arange(0, 20, 1)
+        }
+
+        # Ensure we reset the filter trial cache
+        self._filter_trial_cache = None
+        self._ensure_filter_trial_cache(bootstrap)
+
+        # Determine the pedge_threshold range based on the observed p_edge
+        # range for each group
+        if (self.stat_table.is_best.all() or not self.stat_table.is_best.any()):
+            t_pe_upper = self.stat_table.groupby('group').p_edge.mean()
+        else:
+            q_pe = self.stat_table.groupby(['group', 'is_best'])
+            q_pe = q_pe.p_edge.quantile([0.25, 0.95]).unstack((1, 2))
+            t_pe_upper = q_pe.loc[:, [(False, 0.95), (True, 0.25)]].mean(1)
+
+        t_pe_vals = [
+            np.arange(0, u, round_to_n(u / 20, 1)) for u in t_pe_upper
+        ]
+
+        # Set group-specific parameter grid values
+        g_ex_vals = repeat(np.linspace(0, 0.4, 21))
+        for g, (t_pe, g_ex) in enumerate(zip(t_pe_vals, g_ex_vals)):
+            params[('pedge_thresh', g)] = t_pe
+            params[('group_thresh_expansion', g)] = g_ex
+
+        # Default starting point is with thresholds off and no group expansion
+        ngroups = len(self.segrps)
+        dflt_params = {
+            'containment_thresh': 0,
+            'min_area': 0,
+            'pedge_thresh': list(repeat(0, ngroups)),
+            'group_thresh_expansion': list(repeat(0, ngroups))
+        }
+
+        # Search first along each parameter dimension with all others kept at
+        # default:
+        opt_params = {}
+        for k, pvals in params.items():
+            params_list = [_sub_params({k: v}, dflt_params) for v in pvals]
+            scrs = self.parallel_filter_trials(params_list,
+                                               score=self.scoring,
+                                               bootstrap=bootstrap,
+                                               use_cache=True, n_jobs=n_jobs)
+            maxInd = np.argmax(scrs)
+            opt_params[k] = pvals[maxInd]
+
+        # Reset the template parameters to the best along each dimension
+        base_params = _sub_params(opt_params, dflt_params)
+
+        if lazy:
+            # Simply repeat search along each parameter dimension, but now
+            # using the new optimum as a starting point
+            opt_params = {}
+            for k, pvals in params.items():
+                params_list = [_sub_params({k: v}, base_params) for v in pvals]
+                scrs = self.parallel_filter_trials(params_list,
+                                                   score=self.scoring,
+                                                   bootstrap=bootstrap,
+                                                   use_cache=True, n_jobs=n_jobs)
+                maxInd = np.argmax(scrs)
+                opt_params[k] = pvals[maxInd]
+            opt_params = _sub_params(opt_params, base_params)
+            scr = self.filter_trial(**opt_params, bootstrap=bootstrap,
+                                    use_cache=True)
+            self._opt_params = opt_params
+            self._opt_score = getattr(scr, self.scoring)
+            return
+
+        # Next perform a joint search for parameters with optimal pairings
+        opt_param_pairs = {k: {v} for k, v in opt_params.items()}
+        for k1, k2 in combinations(params.keys(), 2):
+            vals = list(product(params[k1], params[k2]))
+            params_list = [_sub_params({ k1: v1, k2: v2 }, base_params)
+                          for v1, v2 in vals]
+            scrs = self.parallel_filter_trials(params_list,
+                                               score=self.scoring,
+                                               bootstrap=bootstrap,
+                                               use_cache=True, n_jobs=n_jobs)
+            maxInd = np.argmax(scrs)
+            p1opt, p2opt = vals[maxInd]
+            opt_param_pairs[k1].add(p1opt)
+            opt_param_pairs[k2].add(p2opt)
+
+        # Finally search over all combinations of the parameter values found
+        # with optimal pairings
+        params_list = [_sub_params({k: v for k, v in
+                                    zip(opt_param_pairs.keys(), pvals)},
+                                   base_params)
+                       for pvals in product(*opt_param_pairs.values())]
+        scrs = self.parallel_filter_trials(params_list,
+                                           score=self.scoring,
+                                           bootstrap=bootstrap,
+                                           use_cache=True, n_jobs=n_jobs)
+        maxInd = np.argmax(scrs)
+        self._opt_params = params_list[maxInd]
+        self._opt_score = scrs[maxInd]
+
+
+class SegmentationTrainer(object):
+    """Finds optimal segmentation parameters given a trained CNN.
+
+    Args:
+        shared_params: Training and segmentation parameters as provided by
+            :py:class:`utils.SharedParameterContainer`.
+        shared_data: Training data as provided by
+            :py:class:`utils.SharedDataContainer`.
+        ssm_trainer: SmoothingModelTrainer with optimised model for
+            determination of smoothing sigma.
+        cnn_trainer: Trainer with optimised CNN.
+    """
+    def __init__(self,
+                 shared_params: SharedParameterContainer,
+                 shared_data: SharedDataContainer,
+                 ssm_trainer: SmoothingModelTrainer,
+                 cnn_trainer: CNNTrainer):
+        self._shared_params = shared_params
+        self._shared_data = shared_data
+        self._ssm_trainer = ssm_trainer
+        self._cnn_trainer = cnn_trainer
+
+    @property
+    def save_dir(self):
+        """Base directory in which to save trained models"""
+        return self._shared_params.save_dir
+
+    @property
+    def training_parameters(self):
+        return self._shared_params.parameters
+
+    @property
+    def segment_parameters(self):
+        return self._shared_params.segmentation_parameters
+
+    @segment_parameters.setter
+    def segment_parameters(self, val):
+        self._shared_params.segmentation_parameters = val
+
+    @property
+    def segment_parameter_coords(self):
+        param_coords = DEFAULT_SEG_PARAM_COORDS.copy()
+        param_coords.update(self.training_parameters.seg_param_coords)
+        return param_coords
+
+    @property
+    def gen(self):
+        """Training, validation and test data generators with raw output.
+
+        This attribute provides three :py:class:`ImageLabel` generators as a
+        :py:class:`TrainValTestProperty`, with each generator assigned a
+        :py:class:`ScalingAugmenter` that outputs an ``(image, label, info)``
+        tuple, where ``label`` provides the unflattened ``ndarray`` of cell
+        outlines and ``info`` is a ``dict`` of meta-data associated with the
+        label (if any). Augmentations are limited to just cropping and scaling
+        operations to match the intended pixel size and input size of the CNN.
+        """
+        # Create an augmenter that returns unflattened label images
+        aug = standard_augmenter(
+            self._ssm_trainer.model,
+            lambda lbl, _: lbl,
+            self.training_parameters,
+            isval=True)
+
+        def seg_example_aug(img, lbl):
+            # Assume that the label preprocessing function also returns info
+            _, info = lbl
+            img, lbl = aug(img, lbl)
+            return img, (lbl > 0, info)
+
+        # In this case, always prefer the validation augmenter
+        return self._shared_data.gen_with_aug(seg_example_aug)
+
+    @property
+    def examples(self):
+        """Training, validation and test segmentation example generators.
+
+        This attribute provides three generators as a
+        :py:class:`TrainValTestProperty`, with each generator yielding
+        a :py:class:`SegExample` for each image-label pair in the training,
+        validation or test image data collections.
+        """
+        # Ensure that the saved generators are updated if more data is
+        # added...
+        if getattr(self, '_ncells', None) != self._shared_data.data.ncells:
+            self._seg_examples = None
+            self._ncells = self._shared_data.data.ncells
+        # ...or if data generation parameters change:
+        old_gen_params = getattr(self, '_current_gen_params', None)
+        new_gen_params = tuple(getattr(self.training_parameters, p) for p in
+                               ('in_memory', 'input_norm_dw', 'batch_size',
+                                'balanced_sampling', 'use_sample_weights',
+                                'xy_out', 'target_pixel_size', 'substacks'))
+        if old_gen_params != new_gen_params:
+            self._seg_examples = None
+            self._current_gen_params = new_gen_params
+
+        cnn = self._cnn_trainer.opt_cnn
+        gen = self.gen
+        if self.training_parameters.in_memory:
+            if getattr(self, '_seg_examples', None) is None:
+                self._seg_examples = TrainValTestProperty(
+                    list(_example_generator(cnn, gen.train)),
+                    list(_example_generator(cnn, gen.val)),
+                    list(_example_generator(cnn, gen.test)))
+            return TrainValTestProperty(
+                (e for e in self._seg_examples.train),
+                (e for e in self._seg_examples.val),
+                (e for e in self._seg_examples.test))
+        else:
+            self._seg_examples = None
+            return TrainValTestProperty(
+                _example_generator(cnn, gen.train),
+                _example_generator(cnn, gen.val),
+                _example_generator(cnn, gen.test))
+
+    @property
+    def flattener(self):
+        return self._cnn_trainer.flattener
+
+    @property
+    def seg_param_stats(self):
+        if getattr(self, '_seg_param_stats', None) is None:
+            seg_stats_file = self.training_parameters.segmentation_stats_file
+            stats_file = self.save_dir / seg_stats_file
+            if not stats_file.is_file():
+                raise BadProcess('"fit_seg_params" has not been run yet')
+            self._seg_param_stats = pd.read_csv(stats_file, index_col=0)
+        return self._seg_param_stats
+
+    def fit_seg_params(self, n_jobs=None, scoring='F0_5', subsample_frac=1.,
+                       fit_on_split='val'):
+        """Find optimal segmentation hyperparameters.
+
+        Args:
+            njobs (int): Number of parallel processes to run.
+            scoring (str): Scoring metric to be used to assess segmentation
+                performance. The name of any of the attributes in
+                :py:class:`Score` may be specified.
+        """
+
+        if n_jobs is None:
+            n_jobs = self.training_parameters.n_jobs
+
+        # Initialise the default segmenter to determine the number of groups
+        # and obtain broadcast base parameters:
+        segmenter = MorphSegGrouped(
+            self.flattener,
+            params=self.segment_parameters)
+        ngroups = len(segmenter.groups)
+        base_params = segmenter.params
+
+        # Generate parameter search grid according to training parameters
+        param_coords = self.segment_parameter_coords
+        param_grid = list(product(*param_coords.values()))
+        par_names = list(param_coords.keys())
+
+        if type(subsample_frac) == float:
+            subsample_frac = (subsample_frac,)
+        if type(fit_on_split) == str:
+            fit_on_split = (fit_on_split,)
+        if len(subsample_frac) == 1 and len(fit_on_split) > 1:
+            subsample_frac = subsample_frac * len(fit_on_split)
+        if len(fit_on_split) == 1 and len(subsample_frac) > 1:
+            fit_on_split = fit_on_split * len(subsample_frac)
+
+        examples = []
+        for ssf, split in zip(subsample_frac, fit_on_split):
+            step = max(int(np.floor(1. / ssf)), 1)
+            # NB: the following essentially assumes that examples are
+            # presented in random order (see _example_generator)
+            examples.extend(list(islice(
+                getattr(self.examples, split), None, None, step)))
+
+        rows = []
+        for gind in range(ngroups)[::-1]:
+            for pars in tqdm(param_grid):
+                rows.append(_seg_filter_optim(gind, pars, par_names,
+                                              self.flattener, examples,
+                                              base_params=base_params,
+                                              scoring=scoring, n_jobs=n_jobs))
+
+        rows_expanded = []
+        for row in rows:
+            row_details = chain(
+                [('group', row['group']), ('score', row['score'])],
+                row['scores']._asdict().items(),
+                row['base'].items(), row['filter'].items())
+            row_expanded = []
+            for k, v in row_details:
+                if k in BROADCASTABLE_PARAMS and type(v) is list:
+                    kvpairs = [('_'.join((k, str(g))), gv)
+                               for g, gv in enumerate(v)]
+                else:
+                    kvpairs = [(k, v)]
+                row_expanded.extend(kvpairs)
+            rows_expanded.append(dict(row_expanded))
+
+        # TODO: if None values are combined with integer values, the entire
+        # column gets converted here to float, with the None values as NaN.
+        # This causes errors, for example, with specification of
+        # edge_sub_dilations. There is currently no obvious solution to this.
+        self._seg_param_stats = pd.DataFrame(rows_expanded)
+        stats_file = (self.save_dir /
+                      self.training_parameters.segmentation_stats_file)
+        self._seg_param_stats.to_csv(stats_file)
+
+        self.refit_filter_seg_params(scoring=scoring, n_jobs=n_jobs)
+
+    def refit_filter_seg_params(self,
+                                lazy=False,
+                                bootstrap=False,
+                                scoring='F0_5',
+                                n_jobs=None):
+
+        if n_jobs is None:
+            n_jobs = self.training_parameters.n_jobs
+
+        # Initialise the default segmenter to determine the number of groups
+        # and obtain broadcast base parameters:
+        segmenter = MorphSegGrouped(
+            self.flattener,
+            params=self.segment_parameters)
+        ngroups = len(segmenter.groups)
+        base_params = segmenter.params
+
+        # Merge the best parameters from each group into a single parameter set
+        par_names = list(self.segment_parameter_coords.keys())
+        broadcast_par_names = [k for k in par_names
+                               if k in BROADCASTABLE_PARAMS]
+        merged_params = {k: getattr(base_params, k) for k in par_names}
+        for k in broadcast_par_names:
+            merged_params[k] = merged_params[k].copy()
+        stats = self.seg_param_stats
+        for g, r in enumerate(stats.groupby('group').score.idxmax()):
+            for k in broadcast_par_names:
+                merged_params[k][g] = stats.loc[r, k + '_' + str(g)]
+        merged_params = base_params._replace(**merged_params)
+
+        sfpo = SegFilterParamOptim(self.flattener,
+                                   base_params=merged_params,
+                                   scoring=scoring)
+        with np.errstate(all='ignore'):
+            sfpo.generate_stat_table(list(self.examples.val), n_jobs=n_jobs)
+            sfpo.fit_filter_params(lazy=lazy, bootstrap=bootstrap,
+                                   n_jobs=n_jobs)
+
+        self.segment_parameters = merged_params._replace(**sfpo.opt_params)
+
+    def validate_seg_params(
+            self,
+            iou_thresh=0.7,
+            save=True,
+            refine_outlines=True):
+        segmenter = MorphSegGrouped(self.flattener,
+                                    params=self.segment_parameters,
+                                    return_masks=True)
+        edge_inds = [
+            i for i, t in enumerate(self.flattener.targets)
+            if t.prop == 'edge'
+        ]
+        stats = {}
+        dfs = {}
+        for k, seg_exs in zip(self.examples._fields, self.examples):
+            stats[k] = []
+            for seg_ex in seg_exs:
+                seg = segmenter.segment(seg_ex.pred,
+                                        refine_outlines=refine_outlines)
+                edge_scores = np.array([
+                    seg_ex.pred[edge_inds, ...].max(axis=0)[s].mean()
+                    for s in seg.edges
+                    ])
+                IoUs = calc_IoUs(seg_ex.target, seg.masks)
+                bIoU, _ = best_IoU(IoUs)
+                stats[k].append((edge_scores, IoUs, np.mean(bIoU),
+                    np.min(bIoU, initial=1),
+                    calc_AP(IoUs,
+                        probs=edge_scores,
+                        iou_thresh=iou_thresh)[0]))
+            dfs[k] = pd.DataFrame([s[2:] for s in stats[k]],
+                    columns=['IoU_mean', 'IoU_min', 'AP'])
+
+        print({k: df.mean() for k, df in dfs.items()})
+
+        nrows = len(dfs)
+        ncols = dfs['val'].shape[1]
+        fig, axs = plt.subplots(nrows=nrows,
+                ncols=ncols,
+                figsize=(ncols * 4, nrows * 4))
+        for axrow, (k, df) in zip(axs, dfs.items()):
+            for ax, col in zip(axrow, df.columns):
+                ax.hist(df.loc[:, col], bins=26, range=(0, 1))
+                ax.set(xlabel=col, title=k)
+        if save:
+            fig.savefig(self.save_dir / 'seg_validation_plot.png')
+            plt.close(fig)
+
diff --git a/python/baby/training/smoothing_model_trainer.py b/python/baby/training/smoothing_model_trainer.py
index 00388d2e8e7d7e2c4249ba36abdd646e19d2026c..9960a061528f7296dfab6e458e6badb129f7f8c8 100644
--- a/python/baby/training/smoothing_model_trainer.py
+++ b/python/baby/training/smoothing_model_trainer.py
@@ -1,12 +1,14 @@
 # If you publish results that make use of this software or the Birth Annotator
 # for Budding Yeast algorithm, please cite:
-# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S
-# Swain, 2021, Birth Annotator for Budding Yeast (in preparation).
+# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+# and Swain, P.S. (2023). Determining growth rates from bright-field images of
+# budding cells through identifying overlaps. eLife. 12:e79812.
+# https://doi.org/10.7554/eLife.79812
 # 
 # 
 # The MIT License (MIT)
 # 
-# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021
+# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
 # 
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to
@@ -31,8 +33,10 @@ import pandas as pd
 from baby.augmentation import _filled_canny, _apply_crop, SmoothingSigmaModel
 from baby.errors import BadProcess, BadType
 from baby.preprocessing import dwsquareconn
-from baby.segmentation import binary_edge, squareconn, morph_radial_thresh_fit, \
-    draw_radial, mask_iou
+from baby.generator import augmented_generator
+from baby.segmentation import (binary_edge, squareconn, mask_to_knots,
+                               draw_radial, mask_iou)
+from baby.morph_thresh_seg import SegmentationParameters
 from baby.utils import find_file
 from numpy.polynomial import Polynomial
 from scipy.ndimage import binary_fill_holes
@@ -41,10 +45,10 @@ from skimage import filters, transform
 from skimage.measure import regionprops
 from tqdm import trange
 
-from .utils import augmented_generator, TrainValProperty
+from .utils import TrainValProperty
 
 
-def _generate_smoothing_sigma_stats(gen):
+def _generate_smoothing_sigma_stats(gen, params):
     sigmas = np.arange(0.4, 5.0, 0.20)
     rotations = np.arange(7, 45, 7)
     scaling = np.linspace(0.5, 1.5, 6)
@@ -52,7 +56,7 @@ def _generate_smoothing_sigma_stats(gen):
     square_edge = lambda m: binary_edge(m, squareconn)
     smoothing_stats = []
     for t in trange(len(gen.paths)):
-        _, (segs, _) = gen.get_by_index(t)
+        segs = gen.get_by_index(t)
 
         if segs.shape[2] == 1 and segs.sum() == 0:
             continue
@@ -64,13 +68,40 @@ def _generate_smoothing_sigma_stats(gen):
             sfill = segs_fill[..., c]
             sedge = segs_edge[..., c]
             nedge = segs[..., c].sum()
+            area = sfill.sum()
+
+            # Add padding to make input image square
+            nx, ny = sfill.shape[:2]
+            if nx != ny:
+                rprops = regionprops(sfill.astype('int'))[0]
+                min_rr, min_cc, max_rr, max_cc = rprops.bbox
+                pad = np.round(0.25 * max(max_rr - min_rr, max_cc - min_cc)).astype(int)
+                min_rr = max(min_rr - pad, 0)
+                min_cc = max(min_cc - pad, 0)
+                max_rr = min(max_rr + pad, nx)
+                max_cc = min(max_cc + pad, ny)
+                sfill = sfill[min_rr:max_rr, min_cc:max_cc]
+                sedge = sedge[min_rr:max_rr, min_cc:max_cc]
+                nx, ny = sfill.shape[:2]
+
+            # Add padding to ensure features do not rotate out of image limits
+            xpad = max(nx, ny) - nx
+            ypad = max(nx, ny) - ny
+            xlpad = xpad // 2
+            ylpad = ypad // 2
+            sfill = np.pad(sfill, ((xlpad, xpad - xlpad),
+                (ylpad, ypad - ylpad)), mode='constant')
+            sedge = np.pad(sedge, ((xlpad, xpad - xlpad),
+                (ylpad, ypad - ylpad)), mode='constant')
 
-            # fit radial spline to generate accurate reference edges for
-            # resize transformation
             rprops = regionprops(sfill.astype('int'))[0]
             centre = np.array(rprops.centroid)
-            radii, angles = morph_radial_thresh_fit(sedge, sfill, rprops)
-            genedge = draw_radial(radii, angles, centre, sedge.shape)
+
+            # fit spline to generate accurate reference edges for resize
+            # transformation. We do not have a predicted edge in this case, so
+            # supply the ground truth edge for the fitting routine.
+            (_, radii, angles), genedge = mask_to_knots(
+                sfill, p_edge=sedge, **params._asdict())
             genfill = binary_fill_holes(genedge, squareconn)
 
             # Limit the number of rotation and scaling operations by
@@ -82,6 +113,7 @@ def _generate_smoothing_sigma_stats(gen):
                 sblur = filters.gaussian(sfill, s)
                 genblur = filters.gaussian(genfill, s)
 
+                # Identity transformation (on raw edge)
                 spf = _filled_canny(sblur)
                 smoothing_stats += [{
                     'ind': t,
@@ -90,18 +122,20 @@ def _generate_smoothing_sigma_stats(gen):
                     'rotation': 0,
                     'scaling': 1,
                     'nedge': nedge,
+                    'area': area,
                     'iou': mask_iou(spf, sfill),
                     'edge_iou': mask_iou(square_edge(spf), sedge)
-                }]
+                    }]
 
+                # Rotation transformation (on raw edge)
                 sr = transform.rotate(sblur,
-                                      angle=r,
-                                      mode='reflect',
-                                      resize=True)
+                        angle=r,
+                        mode='reflect',
+                        resize=True)
                 sr = transform.rotate(sr,
-                                      angle=-r,
-                                      mode='reflect',
-                                      resize=False)
+                        angle=-r,
+                        mode='reflect',
+                        resize=False)
                 srf = _filled_canny(_apply_crop(sr, spf.shape))
                 smoothing_stats += [{
                     'ind': t,
@@ -110,15 +144,17 @@ def _generate_smoothing_sigma_stats(gen):
                     'rotation': r,
                     'scaling': 1,
                     'nedge': nedge,
+                    'area': area,
                     'iou': mask_iou(srf, sfill),
                     'edge_iou': mask_iou(square_edge(srf), sedge)
-                }]
+                    }]
 
+                # Scaling transformation (on generated edge)
                 insize = np.array(spf.shape)
                 outsize = np.round(insize * z).astype('int')
                 centre_sc = outsize / 2 + z * (centre - insize / 2)
                 genedge_sc = draw_radial(z * radii, angles, centre_sc,
-                                         outsize)
+                        outsize, cartesian_spline=params.cartesian_spline)
                 genfill_sc = binary_fill_holes(genedge_sc, squareconn)
                 sd = transform.resize(genblur, outsize, anti_aliasing=False)
                 sdf = _filled_canny(sd)
@@ -129,25 +165,52 @@ def _generate_smoothing_sigma_stats(gen):
                     'rotation': 0,
                     'scaling': z,
                     'nedge': nedge,
+                    'area': area,
                     'iou': mask_iou(sdf, genfill_sc),
                     'edge_iou': mask_iou(square_edge(sdf), genedge_sc)
-                }]
+                    }]
 
     return pd.DataFrame(smoothing_stats)
 
 class SmoothingModelTrainer:
-    def __init__(self, save_dir, stats_file, model_file):
-        self.save_dir = save_dir
-        self.stats_file = save_dir / stats_file
-        self.model_file = save_dir / model_file
+    """Trains the smoothing model for augmentations on binary masks.
+
+    Args:
+        shared_params (utils.SharedParameterContainer): training and
+            segmentation parameters as provided by
+            :py:class:`utils.SharedParameterContainer`.
+        shared_data (utils.SharedDataContainer): Access to training data.
+    """
+    def __init__(self, shared_params, shared_data):
+        self._shared_params = shared_params
+        self._shared_data = shared_data
         self._model = None
         self._stats = None
 
-    def generate_smoothing_sigma_stats(self, train_gen, val_gen):
-        with augmented_generator(train_gen, lambda x, y: (x, y)) as gen:
-            sss_train = _generate_smoothing_sigma_stats(gen)
-        with augmented_generator(val_gen, lambda x, y: (x, y)) as gen:
-            sss_val = _generate_smoothing_sigma_stats(gen)
+    @property
+    def save_dir(self):
+        return self._shared_params.save_dir
+
+    @property
+    def stats_file(self):
+        return (self.save_dir /
+                self._shared_params.parameters.smoothing_sigma_stats_file)
+
+    @property
+    def model_file(self):
+        return (self.save_dir /
+                self._shared_params.parameters.smoothing_sigma_model_file)
+
+    @property
+    def segment_params(self):
+        return self._shared_params.segmentation_parameters
+
+    def generate_smoothing_sigma_stats(self, aug=lambda x, y: y[0]):
+        train_gen, val_gen, _ = self._shared_data.gen
+        with augmented_generator(train_gen, aug) as gen:
+            sss_train = _generate_smoothing_sigma_stats(gen, self.segment_params)
+        with augmented_generator(val_gen, aug) as gen:
+            sss_val = _generate_smoothing_sigma_stats(gen, self.segment_params)
         sss_train['validation'] = False
         sss_val['validation'] = True
         sss = pd.concat((sss_train, sss_val))
@@ -158,10 +221,10 @@ class SmoothingModelTrainer:
         if self._stats is None:
             if not self.stats_file.exists():
                 raise BadProcess(
-                    'smoothing sigma stats have not been generated')
+                        'smoothing sigma stats have not been generated')
             self._stats = pd.read_csv(self.stats_file)
         return TrainValProperty(self._stats[~self._stats['validation']],
-                                self._stats[self._stats['validation']])
+                self._stats[self._stats['validation']])
 
     @property
     def model(self):
@@ -172,7 +235,7 @@ class SmoothingModelTrainer:
                 self._model = smoothing_sigma_model
             else:
                 raise BadProcess(
-                    'The "smoothing_sigma_model" has not been assigned yet')
+                        'The "smoothing_sigma_model" has not been assigned yet')
         return self._model
 
     @model.setter
@@ -183,8 +246,8 @@ class SmoothingModelTrainer:
             ssm.load(ssm_file)
         if not isinstance(ssm, SmoothingSigmaModel):
             raise BadType(
-                '"smoothing_sigma_model" must be of type "baby.augmentation.SmoothingSigmaModel"'
-            )
+                    '"smoothing_sigma_model" must be of type "baby.augmentation.SmoothingSigmaModel"'
+                    )
         ssm.save(self.model_file)
         self._model = ssm
 
@@ -194,10 +257,10 @@ class SmoothingModelTrainer:
         stats = self.stats.train
         stats = stats.groupby(idcols).apply(group_best_iou)
         filts = {
-            'identity': (stats.scaling == 1) & (stats.rotation == 0),
-            'scaling': stats.scaling != 1,
-            'rotation': stats.rotation != 0
-        }
+                'identity': (stats.scaling == 1) & (stats.rotation == 0),
+                'scaling': stats.scaling != 1,
+                'rotation': stats.rotation != 0
+                }
         return stats, filts
 
     def fit(self, filt='identity'):
@@ -210,8 +273,8 @@ class SmoothingModelTrainer:
         b = 10  # initial guess for offset term in final model
         # Fit s = c + m * log(n - b); want n = b + exp((s - c)/m)
         pinv = Polynomial.fit(np.log(np.clip(stats.nedge - b, 1, None)),
-                              stats.sigma,
-                              deg=1)
+                stats.sigma,
+                deg=1)
         c = pinv(0)
         m = pinv(1) - c
 
@@ -228,34 +291,34 @@ class SmoothingModelTrainer:
         params = (self.model._a, self._model._b, self.model._c)
 
         fig, axs = plt.subplots(2,
-                                len(filts),
-                                figsize=(12, 12 * 2 / len(filts)))
+                len(filts),
+                figsize=(12, 12 * 2 / len(filts)))
         sigma_max = stats.sigma.max()
         nedge_max = stats.nedge.max()
         sigma = np.linspace(0, sigma_max, 100)
         for ax, (k, f) in zip(axs[0], filts.items()):
             ax.scatter(stats[f].sigma,
-                       stats[f].nedge,
-                       16,
-                       alpha=0.05,
-                       edgecolors='none')
+                    stats[f].nedge,
+                    16,
+                    alpha=0.05,
+                    edgecolors='none')
             ax.plot(sigma, model(sigma, *params), 'r')
             ax.set(title=k.title(),
-                   xlabel='sigma',
-                   ylabel='nedge',
-                   ylim=[0, nedge_max])
+                    xlabel='sigma',
+                    ylabel='nedge',
+                    ylim=[0, nedge_max])
 
         nedge = np.linspace(1, nedge_max, 100)
         for ax, (k, f) in zip(axs[1], filts.items()):
             ax.scatter(stats[f].nedge,
-                       stats[f].sigma,
-                       16,
-                       alpha=0.05,
-                       edgecolors='none')
+                    stats[f].sigma,
+                    16,
+                    alpha=0.05,
+                    edgecolors='none')
             ax.plot(nedge, [self.model(n) for n in nedge], 'r')
             ax.set(title=k.title(),
-                   xlabel='nedge',
-                   ylabel='sigma',
-                   ylim=[0, sigma_max])
+                    xlabel='nedge',
+                    ylabel='sigma',
+                    ylim=[0, sigma_max])
         fig.tight_layout()
         fig.savefig(self.save_dir / 'fitted_smoothing_sigma_model.png')
diff --git a/python/baby/training/training.py b/python/baby/training/training.py
index ad9f87c502a3ff84075323eb8fb3438d8c744b54..1c1738e341a7bc52504a4c2a4f95ad9c6fa777af 100644
--- a/python/baby/training/training.py
+++ b/python/baby/training/training.py
@@ -1,12 +1,14 @@
 # If you publish results that make use of this software or the Birth Annotator
 # for Budding Yeast algorithm, please cite:
-# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S
-# Swain, 2021, Birth Annotator for Budding Yeast (in preparation).
+# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+# and Swain, P.S. (2023). Determining growth rates from bright-field images of
+# budding cells through identifying overlaps. eLife. 12:e79812.
+# https://doi.org/10.7554/eLife.79812
 # 
 # 
 # The MIT License (MIT)
 # 
-# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021
+# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
 # 
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to
@@ -28,7 +30,6 @@
 import inspect
 import json
 import pickle
-import shutil
 import warnings
 from itertools import product, chain
 from pathlib import Path
@@ -37,98 +38,50 @@ from typing import NamedTuple
 import numpy as np
 import pandas as pd
 import tensorflow as tf
-from baby.brain import default_params
-from baby.morph_thresh_seg import MorphSegGrouped
-from baby.performance import calc_IoUs, best_IoU, calc_AP
-from baby.seg_trainer import SegFilterParamOptim, _sub_params
-from baby.tracker.training import CellTrainer, BudTrainer
-from baby.training import CNNTrainer
-
 from matplotlib import pyplot as plt
-from tqdm import tqdm
+
+from baby.tracker.training import CellTrainer
+from baby.training import CNNTrainer
+from baby.errors import BadParam, BadFile, BadType, BadProcess
 
 if tf.__version__.startswith('2'):
-    tf.compat.v1.disable_eager_execution()
+    # tf.compat.v1.disable_eager_execution()
     from .hyper_parameter_trainer import HyperParameterTrainer
 else:
     from .v1_hyper_parameter_trainer import HyperParamV1 as \
-        HyperParameterTrainer
+            HyperParameterTrainer
 
-from .utils import BabyTrainerParameters, TrainValTestProperty, \
-    augmented_generator
+from .utils import (BabyTrainerParameters, TrainValTestProperty,
+                    standard_augmenter, SharedParameterContainer,
+                    SharedDataContainer)
 from .smoothing_model_trainer import SmoothingModelTrainer
 from .flattener_trainer import FlattenerTrainer
-
-from baby.utils import find_file, as_python_object, jsonify, batch_iterator, \
-    split_batch_pred
-from baby.errors import BadParam, BadFile, BadType, BadProcess
-from baby.io import TrainValTestPairs
-from baby.preprocessing import (robust_norm, seg_norm)
-from baby.augmentation import (Augmenter, ScalingAugmenter)
-from baby.generator import ImageLabel
-from baby.visualise import colour_segstack
-
-LOG_DIR = 'logs'
-
-
-# TODO: ADD TO UTILS in training
-class SegExample(NamedTuple):
-    """CNN output paired with target segmented outlines and info
-
-    Used for optimising segmentation hyperparameters and for training bud
-    assigment models
-    """
-    pred: np.ndarray
-    target: np.ndarray
-    info: dict
-    img: np.ndarray
-
-
-# interior_threshold: threshold on predicted interior
-# nclosing: number of closing operations on threshold mask
-# nopening: number of opening operations on threshold mask
-#
-# TODO Create SegTrainer
-# TODO Add default parameters to SegTrainer
-#   Structure taht the parameters should actually be (shape/size)
-#   Thresholds modified to the correct parameters
-base_seg_params = {
-    'interior_threshold': [0.5, 0.5, 0.5],
-    'nclosing': [0, 0, 0],
-    'nopening': [0, 0, 0],
-    'connectivity': [2, 2, 2],
-    'edge_sub_dilations': [0, 0, 0]
-}
-
-# Todo add default parameters to SegTrainer
-#       Search space !
-seg_param_coords = {
-    'nclosing': [0, 1, 2],
-    'nopening': [0, 1, 2],
-    'interior_threshold': np.arange(0.3, 1.0, 0.05).tolist(),
-    'connectivity': [1, 2],
-    'edge_sub_dilations': [0, 1, 2]
-}
+from .segmentation_trainer import SegmentationTrainer, SegExample
+from .bud_trainer import BudTrainer
 
 
 class BabyTrainer(object):
-    """Manager to set up and train BABY models
-    
-    :param save_dir: directory in which to save parameters and logs (and
-        from which to auto-load parameters and logs)
-    :param train_val_images: either a dict with keys 'training' and
-        'validation' and values specifying lists of file name pairs, or the
-        name of a json file containing such a dict. The file name pairs
-        should correspond to image-label pairs suitable for input to
-        `baby.generator.ImageLabel`.
-    :param flattener: either a `baby.preprocessing.SegmentationFlattening`
-        object, or the name of a json file that is a saved
-        `SegmentationFlattening` object.
+    """Manager to set up and train BABY models.
+
+    Args:
+        save_dir (str or Path): Directory in which to save parameters and logs
+            (and from which to auto-load parameters and logs). Must be an
+            absolute path, or specified relative to ``base_dir``.
+        base_dir (str or Path or None): Base directory within which all
+            relevant image files can be found. By default (i.e., specifying
+            ``None``), uses the current working directory. References to the
+            image files will be saved relative to this directory. If the base
+            directory is copied/moved but the structure within that directory
+            is left intact, then references to the image files will be
+            correctly maintained.
+        params (None or BabyTrainerParameters or str or Path): Optionally
+            specify initial training parameters as a
+            :py:class:`BabyTrainerParameters` instance or the path to a saved
+            :py:class:`BabyTrainerParameters` instance.
     """
 
     def __init__(self, save_dir, base_dir=None, params=None, max_cnns=3):
 
-        # Register the save dir
         if base_dir is not None:
             base_dir = Path(base_dir)
             if not base_dir.is_dir():
@@ -136,25 +89,22 @@ class BabyTrainer(object):
         else:
             base_dir = Path.cwd()
         self.base_dir = base_dir
-        # All other directories are relative to the base dir
+
+        # The save directory may be specified relative to the base directory
         save_dir = Path(save_dir)
         if not save_dir.is_absolute():
             save_dir = base_dir / save_dir
         if not save_dir.is_dir():
             raise BadParam('"save_dir" must be a valid directory')
-        self.save_dir = save_dir
-        self._parameters_file = 'parameters.json'
-
-        # Register parameters if specified
-        if isinstance(params, BabyTrainerParameters):
-            self.parameters = params
-        elif isinstance(params, str):
-            filename = find_file(params, save_dir, 'params')
-            savename = save_dir / self._parameters_file
-            if not savename.is_file():
-                shutil.copy(filename, savename)
+
+        # Initialise shared parameters and data
+        self._shared_params = SharedParameterContainer(
+            save_dir, params=params)
+        self._shared_data = SharedDataContainer(
+            self._shared_params, base_dir=base_dir)
 
         self._max_cnns = max_cnns
+
         # Trainers
         self._smoothing_sigma_trainer = None
         self._flattener_trainer = None
@@ -162,653 +112,192 @@ class BabyTrainer(object):
         self._cnn_trainer = None
         self._track_trainer = None
         self._bud_trainer = None
+        self._segmentation_trainer = None
+
+    @property
+    def parameters(self):
+        return self._shared_params.parameters
+
+    @parameters.setter
+    def parameters(self, params):
+        self._shared_params.parameters = params
+
+    @property
+    def segmentation_parameters(self):
+        return self._shared_params.segmentation_parameters
+
+    @segmentation_parameters.setter
+    def segmentation_parameters(self, params):
+        self._shared_params.segmentation_parameters = params
+
+    @property
+    def save_dir(self):
+        return self._shared_params.save_dir
+
+    @property
+    def in_memory(self):
+        return self.parameters.in_memory
+
+    @in_memory.setter
+    def in_memory(self, x):
+        self.parameters = dict(in_memory=x)
+
+    @property
+    def data(self):
+        return self._shared_data.data
+
+    @data.setter
+    def data(self, x):
+        self._shared_data.data = x
+
+    @property
+    def tracker_data(self):
+        return self._shared_data.tracker_data
+
+    @tracker_data.setter
+    def tracker_data(self, x):
+        self._shared_data.tracker_data = x
+
+    @property
+    def gen(self):
+        return self.flattener_trainer.default_gen
+
+    def gen_with_aug(self, aug):
+        return self._shared_data.gen_with_aug(aug)
 
     @property
     def smoothing_sigma_trainer(self):
         if self._smoothing_sigma_trainer is None:
             self._smoothing_sigma_trainer = SmoothingModelTrainer(
-                save_dir=self.save_dir,
-                stats_file=self.parameters.smoothing_sigma_stats_file,
-                model_file=self.parameters.smoothing_sigma_model_file)
+                self._shared_params, self._shared_data)
         return self._smoothing_sigma_trainer
 
     @property
     def flattener_trainer(self):
         if self._flattener_trainer is None:
             self._flattener_trainer = FlattenerTrainer(
-                save_dir=self.save_dir,
-                stats_file=self.parameters.flattener_stats_file,
-                flattener_file=self.parameters.flattener_file)
+                self._shared_params, self._shared_data,
+                self.smoothing_sigma_trainer)
         return self._flattener_trainer
 
     @property
     def hyperparameter_trainer(self):
         if self._hyperparameter_trainer is None:
             self._hyperparameter_trainer = HyperParameterTrainer(
-                save_dir=self.save_dir,
-                cnn_set=self.parameters.cnn_set,
-                gen=self.gen,
-                aug=self.aug,
-                outputs=self.flattener_trainer.flattener.names(),
-                tuner_params=None
-                # Todo: tuner file if it exists in parameters
-            )
+                    save_dir=self.save_dir,
+                    cnn_set=self.parameters.cnn_set,
+                    gen=self.gen,
+                    aug=self.aug,
+                    outputs=self.flattener_trainer.flattener.names(),
+                    tuner_params=None
+                    # Todo: tuner file if it exists in parameters
+                    )
         return self._hyperparameter_trainer
 
     @property
     def cnn_trainer(self):
         if self._cnn_trainer is None:
             self._cnn_trainer = CNNTrainer(
-                save_dir=self.save_dir,
-                cnn_set=self.parameters.cnn_set,
-                gen=self.gen,
-                aug=self.aug,
-                flattener=self.flattener_trainer.flattener,
-                cnn_fn=self.parameters.cnn_fn,
-                max_cnns=self._max_cnns  # Todo: private access OK?
-            )
+                    self._shared_params,
+                    self.flattener_trainer,
+                    max_cnns=self._max_cnns  # Todo: private access OK?
+                    )
         return self._cnn_trainer
 
+    @property
+    def segmentation_trainer(self):
+        if self._segmentation_trainer is None:
+            self._segmentation_trainer = SegmentationTrainer(
+                self._shared_params, self._shared_data, 
+                self.smoothing_sigma_trainer, self.cnn_trainer)
+        return self._segmentation_trainer
+
     @property
     def track_trainer(self):
         if self._track_trainer is None:
             self._track_trainer = CellTrainer(self.tracker_data._metadata,
-                                              self.tracker_data)
+                    self.tracker_data)
         return self._track_trainer
 
     @track_trainer.setter
-    def track_trainer(self, all_feats2use=None):
+    def track_trainer(self, all_feats2use):
         self._track_trainer = CellTrainer(self.tracker_data._metadata,
-                                          data=self.tracker_data,
-                                          all_feats2use=all_feats2use)
+                data=self.tracker_data,
+                all_feats2use=all_feats2use)
 
     @property
     def bud_trainer(self):
-        props_file = self.save_dir / self.parameters.mother_bud_props_file
-        if not hasattr(self, '_bud_trainer') or not self._bud_trainer:
+        if self._bud_trainer is None:
             self._bud_trainer = BudTrainer(
-                props_file=props_file,
-                px_size=self.parameters.target_pixel_size)
+                self._shared_params, self._shared_data,
+                self.segmentation_trainer)
         return self._bud_trainer
 
-    @property
-    def parameters(self):
-        if not hasattr(self, '_parameters') or not self._parameters:
-            param_file = self.save_dir / self._parameters_file
-            if param_file.is_file():
-                with open(param_file, 'rt') as f:
-                    params = json.load(f, object_hook=as_python_object)
-                if not isinstance(params, BabyTrainerParameters):
-                    raise BadFile('Parameters file has been corrupted')
-                self._parameters = params
-            else:
-                self.parameters = BabyTrainerParameters()
-        return self._parameters
-
-    @parameters.setter
-    def parameters(self, params):
-        if isinstance(params, dict):
-            p = self.parameters._asdict()
-            p.update(params)
-            params = BabyTrainerParameters(**p)
-        elif not isinstance(params, BabyTrainerParameters):
-            params = BabyTrainerParameters(*params)
-        self._parameters = params
-        with open(self.save_dir / self._parameters_file, 'wt') as f:
-            json.dump(jsonify(self._parameters), f)
-
-    @property
-    def in_memory(self):
-        return self.parameters.in_memory
-
-    @in_memory.setter
-    def in_memory(self, x):
-        self.parameters = dict(in_memory=x)
-
-    def _check_for_data_update(self):
-        if getattr(self, '_ncells', None) != self._impairs.ncells:
-            # Reset generators
-            self._gen_train = None
-            self._gen_val = None
-            self._gen_test = None
-            # Trigger save of the data
-            datafile = self.save_dir / self.parameters.train_val_test_pairs_file
-            self._impairs.save(datafile, self.base_dir)
-            self._ncells = self._impairs.ncells
-
-            # And for thet tracker datasets too
-    def _tracker_check_for_data_update(self):
-        if getattr(self, '_tracker_ncells',
-                   None) != self._tracker_impairs.ncells:
-            datafile = self.save_dir / self.parameters.tracker_tvt_pairs_file
-            self._tracker_impairs.save(datafile, self.base_dir)
-            self._tracker_ncells = self._tracker_impairs.ncells
-
-    @property
-    def data(self):
-        if not hasattr(self, '_impairs') or not self._impairs:
-            self._impairs = TrainValTestPairs()
-            pairs_file = self.save_dir / self.parameters.train_val_test_pairs_file
-            if pairs_file.is_file():
-                self._impairs.load(pairs_file, self.base_dir)
-        self._check_for_data_update()
-        return self._impairs
-
-    @data.setter
-    def data(self, train_val_test_pairs):
-        if isinstance(train_val_test_pairs, str):
-            pairs_file = find_file(train_val_test_pairs, self.save_dir,
-                                   'data')
-            train_val_test_pairs = TrainValTestPairs()
-            train_val_test_pairs.load(pairs_file, self.base_dir)
-        if not isinstance(train_val_test_pairs, TrainValTestPairs):
-            raise BadType(
-                '"data" must be of type "baby.io.TrainValTestPairs"')
-        self._impairs = train_val_test_pairs
-        self._check_for_data_update()
-
-    @property
-    def tracker_data(self):
-        if not hasattr(self, '_impairs') or not self._impairs:
-            self._tracker_impairs = TrainValTestPairs()
-            pairs_file = self.save_dir / self.parameters.tracker_tvt_pairs_file
-            if pairs_file.is_file():
-                self._tracker_impairs.load(pairs_file, self.base_dir)
-        self._check_for_data_update()
-        return self._tracker_impairs
-
-    @data.setter
-    def tracker_data(self, train_val_test_pairs):
-        if isinstance(train_val_test_pairs, str):
-            pairs_file = find_file(train_val_test_pairs, self.save_dir,
-                                   'data')
-            train_val_test_pairs = TrainValTestPairs()
-            train_val_test_pairs.load(pairs_file, self.base_dir)
-        if not isinstance(train_tvt_pairs, TrainValTestPairs):
-            raise BadType(
-                '"data" must be of type "baby.io.TrainValTestPairs"')
-        self._tracker_impairs = train_val_test_pairs
-        self._tracker_check_for_data_update()
-
-    @property
-    def gen(self):
-        # NB: generator init ensures all specified images exist
-        # NB: only dummy augmenters are assigned to begin with
-        p = self.parameters
-        if not getattr(self, '_gen_train', None):
-            if len(self.data.training) == 0:
-                raise BadProcess('No training images have been added')
-            # Initialise generator for training images
-            self._gen_train = ImageLabel(self.data.training,
-                                         batch_size=p.batch_size,
-                                         aug=Augmenter(),
-                                         preprocess=(robust_norm, seg_norm),
-                                         in_memory=p.in_memory)
-
-        if not getattr(self, '_gen_val', None):
-            if len(self.data.validation) == 0:
-                raise BadProcess('No validation images have been added')
-            # Initialise generator for validation images
-            self._gen_val = ImageLabel(self.data.validation,
-                                       batch_size=p.batch_size,
-                                       aug=Augmenter(),
-                                       preprocess=(robust_norm, seg_norm),
-                                       in_memory=p.in_memory)
-
-        if not getattr(self, '_gen_test', None):
-            if len(self.data.testing) == 0:
-                raise BadProcess('No testing images have been added')
-            # Initialise generator for testing images
-            self._gen_test = ImageLabel(self.data.testing,
-                                       batch_size=p.batch_size,
-                                       aug=Augmenter(),
-                                       preprocess=(robust_norm, seg_norm),
-                                       in_memory=p.in_memory)
-
-        return TrainValTestProperty(self._gen_train, self._gen_val,
-                                    self._gen_test)
-
-    def plot_gen_sample(self, validation=False):
-        # TODO: Move to flattener?
-        g = self.gen.val if validation else self.gen.train
-        g.aug = self.aug.val if validation else self.aug.train
-        img_batch, lbl_batch = g[0]
-        lbl_batch = np.concatenate(lbl_batch, axis=3)
-
-        f = self.flattener
-        target_names = f.names()
-        edge_inds = np.flatnonzero([t.prop == 'edge' for t in f.targets])
-
-        ncol = len(img_batch)
-        nrow = len(target_names) + 1
-        fig = plt.figure(figsize=(3 * ncol, 3 * nrow))
-        for b, (bf, seg) in enumerate(zip(img_batch, lbl_batch)):
-            plt.subplot(nrow, ncol, b + 0 * ncol + 1)
-            plt.imshow(bf[:, :, 0], cmap='gray')
-            plt.imshow(colour_segstack(seg[:, :, edge_inds]))
-
-            for i, name in enumerate(target_names):
-                plt.subplot(nrow, ncol, b + (i + 1) * ncol + 1)
-                plt.imshow(seg[:, :, i], cmap='gray')
-                plt.title(name)
-
-        fig.savefig(self.save_dir / '{}_generator_sample.png'.format(
-            'validation' if validation else 'training'))
-
-    def generate_smoothing_sigma_stats(self):
-        # train_gen = augmented_generator(self.gen.train, lambda x, y: (x, y))
-        # val_gen = augmented_generator(self.gen.train, lambda x, y: (x, y))
-        self.smoothing_sigma_trainer.generate_smoothing_sigma_stats(
-            self.gen.train, self.gen.val)
+    def fit_smoothing_model(self, filt='identity'):
+        try:
+            self.smoothing_sigma_trainer.stats
+        except BadProcess:
+            self.smoothing_sigma_trainer.generate_smoothing_sigma_stats()
+        self.smoothing_sigma_trainer.fit(filt=filt)
 
-    @property
-    def smoothing_sigma_stats(self):
+    def plot_fitted_smoothing_model(self):
         warnings.warn(
-            "nursery.smoothing_sigma_stats will soon be "
-            "deprecated, use nursery.smoothing_sigma_trainer.stats "
-            "instead", DeprecationWarning)
-        return self.smoothing_sigma_trainer.stats
+                "nursery.plot_fitted_smoothing_sigma_model will soon be "
+                "deprecated, use "
+                "nursery.smoothing_signa_trainer.plot_fitted_model "
+                "instead", DeprecationWarning)
+        self.smoothing_sigma_trainer.plot_fitted_model()
 
     @property
     def smoothing_sigma_model(self):
         return self.smoothing_sigma_trainer.model
 
-    def generate_flattener_stats(self, max_erode=5):
-        warnings.warn(
-            "nursery.generate_flattener_stats will soon be "
-            "deprecated, use "
-            "nursery.flattener_trainer.generate_stats(train_gen, "
-            "val_gen, train_aug, val_aug, max_erode=5) instead",
-            DeprecationWarning)
-
-        flattener = lambda x, y: x
-        # NB: use isval=True for training aug since we do not need extra
-        # augmentations for calibrating the flattener
-        tAug = _std_aug(self.smoothing_sigma_model, flattener,
-                        self.parameters, isval=True)
-        vAug = _std_aug(self.smoothing_sigma_model,
-                        flattener,
-                        self.parameters,
-                        isval=True)
-
-        self.flattener_trainer.generate_flattener_stats(*self.gen[:2],
-                                                        tAug,
-                                                        vAug,
-                                                        max_erode=max_erode)
+    @smoothing_sigma_model.setter
+    def smoothing_sigma_model(self, m):
+        self.smoothing_sigma_trainer.model = m
 
-    @property
-    def flattener_stats(self):
-        warnings.warn(
-            "nursery.flattener_stats will soon be "
-            "deprecated, use nursery.flattener_trainer.stats "
-            "instead", DeprecationWarning)
-        return self.flattener_trainer.stats
+    def fit_flattener(self, max_erode=5, **kwargs):
+        try:
+            self.flattener_trainer.stats
+        except BadProcess:
+            self.flattener_trainer.generate_flattener_stats(
+                max_erode=max_erode)
+        self.flattener_trainer.fit(**kwargs)
+
+    def plot_flattener_stats(self, **kwargs):
+        self.flattener_trainer.plot_stats(**kwargs)
 
     @property
     def flattener(self):
-        warnings.warn(
-            "nursery.flattener will soon be "
-            "deprecated, use nursery.flattener_trainer.flattener "
-            "instead", DeprecationWarning)
         return self.flattener_trainer.flattener
 
     @flattener.setter
     def flattener(self, f):
-        warnings.warn(
-            "nursery.flattener will soon be "
-            "deprecated, use nursery.flattener_trainer.flattener "
-            "instead", DeprecationWarning)
         self.flattener_trainer.flattener = f
 
-    @property
-    def aug(self):
-        p = self.parameters
-        t = _std_aug(self.smoothing_sigma_model, self.flattener,
-                     self.parameters)
-        v = _std_aug(self.smoothing_sigma_model,
-                     self.flattener,
-                     self.parameters,
-                     isval=True)
-        w = _std_aug(self.smoothing_sigma_model,
-                     self.flattener,
-                     self.parameters,
-                     isval=True)
-        return TrainValTestProperty(t, v, w)
-
-    @property
-    def cnn_fn(self):
-        warnings.warn(
-            "nursery.cnn_fn will soon be "
-            "deprecated, use nursery.cnn_trainer.cnn_fn "
-            "instead", DeprecationWarning)
-        return self.cnn_trainer.cnn_fn
-
-    @cnn_fn.setter
-    def cnn_fn(self, fn):
-        warnings.warn(
-            "nursery.cnn_fn will soon be "
-            "deprecated, use nursery.cnn_trainer.cnn_fn "
-            "instead", DeprecationWarning)
-        self.cnn_trainer.cnn_fn = fn
-
-    @property
-    def cnn_dir(self):
-        warnings.warn(
-            "nursery.cnn_dir will soon be "
-            "deprecated, use nursery.cnn_trainer.cnn_dir "
-            "instead", DeprecationWarning)
-        return self.cnn_trainer.cnn_dir
-
-    @property
-    def cnn_name(self):
-        warnings.warn(
-            "nursery.cnn_bane will soon be "
-            "deprecated, use nursery.cnn_trainer.cnn_name "
-            "instead", DeprecationWarning)
-        return self.cnn_trainer.cnn_name
-
-    @property
-    def cnn(self):
-        warnings.warn(
-            "nursery.cnn will soon be "
-            "deprecated, use nursery.cnn_trainer.cnn "
-            "instead", DeprecationWarning)
-        return self.cnn_trainer.cnn
-
-    @property
-    def histories(self):
-        warnings.warn(
-            "nursery.histories will soon be "
-            "deprecated, use nursery.cnn_trainer.histories "
-            "instead", DeprecationWarning)
-        return self.cnn_trainer.histories
-
-    @property
-    def cnn_opt_dir(self):
-        warnings.warn(
-            "nursery.opt_dir will soon be "
-            "deprecated, use nursery.cnn_trainer.opt_dir "
-            "instead", DeprecationWarning)
-        return self.cnn_trainer.opt_dir
-
-    @property
-    def cnn_opt(self):
-        warnings.warn(
-            "nursery.cnn_opt will soon be "
-            "deprecated, use nursery.cnn_trainer.opt_cnn "
-            "instead", DeprecationWarning)
-        return self.cnn_trainer.opt_cnn
-
-    def fit_smoothing_model(self, filt='identity'):
-        warnings.warn(
-            "nursery.fit_smoothing_model will soon be "
-            "deprecated, use nursery.smoothing_signa_trainer.fit "
-            "instead", DeprecationWarning)
-        self.smoothing_sigma_trainer.fit(filt=filt)
-
-    def plot_fitted_smoothing_sigma_model(self):
-        warnings.warn(
-            "nursery.plot_fitted_smoothing_sigma_model will soon be "
-            "deprecated, use "
-            "nursery.smoothing_signa_trainer.plot_fitted_model "
-            "instead", DeprecationWarning)
-        self.smoothing_sigma_trainer.plot_fitted_model()
-
-    def fit_flattener(self, **kwargs):
-        warnings.warn(
-            "nursery.fit_flattener will soon be "
-            "deprecated, use nursery.flattener_trainer.fit "
-            "instead", DeprecationWarning)
-        self.flattener_trainer.fit(**kwargs)
-
-    def plot_flattener_stats(self, **kwargs):
-        warnings.warn(
-            "nursery.plot_flattener_stats will soon be "
-            "deprecated, use nursery.flattener_trainer.plot_stats "
-            "instead", DeprecationWarning)
-        self.flattener_trainer.plot_stats(**kwargs)
+    def plot_generator_sample(self, **kwargs):
+        self.flattener_trainer.plot_default_gen_sample(**kwargs)
 
     def fit_cnn(self, **kwargs):
-        warnings.warn(
-            "nursery.fit_cnn will soon be "
-            "deprecated, use nursery.cnn_trainer.fit "
-            "instead", DeprecationWarning)
         self.cnn_trainer.fit(**kwargs)
 
-    def plot_histories(self, **kwargs):
-        warnings.warn(
-            "nursery.plot_histories will soon be "
-            "deprecated, use nursery.cnn_trainer.plot_histories "
-            "instead", DeprecationWarning)
+    def plot_cnn_histories(self, **kwargs):
         self.cnn_trainer.plot_histories(**kwargs)
 
-    # TODO: move to Segmentation Param Trainer
     @property
-    def seg_examples(self):
-        p = self.parameters
-        a = ScalingAugmenter(self.smoothing_sigma_model,
-                             lambda lbl, _: lbl,
-                             xy_out=p.xy_out,
-                             target_pixel_size=p.target_pixel_size,
-                             substacks=p.substacks,
-                             p_noop=1,
-                             probs={
-                                 'vshift': 0.25,
-                                 'hshift': 0.25
-                             })
-
-        def seg_example_aug(img, lbl):
-            # Assume that the label preprocessing function also returns info
-            _, info = lbl
-            img, lbl = a(img, lbl)
-            # In this case, always prefer the validation augmenter
-            return img, lbl > 0, info
-
-        def example_generator(dgen):
-            opt_cnn = self.cnn_opt
-            b_iter = batch_iterator(list(range(dgen.n_pairs)),
-                                    batch_size=dgen.batch_size)
-            with tqdm(total=dgen.n_pairs) as pbar:
-                for b_inds in b_iter:
-                    batch = [
-                        dgen.get_by_index(b, aug=seg_example_aug)
-                        for b in b_inds
-                    ]
-                    preds = split_batch_pred(
-                        opt_cnn.predict(np.stack([img for img, _, _ in batch
-                                                 ])))
-                    for pred, (img, lbl, info) in zip(preds, batch):
-                        pbar.update()
-                        lbl = lbl.transpose(2, 0, 1)
-                        # Filter out examples that have been augmented away
-                        valid = lbl.sum(axis=(1, 2)) > 0
-                        lbl = lbl[valid]
-                        clab = info.get('cellLabels', []) or []
-                        if type(clab) is int:
-                            clab = [clab]
-                        clab = [l for l, v in zip(clab, valid) if v]
-                        info['cellLabels'] = clab
-                        buds = info.get('buds', []) or []
-                        if type(buds) is int:
-                            buds = [buds]
-                        buds = [b for b, v in zip(buds, valid) if v]
-                        info['buds'] = buds
-                        yield SegExample(pred, lbl, info, img)
-
-        if self.in_memory:
-            if getattr(self, '_seg_examples', None) is None:
-                self._seg_examples = TrainValTestProperty(
-                    list(example_generator(self.gen.train)),
-                    list(example_generator(self.gen.val)),
-                    list(example_generator(self.gen.test)))
-            return TrainValTestProperty((e for e in self._seg_examples.train),
-                                        (e for e in self._seg_examples.val),
-                                        (e for e in self._seg_examples.test))
-        else:
-            self._seg_examples = None
-            return TrainValTestProperty(example_generator(self.gen.train),
-                                        example_generator(self.gen.val),
-                                        example_generator(self.gen.test))
+    def cnn_dir(self):
+        """Directory containing saved weights for the optimised CNN."""
+        return self.cnn_trainer.opt_dir
 
-    # TODO Move to Segmentation Parameter TRainer
-    @property
-    def seg_param_stats(self):
-        if getattr(self, '_seg_param_stats', None) is None:
-            p = self.parameters
-            stats_file = self.save_dir / p.segmentation_stats_file
-            if not stats_file.is_file():
-                raise BadProcess('"fit_seg_params" has not been run yet')
-            self._seg_param_stats = pd.read_csv(stats_file, index_col=0)
-        return self._seg_param_stats
-
-    # TODO Move to Segmentation Parameter TRainer
     @property
-    def seg_params(self):
-        params_file = self.save_dir / self.parameters.segmentation_param_file
-        with open(params_file, 'rt') as f:
-            params = json.load(f)
-        return params
-
-    # TODO Move to Segmentation Parameter TRainer
-    @seg_params.setter
-    def seg_params(self, val):
-        if not type(val) == dict:
-            raise BadParam('"seg_params" should be a "dict"')
-        msg_args = inspect.getfullargspec(MorphSegGrouped.__init__).args
-        if not set(val.keys()).issubset(msg_args):
-            raise BadParam(
-                '"seg_params" must specify arguments to "MorphSegGrouped"')
-        params_file = self.save_dir / self.parameters.segmentation_param_file
-        with open(params_file, 'wt') as f:
-            json.dump(jsonify(val), f)
-
-    # TODO Deprecate and keep in bud_trainer
-    def generate_bud_stats(self):
-        self.bud_trainer.generate_property_table(self.seg_examples,
-                                                 self.flattener)
-
-    # TODO Depcrecate and keep in bud_trainer
-    def fit_bud_model(self, **kwargs):
-        self.bud_trainer.explore_hyperparams(**kwargs)
-        model_file = self.save_dir / self.parameters.mother_bud_model_file
-        self.bud_trainer.save_model(model_file)
-
-    #Todo Move to Segmentation Parameter Trainer
-    def fit_seg_params(self, njobs=5, scoring='F0_5'):
-        param_grid = list(product(*seg_param_coords.values()))
-        basic_pars = list(seg_param_coords.keys())
-
-        # TODO switch back to validation examples
-        val_examples = list(self.seg_examples.val)
-        from joblib import Parallel, delayed
-        rows = []
-        for gind in range(3)[::-1]:
-            rows.extend(
-                Parallel(n_jobs=njobs)(
-                    delayed(_seg_filter_optim)(gind,
-                                               pars,
-                                               basic_pars,
-                                               self.flattener,
-                                               val_examples,
-                                               base_params=base_seg_params,
-                                               scoring=scoring)
-                    for pars in tqdm(param_grid)))
-
-        rows_expanded = [
-            dict(
-                chain(*[[('_'.join((k, str(g))), gv)
-                         for g, gv in enumerate(v)] if type(v) == list else [(
-                             k, v)]
-                        for k, v in chain([
-                            ('group', row['group']), ('score', row['score'])
-                        ], row['basic'].items(), row['filter'].items())]))
-            for row in rows
-        ]
-
-        self._seg_param_stats = pd.DataFrame(rows_expanded)
-        stats_file = self.save_dir / self.parameters.segmentation_stats_file
-        self._seg_param_stats.to_csv(stats_file)
-
-        self.refit_filter_seg_params(scoring=scoring)
-
-    # TODO move to Segmentation Parameter Trainer
-    def refit_filter_seg_params(self,
-                                lazy=False,
-                                bootstrap=False,
-                                scoring='F0_5'):
-
-        # Merge the best parameters from each group into a single parameter set
-        merged_params = {k: v.copy() for k, v in base_seg_params.items()}
-        stats = self.seg_param_stats
-        for g, r in enumerate(stats.groupby('group').score.idxmax()):
-            for k in merged_params:
-                merged_params[k][g] = stats.loc[r, k + '_' + str(g)]
-
-        sfpo = SegFilterParamOptim(self.flattener,
-                                   basic_params=merged_params,
-                                   scoring=scoring)
-        sfpo.generate_stat_table(self.seg_examples.val)
-
-        sfpo.fit_filter_params(lazy=lazy, bootstrap=bootstrap)
-        merged_params.update(sfpo.opt_params)
-        self.seg_params = merged_params
-
-    # TODO Move to Segmentation Parameter Trainer
-    def validate_seg_params(self, iou_thresh=0.7, save=True):
-        segmenter = MorphSegGrouped(self.flattener,
-                                    return_masks=True,
-                                    fit_radial=True,
-                                    use_group_thresh=True,
-                                    **self.seg_params)
-        edge_inds = [
-            i for i, t in enumerate(self.flattener.targets)
-            if t.prop == 'edge'
-        ]
-        stats = {}
-        dfs = {}
-        for k, seg_exs in zip(self.seg_examples._fields, self.seg_examples):
-            stats[k] = []
-            for seg_ex in seg_exs:
-                seg = segmenter.segment(seg_ex.pred, refine_outlines=True)
-                edge_scores = np.array([
-                    seg_ex.pred[edge_inds, ...].max(axis=0)[s].mean()
-                    for s in seg.edges
-                ])
-                IoUs = calc_IoUs(seg_ex.target, seg.masks)
-                bIoU, _ = best_IoU(IoUs)
-                stats[k].append((edge_scores, IoUs, np.mean(bIoU),
-                                 np.min(bIoU, initial=1),
-                                 calc_AP(IoUs,
-                                         probs=edge_scores,
-                                         iou_thresh=iou_thresh)[0]))
-            dfs[k] = pd.DataFrame([s[2:] for s in stats[k]],
-                                  columns=['IoU_mean', 'IoU_min', 'AP'])
-
-        print({k: df.mean() for k, df in dfs.items()})
-
-        nrows = len(dfs)
-        ncols = dfs['val'].shape[1]
-        fig, axs = plt.subplots(nrows=nrows,
-                                ncols=ncols,
-                                figsize=(ncols * 4, nrows * 4))
-        for axrow, (k, df) in zip(axs, dfs.items()):
-            for ax, col in zip(axrow, df.columns):
-                ax.hist(df.loc[:, col], bins=26, range=(0, 1))
-                ax.set(xlabel=col, title=k)
-        if save:
-            fig.savefig(self.save_dir / 'seg_validation_plot.png')
-            plt.close(fig)
-
+    def cnn(self):
+        """Optimised CNN from the ``cnn_trainer``."""
+        return self.cnn_trainer.opt_cnn
 
 class Nursery(BabyTrainer):
     pass
 
 
-def load_history(subdir):
-    with open(LOG_DIR / subdir / 'history.pkl', 'rb') as f:
-        return pickle.load(f)
-
-
 def get_best_and_worst(model, gen):
     best = {}
     worst = {}
@@ -843,39 +332,3 @@ def get_best_and_worst(model, gen):
                         worst[output][worst_maxind] = out
 
     return best, worst
-
-
-def _seg_filter_optim(g,
-                      p,
-                      pk,
-                      flattener,
-                      seg_gen,
-                      base_params=default_params,
-                      scoring='F0_5'):
-    p = _sub_params({(k, g): v for k, v in zip(pk, p)}, base_params)
-    sfpo = SegFilterParamOptim(flattener, basic_params=p, scoring=scoring)
-    sfpo.generate_stat_table(seg_gen)
-    sfpo.fit_filter_params(lazy=True, bootstrap=False)
-    return {
-        'group': g,
-        'basic': p,
-        'filter': sfpo.opt_params,
-        'score': sfpo.opt_score
-    }
-
-
-def _std_aug(ssm, flattener, p, isval=False):
-    probs = {'vshift': 0.25, 'hshift': 0.25}
-    extra_args = {}
-    if isval:
-        extra_args['p_noop'] = 1
-    else:
-        probs['rotate'] = 0.2
-
-    return ScalingAugmenter(ssm,
-                            flattener,
-                            xy_out=p.xy_out,
-                            target_pixel_size=p.target_pixel_size,
-                            substacks=p.substacks,
-                            probs=probs,
-                            **extra_args)
diff --git a/python/baby/training/utils.py b/python/baby/training/utils.py
index 90dbb2fea7ef229d6b712aa732e6430a6c990367..d0730044a3911f95d3d2351809b9a4946788c95f 100644
--- a/python/baby/training/utils.py
+++ b/python/baby/training/utils.py
@@ -1,12 +1,14 @@
 # If you publish results that make use of this software or the Birth Annotator
 # for Budding Yeast algorithm, please cite:
-# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S
-# Swain, 2021, Birth Annotator for Budding Yeast (in preparation).
+# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+# and Swain, P.S. (2023). Determining growth rates from bright-field images of
+# budding cells through identifying overlaps. eLife. 12:e79812.
+# https://doi.org/10.7554/eLife.79812
 # 
 # 
 # The MIT License (MIT)
 # 
-# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021
+# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
 # 
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to
@@ -25,13 +27,24 @@
 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
 # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
 # IN THE SOFTWARE.
-from contextlib import contextmanager
 from typing import NamedTuple, Any, Tuple, Union
+from itertools import repeat
+import shutil
+import json
+from pathlib import Path
 
 import tensorflow as tf
-from baby.augmentation import Augmenter
-from baby.generator import ImageLabel
-from baby.utils import EncodableNamedTuple
+from tensorflow.keras.optimizers.schedules import CosineDecay
+
+from baby.errors import BadFile, BadParam, BadProcess, BadType
+from baby.io import TrainValTestPairs
+from baby.preprocessing import (robust_norm, robust_norm_dw, seg_norm,
+                                SegmentationFlattening)
+from baby.augmentation import Augmenter, ScalingAugmenter, SmoothingSigmaModel
+from baby.generator import ImageLabel, AugmentedGenerator
+from baby.utils import (find_file, EncodableNamedTuple,
+                        jsonify, as_python_object)
+from baby.morph_thresh_seg import SegmentationParameters
 
 
 def fix_tf_rtx_gpu_bug():
@@ -52,10 +65,10 @@ def fix_tf_rtx_gpu_bug():
             logical_gpus = tf.config.experimental.list_logical_devices('GPU')
             print(len(gpus), "Physical GPUs,", len(logical_gpus),
                   "Logical GPUs")
-        else:
-            raise Exception(
-                'Unsupported version of tensorflow encountered ({})'.format(
-                    tf.version.VERSION))
+    else:
+        raise Exception(
+            'Unsupported version of tensorflow encountered ({})'.format(
+                tf.version.VERSION))
 
 
 class TrainValProperty(NamedTuple):
@@ -83,24 +96,379 @@ class BabyTrainerParameters(NamedTuple):
     segmentation_param_file: str = 'segmentation_params.json'
     mother_bud_props_file: str = 'mother_bud_props.csv'
     mother_bud_model_file: str = 'mother_bud_model.pkl'
-    cnn_set: Tuple[str, ...] = ('msd_d80', 'unet_4s')
+    cnn_set: Tuple[str, ...] = ('unet_4s',)
     cnn_fn: Union[None, str] = None
     batch_size: int = 8
     in_memory: bool = True
     xy_out: int = 80
     target_pixel_size: float = 0.263
     substacks: Union[None, int] = None
+    aug_probs: dict = {}
+    aug_p_noop: float = 0.05
+    base_seg_params: dict = {}
+    seg_param_coords: dict = {}
+    input_norm_dw: bool = False
+    only_basic_augs: bool = True
+    balanced_sampling: bool = False
+    use_sample_weights: bool = False
+    canny_padding: int = 2
+    n_jobs: int = 4
+
+
+TRAINING_PARAMETERS_FILENAME =  'training_parameters.json'
+"""File name to which training parameters are saved"""
+
+
+class SharedParameterContainer(object):
+    """A container of current training-related parameters.
+
+    Designed to be used by :py:class:`training.BabyTrainer` and passed to the
+    children trainers. Updates on this object by the
+    :py:class:`training.BabyTrainer` are then automatially propagated to the
+    classes that use the parameters.
+
+    Parameters are auto-saved to a file
+    :py:const:`TRAINING_PARAMETERS_FILENAME` within the ``save_dir``.
+
+    Args:
+        save_dir (str or Path): directory in which to save parameters and/or
+            from which to auto-load parameters.
+        params (None or BabyTrainerParameters or str or Path): Optionally
+            specify initial training parameters as a
+            :py:class:`BabyTrainerParameters` instance or the path to a saved
+            :py:class:`BabyTrainerParameters` instance.
+    """
+    def __init__(self, save_dir, params=None):
+        self.save_dir = Path(save_dir)
+        self._parameters_file = TRAINING_PARAMETERS_FILENAME
+
+        # Trigger parameter loading/initialisation via property
+        self.parameters
+        self.segmentation_parameters
+
+        # Register parameters if specified
+        if isinstance(params, BabyTrainerParameters):
+            self.parameters = params
+        elif params is not None:
+            filename = find_file(params, save_dir, 'params')
+            savename = save_dir / self._parameters_file
+            if filename != savename:
+                shutil.copy(filename, savename)
+
+    @property
+    def parameters(self):
+        """A :py:class:`BabyTrainerParameters` instance.
+
+        If not already initialised, parameters are loaded from the save file
+        if found, otherwise they are initialised to defaults as per
+        :py:class:`BabyTrainerParameters`.
+
+        This can be set to either a new ``BabyTrainerParameters`` instance, or
+        to a ``dict`` that specifies existing parameter values to replace.
+        """
+        if not getattr(self, '_parameters', None):
+            param_file = self.save_dir / self._parameters_file
+            if param_file.is_file():
+                with open(param_file, 'rt') as f:
+                    params = json.load(f, object_hook=as_python_object)
+                if not isinstance(params, BabyTrainerParameters):
+                    raise BadFile('Parameters file has been corrupted')
+                self._parameters = params
+            else:
+                self.parameters = BabyTrainerParameters()
+        return self._parameters
+
+    @parameters.setter
+    def parameters(self, params):
+        if isinstance(params, dict):
+            if not getattr(self, '_parameters', None):
+                self._parameters = BabyTrainerParameters()
+            params = self._parameters._replace(**params)
+        elif not isinstance(params, BabyTrainerParameters):
+            params = BabyTrainerParameters(*params)
+        self._parameters = params
+        with open(self.save_dir / self._parameters_file, 'wt') as f:
+            json.dump(jsonify(self._parameters), f)
+
+    @property
+    def segmentation_parameters(self):
+        """A :py:class:`baby.segmentation.SegmentationParameters` instance.
+
+        If not already initialised, parameters are loaded from the save file
+        if found, otherwise they are initialised to defaults as per
+        :py:class:`baby.segmentation.SegmentationParameters`.
+
+        This can be set to either a new ``SegmentationParameters`` instance,
+        or to a ``dict`` that specifies existing parameter values to replace.
+        """
+        if not getattr(self, '_segmentation_parameters', None):
+            seg_param_file = (self.save_dir /
+                              self.parameters.segmentation_param_file)
+            if seg_param_file.is_file():
+                with open(seg_param_file, 'rt') as f:
+                    params = json.load(f, object_hook=as_python_object)
+                if not isinstance(params, SegmentationParameters):
+                    raise BadFile(
+                        'Segmentation parameters file has been corrupted.')
+                self._segmentation_parameters = params
+            else:
+                self.segmentation_parameters = SegmentationParameters()
+        return self._segmentation_parameters
+
+    @segmentation_parameters.setter
+    def segmentation_parameters(self, params):
+        if isinstance(params, dict):
+            if not getattr(self, '_segmentation_parameters', None):
+                self._segmentation_parameters = SegmentationParameters()
+            params = self._segmentation_parameters._replace(**params)
+        elif not isinstance(params, SegmentationParameters):
+            params = SegmentationParameters(*params)
+        self._segmentation_parameters = params
+        seg_param_file = (self.save_dir /
+                          self.parameters.segmentation_param_file)
+        with open(seg_param_file, 'wt') as f:
+            json.dump(jsonify(self._segmentation_parameters), f, indent=2)
+
+
+class SharedDataContainer(object):
+    """A container of current data and generators.
+
+    Designed to be used by :py:class:`training.BabyTrainer` and passed to the
+    children trainers. Updates on this object by the
+    :py:class:`training.BabyTrainer` are then automatially propagated to the
+    classes that use the parameters.
+
+    Parameters are auto-saved to a file
+    :py:const:`TRAINING_PARAMETERS_FILENAME` within the ``save_dir``.
+
+    Args:
+        shared_params (SharedParameterContainer): Shared parameters.
+        base_dir (None or str or Path): Base directory within which all
+            relevant image files can be found. References to the image files
+            will be saved relative to this directory.
+    """
+    def __init__(self, shared_params, base_dir=None):
+        self._shared_params = shared_params
+
+        if base_dir is not None:
+            base_dir = Path(base_dir)
+            if not base_dir.is_dir():
+                raise BadParam('"base_dir" must be a valid directory or None')
+        else:
+            base_dir = Path.cwd()
+        self.base_dir = base_dir
 
+    @property
+    def save_dir(self):
+        return self._shared_params.save_dir
 
+    @property
+    def parameters(self):
+        return self._shared_params.parameters
 
-@contextmanager
-def augmented_generator(gen: ImageLabel, aug: Augmenter):
-    # Save the previous augmenter if any
-    saved_aug = gen.aug
-    gen.aug = aug
-    try:
-        yield gen
-    # Todo: add except otherwise there might be an issue of there is an error?
-    finally:
-        gen.aug = saved_aug
+    def _check_for_data_update(self):
+        if getattr(self, '_ncells', None) != self._impairs.ncells:
+            # Reset generators
+            self._gen_train = None
+            self._gen_val = None
+            self._gen_test = None
+            # Trigger save of the data
+            datafile = self.save_dir / self.parameters.train_val_test_pairs_file
+            self._impairs.save(datafile, self.base_dir)
+            self._ncells = self._impairs.ncells
 
+    def _tracker_check_for_data_update(self):
+        # And for thet tracker datasets too
+        if (getattr(self, '_tracker_ncells', None)
+            != self._tracker_impairs.ncells):
+            datafile = self.save_dir / self.parameters.tracker_tvt_pairs_file
+            self._tracker_impairs.save(datafile, self.base_dir)
+            self._tracker_ncells = self._tracker_impairs.ncells
+
+    @property
+    def data(self):
+        if not hasattr(self, '_impairs') or not self._impairs:
+            self._impairs = TrainValTestPairs()
+            pairs_file = self.save_dir / self.parameters.train_val_test_pairs_file
+            if pairs_file.is_file():
+                self._impairs.load(pairs_file, self.base_dir)
+        self._check_for_data_update()
+        return self._impairs
+
+    @data.setter
+    def data(self, train_val_test_pairs):
+        if isinstance(train_val_test_pairs, str):
+            pairs_file = find_file(train_val_test_pairs, self.save_dir,
+                    'data')
+            train_val_test_pairs = TrainValTestPairs()
+            train_val_test_pairs.load(pairs_file, self.base_dir)
+        if not isinstance(train_val_test_pairs, TrainValTestPairs):
+            raise BadType(
+                    '"data" must be of type "baby.io.TrainValTestPairs"')
+        self._impairs = train_val_test_pairs
+        self._check_for_data_update()
+
+    @property
+    def tracker_data(self):
+        if not hasattr(self, '_impairs') or not self._impairs:
+            self._tracker_impairs = TrainValTestPairs()
+            pairs_file = self.save_dir / self.parameters.tracker_tvt_pairs_file
+            if pairs_file.is_file():
+                self._tracker_impairs.load(pairs_file, self.base_dir)
+        self._check_for_data_update()
+        return self._tracker_impairs
+
+    @data.setter
+    def tracker_data(self, train_val_test_pairs):
+        if isinstance(train_val_test_pairs, str):
+            pairs_file = find_file(train_val_test_pairs, self.save_dir,
+                    'data')
+            train_val_test_pairs = TrainValTestPairs()
+            train_val_test_pairs.load(pairs_file, self.base_dir)
+        if not isinstance(train_tvt_pairs, TrainValTestPairs):
+            raise BadType(
+                    '"data" must be of type "baby.io.TrainValTestPairs"')
+        self._tracker_impairs = train_val_test_pairs
+        self._tracker_check_for_data_update()
+
+    @property
+    def gen(self):
+        """Training, validation and test data generators.
+
+        This attribute provides three :py:class:`ImageLabel` generators as a
+        :py:class:`TrainValTestProperty`, with each generator assigned just a
+        dummy augmenter to begin with.
+
+        Note:
+            Generator initialisation requires that all specified images exist.
+        """
+
+        old_gen_params = getattr(self, '_current_gen_params', None)
+        new_gen_params = tuple(getattr(self.parameters, p) for p in
+                               ('in_memory', 'input_norm_dw', 'batch_size',
+                                'balanced_sampling', 'use_sample_weights'))
+        if old_gen_params != new_gen_params:
+            self._gen_train = None
+            self._gen_val = None
+            self._gen_test = None
+            self._current_gen_params = new_gen_params
+
+        (in_memory, input_norm_dw, batch_size,
+         balanced_sampling, use_sample_weights) = new_gen_params
+        input_norm = robust_norm_dw if input_norm_dw else robust_norm
+
+        if not getattr(self, '_gen_train', None):
+            if len(self.data.training) == 0:
+                raise BadProcess('No training images have been added')
+            # Initialise generator for training images
+            self._gen_train = ImageLabel(self.data.training,
+                                         batch_size=batch_size,
+                                         aug=Augmenter(),
+                                         preprocess=(input_norm, seg_norm),
+                                         in_memory=in_memory,
+                                         balanced_sampling=balanced_sampling,
+                                         use_sample_weights=use_sample_weights)
+
+        if not getattr(self, '_gen_val', None):
+            if len(self.data.validation) == 0:
+                raise BadProcess('No validation images have been added')
+            # Initialise generator for validation images
+            self._gen_val = ImageLabel(self.data.validation,
+                                       batch_size=batch_size,
+                                       aug=Augmenter(),
+                                       preprocess=(input_norm, seg_norm),
+                                       in_memory=in_memory,
+                                       balanced_sampling=balanced_sampling,
+                                       use_sample_weights=use_sample_weights)
+
+        if not getattr(self, '_gen_test', None):
+            if len(self.data.testing) == 0:
+                raise BadProcess('No testing images have been added')
+            # Initialise generator for testing images
+            self._gen_test = ImageLabel(self.data.testing,
+                                        batch_size=batch_size,
+                                        aug=Augmenter(),
+                                        preprocess=(input_norm, seg_norm),
+                                        in_memory=in_memory,
+                                        balanced_sampling=balanced_sampling,
+                                        use_sample_weights=use_sample_weights)
+
+        self._gen_train.n_jobs = self.parameters.n_jobs
+        self._gen_val.n_jobs = self.parameters.n_jobs
+        self._gen_test.n_jobs = self.parameters.n_jobs
+        return TrainValTestProperty(self._gen_train, self._gen_val,
+                self._gen_test)
+
+    def gen_with_aug(self, aug):
+        """Returns generators wrapped with alternative augmenters.
+
+        Args:
+            aug (Augmenter or Tuple[Augmenter, Augmenter, Augmenter]):
+                Augmenter to use, or tuple of different augmenters for
+                training, validation and testing generators.
+
+        Returns:
+            :py:class:`TrainValTestProperty` of :py:class:`AugmentedGenerator`
+            objects for training, validation and testing generators.
+        """
+        atrain, aval, atest = aug if type(aug) == tuple else repeat(aug, 3)
+        gtrain, gval, gtest = self.gen
+        return TrainValTestProperty(
+            AugmentedGenerator(gtrain, atrain),
+            AugmentedGenerator(gval, aval),
+            AugmentedGenerator(gtest, atest))
+
+
+VALIDATION_AUGMENTATIONS = {'vshift', 'hshift'}
+
+
+def standard_augmenter(ssm, flattener, params, isval=False):
+    """Returns an augmenter for training on flattenable inputs.
+
+    Args:
+        ssm (SmoothingSigmaModel): Smoothing model to use.
+        flattener (SegmentationFlattening): Flattener to apply after
+            augmentation.
+        params (BabyTrainerParameters): Parameters to use when constructing
+            augmenter.
+        isval (bool): If ``True``, set unspecified augmentation probabilities
+            to zero for generating validation data. If ``False``, augmentation
+            probabilities are left at defaults (see also
+            :py:class:`ScalingAugmenter` and
+            :py:class:`BabyTrainerParameters`).
+
+    Returns:
+        A :py:class:`ScalingAugmenter` with specified ``ssm`` and
+        ``flattener`` and parameterisation according to ``params``.
+    """
+    probs = {'vshift': 0.25, 'hshift': 0.25}
+    extra_args = dict(canny_padding=params.canny_padding,
+                      p_noop=params.aug_p_noop)
+    if isval:
+        extra_args['p_noop'] = 1
+    else:
+        probs['rotate'] = 0.25
+
+    probs.update(params.aug_probs)
+    if isval:
+        probs = {k: v for k, v in probs.items()
+                 if k in VALIDATION_AUGMENTATIONS}
+
+    return ScalingAugmenter(ssm,
+            flattener,
+            xy_out=params.xy_out,
+            target_pixel_size=params.target_pixel_size,
+            substacks=params.substacks,
+            probs=probs,
+            only_basic_augs=params.only_basic_augs,
+            **extra_args)
+
+
+def warmup_and_cosine_decay(learning_rate=0.001, warmup_steps=30, decay_steps=370):
+    decay = CosineDecay(learning_rate, decay_steps)
+    def warmup_schedule(step):
+        if step < warmup_steps:
+            return step / warmup_steps * learning_rate
+        else:
+            return decay(step - warmup_steps)
+    return warmup_schedule
diff --git a/python/baby/training/v1_hyper_parameter_trainer.py b/python/baby/training/v1_hyper_parameter_trainer.py
index 507d8f1139b9fb12e5821214abc165e49b14761a..934d33f9515bc054ca335532cd022208430d5c7b 100644
--- a/python/baby/training/v1_hyper_parameter_trainer.py
+++ b/python/baby/training/v1_hyper_parameter_trainer.py
@@ -1,12 +1,14 @@
 # If you publish results that make use of this software or the Birth Annotator
 # for Budding Yeast algorithm, please cite:
-# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S
-# Swain, 2021, Birth Annotator for Budding Yeast (in preparation).
+# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+# and Swain, P.S. (2023). Determining growth rates from bright-field images of
+# budding cells through identifying overlaps. eLife. 12:e79812.
+# https://doi.org/10.7554/eLife.79812
 # 
 # 
 # The MIT License (MIT)
 # 
-# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021
+# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
 # 
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to
diff --git a/python/baby/utils.py b/python/baby/utils.py
index 02f5bc890ada72420d1bace21931ea0c69b9b81b..ef6ef5c1dd7c9e087d92f011a75db6475517146f 100644
--- a/python/baby/utils.py
+++ b/python/baby/utils.py
@@ -1,12 +1,14 @@
 # If you publish results that make use of this software or the Birth Annotator
 # for Budding Yeast algorithm, please cite:
-# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S
-# Swain, 2021, Birth Annotator for Budding Yeast (in preparation).
+# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+# and Swain, P.S. (2023). Determining growth rates from bright-field images of
+# budding cells through identifying overlaps. eLife. 12:e79812.
+# https://doi.org/10.7554/eLife.79812
 # 
 # 
 # The MIT License (MIT)
 # 
-# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021
+# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
 # 
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to
@@ -67,26 +69,13 @@ def get_name(obj):
     return getattr(obj, '_baby_name', obj.__name__)
 
 
-def NamedTupleToJSON(self):
-    return {
-        '_python_NamedTuple': self._asdict(),
-        '__module__': self.__class__.__module__,
-        '__class__': self.__class__.__name__
-    }
-
-
-def EncodableNamedTuple(obj):
-    obj.toJSON = NamedTupleToJSON
-    return obj
-
-
 def jsonify(obj):
     if hasattr(obj, 'toJSON'):
         return obj.toJSON()
     elif hasattr(obj, 'dtype') and hasattr(obj, 'tolist'):
         return obj.tolist()
     elif isinstance(obj, tuple):
-        return {'_python_tuple': list(obj)}
+        return {'_python_tuple': [jsonify(v) for v in obj]}
     if isinstance(obj, set):
         return {'_python_set': list(obj)}
     elif isinstance(obj, dict):
@@ -97,6 +86,19 @@ def jsonify(obj):
         return obj
 
 
+def NamedTupleToJSON(self):
+    return {
+        '_python_NamedTuple': jsonify(self._asdict()),
+        '__module__': self.__class__.__module__,
+        '__class__': self.__class__.__name__
+    }
+
+
+def EncodableNamedTuple(obj):
+    obj.toJSON = NamedTupleToJSON
+    return obj
+
+
 def as_python_object(obj):
     if '_python_NamedTuple' in obj:
         obj_class = getattr(import_module(obj['__module__']), obj['__class__'])
diff --git a/python/baby/visualise.py b/python/baby/visualise.py
index 524b48fd399d220ac1c16aa5a6d6c846b80de9a0..4ad8ac94bfaee53bc6a8e322d56ed6ddadcd65dd 100644
--- a/python/baby/visualise.py
+++ b/python/baby/visualise.py
@@ -1,12 +1,14 @@
 # If you publish results that make use of this software or the Birth Annotator
 # for Budding Yeast algorithm, please cite:
-# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S
-# Swain, 2021, Birth Annotator for Budding Yeast (in preparation).
+# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+# and Swain, P.S. (2023). Determining growth rates from bright-field images of
+# budding cells through identifying overlaps. eLife. 12:e79812.
+# https://doi.org/10.7554/eLife.79812
 # 
 # 
 # The MIT License (MIT)
 # 
-# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021
+# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
 # 
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to
@@ -80,8 +82,9 @@ def plot_ims(ims, size=4, cmap=plt.cm.gray, show=True, dw=False, **kwargs):
     if dw:
         ims = ims.transpose([2, 0, 1])
 
+    aspect = ims.shape[2] / ims.shape[1]
     ncols = len(ims)
-    fig, axs = plt.subplots(1, ncols, figsize=(ncols * size, size),
+    fig, axs = plt.subplots(1, ncols, figsize=(aspect * ncols * size, size),
                             squeeze=False)
     axs = axs[0]
     for ax, im in zip(axs, ims):
diff --git a/python/baby/volume.py b/python/baby/volume.py
index d62c03f297ae56b7e84fb5d1933ba8b74373f1ac..53ad6ed7a78806fbb8494b7514c8acb674ef73b6 100644
--- a/python/baby/volume.py
+++ b/python/baby/volume.py
@@ -1,12 +1,14 @@
 # If you publish results that make use of this software or the Birth Annotator
 # for Budding Yeast algorithm, please cite:
-# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S
-# Swain, 2021, Birth Annotator for Budding Yeast (in preparation).
+# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+# and Swain, P.S. (2023). Determining growth rates from bright-field images of
+# budding cells through identifying overlaps. eLife. 12:e79812.
+# https://doi.org/10.7554/eLife.79812
 # 
 # 
 # The MIT License (MIT)
 # 
-# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021
+# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
 # 
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to
diff --git a/setup.py b/setup.py
deleted file mode 100644
index 23a3df5e852764e62540fa6bf93a17221a591201..0000000000000000000000000000000000000000
--- a/setup.py
+++ /dev/null
@@ -1,63 +0,0 @@
-from setuptools import setup, find_packages
-
-setup(
-    name='baby',
-    version='0.24',
-    packages=find_packages('python'),
-    package_dir={'': 'python'},
-    include_package_data=True,
-    entry_points={
-        'console_scripts': [
-            'baby-phone = baby.server:main',
-            'baby-race = baby.speed_tests:main',
-            'baby-fit-grs = baby.postprocessing:main'
-            ]
-        },
-    url='',
-    license='MIT License',
-    author='Julian Pietsch',
-    author_email='julian.pietsch@ed.ac.uk',
-    description='Birth Annotator for Budding Yeast',
-    long_description='''
-If you publish results that make use of this software or the Birth Annotator
-for Budding Yeast algorithm, please cite:
-Julian M J Pietsch, Alán F Muñoz, Diane-Yayra A Adjavon, Ivan B N Clark, Peter
-S Swain, 2022, A label-free method to track individuals and lineages of
-budding cells (in submission).
-
-
-The MIT License (MIT)
-
-Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2022
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
-    ''',
-    install_requires=['tensorflow>=1.14,<2.4',
-                      'scipy',
-                      'numpy',
-                      'pandas',
-                      'scikit-image',
-                      'scikit-learn==0.22.2',
-                      'tqdm',
-                      'imageio',
-                      'pillow<9',
-                      'matplotlib',
-                      'aiohttp',
-                      'gaussianprocessderivatives']
-)
diff --git a/tests/brain_test.py b/tests/brain_test.py
index 802b91460758524eb767d5818db652d0f59a3639..f7ca5319bcbc961bb3ab54cb91abde366cf00ff6 100644
--- a/tests/brain_test.py
+++ b/tests/brain_test.py
@@ -1,12 +1,14 @@
 # If you publish results that make use of this software or the Birth Annotator
 # for Budding Yeast algorithm, please cite:
-# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S
-# Swain, 2021, Birth Annotator for Budding Yeast (in preparation).
+# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+# and Swain, P.S. (2023). Determining growth rates from bright-field images of
+# budding cells through identifying overlaps. eLife. 12:e79812.
+# https://doi.org/10.7554/eLife.79812
 # 
 # 
 # The MIT License (MIT)
 # 
-# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021
+# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
 # 
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to
@@ -34,13 +36,15 @@ import json
 
 import baby
 from baby.brain import BabyBrain
-from baby.morph_thresh_seg import MorphSegGrouped
+from baby.morph_thresh_seg import MorphSegGrouped, SegmentationParameters
 
 #from .conftest import BASE_DIR
 
-#MODEL_PATH = BASE_DIR / 'models'
-MODEL_PATH = baby.model_path()
-DEFAULT_MODELSET = 'evolve_brightfield_60x_5z'
+DEFAULT_MODELSET = 'yeast-alcatras-brightfield-EMCCD-60x-5z'
+TEST_MODELSETS = ['yeast-alcatras-brightfield-EMCCD-60x-5z',
+                  'yeast-alcatras-brightfield-sCMOS-60x-5z']
+                  #'ecoli-mm-phase-sCMOS-100x-1z']
+
 
 eqlen_outkeys = {
     'angles', 'radii', 'cell_label', 'edgemasks', 'ellipse_dims', 'volumes',
@@ -52,9 +56,7 @@ eqlen_outkeys = {
 def bb(modelsets, tf_session_graph):
     # Attempt to load default evolve model
     tf_session, tf_graph = tf_session_graph
-    return BabyBrain(session=tf_session,
-                     graph=tf_graph,
-                     **modelsets[DEFAULT_MODELSET])
+    return modelsets.get(DEFAULT_MODELSET, session=tf_session, graph=tf_graph)
 
 
 @pytest.fixture(scope='module')
@@ -62,24 +64,44 @@ def imgstack(imgs_evolve60):
     return np.stack([v['Brightfield'][0] for v in imgs_evolve60.values()])
 
 
-def test_modelsets(modelsets):
+def test_modelsets(modelsets, verify_all_modelsets):
+    # Ensure that at least the test model sets are present
+    all_mset_ids = modelsets.ids()
+    assert all([mset_id in all_mset_ids for mset_id in TEST_MODELSETS])
+
     bb_args = inspect.getfullargspec(BabyBrain.__init__).args
-    msg_args = inspect.getfullargspec(MorphSegGrouped.__init__).args
+    msg_args = set(SegmentationParameters()._fields)
+
+    if verify_all_modelsets:
+        modelsets.update('all', force=False)
+    else:
+        modelsets.update(TEST_MODELSETS, force=False)
 
-    for mset in modelsets.values():
+    for mset_info in modelsets.specifications().values():
         # Make sure all parameters match the BabyBrain and MorphSegGrouped
         # signatures
+        mset = mset_info['brain_params']
         assert set(mset.keys()).issubset(bb_args)
         params = mset.get('params', {})
-        assert set(params.keys()).issubset(msg_args)
+        if type(params) == dict:
+            assert set(params.keys()).issubset(msg_args)
+
+    share_path = modelsets.LOCAL_MODELSETS_PATH / modelsets.SHARE_PATH
+    for mset_id, mset_info in modelsets.specifications(local=True).items():
+        mset_path = modelsets.LOCAL_MODELSETS_PATH / mset_id
+        mset = mset_info['brain_params']
+
+        assert mset_path.is_dir()
+        params = mset.get('params', {})
+        if type(params) != dict and type(params) != SegmentationParameters:
+            assert ((mset_path / params).is_file() or
+                    (share_path / params).is_file())
 
         # Make sure all model files exist
         for k, v in mset.items():
             if k.endswith('_file'):
-                assert (MODEL_PATH / v).is_file()
-
-    # Ensure that the default test model is present
-    assert DEFAULT_MODELSET in modelsets
+                assert ((mset_path / v).is_file() or
+                        (share_path / v).is_file())
 
 
 def test_init(bb, imgstack):
@@ -92,7 +114,7 @@ def test_init(bb, imgstack):
             [len(o['centres']) == len(o[k]) for k in o if k in eqlen_outkeys])
 
 
-def test_segment(bb, imgstack):
+def test_evolve_segment(bb, imgstack):
     # Test segment with all options enabled
     output = bb.segment(imgstack,
                         yield_edgemasks=True,
@@ -105,6 +127,39 @@ def test_segment(bb, imgstack):
             [len(o['centres']) == len(o[k]) for k in o if k in eqlen_outkeys])
 
 
+def test_prime_segment(bb_prime60, imgs_prime60):
+    imgstack = np.stack([v['Brightfield'][0] for v in imgs_prime60.values()])
+    # Test segment with all options enabled
+    output = bb_prime60.segment(imgstack,
+                                yield_edgemasks=True,
+                                yield_masks=True,
+                                yield_preds=True,
+                                yield_volumes=True,
+                                refine_outlines=True)
+    for o in output:
+        assert all(
+            [len(o['centres']) == len(o[k]) for k in o if k in eqlen_outkeys])
+
+
+def test_mm_segment(bb_mmscmos, imgs_mmscmos):
+    # The sample mother machine sCMOS images have different shapes so cannot
+    # be stacked, so segment each image separately
+    for imgpair in imgs_mmscmos.values():
+        img = imgpair['Brightfield'][0]
+        # Test segment with all options enabled except refine_outlines, which is
+        # not yet available for the cartesian splines used for E. coli
+        o = bb_mmscmos.segment(img[None, ...],
+                               yield_edgemasks=True,
+                               yield_masks=True,
+                               yield_preds=True,
+                               yield_volumes=True,
+                               refine_outlines=False)
+        # Expand generator and select first image
+        o = list(o)[0]
+        assert all(
+            [len(o['centres']) == len(o[k]) for k in o if k in eqlen_outkeys])
+
+
 def test_evolve_segment_and_track(bb, imgstack, imgs_evolve60):
     # Test stateless version
     output = bb.segment_and_track(imgstack,
diff --git a/tests/cnn_test.py b/tests/cnn_test.py
index d113088c05b404b49d1fd0fef29d9053cd8a7ba9..8fcfe3bcaf0c8e9d3d0d399cac0766c492e7f1b0 100644
--- a/tests/cnn_test.py
+++ b/tests/cnn_test.py
@@ -1,12 +1,14 @@
 # If you publish results that make use of this software or the Birth Annotator
 # for Budding Yeast algorithm, please cite:
-# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S
-# Swain, 2021, Birth Annotator for Budding Yeast (in preparation).
+# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+# and Swain, P.S. (2023). Determining growth rates from bright-field images of
+# budding cells through identifying overlaps. eLife. 12:e79812.
+# https://doi.org/10.7554/eLife.79812
 # 
 # 
 # The MIT License (MIT)
 # 
-# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021
+# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
 # 
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to
@@ -40,18 +42,17 @@ def test_evolve_predict(bb_evolve60, imgs_evolve60, save_cnn_predictions,
     imgstack = np.stack([robust_norm(*v['Brightfield'])
                          for v in imgs_evolve60.values()])
 
-    preds = bb_evolve60.morph_predict(imgstack)
-    assert len(preds) == len(bb_evolve60.flattener.names())
-    assert all([p_out.shape[:3] == imgstack.shape[:3] for p_out in preds])
+    preds = list(bb_evolve60.morph_predict(imgstack))
+    assert len(preds) == len(imgstack)
+    npredchan = len(bb_evolve60.flattener.names())
+    assert all([len(pred) == npredchan for pred in preds]) 
+    assert all([pred.shape[1:] == imgstack.shape[1:3] for pred in preds])
 
-    morph_preds = split_batch_pred(preds)
-    assert len(morph_preds) == len(imgstack)
-
-    assert all([pred.max() <= 1 and pred.min() >= 0 for pred in morph_preds])
+    assert all([pred.max() <= 1 and pred.min() >= 0 for pred in preds])
 
     if save_cnn_predictions:
         # Save prediction output as 16 bit tiled png
-        for pred, (k, v) in zip(morph_preds, imgs_evolve60.items()):
+        for pred, (k, v) in zip(preds, imgs_evolve60.items()):
             _, info = v['Brightfield']
             info['channel'] = 'cnnpred'
             save_tiled_image(
diff --git a/tests/conftest.py b/tests/conftest.py
index cff9f5cd1e1213b7c1e803b4353fa18dedbe2bea..6e0053a3dc45b63be2938ca972d6bf1437edb77c 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,12 +1,14 @@
 # If you publish results that make use of this software or the Birth Annotator
 # for Budding Yeast algorithm, please cite:
-# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S
-# Swain, 2021, Birth Annotator for Budding Yeast (in preparation).
+# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+# and Swain, P.S. (2023). Determining growth rates from bright-field images of
+# budding cells through identifying overlaps. eLife. 12:e79812.
+# https://doi.org/10.7554/eLife.79812
 # 
 # 
 # The MIT License (MIT)
 # 
-# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021
+# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
 # 
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to
@@ -26,9 +28,11 @@
 # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
 # IN THE SOFTWARE.
 from pathlib import Path
+import os
 
 import baby
 import pytest
+from baby import modelsets as modelsets_module
 from baby.brain import BabyBrain
 from baby.io import load_paired_images
 
@@ -48,15 +52,17 @@ def pytest_addoption(parser):
         "--save-crawler-output", action="store_true", default=False,
         help="When running crawler tests, save the predictions to tmp dir"
     )
+    parser.addoption(
+        "--verify-all-modelsets", action="store_true", default=False,
+        help="Download all models to confirm model set files exist"
+    )
 
-@pytest.fixture(scope='session')
-def model_dir():
-    return baby.model_path()
 
 @pytest.fixture(scope='session')
 def image_dir():
     return IMAGE_DIR
 
+
 @pytest.fixture(scope='session')
 def save_cnn_predictions(request):
     return request.config.getoption("--save-cnn-predictions")
@@ -72,14 +78,25 @@ def save_segment_outlines(request):
     return request.config.getoption("--save-segment-outlines")
 
 
+@pytest.fixture(scope='session')
+def verify_all_modelsets(request):
+    return request.config.getoption("--verify-all-modelsets")
+
+
 @pytest.fixture(scope='session')
 def imgs_evolve60():
     return load_paired_images(IMAGE_DIR.glob('evolve_*.png'))
 
 
 @pytest.fixture(scope='session')
-def modelsets():
-    return baby.modelsets()
+def imgs_prime60():
+    return load_paired_images(IMAGE_DIR.glob('prime95b_*.png'))
+
+
+@pytest.fixture(scope='session')
+def imgs_mmscmos():
+    return load_paired_images(IMAGE_DIR.glob('mmsCMOS_*.png'))
+
 
 @pytest.fixture(scope='session')
 def tf_session_graph():
@@ -111,15 +128,36 @@ def tf_session_graph():
     return tf_session, tf_graph
 
 
+@pytest.fixture(scope='session')
+def modelsets(tmp_path_factory):
+    envpath = os.environ.get(modelsets_module.ENV_VAR_MODELSETS_PATH)
+    if not envpath:
+        envpath = tmp_path_factory.mktemp('modelsets')
+        print(envpath)
+        # Patch modelsets module with temporary storage path
+        modelsets_module.ENV_LOCAL_MODELSETS_PATH = envpath
+        modelsets_module.LOCAL_MODELSETS_PATH = envpath
+        modelsets_module.LOCAL_MODELSETS_CACHE = (
+            envpath / modelsets_module.MODELSETS_FILENAME)
+    return modelsets_module
+
+
 @pytest.fixture(scope='session')
 def bb_evolve60(modelsets, tf_session_graph):
     tf_session, tf_graph = tf_session_graph
-    return BabyBrain(session=tf_session, graph=tf_graph,
-                     **modelsets['evolve_brightfield_60x_5z'])
+    return modelsets.get('yeast-alcatras-brightfield-EMCCD-60x-5z',
+                         session=tf_session, graph=tf_graph)
 
 
 @pytest.fixture(scope='module')
 def bb_prime60(modelsets, tf_session_graph):
     tf_session, tf_graph = tf_session_graph
-    return BabyBrain(session=tf_session, graph=tf_graph,
-                     **modelsets['prime95b_brightfield_60x_5z'])
+    return modelsets.get('yeast-alcatras-brightfield-sCMOS-60x-5z',
+                         session=tf_session, graph=tf_graph)
+
+
+@pytest.fixture(scope='module')
+def bb_mmscmos(modelsets, tf_session_graph):
+    tf_session, tf_graph = tf_session_graph
+    return modelsets.get('ecoli-mm-phase-sCMOS-100x-1z',
+                         session=tf_session, graph=tf_graph)
diff --git a/tests/crawler_test.py b/tests/crawler_test.py
index 97424c12ff6202142384f0cf05e42e193e707e52..5b51fed959e659b3367af0c101cbdf10582d8c5d 100644
--- a/tests/crawler_test.py
+++ b/tests/crawler_test.py
@@ -1,12 +1,14 @@
 # If you publish results that make use of this software or the Birth Annotator
 # for Budding Yeast algorithm, please cite:
-# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S
-# Swain, 2021, Birth Annotator for Budding Yeast (in preparation).
+# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+# and Swain, P.S. (2023). Determining growth rates from bright-field images of
+# budding cells through identifying overlaps. eLife. 12:e79812.
+# https://doi.org/10.7554/eLife.79812
 # 
 # 
 # The MIT License (MIT)
 # 
-# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021
+# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
 # 
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to
diff --git a/tests/images/evolve_testA_preds.png b/tests/images/evolve_testA_preds.png
index 23ef2498373b16fa2da9c79fb6aa9bfa1717f03a..3c0d72ce0106dfd6a0d7c1ddce81271384d3a043 100644
Binary files a/tests/images/evolve_testA_preds.png and b/tests/images/evolve_testA_preds.png differ
diff --git a/tests/images/evolve_testB_preds.png b/tests/images/evolve_testB_preds.png
index 5ab187c2e34472477e43f87a7a977a2e3eb25db6..cca0dd9c7a4755405577d00ced8d5d943197eba8 100644
Binary files a/tests/images/evolve_testB_preds.png and b/tests/images/evolve_testB_preds.png differ
diff --git a/tests/images/evolve_testC_preds.png b/tests/images/evolve_testC_preds.png
index e6101f64115de1e065b1e0a76a9230e5b91a6c68..d0e9fe8575337a67482f55ae079cc030f3aea2c4 100644
Binary files a/tests/images/evolve_testC_preds.png and b/tests/images/evolve_testC_preds.png differ
diff --git a/tests/images/evolve_testD_preds.png b/tests/images/evolve_testD_preds.png
index 2d762d03cff83b05a10f4a0350ece867553c17b3..cfdf9ced1437a838f7cca5276ee3af3f6fc23e60 100644
Binary files a/tests/images/evolve_testD_preds.png and b/tests/images/evolve_testD_preds.png differ
diff --git a/tests/images/evolve_testE_preds.png b/tests/images/evolve_testE_preds.png
index d1a785e82308336068445d01ea188611f2d1fd7b..d1fc0d0bb3feab593dc568f31d1ce73b1d39f57d 100644
Binary files a/tests/images/evolve_testE_preds.png and b/tests/images/evolve_testE_preds.png differ
diff --git a/tests/images/evolve_testF_tp3_preds.png b/tests/images/evolve_testF_tp3_preds.png
index f46d2b2ccfd49aaf7d3fdae2d432dd76a047f621..0541c2cabef40796b25967063da92d5e96c65780 100644
Binary files a/tests/images/evolve_testF_tp3_preds.png and b/tests/images/evolve_testF_tp3_preds.png differ
diff --git a/tests/images/evolve_testF_tp4_preds.png b/tests/images/evolve_testF_tp4_preds.png
index e884d8151baf3e720aa69a0aaf528ec73d226fbe..124a87dc7fc35e33d292fd04176f917bfd22da97 100644
Binary files a/tests/images/evolve_testF_tp4_preds.png and b/tests/images/evolve_testF_tp4_preds.png differ
diff --git a/tests/images/evolve_testG_tp1_preds.png b/tests/images/evolve_testG_tp1_preds.png
index 2d79c356bbe77e9dc25b05ae88286b58cab3c378..3a628f025bcbb1bf311d06b6686237f0a2c6ffd3 100644
Binary files a/tests/images/evolve_testG_tp1_preds.png and b/tests/images/evolve_testG_tp1_preds.png differ
diff --git a/tests/images/evolve_testG_tp2_preds.png b/tests/images/evolve_testG_tp2_preds.png
index debec78ca7e8f63675e03393c827965e89915ce2..e49496b6654d6b5853b8505b08ebcadac259d8de 100644
Binary files a/tests/images/evolve_testG_tp2_preds.png and b/tests/images/evolve_testG_tp2_preds.png differ
diff --git a/tests/images/evolve_testG_tp3_preds.png b/tests/images/evolve_testG_tp3_preds.png
index 0653f82ffdfe8ff1c7596766aa1c7e9f0793573c..1cde31e9d99c91f0fddcfeb3955881bfcdbcc2ee 100644
Binary files a/tests/images/evolve_testG_tp3_preds.png and b/tests/images/evolve_testG_tp3_preds.png differ
diff --git a/tests/images/evolve_testG_tp4_preds.png b/tests/images/evolve_testG_tp4_preds.png
index 3b226b5b8d636df2e3f37cf50799c1669013d75b..b47b3a42eb78f88e9f923903913b58eaed32fbc2 100644
Binary files a/tests/images/evolve_testG_tp4_preds.png and b/tests/images/evolve_testG_tp4_preds.png differ
diff --git a/tests/images/evolve_testG_tp5_preds.png b/tests/images/evolve_testG_tp5_preds.png
index 4e6af49f652c613bcf8d47a050cba55115ac3d80..8bb83faeeb01a40fdb3b79fbda4a457f97b98e47 100644
Binary files a/tests/images/evolve_testG_tp5_preds.png and b/tests/images/evolve_testG_tp5_preds.png differ
diff --git a/tests/images/mmsCMOS_testA_Brightfield.png b/tests/images/mmsCMOS_testA_Brightfield.png
new file mode 100644
index 0000000000000000000000000000000000000000..b880c1a5238c03a7edf3ba2b88e1609f06585559
Binary files /dev/null and b/tests/images/mmsCMOS_testA_Brightfield.png differ
diff --git a/tests/images/mmsCMOS_testA_segoutlines.png b/tests/images/mmsCMOS_testA_segoutlines.png
new file mode 100644
index 0000000000000000000000000000000000000000..132839155dddf20c90a4ad18998ee71a1feecd54
Binary files /dev/null and b/tests/images/mmsCMOS_testA_segoutlines.png differ
diff --git a/tests/images/mmsCMOS_testB_Brightfield.png b/tests/images/mmsCMOS_testB_Brightfield.png
new file mode 100644
index 0000000000000000000000000000000000000000..dc00c401f7d768c1419b67d7005f042f69a66975
Binary files /dev/null and b/tests/images/mmsCMOS_testB_Brightfield.png differ
diff --git a/tests/images/mmsCMOS_testB_segoutlines.png b/tests/images/mmsCMOS_testB_segoutlines.png
new file mode 100644
index 0000000000000000000000000000000000000000..7a26bf512e9fa49712d4f93c3eedc2beda5a0f6d
Binary files /dev/null and b/tests/images/mmsCMOS_testB_segoutlines.png differ
diff --git a/tests/io_test.py b/tests/io_test.py
index e9d5c1b5a2fc4c6e685b979cafa491ddf99a5022..c40a587484393e3e5350b2979912f8d82c0d548a 100644
--- a/tests/io_test.py
+++ b/tests/io_test.py
@@ -1,12 +1,14 @@
 # If you publish results that make use of this software or the Birth Annotator
 # for Budding Yeast algorithm, please cite:
-# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S
-# Swain, 2021, Birth Annotator for Budding Yeast (in preparation).
+# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+# and Swain, P.S. (2023). Determining growth rates from bright-field images of
+# budding cells through identifying overlaps. eLife. 12:e79812.
+# https://doi.org/10.7554/eLife.79812
 # 
 # 
 # The MIT License (MIT)
 # 
-# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021
+# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
 # 
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to
@@ -38,6 +40,7 @@ def test_tiled_save_load(tmp_path):
     save_tiled_image(img, testfile, info=info)
     loaded_img, loaded_info = load_tiled_image(testfile)
 
+    assert loaded_img.dtype == 'uint16'
     assert (loaded_img == img).all()
     assert 'experimentID' in loaded_info
     assert loaded_info['experimentID'] == info['experimentID']
diff --git a/tests/modelsets_test.py b/tests/modelsets_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..41b42beb37333095d69f99e362eda1bf22dfe23a
--- /dev/null
+++ b/tests/modelsets_test.py
@@ -0,0 +1,123 @@
+# If you publish results that make use of this software or the Birth Annotator
+# for Budding Yeast algorithm, please cite:
+# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+# and Swain, P.S. (2023). Determining growth rates from bright-field images of
+# budding cells through identifying overlaps. eLife. 12:e79812.
+# https://doi.org/10.7554/eLife.79812
+# 
+# 
+# The MIT License (MIT)
+# 
+# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
+# 
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to
+# deal in the Software without restriction, including without limitation the
+# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+# sell copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+# 
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+# 
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
+# IN THE SOFTWARE.
+import pytest
+from pathlib import Path
+import os
+
+from baby import BabyBrain
+
+
+DEFAULT_MODELSET = 'yeast-alcatras-brightfield-sCMOS-60x-5z'
+
+
+@pytest.fixture(scope='module')
+def remote_modelsets(modelsets):
+    return modelsets.remote_modelsets()
+
+
+def test_remote_modelsets_available(remote_modelsets):
+    assert 'models' in remote_modelsets
+    assert 'shared' in remote_modelsets
+
+
+def test_local_cache_dir(modelsets):
+    localpath = modelsets.LOCAL_MODELSETS_PATH
+    envpath = os.environ.get(modelsets.ENV_VAR_MODELSETS_PATH)
+    if not envpath:
+        assert localpath != modelsets.DEFAULT_LOCAL_MODELSETS_PATH
+    else:
+        assert localpath == Path(envpath)
+    modelsets._ensure_local_path()
+    assert localpath.exists()
+
+
+def test_meta(modelsets, remote_modelsets):
+    modelsets.ids() # trigger creation of local model sets cache
+    assert modelsets.LOCAL_MODELSETS_CACHE.exists()
+
+    # Listed models match those on remote
+    remote_keys = remote_modelsets['models'].keys()
+    assert len(set(modelsets.ids()).difference(remote_keys)) == 0
+
+    # Listed meta data matches that on remote
+    remote_meta = {k: v['meta'] for k, v in
+                   remote_modelsets['models'].items()}
+    assert modelsets.meta() == remote_meta
+
+
+def test_update(modelsets, remote_modelsets):
+    remote_default = remote_modelsets['models'][DEFAULT_MODELSET]
+    modelsets.update([DEFAULT_MODELSET])
+    localpath = modelsets.LOCAL_MODELSETS_PATH / DEFAULT_MODELSET
+    assert localpath.is_dir()
+    for filename in remote_default['files']:
+        assert (localpath / filename).is_file()
+
+
+def test_get_params(modelsets, remote_modelsets):
+    remote_default = remote_modelsets['models'][DEFAULT_MODELSET]
+    local_params = modelsets.get_params(DEFAULT_MODELSET)
+    assert local_params == remote_default['brain_params']
+    localpath = modelsets.LOCAL_MODELSETS_PATH / DEFAULT_MODELSET
+    print(localpath)
+
+    # Test auto-update for missing model set file
+    (localpath / modelsets.MODELSET_FILENAME).unlink()
+    modelsets.get_params(DEFAULT_MODELSET)
+    assert (localpath / modelsets.MODELSET_FILENAME).exists()
+
+    # Test auto-update for missing model file
+    target_file = [f for f in remote_default['files']
+                   if f != modelsets.MODELSET_FILENAME][0]
+    target_file = localpath / target_file
+    target_file.unlink()
+    modelsets.get_params(DEFAULT_MODELSET)
+    assert target_file.exists()
+
+
+def test_resolve(modelsets):
+    params = modelsets.get_params(DEFAULT_MODELSET)
+    localpath = modelsets.LOCAL_MODELSETS_PATH / DEFAULT_MODELSET
+    sharepath = modelsets.LOCAL_MODELSETS_PATH / modelsets.SHARE_PATH
+    localtest = params['morph_model_file']
+    sharetest = params['celltrack_model_file']
+    assert (localpath / localtest).is_file()
+    assert (sharepath / sharetest).is_file()
+    assert ((localpath / localtest) == 
+            modelsets.resolve(localtest, DEFAULT_MODELSET))
+    assert ((sharepath / sharetest) == 
+            modelsets.resolve(sharetest, DEFAULT_MODELSET))
+
+
+def test_get(modelsets):
+    bb = modelsets.get(DEFAULT_MODELSET)
+    assert type(bb) == BabyBrain
+
+
diff --git a/tests/parallel_test.py b/tests/parallel_test.py
index f38e205fcc9d13450f037997c7e2e945bff8af65..0f01d398a578b91fee08d758e5bdb824bac65f35 100644
--- a/tests/parallel_test.py
+++ b/tests/parallel_test.py
@@ -1,12 +1,14 @@
 # If you publish results that make use of this software or the Birth Annotator
 # for Budding Yeast algorithm, please cite:
-# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S
-# Swain, 2021, Birth Annotator for Budding Yeast (in preparation).
+# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+# and Swain, P.S. (2023). Determining growth rates from bright-field images of
+# budding cells through identifying overlaps. eLife. 12:e79812.
+# https://doi.org/10.7554/eLife.79812
 # 
 # 
 # The MIT License (MIT)
 # 
-# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021
+# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
 # 
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to
@@ -28,34 +30,30 @@
 import pickle
 from collections import namedtuple, Counter
 from os.path import isfile
+import json
+from pathlib import Path
 
 import baby
 import numpy as np
 import pytest
-from baby.brain import default_params
 from baby.io import load_paired_images
-from baby.morph_thresh_seg import MorphSegGrouped
+from baby.morph_thresh_seg import MorphSegGrouped, SegmentationParameters
 from baby.preprocessing import raw_norm, SegmentationFlattening
 from baby.tracker.core import MasterTracker
+from baby.utils import as_python_object
 
 from joblib import Parallel, delayed, parallel_backend
 
-MODEL_DIR = baby.model_path()
 
-
-def resolve_file(filename):
-    if not isfile(filename):
-        filename = MODEL_DIR / filename
-    assert isfile(filename)
-    return filename
+DEFAULT_MODELSET = 'yeast-alcatras-brightfield-EMCCD-60x-5z'
 
 
 @pytest.fixture(scope='module')
 def evolve60env(modelsets, image_dir):
-    mset = modelsets['evolve_brightfield_60x_5z']
+    mset = modelsets.get_params(DEFAULT_MODELSET)
 
     # Load flattener
-    ff = resolve_file(mset['flattener_file'])
+    ff = modelsets.resolve(mset['flattener_file'], DEFAULT_MODELSET)
     flattener = SegmentationFlattening(ff)
 
     tnames = flattener.names()
@@ -64,10 +62,16 @@ def evolve60env(modelsets, image_dir):
     i_bud = tnames.index(bud_target)
 
     # Load BabyBrain param defaults
-    params = default_params.copy()
-    params.update(mset.get('params', {}))
+    params = mset['params']
+    if type(params) == dict:
+        params = SegmentationParameters(**params)
+    if type(params) != SegmentationParameters:
+        param_file = modelsets.resolve(mset['params'], DEFAULT_MODELSET)
+        with open(param_file, 'rt') as f:
+            params = json.load(f, object_hook=as_python_object)
+    assert type(params) == SegmentationParameters
 
-    segmenter = MorphSegGrouped(flattener, return_masks=True, **params)
+    segmenter = MorphSegGrouped(flattener, params=params, return_masks=True)
 
     # Load CNN outputs
     impairs = load_paired_images(image_dir.glob('evolve_test[FG]_tp*.png'),
@@ -78,10 +82,10 @@ def evolve60env(modelsets, image_dir):
     ]), sorted([k for k in impairs.keys() if k.startswith('evolve_testG')]))
 
     # Load the celltrack and budassign models
-    ctm_file = resolve_file(mset['celltrack_model_file'])
+    ctm_file = modelsets.resolve(mset['celltrack_model_file'], DEFAULT_MODELSET)
     with open(ctm_file, 'rb') as f:
         ctm = pickle.load(f)
-    bam_file = resolve_file(mset['budassign_model_file'])
+    bam_file = modelsets.resolve(mset['budassign_model_file'], DEFAULT_MODELSET)
     with open(bam_file, 'rb') as f:
         bam = pickle.load(f)
 
diff --git a/tests/segment_test.py b/tests/segment_test.py
index 1dc16db33d0b4f591a698d813c9da82f66632a36..baada9265ef233b769bbb45f9bf0f6d298d41acd 100644
--- a/tests/segment_test.py
+++ b/tests/segment_test.py
@@ -1,12 +1,14 @@
 # If you publish results that make use of this software or the Birth Annotator
 # for Budding Yeast algorithm, please cite:
-# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S
-# Swain, 2021, Birth Annotator for Budding Yeast (in preparation).
+# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+# and Swain, P.S. (2023). Determining growth rates from bright-field images of
+# budding cells through identifying overlaps. eLife. 12:e79812.
+# https://doi.org/10.7554/eLife.79812
 # 
 # 
 # The MIT License (MIT)
 # 
-# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021
+# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
 # 
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to
@@ -28,6 +30,9 @@
 import pytest
 
 from os.path import isfile
+from pathlib import Path
+import json
+import inspect
 import numpy as np
 from scipy.ndimage import binary_fill_holes
 from itertools import chain
@@ -37,13 +42,17 @@ from baby.io import load_paired_images, save_tiled_image
 from baby.errors import BadParam
 from baby.preprocessing import (raw_norm, seg_norm, dwsquareconn,
                                 SegmentationFlattening)
-from baby.brain import default_params
-from baby.morph_thresh_seg import MorphSegGrouped, SegmentationOutput
+from baby.morph_thresh_seg import (MorphSegGrouped, SegmentationOutput,
+                                   SegmentationParameters)
+from baby.utils import as_python_object
+from baby import segmentation
 from baby.segmentation import morph_seg_grouped
 from baby.performance import (calc_IoUs, best_IoU, calc_AP,
                               flattener_seg_probs)
 
-#from .conftest import MODEL_DIR, IMAGE_DIR
+
+DEFAULT_MODELSET = 'yeast-alcatras-brightfield-EMCCD-60x-5z'
+
 
 # Tuple for variables needed to test segmentation
 SegmentationEnv = namedtuple(
@@ -51,6 +60,19 @@ SegmentationEnv = namedtuple(
     ['flattener', 'cparams', 'fparams', 'cnn_out', 'truth', 'imnames'])
 
 
+# Old default parameters
+DEFAULT_PARAMETERS = SegmentationParameters(
+    interior_threshold=(0.7, 0.5, 0.5),
+    nclosing=(1, 0, 0),
+    nopening=(1, 0, 0),
+    connectivity=(2, 2, 1),
+    pedge_thresh=0.001,
+    fit_radial=True,
+    edge_sub_dilations=1,
+    use_group_thresh=True,
+    group_thresh_expansion=0.1)
+
+
 def compare_edges_and_masks(edges, masks):
     if len(edges) == 0 and len(masks) == 0:
         return 1, 1
@@ -95,24 +117,33 @@ def run_performance_checks(seg_outputs, cnn_outputs, flattener, truth):
 
 
 @pytest.fixture(scope='module')
-def evolve60env(modelsets, model_dir, image_dir):
-    mset = modelsets['evolve_brightfield_60x_5z']
+def evolve60env(modelsets, image_dir):
+    mset = modelsets.get_params(DEFAULT_MODELSET)
 
     # Load flattener
-    ff = mset['flattener_file']
-    if not isfile(ff):
-        ff = model_dir / ff
-    assert isfile(ff)
+    ff = modelsets.resolve(mset['flattener_file'], DEFAULT_MODELSET)
     flattener = SegmentationFlattening(ff)
 
     # Load BabyBrain param defaults
-    cparams = default_params.copy()
-    cparams.update(mset.get('params', {}))
+    params = mset['params']
+    if type(params) == dict:
+        params = SegmentationParameters(**params)
+    if type(params) != SegmentationParameters:
+        param_file = modelsets.resolve(mset['params'], DEFAULT_MODELSET)
+        with open(param_file, 'rt') as f:
+            params = json.load(f, object_hook=as_python_object)
+    assert type(params) == SegmentationParameters
+    cparams = params
 
     # Convert to params compatible with morph_seg_grouped
-    fparams = cparams.copy()
-    del fparams['edge_sub_dilations']
+    msg_args = inspect.signature(morph_seg_grouped).parameters.keys()
+    fparams = cparams._asdict()
+    fparams = {k: v for k, v in fparams.items() if k in msg_args}
     fparams['ingroup_edge_segment'] = True
+    fparams['containment_func'] = getattr(segmentation,
+                                          fparams['containment_func'])
+    if fparams['cellgroups'] is None:
+        fparams['cellgroups'] = ['large', 'medium', 'small']
 
     # Load CNN outputs
     impairs = load_paired_images(image_dir.glob('evolve_*.png'),
@@ -197,9 +228,9 @@ def test_segmenter_bbparams_empty(evolve60env):
     params = evolve60env.cparams
     ntargets = len(flattener.names())
     segmenter = MorphSegGrouped(flattener,
+                                params=params,
                                 return_masks=True,
-                                return_coords=True,
-                                **params)
+                                return_coords=True)
     out = segmenter.segment(np.zeros((ntargets, 81, 81)))
     assert tuple(len(o) for o in out) == (0, 0, 0, 0)
 
@@ -254,9 +285,9 @@ def test_segfunc_bbparams_preds(evolve60env, save_segoutlines):
 def test_segmenter_bbparams_preds(evolve60env, save_segoutlines):
     flattener, params, _, cnn_out, truth, imnames = evolve60env
     segmenter = MorphSegGrouped(flattener,
+                                params=params,
                                 return_masks=True,
-                                return_coords=True,
-                                **params)
+                                return_coords=True)
     seg_outputs = [segmenter.segment(pred) for pred in cnn_out]
     save_segoutlines(seg_outputs, imnames)
 
@@ -290,9 +321,9 @@ def test_segfunc_refined_preds(evolve60env, save_segoutlines):
 def test_segmenter_refined_preds(evolve60env, save_segoutlines):
     flattener, params, _, cnn_out, truth, imnames = evolve60env
     segmenter = MorphSegGrouped(flattener,
+                                params=params,
                                 return_masks=True,
-                                return_coords=True,
-                                **params)
+                                return_coords=True)
     seg_outputs = [
         segmenter.segment(pred, refine_outlines=True) for pred in cnn_out
     ]
diff --git a/tests/tracker_test.py b/tests/tracker_test.py
index 413f7e1c68a0771f8c507d7e59964b40de91d0c4..3a36c03fb4d9570bb04e79a249532e4436fb9e7a 100644
--- a/tests/tracker_test.py
+++ b/tests/tracker_test.py
@@ -1,12 +1,14 @@
 # If you publish results that make use of this software or the Birth Annotator
 # for Budding Yeast algorithm, please cite:
-# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S
-# Swain, 2021, Birth Annotator for Budding Yeast (in preparation).
+# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N.,
+# and Swain, P.S. (2023). Determining growth rates from bright-field images of
+# budding cells through identifying overlaps. eLife. 12:e79812.
+# https://doi.org/10.7554/eLife.79812
 # 
 # 
 # The MIT License (MIT)
 # 
-# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021
+# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023
 # 
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to
@@ -28,34 +30,29 @@
 import pickle
 from collections import namedtuple, Counter
 from os.path import isfile
+import json
+from pathlib import Path
 
 import baby
 import numpy as np
 import pytest
-from baby.brain import default_params
 from baby.io import load_paired_images
-from baby.morph_thresh_seg import MorphSegGrouped
+from baby.morph_thresh_seg import MorphSegGrouped, SegmentationParameters
 from baby.preprocessing import raw_norm, SegmentationFlattening
 from baby.tracker.core import MasterTracker
+from baby.utils import as_python_object
 
-MODEL_DIR = baby.model_path()
 
 TrackerEnv = namedtuple('TrackerEnv', ['masks', 'p_budneck', 'p_bud'])
-
-
-def resolve_file(filename):
-    if not isfile(filename):
-        filename = MODEL_DIR / filename
-    assert isfile(filename)
-    return filename
+DEFAULT_MODELSET = 'yeast-alcatras-brightfield-EMCCD-60x-5z'
 
 
 @pytest.fixture(scope='module')
 def evolve60env(modelsets, image_dir):
-    mset = modelsets['evolve_brightfield_60x_5z']
+    mset = modelsets.get_params(DEFAULT_MODELSET)
 
     # Load flattener
-    ff = resolve_file(mset['flattener_file'])
+    ff = modelsets.resolve(mset['flattener_file'], DEFAULT_MODELSET)
     flattener = SegmentationFlattening(ff)
 
     tnames = flattener.names()
@@ -64,10 +61,16 @@ def evolve60env(modelsets, image_dir):
     i_bud = tnames.index(bud_target)
 
     # Load BabyBrain param defaults
-    params = default_params.copy()
-    params.update(mset.get('params', {}))
+    params = mset['params']
+    if type(params) == dict:
+        params = SegmentationParameters(**params)
+    if type(params) != SegmentationParameters:
+        param_file = modelsets.resolve(mset['params'], DEFAULT_MODELSET)
+        with open(param_file, 'rt') as f:
+            params = json.load(f, object_hook=as_python_object)
+    assert type(params) == SegmentationParameters
 
-    segmenter = MorphSegGrouped(flattener, return_masks=True, **params)
+    segmenter = MorphSegGrouped(flattener, params=params, return_masks=True)
 
     # Load CNN outputs
     impairs = load_paired_images(image_dir.glob('evolve_test[FG]_tp*.png'),
@@ -94,10 +97,10 @@ def evolve60env(modelsets, image_dir):
     trkF, trkG = trks
 
     # Load the celltrack and budassign models
-    ctm_file = resolve_file(mset['celltrack_model_file'])
+    ctm_file = modelsets.resolve(mset['celltrack_model_file'], DEFAULT_MODELSET)
     with open(ctm_file, 'rb') as f:
         ctm = pickle.load(f)
-    bam_file = resolve_file(mset['budassign_model_file'])
+    bam_file = modelsets.resolve(mset['budassign_model_file'], DEFAULT_MODELSET)
     with open(bam_file, 'rb') as f:
         bam = pickle.load(f)