diff --git a/.gitignore b/.gitignore index fb11995ecf7a23417b7105faac4fd539370d9dbb..ef97f288f045c275e9d39282062406480412353a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ *.pyc +*.asv .*.sw? *~ .DS_Store @@ -13,3 +14,5 @@ baby-server-test.png *.egg-info/ **/venv/ /tf1_venv/ +/dist +/tests/test-modelset-cache/ diff --git a/LICENSE.txt b/LICENSE.txt index df6ab2d7a11345bfd94c81ba9e31d37c44a48791..e95c479f323c327c150ff2f9276739747314ad6e 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -1,12 +1,14 @@ If you publish results that make use of this software or the Birth Annotator for Budding Yeast algorithm, please cite: -Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S Swain, -2021, Birth Annotator for Budding Yeast (in preparation). +Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +and Swain, P.S. (2023). Determining growth rates from bright-field images of +budding cells through identifying overlaps. eLife. 12:e79812. +https://doi.org/10.7554/eLife.79812 The MIT License (MIT) -Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021 +Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/MANIFEST.in b/MANIFEST.in deleted file mode 100644 index bb4ac0c66c1510fa03ad47f7727055fcdcfbd778..0000000000000000000000000000000000000000 --- a/MANIFEST.in +++ /dev/null @@ -1,4 +0,0 @@ -graft python/baby -include README.md -include LICENSE.txt -include AUTHORS.txt diff --git a/README.md b/README.md index c59fbc1ceef00ec54d49b71d2b4a7fbd64f01008..a77c33d75f38ab6a0bb883d767a105e8e04bdcf1 100644 --- a/README.md +++ b/README.md @@ -12,92 +12,97 @@ budding cells from bright-field stacks. The Birth Annotator for Budding Yeast The algorithm is described in: -Julian M J Pietsch, Alán F Muñoz, Diane-Yayra A Adjavon, Ivan B N Clark, Peter -S Swain, 2022, A label-free method to track individuals and lineages of -budding cells (in submission). +[Julian MJ Pietsch, Alán F Muñoz, Diane-Yayra A Adjavon, Iseabail Farquhar, +Ivan BN Clark, Peter S Swain. (2023). Determining growth rates from +bright-field images of budding cells through identifying overlaps. eLife. +12:e79812.](https://doi.org/10.7554/eLife.79812) + ## Installation -BABY can be used with Python versions 3.6-3.8 (see below for details). If you -wish to use the latest compatible versions of all packages, BABY can simply be -installed by first obtaining this repository (e.g., `git clone -https://git.ecdf.ed.ac.uk/jpietsch/baby.git`), and then running pip on the -repository directory: +We recommend installing BABY and all its requirements in a virtual environment +(e.g., `conda create` if you are using Anaconda, or `python3 -m venv` otherwise). + +If you do not have a GPU and simply wish to use the latest compatible versions +of all packages, BABY can be installed by first obtaining this repository +(e.g., `git clone https://git.ecdf.ed.ac.uk/jpietsch/baby.git`), and then +using pip: ```bash -> cd baby -> pip install . +> pip install baby/ ``` -NB: The '.' is important! - -If you pull new changes, you need to update by running: `pip install -U .` from -within the repository directory. +NB: You can update by running: `pip install -U baby/`. *Developers:* You may prefer to install an editable version: ```bash -> pip install -e . +> pip install -e baby/ ``` -This avoids the need to run the update command. - -**Requirements for Python and TensorFlow** - -BABY requires Python 3 and [TensorFlow](https://www.tensorflow.org). The -models were trained in TensorFlow 1.14.0, but are compatible with versions of -TensorFlow up to 2.3.4. The required version of Python depends on the version -of TensorFlow you choose. We recommend either: +### Python and TensorFlow version -- Python 3.6 and TensorFlow 1.14, -- Python 3.7 and TensorFlow 1.15, or -- Python 3.8 and TensorFlow 2.3. +BABY requires Python 3 and [TensorFlow](https://www.tensorflow.org). +Different versions of TensorFlow have different Python version requirements. +You can find a table of matching versions +[here](https://www.tensorflow.org/install/source#tested_build_configurations). -In any case, it is recommended that you install TensorFlow and all other -required packages into a virtual environment (i.e., `conda create` if you are -using Anaconda, or `python3 -m venv` otherwise). +Our models were trained with TensorFlow version 2.8, but have been tested up +to version 2.14. -By default, BABY will trigger installation of the highest compatible version of -TensorFlow. If you want to use an earlier version as suggested above, then -first install that version in your virtual environment by running: +By default, BABY will trigger installation of the highest compatible version +of TensorFlow. If you want to use an earlier version, then first install that +version in your virtual environment by running: ```bash -> pip install tensorflow==1.14 +> pip install tensorflow==2.8 ``` and then follow the instructions for installing BABY as above. -**NB:** To make use of a GPU you should also follow the other [set up -instructions](https://www.tensorflow.org/install/gpu). +### Running with GPU -**NB:** For `tensorflow==1.14`, you will also need to downgrade the default -version of `h5py`: +To make use of a GPU you should follow the [TensorFlow set up +instructions](https://www.tensorflow.org/install/gpu) before installing BABY. -```bash -> pip uninstall h5py -> pip install h5py==2.9.0 -``` +BABY can make use of Metal on M1/M2 Macs by following the instructions +[here](https://developer.apple.com/metal/tensorflow-plugin/). + + +## Quickstart using the Python API -## Run using the Python API +The BABY algorithm makes use of several machine learning models that are +defined as a model set. Various model sets are available, and each has been +optimised for a particular species, microfluidics device, pixel size, channel +and number of input Z sections. -Create a new `BabyBrain` with one of the model sets. The `brain` contains -all the models and parameters for segmenting and tracking cells. +You can get a list of available model sets and the types of input they were +trained for using the `meta` function in the `modelsets` module: ```python ->>> from baby import BabyBrain, BabyCrawler, modelsets ->>> modelset = modelsets()['evolve_brightfield_60x_5z'] ->>> brain = BabyBrain(**modelset) +>>> from baby import modelsets +>>> modelsets.meta() +``` + +You then load your favourite model set as a `BabyBrain` object, which +coordinates all the models and parameters in the set to produce tracked and +segmented outlines from input images. You can get a `BabyBrain` for a given +model set using the `get` function in the `modelsets` module: + +```python +>>> brain = modelsets.get('yeast-alcatras-brightfield-EMCCD-60x-5z') ``` For each time course you want to process, instantiate a new `BabyCrawler`. The crawler keeps track of cells between time steps. ```python +>>> from baby import BabyCrawler >>> crawler = BabyCrawler(brain) ``` -Load an image time series (from the `tests` subdirectory in this example). The -image should have shape (x, y, z). +Load an image time series (from the `tests` subdirectory in this repository). +The image should have shape (x, y, z). ```python >>> from baby.io import load_tiled_image @@ -150,12 +155,6 @@ requests using: > baby-phone ``` -or on windows: - -``` -> baby-phone.exe -``` - Server runs by default on [http://0.0.0.0:5101](). HTTP requests need to be sent to the correct URL endpoint, but the HTTP API is currently undocumented. The primary client implementation is in Matlab. diff --git a/hatch.toml b/hatch.toml new file mode 100644 index 0000000000000000000000000000000000000000..10caa8e586007468886f88c811e30e50f17e84da --- /dev/null +++ b/hatch.toml @@ -0,0 +1,12 @@ +[envs.test] +python = "3.10" +dependencies = [ + "pytest", +] +[envs.test.env-vars] +BABY_MODELSETS_PATH = "tests/test-modelset-cache" + +[envs.test_cache] +template = "test" +[envs.test_cache.env-vars] +BABY_MODELSETS_PATH = "" diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..b336e92a4ee32960b0a9ac0de3820477adbf0646 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,68 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "baby-seg" +dynamic = ["version"] +description = "Birth Annotator for Budding Yeast" +readme = "README.md" +license = "MIT" +license-files = { paths = ["LICENSE.txt"] } +authors = [ + {name = "Julian Pietsch", email = "julian.pietsch@synmikro.mpi-marburg.mpg.de"} +] +requires-python = ">=3.6" +dependencies = [ + "aiohttp", + "gaussianprocessderivatives", + "imageio", + "matplotlib", + "numpy", + "pandas", + "pillow<9", + "requests", + "scikit-image", + "scikit-learn<1.3", + "scipy", + "tensorflow>=1.14", + "tensorflow-metal; platform_system == 'Darwin' and platform_machine == 'arm64'", + "tqdm", +] +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + "Intended Audience :: Science/Research", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Scientific/Engineering :: Image Processing", +] + +[project.optional-dependencies] +dev = [ + "elasticdeform", + "keras_tuner", + "pytest", +] + +[project.scripts] +baby-fit-grs = "baby.postprocessing:main" +baby-phone = "baby.server:main" +baby-race = "baby.speed_tests:main" + +[project.urls] +Homepage = "https://git.ecdf.ed.ac.uk/swain-lab/baby" + +[tool.hatch.version] +path = "python/baby/__init__.py" + +[tool.hatch.build.targets.wheel] +packages = [ + "python/baby", +] + +[tool.hatch.build.targets.sdist] +include = [ + "/python", + "/AUTHORS.txt", +] diff --git a/python/baby/__init__.py b/python/baby/__init__.py index 77f8b476878a21c7c3e1829542ed6f1e6bec6f49..002f91abbba22e8591ac5f7041682d87eddebe09 100644 --- a/python/baby/__init__.py +++ b/python/baby/__init__.py @@ -1,12 +1,14 @@ # If you publish results that make use of this software or the Birth Annotator # for Budding Yeast algorithm, please cite: -# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S -# Swain, 2021, Birth Annotator for Budding Yeast (in preparation). +# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +# and Swain, P.S. (2023). Determining growth rates from bright-field images of +# budding cells through identifying overlaps. eLife. 12:e79812. +# https://doi.org/10.7554/eLife.79812 # # # The MIT License (MIT) # -# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021 +# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to @@ -26,19 +28,7 @@ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS # IN THE SOFTWARE. """Mostly used to access models and model-sets""" -from pathlib import Path -import json from .brain import BabyBrain from .crawler import BabyCrawler -BASE_DIR = Path(__file__).parent -MODEL_DIR = BASE_DIR / 'models' - -def modelsets(): - with open(BASE_DIR / 'modelsets.json', 'r') as fd: - msets = json.load(fd) - return msets - -# Todo: should probably be removed, but used in Tests -def model_path(): - return MODEL_DIR +__version__ = 'v0.30.0' diff --git a/python/baby/augmentation.py b/python/baby/augmentation.py index 38ee7a195a7a36e4b64bb8614ce663cb67ba6aff..23b959f8d621255172e27666545f9a5650eceac3 100644 --- a/python/baby/augmentation.py +++ b/python/baby/augmentation.py @@ -1,12 +1,14 @@ # If you publish results that make use of this software or the Birth Annotator # for Budding Yeast algorithm, please cite: -# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S -# Swain, 2021, Birth Annotator for Budding Yeast (in preparation). +# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +# and Swain, P.S. (2023). Determining growth rates from bright-field images of +# budding cells through identifying overlaps. eLife. 12:e79812. +# https://doi.org/10.7554/eLife.79812 # # # The MIT License (MIT) # -# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021 +# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to @@ -30,7 +32,7 @@ A set of augmentation functions for on-the-fly preprocessing of images before input into the Neural Network. All augmentation functions are meant to work with unstructured arrays (in particular, elastic_deform, turn, and the shifts probably will return errors if used on structured arrays. They all take exactly -two arrays as input, and will perform identitcal transformations on both arrays. +two arrays as input, and will perform identical transformations on both arrays. """ from __future__ import absolute_import, division, print_function import json @@ -38,21 +40,28 @@ import numpy as np from itertools import permutations, repeat from scipy.ndimage import map_coordinates, gaussian_filter, shift from scipy.ndimage.morphology import binary_fill_holes +from scipy import interpolate from skimage import transform from skimage.feature import canny from skimage.filters import gaussian from skimage.draw import rectangle_perimeter +import elasticdeform +# import cv2 -from .preprocessing import segoutline_flattening -from .errors import BadParam +from .preprocessing import segoutline_flattening, connect_pixel_gaps +from .errors import BadParam, BadFile -AUGMENTATION_ORDER = ('substacks', 'rotate', 'vshift', 'hshift', 'downscale', - 'crop', 'vflip', 'hflip', 'movestacks', 'noise') +AUGMENTATION_ORDER = ('substacks', 'vshift', 'hshift', 'rough_crop', + 'downscale', 'shadows', 'blur', 'elastic_deform', + 'rotate', 'crop', 'vflip', 'hflip', + 'movestacks', 'noise', 'pad', 'gamma', + 'hist_deform', 'hist_scale') class Augmenter(object): - def __init__(self, xy_out=80, probs={}, p_noop=0.05, substacks=None): + def __init__(self, xy_out=80, probs={}, p_noop=0.05, substacks=None, + only_basic_augs=True): """ Random data augmentation of img and lbl. @@ -61,35 +70,62 @@ class Augmenter(object): xy_out : int or pair of ints as a tuple The intended width and height of the final augmented image. - probs : dict of floats in [0, 1] + probs : dict of floats in [0, 1] or None Specify non-default probabilities for the named augmentations. - Augmentations with zero probability are omitted. + Augmentations with zero probability are omitted. To get all + augmentations with equal probability, simply specify an empty + dict. p_noop : float in [0, 1] Adjusts the probability for no augmentation operation in the default case. If set to 1, then all operations will be omitted by default (i.e., unless a non-zero probability is specified for that operation in probs). + + only_basic_augs: bool + If True, then elastic_deform and img intensity augmentations are + omitted (i.e., "shadows", "elastic_deform", "gamma", + "hist_deform", "hist_scale"). This was the default for the BABY + paper. Specify False to include all augmentations. NB: this has + lower priority than the `probs` arg, e.g., if `shadows=0.2` is + specified in probs, then the shadows aug will be active. """ if type(xy_out) is int: self.xy_out = (xy_out, xy_out) elif len(xy_out) == 2: - self.xy_out = xy_out + self.xy_out = tuple(xy_out) else: raise Exception('"xy_out" must be an int or pair of ints') self.xy_in = None + self.pad_value = None + self.aug_log = [] + self._vshift = None + self._hshift = None + + self.preserve_bool_lbl = True + + # Interpolation order for label operations + # This should be 0 for bitmask labels + self.lbl_order = 0 + + if only_basic_augs: + custom_probs = probs + probs = dict(shadows=0, elastic_deform=0, gamma=0, hist_deform=0, + hist_scale=0, blur=0) + probs.update(**custom_probs) + # Treat 'crop', 'pad' and 'substacks' specially to have p = 1 + guaranteed_augs = ('substacks', 'rough_crop', 'pad', 'crop') self.aug_order = [ a for a in AUGMENTATION_ORDER - if probs.get(a, 1) > 0 or a == 'crop' + if probs.get(a, 1) > 0 or a in guaranteed_augs ] - - # Treat 'crop' and 'substacks' specially to have p = 1 - guaranteed_augs = ('crop', 'substacks') - n_augs = len(self.aug_order) - len(guaranteed_augs) - p_default = (1. - p_noop)**n_augs / n_augs + n_augs = len([1 for a in self.aug_order + if probs.get(a, 0) < 1 + and a not in guaranteed_augs]) + p_default = 1. - p_noop**(1/n_augs) self.probs = np.array([ 1 if a in guaranteed_augs else probs.get(a, p_default) for a in self.aug_order @@ -126,22 +162,27 @@ class Augmenter(object): assert img.shape[:2] == lbl.shape[:2], \ 'xy dimensions of img and lbl are mismatched' - lbl_is_bool = lbl.dtype == 'bool' + lbl_is_bool = lbl.dtype == 'bool' and self.preserve_bool_lbl self.xy_in = img.shape[:2] + self.pad_value = None + self.aug_log = [] for a, p in zip(self.aug_order, self.probs): if p == 0.0: continue elif p == 1.0: + self.aug_log.append(a) img, lbl = getattr(self, a)(img, lbl) else: if np.random.uniform() < p: + self.aug_log.append(a) img, lbl = getattr(self, a)(img, lbl) if lbl_is_bool and lbl.dtype != 'bool': lbl = lbl > 0.5 # ensure label stays boolean self.xy_in = None + self.pad_value = None # Ensure that x and y dimensions match the intended output size assert img.shape[:2] == self.xy_out and lbl.shape[:2] == self.xy_out, \ @@ -200,6 +241,40 @@ class Augmenter(object): img = img[:, :, ss] self.refslice = int(np.median(np.where(ss))) + 1 + self.pad_value = np.median(img) + return img, lbl + + def pad(self, img, lbl, update_size=True): + """Ensure image/label size are large enough + + A guaranteed augmentation that pads image to ensure that + as required Guaranteed augmentation simply run to ensure image/label sizes + """ + + # We store a pad value to ensure consistent padding if this routine is + # called more than once in the augmentation chain + if self.pad_value is None: + self.pad_value = np.median(img) + + # Ensure that img and lbl xy dimensions are at least as large as the + # requested output + Xin, Yin = img.shape[:2] + Xout, Yout = self.xy_out + Xpad = Xout - Xin if Xin < Xout else 0 + Ypad = Yout - Yin if Yin < Yout else 0 + if Xpad > 0 or Ypad > 0: + XpadL = Xpad // 2 + Xpad = (XpadL, Xpad - XpadL) + YpadL = Ypad // 2 + Ypad = (YpadL, Ypad - YpadL) + img = np.pad(img, (Xpad, Ypad, (0, 0)), + mode='constant', constant_values=self.pad_value) + lbl = np.pad(lbl, (Xpad, Ypad, (0, 0)), mode='constant') + + # Update input size by default + if update_size: + self.xy_in = img.shape[:2] + return img, lbl def rotate(self, img, lbl): @@ -236,9 +311,8 @@ class Augmenter(object): inshape = self.xy_in[0] maxpix = np.max([0, (inshape - self.xy_out[0]) // 2]) - pix = np.random.choice(np.arange(-maxpix, maxpix + 1, dtype='int')) - return (shift(img, [pix, 0, 0], mode='reflect', order=0), - shift(lbl, [pix, 0, 0], mode='reflect', order=0)) + self._vshift = np.random.choice(np.arange(-maxpix, maxpix + 1, dtype='int')) + return img, lbl def hshift(self, img, lbl, maxpix=None): """Shift along width, max of 10px by default @@ -257,9 +331,22 @@ class Augmenter(object): inshape = self.xy_in[1] maxpix = np.max([0, (inshape - self.xy_out[1]) // 2]) - pix = np.random.choice(np.arange(-maxpix, maxpix + 1, dtype='int')) - return (shift(img, [0, pix, 0], mode='reflect', order=0), - shift(lbl, [0, pix, 0], mode='reflect', order=0)) + self._hshift = np.random.choice(np.arange(-maxpix, maxpix + 1, dtype='int')) + return img, lbl + + def rough_crop(self, img, lbl, xysize=None, nonempty_slices=None): + if xysize is None: + xysize = np.array(self.xy_out, dtype=float) * 1.7 + xysize = np.ceil(xysize).astype(int) + if self.pad_value is None: + self.pad_value = np.median(img) + vshift = 0 if self._vshift is None else self._vshift + hshift = 0 if self._hshift is None else self._hshift + img = _shift_and_crop(img, vshift, hshift, xysize, self.pad_value) + lbl = _shift_and_crop(lbl, vshift, hshift, xysize, 0, + nonempty_slices=nonempty_slices) + self._vshift, self._hshift = None, None + return img, lbl def crop(self, img, lbl, xysize=None): if xysize is None: @@ -327,48 +414,155 @@ class Augmenter(object): size=img.shape) return img, lbl - def elastic_deform(self, img, lbl, params={}): - """Slight deformation + def elastic_deform(self, img, lbl): + """Random deformation based on elasticdeform package from PyPI - Elastic deformation of images as described in - Simard, Steinkraus and Platt, "Best Practices for - Convolutional Neural Networks applied to Visual Document Analysis", in - Proc. of the International Conference on Document Analysis and - Recognition, 2003. - Adapted from: - https://gist.github.com/chsasank/4d8f68caf01f041a6453e67fb30f8f5a + We aim for relatively weak deformation, so assign grid points + approximately every 32 pixels and keep the sigma low. - Example image: + The interpolation order for the lbl is set from the `lbl_order` + property of this class. It should be 0 if the lbl is a bitmask and no + smoothing has been applied, but can be the default 3 if smoothing of + a bitmask image has been applied. + """ + + # Want only light deformation, so approximately one grid point every + # 32 pixels: + npoints = np.maximum(np.round(np.array(lbl.shape[:2])/32), 1) + npoints = npoints.astype('int').tolist() + return elasticdeform.deform_random_grid( + [img, lbl], sigma=2, points=npoints, + order=[3, self.lbl_order], mode='reflect', axis=[(0, 1), (0, 1)]) + + def shadows(self, img, lbl): + """Location-dependent intensity deformation - .. image:: ../report/figures/augmentations/elastic_deform.* + Introduce local changes in intensity by fitting a bivariate spline to + a grid of random intensities. """ - alpha = 0.75 * img.shape[1] - sigma = 0.08 * img.shape[1] - x_y = _elastic_deform(np.dstack([img, lbl]), alpha=alpha, sigma=sigma) - return np.split(x_y, [img.shape[2]], axis=2) - def identity(self, img, lbl): - """Do nothing + # Want only slow changes in intensity, so approximately one grid point + # every 64 pixels: + npoints = np.maximum(np.round(np.array(img.shape[:2])/64), 1) + npoints = npoints.astype('int').tolist() + + # Set up grid with random intensities and interpolate + x = np.linspace(0, img.shape[0] - 1, np.maximum(npoints[0], 2)) + y = np.linspace(0, img.shape[1] - 1, np.maximum(npoints[1], 2)) + I = np.random.normal(size=(x.size, y.size)) + mapping = interpolate.RectBivariateSpline( + x, y, I, kx=np.minimum(x.size, 4) - 1, + ky=np.minimum(y.size, 4) - 1, s=(x.size * y.size) * 0.25) + # Scaled so that 90% of distribution sits within 0.5- to 2-fold + I_scaler = np.exp(0.4 * mapping(np.arange(img.shape[0]), + np.arange(img.shape[1]), grid=True)) + # Truncate to a two-fold change in intensity + I_scaler = np.minimum(np.maximum(I_scaler, 0.5), 2) + # Ensure that the mean intensity of the entire image stays + # approximately constant + I_scaler /= np.median(I_scaler) + # Apply scaling consistently across z-sections + I_scaler = I_scaler[..., None] + img_min = img.min() + img_aug = img_min + (img - img_min) * I_scaler + return img_aug, lbl + + def gamma(self, img, lbl): + """Adjust image gamma + + Picks a random gamma transform to apply to the image + """ - Example image: + # Pick a random gamma + g = np.random.uniform(0.3, 2.0) + + # Attempt to robustly maintain the intensity range + iQ = np.quantile(img, (0, 0.02, 0.5, 0.98, 1)) + imin, iptp = iQ[0], iQ[4] - iQ[0] + iQn = (iQ - imin) / iptp + # aim for consistent difference between 2/98 quantiles + sc = np.diff(iQ[[1, 3]]) / np.diff(iQn[[1, 3]] ** g) + # aim for consistent median + off = iQ[2] - iQn[2] ** g + + return off + sc * ((img - imin) / iptp) ** g, lbl + + def hist_deform(self, img, lbl): + """Random deformation of intensity histogram + + Maps equally-spaced intensity ranges to new ranges sampled from the + Dirichlet distribution (of equal average lengths), but uses + interpolation to smooth the mapping. - .. image:: ../report/figures/augmentations/identity.* + Inspired by the `histogram_voodoo` implementation provided by Daniel + Eaton from the Paulsson lab for the Delta segmentation package. """ - return img, lbl + npoints = 5 # includes end points + control_points = np.linspace(0, 1, num=npoints) + # split into `npoints - 1` segments with roughly equal lengths + # i.e., want relatively high alpha parameter: + mapped_points = np.random.dirichlet((2 * npoints,) * (npoints - 1)) + mapped_points = np.cumsum(np.concatenate([[0], mapped_points])) + mapping = interpolate.PchipInterpolator(control_points, mapped_points) + + # Attempt to robustly maintain the intensity range + iQ = np.quantile(img, (0, 0.02, 0.5, 0.98, 1)) + imin, iptp = iQ[0], iQ[4] - iQ[0] + iQn = (iQ - imin) / iptp + # aim for consistent difference between 2/98 quantiles + sc = np.diff(iQ[[1, 3]]) / np.diff(mapping(iQn[[1, 3]])) + # aim for consistent median + off = iQ[2] - mapping(iQn[2]) + + return off + sc * mapping((img - imin) / iptp), lbl + + def hist_scale(self, img, lbl): + """Randomly offset and scale the intensity histogram + """ + + # Sample an offset within 25% of the range of image intensities + imgrange = np.diff(np.quantile(img, [0.02, 0.98])) + off = np.random.uniform(-0.25*imgrange, 0.25*imgrange) + + # Fairly sample a scaling factor between roughly 0.5 and 2 + sc = np.random.lognormal(sigma=0.4) + # Truncate to a minimum of 0.5 and maximum of 2 + sc = np.minimum(np.maximum(sc, 0.5), 2) + + return (img + off) * sc, lbl + + def blur(self, img, lbl): + """Apply gaussian filter to img to simulate loss of focus + + Draw the sigma from an uniform distribution between 1 and 4. + """ + return gaussian(img, np.random.uniform(1, 4)), lbl class SmoothingSigmaModel(object): """Model for picking a smoothing sigma for gaussian filtering - a, b and c should be obtained by fitting data to the following model: + There are two model types, 'exponential' and 'constant' + + For the exponential model, `a`, `b` and `c` should be obtained by fitting + data to the following model: nedge = c + a * exp(b * sigma) + + For the constant model, set `a` such that: + sigma = a + i.e., sigma is independent of nedge """ - def __init__(self, a=None, b=None, c=None): + def __init__(self, a=None, b=None, c=None, model='exponential'): self._a = a self._b = b self._c = c - self._formula = 'sigma = log((nedge - c) / a) / b' + if model == 'exponential': + self._formula = 'sigma = log((nedge - c) / a) / b' + elif model == 'constant': + self._formula = 'sigma = a' + else: + raise BadParam('Unrecognised model type') def save(self, filename): with open(filename, 'wt') as f: @@ -383,24 +577,41 @@ class SmoothingSigmaModel(object): def load(self, filename): with open(filename, 'rt') as f: model = json.load(f) - if model.get('formula') == 'sigma = m*log(nedge) + c': + formula = model.get('formula') + if formula == 'sigma = m*log(nedge) + c': m, c = (model.get(k, 0) for k in ('m', 'c')) self._a = np.exp(-c * m) self._b = 1 / m self._c = 0 - elif self._formula != model.get('formula'): - raise BadFile('Model formula does not match SmoothingSigmaModel') - else: + self._formula = 'sigma = log((nedge - c) / a) / b' + elif formula == 'sigma = a': + self._a = model.get('a', 0) + self._b, self._c = None, None + self._formula = formula + elif formula == 'sigma = log((nedge - c) / a) / b': self._a, self._b, self._c = ( model.get(k, 0) for k in ('a', 'b', 'c')) + self._formula = formula + else: + raise BadFile('Model with unrecognised formula encountered') def __repr__(self): - return 'SmoothingSigmaModel: {}; a = {:.2f}, b = {:.2f}, c = {:.2f}'.format( - self._formula, self._a, self._b, self._c) - - def __call__(self, s): - return np.log(np.clip( - (np.sum(s) - self._c) / self._a, 1, None)) / self._b + keys = ('a', 'b', 'c') + coefs = (getattr(self, '_' + k) for k in keys) + coefs = [ + f'{k} = {v:.2f}' for k, v in zip(keys, coefs) if v is not None] + return 'SmoothingSigmaModel: {}; {}'.format( + self._formula, ', '.join(coefs)) + + def __call__(self, s, scaling=1.): + if self._formula == 'sigma = a': + return self._a + elif self._formula == 'sigma = log((nedge - c) / a) / b': + nedge = np.sum(s) * scaling + return np.log(np.clip( + (nedge - self._c) / self._a, 1, None)) / self._b + else: + raise BadParam('Unrecognised formula encountered') class SmoothedLabelAugmenter(Augmenter): @@ -408,19 +619,24 @@ class SmoothedLabelAugmenter(Augmenter): def __init__(self, sigmafunc, targetgenfunc=segoutline_flattening, + canny_padding=2, **kwargs): super(SmoothedLabelAugmenter, self).__init__(**kwargs) self.sigmafunc = sigmafunc self.targetgenfunc = targetgenfunc + self.canny_padding = canny_padding + # Since label smoothing is applied, interpolation order for label + # transformations can be increased + self.lbl_order = 3 def __call__(self, img, lbl_info): """This augmenter needs to be used in combination with a label preprocessing function that returns both images and info. """ - lbl, info = lbl_info + lbl, self.current_info = lbl_info - # Smooth filled label for to avoid anti-aliasing artefacts + # Smooth filled label to avoid anti-aliasing artefacts lbl = lbl.astype('float') for l in range(lbl.shape[2]): o = lbl[..., l] # slice a single outline @@ -428,13 +644,44 @@ class SmoothedLabelAugmenter(Augmenter): lbl[..., l] = gaussian(binary_fill_holes(o), self.sigmafunc(o)) - img, lbl = super(SmoothedLabelAugmenter, self).__call__(img, lbl) + return super(SmoothedLabelAugmenter, self).__call__(img, lbl) + + def rotate(self, img, lbl): + """Random rotation + + Example image: + + .. image:: ../report/figures/augmentations/turn.* + """ + angle = np.random.choice(360) + return (transform.rotate(img, + angle=angle, + order=3, + mode='reflect', + resize=True), + transform.rotate(lbl, + angle=angle, + order=3, + mode='reflect', + resize=True)) + + def crop(self, img, lbl, xysize=None): + # Overload the crop function to restore filled cell masks and generate + # the flattened targets. Performing these operations before applying + # cropping helps to avoid boundary effects and misclassification of + # size group. + # + if xysize is None: + xysize = np.array(self.xy_out) - # NB: to limit open shapes, the crop operation has been overloaded to - # find edges before cropping + # Find edges from blurred images and fill + for s in range(lbl.shape[2]): + lbl[:, :, s] = _filled_canny(lbl[:, :, s], self.canny_padding) - # Finally generate flattened targets from segmentation outlines + info = self.current_info if 'focusStack' in info: + # Modify info dict that gets passed to flattener in order to + # specify focus relative to reference slice formed from substack info = info.copy() cellFocus = info['focusStack'] if type(cellFocus) != list: @@ -443,18 +690,9 @@ class SmoothedLabelAugmenter(Augmenter): info['focusStack'] = [f - self.refslice for f in cellFocus] # print('new focus = {}'.format(', '.join([str(f) for f in info['focusStack']]))) + # Generate flattened targets from segmentation outlines lbl = self.targetgenfunc(lbl, info) - return img, lbl - - def crop(self, img, lbl, xysize=None): - if xysize is None: - xysize = np.array(self.xy_out) - - # Find edges and fill cells before cropping - for s in range(lbl.shape[2]): - lbl[:, :, s] = _filled_canny(lbl[:, :, s]) - return _apply_crop(img, xysize), _apply_crop(lbl, xysize) def downscale(self, img, lbl, maxpix=None): @@ -520,17 +758,16 @@ class ScalingAugmenter(SmoothedLabelAugmenter): scale_index = self.aug_order.index('downscale') self.scale_prob = self.probs[scale_index] self.probs[scale_index] = 1 + self.preserve_bool_lbl = False def __call__(self, img, lbl_info): lbl, info = lbl_info + self.current_info = info self._input_pix_size = info.get('pixel_size', self.target_pixel_size) self._scaling = self._input_pix_size / self.target_pixel_size self._outshape = np.round(np.array(self.xy_out) / self._scaling) - img, lbl = super(ScalingAugmenter, self).__call__(img, lbl_info) - # self._input_pix_size = None - # iself.scaling = None - # self._outshape = None - return img, lbl + + return super(SmoothedLabelAugmenter, self).__call__(img, lbl) def vshift(self, img, lbl, maxpix=None): if maxpix is None: @@ -550,99 +787,144 @@ class ScalingAugmenter(SmoothedLabelAugmenter): maxpix = np.max([0, (inshape - self._outshape[1]) // 2]) return super(ScalingAugmenter, self).hshift(img, lbl, maxpix=maxpix) + def rough_crop(self, img, lbl): + xysize = np.array(self.xy_out, dtype=float) * 1.7 / self._scaling + xysize = np.ceil(xysize).astype(int) + + nonempty = np.ones(lbl.shape[2], dtype=bool) + n_total = nonempty.size if lbl.any() else 0 + img, lbl = super(ScalingAugmenter, self).rough_crop( + img, lbl, xysize=xysize, nonempty_slices=nonempty) + + # Remove empty sections of lbl and sync with info + if lbl.shape[2] == 0: + lbl = np.zeros(lbl.shape[:2] + (1,), dtype=bool) + info = self.current_info.copy() + for k in ['cellLabels', 'focusStack', 'buds']: + if k in info and info[k] is not None: + l = info[k] + if type(l) != list: + l = [l] + if len(l) != n_total: + raise Exception(f'"{k}" info does not match label image shape ({str(info)})') + info[k] = [x for x, m in zip(l, nonempty) if m] + self.current_info = info + + # Smooth filled label to avoid anti-aliasing artefacts + lbl = lbl.astype('float') + for l in range(lbl.shape[2]): + o = lbl[..., l] # slice a single outline + sigma = max(self.sigmafunc(o), 1) + if o.sum() > 0: + lbl[..., l] = gaussian(_bordered_fill_holes(o), sigma) + + return img, lbl + def downscale(self, img, lbl, maxpix=None): # Scale image and label to target pixel size inshape = img.shape[:2] scaling = self._scaling # Apply random scaling according to probability for this op + # also need to correct the aug_log + last_aug = self.aug_log.pop() + assert last_aug == 'downscale' p = self.scale_prob if p == 1.0 or (p > 0 and np.random.uniform() < p): + self.aug_log.append('downscale') scaling += scaling * self.scale_frac * np.random.uniform(-1, 1) - outshape = np.round(np.array(img.shape[:2]) * scaling) - outshape = np.maximum(outshape, self.xy_out) - - return (transform.resize(img, outshape), - transform.resize(lbl, outshape, anti_aliasing=False)) + outshape = np.round(np.array(img.shape[:2]) * scaling).astype(int) + # OpenCV is much faster but much less accurate... + # img = cv2.resize(img, outshape, interpolation=cv2.INTER_LINEAR) + # if img.ndim == 2: + # img = img[:, :, None] + # lbl = cv2.resize(lbl, outshape, interpolation=cv2.INTER_LINEAR) + # if lbl.ndim == 2: + # lbl = lbl[:, :, None] + img = transform.resize(img, outshape, mode='edge', anti_aliasing=False) + lbl = transform.resize(lbl, outshape, order=3, mode='edge', anti_aliasing=False) + return img, lbl # =============== UTILITY FUNCTIONS ====================== # +def _shift_and_crop(stack, vshift, hshift, xysize, pad_value, + nonempty_slices=None): + r_in, c_in = stack.shape[:2] + r_out, c_out = xysize + rb_in = (r_in - r_out) // 2 - vshift + cb_in = (c_in - c_out) // 2 - hshift + re_in = rb_in + r_out + ce_in = cb_in + c_out + rb_out = max(0, -rb_in) + re_out = min(r_out, r_out - re_in + r_in) + cb_out = max(0, -cb_in) + ce_out = min(c_out, c_out - ce_in + c_in) + rb_in = max(0, rb_in) + re_in = min(r_in, re_in) + cb_in = max(0, cb_in) + ce_in = min(c_in, ce_in) + slicedims = stack.shape[2:] + if nonempty_slices is not None: + nonempty_slices[()] = stack[rb_in:re_in, cb_in:ce_in].any(axis=(0,1)) + slicedims = (int(nonempty_slices.sum()),) + slicedims[1:] + out = np.full_like(stack, pad_value, shape=tuple(xysize) + slicedims) + if nonempty_slices is None: + out[rb_out:re_out, cb_out:ce_out] = stack[rb_in:re_in, cb_in:ce_in] + else: + out[rb_out:re_out, cb_out:ce_out, :] = stack[ + rb_in:re_in, cb_in:ce_in, nonempty_slices] + return out + def _apply_crop(stack, xysize): cropy, cropx = xysize starty, startx = stack.shape[:2] - startx = (startx - cropx) // 2 - starty = (starty - cropy) // 2 - return stack[starty:(starty + cropy), startx:(startx + cropx), ...] + if startx > cropx: + startx = (startx - cropx) // 2 + stack = stack[:, startx:(startx + cropx), ...] + if starty > cropy: + starty = (starty - cropy) // 2 + stack = stack[starty:(starty + cropy), :, ...] + return stack -def _elastic_deform(image, alpha, sigma, random_state=None): - """ - Elastic deformation of images as described in [Simard2003]_. - [Simard2003] Simard, Steinkraus and Platt, "Best Practices for - Convolutional Neural Networks applied to Visual Document Analysis", in - Proc. of the International Conference on Document Analysis and - Recognition, 2003. - Adapted from: - https://gist.github.com/chsasank/4d8f68caf01f041a6453e67fb30f8f5a +def _bordered_fill_holes(edgemask, pad_width=1, prepadded=False): + """Fill holes in a mask treating the boundary as an edge + + NB: this function assumes that no cells intersect with opposing borders. """ - if random_state is None: - random_state = np.random.RandomState(None) - - shape = image[:, :, 0].shape - dx = gaussian_filter( - (random_state.rand(*shape) * 2 - 1), sigma, mode="constant", - cval=0) * alpha - dy = gaussian_filter( - (random_state.rand(*shape) * 2 - 1), sigma, mode="constant", - cval=0) * alpha - - _x, _y = np.meshgrid(np.arange(shape[0]), - np.arange(shape[1]), - indexing='ij') - indices = np.reshape(_x + dx, (-1, 1)), np.reshape(_y + dy, (-1, 1)) - if len(image.shape) == 3: - result = np.empty_like(image) - for d in range(image.shape[2]): - # iterate over depth - cval = np.median(image[:, :, d]) - result[:, :, d] = map_coordinates(image[:, :, d], - indices, - order=1, - cval=cval).reshape(shape) - result - else: - cval = np.median(image) - result = map_coordinates(image, indices, order=1, - cval=cval).reshape(shape) - return result + bp = pad_width + if not prepadded: + edgemask = np.pad(edgemask, bp) + + mask = np.zeros(edgemask.shape, dtype='bool') + mask[bp:-bp, bp:-bp] = edgemask[bp:-bp, bp:-bp] -def _filled_canny(segblur, bp=2): + # The following assumes that the cell does not intersect with opposing + # borders, filling bordering with two U-shaped border edges: + mask[:bp, :] = True + mask[-bp:, :] = True + mask[:, :bp] = True + mask = binary_fill_holes(mask) + + mask[:, :bp] = False + mask[:, -bp:] = True + mask = binary_fill_holes(mask) + + return mask[bp:-bp, bp:-bp] + + +def _filled_canny(segblur, pad_width=2): """Use canny to find edge and fill object Handles intersections with border by assuming that the object cannot intersect all borders at once. segblur: segmentation image that has been gaussian blurred - bp: border padding + pad_width: border padding """ - - se = canny(np.pad(segblur, bp, 'edge'), sigma=0) - sf = np.zeros(se.shape, dtype='bool') - sf[bp:-bp, bp:-bp] = se[bp:-bp, bp:-bp] - - # The following assumes that the cell does not intersect with opposing - # borders, filling bordering with two U-shaped border edges: - sf[:bp, :] = True - sf[-bp:, :] = True - sf[:, :bp] = True - sf = binary_fill_holes(sf) - - sf[:, :bp] = False - sf[:, -bp:] = True - sf = binary_fill_holes(sf) - - return sf[bp:-bp, bp:-bp] + se = canny(np.pad(segblur, pad_width, 'edge'), sigma=0) + return _bordered_fill_holes(se, pad_width=pad_width, prepadded=True) diff --git a/python/baby/brain.py b/python/baby/brain.py index 50e9f4023e6278b8329ccec938d99e530182bf21..0f9cb0b529b0342cd7f20053e30792b786b8d413 100644 --- a/python/baby/brain.py +++ b/python/baby/brain.py @@ -1,23 +1,25 @@ # If you publish results that make use of this software or the Birth Annotator # for Budding Yeast algorithm, please cite: -# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S -# Swain, 2021, Birth Annotator for Budding Yeast (in preparation). -# -# +# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +# and Swain, P.S. (2023). Determining growth rates from bright-field images of +# budding cells through identifying overlaps. eLife. 12:e79812. +# https://doi.org/10.7554/eLife.79812 +# +# # The MIT License (MIT) -# -# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021 -# +# +# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 +# # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to # deal in the Software without restriction, including without limitation the # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or # sell copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: -# +# # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. -# +# # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE @@ -27,7 +29,8 @@ # IN THE SOFTWARE. from __future__ import (absolute_import, division, print_function, unicode_literals) -from os.path import dirname, join, isfile, isdir +from pathlib import Path +import json from itertools import repeat, chain import numpy as np @@ -37,30 +40,42 @@ import tensorflow as tf from tensorflow.keras import models, layers from tensorflow.keras import backend as K -from .losses import bce_dice_loss, dice_loss, dice_coeff -from .preprocessing import robust_norm, SegmentationFlattening -from .morph_thresh_seg import MorphSegGrouped -from .tracker.core import MasterTracker -from .utils import batch_iterator, split_batch_pred -from .brain_util import _segment, _track, _segment_and_track_parallel +try: + from tensorflow.keras.optimizers import AdamW +except ImportError: + try: + from tensorflow.keras.optimizers.experimental import AdamW + except ImportError: + try: + from tensorflow_addons.optimizers import AdamW + except ImportError: + raise ImportError('You need to pip install tensorflow-addons with this version of tensorflow') + +from skimage import transform -models_path = join(dirname(__file__), 'models') +from . import modelsets +from .losses import bce_dice_loss, dice_loss, dice_coeff +from .preprocessing import SegmentationFlattening +from .morph_thresh_seg import MorphSegGrouped, SegmentationParameters +from .tracker.core import MasterTracker, MMTracker +from .utils import batch_iterator, split_batch_pred, as_python_object +from .brain_util import (_segment, _track, _segment_and_track_parallel, + _apply_preprocessing, _rescale_output, + _tile_generator, _stitch_tiles, _batch_generator) +from .errors import BadFile, BadParam tf_version = [int(v) for v in tf.__version__.split('.')] -# Default to optimal segmentation parameters found in Jupyter notebook -# segmentation-190906.ipynb: -default_params = { - 'interior_threshold': (0.7, 0.5, 0.5), - 'nclosing': (1, 0, 0), - 'nopening': (1, 0, 0), - 'connectivity': (2, 2, 1), - 'pedge_thresh': 0.001, - 'fit_radial': True, - 'edge_sub_dilations': 1, - 'use_group_thresh': True, - 'group_thresh_expansion': 0.1 -} + +DEFAULT_MODELSET = 'yeast-alcatras-brightfield-EMCCD-60x-5z' +DEFAULT_MODELSET_PARAMS = None + + +def _default_params(): + global DEFAULT_MODELSET_PARAMS + if DEFAULT_MODELSET_PARAMS is None: + DEFAULT_MODELSET_PARAMS = modelsets.get_params(DEFAULT_MODELSET) + return DEFAULT_MODELSET_PARAMS class BabyBrain(object): @@ -70,37 +85,58 @@ class BabyBrain(object): file or the name of the file in the default "models" dir shipped with this package. If any are left unspecified, then default models will be loaded. - :param morph_model_file: neural network model taking a stack of images and - outputting predictions according to the paired flattener model. - :param flattener_file: a saved `SegmentationFlattening` model specifying - the trained target types for each output layer of the neural network. - :param celltrack_model_file: - :param budassign_model_file: - :param default_image_size: - :param params: dict of keyword parameters to be passed to the - `morph_seg_grouped` function when segmenting. - :param session: optionally specify the Tensorflow session to load the - neural network model into (useful only for Tensorflow versions <2) - :param graph: optionally specify the Tensorflow graph to load the neural - network model into (useful only for Tensorflow versions <2) - :param suppress_errors: whether or not to catch Exceptions raised during - segmentation or tracking. If True, then any Exceptions will be logged - using standard Python logging. - :param error_dump_dir: optionally specify a directory in which to dump - input parameters when an error is caught. + Args: + morph_model_file: neural network model taking a stack of images and + outputting predictions according to the paired flattener model. + flattener_file: a saved `SegmentationFlattening` model specifying + the trained target types for each output layer of the neural network. + celltrack_model_file: file name of the saved machine learning model + for default tracking predictions. + celltrack_backup_model_file: file name of the saved machine learning + model for backup tracking predictions. + budassign_model_file: file name of the saved machine learning model + for predicting bud assignments + pixel_size (float): Target pixel size for inputs to the trained CNN. + default_image_size (None or Tuple[int] or int): Optionally specify an + alternative to the input size of the trained CNN as a ``(W, H)`` + tuple giving the ``W`` and ``H`` of the image. If just an ``int`` + is specified, then ``W = H`` is assumed. + params (str or Path or SegmentationParameters): Segmentation + parameters to use with :py:class:`MorphSegGrouped`. May be + specified as :py:class:`SegmentationParameters`, or the path to a + saved :py:class:`SegmentationParameters`. + modelset_path (str or Path): path to a folder containing the files + specified by the other arguments. The shared folder of the local + model set cache is always checked if file is not found at + `modelset_path`. See :py:func:`modelsets.resolve` for details on + how paths to model files are resolved. + session: optionally specify the Tensorflow session to load the + neural network model into (useful only for Tensorflow versions <2) + graph: optionally specify the Tensorflow graph to load the neural + network model into (useful only for Tensorflow versions <2) + suppress_errors: whether or not to catch Exceptions raised during + segmentation or tracking. If True, then any Exceptions will be logged + using standard Python logging. + error_dump_dir: optionally specify a directory in which to dump + input parameters when an error is caught. ''' def __init__(self, morph_model_file=None, flattener_file=None, celltrack_model_file=None, + celltrack_backup_model_file=None, budassign_model_file=None, + mmtracking=False, pixel_size=0.263, default_image_size=None, - params=default_params, + params=SegmentationParameters(), + nstepsback=None, clogging_thresh=0.75, min_bud_tps=3, isbud_thresh=0.5, + input_norm_dw=False, + modelset_path=DEFAULT_MODELSET, session=None, graph=None, print_info=False, @@ -110,25 +146,28 @@ class BabyBrain(object): self.reshaped_models = {} if morph_model_file is None: - morph_model_file = join(models_path, 'I5_msd_d80_20190916.hdf5') - elif not isfile(morph_model_file): - morph_model_file = join(models_path, morph_model_file) + morph_model_file = _default_params()['morph_model_file'] + morph_model_file = modelsets.resolve(morph_model_file, modelset_path) if flattener_file is None: - flattener_file = join(models_path, 'flattener_v2_20190905.json') - elif not isfile(flattener_file): - flattener_file = join(models_path, flattener_file) + flattener_file = _default_params()['flattener_file'] + flattener_file = modelsets.resolve(flattener_file, modelset_path) + + if type(params) == dict: + params = SegmentationParameters(**params) + if type(params) != SegmentationParameters: + param_file = modelsets.resolve(params, modelset_path) + with open(param_file, 'rt') as f: + params = json.load(f, object_hook=as_python_object) + if type(params) != SegmentationParameters: + raise BadFile('Specified file does not contain a valid ' + '`SegmentationParameters` object') + + if not params.fit_radial: + raise BadParam('`BabyBrain` currently only works for ' + '`SegmentationParameters` with `fit_radial=True`') - if celltrack_model_file is None: - celltrack_model_file = join(models_path, 'ct_svc_20201106_6.pkl') - elif not isfile(celltrack_model_file): - celltrack_model_file = join(models_path, celltrack_model_file) - - if budassign_model_file is None: - budassign_model_file = join(models_path, - 'baby_randomforest_20190906.pkl') - elif not isfile(budassign_model_file): - budassign_model_file = join(models_path, budassign_model_file) + self.params = params self.session = None self.graph = None @@ -154,7 +193,8 @@ class BabyBrain(object): custom_objects={ 'bce_dice_loss': bce_dice_loss, 'dice_loss': dice_loss, - 'dice_coeff': dice_coeff + 'dice_coeff': dice_coeff, + 'AdamW': AdamW }) else: self.morph_model = models.load_model( @@ -164,35 +204,60 @@ class BabyBrain(object): 'dice_loss': dice_loss, 'dice_coeff': dice_coeff }) + if tf_version[0] == 2 and tf_version[1] > 3: + # TF 2.4 no longer supports reshaping using the functional API + # on a new Input since it allows multiple inputs to a layer + # Recommendation is to define generic Input layers with None + # for variable dimensions, or to disable input checking with + # the following code. + # See release notes for version 2.4.0 at + # https://github.com/tensorflow/tensorflow/blob/master/RELEASE.md + self.morph_model.input_spec = None self.suppress_errors = suppress_errors self.error_dump_dir = error_dump_dir + self.input_norm_dw = input_norm_dw self.flattener = SegmentationFlattening(flattener_file) - self.params = params - if ('use_group_thresh' not in self.params and - 'group_thresh_expansion' in self.params): - self.params['use_group_thresh'] = True self.morph_segmenter = MorphSegGrouped(self.flattener, - fit_radial=True, + params=params, return_masks=True, - return_coords=True, - **self.params) + return_coords=True) self.pixel_size = pixel_size self.clogging_thresh = clogging_thresh - # Load tracker models and initialise Tracker - with open(celltrack_model_file, 'rb') as f: - celltrack_model = pickle.load(f) - with open(budassign_model_file, 'rb') as f: - budassign_model = pickle.load(f) - self.tracker = MasterTracker(ctrack_args={'model': celltrack_model}, - btrack_args={'model': budassign_model}, - min_bud_tps=min_bud_tps, - isbud_thresh=isbud_thresh, - px_size=pixel_size) + if mmtracking: + # Set the tracker engine to the model-free Mother Machine variant + self.tracker = MMTracker(px_size=pixel_size) + else: + # Load tracker models and initialise Tracker + if celltrack_model_file is None: + celltrack_model_file = _default_params()['celltrack_model_file'] + celltrack_model_file = modelsets.resolve(celltrack_model_file, modelset_path) + + if celltrack_backup_model_file is None: + celltrack_backup_model_file = _default_params()[ + 'celltrack_backup_model_file'] + celltrack_backup_model_file = modelsets.resolve( + celltrack_backup_model_file, modelset_path) + + if budassign_model_file is None: + budassign_model_file = _default_params()['budassign_model_file'] + budassign_model_file = modelsets.resolve(budassign_model_file, modelset_path) + + ctrack_args ={ + 'model': celltrack_model_file, + 'bak_model': celltrack_backup_model_file, + 'nstepsback': nstepsback + } + btrack_args = {'model': budassign_model_file} + self.tracker = MasterTracker(ctrack_args=ctrack_args, + btrack_args=btrack_args, + min_bud_tps=min_bud_tps, + isbud_thresh=isbud_thresh, + px_size=pixel_size) # Run prediction on mock image to load model for prediction _, x, y, z = self.morph_model.input.shape @@ -210,43 +275,78 @@ class BabyBrain(object): def depth(self): return self.morph_model.input.shape[3] - def morph_predict(self, X, needs_context=True): + def _predict(self, X, needs_context=True): if tf_version[0] == 1 and needs_context: with self.graph.as_default(): K.set_session(self.session) - return self.morph_predict(X, needs_context=False) + return self._predict(X, needs_context=False) imdims = X.shape[1:3] - # Current MSD model requires shape to be divisible by 8 - nndims = tuple([int(np.ceil(float(d) / 8.)) * 8 for d in imdims]) + # Current MSD model requires shape to be divisible by 8. Standard 5 + # layer U-Net requires shape to be divisible by 16. + # CNN model for FRET is very sensitive to padding: we need images to + # be at least 64px wide, the padding needs to be centred and the pad + # value needs to be the image median. This basically makes the + # following match the padding protocol in the Augmenter + nndims = tuple(max(64, int(np.ceil(float(d) / 16.)) * 16) + for d in imdims) + xpadoff, ypadoff = 0, 0 if not all([n == i for n, i in zip(nndims, imdims)]): xpad, ypad = tuple(n - i for n, i in zip(nndims, imdims)) - X = np.pad(X, ((0, 0), (0, xpad), (0, ypad), (0, 0)), 'edge') + xpadoff, ypadoff = xpad // 2, ypad // 2 + xpad = (xpadoff, xpad - xpadoff) + ypad = (ypadoff, ypad - ypadoff) + X = np.array([np.pad(Xi, (xpad, ypad, (0, 0)), + mode='constant', + constant_values=np.median(Xi)) + for Xi in X]) if nndims not in self.reshaped_models: - base_input_shape = self.morph_model.input.shape[1:3] - if all([n == m for n, m in zip(nndims, base_input_shape)]): - self.reshaped_models[nndims] = self.morph_model - else: - i = layers.Input(shape=X.shape[1:]) - self.reshaped_models[nndims] = models.Model( - i, self.morph_model(i)) + i = layers.Input(shape=X.shape[1:]) + self.reshaped_models[nndims] = models.Model( + i, self.morph_model(i)) if tf_version[0] == 1 and self.print_info: print('Running prediction in session "{}"...'.format( K.get_session())) - pred = self.reshaped_models[nndims].predict(X) + pred = self.reshaped_models[nndims].predict(X, verbose=0) - return [p[:, :imdims[0], :imdims[1], :] for p in pred] + return [p[:, xpadoff:xpadoff+imdims[0], ypadoff:ypadoff+imdims[1], :] + for p in pred] + + def morph_predict(self, X, pixel_size=None, overlap_size=48, + yield_rescaling=False, keep_bb_pixel_size=False): + # First preprocess each brightfield image in batch + X, rescaling, inshape = _apply_preprocessing( + X, self.input_norm_dw, pixel_size, self.pixel_size) + + if yield_rescaling: + yield rescaling, inshape + + tilegen = _tile_generator(X, overlap_size=overlap_size) + tiling_strategy = next(tilegen) + predgen = chain(*map(lambda x: split_batch_pred(self._predict(x)), + _batch_generator(tilegen, 8))) + + for pred in _stitch_tiles(predgen, *tiling_strategy): + if keep_bb_pixel_size: + yield pred + else: + yield transform.resize( + pred, (len(pred),) + inshape, order=1) def segment(self, bf_img_batch, + pixel_size=None, + overlap_size=48, yield_edgemasks=False, yield_masks=False, yield_preds=False, yield_volumes=False, - refine_outlines=False): + refine_outlines=False, + yield_rescaling=False, + keep_bb_pixel_size=False): '''Generator yielding segmented output for a batch of input images :param bf_img_batch: a list of ndarray with shape (X, Y, Z), or @@ -272,19 +372,26 @@ class BabyBrain(object): - volumes: (optional) list of floats corresponding, for each cell, to the conical section method for cell volume estimation ''' - # First preprocess each brightfield image in batch - bf_img_batch = np.stack( - [robust_norm(img, {}) for img in bf_img_batch]) - for batch in batch_iterator(bf_img_batch): - morph_preds = split_batch_pred(self.morph_predict(batch)) - - for cnn_output in morph_preds: - yield _segment(self.morph_segmenter, cnn_output, - refine_outlines, yield_volumes, yield_masks, - yield_preds, yield_edgemasks, - self.clogging_thresh, self.error_dump_dir, - self.suppress_errors) + predgen = self.morph_predict(bf_img_batch, + pixel_size=pixel_size, + overlap_size=overlap_size, + yield_rescaling=True, + keep_bb_pixel_size=True) + rescaling, inshape = next(predgen) + if yield_rescaling: + yield rescaling, inshape + + for cnn_output in predgen: + segout = _segment(self.morph_segmenter, cnn_output, + refine_outlines, yield_volumes, yield_masks, + yield_preds, yield_edgemasks, + self.clogging_thresh, self.error_dump_dir, + self.suppress_errors) + if not keep_bb_pixel_size: + _rescale_output(segout, rescaling, inshape, + self.morph_segmenter.params.cartesian_spline) + yield segout def run(self, bf_img_batch): '''Implementation of legacy runner function... @@ -314,6 +421,7 @@ class BabyBrain(object): def segment_and_track(self, bf_img_batch, tracker_states=None, + pixel_size=None, yield_next=False, yield_edgemasks=False, assign_mothers=False, @@ -365,26 +473,46 @@ class BabyBrain(object): tracker_states = repeat(None) tnames = self.flattener.names() - i_budneck = tnames.index('bud_neck') bud_target = 'sml_fill' if 'sml_fill' in tnames else 'sml_inte' - i_bud = tnames.index(bud_target) + assign_buds = 'bud_neck' in tnames and bud_target in tnames + if assign_buds: + i_budneck = tnames.index('bud_neck') + i_bud = tnames.index(bud_target) + else: + i_budneck = None + i_bud = None segment_gen = self.segment(bf_img_batch, + pixel_size=pixel_size, yield_masks=True, yield_edgemasks=True, yield_preds=True, yield_volumes=yield_volumes, - refine_outlines=refine_outlines) + refine_outlines=refine_outlines, + yield_rescaling=True, + keep_bb_pixel_size=True) + + rescaling, inshape = next(segment_gen) for seg, state in zip(segment_gen, tracker_states): - yield _track(self.tracker, seg, state, i_budneck, i_bud, + trackout = _track(self.tracker, seg, state, i_budneck, i_bud, assign_mothers, return_baprobs, yield_edgemasks, yield_next, self.error_dump_dir, self.suppress_errors) + if yield_next: + segout, state = trackout + _rescale_output(segout, rescaling, inshape, + self.morph_segmenter.params.cartesian_spline) + yield segout, state + else: + _rescale_output(trackout, rescaling, inshape, + self.morph_segmenter.params.cartesian_spline) + yield trackout def segment_and_track_parallel(self, bf_img_batch, tracker_states=None, + pixel_size=None, yield_next=False, yield_edgemasks=False, assign_mothers=False, @@ -432,21 +560,21 @@ class BabyBrain(object): tracker states for this time point as a tuple ''' - # First preprocess each brightfield image in batch - bf_img_batch = np.stack( - [robust_norm(img, {}) for img in bf_img_batch]) - # Do not run the CNN in parallel - morph_preds = list( - chain(*(split_batch_pred(self.morph_predict(batch)) - for batch in batch_iterator(bf_img_batch)))) + predgen = self.morph_predict(bf_img_batch, + pixel_size=pixel_size, + yield_rescaling=True, + keep_bb_pixel_size=True) + rescaling, inshape = next(predgen) + preds = list(predgen) if tracker_states is None: tracker_states = repeat(None) trackout = _segment_and_track_parallel( - self.morph_segmenter, self.tracker, self.flattener, morph_preds, - tracker_states, refine_outlines, yield_volumes, yield_edgemasks, - self.clogging_thresh, assign_mothers, return_baprobs, yield_next, - njobs, self.error_dump_dir, self.suppress_errors) + self.morph_segmenter, self.tracker, self.flattener, preds, + tracker_states, rescaling, inshape, refine_outlines, + yield_volumes, yield_edgemasks, self.clogging_thresh, + assign_mothers, return_baprobs, yield_next, njobs, + self.error_dump_dir, self.suppress_errors) return trackout diff --git a/python/baby/brain_util.py b/python/baby/brain_util.py index 10a7cd15a4703ae811d802f9315a61e551f40095..ef6e49845925fd7bafccad2b05dd8caaf5166af7 100644 --- a/python/baby/brain_util.py +++ b/python/baby/brain_util.py @@ -1,23 +1,25 @@ # If you publish results that make use of this software or the Birth Annotator # for Budding Yeast algorithm, please cite: -# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S -# Swain, 2021, Birth Annotator for Budding Yeast (in preparation). -# -# +# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +# and Swain, P.S. (2023). Determining growth rates from bright-field images of +# budding cells through identifying overlaps. eLife. 12:e79812. +# https://doi.org/10.7554/eLife.79812 +# +# # The MIT License (MIT) -# -# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021 -# +# +# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 +# # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to # deal in the Software without restriction, including without limitation the # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or # sell copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: -# +# # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. -# +# # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE @@ -31,13 +33,18 @@ from os.path import join, isdir from time import strftime from uuid import uuid1 import logging +from itertools import islice import numpy as np import pickle -from scipy.ndimage import binary_dilation +from scipy.ndimage import binary_dilation, binary_fill_holes +from skimage import transform +from skimage.morphology import diamond from .morph_thresh_seg import SegmentationOutput +from .segmentation import draw_radial from .io import save_tiled_image -from .errors import Clogging +from .errors import Clogging, BadParam +from .preprocessing import robust_norm, robust_norm_dw _logger = None @@ -62,6 +69,170 @@ def _generate_error_dump_id(): return strftime('%Y-%m-%d_%H-%M-%S_') + str(uuid1()) +def _batch_generator(iterable, n=8): + """Yields batches from an iterator + Based on + https://docs.python.org/3/library/itertools.html#itertools-recipes""" + it = iter(iterable) + while True: + batch = list(islice(it, n)) + if not batch: + return + yield np.array(batch) + + +def _tile_generator(imgarray, + overlap_size=48, + tile_sizes=np.arange(64, 257, 8)): + """Yields optimal tiles from an image.""" + + imH, imW = imgarray.shape[1:3] + + if imH * imW <= np.square(np.max(tile_sizes)): + # If area is less than that of max tile size we will simply return + # images without modification + nH, nW = 1, 1 + tileH, tileW = imH, imW + else: + tile_sizes = tile_sizes[tile_sizes > overlap_size + 32] + if len(tile_sizes) == 0: + raise BadParam('`overlap_size` is too large for the ' + + 'specified tile_sizes') + + # For each candidate tile size, calculate the required number of tiles + # per image (rounding up), including allowance for overlaps + nH = np.ceil((imH - overlap_size)/(tile_sizes - overlap_size)).astype(int) + nW = np.ceil((imW - overlap_size)/(tile_sizes - overlap_size)).astype(int) + + # Calculate the new total size given the rounding above + paddedH = nH * (tile_sizes - overlap_size) + overlap_size + paddedW = nW * (tile_sizes - overlap_size) + overlap_size + + # Choose the tile size that minimises the fractional padding + padH = paddedH - imH + i_best = np.argmin(padH / tile_sizes) + nH, tileH = nH[i_best], tile_sizes[i_best] + padH = padH[i_best] + padW = paddedW - imW + i_best = np.argmin(padW / tile_sizes) + nW, tileW = nW[i_best], tile_sizes[i_best] + padW = padW[i_best] + + # Pad image to tiled size + imgarray = np.pad( + imgarray, ((0, 0), (0, padH), (0, padW), (0, 0))) + + # First yield tiling details for future stitching + yield nH, nW, imH, imW, overlap_size + + # Split the images up into tiles + for img in imgarray: + for j in range(nH): + for i in range(nW): + Hl = j * (tileH - overlap_size) + Hu = Hl + tileH + Wl = i * (tileW - overlap_size) + Wu = Wl + tileW + yield img[Hl:Hu, Wl:Wu, ...] + + +def _stitch_tiles(tilestream, nH, nW, imH, imW, nOvlap): + """Stitches tiles back together""" + + # Regions where tiles overlap use linear decay for weighted averaging + lindecay = np.linspace(1, 0, nOvlap + 2)[1:-1] + Hdecay = np.c_[lindecay][..., None] + Wdecay = np.c_[lindecay].T[..., None] + + # Batch tiles into each complete image + for tiles in zip(*[iter(tilestream)] * (nW * nH)): + # Transpose tiles so that we can stitch with h/vstack + tiles = [tile.transpose((1, 2, 0)) for tile in tiles] + + img = None + # Batch tiles into rows at a time + for row in zip(*[iter(tiles)]*nW): + # Stitch tiles into a row + rowimg = row[0] + for t in row[1:]: + rowimg[:, -nOvlap:] = (rowimg[:, -nOvlap:] * Wdecay + + t[:, :nOvlap] * Wdecay[:, ::-1]) + rowimg = np.hstack((rowimg, t[:, nOvlap:])) + + # Stitch rows together + if img is None: + img = rowimg + else: + img[-nOvlap:, :] = (img[-nOvlap:, :] * Hdecay + + rowimg[:nOvlap, :] * Hdecay[::-1, :]) + img = np.vstack((img, rowimg[nOvlap:, :])) + + yield img[:imH, :imW, :].transpose((2, 0, 1)) + + +def _apply_preprocessing(bf_imgs, input_norm_dw, pxsize_in, pxsize_out): + input_norm = robust_norm_dw if input_norm_dw else robust_norm + bf_imgs = np.stack([input_norm(img, {}) for img in bf_imgs]) + inshape = bf_imgs.shape + rescaling = None + if pxsize_in is not None and pxsize_in != pxsize_out: + rescaling = pxsize_in / pxsize_out + if rescaling is not None: + bf_imgs = list(bf_imgs) + for i, img in enumerate(bf_imgs): + bf_imgs[i] = transform.rescale(img, + rescaling, + order=1, + channel_axis=2) + bf_imgs = np.stack(bf_imgs) + rescaling = 1. / rescaling + return bf_imgs, rescaling, inshape[1:3] + + +def _rescale_output(output, rescaling, outshape, cartesian_spline): + if rescaling is None: + return + + if 'centres' in output: + output['centres'] = [[x * rescaling for x in cell] + for cell in output['centres']] + if 'radii' in output: + output['radii'] = [[r * rescaling for r in cell] + for cell in output['radii']] + has_coords = {'centres', 'angles', 'radii'}.issubset(output.keys()) + if has_coords and ('edgemasks' in output or 'masks' in output): + edgemasks = np.stack([ + draw_radial(np.array(r), np.array(a), c, outshape, + cartesian_spline=cartesian_spline) + for r, a, c in zip(output['radii'], output['angles'], + output['centres']) + ]) + _0xy = (0,) + outshape + if 'edgemasks' in output: + if output['edgemasks'].shape[0] == 0: + output['edgemasks'] = np.zeros(_0xy, dtype=bool) + elif has_coords: + output['edgemasks'] = edgemasks + else: + output['edgemasks'] = transform.resize( + output['edgemasks'], (len(output['edgemasks']),) + outshape, order=0) + if 'masks' in output: + if output['masks'].shape[0] == 0: + output['masks'] = np.zeros(_0xy, dtype=bool) + elif has_coords: + output['masks'] = binary_fill_holes( + edgemasks, diamond(1)[None, ...]) + else: + output['masks'] = transform.resize( + output['masks'], (len(output['masks']),) + outshape, order=0) + if 'preds' in output: + output['preds'] = transform.resize( + output['preds'], (len(output['preds']),) + outshape, order=1) + if 'volumes' in output: + vscaling = rescaling ** 3 + output['volumes'] = [v * vscaling for v in output['volumes']] + + def _segment(segmenter, cnn_output, refine_outlines, @@ -166,10 +337,18 @@ def _track(tracker, if logger is None: logger = _get_logger() + if i_budneck is None: + p_budneck = np.zeros(seg['preds'].shape[1:]) + else: + p_budneck = seg['preds'][i_budneck] + if i_bud is None: + p_bud = np.zeros(seg['preds'].shape[1:]) + else: + p_bud = seg['preds'][i_bud] try: tracking = tracker.step_trackers(seg['masks'], - seg['preds'][i_budneck], - seg['preds'][i_bud], + p_budneck, + p_bud, state=state, assign_mothers=assign_mothers, return_baprobs=return_baprobs) @@ -184,10 +363,10 @@ def _track(tracker, seg['masks'].transpose((1, 2, 0)).astype('uint8'), fprefix + '_masks.png') save_tiled_image( - np.uint16((2**16 - 1) * seg['preds'][i_budneck, :, :, None]), + np.uint16((2**16 - 1) * p_budneck[..., None]), fprefix + '_budneck_pred.png') save_tiled_image( - np.uint16((2**16 - 1) * seg['preds'][i_bud, :, :, None]), + np.uint16((2**16 - 1) * p_bud[..., None]), fprefix + '_bud_pred.png') with open(fprefix + '_state.pkl', 'wb') as f: pickle.dump(state, f) @@ -250,6 +429,8 @@ def _segment_and_track(segmenter, state, i_budneck, i_bud, + rescaling, + outshape, refine_outlines, yield_volumes, yield_edgemasks, @@ -274,40 +455,53 @@ def _segment_and_track(segmenter, error_dump_dir, suppress_errors, logger=logger) - return _track(tracker, - segout, - state, - i_budneck, - i_bud, - assign_mothers, - return_baprobs, - yield_edgemasks, - yield_next, - error_dump_dir, - suppress_errors, - logger=logger) + trackout = _track(tracker, + segout, + state, + i_budneck, + i_bud, + assign_mothers, + return_baprobs, + yield_edgemasks, + yield_next, + error_dump_dir, + suppress_errors, + logger=logger) + if yield_next: + segout, state = trackout + _rescale_output(segout, rescaling, outshape, + segmenter.params.cartesian_spline) + return segout, state + else: + _rescale_output(trackout, rescaling, outshape, + segmenter.params.cartesian_spline) + return trackout -def _segment_and_track_parallel(segmenter, tracker, flattener, morph_preds, - tracker_states, refine_outlines, - yield_volumes, yield_edgemasks, - clogging_thresh, assign_mothers, - return_baprobs, yield_next, njobs, - error_dump_dir, suppress_errors): - # logger = _get_logger() +def _segment_and_track_parallel(segmenter, tracker, flattener, morph_preds, + tracker_states, rescaling, outshape, + refine_outlines, yield_volumes, + yield_edgemasks, clogging_thresh, + assign_mothers, return_baprobs, yield_next, + njobs, error_dump_dir, suppress_errors): tnames = flattener.names() - i_budneck = tnames.index('bud_neck') bud_target = 'sml_fill' if 'sml_fill' in tnames else 'sml_inte' - i_bud = tnames.index(bud_target) + assign_buds = 'bud_neck' in tnames and bud_target in tnames + if assign_buds: + i_budneck = tnames.index('bud_neck') + i_bud = tnames.index(bud_target) + else: + i_budneck = None + i_bud = None # Run segmentation and tracking in parallel from joblib import Parallel, delayed return Parallel(n_jobs=njobs, mmap_mode='c')( delayed(_segment_and_track) (segmenter, tracker, cnn_output, state, i_budneck, i_bud, - refine_outlines, yield_volumes, yield_edgemasks, clogging_thresh, - assign_mothers, return_baprobs, yield_next, error_dump_dir, - suppress_errors) + rescaling, outshape, refine_outlines, yield_volumes, yield_edgemasks, + clogging_thresh, assign_mothers, return_baprobs, yield_next, + error_dump_dir, suppress_errors) for cnn_output, state in zip(morph_preds, tracker_states)) diff --git a/python/baby/crawler.py b/python/baby/crawler.py index 0929f140d1285c8595f147d1242b11562c53dec5..328e9fe3ffbc917fbd071349d9d3afbdc35b3409 100644 --- a/python/baby/crawler.py +++ b/python/baby/crawler.py @@ -1,12 +1,14 @@ # If you publish results that make use of this software or the Birth Annotator # for Budding Yeast algorithm, please cite: -# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S -# Swain, 2021, Birth Annotator for Budding Yeast (in preparation). +# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +# and Swain, P.S. (2023). Determining growth rates from bright-field images of +# budding cells through identifying overlaps. eLife. 12:e79812. +# https://doi.org/10.7554/eLife.79812 # # # The MIT License (MIT) # -# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021 +# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to @@ -47,6 +49,7 @@ class BabyCrawler(object): def step(self, bf_img_batch, + pixel_size=None, with_edgemasks=False, assign_mothers=False, return_baprobs=False, @@ -92,20 +95,24 @@ class BabyCrawler(object): output = [] if parallel: + kwargs = {k: v for k, v in kwargs.items() if k in {'njobs'}} seg_trk_gen = self.baby_brain.segment_and_track_parallel( bf_img_batch, tracker_states=self.tracker_states, yield_next=True, + pixel_size=pixel_size, yield_edgemasks=with_edgemasks, yield_volumes=with_volumes, assign_mothers=assign_mothers, return_baprobs=return_baprobs, - refine_outlines=refine_outlines) + refine_outlines=refine_outlines, + **kwargs) else: seg_trk_gen = self.baby_brain.segment_and_track( bf_img_batch, tracker_states=self.tracker_states, yield_next=True, + pixel_size=pixel_size, yield_edgemasks=with_edgemasks, yield_volumes=with_volumes, assign_mothers=assign_mothers, diff --git a/python/baby/errors.py b/python/baby/errors.py index c53d3d55b58777d7cee435bc0edd2113b8ae6efc..ce052a9181a054a5836767d351b1dd7e95fb062e 100644 --- a/python/baby/errors.py +++ b/python/baby/errors.py @@ -1,12 +1,14 @@ # If you publish results that make use of this software or the Birth Annotator # for Budding Yeast algorithm, please cite: -# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S -# Swain, 2021, Birth Annotator for Budding Yeast (in preparation). +# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +# and Swain, P.S. (2023). Determining growth rates from bright-field images of +# budding cells through identifying overlaps. eLife. 12:e79812. +# https://doi.org/10.7554/eLife.79812 # # # The MIT License (MIT) # -# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021 +# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to @@ -56,3 +58,6 @@ class BadOutput(Exception): class Clogging(Exception): pass + +class BadModel(Exception): + pass diff --git a/python/baby/generator.py b/python/baby/generator.py index 5e6fa8c48aa7535fa0ac667af57ef7bec182e8a5..f3e2e975dc36e83ec765b3f5c3dc8c25a6cb4a7f 100644 --- a/python/baby/generator.py +++ b/python/baby/generator.py @@ -1,12 +1,14 @@ # If you publish results that make use of this software or the Birth Annotator # for Budding Yeast algorithm, please cite: -# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S -# Swain, 2021, Birth Annotator for Budding Yeast (in preparation). +# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +# and Swain, P.S. (2023). Determining growth rates from bright-field images of +# budding cells through identifying overlaps. eLife. 12:e79812. +# https://doi.org/10.7554/eLife.79812 # # # The MIT License (MIT) # -# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021 +# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to @@ -25,15 +27,21 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS # IN THE SOFTWARE. +import json from pathlib import Path +from contextlib import contextmanager import numpy as np from functools import namedtuple -from tensorflow.python.keras.utils.data_utils import Sequence +from itertools import repeat +from tensorflow.keras.utils import Sequence from tqdm import tqdm +from matplotlib import pyplot as plt +from PIL import Image from .io import load_tiled_image from .augmentation import Augmenter from .preprocessing import standard_norm +from .visualise import colour_segstack # Following are used by ImgLblInMem and should be deprecated from .preprocessing import robust_norm as preprocess_brightfield @@ -44,7 +52,8 @@ ImageLabelShapes = namedtuple('ImageLabelShapes', ('input', 'output')) class ImageLabel(Sequence): def __init__(self, paths, batch_size, aug, preprocess=None, - in_memory=False): + in_memory=False, balanced_sampling=False, + use_sample_weights=False, n_jobs=4): """Generator for training image-label pairs. Arguments: @@ -55,12 +64,15 @@ class ImageLabel(Sequence): preprocess: a single callable or tuple of callables (one for each file of pair); None specifies the default `standard_norm` in_memory: whether or not to load all images into memory + balanced_sampling: whether to increase sampling of images based on + their shape relative to the output crop """ self.batch_size = batch_size assert callable(aug), '"aug" must be a callable' self.aug = aug + self.n_jobs = n_jobs # Check that all required images exist self.paths = [(Path(img), Path(lbl)) for img, lbl in paths] @@ -84,29 +96,98 @@ class ImageLabel(Sequence): [ppf(*load_tiled_image(img)) for ppf, img in zip(self.preprocess, imgs)] for imgs in tqdm(self.paths) ] + + self.balanced_sampling = balanced_sampling + self.use_sample_weights = use_sample_weights # Initialise ordering self.on_epoch_end() def __len__(self): - return int(np.ceil(len(self.paths) / float(self.batch_size))) + return int(np.ceil(sum(self.nsamples) / float(self.batch_size))) @property def shapes(self): - if not hasattr(self, '_shapes') or not self._shapes: - if len(self.paths) == 0: - return ImageLabelShapes(tuple(), tuple()) + if len(self.paths) == 0: + return ImageLabelShapes(tuple(), tuple()) + + img, lbl = self.get_by_index(0) + if type(lbl) == tuple and len(lbl) == 2: + lbl, _ = lbl + Nbatch = (self.batch_size,) + return ImageLabelShapes(Nbatch + img.shape, Nbatch + lbl.shape) + + def _collect_size_info(self): + self._rawshapes = [] + self._pixelsizes = [] + self._weights = [] + for imgfile, lblfile in self.paths: + img = Image.open(imgfile) + img_info = json.loads(img.info.get('Description', '{}')) + if 'tilesize' in img_info and 'ntiles' in img_info: + img_shape = tuple(img_info['tilesize']) + (img_info['ntiles'],) + else: + img_shape = img.size + (1,) + self._rawshapes.append(img_shape) + + lbl = Image.open(lblfile) + lbl_info = json.loads(lbl.info.get('Description', '{}')) + self._pixelsizes.append(lbl_info.get('pixel_size')) + self._weights.append(lbl_info.get('weight',1)) + self._weights = np.array(self._weights) - img, lbl = self.get_by_index(0) - Nbatch = (self.batch_size,) - self._shapes = ImageLabelShapes(Nbatch + img.shape, - Nbatch + lbl.shape) - return self._shapes + @property + def rawshapes(self): + if not hasattr(self, '_rawshapes') or self._rawshapes is None: + self._collect_size_info() + return self._rawshapes + + @property + def pixelsizes(self): + if not hasattr(self, '_pixelsizes') or self._pixelsizes is None: + self._collect_size_info() + return self._pixelsizes + + @property + def weights(self): + if not hasattr(self, '_weights') or self._weights is None: + self._collect_size_info() + return self._weights + + @property + def target_pixel_size(self): + val = None + if hasattr(self.aug, 'target_pixel_size'): + val = self.aug.target_pixel_size + return val + + @property + def nsamples(self): + if ~self.balanced_sampling: + return np.ones(len(self.rawshapes), dtype='int').tolist() + + aug_shape = self.shapes.input[1:] + usepxsz = self.target_pixel_size is not None + if usepxsz: + aug_size = np.array(aug_shape, dtype=float) + aug_size[:2] *= self.target_pixel_size + nsamples = [] + for in_shape, pxsz in zip(self.rawshapes, self.pixelsizes): + if pxsz and usepxsz: + in_size = np.array(in_shape, dtype=float) + in_size[:2] *= pxsz + szratio = np.floor_divide(in_size, aug_size) + else: + szratio = np.floor_divide(in_shape, aug_shape) + szratio = szratio.astype(int) + nsamples.append(np.prod(np.maximum(szratio, 1))) + return nsamples def on_epoch_end(self): # Shuffle samples for next epoch - Nsamples = len(self.paths) - self.ordering = np.random.choice(Nsamples, Nsamples, replace=False) + self.ordering = np.repeat(np.arange(len(self.nsamples)), + self.nsamples) + np.random.shuffle(self.ordering) @property def n_pairs(self): @@ -124,24 +205,166 @@ class ImageLabel(Sequence): else: return aug(img, lbl) + def parallel_get_indices(self, inds, n_jobs=None): + if n_jobs is None: + n_jobs = self.n_jobs + passthrough = lambda img, lbl: (img, lbl) + img_lbl_pairs = [self.get_by_index(i, aug=passthrough) for i in inds] + from joblib import Parallel, delayed + return Parallel(n_jobs=n_jobs)( + delayed(self.aug)(img, lbl) for img, lbl in img_lbl_pairs) + def __getitem__(self, idx): Nbatch = self.batch_size current_batch = self.ordering[idx * Nbatch:(idx + 1) * Nbatch] - img_batch = [] - lbl_batch = [] - - for i in current_batch: - img, lbl = self.get_by_index(i) - lbl = np.dsplit(lbl, lbl.shape[2]) + if self.n_jobs > 1: + img_batch, lbl_batch = zip(*self.parallel_get_indices(current_batch)) + else: + img_batch, lbl_batch = zip(*[self.get_by_index(i) for i in + current_batch]) - img_batch.append(img) - lbl_batch.append(lbl) + lbl_batch = [np.dsplit(lbl, lbl.shape[2]) for lbl in lbl_batch] img_batch = np.array(img_batch) lbl_batch = [np.array(lw) for lw in zip(*lbl_batch)] - return img_batch, lbl_batch + if self.use_sample_weights: + return img_batch, lbl_batch, self.weights[current_batch] + else: + return img_batch, lbl_batch + + def plot_sample(self, i=0, figsize=3): + """Plot a sample batch from the generator + + This function assumes that the assigned Augmenter produces label + images that can be concatenated along a new axis. + """ + img_batch, lbl_batch = self[i][:2] + lbl_batch = np.concatenate(lbl_batch, axis=3) + + n_sections = img_batch.shape[3] + n_targets = lbl_batch.shape[3] + + target_names = repeat(None, n_targets) + edge_inds = None + if hasattr(self.aug, 'targetgenfunc'): + if hasattr(self.aug.targetgenfunc, 'names'): + target_names = self.aug.targetgenfunc.names() + if hasattr(self.aug.targetgenfunc, 'targets'): + edge_inds = np.flatnonzero( + [t.prop == 'edge' for t in self.aug.targetgenfunc.targets]) + + ncol = len(img_batch) + nrow = n_sections + n_targets + fig, axs = plt.subplots(nrow, ncol, + figsize=(figsize * ncol, figsize * nrow)) + + # Plot img sections first... + for axrow, section in zip(axs[:n_sections], + np.split(img_batch, n_sections, axis=3)): + for ax, img, lbl in zip(axrow, section, lbl_batch): + ax.imshow(img, cmap='gray') + if edge_inds is not None: + ax.imshow(colour_segstack(lbl[..., edge_inds], dw=True)) + ax.grid(False) + ax.set(xticks=[], yticks=[]) + + # ...then plot targets + for axrow, target, name in zip(axs[n_sections:], + np.split(lbl_batch, n_targets, axis=3), + target_names): + for ax, lbl in zip(axrow, target): + ax.imshow(lbl, cmap='gray') + ax.grid(False) + ax.set(xticks=[], yticks=[]) + if name is not None: + ax.set_title(name) + + return fig, axs + + + +@contextmanager +def augmented_generator(gen: ImageLabel, aug: Augmenter): + # Save the previous augmenter if any + saved_aug = gen.aug + gen.aug = aug + try: + yield gen + # Todo: add except otherwise there might be an issue of there is an error? + finally: + gen.aug = saved_aug + + +class AugmentedGenerator(Sequence): + """Wraps a generator with an alternative augmenter. + + Args: + gen (ImageLabel): Generator to wrap. + aug (augmentation.Augmenter): Augmenter to use. + """ + def __init__(self, gen, aug): + self._gen = gen + self._aug = aug + self.on_epoch_end() + + def __len__(self): + with augmented_generator(self._gen, self._aug) as g: + return len(g) + + @property + def batch_size(self): + return self._gen.batch_size + + @property + def shapes(self): + with augmented_generator(self._gen, self._aug) as g: + return g.shapes + + @property + def nsamples(self): + with augmented_generator(self._gen, self._aug) as g: + return g.nsamples + + @property + def rawshapes(self): + return self._gen.rawshapes + + @property + def pixelsizes(self): + return self._gen.pixelsizes + + def on_epoch_end(self): + with augmented_generator(self._gen, self._aug) as g: + g.on_epoch_end() + + @property + def ordering(self): + return self._gen.ordering + + @property + def n_pairs(self): + return self._gen.n_pairs + + def get_by_index(self, i, aug=None): + if aug is None: + with augmented_generator(self._gen, self._aug) as g: + return g.get_by_index(i) + else: + return self._gen.get_by_index(i, aug=aug) + + def parallel_get_indices(self, inds, **kwargs): + with augmented_generator(self._gen, self._aug) as g: + return g.parallel_get_indices(inds, **kwargs) + + def __getitem__(self, idx): + with augmented_generator(self._gen, self._aug) as g: + return g[idx] + + def plot_sample(self, *args, **kwargs): + with augmented_generator(self._gen, self._aug) as g: + return g.plot_sample(*args, **kwargs) class ImgLblInMem(Sequence): diff --git a/python/baby/io.py b/python/baby/io.py index f468bf451443099fa2f6262e51ba160b3492e773..858977ca056f084f4caa2ee32dd1ec72ec6f4571 100644 --- a/python/baby/io.py +++ b/python/baby/io.py @@ -1,12 +1,14 @@ # If you publish results that make use of this software or the Birth Annotator # for Budding Yeast algorithm, please cite: -# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S -# Swain, 2021, Birth Annotator for Budding Yeast (in preparation). +# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +# and Swain, P.S. (2023). Determining growth rates from bright-field images of +# budding cells through identifying overlaps. eLife. 12:e79812. +# https://doi.org/10.7554/eLife.79812 # # # The MIT License (MIT) # -# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021 +# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to @@ -29,7 +31,7 @@ import json import re from pathlib import Path -from typing import Union +from typing import Union, Tuple from fnmatch import translate as glob_to_re from os import walk from itertools import groupby, chain, repeat @@ -37,17 +39,38 @@ from collections import namedtuple, Counter import numpy as np import random import pandas as pd -from PIL.PngImagePlugin import PngInfo -from imageio import imread, imwrite from sklearn.model_selection import train_test_split +from PIL.PngImagePlugin import PngInfo +LEGACY_IIO = False +try: + from imageio import v3 as iio +except ImportError: + from imageio import imread, imwrite + LEGACY_IIO = True + from .errors import LayoutError, UnpairedImagesError from .utils import PathEncoder def load_tiled_image(filename): - tImg = imread(filename) - info = json.loads(tImg.meta.get('Description', '{}')) + if LEGACY_IIO: + tImg = imread(filename) + info = tImg.meta + else: + tImg = iio.imread(filename) + info = iio.immeta(filename) + # Due to a limitation in Pillow, 16-bit (uint16) PNG images are + # currently loaded as int32. This will apparently change in a future + # version of Pillow, but for the moment, manually force conversion to + # uint16. It is still possible to write 16-bit PNG images, and in my + # tests, images of dtype int32 are saved as uint16 (as are images of + # dtype uint32). Other dtypes (e.g., bool and uint8) load and write as + # expected. + if str(filename).lower().endswith('.png') and tImg.dtype == 'int32': + tImg = tImg.astype('uint16') + + info = json.loads(info.get('Description', '{}')) tw, th = info.get('tilesize', tImg.shape[0:2]) nt = info.get('ntiles', 1) nr, nc = info.get('layout', (1, 1)) @@ -61,7 +84,7 @@ def load_tiled_image(filename): return img, info -def save_tiled_image(img, filename, info={}, layout=None): +def save_tiled_image(img, filename, info={}, layout=None, fill=0): if layout is not None and len(layout) != 2: raise LayoutError('"layout" must a 2-tuple') @@ -82,7 +105,7 @@ def save_tiled_image(img, filename, info={}, layout=None): info['layout'] = (nr, nc) nc_final_row = np.mod(nt, nc) - tImg = np.zeros((tw * nr, th * nc), dtype=img.dtype) + tImg = np.full((tw * nr, th * nc), fill, dtype=img.dtype) for i in range(nr): i_nc = nc_final_row if i + 1 == nr and nc_final_row > 0 else nc for j in range(i_nc): @@ -91,8 +114,11 @@ def save_tiled_image(img, filename, info={}, layout=None): meta = PngInfo() meta.add_text('Description', json.dumps(info)) - imwrite(filename, tImg, format='png', pnginfo=meta, - prefer_uint8=tImg.dtype != 'uint16') + if LEGACY_IIO: + imwrite(filename, tImg, format='png', pnginfo=meta, + prefer_uint8=False) + else: + iio.imwrite(filename, tImg, extension='.png', pnginfo=meta) def load_paired_images(filenames, typeA='Brightfield', typeB='segoutlines'): @@ -274,7 +300,7 @@ class TrainValPairs(object): # case insensitive manner re_img = re.compile(r'^(.*)' + img_suffix + r'$', re.IGNORECASE) re_lbl = re.compile(r'^(.*)' + lbl_suffix + r'$', re.IGNORECASE) - png_files = sorted(Path(base_dir).rglob('*.png')) + png_files = sorted(Path(base_dir).resolve().rglob('*.png')) matches = [(re_img.search(f.stem), re_lbl.search(f.stem), f) for f in png_files] matches = [('img', im, f) if im else ('lbl', lm, f) @@ -329,6 +355,14 @@ class TrainValPairs(object): len(self.training), len(self.validation)) +IMAGE_INFO_GROUP_BY_MAP = { + 'experimentID': str, + 'position': int, + 'trap': int, + 'tp': int +} + + class TrainValTestPairs(object): @property @@ -502,6 +536,42 @@ class TrainValTestPairs(object): val_size=0.2, test_size=0.2, group_by=('experimentID', 'position', 'trap')): + """Search a directory for image/label pairs to add + + Images are assumed to be in PNG format (i.e., they must have extension + .png). The images can be annotated in the 'Description' meta data slot + using a JSON-encoded dictionary (see :py:function:`save_tiled_image` + and :py:function:`load_tiled_image`). + + To ensure that highly similar image/label pairs (e.g., from + consecutive time points) are not unfairly separated into training and + either the validation or test sets, pairs are allocated according to + their 'group', which is defined by the ``group_by`` parameter. This + requires the images to include meta data annotations as described + above (valid keys are currently limited to 'experimentID', 'position', + 'trap' and 'tp'). + + Args: + base_dir: base directory in which to search for images. All + subfolders are recursively searched. + img_suffix (str): the suffix (before .png extension) of all image + files. + lbl_suffix (str): the suffix (before .png extension) of all label + files. + val_size (float): fraction of files to split into validation set. + test_size (float): fraction of files to split into test set. + group_by (Union[str, Tuple[str]]): perform the split according to + groups defined by these tokens as found in the meta data of + each image. + """ + + # Check that all group_by values are valid + if type(group_by) == str: + group_by = (group_by,) + if not all((t in IMAGE_INFO_GROUP_BY_MAP for t in group_by)): + raise BadParam('group_by must be one of {}'.format( + ', '.join(f'"{k}"' for k in IMAGE_INFO_GROUP_BY_MAP.keys()))) + only_outlines = False if img_suffix is None: img_suffix='segoutlines' @@ -513,7 +583,7 @@ class TrainValTestPairs(object): # case insensitive manner re_img = re.compile(r'^(.*)' + img_suffix + r'$', re.IGNORECASE) re_lbl = re.compile(r'^(.*)' + lbl_suffix + r'$', re.IGNORECASE) - png_files = sorted(Path(base_dir).rglob('*.png')) + png_files = sorted(Path(base_dir).resolve().rglob('*.png')) matches = [(re_img.search(f.stem), re_lbl.search(f.stem), f) for f in png_files] matches = [('img', im, f) if im else ('lbl', lm, f) @@ -541,14 +611,18 @@ class TrainValTestPairs(object): if len(pairs) == 0: return - # Choose a split that ensures separation by group keys and avoids, - # e.g., splitting same cell but different time points + # Choose a split that ensures separation by group keys and can avoid, + # e.g., splitting same cell but different time points. info = [ json.loads(imread(l).meta.get('Description', '{}')) for _, l in pairs ] + # Collect grouping variables to split pairs up according to their + # group. NB: we force everything to be a string in case variables are + # loaded inconsistently from the JSON meta data. Any missing or False + # values are annotated with a unique label taken from enumeration: pair_groups = [ - tuple(i.get(f, 'missing_' + str(e)) + tuple(str(i.get(f) or 'missing_' + str(e)) for f in group_by) for e, i in enumerate(info) ] @@ -563,11 +637,16 @@ class TrainValTestPairs(object): unique_groups, [int(npairs*train_size), int(npairs * (1-test_size))]) - reformat = lambda exp, pos, trap : (exp, int(pos), int(trap)) + # BELOW CODE REMOVED SINCE I IT DOES NOT ACHIEVE WHAT I THINK IS THE + # INTENDED PURPOSE OF ORDERING THE OUTPUT + # reformat = lambda exp, pos, trap : (exp, int(pos), int(trap)) + # train_groups = set([reformat(*t) for t in train_groups]) + # val_groups = set([reformat(*t) for t in val_groups]) + # test_groups = set([reformat(*t) for t in test_groups]) - train_groups = set([reformat(*t) for t in train_groups]) - val_groups = set([reformat(*t) for t in val_groups]) - test_groups = set([reformat(*t) for t in test_groups]) + train_groups = {tuple(x) for x in train_groups} + val_groups = {tuple(x) for x in val_groups} + test_groups = {tuple(x) for x in test_groups} # Add new pairs to the existing train-val split self.training += [p for p, g in zip(pairs, pair_groups) if diff --git a/python/baby/layers.py b/python/baby/layers.py index cbf71a0f36fbbcb6afc4aa5f7491f63d367dffea..3d9c84de6074dadee9c816b41610915ede1040ff 100644 --- a/python/baby/layers.py +++ b/python/baby/layers.py @@ -1,12 +1,14 @@ # If you publish results that make use of this software or the Birth Annotator # for Budding Yeast algorithm, please cite: -# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S -# Swain, 2021, Birth Annotator for Budding Yeast (in preparation). +# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +# and Swain, P.S. (2023). Determining growth rates from bright-field images of +# budding cells through identifying overlaps. eLife. 12:e79812. +# https://doi.org/10.7554/eLife.79812 # # # The MIT License (MIT) # -# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021 +# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to @@ -25,97 +27,207 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS # IN THE SOFTWARE. -from tensorflow.python.keras.layers import ( +import numpy as np +from itertools import repeat +from string import ascii_lowercase as alphabet +from tensorflow.keras.layers import ( Conv2D, BatchNormalization, Activation, MaxPooling2D, - Conv2DTranspose, concatenate, Dropout + Conv2DTranspose, concatenate, Dropout, LayerNormalization, + add, DepthwiseConv2D ) ### U-NET layers ### -def conv_block(input_tensor, num_filters, stage, batchnorm=True, dropout=0.): - postfix = '_{}a'.format(stage + 1) - encoder = Conv2D(num_filters, (3, 3), padding='same', - name='enc_conv' + postfix)(input_tensor) +def res_block(input_tensor, nfilters, kernel=3, expand_ratio=4, drop=0., + activation='swish', prefix='', suffix='', use_bias=False, + init='glorot_uniform', pre_activate=False, post_activate=False, + convnext=False, **kwargs): + """A flexible residual block + + Primarily inspired from EfficientNet block, but without the squeeze and + excite block since that is unlikely to be important for the highly + application-specific models we typically train. + + If expand_ratio < 1 and post_activate = True, it is a special case of the + ResNeXt block where the number of groups equals the number of features. + + If convnext=True, then expansion occurs after depth-wise convolution, + batch normalisation is omitted in favour of a single layer normalisation + after depth-wise convolution and there is only a single activation in the + expanded space. + """ + + tmplt = '_'.join([e for e in (prefix, '{}', suffix) if len(e) > 0]) + lbl = lambda l: tmplt.format(l) + + x = input_tensor + + if not convnext: + if pre_activate: + x = BatchNormalization(name=lbl('pre_bn'))(x) + x = Activation(activation, name=lbl('pre_act'))(x) + + # Expand from bottleneck by a factor of expand_ratio + x = Conv2D(x.shape[-1] * expand_ratio, 1, padding='same', + use_bias=use_bias, + kernel_initializer=init, name=lbl('expand'))(x) + x = BatchNormalization(name=lbl('expand_bn'))(x) + x = Activation(activation, name=lbl('expand_act'))(x) + + # Mix in spatial dimension + x = DepthwiseConv2D(kernel, padding='same', use_bias=use_bias, + depthwise_initializer=init, name=lbl('mixXY'))(x) + if convnext: + x = LayerNormalization(epsilon=1e-6, name=lbl('mixXY_ln'))(x) + # Expand from bottleneck + x = Conv2D(x.shape[-1] * expand_ratio, 1, padding='same', + use_bias=use_bias, kernel_initializer=init, + name=lbl('expand'))(x) + x = Activation(activation, name=lbl('expand_act'))(x) + else: + # Do standard norm/activation + x = BatchNormalization(name=lbl('mixXY_bn'))(x) + x = Activation(activation, name=lbl('mixXY_act'))(x) + + # Mix across features + x = Conv2D(nfilters, 1, padding='same', use_bias=use_bias, + kernel_initializer=init, name=lbl('proj'))(x) + if not pre_activate and not convnext: + x = BatchNormalization(name=lbl('proj_bn'))(x) + + # Stochastically drop this entire block + if drop > 0: + x = Dropout(drop, noise_shape=(None, 1, 1, 1), name=lbl('drop'))(x) + + # Sum at bottleneck (though, will not be bottleneck if expand_ratio > 1) + if input_tensor.shape[-1] != nfilters: + input_tensor = Conv2D(nfilters, 1, padding='same', use_bias=use_bias, + kernel_initializer=init, name=lbl('input_proj'))(input_tensor) + x = add([input_tensor, x], name=lbl('add')) + + # Post activation to match ResNeXt models + if post_activate and not pre_activate and not convnext: + x = Activation(activation, name=lbl('add_act'))(x) + + return x + + +def conv_block(input_tensor, nfilters, kernel=3, prefix='', suffix='', + batchnorm=True, dropout=0., activation='relu', init='glorot_uniform', + **kwargs): + """Standard convolution with batch norm and activation + """ + + tmplt = '_'.join([e for e in (prefix, '{}', suffix) if len(e) > 0]) + lbl = lambda l: tmplt.format(l) + + x = input_tensor + x = Conv2D(nfilters, kernel, padding='same', kernel_initializer=init, + name=lbl('conv'))(x) if batchnorm: - encoder = BatchNormalization(name='enc_bn' + postfix)(encoder) - encoder = Activation('relu', name='enc_act' + postfix)(encoder) + x = BatchNormalization(name=lbl('conv_bn'))(x) + x = Activation(activation, name=lbl('conv_act'))(x) if dropout > 0: - encoder = Dropout(dropout, name='enc_dropout' + postfix)(encoder) + x = Dropout(dropout, name=lbl('conv_dropout'))(x) + + return x - postfix = '_{}b'.format(stage + 1) - encoder = Conv2D(num_filters, (3, 3), padding='same', - name='enc_conv' + postfix)(encoder) - if batchnorm: - encoder = BatchNormalization(name='enc_bn' + postfix)(encoder) - encoder = Activation('relu', name='enc_act' + postfix)(encoder) - return encoder +def encoder_block(input_tensor, nfilters, stage, repeats=2, block=conv_block, + dropout=0., drop=repeat(0.), batchnorm=True, conv_pool=False, + init='glorot_uniform', **kwargs): + x = input_tensor + for i in range(repeats - 1): + x = block(x, nfilters, dropout=dropout, batchnorm=batchnorm, + prefix='enc', init=init, suffix=f'{stage + 1}{alphabet[i]}', + **kwargs) + encoder = block(x, nfilters, batchnorm=batchnorm, init=init, prefix='enc', + suffix=f'{stage + 1}{alphabet[repeats - 1]}', **kwargs) + if conv_pool: + encoder_pool = DepthwiseConv2D(2, strides=2, padding='same', + kernel_initializer=init, name=f'down_{stage + 1}')(encoder) + if batchnorm: + encoder_pool = BatchNormalization( + name=f'down_bn_{stage + 1}')(encoder_pool) + else: + encoder_pool = MaxPooling2D( + (2, 2), strides=(2, 2), name=f'down_{stage + 1}')(encoder) -def encoder_block(input_tensor, num_filters, stage, batchnorm=True, - dropout=0.): - encoder = conv_block(input_tensor, num_filters, stage, - batchnorm=batchnorm, - dropout=dropout) - encoder_pool = MaxPooling2D( - (2, 2), strides=(2, 2), name='down_{}'.format(stage + 1))(encoder) return encoder_pool, encoder -def decoder_block(input_tensor, concat_tensor, num_filters, stage, - batchnorm=True, prename='', dropout=0.): - postfix = '_{}'.format(stage + 1) - decoder = Conv2DTranspose(num_filters, (2, 2), strides=(2, 2), padding='same', - name=prename + 'up' + postfix)(input_tensor) - decoder = concatenate([concat_tensor, decoder], axis=-1, - name=prename + 'skip' + postfix) - if batchnorm: - decoder = BatchNormalization(name=prename + 'up_bn' + postfix)(decoder) - decoder = Activation('relu', name=prename + 'up_act' + postfix)(decoder) - if dropout > 0: - decoder = Dropout(dropout, name='dec_dropout' + postfix)(decoder) +def decoder_block(input_tensor, skip_tensor, nfilters, stage, repeats=2, + block=conv_block, dropout=0., drop=repeat(0.), batchnorm=True, + prename='', init='glorot_uniform', residual_skip=False, + up_activate=True, activation='relu', **kwargs): - postfix = '_{}a'.format(stage + 1) - decoder = Conv2D(num_filters, (3, 3), padding='same', - name=prename + 'dec_conv' + postfix)(decoder) + x = input_tensor + x = Conv2DTranspose(nfilters, 2, strides=2, padding='same', + kernel_initializer=init, name=prename + f'up_{stage + 1}')(x) if batchnorm: - decoder = BatchNormalization(name=prename + 'dec_bn' + postfix)(decoder) - decoder = Activation('relu', name=prename + 'dec_act' + postfix)(decoder) - if dropout > 0: - decoder = Dropout(dropout, name='dec_dropout' + postfix)(decoder) + x = BatchNormalization(name=prename + f'up_bn_{stage + 1}')(x) + if up_activate: + x = Activation(activation, name=prename + f'up_act_{stage + 1}')(x) + if residual_skip: + x = add([skip_tensor, x], name=prename + f'up_skip_{stage + 1}') + else: + x = concatenate([skip_tensor, x], axis=-1, + name=prename + f'up_skip_{stage + 1}') + + drop = iter(drop) + for i in range(repeats - 1): + x = block(x, nfilters, dropout=dropout, drop=next(drop), + batchnorm=batchnorm, init=init, prefix='dec', + suffix=f'{stage + 1}{alphabet[i]}', **kwargs) + decoder = block(x, nfilters, drop=next(drop), batchnorm=batchnorm, + init=init, prefix='dec', + suffix=f'{stage + 1}{alphabet[repeats - 1]}', **kwargs) - postfix = '_{}b'.format(stage + 1) - decoder = Conv2D(num_filters, (3, 3), padding='same', - name=prename + 'dec_conv' + postfix)(decoder) - if batchnorm: - decoder = BatchNormalization(name=prename + 'dec_bn' + postfix)(decoder) - decoder = Activation('relu', name=prename + 'dec_act' + postfix)(decoder) - if dropout > 0: - decoder = Dropout(dropout, name='dec_dropout' + postfix)(decoder) return decoder -def unet_block(input_tensor, layer_sizes, batchnorm=True, dropout=0.): +def unet_block(input_tensor, layer_sizes, enc_repeats=2, dec_repeats=2, + drop=0., dropout=0., block=conv_block, stem=False, **kwargs): + + nlayers = len(layer_sizes) + + # Rate of stochastic block dropping increases linearly with depth + drop = iter(np.linspace(0, drop, nlayers * enc_repeats + (nlayers - 1) * + dec_repeats)) + + if stem: + x = input_tensor + x = Conv2D(layer_sizes[0], 3, padding='same', name='stem')(x) + if not kwargs.get('pre_activate', False): + x = BatchNormalization(name='stem_bn')(x) + x = Activation(kwargs.get('activation', 'relu'), name='stem_act')(x) + input_tensor = x + # Encoding upper_layer = input_tensor encoding_layers = [] - for i, num_filters in enumerate(layer_sizes[:-1]): - upper_layer, encoder = encoder_block(upper_layer, num_filters, i, - batchnorm=batchnorm, - dropout=dropout) + for i, nfilters in enumerate(layer_sizes[:-1]): + upper_layer, encoder = encoder_block(upper_layer, nfilters, i, + repeats=enc_repeats, dropout=dropout, drop=drop, block=block, + **kwargs) encoding_layers.append(encoder) # Centre - lower_layer = conv_block(upper_layer, layer_sizes[-1], len(layer_sizes) - 1, - batchnorm=batchnorm, dropout=dropout) + x = upper_layer + for i in range(enc_repeats - 1): + x = block(x, layer_sizes[-1], dropout=dropout, drop=next(drop), + prefix='enc', suffix=f'{nlayers}{alphabet[i]}', **kwargs) + lower_layer = block(x, layer_sizes[-1], drop=next(drop), prefix='enc', + suffix=f'{nlayers}{alphabet[enc_repeats - 1]}', **kwargs) # Decoding - for i, num_filters in reversed(list(enumerate(layer_sizes[:-1]))): - lower_layer = decoder_block(lower_layer, encoding_layers[i], - num_filters, i, batchnorm=batchnorm, - dropout=dropout) + for i, nfilters in reversed(list(enumerate(layer_sizes[:-1]))): + lower_layer = decoder_block(lower_layer, encoding_layers[i], nfilters, + i, repeats=dec_repeats, dropout=dropout, drop=drop, + block=block, **kwargs) return lower_layer diff --git a/python/baby/losses.py b/python/baby/losses.py index 18231b088fe1409ea046c411b75419a508c984af..cc8079e917e94862454f712bb9f884b9ccf5d11c 100644 --- a/python/baby/losses.py +++ b/python/baby/losses.py @@ -1,12 +1,14 @@ # If you publish results that make use of this software or the Birth Annotator # for Budding Yeast algorithm, please cite: -# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S -# Swain, 2021, Birth Annotator for Budding Yeast (in preparation). +# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +# and Swain, P.S. (2023). Determining growth rates from bright-field images of +# budding cells through identifying overlaps. eLife. 12:e79812. +# https://doi.org/10.7554/eLife.79812 # # # The MIT License (MIT) # -# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021 +# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to @@ -26,7 +28,7 @@ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS # IN THE SOFTWARE. import tensorflow as tf -from tensorflow.python.keras.losses import binary_crossentropy +from tensorflow.keras.losses import binary_crossentropy ### LOSS FUNCTIONS ### @@ -37,9 +39,10 @@ def dice_coeff(y_true, y_pred): # Flatten y_true_f = tf.reshape(y_true, [-1]) y_pred_f = tf.reshape(y_pred, [-1]) - intersection = tf.reduce_sum(y_true_f * y_pred_f) + intersection = tf.reduce_sum(tf.boolean_mask(y_pred_f, y_true_f)) score = ((2. * intersection + smooth) / - (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)) + (tf.math.count_nonzero(y_true_f, dtype=tf.dtypes.float32) + + tf.reduce_sum(y_pred_f) + smooth)) return score diff --git a/python/baby/models.py b/python/baby/models.py index ebf47ccd21e9a940ade15deb319bf1daa90c6365..9d3d87705ad81625b825cce063b3216796e2641b 100644 --- a/python/baby/models.py +++ b/python/baby/models.py @@ -1,12 +1,14 @@ # If you publish results that make use of this software or the Birth Annotator # for Budding Yeast algorithm, please cite: -# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S -# Swain, 2021, Birth Annotator for Budding Yeast (in preparation). +# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +# and Swain, P.S. (2023). Determining growth rates from bright-field images of +# budding cells through identifying overlaps. eLife. 12:e79812. +# https://doi.org/10.7554/eLife.79812 # # # The MIT License (MIT) # -# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021 +# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to @@ -25,12 +27,13 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS # IN THE SOFTWARE. -from tensorflow.python.keras.models import Model -from tensorflow.python.keras.optimizers import Adam -from tensorflow.python.keras.layers import Input +from tensorflow.keras.models import Model +from tensorflow.keras.optimizers import Adam +from tensorflow.keras.layers import Input +from tensorflow.keras.initializers import VarianceScaling from .utils import named_obj -from .layers import msd_block, unet_block, make_outputs +from .layers import msd_block, unet_block, conv_block, res_block, make_outputs from .losses import bce_dice_loss, dice_coeff @@ -39,13 +42,18 @@ def named_model_fn(name): def wrap(f): @named_obj(name) - def model_fn(generator, flattener, weights={}, **kwargs): + def model_fn(generator, flattener, weights={}, use_adamw=False, **kwargs): weights = {n: weights.get(n, 1) for n in flattener.names()} inputs = Input(shape=generator.shapes.input[1:]) model = Model(inputs=[inputs], outputs=make_outputs(f(inputs, **kwargs), flattener.names())) - model.compile(optimizer=Adam(amsgrad=False), + if use_adamw: + from tensorflow_addons.optimizers import AdamW + optimizer = AdamW(weight_decay=0.00025) + else: + optimizer = Adam(amsgrad=False) + model.compile(optimizer=optimizer, metrics=[dice_coeff], loss=bce_dice_loss, loss_weights=weights) @@ -75,6 +83,46 @@ def unet(inputs, depth=4, layer_size=8, batchnorm=True, dropout=0.): dropout=dropout, batchnorm=batchnorm) +@named_model_fn('unet3') +def unet3(inputs, depth=3, layer_size=8, batchnorm=True, dropout=0.): + layer_sizes = [layer_size*(2**i) for i in range(depth)] + return unet_block(inputs, layer_sizes, + dropout=dropout, batchnorm=batchnorm) + + +@named_model_fn('unet_hyper') +def unet_hyper(inputs, width=64, depth=4, initializer='glorot_uniform', + block_type='conv', **kwargs): + layer_sizes = [width*(2**i) for i in range(depth)] + if initializer == 'variance_scaling': + initializer = VarianceScaling(2., mode='fan_out') + block_args = { + 'conv': dict(block=conv_block, stem=False), + 'effnet': dict(block=res_block, stem=True), + 'effnet-preact': dict(block=res_block, pre_activate=True, stem=True), + 'convnext': dict(block=res_block, stem=True, convnext=True) + }[block_type] + block_args.update(kwargs) + return unet_block(inputs, layer_sizes, init=initializer, **block_args) + + +@named_obj('unet_convnext') +def unet_convnext(generator, flattener, weights={}, width=8, depth=4, + kernel=7, enc_repeats=3, dec_repeats=2, expand_ratio=4, + activation='swish', initializer='variance_scaling', **kwargs): + """U-net model with ConvNeXt blocks and defaults + + With default parameters, produces a model of similar size to the default + unet model here (depth 4, width 8). + """ + return unet_hyper(generator, flattener, weights=weights, width=width, + depth=depth, kernel=kernel, enc_repeats=enc_repeats, + dec_repeats=dec_repeats, block_type='convnext', + expand_ratio=expand_ratio, activation=activation, + initializer=initializer, use_adamw=True, residual_skip=True, + conv_pool=True, up_activate=False, **kwargs) + + @named_model_fn('msd') def msd(inputs, depth=80, width=1, n_dilations=4, dilation=1, batchnorm=True): dilations = [dilation * (2 ** i) for i in range(n_dilations)] diff --git a/python/baby/models/I1_evolve_unet_4s_20200630.hdf5 b/python/baby/models/I1_evolve_unet_4s_20200630.hdf5 deleted file mode 100644 index cf4b862e5aa88bf1534edebf7fdc06eda316727f..0000000000000000000000000000000000000000 Binary files a/python/baby/models/I1_evolve_unet_4s_20200630.hdf5 and /dev/null differ diff --git a/python/baby/models/I1_evolve_unet_4s_20210302.hdf5 b/python/baby/models/I1_evolve_unet_4s_20210302.hdf5 deleted file mode 100644 index 6f5d000647bd1e235672b6fea4105bcce9289ee4..0000000000000000000000000000000000000000 Binary files a/python/baby/models/I1_evolve_unet_4s_20210302.hdf5 and /dev/null differ diff --git a/python/baby/models/I1_prime_unet_4s_20200630.hdf5 b/python/baby/models/I1_prime_unet_4s_20200630.hdf5 deleted file mode 100644 index f5f890a31e82e211dea7b4ae890e769f707c48dc..0000000000000000000000000000000000000000 Binary files a/python/baby/models/I1_prime_unet_4s_20200630.hdf5 and /dev/null differ diff --git a/python/baby/models/I1_prime_unet_4s_20210302.hdf5 b/python/baby/models/I1_prime_unet_4s_20210302.hdf5 deleted file mode 100644 index 5093dba134babbcb94c35df94f022305dbcbd24c..0000000000000000000000000000000000000000 Binary files a/python/baby/models/I1_prime_unet_4s_20210302.hdf5 and /dev/null differ diff --git a/python/baby/models/I3_evolve_unet_4s_20200903.hdf5 b/python/baby/models/I3_evolve_unet_4s_20200903.hdf5 deleted file mode 100644 index 785ba76d57a1ed9272f0d0a3a81dce63b0cd081e..0000000000000000000000000000000000000000 Binary files a/python/baby/models/I3_evolve_unet_4s_20200903.hdf5 and /dev/null differ diff --git a/python/baby/models/I3_evolve_unet_4s_20210215.hdf5 b/python/baby/models/I3_evolve_unet_4s_20210215.hdf5 deleted file mode 100644 index fede45d6909cb21be94c2940cf28803eeb7f7d85..0000000000000000000000000000000000000000 Binary files a/python/baby/models/I3_evolve_unet_4s_20210215.hdf5 and /dev/null differ diff --git a/python/baby/models/I3_prime_unet_4s_20200903.hdf5 b/python/baby/models/I3_prime_unet_4s_20200903.hdf5 deleted file mode 100644 index f4a1dba19ca8be6bf8351ae593226ac65bdf7d19..0000000000000000000000000000000000000000 Binary files a/python/baby/models/I3_prime_unet_4s_20200903.hdf5 and /dev/null differ diff --git a/python/baby/models/I3_prime_unet_4s_20210215.hdf5 b/python/baby/models/I3_prime_unet_4s_20210215.hdf5 deleted file mode 100644 index be7e266f30a0f03ba0c502655fc58ab809d483d4..0000000000000000000000000000000000000000 Binary files a/python/baby/models/I3_prime_unet_4s_20210215.hdf5 and /dev/null differ diff --git a/python/baby/models/I5_evolve_unet_4s_20210120.hdf5 b/python/baby/models/I5_evolve_unet_4s_20210120.hdf5 deleted file mode 100644 index f7b133276b73e3c30fd2b3bab81309e0e694addd..0000000000000000000000000000000000000000 Binary files a/python/baby/models/I5_evolve_unet_4s_20210120.hdf5 and /dev/null differ diff --git a/python/baby/models/I5_prime_unet_4s_20210120.hdf5 b/python/baby/models/I5_prime_unet_4s_20210120.hdf5 deleted file mode 100644 index 89afdb85d621332913a9fd1816a7d03c00a24277..0000000000000000000000000000000000000000 Binary files a/python/baby/models/I5_prime_unet_4s_20210120.hdf5 and /dev/null differ diff --git a/python/baby/models/ct_rf_20210125_9.pkl b/python/baby/models/ct_rf_20210125_9.pkl deleted file mode 100644 index a86a15b55c53f4a92a71e91c732ef32355ab6164..0000000000000000000000000000000000000000 Binary files a/python/baby/models/ct_rf_20210125_9.pkl and /dev/null differ diff --git a/python/baby/models/ct_rf_20210201_12.pkl b/python/baby/models/ct_rf_20210201_12.pkl deleted file mode 100644 index aff4247d0a4973186315b36f19ac6f1caf4c2057..0000000000000000000000000000000000000000 Binary files a/python/baby/models/ct_rf_20210201_12.pkl and /dev/null differ diff --git a/python/baby/models/ct_rf_20210201_4.pkl b/python/baby/models/ct_rf_20210201_4.pkl deleted file mode 100644 index cc5cd715a297ed113be2d8503c96e9d4d91a67f1..0000000000000000000000000000000000000000 Binary files a/python/baby/models/ct_rf_20210201_4.pkl and /dev/null differ diff --git a/python/baby/models/flattener_60x_1z_evolve_20210302.json b/python/baby/models/flattener_60x_1z_evolve_20210302.json deleted file mode 100644 index b1333bbc17f6b11554750809dda4bf8a9cf0fe5c..0000000000000000000000000000000000000000 --- a/python/baby/models/flattener_60x_1z_evolve_20210302.json +++ /dev/null @@ -1 +0,0 @@ -{"groupdef": {"small": {"_python_NamedTuple": {"lower": 1, "upper": 134, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "medium": {"_python_NamedTuple": {"lower": 101, "upper": 177, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "large": {"_python_NamedTuple": {"lower": 143, "upper": Infinity, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "buds": {"_python_NamedTuple": {"lower": 1, "upper": 200, "budonly": true, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}}, "groupprops": {"small": {"_python_set": ["filled", "edge"]}, "medium": {"_python_set": ["edge", "overlap", "interior"]}, "large": {"_python_set": ["edge", "overlap", "interior"]}, "buds": {"_python_set": ["budneck"]}}, "targets": [{"_python_NamedTuple": {"name": "lge_inte", "group": "large", "prop": "interior", "nerode": 4, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "lge_edge", "group": "large", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "mid_inte", "group": "medium", "prop": "interior", "nerode": 3, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "mid_edge", "group": "medium", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "sml_inte", "group": "small", "prop": "filled", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "sml_edge", "group": "small", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "bud_neck", "group": "buds", "prop": "budneck", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}]} \ No newline at end of file diff --git a/python/baby/models/flattener_60x_1z_prime_20210302.json b/python/baby/models/flattener_60x_1z_prime_20210302.json deleted file mode 100644 index 4bd3545c288f66539370e1db9da2bd69a61291bf..0000000000000000000000000000000000000000 --- a/python/baby/models/flattener_60x_1z_prime_20210302.json +++ /dev/null @@ -1 +0,0 @@ -{"groupdef": {"small": {"_python_NamedTuple": {"lower": 1, "upper": 258, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "medium": {"_python_NamedTuple": {"lower": 192, "upper": 352, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "large": {"_python_NamedTuple": {"lower": 286, "upper": Infinity, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "buds": {"_python_NamedTuple": {"lower": 1, "upper": 300, "budonly": true, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}}, "groupprops": {"small": {"_python_set": ["filled", "edge"]}, "medium": {"_python_set": ["overlap", "edge", "interior"]}, "large": {"_python_set": ["overlap", "edge", "interior"]}, "buds": {"_python_set": ["budneck"]}}, "targets": [{"_python_NamedTuple": {"name": "lge_inte", "group": "large", "prop": "interior", "nerode": 5, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "lge_edge", "group": "large", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "mid_inte", "group": "medium", "prop": "interior", "nerode": 5, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "mid_edge", "group": "medium", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "sml_inte", "group": "small", "prop": "filled", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "sml_edge", "group": "small", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "bud_neck", "group": "buds", "prop": "budneck", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}]} \ No newline at end of file diff --git a/python/baby/models/flattener_60x_3z_evolve_20210215.json b/python/baby/models/flattener_60x_3z_evolve_20210215.json deleted file mode 100644 index 1f75baa878fad6c46cdfc539b888972fa9fdcc82..0000000000000000000000000000000000000000 --- a/python/baby/models/flattener_60x_3z_evolve_20210215.json +++ /dev/null @@ -1 +0,0 @@ -{"groupdef": {"small": {"_python_NamedTuple": {"lower": 1, "upper": 133, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "medium": {"_python_NamedTuple": {"lower": 100, "upper": 175, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "large": {"_python_NamedTuple": {"lower": 142, "upper": Infinity, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "buds": {"_python_NamedTuple": {"lower": 1, "upper": 200, "budonly": true, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}}, "groupprops": {"small": {"_python_set": ["filled", "edge"]}, "medium": {"_python_set": ["overlap", "edge", "interior"]}, "large": {"_python_set": ["overlap", "edge", "interior"]}, "buds": {"_python_set": ["budneck"]}}, "targets": [{"_python_NamedTuple": {"name": "lge_inte", "group": "large", "prop": "interior", "nerode": 4, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "lge_edge", "group": "large", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "mid_inte", "group": "medium", "prop": "interior", "nerode": 3, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "mid_edge", "group": "medium", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "sml_inte", "group": "small", "prop": "filled", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "sml_edge", "group": "small", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "bud_neck", "group": "buds", "prop": "budneck", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}]} \ No newline at end of file diff --git a/python/baby/models/flattener_60x_3z_prime_20210215.json b/python/baby/models/flattener_60x_3z_prime_20210215.json deleted file mode 100644 index a35852ef72d755894d4397f0337373f8bf2b567b..0000000000000000000000000000000000000000 --- a/python/baby/models/flattener_60x_3z_prime_20210215.json +++ /dev/null @@ -1 +0,0 @@ -{"groupdef": {"small": {"_python_NamedTuple": {"lower": 1, "upper": 248, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "medium": {"_python_NamedTuple": {"lower": 181, "upper": 319, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "large": {"_python_NamedTuple": {"lower": 252, "upper": Infinity, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "buds": {"_python_NamedTuple": {"lower": 1, "upper": 300, "budonly": true, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}}, "groupprops": {"small": {"_python_set": ["filled", "edge"]}, "medium": {"_python_set": ["overlap", "edge", "interior"]}, "large": {"_python_set": ["overlap", "edge", "interior"]}, "buds": {"_python_set": ["budneck"]}}, "targets": [{"_python_NamedTuple": {"name": "lge_inte", "group": "large", "prop": "interior", "nerode": 5, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "lge_edge", "group": "large", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "mid_inte", "group": "medium", "prop": "interior", "nerode": 3, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "mid_edge", "group": "medium", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "sml_inte", "group": "small", "prop": "filled", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "sml_edge", "group": "small", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "bud_neck", "group": "buds", "prop": "budneck", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}]} \ No newline at end of file diff --git a/python/baby/models/flattener_60x_evolve_20200630.json b/python/baby/models/flattener_60x_evolve_20200630.json deleted file mode 100644 index bf693db44d61dac6be5ea26a3932cddf5de6ce9e..0000000000000000000000000000000000000000 --- a/python/baby/models/flattener_60x_evolve_20200630.json +++ /dev/null @@ -1 +0,0 @@ -{"groupdef": {"small": {"_python_NamedTuple": {"lower": 1, "upper": 104, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "medium": {"_python_NamedTuple": {"lower": 76, "upper": 153, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "large": {"_python_NamedTuple": {"lower": 124, "upper": Infinity, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "buds": {"_python_NamedTuple": {"lower": 1, "upper": 200, "budonly": true, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}}, "groupprops": {"small": {"_python_set": ["filled", "edge"]}, "medium": {"_python_set": ["interior", "overlap", "edge"]}, "large": {"_python_set": ["interior", "overlap", "edge"]}, "buds": {"_python_set": ["budneck"]}}, "targets": [{"_python_NamedTuple": {"name": "lge_inte", "group": "large", "prop": "interior", "nerode": 3, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "lge_edge", "group": "large", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "mid_inte", "group": "medium", "prop": "interior", "nerode": 2, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "mid_edge", "group": "medium", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "sml_inte", "group": "small", "prop": "filled", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "sml_edge", "group": "small", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "bud_neck", "group": "buds", "prop": "budneck", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}]} \ No newline at end of file diff --git a/python/baby/models/flattener_60x_evolve_20210120.json b/python/baby/models/flattener_60x_evolve_20210120.json deleted file mode 100644 index 7c06d8316c0e5421083adbfe96437e6d7402e0a2..0000000000000000000000000000000000000000 --- a/python/baby/models/flattener_60x_evolve_20210120.json +++ /dev/null @@ -1 +0,0 @@ -{"groupdef": {"small": {"_python_NamedTuple": {"lower": 1, "upper": 133, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "medium": {"_python_NamedTuple": {"lower": 100, "upper": 176, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "large": {"_python_NamedTuple": {"lower": 143, "upper": Infinity, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "buds": {"_python_NamedTuple": {"lower": 1, "upper": 200, "budonly": true, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}}, "groupprops": {"small": {"_python_set": ["edge", "filled"]}, "medium": {"_python_set": ["overlap", "edge", "interior"]}, "large": {"_python_set": ["overlap", "edge", "interior"]}, "buds": {"_python_set": ["budneck"]}}, "targets": [{"_python_NamedTuple": {"name": "lge_inte", "group": "large", "prop": "interior", "nerode": 4, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "lge_edge", "group": "large", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "mid_inte", "group": "medium", "prop": "interior", "nerode": 3, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "mid_edge", "group": "medium", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "sml_inte", "group": "small", "prop": "filled", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "sml_edge", "group": "small", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "bud_neck", "group": "buds", "prop": "budneck", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}]} \ No newline at end of file diff --git a/python/baby/models/flattener_60x_prime_20200630.json b/python/baby/models/flattener_60x_prime_20200630.json deleted file mode 100644 index d156a0d3cd015ca234b6853130ab75a2bea53ae6..0000000000000000000000000000000000000000 --- a/python/baby/models/flattener_60x_prime_20200630.json +++ /dev/null @@ -1 +0,0 @@ -{"groupdef": {"small": {"_python_NamedTuple": {"lower": 1, "upper": 285, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "medium": {"_python_NamedTuple": {"lower": 227, "upper": 340, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "large": {"_python_NamedTuple": {"lower": 282, "upper": Infinity, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "buds": {"_python_NamedTuple": {"lower": 1, "upper": 300, "budonly": true, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}}, "groupprops": {"small": {"_python_set": ["filled", "edge"]}, "medium": {"_python_set": ["interior", "overlap", "edge"]}, "large": {"_python_set": ["interior", "overlap", "edge"]}, "buds": {"_python_set": ["budneck"]}}, "targets": [{"_python_NamedTuple": {"name": "lge_inte", "group": "large", "prop": "interior", "nerode": 5, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "lge_edge", "group": "large", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "mid_inte", "group": "medium", "prop": "interior", "nerode": 5, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "mid_edge", "group": "medium", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "sml_inte", "group": "small", "prop": "filled", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "sml_edge", "group": "small", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "bud_neck", "group": "buds", "prop": "budneck", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}]} \ No newline at end of file diff --git a/python/baby/models/flattener_60x_prime_20210120.json b/python/baby/models/flattener_60x_prime_20210120.json deleted file mode 100644 index 03cb6e0dd99aaf62459221cb19e68cadc70c140b..0000000000000000000000000000000000000000 --- a/python/baby/models/flattener_60x_prime_20210120.json +++ /dev/null @@ -1 +0,0 @@ -{"groupdef": {"small": {"_python_NamedTuple": {"lower": 1, "upper": 255, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "medium": {"_python_NamedTuple": {"lower": 189, "upper": 319, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "large": {"_python_NamedTuple": {"lower": 252, "upper": Infinity, "budonly": false, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}, "buds": {"_python_NamedTuple": {"lower": 1, "upper": 300, "budonly": true, "focus": null}, "__module__": "baby.preprocessing", "__class__": "CellGroup"}}, "groupprops": {"small": {"_python_set": ["edge", "filled"]}, "medium": {"_python_set": ["edge", "overlap", "interior"]}, "large": {"_python_set": ["edge", "overlap", "interior"]}, "buds": {"_python_set": ["budneck"]}}, "targets": [{"_python_NamedTuple": {"name": "lge_inte", "group": "large", "prop": "interior", "nerode": 5, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "lge_edge", "group": "large", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "mid_inte", "group": "medium", "prop": "interior", "nerode": 4, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "mid_edge", "group": "medium", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "sml_inte", "group": "small", "prop": "filled", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "sml_edge", "group": "small", "prop": "edge", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}, {"_python_NamedTuple": {"name": "bud_neck", "group": "buds", "prop": "budneck", "nerode": 0, "ndilate": 0, "ndilate_overlaps": 0, "ndilate_mother": 2}, "__module__": "baby.preprocessing", "__class__": "PredTarget"}]} \ No newline at end of file diff --git a/python/baby/models/mb_model_20201022.pkl b/python/baby/models/mb_model_20201022.pkl deleted file mode 100644 index e25c52bdf6a7debca807a6de540cf36e673eb0e6..0000000000000000000000000000000000000000 Binary files a/python/baby/models/mb_model_20201022.pkl and /dev/null differ diff --git a/python/baby/models/mb_model_60x_1z_evolve_210302.pkl b/python/baby/models/mb_model_60x_1z_evolve_210302.pkl deleted file mode 100644 index bad010d4ea03f38ecc1d90828f2c667792459936..0000000000000000000000000000000000000000 Binary files a/python/baby/models/mb_model_60x_1z_evolve_210302.pkl and /dev/null differ diff --git a/python/baby/models/mb_model_60x_1z_prime_210302.pkl b/python/baby/models/mb_model_60x_1z_prime_210302.pkl deleted file mode 100644 index 32d399892f0bfe40d0a1d9426a3ca08e607b32ea..0000000000000000000000000000000000000000 Binary files a/python/baby/models/mb_model_60x_1z_prime_210302.pkl and /dev/null differ diff --git a/python/baby/models/mb_model_60x_3z_evolve_20210215.pkl b/python/baby/models/mb_model_60x_3z_evolve_20210215.pkl deleted file mode 100644 index fb2b349dc580a4ef2062ce0d4b2fab5e8abe69b2..0000000000000000000000000000000000000000 Binary files a/python/baby/models/mb_model_60x_3z_evolve_20210215.pkl and /dev/null differ diff --git a/python/baby/models/mb_model_60x_3z_prime_20210215.pkl b/python/baby/models/mb_model_60x_3z_prime_20210215.pkl deleted file mode 100644 index 97a6fd1ca0a3ebe0a9093cff56e3537e1f1c4bcb..0000000000000000000000000000000000000000 Binary files a/python/baby/models/mb_model_60x_3z_prime_20210215.pkl and /dev/null differ diff --git a/python/baby/models/mb_model_60x_evolve_20210120.pkl b/python/baby/models/mb_model_60x_evolve_20210120.pkl deleted file mode 100644 index 3b86ce1254f8c1205f39004c816d80ff35f7461c..0000000000000000000000000000000000000000 Binary files a/python/baby/models/mb_model_60x_evolve_20210120.pkl and /dev/null differ diff --git a/python/baby/models/mb_model_60x_prime_20210120.pkl b/python/baby/models/mb_model_60x_prime_20210120.pkl deleted file mode 100644 index 8c1fb44059d2b71e135d51014320ca0799959fb6..0000000000000000000000000000000000000000 Binary files a/python/baby/models/mb_model_60x_prime_20210120.pkl and /dev/null differ diff --git a/python/baby/modelsets.json b/python/baby/modelsets.json deleted file mode 100644 index fcdaf59a60a683ec5a82e558b6c82311cd44ce22..0000000000000000000000000000000000000000 --- a/python/baby/modelsets.json +++ /dev/null @@ -1,103 +0,0 @@ -{ - "prime95b_brightfield_60x_5z": { - "morph_model_file": "I5_prime_unet_4s_20210120.hdf5", - "flattener_file": "flattener_60x_prime_20210120.json", - "celltrack_model_file": "ct_rf_20210201_12.pkl", - "budassign_model_file": "mb_model_60x_prime_20210120.pkl", - "pixel_size": 0.182, - "default_image_size": [117, 117], - "params": { - "interior_threshold": [0.95, 0.5, 0.35], - "nclosing": [0, 0, 0], "nopening": [2, 0, 0], - "connectivity": [1, 1, 2], - "edge_sub_dilations": [0, 0, 0], - "containment_thresh": 0.85, "min_area": 19, - "pedge_thresh": [0.0, 0.0028, 0.0012], - "group_thresh_expansion": [0.28, 0.06, 0.32] - } - }, - "evolve_brightfield_60x_5z": { - "morph_model_file": "I5_evolve_unet_4s_20210120.hdf5", - "flattener_file": "flattener_60x_evolve_20210120.json", - "celltrack_model_file": "ct_rf_20210201_12.pkl", - "budassign_model_file": "mb_model_60x_evolve_20210120.pkl", - "pixel_size": 0.263, - "default_image_size": [81, 81], - "params": { - "interior_threshold": [0.9, 0.65, 0.8], - "nclosing": [0, 1, 0], "nopening": [1, 2, 0], - "connectivity": [1, 1, 1], - "edge_sub_dilations": [0, 2, 0], - "containment_thresh": 0.85, "min_area": 9, - "pedge_thresh": [0.0, 0.0, 0.0014], - "group_thresh_expansion": [0.14, 0.06, 0.24] - }, - "isbud_thresh": 0.35 - }, - "prime95b_brightfield_60x_1z": { - "morph_model_file": "I1_prime_unet_4s_20210302.hdf5", - "flattener_file": "flattener_60x_1z_prime_20210302.json", - "celltrack_model_file": "ct_rf_20210201_12.pkl", - "budassign_model_file": "mb_model_60x_1z_prime_210302.pkl", - "pixel_size": 0.182, - "default_image_size": [117, 117], - "params": { - "interior_threshold": [0.5, 0.8, 0.4], - "nclosing": [0, 0, 0], "nopening": [1, 2, 0], - "connectivity": [1, 1, 1], - "edge_sub_dilations": [0, 1, 0], - "containment_thresh": 0.8, "min_area": 19, - "pedge_thresh": [0.0, 0.0, 0.0002], - "group_thresh_expansion": [0.22, 0.4, 0.24] - } - }, - "evolve_brightfield_60x_1z": { - "morph_model_file": "I1_evolve_unet_4s_20210302.hdf5", - "flattener_file": "flattener_60x_1z_evolve_20210302.json", - "celltrack_model_file": "ct_rf_20210201_12.pkl", - "budassign_model_file": "mb_model_60x_1z_evolve_210302.pkl", - "pixel_size": 0.263, - "default_image_size": [81, 81], - "params": { - "interior_threshold": [0.35, 0.8, 0.35], - "nclosing": [0, 0, 0], "nopening": [0, 0, 1], - "connectivity": [1, 1, 1], "edge_sub_dilations": [0, 1, 0], - "containment_thresh": 0.7, "min_area": 14, - "pedge_thresh": [0.0, 0.0, 0.0009], - "group_thresh_expansion": [0.16, 0.4, 0.0] - } - }, - "prime95b_brightfield_60x_3z": { - "morph_model_file": "I3_prime_unet_4s_20210215.hdf5", - "flattener_file": "flattener_60x_3z_prime_20210215.json", - "celltrack_model_file": "ct_rf_20210201_12.pkl", - "budassign_model_file": "mb_model_60x_3z_prime_20210215.pkl", - "pixel_size": 0.182, - "default_image_size": [117, 117], - "params": { - "interior_threshold": [0.65, 0.9, 0.3], - "nclosing": [0, 0, 0], "nopening": [1, 2, 1], - "connectivity": [1, 1, 1], - "edge_sub_dilations": [0, 0, 0], - "containment_thresh": 0.6, "min_area": 19, - "pedge_thresh": [0.0, 0.0, 0.0012], - "group_thresh_expansion": [0.12, 0.4, 0.3] - } - }, - "evolve_brightfield_60x_3z": { - "morph_model_file": "I3_evolve_unet_4s_20210215.hdf5", - "flattener_file": "flattener_60x_3z_evolve_20210215.json", - "celltrack_model_file": "ct_rf_20210201_12.pkl", - "budassign_model_file": "mb_model_60x_3z_evolve_20210215.pkl", - "pixel_size": 0.263, - "default_image_size": [81, 81], - "params": { - "interior_threshold": [0.75, 0.4, 0.45], - "nclosing": [0, 0, 0], "nopening": [0, 2, 0], - "connectivity": [1, 1, 1], "edge_sub_dilations": [0, 0, 0], - "containment_thresh": 0.9, "min_area": 9, - "pedge_thresh": [0.0, 0.0015, 0.0], - "group_thresh_expansion": [0.28, 0.38, 0.22] - } - } -} diff --git a/python/baby/modelsets.py b/python/baby/modelsets.py new file mode 100644 index 0000000000000000000000000000000000000000..645112e363c3e51c3ab509ae64ee43cd971c651d --- /dev/null +++ b/python/baby/modelsets.py @@ -0,0 +1,404 @@ +# If you publish results that make use of this software or the Birth Annotator +# for Budding Yeast algorithm, please cite: +# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +# and Swain, P.S. (2023). Determining growth rates from bright-field images of +# budding cells through identifying overlaps. eLife. 12:e79812. +# https://doi.org/10.7554/eLife.79812 +# +# +# The MIT License (MIT) +# +# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# SOFTWARE. +import os +from pathlib import Path +import requests +from urllib.parse import urljoin, quote +import json + +from .errors import BadParam, BadModel, BadFile +from .utils import jsonify, as_python_object +from .morph_thresh_seg import SegmentationParameters + +MODELSETS_FILENAME = 'modelsets.json' +MODELSET_FILENAME = 'modelset.json' + +BABY_MODELS_URL = 'https://julianpietsch.github.io/baby-models/' +MODELSETS_URL = urljoin(BABY_MODELS_URL, MODELSETS_FILENAME) + +ENV_VAR_MODELSETS_PATH = 'BABY_MODELSETS_PATH' +ENV_LOCAL_MODELSETS_PATH = os.environ.get(ENV_VAR_MODELSETS_PATH) +if (type(ENV_LOCAL_MODELSETS_PATH) == str and + ENV_LOCAL_MODELSETS_PATH.startswith('~')): + ENV_LOCAL_MODELSETS_PATH = Path(ENV_LOCAL_MODELSETS_PATH).expanduser() +DEFAULT_LOCAL_MODELSETS_PATH = Path.home() / '.baby_models' +LOCAL_MODELSETS_PATH = ENV_LOCAL_MODELSETS_PATH or DEFAULT_LOCAL_MODELSETS_PATH +LOCAL_MODELSETS_PATH = Path(LOCAL_MODELSETS_PATH) +LOCAL_MODELSETS_CACHE = LOCAL_MODELSETS_PATH / MODELSETS_FILENAME + +SHARE_PATH = 'shared' + + +def remote_modelsets(): + '''Retrieve info on the model sets available for download + + Returns: + Nested dict hierarchy decoded from remote JSON file. The root of the + hierarchy has keys 'models' and 'shared'. + + The value of the 'models' key is a dict whose keys are the folder + names for the model sets and whose values are dicts giving details for + each model set. Each model set dict has the following items: + - a 'name' key with value being a str naming the model, + - a 'files' key with value being a list of str specifying all file + names in the model set folder, + - a 'meta' key with value being a dict of meta data on the model, + - a 'brain_params' key with value being a dict of parameters to + instantiate a :py:class:`brain.BabyBrain` object. + + The 'shared' key at the root of the hierarchy points to a list of str + giving file names in the 'shared' folder. + + ''' + msets_text = requests.get(MODELSETS_URL).text + return json.loads(msets_text, object_hook=as_python_object) + + +def _ensure_local_path(): + '''Creates the directory to store model sets if it does not exist''' + if not LOCAL_MODELSETS_PATH.is_dir(): + LOCAL_MODELSETS_PATH.mkdir(parents=True) + local_share_path = LOCAL_MODELSETS_PATH / SHARE_PATH + if not local_share_path.is_dir(): + local_share_path.mkdir(parents=True) + + +def update_local_cache(): + '''Update the locally cached info on model sets available for download''' + _ensure_local_path() + with open(LOCAL_MODELSETS_CACHE, 'wt') as f: + json.dump(jsonify(remote_modelsets()), f) + + +def _ensure_local_cache(update=False): + if not LOCAL_MODELSETS_CACHE.exists() or update: + update_local_cache() + + +def specifications(update=False, local=False): + '''Get full specifications for all model sets + + By default returns specifications for all models, including models that + have not yet been downloaded. + + Args: + update: whether or not to update the locally-cached list of models + local: whether to include only locally-available models. + + Returns: + a `dict` mapping model set IDs to dicts specifying details for each + model set. Each model set dict has at least the following items: + - a 'name' key with value being a str naming the model, + - a 'meta' key with value being a dict of meta data on the model, + - a 'brain_params' key with value being a dict of parameters to + instantiate a :py:class:`brain.BabyBrain` object. + ''' + _ensure_local_cache(update=update) + if local: + return local_modelsets()['models'] + with open(LOCAL_MODELSETS_CACHE, 'rt') as f: + modelsets_info = json.load(f, object_hook=as_python_object) + return modelsets_info['models'] + + +def ids(update=False, local=False): + '''List the available model sets by ID + + Args: + update: whether or not to update the locally-cached list of models + local: whether to include only locally-available models. + + Returns: + a `list` of `str` giving the ID of each available model set + ''' + return list(specifications(update=update, local=local).keys()) + + +def meta(update=False, local=False): + '''Obtain descriptive meta information on each model set + + Args: + update: whether or not to update the locally-cached list of models + local: whether to include only locally-available models. + + Returns: + a `dict` mapping model set IDs to dicts of meta information associated + with each model set + ''' + return {msId: v['meta'] for msId, v in + specifications(update=update, local=local).items()} + + +def local_modelsets(): + _ensure_local_path() + modelsets = {} + for modelset_file in LOCAL_MODELSETS_PATH.glob('**/' + MODELSET_FILENAME): + modelset_id = modelset_file.parent.name + with open(modelset_file, 'rt') as f: + modelsets[modelset_id] = json.load(f, object_hook=as_python_object) + modelsets[modelset_id]['files'] = [ + p.name for p in modelset_file.parent.iterdir()] + + local_share_path = LOCAL_MODELSETS_PATH / SHARE_PATH + shared_files = [p.name for p in local_share_path.iterdir()] + + return {'models': modelsets, 'shared': shared_files} + + +def resolve(filename, modelset_id): + '''Resolve the path to a file specified by a model set + + File names are resolved by first assuming the `modelset_id` argument + specifies a directory containing model files (the model set path). If + `modelset_id` is not a directory, then it is assumed to be a model ID as + per :py:func:`ids`, and then the corresponding model set path in the local + cache will be searched first. + + If the file is not found in the model set path, it will be searched for in + the shared directory of the local cache. + + Args: + filename (str): file name to resolve, typically as specified by one of + the keys obtained from :py:func:`get_params`. + modelset_id (str): one of the model IDs as returned by :py:func:`ids`. + + Returns: + A `Path` object giving the path to the specified file. If the file + cannot be found a BadParam exception is raised. + ''' + if Path(modelset_id).is_dir(): + modelset_path = Path(modelset_id) + else: + modelset_path = LOCAL_MODELSETS_PATH / modelset_id + + trial_path = modelset_path / filename + if trial_path.is_file(): + return trial_path + + trial_path = LOCAL_MODELSETS_PATH / SHARE_PATH / filename + if trial_path.is_file(): + return trial_path + else: + raise BadParam(f'The file {filename} could not be resolved for model set {modelset_id}') + + +def _get_modelset_files(modelset): + params = modelset['brain_params'] + modelset_files = [v for k, v in params.items() if k.endswith('_file')] + if type(params['params']) not in {dict, SegmentationParameters}: + modelset_files.append(params['params']) + return modelset_files + + +def _missing_files(modelset, modelset_id): + modelset_files = _get_modelset_files(modelset) + missing = [] + for fname in modelset_files: + try: + resolve(fname, modelset_id) + except BadParam: + missing.append(fname) + return missing + + +def update(modelset_ids=None, force=True, cleanup=False, verbose=True): + '''Updates any outdated model sets that are available remotely + + Args: + modelset_ids: a list of str specifying one of the model IDs as + returned by :py:func:`ids`. By default, only updates model sets + that have already been downloaded locally. Set this to `'all'` to + download all available model sets. + force: whether to replace existing files or just obtain missing ones. + cleanup: whether to delete model set and shared files that are no + longer on the remote + verbose: whether to print download status to standard out. + ''' + remote_mset_info = remote_modelsets() + remote_msets = remote_mset_info['models'] + local_mset_info = local_modelsets() # ensures local path + local_msets = local_mset_info['models'] + + if modelset_ids is None: + remote_ids = set(remote_msets.keys()) + modelset_ids = list(remote_ids.intersection(local_msets.keys())) + elif modelset_ids == 'all': + modelset_ids = list(remote_msets.keys()) + + invalid_ids = [msId for msId in modelset_ids if msId not in remote_msets] + if any(invalid_ids): + invalid_ids = ', '.join([f'`{msId}`' for msId in invalid_ids]) + raise BadParam(f'Requested model set(s) {invalid_ids} not available') + + # Update model sets from remote host + for mset_id in modelset_ids: + mset_meta = remote_msets[mset_id] + local_mset_dir = LOCAL_MODELSETS_PATH / mset_id + if not local_mset_dir.exists(): + local_mset_dir.mkdir(parents=True) + + mset_missing = mset_id not in local_msets + mset_changed = mset_missing or local_msets[mset_id] != mset_meta + if mset_missing or (mset_changed and not force): + new_mset_meta = mset_meta.copy() + del new_mset_meta['files'] + with open(local_mset_dir / MODELSET_FILENAME, 'wt') as f: + json.dump(jsonify(new_mset_meta), f) + + if not force and not mset_missing: + # If there is already a local model set and we are not forcing an + # update, then we will only proceed with file download / cleanup + # if the local model has missing files. This allows for the case + # where a local version of a model has different files to those + # found on the server, and so avoids downloading potentially + # outdated extras. + nmissing = len(_missing_files(local_msets[mset_id], mset_id)) + if nmissing == 0: + continue + + remote_mset_files = set(mset_meta['files']) + local_mset_files = local_msets.get(mset_id, {}).get('files', []) + if cleanup: + # Clean up any old files that are no longer on the remote + for local_file in local_mset_files: + if local_file not in remote_mset_files: + (local_mset_dir / local_file).unlink() + + remote_mset_files = remote_mset_files.difference({MODELSET_FILENAME}) + remote_mset_dir = urljoin(BABY_MODELS_URL, mset_id + '/') + if force: + files_to_download = remote_mset_files + else: + files_to_download = remote_mset_files.difference(local_mset_files) + if len(files_to_download) > 0 and verbose: + print(f'Downloading files for {mset_id}...') + for remote_file in files_to_download: + r = requests.get(urljoin(remote_mset_dir, remote_file)) + with open(local_mset_dir / remote_file, 'wb') as f: + for chunk in r.iter_content(chunk_size=128): + f.write(chunk) + + # Update shared files from remote host + remote_share_dir = urljoin(BABY_MODELS_URL, SHARE_PATH + '/') + remote_shared_files = set(remote_mset_info['shared']) + local_share_path = LOCAL_MODELSETS_PATH / SHARE_PATH + local_shared_files = local_mset_info['shared'] + + if cleanup: + # Clean up any old files that are no longer on the remote + for local_file in local_shared_files: + if local_file not in remote_shared_files: + (local_share_path / local_file).unlink() + + if force: + files_to_download = remote_shared_files + else: + files_to_download = remote_shared_files.difference(local_shared_files) + + # Download any files that need updating + if len(files_to_download) > 0 and verbose: + print('Downloading shared files...') + for remote_file in files_to_download: + r = requests.get(urljoin(remote_share_dir, remote_file)) + with open(local_share_path / remote_file, 'wb') as f: + for chunk in r.iter_content(chunk_size=128): + f.write(chunk) + + +def _ensure_modelset(modelset_id): + '''Ensure that a model set has been downloaded and is ready to use + + Args: + modelset_id: a `str` specifying one of the model IDs as returned by + :py:func:`ids`. + ''' + _ensure_local_path() + local_path = LOCAL_MODELSETS_PATH / modelset_id + share_path = LOCAL_MODELSETS_PATH / SHARE_PATH + local_modelset_file = local_path / MODELSET_FILENAME + updated = False + if not local_modelset_file.exists(): + update([modelset_id], force=False) + updated = True + with open(local_modelset_file, 'rt') as f: + modelset = json.load(f, object_hook=as_python_object) + modelset_files = _get_modelset_files(modelset) + for fname in modelset_files: + try: + resolve(fname, modelset_id) + except BadParam: + if updated: + raise BadModel('Model is corrupt. Contact maintainer.') + else: + update([modelset_id], force=False) + try: + resolve(fname, modelset_id) + except BadParam: + raise BadModel('Model is corrupt. Contact maintainer.') + updated = True + + +def get_params(modelset_id): + '''Get model set parameters + + The parameters are designed to be supplied as the argument to + instantiate a :py:class:`brain.BabyBrain` object. + + The model set will be automatically downloaded if it has not yet been. + + Args: + modelset_id: a `str` specifying one of the model IDs as returned by + :py:func:`ids`. + ''' + _ensure_modelset(modelset_id) + local_path = LOCAL_MODELSETS_PATH / modelset_id + local_modelset_file = local_path / MODELSET_FILENAME + with open(local_modelset_file, 'rt') as f: + modelset = json.load(f, object_hook=as_python_object) + return modelset['brain_params'] + + +def get(modelset_id, **kwargs): + '''Get a model set as a BabyBrain object + + The model set will be automatically downloaded if it has not yet been. + + Args: + modelset_id: a `str` specifying one of the model IDs as returned by + :py:func:`ids`. + + Returns: + A :py:class:`brain.BabyBrain` object instantiated with the model set + parameters. + ''' + from .brain import BabyBrain + return BabyBrain(modelset_path=modelset_id, **get_params(modelset_id), + **kwargs) diff --git a/python/baby/morph_thresh_seg.py b/python/baby/morph_thresh_seg.py index f09e10d5484747dc81f3ba24752b0d23b1803d7d..2e9832c62e6f6991fd3234605f918d591035dcda 100644 --- a/python/baby/morph_thresh_seg.py +++ b/python/baby/morph_thresh_seg.py @@ -1,12 +1,14 @@ # If you publish results that make use of this software or the Birth Annotator # for Budding Yeast algorithm, please cite: -# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S -# Swain, 2021, Birth Annotator for Budding Yeast (in preparation). +# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +# and Swain, P.S. (2023). Determining growth rates from bright-field images of +# budding cells through identifying overlaps. eLife. 12:e79812. +# https://doi.org/10.7554/eLife.79812 # # # The MIT License (MIT) # -# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021 +# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to @@ -30,15 +32,142 @@ import itertools from scipy import ndimage -from typing import Union, Iterable, Any, Optional, List, NamedTuple +from typing import (Union, Sequence, Any, Optional, List, NamedTuple, Literal, + Tuple) from .errors import BadParam -from .segmentation import mask_containment, iterative_erosion, thresh_seg, \ - binary_edge, single_region_prop, outline_to_radial, get_edge_scores, \ - refine_radial_grouped, iterative_dilation, draw_radial +from baby import segmentation +from .segmentation import ( + iterative_erosion, threshold_segmentation, binary_edge, + single_region_prop, mask_to_knots, get_edge_scores, refine_radial_grouped, + iterative_dilation, draw_radial) +from .utils import EncodableNamedTuple from .volume import volume +BROADCASTABLE_PARAMS = {'edge_sub_dilations', 'interior_threshold', + 'nclosing', 'nopening', 'connectivity', + 'min_area', 'bbox_padding', 'pedge_thresh', + 'group_thresh_expansion'} + + +@EncodableNamedTuple +class SegmentationParameters(NamedTuple): + """Defines parameters for segmentation from CNN outputs. + + Designed to be used with :py:class:`MorphSegGrouped` and descriptions of + the attributes below should be read in conjunction with that class. + + Certain parameters can be specified group-wise, that is, with a different + value for each of the groupings defined by ``cellgroups``. If specified as + a singleton, the same parameter value will be replicated for each group. If + specified as a list, the length of the list should match the number of + cell groupings. + + This class has been decorated with :py:func:`utils.EncodableNamedTuple`, + so can be reversibly encoded to JSON using :py:func:`utils.jsonify`. + + Attributes: + cellgroups: By default, all groups defined by the ``flattener`` + (except ``budonly`` groups) will be included in descending size + order. Optionally specify a sequence of ``str`` or a sequence of + tuples of ``str`` naming specific flattener groups or respectively + combinations thereof. + edge_sub_dilations: Number of iterations of grayscale morphological + dilation to apply to the edge probability image when it is used + to dampen the interior probability image before thresholding. + Specify ``None`` to disable dampening, or ``0`` to dampen with + the (undilated) edge probability image. Can be specified + group-wise. + interior_threshold: Threshold to apply to interior probability image + for locating cell instances (see + :py:func:`segmentation.threshold_segmentation`). Can be specified + group-wise. + connectivity: Connectivity for labelling connected regions (see + :py:func:`segmentation.threshold_segmentation`). Can be specified + group-wise. + nclosing: Number of closing iterations to apply to identified masks + (see :py:func:`segmentation.threshold_segmentation`). Can be + specified group-wise. + nopening: Number of opening iterations to apply to identified masks + (see :py:func:`segmentation.threshold_segmentation`). Can be + specified group-wise. + min_area: Minimum area (in square pixels) below which a mask will be + discarded. Has precedence over a smaller minimum specified by + any of the ``cellgroups``. Can be specified group-wise. + fit_radial: If ``True``, edges are generated by splines with knots + defined in radial coordinates; if ``False``, edges are taken + directly from the dilated masks. + cartesian_spline: If ``True``, splines are interpolated in cartesian + coordinates; if ``False``, they are interpolated in radial + coordinates. + bbox_padding: The number of pixels of padding that should be + included when limiting mask processing to a bounding box. + nrays_thresh_map: See :py:func:`segmentation.guess_radial_edge`. Only + applies when ``cartesian_spline=False`` and ``curvy_knots=False``. + nrays_max: See :py:func:`segmentation.guess_radial_edge`. Only applies + when ``cartesian_spline=False`` and ``curvy_knots=False``. + n_trials: See :py:func:`segmentation.guess_cartesian_edge`. Only + applies when ``cartesian_spline=True`` and ``curvy_knots=False``. + pxperknot: See :py:func:`segmentation.guess_cartesian_edge`. Only + applies when ``cartesian_spline=True`` and ``curvy_knots=False``. + squishiness: See :py:func:`segmentation.guess_cartesian_edge`. Only + applies when ``cartesian_spline=True`` and ``curvy_knots=False``. + alignedness: See :py:func:`segmentation.guess_cartesian_edge`. Only + applies when ``cartesian_spline=True`` and ``curvy_knots=False``. + curvy_knots: If ``True``, knots for splines are placed at points of + estimated high curvature + (:py:func:`segmentation.curvy_knots_from_outline`); if ``False``, + knots are placed at equally spaced locations in radial space + (``cartesian_spline=False``) or approximately equally spaced + locations along the perimeter (``cartesian_spline=True``). + n_knots_fraction: Determines the number of knots as a fraction of edge + pixel count for the case where ``curvy_knots=True`` + (see:py:func:`segmentation.curvy_knots_from_outline`). + pedge_thresh: Threshold on the edge probability scores of guessed + outlines, below which outlines will be discarded. Can be specified + group-wise. If any are specified as ``None``, then the decision on + which contained outline to keep in neighbouring groups based on + area rather than edge score. + use_group_thresh: If ``True``, discard masks if their areas fall + outside the size limits as defined for the group they originated + in; if ``False``, keep all generated masks (except those below the + ``min_area`` threshold). + group_thresh_expansion: Expansion factor for the size limits defined + by each group. Can be specified group-wise. + containment_func: Specifies the metric to use to decide whether cells + in adjacent groups should be considered identical or not. + Currently the only valid choices are 'mask_containment' and + 'mask_iou' (see :py:func:`segmentation.mask_containment` and + :py:func:`segmentation.mask_iou`). + containment_thresh: Threshold above which to consider cells in + adjacent groups identical. + """ + cellgroups: Union[None, List[str], List[Tuple[str, ...]]] = None + edge_sub_dilations: Union[None, int, List[Optional[int]]] = 0 + interior_threshold: Union[float, List[float]] = 0.5 + nclosing: Union[int, List[int]] = 0 + nopening: Union[int, List[int]] = 0 + connectivity: Union[int, List[int]] = 2 + min_area: Union[int, List[int]] = 10 + fit_radial: bool = True + cartesian_spline: int = False + bbox_padding: Union[int, List[int]] = 10 + nrays_thresh_map: List[Tuple[float, int]] = [(5., 4), (20., 6)] + nrays_max: int = 8 + n_trials: int = 10 + pxperknot: float = 4. + squishiness: float = 0.1 + alignedness: float = 10. + curvy_knots: bool = False + n_knots_fraction: float = 0.5 + pedge_thresh: Union[None, float, List[Optional[float]]] = None + use_group_thresh: bool = False + group_thresh_expansion: Union[float, List[float]] = 0. + containment_func: Literal['mask_containment', 'mask_iou'] = 'mask_containment' + containment_thresh: int = 0.8 + + # class ContainmentFunction: # def __init__(self, threshold: float = .8): # self.threshold = threshold @@ -60,13 +189,11 @@ from .volume import volume # a = A() class Cell: - def __init__(self, area, mask, predicted_edge, border_rect, - fit_radial=True): + def __init__(self, area, mask, predicted_edge, border_rect, params): self.area = area self.mask = mask self.predicted_edge = predicted_edge - self.fit_radial = fit_radial - self.border_rect = border_rect + self.params = params self._coords = None self._edge = None @@ -79,15 +206,14 @@ class Cell: self.predicted_edge)[0] return self._edge_score - def _calculate_properties(self, fit_radial): - self._edge = binary_edge(self.mask) - if fit_radial: - rprop = single_region_prop(self.mask) - coords, edge = outline_to_radial(self.edge, rprop, - return_outline=True) + def _calculate_properties(self): + if self.params.fit_radial: + coords, edge = mask_to_knots(self.mask, + p_edge=self.predicted_edge, + **self.params._asdict()) self.mask = ndimage.binary_fill_holes(edge) else: - edge = self._edge | (self.border_rect & self.mask) + edge = binary_edge(self.mask) | (self.border_rect & self.mask) coords = tuple() self._coords = coords self._edge = edge @@ -95,13 +221,13 @@ class Cell: @property def edge(self): if self._edge is None: - self._calculate_properties(fit_radial=self.fit_radial) + self._calculate_properties() return self._edge @property def coords(self): if self._coords is None: - self._calculate_properties(fit_radial=self.fit_radial) + self._calculate_properties() return self._coords @property @@ -172,53 +298,30 @@ class Target: class Group: - def __init__(self, - targets, - min_area=10., - use_thresh=False, - thresh_expansion=0., - pedge_thresh=None, - interior_threshold=0.5, - n_closing=0, - n_opening=0, - connectivity=2, - edge_sub_dilations=None): - # Parameter assignment - self.__connectivity = connectivity - self.__min_area = min_area - self.__use_thresh = use_thresh - self.__thresh_expansion = thresh_expansion - self.__pedge_thresh = pedge_thresh - self.__n_closing = n_closing - self.__n_opening = n_opening - self.__interior_threshold = interior_threshold - self.edge_sub_dilations = edge_sub_dilations - - # Subgroup targets + def __init__(self, targets, params): + + self.params = params self.targets = targets + self.cells = [] # Computed members self._n_erode = None self._lower = None self._upper = None - # Dunno yet, probably functions - self.cells = [] - def _calculate_bounds(self): # get bounds - if self.__use_thresh: + if self.params.use_group_thresh: lower = min(target.definition.get('lower', 1.) for target in self.targets) upper = max(target.definition.get('upper', float('inf')) for target in self.targets) - expansion = self.__thresh_expansion * (lower - if upper == float('inf') - else upper - lower) - lower = max(lower - expansion, self.__min_area) + expansion = self.params.group_thresh_expansion * ( + lower if upper == float('inf') else upper - lower) + lower = max(lower - expansion, self.params.min_area) upper += expansion else: - lower, upper = self.__min_area, float('inf') + lower, upper = self.params.min_area, float('inf') self._lower = lower self._upper = upper @@ -244,7 +347,7 @@ class Group: @property def max_n_erode(self): max_n_erode = max(self.n_erode) - if self.edge_sub_dilations is not None and max_n_erode == 0: + if self.params.edge_sub_dilations is not None and max_n_erode == 0: return 1 else: return max_n_erode # Todo: add number of dilations? @@ -255,7 +358,7 @@ class Group: @property def interior_threshold(self): - return self.__interior_threshold or 0.5 + return self.params.interior_threshold or 0.5 def prediction(self, pred, target_name, erode=False): predictions = [target.prediction(pred, target_name) @@ -272,16 +375,22 @@ class Group: for p, n_erode in zip(predictions, self.n_erode)] return predictions - def segment(self, pred, border_rect, fit_radial=False): - """ - Obtain the cell masks, areas, edges, and coordiantes from the - prediction of the interior of the cell group. - - - :param fit_radial: - :param pred: The neural network's prediction - :param border_rect: A boolean array delimiting the border of the - prediction arrays + def segment(self, pred, border_rect): + """Find cell instances for this group from interior probability + + Given raw CNN output ``pred``, extract the interior probability image + defined for this group, then apply threshold segmentation as per + :py:func:`segmentation.threshold_segmentation` and create a + :py:class:`Cell` for each identified mask. Depending on the + :py:class:`SegmentationParameters`, some filtering by mask area and/or + probability of the mask edge on the edge probability image may be + applied. + + Args: + pred (ndarray): The neural network's prediction with shape + (targets, X, Y, 1). + border_rect (ndarray): An image of shape (X, Y) which is False + except for a one-pixel border of True values. """ # TODO this is the bit where the use of the overlap informations @@ -289,27 +398,33 @@ class Group: pred_interior = self.prediction(pred, 'interior', erode=True) pred_edge = self.prediction(pred, 'edge', erode=False) - if self.edge_sub_dilations is not None: + if self.params.edge_sub_dilations is not None: pred_edge = iterative_dilation(pred_edge, - self.edge_sub_dilations) + self.params.edge_sub_dilations) pred_interior *= (1 - pred_edge) - masks_areas = [(m, a) for m, a in thresh_seg( - pred_interior, interior_threshold=self.__interior_threshold or 0.5, - nclosing=self.__n_closing or 0, - nopening=self.__n_opening or 0, - ndilate=self.max_n_erode, - return_area=True, connectivity=self.__connectivity) - if self.lower <= a < self.upper] - self.cells = [Cell(a, m, pred_edge, border_rect, fit_radial=fit_radial) - for m, a in masks_areas] + masks_areas = [ + (m, a) for m, a in threshold_segmentation( + pred_interior, + interior_threshold=self.params.interior_threshold or 0.5, + nclosing=self.params.nclosing or 0, + nopening=self.params.nopening or 0, + ndilate=self.max_n_erode, + return_area=True, + connectivity=self.params.connectivity) + if self.lower <= a < self.upper + ] + self.cells = [ + Cell(a, m, pred_edge, border_rect, self.params) + for m, a in masks_areas + ] # Remove cells that do not exceed the p_edge threshold - if self.__pedge_thresh is not None: + if self.params.pedge_thresh is not None: self.cells = [cell for cell in self.cells - if cell.edge_score > self.__pedge_thresh] + if cell.edge_score > self.params.pedge_thresh] -def broadcast_arg(arg: Union[Iterable, Any], +def broadcast_arg(arg: Union[Sequence, Any], argname: Optional[str] = None, n_groups: Optional[int] = 3): if argname is None: @@ -323,98 +438,82 @@ def broadcast_arg(arg: Union[Iterable, Any], return [arg] * n_groups -class MorphSegGrouped: - def __init__(self, flattener, cellgroups=None, - interior_threshold=0.5, nclosing=0, nopening=0, - connectivity=2, - min_area=10, pedge_thresh=None, fit_radial=False, - use_group_thresh=False, group_thresh_expansion=0., - edge_sub_dilations=None, - containment_thresh=0.8, containment_func=mask_containment, - return_masks=False, return_coords=False, return_volume=False): - """ - :param flattener: - :param cellgroups: - :param interior_threshold: - :param nclosing: - :param nopening: - :param connectivity: - :param min_area: - :param pedge_thresh: - :param fit_radial: - :param use_group_thresh: - :param group_thresh_expansion: - :param edge_sub_dilations: - :param containment_thresh: - :param containment_func: - :param return_masks: - :param return_coords: - """ - # Todo: assertions about valid options - # (e.g. 0 < interior_threshold < 1) - # Assign options and parameters - self.pedge_thresh = pedge_thresh - self.fit_radial = fit_radial - self.containment_thresh = containment_thresh - self.containment_func = containment_func +class MorphSegGrouped: + """Provides an instance segmentation method given CNN target definitions. + + Args: + flattener (preprocessing.SegmentationFlattening): Target definitions + for the CNN predictions that will be used as input to + :py:meth:`segment`. + params (SegmentationParameters): Parameters used for segmentation. See + :py:class:`SegmentationParameters` for more details. + return_masks (bool): Whether masks should be returned by + :py:meth:`segment`. + return_coords (bool): Whether knot coordinates should be returned by + :py:meth:`segment`. + + Attributes: + groups (List[Group]): Segmentation processing/results for each group. + flattener (preprocessing.SegmentationFlattening): As for the + ``flattener`` argument. + params (SegmentationParameters): Parameters used for segmentation + including any broadcasting for the number of cell groups. + return_masks: Whether masks should be returned by :py:meth:`segment`. + return_coords: Whether knot coordinates should be returned by + :py:meth:`segment`. + """ + + def __init__(self, flattener, params=SegmentationParameters(), + return_masks=False, return_coords=False): + self.flattener = flattener self.return_masks = return_masks self.return_coords = return_coords - self.return_volume = return_volume - self.flattener = flattener + # TODO: assertions about valid options + # (e.g. 0 < interior_threshold < 1) - # Define group parameters + # Define group parameters + cellgroups = params.cellgroups if cellgroups is None: - cellgroups = ['large', 'medium', 'small'] + cellgroups = [t.group for t in flattener.targets + if t.prop in {'interior', 'filled'}] cellgroups = [(g,) if isinstance(g, str) else g for g in cellgroups] n_groups = len(cellgroups) - interior_threshold = broadcast_arg(interior_threshold, - 'interior_threshold', n_groups) - n_closing = broadcast_arg(nclosing, 'nclosing', n_groups) - n_opening = broadcast_arg(nopening, 'nopening', n_groups) - min_area = broadcast_arg(min_area, 'min_area', n_groups) - connectivity = broadcast_arg(connectivity, 'connectivity', n_groups) - pedge_thresh = broadcast_arg(pedge_thresh, 'pedge_thresh', n_groups) - group_thresh_expansion = broadcast_arg(group_thresh_expansion, - 'group_thresh_expansion', - n_groups) - edge_sub_dilations = broadcast_arg(edge_sub_dilations, - 'edge_substraction_dilations', - n_groups) + + # Broadcast relevant parameters + params = params._replace(**{ + p: broadcast_arg(getattr(params, p), p, n_groups) + for p in BROADCASTABLE_PARAMS + }) # Minimum area must be larger than 1 to avoid generating cells with # no size: - min_area = [np.max([a, 1]) for a in min_area] + params = params._replace( + min_area=[np.max([a, 1]) for a in params.min_area]) + + self.params = params + + self.containment_func = getattr(segmentation, params.containment_func) + self.containment_thresh = params.containment_thresh # Initialize the different groups and their targets self.groups = [] for i, target_names in enumerate(cellgroups): targets = [Target(name, flattener) for name in target_names] - self.groups.append(Group(targets, min_area=min_area[i], - use_thresh=use_group_thresh, - thresh_expansion=group_thresh_expansion[i], - pedge_thresh=pedge_thresh[i], - interior_threshold=interior_threshold[i], - n_closing=n_closing[i], - n_opening=n_opening[i], - connectivity=connectivity[i], - edge_sub_dilations=edge_sub_dilations[i])) - self.group_segs = None + self.groups.append(Group(targets, params=params._replace(**{ + p: getattr(params, p)[i] for p in BROADCASTABLE_PARAMS + }))) # Todo: This is ideally the form of the input argument def contains(self, a, b): return self.containment_func(a, b) > self.containment_thresh def remove_duplicates(self): + """Resolve any cells duplicated across adjacent groups. """ - Resolve any cells duplicated across adjacent groups: - :param group_segs: - :return: The group segmentations with duplicates removed - """ - - if self.pedge_thresh is None: + if all([t is None for t in self.params.pedge_thresh]): def accessor(cell): return cell.area else: @@ -444,7 +543,7 @@ class MorphSegGrouped: def extract_edges(self, pred, shape, refine_outlines, return_volume): masks = [[]] if refine_outlines: - if not self.fit_radial: + if not self.params.fit_radial: raise BadParam( '"refine_outlines" requires "fit_radial" to have been specified' ) @@ -456,11 +555,14 @@ class MorphSegGrouped: if predicted_edges: coords = list(itertools.chain.from_iterable( - refine_radial_grouped(grouped_coords, - predicted_edges))) + refine_radial_grouped( + grouped_coords, predicted_edges, + cartesian_spline=self.params.cartesian_spline))) else: coords = tuple() - edges = [draw_radial(radii, angles, centre, shape) + edges = [draw_radial( + radii, angles, centre, shape, + cartesian_spline=self.params.cartesian_spline) for centre, radii, angles in coords] if self.return_masks: masks = [ndimage.binary_fill_holes(e) for e in edges] @@ -480,21 +582,23 @@ class MorphSegGrouped: for cell in group.cells] if len(outputs) > 0: - return zip(*outputs) + return (list(o) for o in zip(*outputs)) else: return 4 * [[]] if return_volume else 3 * [[]] def segment(self, pred, refine_outlines=False, return_volume=False): - """ - Take the output of the neural network and turn it into an instance - segmentation output. - - :param pred: list of prediction images (ndarray with shape (x, y)) - matching `self.flattener.names()` - :return: a list of boolean edge images (ndarray shape (x, y)), one for - each cell identified. If `return_masks` and/or `return_coords` are - true, the output will be a tuple of edge images, filled masks, and/or - radial coordinates. + """Returns segmented instances based on the output of the CNN. + + Args: + pred: A list of prediction images (ndarray with shape (x, y)) + matching the names of :py:attr:`flattener` (see + :py:meth:`preprocessing.SegmentationFlattening.names`). + + Returns: + A list of boolean edge images (ndarray shape (x, y)), one for each + cell identified. If ``return_masks`` and/or ``return_coords`` are + true, the output will be a tuple of edge images, filled masks, + and/or radial coordinates. """ if len(pred) != len(self.flattener.names()): raise BadParam( @@ -506,7 +610,7 @@ class MorphSegGrouped: constant_values=True) for group in self.groups: - group.segment(pred, border_rect, fit_radial=self.fit_radial) + group.segment(pred, border_rect) # Remove cells that are duplicated in several groups self.remove_duplicates() diff --git a/python/baby/performance.py b/python/baby/performance.py index 888ed3de8501e102546e7246afc2979a77a2e009..d7a1005ef615e086468786f7bbe1a67ca3a32ef5 100644 --- a/python/baby/performance.py +++ b/python/baby/performance.py @@ -1,12 +1,14 @@ # If you publish results that make use of this software or the Birth Annotator # for Budding Yeast algorithm, please cite: -# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S -# Swain, 2021, Birth Annotator for Budding Yeast (in preparation). +# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +# and Swain, P.S. (2023). Determining growth rates from bright-field images of +# budding cells through identifying overlaps. eLife. 12:e79812. +# https://doi.org/10.7554/eLife.79812 # # # The MIT License (MIT) # -# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021 +# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to diff --git a/python/baby/postprocessing.py b/python/baby/postprocessing.py index 65225788f4b1088550825f30782cec6b944d651c..b4d52f96740b927f2a46201a9859ef94946c950c 100644 --- a/python/baby/postprocessing.py +++ b/python/baby/postprocessing.py @@ -1,12 +1,14 @@ # If you publish results that make use of this software or the Birth Annotator # for Budding Yeast algorithm, please cite: -# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S -# Swain, 2021, Birth Annotator for Budding Yeast (in preparation). +# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +# and Swain, P.S. (2023). Determining growth rates from bright-field images of +# budding cells through identifying overlaps. eLife. 12:e79812. +# https://doi.org/10.7554/eLife.79812 # # # The MIT License (MIT) # -# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021 +# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to @@ -243,8 +245,8 @@ def fit_gr_data(filename, run_moth=True, run_dght=True, split_at_birth=False, fvals = ['f', 'df', 'ddf'] for g in tqdm(data['D']): grp = data['D'][g] - allT = grp['times'] - allV = grp['vol'] + allT = grp['times'][()] + allV = grp['vol'][()] if log_volume: allV = np.log(allV) diff --git a/python/baby/preprocessing.py b/python/baby/preprocessing.py index 3626e141463f36848e732d05a99972358f7b1fba..1e3ff1e31748c0244c1393093c4fe3c5c827e60f 100644 --- a/python/baby/preprocessing.py +++ b/python/baby/preprocessing.py @@ -1,12 +1,14 @@ # If you publish results that make use of this software or the Birth Annotator # for Budding Yeast algorithm, please cite: -# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S -# Swain, 2021, Birth Annotator for Budding Yeast (in preparation). +# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +# and Swain, P.S. (2023). Determining growth rates from bright-field images of +# budding cells through identifying overlaps. eLife. 12:e79812. +# https://doi.org/10.7554/eLife.79812 # # # The MIT License (MIT) # -# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021 +# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to @@ -25,25 +27,25 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS # IN THE SOFTWARE. +import json +from itertools import chain +from typing import NamedTuple, Union import numpy as np from skimage import img_as_float from scipy.ndimage import ( - generic_filter, minimum_filter, maximum_filter, + convolve, minimum_filter, maximum_filter, binary_fill_holes, binary_dilation, binary_erosion, binary_closing ) from skimage.measure import regionprops from skimage.draw import polygon from skimage.morphology import diamond -from itertools import chain -from typing import NamedTuple, Union -import json from .segmentation import binary_edge from .utils import EncodableNamedTuple, jsonify, as_python_object # Depth-wise structuring elements for square or full connectivity dwsquareconn = diamond(1)[..., None] -dwfullconn = np.ones((3,3,1), dtype='uint8') +dwfullconn = np.ones((3,3,1), dtype=np.uint8) def raw_norm(img, info): """Keep raw values but scale non-float images to [0, 1]""" @@ -73,79 +75,38 @@ def robust_norm(img, info, q_outliers=0.04): return (img - mid) * (1 - q_outliers) / imrange -def seg_norm(img, info): - img = img > 0 - # Connect any 1-pixel gaps: - imconn = generic_filter(img.astype('int'), np.sum, footprint=dwfullconn) - imconn = binary_erosion(imconn > 1, dwsquareconn) | img - return imconn, info - - -def segoutline_flattening(fill_stack, info): - """Returns a stack of images in order: - edge: edges for all cells flattened into a single layer - filled: filled area for all cells flattened into a single layer - interiors: filled area excluding the edges - overlap: regions occupied by more than one cell - budring: edges located between mother and bud centres - bud: buds that are are smaller than a fractional threshold of the mother - """ - imsize = fill_stack.shape[0:2] - ncells = fill_stack.shape[2] - imout = np.zeros(imsize+(6,), dtype='bool') - - edge_stack = binary_edge(fill_stack, dwsquareconn) - - edge_flat = np.any(edge_stack, axis=2) - fill_flat = np.any(fill_stack, axis=2) - overlap = np.sum(fill_stack, axis=2)>1 - - imout[:,:,0] = edge_flat # edge - imout[:,:,1] = fill_flat # filled - imout[:,:,2] = fill_flat & ~edge_flat & ~overlap # interiors - imout[:,:,3] = overlap # overlap +def robust_norm_dw(img, info, q_outliers=0.04): + """Robust normalisation of intensity to [-1,1] applied depth-wise""" + img = img.copy().astype('float') + hq = q_outliers / 2 + for i in range(img.shape[2]): + low, mid, high = np.quantile(img[:,:,i], (hq, 0.5, 1-hq)) + imrange = high - low + imrange = 1 if imrange == 0 else imrange + img[:,:,i] = (img[:,:,i] - mid) * (1 - q_outliers) / imrange + return img - bud_pairs = [(m, np.nonzero(np.array(info.get('cellLabels', []))==b)[0][0]) - for m, b in enumerate(info.get('buds', []) or []) if b>0] - cell_info = [ - regionprops(fill_stack[:,:,i].astype('int32'), coordinates='rc') - for i in range(fill_stack.shape[2]) - ] - cell_info = [p[0] if len(p)>0 else None for p in cell_info] +def connect_pixel_gaps(img): + """Connects any 1-pixel gaps in an edge image""" + # Connect any 1-pixel gaps: + imconn = convolve(img.astype(np.uint8), dwfullconn, mode='constant') + imconn = binary_erosion(imconn > 1, dwsquareconn) | img + return imconn + - for m, b in bud_pairs: - if cell_info[m] is None or cell_info[b] is None: - # Label possible transformed outside field of view by augmentation - continue - if m == b: - raise Exception('a mother cannot be its own bud') - m_centre = np.array(cell_info[m].centroid).T - b_centre = np.array(cell_info[b].centroid).T - r_width = cell_info[b].minor_axis_length*0.25 - r_hvec = b_centre-m_centre - r_wvec = np.matmul(np.array([[0,-1],[1,0]]), r_hvec) - if np.linalg.norm(r_wvec) == 0: - raise Exception('mother and bud have coincident centres') - r_wvec = r_width*r_wvec/np.linalg.norm(r_wvec) - r_points = np.zeros((2,4)) - r_points[:,0] = m_centre-0.5*r_wvec - r_points[:,1] = r_points[:,0] + r_hvec - r_points[:,2] = r_points[:,1] + r_wvec - r_points[:,3] = r_points[:,2] - r_hvec - r_inds, c_inds = polygon(r_points[0,:], r_points[1,:], imsize) - r_im = np.zeros(fill_stack.shape[0:2], dtype='bool') - r_im[r_inds, c_inds] = 1 +def seg_norm(img, info): + img = img > 0 + return connect_pixel_gaps(img), info - # Bud junction - bj = (edge_stack[:,:,m] | edge_stack[:,:,b]) & r_im - imout[:,:,4] |= binary_dilation(binary_closing(bj)) - # Smaller buds - if (cell_info[b].area / cell_info[m].area) < 0.7: - imout[:,:,5] |= fill_stack[:,:,b] +def flattener_norm_func(flattener): + def norm_func(img, info): + img, info = seg_norm(img, info) + img = binary_fill_holes(img, dwsquareconn) + return flattener(img, info) - return imout + return norm_func @EncodableNamedTuple @@ -265,6 +226,16 @@ class SegmentationFlattening(object): def names(self): return tuple(t.name for t in self.targets) + def group_names(self, exclude_budonly=False): + groups = self.groupdef.items() + if exclude_budonly: + groups = [(k, g) for k, g in groups if not g.budonly] + sort_lower = sorted( + groups, key=lambda i: i[1].lower, reverse=True) + sort_upper = sorted( + sort_lower, key=lambda i: i[1].upper, reverse=True) + return next(zip(*sort_upper)) + def getGroupTargets(self, group, propfilter=None): assert group in self.groupdef, \ '"{}" group does not exist'.format(group) @@ -438,10 +409,77 @@ class SegmentationFlattening(object): return np.dstack(targetims) -def flattener_norm_func(flattener): - def norm_func(img, info): - img, info = seg_norm(img, info) - img = binary_fill_holes(img, dwsquareconn) - return flattener(img, info) +################## +### DEPRECATED ### +################## - return norm_func + +def segoutline_flattening(fill_stack, info): + """Returns a stack of images in order. + + DEPRECATED + + Args: + edge: edges for all cells flattened into a single layer + filled: filled area for all cells flattened into a single layer + interiors: filled area excluding the edges + overlap: regions occupied by more than one cell + budring: edges located between mother and bud centres + bud: buds that are are smaller than a fractional threshold of the mother + """ + imsize = fill_stack.shape[0:2] + ncells = fill_stack.shape[2] + imout = np.zeros(imsize+(6,), dtype='bool') + + edge_stack = binary_edge(fill_stack, dwsquareconn) + + edge_flat = np.any(edge_stack, axis=2) + fill_flat = np.any(fill_stack, axis=2) + overlap = np.sum(fill_stack, axis=2)>1 + + imout[:,:,0] = edge_flat # edge + imout[:,:,1] = fill_flat # filled + imout[:,:,2] = fill_flat & ~edge_flat & ~overlap # interiors + imout[:,:,3] = overlap # overlap + + bud_pairs = [(m, np.nonzero(np.array(info.get('cellLabels', []))==b)[0][0]) + for m, b in enumerate(info.get('buds', []) or []) if b>0] + + cell_info = [ + regionprops(fill_stack[:,:,i].astype('int32'), coordinates='rc') + for i in range(fill_stack.shape[2]) + ] + cell_info = [p[0] if len(p)>0 else None for p in cell_info] + + for m, b in bud_pairs: + if cell_info[m] is None or cell_info[b] is None: + # Label possible transformed outside field of view by augmentation + continue + if m == b: + raise Exception('a mother cannot be its own bud') + m_centre = np.array(cell_info[m].centroid).T + b_centre = np.array(cell_info[b].centroid).T + r_width = cell_info[b].minor_axis_length*0.25 + r_hvec = b_centre-m_centre + r_wvec = np.matmul(np.array([[0,-1],[1,0]]), r_hvec) + if np.linalg.norm(r_wvec) == 0: + raise Exception('mother and bud have coincident centres') + r_wvec = r_width*r_wvec/np.linalg.norm(r_wvec) + r_points = np.zeros((2,4)) + r_points[:,0] = m_centre-0.5*r_wvec + r_points[:,1] = r_points[:,0] + r_hvec + r_points[:,2] = r_points[:,1] + r_wvec + r_points[:,3] = r_points[:,2] - r_hvec + r_inds, c_inds = polygon(r_points[0,:], r_points[1,:], imsize) + r_im = np.zeros(fill_stack.shape[0:2], dtype='bool') + r_im[r_inds, c_inds] = 1 + + # Bud junction + bj = (edge_stack[:,:,m] | edge_stack[:,:,b]) & r_im + imout[:,:,4] |= binary_dilation(binary_closing(bj)) + + # Smaller buds + if (cell_info[b].area / cell_info[m].area) < 0.7: + imout[:,:,5] |= fill_stack[:,:,b] + + return imout diff --git a/python/baby/profile.py b/python/baby/profile.py index d59cec1b0b1820618f80280efb8e368590d68799..4de4ea07ebb6636ebd8df407a10478e578b4fbff 100644 --- a/python/baby/profile.py +++ b/python/baby/profile.py @@ -1,12 +1,14 @@ # If you publish results that make use of this software or the Birth Annotator # for Budding Yeast algorithm, please cite: -# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S -# Swain, 2021, Birth Annotator for Budding Yeast (in preparation). +# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +# and Swain, P.S. (2023). Determining growth rates from bright-field images of +# budding cells through identifying overlaps. eLife. 12:e79812. +# https://doi.org/10.7554/eLife.79812 # # # The MIT License (MIT) # -# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021 +# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to diff --git a/python/baby/seg_trainer.py b/python/baby/seg_trainer.py deleted file mode 100644 index 91fc368ad5a6bf2a7d3d4bf2496d74f9ef264af6..0000000000000000000000000000000000000000 --- a/python/baby/seg_trainer.py +++ /dev/null @@ -1,449 +0,0 @@ -# If you publish results that make use of this software or the Birth Annotator -# for Budding Yeast algorithm, please cite: -# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S -# Swain, 2021, Birth Annotator for Budding Yeast (in preparation). -# -# -# The MIT License (MIT) -# -# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021 -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to -# deal in the Software without restriction, including without limitation the -# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or -# sell copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS -# IN THE SOFTWARE. -from math import floor, log10 -from itertools import combinations, product, chain, repeat -from typing import NamedTuple, Union, Tuple, Any -import numpy as np -np.seterr(all='ignore') -import pandas as pd - -from baby.segmentation import mask_containment -from baby.morph_thresh_seg import MorphSegGrouped -from baby.performance import calc_IoUs, best_IoU -from baby.brain import default_params -from baby.errors import BadProcess, BadParam - -BASIC_PARAMS = { - 'interior_threshold', 'nclosing', 'nopening', 'connectivity', - 'edge_sub_dilations' -} - -round_to_n = lambda x, n: round(x, -int(floor(log10(x))) + (n - 1)) - - -class Score(NamedTuple): - precision: float - recall: float - F1: float - F0_5: float - F2: float - meanIoU: float - - -class SegFilterParamOptim: - """ - # TODO What does this class do - * What are the parameters and what do they mean - * What are the defaults, what are the ranges/admissible options? - :param flattener: - :param basic_params: - :param IoU_thresh: - :param scoring: - :param nbootstraps: - :param bootstrap_frac: - """ - def __init__(self, - flattener, - basic_params={}, - IoU_thresh=0.5, - scoring='F0_5', - nbootstraps=10, - bootstrap_frac=0.9): - - self.IoU_thresh = IoU_thresh - self.scoring = scoring - self.nbootstraps = nbootstraps - self.bootstrap_frac = bootstrap_frac - self._basic_params = default_params.copy() - self._basic_params.update( - {k: v for k, v in basic_params.items() if k in BASIC_PARAMS}) - self._basic_params.update({ - 'fit_radial': True, - 'min_area': 1, - 'pedge_thresh': None, - 'use_group_thresh': False - }) - self.segmenter = MorphSegGrouped(flattener, - return_masks=True, - **self._basic_params) - - self.group_info = [] - for g, group in enumerate(self.segrps): - lower = min( - target.definition.get('lower', 1.) - for target in group.targets) - upper = max( - target.definition.get('upper', float('inf')) - for target in group.targets) - grange = lower if upper == float('inf') else upper - lower - self.group_info.append((g, lower, upper, grange)) - - @property - def scoring(self): - """ The scoring method used during evaluation of the segmentation. - Accepted values are: # TODO define the scoring metrics - * precision: - * recall: - * F1: - * F0_5: - * F2: - * meanIoU: - :return: str scoring method - """ - return self._scoring - - @scoring.setter - def scoring(self, val): - if val not in Score._fields: - raise BadParam('Specified scoring metric not available') - self._scoring = val - - @property - def basic_params(self): - return self._basic_params - - @property - def segrps(self): - return self.segmenter.groups - - @property - def stat_table(self): - val = getattr(self, '_stat_table', None) - if val is None: - raise BadProcess('"generate_stat_table" has not been run') - return val - - @property - def stat_table_bootstraps(self): - val = getattr(self, '_stat_table_bootstraps', None) - if val is None: - raise BadProcess('"generate_stat_table" has not been run') - return val - - @property - def truth(self): - val = getattr(self, '_nPs', None) - if val is None: - raise BadProcess('"generate_stat_table" has not been run') - return val - - @property - def truth_bootstraps(self): - val = getattr(self, '_nPs_bootstraps', None) - if val is None: - raise BadProcess('"generate_stat_table" has not been run') - return val - - @property - def opt_params(self): - val = getattr(self, '_opt_params', None) - if val is None: - raise BadProcess('"fit_filter_params" has not been run') - return val - - @property - def opt_score(self): - val = getattr(self, '_opt_score', None) - if val is None: - raise BadProcess('"fit_filter_params" has not been run') - return val - - def generate_stat_table(self, example_gen): - rows_truth = [] - rows = [] - for s, seg_ex in enumerate(example_gen): - ncells = len(seg_ex.target) if seg_ex.target.any() else 0 - rows_truth.append((s, ncells)) - - # Perform within-group segmentation - shape = np.squeeze(seg_ex.pred[0]).shape - border_rect = np.pad(np.zeros(tuple(x - 2 for x in shape), - dtype='bool'), - pad_width=1, - mode='constant', - constant_values=True) - masks = [] - for group in self.segrps: - group.segment(seg_ex.pred, border_rect, fit_radial=True) - for cell in group.cells: - masks.append(cell.mask) - - # Calculate containment scores across groups - contained_cells = {} - paired_groups = zip(self.segrps, self.segrps[1:]) - for g, (lower_group, upper_group) in enumerate(paired_groups): - for l, lower in enumerate(lower_group.cells): - for u, upper in enumerate(upper_group.cells): - containment = mask_containment(lower.mask, upper.mask) - if containment > 0: - if lower.edge_score > upper.edge_score: - contained_cells[(g + 1, u)] = containment - else: - contained_cells[(g, l)] = containment - - if ncells > 0: - IoUs = calc_IoUs(seg_ex.target, masks, fill_holes=False) - max_IoU = IoUs.max(axis=0) - assignments = IoUs.argmax(axis=0) - _, best_assignments = best_IoU(IoUs.T) - else: - max_IoU = np.zeros(len(masks)) - assignments = np.zeros(len(masks), dtype=np.uint16) - best_assignments = -np.ones(len(masks), dtype=np.int32) - ind = 0 - for g, group in enumerate(self.segrps): - for c, cell in enumerate(group.cells): - rows.append((s, g, c, cell.area, cell.edge_score, - contained_cells.get((g, c), - 0.), assignments[ind], - max_IoU[ind], best_assignments[ind])) - ind += 1 - - df_truth = pd.DataFrame(rows_truth, columns=['example', 'ncells']) - df_truth = df_truth.set_index('example') - self._nPs = df_truth.ncells - - dtypes = [('example', np.uint16), - ('group', np.uint8), - ('cell', np.uint16), - ('area', np.uint16), - ('p_edge', np.float64), - ('containment', np.float64), - ('assignments', np.uint16), - ('max_IoU', np.float64), - ('best_assignments', np.int32)] - df = pd.DataFrame(np.array(rows, dtype=dtypes)) - - df['is_best'] = ((df.best_assignments >= 0) & - (df.max_IoU >= self.IoU_thresh)) - df['eid'] = df.example - df['uid'] = tuple(zip(df.example, df.assignments)) - - # Generate a set of bootstrapping filters over 90% of the examples - examples = list(set(df_truth.index.values)) - nperboot = np.round(self.bootstrap_frac * len(examples)).astype(int) - bootstraps = [ - np.random.choice(examples, nperboot, replace=True) - for _ in range(self.nbootstraps) - ] - self._nPs_bootstraps = [df_truth.loc[b].sum() for b in bootstraps] - # Limit bootstrap examples to those present in segmentation output - bootstraps = [b[np.isin(b, df.example)] for b in bootstraps] - df.set_index('example', drop=False, inplace=True) - example_counts = df.example.value_counts() - self._stat_table_bootstraps = [] - for b in bootstraps: - df_boot = df.loc[b] - # Renumber examples to handle the case of duplicated examples in - # the bootstrap: - df_boot['eid'] = tuple( - chain(*(repeat(i, example_counts.loc[e]) - for i, e in enumerate(b)))) - df_boot['uid'] = tuple(zip(df_boot.eid, df_boot.assignments)) - df_boot.set_index('uid', inplace=True) - self._stat_table_bootstraps.append(df_boot) - - df.set_index('uid', inplace=True) - self._stat_table = df - - def filter_trial(self, - pedge_thresh, - group_thresh_expansion, - containment_thresh, - min_area, - bootstrap=True, - return_stderr=False): - if bootstrap: - dfs = self.stat_table_bootstraps - truths = self.truth_bootstraps - else: - dfs = [self.stat_table] - truths = [self.truth.sum()] - - uidcols = ['eid', 'assignments'] - score_boots = [] - for df, nT in zip(dfs, truths): - rejects = ((df.containment > containment_thresh) | - (df.area < min_area)) - for t_pe, g_ex, (g, l, u, gr) in zip(pedge_thresh, - group_thresh_expansion, - self.group_info): - g_ex = g_ex * gr - l = max(l - g_ex, 1) - u = u + g_ex - rejects |= (df.group == g) & ((df.p_edge < t_pe) | - (df.area < l) | (df.area > u)) - TP_mask = (~rejects) & (df.max_IoU >= self.IoU_thresh) - - # # TODO compare speed of: - # TPs_IoU = df.loc[TP_mask].groupby(uidcols).max_IoU.max() - # # with speed of: - TPs_IoU = [] - current_eid = 0 - asgn = {} - for m, e, a, iou in zip(TP_mask, df.eid, df.assignments, - df.max_IoU): - if e != current_eid: - TPs_IoU.extend(asgn.values()) - current_eid = e - asgn = {} - if not m: - continue - asgn[a] = max(iou, asgn.get(a, 0)) - TPs_IoU.extend(asgn.values()) - # # END TODO - - nPs = np.sum(~rejects) - nTPs = len(TPs_IoU) - nFPs = nPs - nTPs - nFNs = nT - nTPs - precision = nTPs / (nTPs + nFPs) - recall = nTPs / (nTPs + nFNs) - # Fbeta = (1 + beta^2) * P * R / (beta^2 * P + R) - F1 = 2 * precision * recall / (precision + recall) - F0_5 = 1.25 * precision * recall / (0.25 * precision + recall) - F2 = 5 * precision * recall / (4 * precision + recall) - score_boots.append( - Score(precision, recall, F1, F0_5, F2, np.mean(TPs_IoU))) - - score_boots = np.array(score_boots) - mean_score = Score(*score_boots.mean(axis=0)) - if return_stderr: - stderr = score_boots.std(axis=0) / np.sqrt(score_boots.shape[0]) - return (mean_score, Score(*stderr)) - else: - return mean_score - - def fit_filter_params(self, lazy=False, bootstrap=False): - # Define parameter grid values, firstly those not specific to a group - params = { - ('containment_thresh', None): np.linspace(0, 1, 21), - ('min_area', None): np.arange(0, 20, 1) - } - - # Determine the pedge_threshold range based on the observed p_edge - # range for each group - if (self.stat_table.is_best.all() or not self.stat_table.is_best.any()): - t_pe_upper = self.stat_table.groupby('group').p_edge.mean() - else: - q_pe = self.stat_table.groupby(['group', 'is_best']) - q_pe = q_pe.p_edge.quantile([0.25, 0.95]).unstack((1, 2)) - t_pe_upper = q_pe.loc[:, [(False, 0.95), (True, 0.25)]].mean(1) - - t_pe_vals = [ - np.arange(0, u, round_to_n(u / 20, 1)) for u in t_pe_upper - ] - - # Set group-specific parameter grid values - g_ex_vals = repeat(np.linspace(0, 0.4, 21)) - for g, (t_pe, g_ex) in enumerate(zip(t_pe_vals, g_ex_vals)): - params[('pedge_thresh', g)] = t_pe - params[('group_thresh_expansion', g)] = g_ex - - # Default starting point is with thresholds off and no group expansion - ngroups = len(self.segrps) - dflt_params = { - 'containment_thresh': 0, - 'min_area': 0, - 'pedge_thresh': list(repeat(0, ngroups)), - 'group_thresh_expansion': list(repeat(0, ngroups)) - } - - # Search first along each parameter dimension with all others kept at - # default: - opt_params = {} - for k, pvals in params.items(): - scrs = [] - for v in pvals: - p = _sub_params({k: v}, dflt_params) - scr = self.filter_trial(**p, bootstrap=bootstrap) - scrs.append(getattr(scr, self.scoring)) - maxInd = np.argmax(scrs) - opt_params[k] = pvals[maxInd] - - # Reset the template parameters to the best along each dimension - base_params = _sub_params(opt_params, dflt_params) - - if lazy: - # Simply repeat search along each parameter dimension, but now - # using the new optimum as a starting point - opt_params = {} - for k, pvals in params.items(): - scrs = [] - for v in pvals: - p = _sub_params({k: v}, base_params) - scr = self.filter_trial(**p, bootstrap=bootstrap) - scrs.append(getattr(scr, self.scoring)) - maxInd = np.argmax(scrs) - opt_params[k] = pvals[maxInd] - opt_params = _sub_params(opt_params, base_params) - scr = self.filter_trial(**opt_params, bootstrap=bootstrap) - self._opt_params = opt_params - self._opt_score = getattr(scr, self.scoring) - return - - # Next perform a joint search for parameters with optimal pairings - opt_param_pairs = {k: {v} for k, v in opt_params.items()} - for k1, k2 in combinations(params.keys(), 2): - scrs = [(v1, v2, - getattr( - self.filter_trial(**_sub_params({ - k1: v1, - k2: v2 - }, base_params), - bootstrap=bootstrap), - self.scoring)) - for v1, v2 in product(params[k1], params[k2])] - p1opt, p2opt, _ = max(scrs, key=lambda x: x[2]) - opt_param_pairs[k1].add(p1opt) - opt_param_pairs[k2].add(p2opt) - - # Finally search over all combinations of the parameter values found - # with optimal pairings - scrs = [] - for pvals in product(*opt_param_pairs.values()): - p = {k: v for k, v in zip(opt_param_pairs.keys(), pvals)} - p = _sub_params(p, base_params) - scrs.append((p, - getattr(self.filter_trial(**p, bootstrap=bootstrap), - self.scoring))) - - self._opt_params, self._opt_score = max(scrs, key=lambda x: x[1]) - - -def _sub_params(sub, param_template): - p = { - k: v.copy() if type(v) == list else v - for k, v in param_template.items() - } - for (k, g), v in sub.items(): - if g is None: - p[k] = v - else: - p[k][g] = v - return p diff --git a/python/baby/segmentation.py b/python/baby/segmentation.py index 4b40d103914ca8899ee7c9bb00810084db33050e..633c3eb76c61b7d25eba68442248ef15748b905b 100644 --- a/python/baby/segmentation.py +++ b/python/baby/segmentation.py @@ -1,12 +1,14 @@ # If you publish results that make use of this software or the Birth Annotator # for Budding Yeast algorithm, please cite: -# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S -# Swain, 2021, Birth Annotator for Budding Yeast (in preparation). +# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +# and Swain, P.S. (2023). Determining growth rates from bright-field images of +# budding cells through identifying overlaps. eLife. 12:e79812. +# https://doi.org/10.7554/eLife.79812 # # # The MIT License (MIT) # -# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021 +# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to @@ -25,39 +27,1024 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS # IN THE SOFTWARE. -from collections import Iterable -from itertools import chain, compress, repeat + import numpy as np -from numpy import newaxis as nax -from scipy.ndimage import (minimum_filter, binary_dilation, binary_erosion, - binary_closing, binary_opening, binary_fill_holes) +from itertools import chain +from typing import NamedTuple, Tuple +import inspect from scipy import interpolate from scipy.optimize import least_squares +from scipy.ndimage import (minimum_filter, binary_dilation, binary_erosion, + binary_closing, binary_opening, binary_fill_holes) +from skimage import filters from skimage.measure import label, regionprops -from skimage.segmentation import morphological_geodesic_active_contour from skimage.morphology import diamond, erosion, dilation -from skimage.draw import ellipse_perimeter -from skimage import filters -from .errors import BadParam +from .errors import BadParam + + +############################ +### UTILITY FUNCTIONS ### +############################ + + +squareconn = diamond(1) # 3x3 filter for 1-connected patches +fullconn = np.ones((3, 3), dtype='uint8') + + +def binary_edge(imfill, footprint=fullconn): + """Get square-connected edges from filled image:""" + return minimum_filter(imfill, footprint=footprint) != imfill + + +def iterative_erosion(img, iterations=1, **kwargs): + if iterations < 1: + return img + + for i in range(iterations): + img = erosion(img, **kwargs) + return img + + +def iterative_dilation(img, iterations=1, **kwargs): + if iterations is None: + return img + for _ in range(iterations): + img = dilation(img, **kwargs) + return img + + +def single_region_prop(mask): + return regionprops(mask.astype(np.uint8))[0] + + +def get_edge_scores(outlines, p_edge): + """Return probability scores for a list of outlines + + NB: BREAKING CHANGE. From July 2022 this has been corrected to take the + mean only over edge pixels rather than the entire image. Segmentation + parameters will need to be reoptimised. + + :param outlines: list of outline images (2D bool) + :param p_edge: edge probability image (2D float with values in [0, 1]) + + :return: list of edge probability scores (each in [0, 1]) + """ + return [ + p_edge[binary_dilation(o, iterations=2)].mean() for o in outlines + ] + + +def mask_iou(a, b): + """Intersection over union (IoU) between boolean masks""" + return np.sum(a & b) / np.sum(a | b) + + +def mask_containment(a, b): + """Max of intersection over a or over b""" + return np.max(np.sum(a & b) / np.array([np.sum(a), np.sum(b)])) + + +def bbox_overlaps(bboxA, bboxB): + """Returns True if the bounding boxes are overlapping. + + Args: + bboxA (tuple): A bounding box ``(min_row, min_col, max_row, max_col)`` + as defined by :py:func:`skimage.measure.regionprops`. + bboxB (tuple): A bounding box ``(min_row, min_col, max_row, max_col)`` + as defined by :py:func:`skimage.measure.regionprops`. + """ + lrA, lcA, urA, ucA = regA.bbox + lrB, lcB, urB, ucB = regB.bbox + rA = np.array([lrA, urA]) + cA = np.array([lcA, ucA]) + return ((not ((rA > urB).all() or (rA < lrB).all())) and + (not ((cA > ucB).all() or (cA < lcB).all()))) + + +def region_iou(regA, regB): + """Efficiently computes the IoU between two RegionProperties.""" + if bbox_overlaps(regA, regB): + bb_lr, bb_lc, _, _ = np.stack((regA.bbox, regB.bbox)).min(axis=0) + _, _, bb_ur, bb_uc = np.stack((regA.bbox, regB.bbox)).max(axis=0) + bboxA = np.zeros((bb_ur - bb_lr, bb_uc - bb_lc), dtype='bool') + bboxB = bboxA.copy() + bb = regA.bbox + bboxA[bb[0] - bb_lr:bb[2] - bb_lr, + bb[1] - bb_lc:bb[3] - bb_lc] = regA.image + bb = regB.bbox + bboxB[bb[0] - bb_lr:bb[2] - bb_lr, + bb[1] - bb_lc:bb[3] - bb_lc] = regB.image + return np.sum(bboxA & bboxB) / np.sum(bboxA | bboxB) + else: + return 0. + + +def limit_to_bbox(imglist, bbox, padding=10): + imH, imW = imglist[0].shape[:2] + assert all([i.shape[:2] == (imH, imW) for i in imglist]) + rmin, cmin, rmax, cmax = bbox + rmin = np.maximum(rmin - padding, 0) + cmin = np.maximum(cmin - padding, 0) + rmax = np.minimum(rmax + padding, imH) + cmax = np.minimum(cmax + padding, imW) + return (i[rmin:rmax, cmin:cmax] for i in imglist), (rmin, cmin, imH, imW) + + +def restore_from_bbox(imglist, bbunmap): + bbH, bbW = imglist[0].shape[:2] + assert all([i.shape[:2] == (bbH, bbW) for i in imglist]) + rmin, cmin, imH, imW = bbunmap + for img in imglist: + restored = np.zeros((imH, imW) + img.shape[2:], dtype=img.dtype) + restored[rmin:rmin + bbH, cmin:cmin + bbW] = img + yield restored + + +def threshold_segmentation(p_int, + interior_threshold=0.5, + connectivity=None, + nclosing=0, + nopening=0, + ndilate=0, + return_area=False): + """Generic threshold-based segmentation routine for eroded interiors. + + Finds cell mask instances from a thresholded probability image of cell + interiors (possibly trained on eroded targets). Each candidate mask is + then independently cleaned by binary closing and then binary opening. + Finally, masks are dilated to restore to an original size (i.e., to undo + the level of erosion applied to training target images). + + Args: + p_int: An ndarray specifying cell interior probability for each pixel. + interior_threshold: Threshold to apply to the probability image. + connectivity: Connectivity as defined by `skimage.measure.label`. + nclosing: Number of iterations in closing operation; if 0 then binary + closing is not applied. + nopening: Number of iterations in opening operation; if 0 then binary + opening is not applied. + ndilate: Number of iterations of binary dilation to apply at end. + return_area: If True, then function yields (mask, area) tuples. + + Yields: + Arrays with the same shape as `p_int` for each identified cell mask. If + `return_area=True`, then function yields (mask, area) tuples where the + area is the number of True pixels in the mask. + """ + + lbl, nmasks = label(p_int > interior_threshold, + return_num=True, + connectivity=connectivity) + for l in range(nmasks): + mask = lbl == l + 1 + if nclosing > 0: + mask = binary_closing(mask, iterations=nclosing) + if nopening > 0: + mask = binary_opening(mask, iterations=nopening) + if ndilate > 0: + mask = binary_dilation(mask, iterations=ndilate) + + if return_area: + yield mask, mask.sum() + else: + yield mask + + +def ordered_edge_points(mask, edge=None, border_rect=None): + """Returns edge coordinates ordered around the mask perimeter. + + Uses the tangent to the filled mask image to step pixel-wise around an + edge image and terminate when within a pixel length of the first pixel, or + when steps would be larger than 2 pixels to continue. Note that this + function does not return all edge points, just those pixels closest to a 1 + pixel step along the direction of the tangent. + + Args: + mask (ndarray): A 2D bitmask for a single cell. + edge (None or ndarray): To save recomputation, optionally provide the + edge image obtained from :py:func:`binary_edge`. + border_rect (None or ndarray): To save recomputation, optionally + provide an image of same shape as the mask with ``False`` values + except for a 1-pixel border of ``True`` values. + + Returns: + An ndarray of shape ``(N_edge_pixels, 2)`` giving row/column + coordinates for the ordered edge pixels. + """ + if edge is None: + edge = binary_edge(mask) + + if border_rect is None: + border_rect = np.pad( + np.zeros(tuple(x - 2 for x in mask.shape), dtype='bool'), + pad_width=1, mode='constant', constant_values=True) + + # Need to ensure an edge if mask is adjacent to the border + edge = binary_edge(mask) | (border_rect & mask) + + X, Y = np.nonzero(edge) + edgepts = np.c_[X,Y] + + # Use Sobel filter on Gaussian-blurred mask to estimate tangent along edge + mask_blur = filters.gaussian(mask, 1, mode='constant') + hgrad = filters.sobel_h(mask_blur) + vgrad = filters.sobel_v(mask_blur) + edge_hgrad = hgrad[edge] + edge_vgrad = vgrad[edge] + gradvec = np.c_[edge_hgrad, edge_vgrad] + normgradvec = gradvec / np.sqrt(np.sum(np.square(gradvec), axis=1)[:, None]) + rot90 = np.array([[0, 1], [-1, 0]]) + tngvec = np.matmul(normgradvec, rot90) + + # Loop over edge points in direction of tangent to get ordered list + unvisited = np.ones(edgepts.shape[0], dtype='bool') + i = 0 + start = edgepts[i] + ptorder = [] + for o in range(unvisited.size): + unvisited[i] = False + nextpt = edgepts[[i]] + tngvec[[i]] + if np.sum(np.square(nextpt - start)) <= 2 and i > 0: + # We have returned to the first pixel + break + nextpt_sqdist = np.sum(np.square(edgepts[unvisited]-nextpt), axis=1) + minInd = np.argmin(nextpt_sqdist) + if nextpt_sqdist[minInd] > 4: + # We are jumping further than just neighbouring pixels + break + i = np.flatnonzero(unvisited)[minInd] + ptorder.append(i) + + return edgepts[ptorder] + + +############################ +### RADIAL KNOT SPLINES ### +############################ + + +def rc_to_radial(rr_cc, centre): + """Converts row-column coordinates to radial coordinates. + + Can be used to directly convert (row, column) output from the + `numpy.nonzero` function applied to a binary edge image + ``rc_to_radial(np.nonzero(edge_image), (0, 0))`` + + Args: + rr_cc: A tuple ``(rr, cc)`` of ndarrays specifying row and column + indices to be converted. ``rr`` and ``cc`` should have the same + shape. + centre: A tuple (rr, cc) specifying the origin of the radial + coordinate system. + + Returns: + A tuple ``(rho, phi)`` of ndarrays of same shape as ``rr`` and ``cc`` + giving the radii ``rho`` and angles ``phi`` in the radial coordinate + system. + """ + rr, cc = rr_cc + rloc, cloc = centre + rr = rr - rloc + cc = cc - cloc + return np.sqrt(np.square(rr) + np.square(cc)), np.arctan2(cc, rr) + + +def eval_radial_spline(x, rho, phi): + """Evaluates a radial spline with knots defined in radial coordinates. + + The spline is periodic across the boundary. + + Args: + x: An ndarray of angles at which to evaluate the spline. + rho: An ndarray of radii in[0, Inf) for each knot of the spline. + phi: An ndarray of angles in [-pi, pi) for each knot of the spline. + + Returns: + An ndarray with same shape as x giving corresponding interpolated + radii. + """ + + # Angles need to be in increasing order for expected behaviour of phi as a + # parametric variable + order = np.argsort(phi) + rho = rho[order] + phi = phi[order] + offset = phi[0] + + # Make the boundaries periodic + rho = np.r_[rho, rho[0]] + phi = np.r_[phi - offset, 2 * np.pi] + + tck = interpolate.splrep(phi, rho, per=True) + return interpolate.splev(np.mod(x - offset, 2 * np.pi), tck) + + +def eval_cartesian_spline(x, rho, phi): + """Evaluates a cartesian spline with knots defined in radial coordinates. + + The spline interpolates between knots in order of appearance in the array + and is periodic over the end-points. Notably, the knots do not necessarily + need to be in order of increasing phi (though this will likely be the most + common case). + + Args: + x: An ndarray of 'angles' at which to evaluate the spline. These are + not true radial angles as used for `phi` when defining the knots, + but are rather parametric locations defined over [0, 2*pi). + rho: An ndarray of radii in[0, Inf) for each knot of the spline. + phi: An ndarray of angles in [-pi, pi) for each knot of the spline. + + Returns: + A tuple (Sx, Sy) of ndarrays with same shape as x, giving the + interpolated x and y coordinates corresponding to each parametric + location in x. + """ + + # Make the boundaries periodic + rho = np.r_[rho, rho[0]] + phi = np.r_[phi, phi[0]] + + # TODO check whether the spline behaves better if the lengths between each + # knot are used to scale relative distance in the parametric variable + + # Define splines according to a parametric variable over [0,2*pi) + t = np.linspace(0, 2*np.pi, rho.size) + Xtck = interpolate.splrep(t, rho*np.cos(phi), per=True) + Ytck = interpolate.splrep(t, rho*np.sin(phi), per=True) + + # Evaluate over a modded x + Sx = interpolate.splev(np.mod(x, 2 * np.pi), Xtck) + Sy = interpolate.splev(np.mod(x, 2 * np.pi), Ytck) + return Sx, Sy + + +def draw_radial(rho, phi, centre, shape, cartesian_spline=False): + """Renders a spline defined in radial coordinates as an image. + + By default, interpolates spline in radial space. + + Args: + rho: An ndarray of radii defining knot locations. + phi: An ndarray of angles defining knot locations. + centre: A sequence of length 2 specifying the row and column defining + the origin of the radial coordinate system. + shape: A sequence of length 2 specifying the height and width of the + image to be rendered to. + cartesian_spline: If True, then interpolate spline in cartesian space. + + Returns: + An ndarray with the specified shape and of dtype bool containing the + rendered spline. + """ + + mr, mc = shape + im = np.zeros(shape, dtype='bool') + + # Estimate required sampling density from lengths of piecewise linear segments + rho_loop = np.r_[rho, rho[0]] + phi_loop = np.r_[phi, phi[0]] + xy_loop = np.c_[rho_loop*np.cos(phi_loop), rho_loop*np.sin(phi_loop)] + linperim = np.sum(np.sqrt(np.sum(np.square(np.diff(xy_loop, axis=0)), axis=1))) + neval = np.round(2.5 * linperim).astype(int) + + if neval > 1: + x = np.linspace(0, 2 * np.pi, neval) + if cartesian_spline: + Sx, Sy = eval_cartesian_spline(x, rho, phi) + rr = np.round(centre[0] + Sx).astype(int) + cc = np.round(centre[1] + Sy).astype(int) + else: + R = eval_radial_spline(x, rho, phi) + rr = np.round(centre[0] + R * np.cos(x)).astype(int) + cc = np.round(centre[1] + R * np.sin(x)).astype(int) + rr[rr < 0] = 0 + rr[rr >= mr] = mr - 1 + cc[cc < 0] = 0 + cc[cc >= mc] = mc - 1 + else: + rr = np.round(centre[0]).astype(int) + cc = np.round(centre[1]).astype(int) + im[rr, cc] = True + return im + + +def _radii_from_outline(outline, centre, ray_angles, max_ray_len): + # Improve accuracy of edge position by smoothing the outline image and using + # weighted averaging of pixel positions below: + outline = filters.gaussian(outline, 0.5) + + ray_tmplt = 0.5 * np.arange(np.round(2 * max_ray_len))[:, None] + rr_max, cc_max = outline.shape + + radii = [] + for angle in ray_angles: + ray = np.matmul(ray_tmplt, np.array((np.cos(angle), np.sin(angle)))[None, :]) + ray = np.round(centre + ray).astype('int') + rr, cc = ray[:, 0], ray[:, 1] + ray = ray[(rr >= 0) & (rr < rr_max) & (cc >= 0) & (cc < cc_max), :] + + edge_pix = np.flatnonzero( + np.squeeze(outline[ray[:, 0], ray[:, 1]]) > 0.01) + + if len(edge_pix) == 0: + radii.append(np.NaN) + continue + + ray = ray[edge_pix, :] + edge_pix = np.average(ray, + weights=outline[ray[:, 0], ray[:, 1]], + axis=0) + radii.append(np.sqrt(np.sum((edge_pix - centre)**2))) + + return np.array(radii) + + +def guess_radial_edge(edge, mask=None, rprops=None, + nrays_thresh_map=[(5., 4), (20., 6)], nrays_max=8): + """Guesses knot placement for a radial spline by casting rays. + + Given an edge image, this function casts out rays from the (filled) + mask centroid and finds where they intersect with the edge. + + Args: + edge (ndarray): A 2D bitmask of the edge for a single cell. + mask (None or ndarray): To save recomputation, optionally provide a + mask image that is the filled version of the edge image. Not + required if `rprops` is specified. + rprops (None or RegionProps): To save recomputation, optionally + provide :py:func:`skimage.measure.regionprops` that have been + calculated on the mask image. + nrays_thresh_map (List[Tuple[float, int]]): An ordered list of tuples + ``(upper_threshold, n_rays)`` that give an upper threshold on + major axis length for which the specified number of rays + ``n_rays`` will be used in stead of ``nrays_max``. The first + satisfied threshold in the list will be used to select the number + of rays. + nrays_max (int): The number of rays that will be used if the major + axis length is larger than all ``upper_threshold`` values + specified in ``nrays_thresh_map``. + + Returns: + A tuple ``(rho, phi)`` of ndarrays giving knot locations in radial + coordinates (radii ``rho`` and angles ``phi``) with origin at mask + centroid as determined by :py:func:`skimage.measure.regionprops`. + """ + if mask is None and rprops is None: + mask = binary_fill_holes(edge) + + if rprops is None: + rprops = regionprops(mask.astype('int'))[0] + + r_maj = rprops.major_axis_length + nrays = nrays_max + for upper_thresh, n in nrays_thresh_map: + if r_maj < upper_thresh: + nrays = n + break + + RL, CL, RU, CU = rprops.bbox + bbH, bbW = RU - RL, CU - CL + bbdiag = np.sqrt(bbH * bbH + bbW * bbW) + + astep = 2 * np.pi / nrays + angles = (np.mod(rprops.orientation + np.pi, astep) + + np.arange(nrays)*astep - np.pi) + centre = np.array(rprops.centroid) + + radii = _radii_from_outline(edge, centre, angles, bbdiag) + + # Use linear interpolation for any missing radii (e.g., if region intersects + # with image boundary): + nanradii = np.isnan(radii) + if nanradii.all(): + radii = 0.1 * np.ones(angles.shape) + elif nanradii.any(): + radii = np.interp(angles, + angles[~nanradii], + radii[~nanradii], + period=2 * np.pi) + + return radii, angles + + +def guess_cartesian_edge(mask, p_edge, n_trials=10, pxperknot=5., + squishiness=1., alignedness=1., return_probs=False, + edge=None, rprops=None, border_rect=None): + """Guesses knot placement for a cartesian spline by ordering edge pixels. + + Uses :py:func:`ordered_edge_points` to obtain an ordered sequence of edge + pixels, then selects random subsets of these with roughly even + perimeter-wise spacing using a Dirichlet distribution. The random trial + with the highest probability as measured by ``p_edge`` is returned. + + The number of knots is always even and never fewer than four. The first + knot is biased to align with the major axis, but a normal distribution + allows for variation about this position (see ``alignedness`` argument). + The default parameters provide even sampling of knot positions over all + possible values. A small ``squishiness`` favours more regular spacing + between knots. A large ``alignedness`` favours alignment with the major + axis. + + Args: + mask (ndarray): A 2D bitmask of a single cell (dtype bool). + p_edge (ndarray): A 2D image of edge probabilities. + n_trials (int): Number of random subsets to test. + pxperknot (float): Intended spacing (in pixels) between knots. + squishiness (float): Scaling factor for variance of the Dirichlet + distribution. + alignedness (float): Scaling factor for standard deviation of position + of the first knot. + return_probs (bool): Specify ``True`` to include probabilities of the + best trial, the edge image (determined by binary_edge) and the + image of all ordered edge pixels in the return value. + edge (None or ndarray): If a binary edge image has already been + calculated from the mask it can optionally be provided here to + save recomputation. + rprops (None or RegionProps): If + :py:func:`skimage.measure.regionprops` have already been + calculated for the mask, they can optionally be provided here to + save recomputation. + border_rect (None or ndarray): To save recomputation, optionally + provide an image of same shape as the mask with ``False`` values + except for a 1-pixel border of ``True`` values. + + Returns: + A tuple ``(centre, rho, phi)`` of ndarrays specifying knot placement in + radial coordinates for the best trial. The radial coordinate system is + centred at ``centre`` (row/column format) and the knots have radii + given by ``rho`` and angles given by ``phi``. If + ``return_probs=True``, then a tuple ``(centre, rho, phi, trial_prob, + edge_prob, ordered_edge_prob)`` is returned to additionally give the + probability of the best trial, the edge image (as determined by + :py:func:`binary_edge`), and the image of ordered edge pixels (as + determined by :py:func:`ordered_edge_points`). + """ + if edge is None: + edge = binary_edge(mask) + + # Define knots in radial coordinates at mask centroid + if rprops is None: + rprops = regionprops(mask.astype(np.uint8))[0] + ctr = np.r_[rprops.centroid] + + # Get ordered list of edge points in radial coordinates + edgepts = ordered_edge_points(mask, edge=edge, border_rect=border_rect) + edgepts_xy = edgepts - ctr[None, :] + edgepts_rho = np.sqrt(np.sum(np.square(edgepts_xy), axis=1)) + # Need to flip X and Y here to match row/column format + edgepts_phi = np.arctan2(edgepts_xy[:, 1], edgepts_xy[:, 0]) + + # Candidate ellipse origins should differ from ellipsoid orientation by + # at most 1 arc pixel at half maximum radius + # <arc length> = <angle in radians> * <radius> + phi_tol = 2. / np.max(edgepts_rho) # allowed tolerance + ori_dist = np.abs(edgepts_phi - rprops.orientation) + cand_ori = np.flatnonzero(ori_dist < phi_tol) + if len(cand_ori) > 0: + # Candidates are then filtered to those within one pixel of maximum + # radius from initial set of candidates + cand_ori = cand_ori[np.max(edgepts_rho[cand_ori]) - edgepts_rho[cand_ori] < 1] + ellipse_ori = cand_ori[np.argmin(ori_dist[cand_ori])] + else: + # Just pick closest to ellipse orientation + ellipse_ori = np.argmin(ori_dist) + + # Randomly select edge points as knots over multiple trials + # Keep the trial with the highest edge probability + N_pts = edgepts_rho.size + frac_ori = ellipse_ori / (N_pts - 1.) + # Choose an even number of knots, no less than 4 + N_knots = np.maximum(2 * np.round(N_pts / 2 / pxperknot).astype('int'), 4) + cand_rho_phi = [] + cand_probs = [] + for _ in range(n_trials): + # Split snake into roughly equal segments using Dirichlet distribution + # We want the variance to scale as 1 / N_knots + # Variance is further scaled by squishiness + # Alignment with major ellipse axis is controlled by variance in + # additive normal noise + # defaults of squishiness=1 and alignedness=1 produce even sampling of + # knots over [0, N_pts) + alpha = N_knots * np.ones(N_knots) / squishiness + sec_len = np.random.dirichlet(alpha) + frac_ind = np.cumsum(sec_len) + frac_ind += np.random.normal(loc=frac_ori, scale=1 / N_knots / alignedness) + frac_ind = np.sort(np.mod(frac_ind, 1.)) + inds = np.floor(N_pts * frac_ind).astype('int') + + # Get the knot positions in radial coordinates + knots_rho = edgepts_rho[inds] + knots_phi = edgepts_phi[inds] + + splIm = draw_radial(knots_rho, knots_phi, ctr, p_edge.shape, + cartesian_spline=True) + probs = p_edge[splIm].mean() + cand_probs.append(probs) + cand_rho_phi.append((ctr, knots_rho, knots_phi)) + + indMax = np.argmax(cand_probs) + if return_probs: + raw_edge_prob = p_edge[edge].mean() + ord_edge_prob = p_edge[edgepts[:, 0], edgepts[:, 1]].mean() + return cand_rho_phi[indMax] + (cand_probs[indMax], raw_edge_prob, ord_edge_prob) + else: + return cand_rho_phi[indMax] + + +def inds_max_curvature(R, N_pergrp, N_elim): + """Helper function for curvy_knots_from_outline. + + Args: + R: row vector of radii or matrix with rows respectively giving x and y + coordinates. + N_pergrp: group knots into groups of this size. + N_elim: eliminate this many knots from each group. + + Returns: + Indices for the knots that had the highest curvature in each group. + """ + curvature = np.diff(np.diff(R[:, np.r_[-1, 0:R.shape[1], 0]], axis=1), axis=1) + curvature = np.sqrt(np.sum(curvature**2, axis=0)) + knorder = np.argsort(np.reshape(curvature, (-1, N_pergrp)), axis=1) + rowinds = np.reshape(np.arange(knorder.size), (-1, N_pergrp)) + rowinds = rowinds[np.tile(np.arange(knorder.shape[0])[:, None], (1, N_pergrp)), knorder] + return np.sort(rowinds[:, N_elim:].flatten()) + + +def curvy_knots_from_outline(outline, rprops, cartesian_spline=False, + n_knots_fraction=0.5): + """Places knots on an outline image at points of high curvature. + + The method essentially progressively eliminates knots with low curvature + from an initial dense equi-angular array of knots located on the outline. + Curvature is calculated as the difference of the difference between + consecutive locations (radii for a radial spline; the norm of x and y for + a cartesian spline). + + Args: + outline (ndarray): A 2D bitmask for a single cell edge. + rprops (RegionProperties): The output of + :py:func:`skimage.measure.regionprops` for the filled outline. + cartesian_spline (bool): Specify ``True`` to model curvature in + cartesian rather than radial coordinates. + n_knots_fraction (float): A value in ``[0, 1]`` that determines the + number of knots as a fraction of the number of edge pixels. + + Returns: + A (radii, angles) tuple specifying the locations of the curvy knots in + radial coordinates from the `rprops` centroid + """ + Nedgepx = outline.sum() + # Initial dense number of rays is one ray for every four edge pixels, + # rounded to the nearest power of 2: + # Nrays_dense = 2**(np.round(np.log(max(Nedgepx, 16)) / np.log(2)).astype(int) - 2) + # Initial dense number of rays is one ray for every eight edge pixels, + # rounded to the nearest power of 2: + Nrays_dense = 2**(np.round(np.log(max(Nedgepx, 32)) / np.log(2)).astype(int) - 3) + + # The target number of rays rounded to the nearest four: + Nrays_final = 4 * max(1, np.round(0.25 * Nedgepx * n_knots_fraction).astype(int)) + + # Rays can be no larger than the diagonal of the bounding box + RL, CL, RU, CU = rprops.bbox + bbdiag = np.sqrt((RU - RL)**2 + (CU - CL)**2) + + # Determine radii for initial dense array + astep = 2 * np.pi / Nrays_dense + angles = np.mod(rprops.orientation + np.pi, astep) + \ + np.arange(Nrays_dense)*astep - np.pi + # Roughly compensate for elliptical squeezing by converting parameterised + # ellipse parameter to true angle. See: + # https://math.stackexchange.com/a/436125 + axmaj, axmin = rprops.major_axis_length, rprops.minor_axis_length + angles -= np.arctan((axmaj - axmin) * np.tan(angles) / (axmaj + axmin * np.tan(angles)**2)) + + centre = np.array(rprops.centroid) + radii = _radii_from_outline(outline, centre, angles, bbdiag) + + # Linearly interpolate any missing radii (e.g., if region intersects with + # image boundary): + nanradii = np.isnan(radii) + if nanradii.all(): + radii = 0.1 * np.ones(angles.shape) + elif nanradii.any(): + radii = np.interp(angles, + angles[~nanradii], + radii[~nanradii], + period=2 * np.pi) + + if cartesian_spline: + R = radii * np.vstack((np.cos(angles), np.sin(angles))) + else: + R = radii[None, :] + + # Progressively eliminate knots four at a time until the desired number of + # knots is reached + inds = np.arange(R.shape[1]) + for N_pergrp in np.r_[Nrays_dense // 4 : Nrays_final // 4 : -1]: + inds = inds[inds_max_curvature(R[:, inds], N_pergrp, 1)] + # Rotate indices over periodic boundary to mix groups + inds = inds[np.r_[-1,0:len(inds)-1]] + inds = np.sort(inds) + + return radii[inds], angles[inds] + + +class FakeRegionProperties(NamedTuple): + centroid: Tuple[int, int] + orientation: float + bbox: Tuple[int, int, int, int] + major_axis_length: float + minor_axis_length: float + + +GUESS_RADIAL_EDGE_PARAMS = { + p.name for p in + inspect.signature(guess_radial_edge).parameters.values() + if p.kind == p.POSITIONAL_OR_KEYWORD and p.default != p.empty +} + + +GUESS_CARTESIAN_EDGE_PARAMS = { + p.name for p in + inspect.signature(guess_cartesian_edge).parameters.values() + if p.kind == p.POSITIONAL_OR_KEYWORD and p.default != p.empty +} + + +CURVY_KNOTS_FROM_OUTLINE_PARAMS = { + p.name for p in + inspect.signature(curvy_knots_from_outline).parameters.values() + if p.kind == p.POSITIONAL_OR_KEYWORD and p.default != p.empty +} + + +def mask_to_knots(mask, + p_edge=None, + return_outline=True, + cartesian_spline=False, + curvy_knots=False, + bbox_padding=10, + **kwargs): + """Guess knot positions from a cell bitmask. + + Essentially a unified interface to the different methods for initial knot + placement before refinement. By default, it finds knots using + :py:func:`guess_radial_edge`. If ``cartesian_spline=True`` and + ``curvy_knots=False``, then knots are placed according to + :py:func:`guess_cartesian_edge`. If ``curvy_knots=True``, then knots are + found using :py:func:`curvy_knots_from_outline`. + + For increased performance, processing is limited to a padded bounding box + of the mask image. Any return values are, however, restored to the input + size/coordinates. + + Args: + mask (ndarray): A 2D bitmask for a single cell. + p_edge(ndarray): A 2D edge probability image with same shape as mask. + Currently only required when ``cartesian_spline=True`` and + ``curvy_knots=False``. + return_outline (bool): Specify ``False`` if you only want knot + coordinates returned. + cartesian_spline (bool): Specify ``True`` to determine knots for + splines interpolated in cartesian rather than radial coordinates. + curvy_knots (bool): Specify ``True`` to place knots at points of high + curvature according to :py:func:`curvy_knots_from_outline`. + bbox_padding (int): The number of pixels of padding that should be + included when limiting processing to the mask bounding box. + **kwargs: Additional arguments to be passed to any of the downstream + functions (:py:func:`guess_radial_edge`, + :py:func:`guess_cartesian_edge`, or + :py:func:`curvy_knots_from_outline`). + + Returns: + A tuple ``(knot_coordinates, edge_image)`` of knot coordinates + ``(centre, knot_rho, knot_phi)`` defined in a radial coordinate system + with origin ``centre`` (a size two ndarray giving row/column + location), and ndarrays ``knot_rho`` and ``knot_phi`` giving the + radius and polar angle respectively for each knot. If + ``return_outline=False``, then only ``knot_coordinates`` are returned. + """ + + rprops = single_region_prop(mask) + if cartesian_spline and not curvy_knots: + (mask_bb, p_edge_bb), bbunmap = limit_to_bbox( + (mask, p_edge), rprops.bbox) + else: + (mask_bb,), bbunmap = limit_to_bbox((mask,), rprops.bbox) + edge_bb = binary_edge(mask_bb) + + ctr = np.r_[rprops.centroid] + bboffset = np.r_[bbunmap[:2]] + limited_bb = np.r_[rprops.bbox] - np.tile(bboffset, 2) + rprops_bb = FakeRegionProperties( + centroid=ctr - bboffset, + orientation=rprops.orientation, + bbox=tuple(limited_bb.tolist()), + major_axis_length=rprops.major_axis_length, + minor_axis_length=rprops.minor_axis_length) + + if curvy_knots: + params = { + k: v for k, v in kwargs.items() + if k in CURVY_KNOTS_FROM_OUTLINE_PARAMS + } + k_rho, k_phi = curvy_knots_from_outline( + edge_bb, rprops_bb, cartesian_spline=cartesian_spline, **params) + elif cartesian_spline: + params = { + k: v for k, v in kwargs.items() + if k in GUESS_CARTESIAN_EDGE_PARAMS + } + _, k_rho, k_phi = guess_cartesian_edge( + mask_bb, p_edge_bb, edge=edge_bb, rprops=rprops_bb, + return_probs=False, **params) + else: + params = { + k: v for k, v in kwargs.items() if k in GUESS_RADIAL_EDGE_PARAMS + } + k_rho, k_phi = guess_radial_edge(edge_bb, rprops=rprops_bb, **params) + + if return_outline: + edge = draw_radial(k_rho, k_phi, ctr, mask.shape, + cartesian_spline=cartesian_spline) + return (ctr, k_rho, k_phi), edge + else: + return ctr, k_rho, k_phi + + +############################ +### OUTLINE REFINEMENT ### +############################ + + +def prior_resid_weight(resid, gauss_scale=5, exp_scale=1): + """Weights radial residuals to bias towards the initial guess + + Weight decays as a gaussian for positive residuals and exponentially for + negative residuals. So assuming `resid = rho_guess - rho_initial`, then + larger radii are favoured. + """ + W = np.zeros(resid.shape) + W[resid >= 0] = np.exp(-resid[resid >= 0]**2 / gauss_scale) + W[resid < 0] = np.exp(resid[resid < 0] / exp_scale) + return W + + +def adj_rspline_coords(adj, ref_radii, ref_angles): + """Map optimisation-space radial spline params to standard values + + Params in optimisation space are specified relative to reference radii and + reference angles. If constrained to [-1, 1], optimisation parameters will + allow a 30% change in radius, or change in angle up to 1/4 the distance + between consecutive angles. + """ + npoints = len(ref_radii) + return ( + # allow up to 30% change in radius + ref_radii * (1 + 0.3 * adj[:npoints]), + # allow changes in angle up to 1/4 the distance between points + ref_angles + adj[npoints:] * np.pi / (2 * npoints)) + + +def adj_rspline_resid(adj, rho, phi, probs, ref_radii, ref_angles): + """Weighted residual for radial spline optimisation + + Optimisation params (`adj`) are mapped according to `adj_rspline_coords`. + Target points are given in radial coordinates `rho` and `phi` with weights + `probs`. Optimisation is defined relative to radial spline params + `ref_radii` and `ref_angles`. + """ + radii, angles = adj_rspline_coords(adj, ref_radii, ref_angles) + return probs * (rho - eval_radial_spline(phi, radii, angles)) + + +def adj_cart_spline_resid(adj, rho, phi, probs, ref_radii, ref_angles): + """Weighted residual for cartesian spline optimisation in radial + coordinate system + + Optimisation params (`adj`) are mapped according to `adj_cart_spline_coords`. + Target points are given in radial coordinates `rho` and `phi` with weights + `probs`. Optimisation is defined relative to radial coordinates + `ref_radii` and `ref_angles`. + """ + radii, angles = adj_rspline_coords(adj, ref_radii, ref_angles) + return probs * (rho - eval_cartesian_spline(phi, radii, angles)) + + +def refine_radial_grouped(grouped_coords, grouped_p_edges, cartesian_spline=False): + """Refine initial radial spline by optimising to predicted edge + + Neighbouring groups are used to re-weight predicted edges belonging to + other cells using the initial guess + """ + + if cartesian_spline: + eval_spline = eval_cartesian_spline + adj_resid = adj_cart_spline_resid + else: + eval_spline = eval_radial_spline + adj_resid = adj_rspline_resid + + # Determine edge pixel locations and probabilities from NN prediction + p_edge_locs = [np.nonzero(p_edge > 0.2) for p_edge in grouped_p_edges] + p_edge_probs = [ + p_edge[rr, cc] + for p_edge, (rr, cc) in zip(grouped_p_edges, p_edge_locs) + ] + + p_edge_count = [len(rr) for rr, _ in p_edge_locs] + + opt_coords = [] + ngroups = len(grouped_coords) + for g, g_coords in enumerate(grouped_coords): + # If this group has no predicted edges, keep initial and skip + if p_edge_count[g] == 0: + opt_coords.append(g_coords) + continue + + # Compile a list of all cells in this and neighbouring groups + nbhd = list( + chain.from_iterable([ + [((gi, ci), coords) + for ci, coords in enumerate(grouped_coords[gi])] + for gi in range(max(g - 1, 0), min(g + 2, ngroups)) + if + p_edge_count[gi] > 0 # only keep if there are predicted edges + ])) + if len(nbhd) > 0: + nbhd_ids, nbhd_coords = zip(*nbhd) + else: + nbhd_ids, nbhd_coords = 2 * [[]] + + # Calculate edge pixels in radial coords for all cells in this and + # neighbouring groups: + radial_edges = [ + rc_to_radial(p_edge_locs[g], centre) + for centre, _, _ in nbhd_coords + ] + + # Calculate initial residuals and prior weights + resids = [ + rho - eval_spline(phi, radii, angles) + for (rho, phi), (_, radii, + angles) in zip(radial_edges, nbhd_coords) + ] + indep_weights = [prior_resid_weight(r) for r in resids] + + probs = p_edge_probs[g] + + g_opt_coords = [] + for c, (centre, radii, angles) in enumerate(g_coords): + ind = nbhd_ids.index((g, c)) + rho, phi = radial_edges[ind] + p_weighted = probs * indep_weights[ind] + other_weights = indep_weights[:ind] + indep_weights[ind + 1:] + if len(other_weights) > 0: + p_weighted *= (1 - np.mean(other_weights, axis=0)) + + # Remove insignificant fit data + signif = p_weighted > 0.1 + if signif.sum() < 10: + # With insufficient data, skip optimisation + g_opt_coords.append((centre, radii, angles)) + continue + p_weighted = p_weighted[signif] + phi = phi[signif] + rho = rho[signif] + + nparams = len(radii) + len(angles) + opt = least_squares(adj_resid, + np.zeros(nparams), + bounds=(-np.ones(nparams), np.ones(nparams)), + args=(rho, phi, p_weighted, radii, angles), + ftol=5e-2) -squareconn = diamond(1) # 3x3 filter for 1-connected patches -fullconn = np.ones((3, 3), dtype='uint8') + g_opt_coords.append((centre,) + + adj_rspline_coords(opt.x, radii, angles)) + opt_coords.append(g_opt_coords) -def binary_edge(imfill, footprint=fullconn): - """Get square-connected edges from filled image:""" - return minimum_filter(imfill, footprint=footprint) != imfill + return opt_coords -def mask_iou(a, b): - """Intersection over union (IoU) between boolean masks""" - return np.sum(a & b) / np.sum(a | b) +############################ +### DEPRECATED FUNCTIONS ### +############################ -def mask_containment(a, b): - """Max of intersection over a or over b""" - return np.max(np.sum(a & b) / np.array([np.sum(a), np.sum(b)])) +def get_regions(p_img, threshold): + """Find regions in a probability image sorted by likelihood""" + p_thresh = p_img > threshold + p_label = label(p_thresh, background=0) + rprops = regionprops(p_label, p_img) + rprops = [ + r for r in rprops + if r.major_axis_length > 0 and r.minor_axis_length > 0 + ] + rprops.sort(key=lambda x: x.mean_intensity, reverse=True) + return rprops def morph_thresh_masks(p_interior, @@ -115,173 +1102,20 @@ def unique_masks(masks, ref_masks, threshold=0.5, iou_func=mask_iou): return umasks -def morph_thresh_seg(cnn_outputs, - interior_threshold=0.9, - overlap_threshold=0.9, - bud_threshold=0.9, - bud_dilate=False, - bud_overlap=False, - isbud_threshold=0.5): - """Segment cell outlines from morphology output of CNN by thresholding - - Specify `overlap_threshold` or `bud_threshold` as `None` to ignore. - """ - - _, _, p_interior, p_overlap, _, p_bud = cnn_outputs - - if overlap_threshold is None: - p_overlap = None - - if isbud_threshold is None and bud_threshold is not None: - p_interior = p_interior * (1 - p_bud) - - masks = morph_thresh_masks(p_interior, - interior_threshold=interior_threshold, - p_overlap=p_overlap, - overlap_threshold=overlap_threshold) - - if bud_threshold is not None: - if not bud_overlap: - p_overlap = None - - budmasks = morph_thresh_masks(p_bud, - interior_threshold=bud_threshold, - dilate=False, - p_overlap=p_overlap, - overlap_threshold=overlap_threshold) - - if isbud_threshold is not None: - # Omit interior masks if they overlap with bud masks - masks = unique_masks(masks, - budmasks, - iou_func=mask_containment, - threshold=isbud_threshold) + budmasks - - # Return only the mask outlines - outlines = [minimum_filter(m, footprint=squareconn) != m for m in masks] - - return outlines - - -def get_regions(p_img, threshold): - """Find regions in a probability image sorted by likelihood""" - p_thresh = p_img > threshold - p_label = label(p_thresh, background=0) - rprops = regionprops(p_label, p_img) - rprops = [ - r for r in rprops - if r.major_axis_length > 0 and r.minor_axis_length > 0 - ] - rprops.sort(key=lambda x: x.mean_intensity, reverse=True) - return rprops - - -def bbox_overlaps(regA, regB): - """Returns True if the regions have overlapping bounding boxes""" - lrA, lcA, urA, ucA = regA.bbox - lrB, lcB, urB, ucB = regB.bbox - rA = np.array([lrA, urA]) - cA = np.array([lcA, ucA]) - return ((not ((rA > urB).all() or (rA < lrB).all())) and - (not ((cA > ucB).all() or (cA < lcB).all()))) - - -def region_iou(regA, regB): - if bbox_overlaps(regA, regB): - bb_lr, bb_lc, _, _ = np.stack((regA.bbox, regB.bbox)).min(axis=0) - _, _, bb_ur, bb_uc = np.stack((regA.bbox, regB.bbox)).max(axis=0) - bboxA = np.zeros((bb_ur - bb_lr, bb_uc - bb_lc), dtype='bool') - bboxB = bboxA.copy() - bb = regA.bbox - bboxA[bb[0] - bb_lr:bb[2] - bb_lr, - bb[1] - bb_lc:bb[3] - bb_lc] = regA.image - bb = regB.bbox - bboxB[bb[0] - bb_lr:bb[2] - bb_lr, - bb[1] - bb_lc:bb[3] - bb_lc] = regB.image - return np.sum(bboxA & bboxB) / np.sum(bboxA | bboxB) - else: - return 0.0 - - -def morph_ellipse_seg(cnn_outputs, - interior_threshold=0.9, - overlap_threshold=0.9, - bud_threshold=0.9, - bud_dilate=False, - bud_overlap=False, - isbud_threshold=0.5, - scaling=1.0, - offset=0): - """Segment cell outlines from morphology output of CNN as region ellipses - - Specify `overlap_threshold` or `bud_threshold` as `None` to ignore. - """ - - _, _, p_interior, p_overlap, _, p_bud = cnn_outputs - - if overlap_threshold is None: - p_overlap = None - - masks = morph_thresh_masks(p_interior, - interior_threshold=interior_threshold, - p_overlap=p_overlap, - overlap_threshold=overlap_threshold) - - if bud_threshold is not None: - if not bud_overlap: - p_overlap = None - - budmasks = morph_thresh_masks(p_bud, - interior_threshold=bud_threshold, - dilate=False, - p_overlap=p_overlap, - overlap_threshold=overlap_threshold) - - # Omit interior masks if they overlap with bud masks - masks = budmasks + unique_masks( - masks, budmasks, threshold=isbud_threshold) - - rprops = [regionprops(m.astype('int'))[0] for m in masks] - rprops = [ - r for r in rprops - if r.major_axis_length > 0 and r.minor_axis_length > 0 - ] - - outlines = [] - for region in rprops: - r, c = np.round(region.centroid).astype('int') - r_major = np.round(scaling * region.major_axis_length / 2 + - offset).astype('int') - r_minor = np.round(scaling * region.minor_axis_length / 2 + - offset).astype('int') - orientation = -region.orientation - rr, cc = ellipse_perimeter(r, - c, - r_major, - r_minor, - orientation=orientation, - shape=p_interior.shape) - outline = np.zeros(p_interior.shape, dtype='bool') - outline[rr, cc] = True - outlines.append(outline) - - return outlines - - def get_edge_force(rprop, shape): r_major = rprop.major_axis_length / 2 r_minor = rprop.minor_axis_length / 2 angle = -rprop.orientation nr = shape[0] nc = shape[1] - xmat = np.matmul(np.arange(0, nr)[:, nax], np.ones((1, nc))) - ymat = np.matmul(np.ones((nr, 1)), np.arange(0, nc)[nax, :]) + xmat = np.matmul(np.arange(0, nr)[:, None], np.ones((1, nc))) + ymat = np.matmul(np.ones((nr, 1)), np.arange(0, nc)[None, :]) xy = np.vstack([np.reshape(xmat, (1, -1)), np.reshape(ymat, (1, -1))]) rotmat = np.array([[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]]) radial_index = np.matmul( - rotmat, (xy - np.array(rprop.centroid)[:, nax])) / np.array( - [r_major, r_minor])[:, nax] + rotmat, (xy - np.array(rprop.centroid)[:, None])) / np.array( + [r_major, r_minor])[:, None] return np.reshape(1 - np.exp(-np.sum((radial_index)**2, 0)), (nr, nc)) @@ -337,6 +1171,7 @@ def morph_ac_seg(cnn_outputs, edge_thresh = p_edge > ac_edge_threshold over_thresh = p_over > ac_overlap_threshold + from skimage.segmentation import morphological_geodesic_active_contour for j, (force, mask) in enumerate(zip(e_forces, masks)): ij_edge_im = p_edge.copy() @@ -365,116 +1200,119 @@ def morph_ac_seg(cnn_outputs, return outlines -def eval_radial_spline(x, rho, phi): - """Evaluate a radial spline defined by radii and angles - rho: vector of radii for each point defining the spline - phi: angles in [-pi,pi) defining points of the spline - The spline is periodic across the boundary +def morph_ellipse_seg(cnn_outputs, + interior_threshold=0.9, + overlap_threshold=0.9, + bud_threshold=0.9, + bud_dilate=False, + bud_overlap=False, + isbud_threshold=0.5, + scaling=1.0, + offset=0): + """Segment cell outlines from morphology output of CNN as region ellipses + + Specify `overlap_threshold` or `bud_threshold` as `None` to ignore. """ - # Angles need to be in increasing order to correctly loop over boundary - order = np.argsort(phi) - rho = rho[order] - phi = phi[order] - offset = phi[0] + _, _, p_interior, p_overlap, _, p_bud = cnn_outputs - # Make the boundaries periodic - rho = np.concatenate((rho, rho[0, nax])) - phi = np.concatenate((phi - offset, (2 * np.pi,))) + if overlap_threshold is None: + p_overlap = None - tck = interpolate.splrep(phi, rho, per=True) - try: - return interpolate.splev(np.mod(x - offset, 2 * np.pi), tck) - except ValueError as err: - print('x:') - print(x) - print('rho:') - print(rho) - print('phi:') - print(phi) - raise err - - -def morph_radial_thresh_fit(outline, mask=None, rprops=None): - if mask is None and rprops is None: - mask = binary_fill_holes(outline) + masks = morph_thresh_masks(p_interior, + interior_threshold=interior_threshold, + p_overlap=p_overlap, + overlap_threshold=overlap_threshold) - if rprops is None: - rprops = regionprops(mask.astype('int'))[0] + if bud_threshold is not None: + if not bud_overlap: + p_overlap = None - r_maj = rprops.major_axis_length - nrays = 4 if r_maj < 5 else 6 if r_maj < 20 else 8 + budmasks = morph_thresh_masks(p_bud, + interior_threshold=bud_threshold, + dilate=False, + p_overlap=p_overlap, + overlap_threshold=overlap_threshold) - RL, CL, RU, CU = rprops.bbox - bbdiag = np.sqrt((RU - RL)**2 + (CU - CL)**2) - rr_max, cc_max = outline.shape + # Omit interior masks if they overlap with bud masks + masks = budmasks + unique_masks( + masks, budmasks, threshold=isbud_threshold) + + rprops = [regionprops(m.astype('int'))[0] for m in masks] + rprops = [ + r for r in rprops + if r.major_axis_length > 0 and r.minor_axis_length > 0 + ] + + outlines = [] + + from skimage.draw import ellipse_perimeter + for region in rprops: + r, c = np.round(region.centroid).astype('int') + r_major = np.round(scaling * region.major_axis_length / 2 + + offset).astype('int') + r_minor = np.round(scaling * region.minor_axis_length / 2 + + offset).astype('int') + orientation = -region.orientation + rr, cc = ellipse_perimeter(r, + c, + r_major, + r_minor, + orientation=orientation, + shape=p_interior.shape) + outline = np.zeros(p_interior.shape, dtype='bool') + outline[rr, cc] = True + outlines.append(outline) + + return outlines - astep = 2 * np.pi / nrays - angles = np.mod(rprops.orientation + np.pi, astep) + \ - np.arange(nrays)*astep - np.pi - centre = np.array(rprops.centroid) - # Improve accuracy of edge position by smoothing the outline image and using - # weighted averaging of pixel positions below: - outline = filters.gaussian(outline, 0.5) +def morph_thresh_seg(cnn_outputs, + interior_threshold=0.9, + overlap_threshold=0.9, + bud_threshold=0.9, + bud_dilate=False, + bud_overlap=False, + isbud_threshold=0.5): + """Segment cell outlines from morphology output of CNN by thresholding - radii = [] - for angle in angles: - ray = np.matmul(0.5 * np.arange(np.round(2 * bbdiag))[:, nax], - np.array((np.cos(angle), np.sin(angle)))[nax, :]) - ray = np.round(centre + ray).astype('int') - rr, cc = (ray[:, 0], ray[:, 1]) - ray = ray[(rr >= 0) & (rr < rr_max) & (cc >= 0) & (cc < cc_max), :] + Specify `overlap_threshold` or `bud_threshold` as `None` to ignore. + """ - edge_pix = np.flatnonzero( - np.squeeze(outline[ray[:, 0], ray[:, 1]]) > 0.01) + _, _, p_interior, p_overlap, _, p_bud = cnn_outputs - if len(edge_pix) == 0: - radii.append(np.NaN) - continue + if overlap_threshold is None: + p_overlap = None - ray = ray[edge_pix, :] - edge_pix = np.average(ray, - weights=outline[ray[:, 0], ray[:, 1]], - axis=0) - radii.append(np.sqrt(np.sum((edge_pix - centre)**2))) + if isbud_threshold is None and bud_threshold is not None: + p_interior = p_interior * (1 - p_bud) - radii = np.array(radii) + masks = morph_thresh_masks(p_interior, + interior_threshold=interior_threshold, + p_overlap=p_overlap, + overlap_threshold=overlap_threshold) - # Use linear interpolation for any missing radii (e.g., if region intersects - # with image boundary): - nanradii = np.isnan(radii) - if nanradii.all(): - radii = 0.1 * np.ones(angles.shape) - elif nanradii.any(): - radii = np.interp(angles, - angles[~nanradii], - radii[~nanradii], - period=2 * np.pi) + if bud_threshold is not None: + if not bud_overlap: + p_overlap = None - return radii, angles + budmasks = morph_thresh_masks(p_bud, + interior_threshold=bud_threshold, + dilate=False, + p_overlap=p_overlap, + overlap_threshold=overlap_threshold) + if isbud_threshold is not None: + # Omit interior masks if they overlap with bud masks + masks = unique_masks(masks, + budmasks, + iou_func=mask_containment, + threshold=isbud_threshold) + budmasks -def draw_radial(radii, angles, centre, shape): - mr, mc = shape - im = np.zeros(shape, dtype='bool') - neval = np.round(4 * np.pi * np.max(radii)).astype('int') - if neval > 1: - phi = np.linspace(0, 2 * np.pi, neval) - rho = eval_radial_spline(phi, radii, angles) - else: - phi = 0 - rho = 0 - rr = np.round(centre[0] + rho * np.cos(phi)).astype('int') - cc = np.round(centre[1] + rho * np.sin(phi)).astype('int') - rr[rr < 0] = 0 - rr[rr >= mr] = mr - 1 - cc[cc < 0] = 0 - cc[cc >= mc] = mc - 1 - im[rr, cc] = True - # valid = (rr >= 0) & (cc >= 0) & (rr < mr) & (cc < mc) - # im[rr[valid], cc[valid]] = True - return im + # Return only the mask outlines + outlines = [minimum_filter(m, footprint=squareconn) != m for m in masks] + + return outlines def morph_radial_thresh_seg(cnn_outputs, @@ -532,7 +1370,7 @@ def morph_radial_thresh_seg(cnn_outputs, outlines = [] for mask, outline, rp in zip(masks, mseg, rprops): - radii, angles = morph_radial_thresh_fit(outline, mask, rp) + radii, angles = guess_radial_edge(outline, mask, rp) if np.any(np.isnan(radii)): return mask, outline, rp outlines.append(draw_radial(radii, angles, rp.centroid, shape)) @@ -540,72 +1378,6 @@ def morph_radial_thresh_seg(cnn_outputs, return outlines -def thresh_seg(p_int, - interior_threshold=0.5, - connectivity=None, - nclosing=0, - nopening=0, - ndilate=0, - return_area=False): - """Segment cell outlines from morphology output of CNN by fitting radial - spline to threshold output - """ - - lbl, nmasks = label(p_int > interior_threshold, - return_num=True, - connectivity=connectivity) - for l in range(nmasks): - mask = lbl == l + 1 - if nclosing > 0: - mask = binary_closing(mask, iterations=nclosing) - if nopening > 0: - mask = binary_opening(mask, iterations=nopening) - if ndilate > 0: - mask = binary_dilation(mask, iterations=ndilate) - - if return_area: - yield mask, mask.sum() - else: - yield mask - - -def single_region_prop(mask): - return regionprops(mask.astype('int'))[0] - - -def outline_to_radial(outline, rprop, return_outline=False): - coords = (rprop.centroid,) + morph_radial_thresh_fit(outline, None, rprop) - if return_outline: - centroid, radii, angles = coords - outlines = draw_radial(radii, angles, centroid, outline.shape) - return coords, outlines - else: - return coords - - -def get_edge_scores(outlines, p_edge): - return [ - (p_edge * binary_dilation(o, iterations=2)).mean() for o in outlines - ] - - -def iterative_erosion(img, iterations=1, **kwargs): - if iterations < 1: - return img - - for i in range(iterations): - img = erosion(img, **kwargs) - return img - - -def iterative_dilation(img, iterations=1, **kwargs): - if iterations is None: - return img - for _ in range(iterations): - img = dilation(img, **kwargs) - return img - - def morph_seg_grouped(pred, flattener, cellgroups=['large', 'medium', 'small'], @@ -626,6 +1398,8 @@ def morph_seg_grouped(pred, return_coords=False): """Morphological segmentation for model predictions of flattener targets + DEPRECATED. Use morph_thres_seg.MorphSegGrouped class. + :param pred: list of prediction images (ndarray with shape (x, y)) matching `flattener.names()` :param flattener: an instance of `SegmentationFlattening` defining the @@ -737,13 +1511,10 @@ def morph_seg_grouped(pred, masks_areas = [ (m, a) - for m, a in thresh_seg(p_int, - interior_threshold=thresh or 0.5, - nclosing=nc or 0, - nopening=no or 0, - ndilate=max_ne, - return_area=True, - connectivity=conn) + for m, a in threshold_segmentation( + p_int, interior_threshold=thresh or 0.5, + nclosing=nc or 0, nopening=no or 0, + ndilate=max_ne, return_area=True, connectivity=conn) if a >= lower and a < upper ] @@ -752,17 +1523,15 @@ def morph_seg_grouped(pred, else: masks, areas = [], [] - edges = [binary_edge(m) for m in masks] - if fit_radial: - rprops = [single_region_prop(m) for m in masks] coords, edges = list( zip(*[ - outline_to_radial(edge, rprop, return_outline=True) - for edge, rprop in zip(edges, rprops) + mask_to_knots(mask) + for mask in masks ])) or ([], []) masks = [binary_fill_holes(o) for o in edges] else: + edges = [binary_edge(m) for m in masks] edges = [e | (border_rect & m) for e, m in zip(edges, masks)] coords = [tuple()] * len(masks) @@ -847,144 +1616,6 @@ def morph_seg_grouped(pred, return edges -def rc_to_radial(rr_cc, centre): - """Helper function to convert row-column coords to radial coords - """ - rr, cc = rr_cc - rloc, cloc = centre - rr = rr - rloc - cc = cc - cloc - return np.sqrt(rr**2 + cc**2), np.arctan2(cc, rr) - - -def prior_resid_weight(resid, gauss_scale=5, exp_scale=1): - """Weights radial residuals to bias towards the initial guess - - Weight decays as a gaussian for positive residuals and exponentially for - negative residuals. So assuming `resid = rho_guess - rho_initial`, then - larger radii are favoured. - """ - return (resid >= 0) * np.exp(-resid**2 / gauss_scale) \ - + (resid < 0) * np.exp(resid / exp_scale) - - -def adj_rspline_coords(adj, ref_radii, ref_angles): - """Map optimisation-space radial spline params to standard values - - Params in optimisation space are specified relative to reference radii and - reference angles. If constrained to [-1, 1], optimisation parameters will - allow a 30% change in radius, or change in angle up to 1/4 the distance - between consecutive angles. - """ - npoints = len(ref_radii) - return ( - # allow up to 30% change in radius - ref_radii * (1 + 0.3 * adj[:npoints]), - # allow changes in angle up to 1/4 the distance between points - ref_angles + adj[npoints:] * np.pi / (2 * npoints)) - - -def adj_rspline_resid(adj, rho, phi, probs, ref_radii, ref_angles): - """Weighted residual for radial spline optimisation - - Optimisation params (`adj`) are mapped according to `adj_rspline_coords`. - Target points are given in radial coordinates `rho` and `phi` with weights - `probs`. Optimisation is defined relative to radial spline params - `ref_radii` and `ref_angles`. - """ - radii, angles = adj_rspline_coords(adj, ref_radii, ref_angles) - return probs * (rho - eval_radial_spline(phi, radii, angles)) - - -def refine_radial_grouped(grouped_coords, grouped_p_edges): - """Refine initial radial spline by optimising to predicted edge - - Neighbouring groups are used to re-weight predicted edges belonging to - other cells using the initial guess - """ - - # Determine edge pixel locations and probabilities from NN prediction - p_edge_locs = [np.where(p_edge > 0.2) for p_edge in grouped_p_edges] - p_edge_probs = [ - p_edge[rr, cc] - for p_edge, (rr, cc) in zip(grouped_p_edges, p_edge_locs) - ] - - p_edge_count = [len(rr) for rr, _ in p_edge_locs] - - opt_coords = [] - ngroups = len(grouped_coords) - for g, g_coords in enumerate(grouped_coords): - # If this group has no predicted edges, keep initial and skip - if p_edge_count[g] == 0: - opt_coords.append(g_coords) - continue - - # Compile a list of all cells in this and neighbouring groups - nbhd = list( - chain.from_iterable([ - [((gi, ci), coords) - for ci, coords in enumerate(grouped_coords[gi])] - for gi in range(max(g - 1, 0), min(g + 2, ngroups)) - if - p_edge_count[gi] > 0 # only keep if there are predicted edges - ])) - if len(nbhd) > 0: - nbhd_ids, nbhd_coords = zip(*nbhd) - else: - nbhd_ids, nbhd_coords = 2 * [[]] - - # Calculate edge pixels in radial coords for all cells in this and - # neighbouring groups: - radial_edges = [ - rc_to_radial(p_edge_locs[g], centre) - for centre, _, _ in nbhd_coords - ] - - # Calculate initial residuals and prior weights - resids = [ - rho - eval_radial_spline(phi, radii, angles) - for (rho, phi), (_, radii, - angles) in zip(radial_edges, nbhd_coords) - ] - indep_weights = [prior_resid_weight(r) for r in resids] - - probs = p_edge_probs[g] - - g_opt_coords = [] - for c, (centre, radii, angles) in enumerate(g_coords): - ind = nbhd_ids.index((g, c)) - rho, phi = radial_edges[ind] - p_weighted = probs * indep_weights[ind] - other_weights = indep_weights[:ind] + indep_weights[ind + 1:] - if len(other_weights) > 0: - p_weighted *= (1 - np.mean(other_weights, axis=0)) - - # Remove insignificant fit data - signif = p_weighted > 0.1 - if signif.sum() < 10: - # With insufficient data, skip optimisation - g_opt_coords.append((centre, radii, angles)) - continue - p_weighted = p_weighted[signif] - phi = phi[signif] - rho = rho[signif] - - nparams = len(radii) + len(angles) - opt = least_squares(adj_rspline_resid, - np.zeros(nparams), - bounds=(-np.ones(nparams), np.ones(nparams)), - args=(rho, phi, p_weighted, radii, angles), - ftol=5e-2) - - g_opt_coords.append((centre,) + - adj_rspline_coords(opt.x, radii, angles)) - - opt_coords.append(g_opt_coords) - - return opt_coords - - def morph_radial_edge_seg(cnn_outputs): RL, CL, RU, CU = rp.bbox Rext = np.ceil(0.25 * (RU - RL)).astype('int') diff --git a/python/baby/server.py b/python/baby/server.py index 3b717ecf53db94084a8cccf8e6d37e0b31155c51..740c87bdaf17d946ca016420559f24f42298c521 100644 --- a/python/baby/server.py +++ b/python/baby/server.py @@ -2,13 +2,15 @@ # If you publish results that make use of this software or the Birth Annotator # for Budding Yeast algorithm, please cite: -# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S -# Swain, 2021, Birth Annotator for Budding Yeast (in preparation). +# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +# and Swain, P.S. (2023). Determining growth rates from bright-field images of +# budding cells through identifying overlaps. eLife. 12:e79812. +# https://doi.org/10.7554/eLife.79812 # # # The MIT License (MIT) # -# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021 +# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to @@ -45,6 +47,8 @@ from functools import reduce from operator import mul import numpy as np +from baby import modelsets +from baby import __version__ from baby.brain import BabyBrain from baby.crawler import BabyCrawler from baby.utils import jsonify @@ -61,7 +65,7 @@ SERVER_DIR = dirname(__file__) MAX_RUNNERS = 3 MAX_SESSIONS = 20 SLEEP_TIME = 0.2 # time between threaded checks for data availability -MAX_ATTEMPTS = 300 # allows for 60s delay before timing out +MAX_ATTEMPTS = 3000 # allows for 10m delay before timing out MAX_IMG_SIZE = 100 * 1024 * 1024 # allows for raw image sizes up to 100 MB DIMS_ERROR_MSG = '"dims" must be a length 4 integer array: [ntraps, width, height, depth]' @@ -90,7 +94,7 @@ class PredMissingError(Exception): class TaskMaster(object): - def __init__(self): + def __init__(self, njobs=-2): self._lock = threading.Lock() self._runner_pool = [] @@ -100,12 +104,21 @@ class TaskMaster(object): self.tf_session = None self.tf_graph = None self.tf_version = (0, 0, 0) + self.njobs = njobs + self._modelsets_ids = None + self._modelsets_meta = None @property - def modelsets(self): - with open(join(SERVER_DIR, 'modelsets.json'), 'rt') as f: - modelsets = json.load(f) - return modelsets + def modelset_ids(self): + if self._modelsets_ids is None: + self._modelsets_ids = modelsets.ids() + return self._modelsets_ids + + @property + def modelsets_meta(self): + if self._modelsets_meta is None: + self._modelsets_meta = modelsets.meta() + return self._modelsets_meta def new_session(self, model_name): # Clean up old sessions that exceed the maximum allowed number @@ -125,16 +138,11 @@ class TaskMaster(object): return sessionid - def ensure_runner(self, model_name, modelsets=None): + def ensure_runner(self, model_name): if model_name in self.runners: print('Model "{}" already loaded. Skipping...'.format(model_name)) return - if modelsets is None: - modelsets = self.modelsets - - assert model_name in modelsets - # Clean up old runners that exceed the maximum allowed number nrunners = len(self._runner_pool) with self._lock: @@ -166,9 +174,9 @@ class TaskMaster(object): # Load BabyRunner print('Starting new runner for model "{}"...'.format(model_name)) - baby = BabyBrain(**modelsets[model_name], - session=self.tf_session, graph=self.tf_graph, - suppress_errors=True, error_dump_dir=ERR_DUMP_DIR) + baby = modelsets.get(model_name, session=self.tf_session, + graph=self.tf_graph, suppress_errors=True, + error_dump_dir=ERR_DUMP_DIR) if self.runners.get(model_name) == 'pending': with self._lock: @@ -214,7 +222,7 @@ class TaskMaster(object): # tf.keras.backend.set_session(self.tf_session) t_start = time.perf_counter() - pred = crawler.step(img, parallel=True, **kwargs) + pred = crawler.step(img, parallel=True, njobs=self.njobs, **kwargs) t_elapsed = time.perf_counter() - t_start print('...images segmented in {:.3f} seconds.'.format(t_elapsed)) @@ -243,28 +251,33 @@ class TaskMaster(object): @routes.get('/') async def version(request): - return web.json_response({'baby': 'v1.0'}) + return web.json_response({'baby': __version__}) @routes.get('/models') async def get_modelsets(request): taskmstr = request.app['TaskMaster'] - return web.json_response(list(taskmstr.modelsets.keys())) + with_meta = False + if 'meta' in request.query: + with_meta = request.query['meta'] == 'true' + if with_meta: + return web.json_response(taskmstr.modelsets_meta) + else: + return web.json_response(taskmstr.modelset_ids) @routes.get('/session/{model}') async def get_session(request): model_name = request.match_info['model'] taskmstr = request.app['TaskMaster'] - modelsets = taskmstr.modelsets - if model_name not in modelsets: + if model_name not in taskmstr.modelset_ids: raise web.HTTPNotFound(text='"{}" model is unknown'.format(model_name)) # Ensure model is loaded in another thread loop = asyncio.get_event_loop() loop.run_in_executor(request.app['Executor'], - taskmstr.ensure_runner, model_name, modelsets) + taskmstr.ensure_runner, model_name) sessionid = taskmstr.new_session(model_name) print('Creating new session "{}" with model "{}"...'.format( @@ -383,13 +396,6 @@ async def segment(request): val = await field.read(decode=True) kwargs[field.name] = json.loads(val) - if request.query.get('test', False): - print('Data received. Writing test image to "baby-server-test.png"...') - from imageio import imwrite - imwrite(join(SERVER_DIR, 'baby-server-test.png'), - np.squeeze(img[0, :, :, 0])) - return web.json_response({'status': 'test image written'}) - print('Data received. Segmenting {} images...'.format(len(img))) loop.run_in_executor(executor, taskmstr.segment, sessionid, img, kwargs) @@ -441,12 +447,21 @@ async def get_segmentation(request): app = web.Application() app.add_routes(routes) -app['TaskMaster'] = TaskMaster() -app['Executor'] = ThreadPoolExecutor(2) def main(): + from argparse import ArgumentParser + parser = ArgumentParser('Start BABY server for receiving segmentation requests') + parser.add_argument('-p', '--port', type=int, default=5101, + help='port to bind the server to') + parser.add_argument('-n', '--njobs', type=int, default=-2, + help='number of parallel jobs to run when processing') + args = parser.parse_args() + import tensorflow as tf + app['TaskMaster'] = TaskMaster(njobs=args.njobs) + app['Executor'] = ThreadPoolExecutor(2) + tf_version = tuple(int(v) for v in tf.version.VERSION.split('.')) app['TaskMaster'].tf_version = tf_version @@ -486,4 +501,4 @@ def main(): lfh.setFormatter(lff) logging.getLogger().addHandler(lfh) - web.run_app(app, port=5101) + web.run_app(app, port=args.port) diff --git a/python/baby/speed_tests.py b/python/baby/speed_tests.py index e4adf6910eeced9d4726a1b5905b0ab9feee1f62..2256e64b078cdb6c5ccf729ca84e57b7f5c15689 100644 --- a/python/baby/speed_tests.py +++ b/python/baby/speed_tests.py @@ -2,13 +2,15 @@ # If you publish results that make use of this software or the Birth Annotator # for Budding Yeast algorithm, please cite: -# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S -# Swain, 2021, Birth Annotator for Budding Yeast (in preparation). +# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +# and Swain, P.S. (2023). Determining growth rates from bright-field images of +# budding cells through identifying overlaps. eLife. 12:e79812. +# https://doi.org/10.7554/eLife.79812 # # # The MIT License (MIT) # -# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021 +# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to diff --git a/python/baby/tracker/benchmark.py b/python/baby/tracker/benchmark.py index a701e34eac5f0ce6c1fc05717e1c6244da950fcb..c2a6e25df6698c4c0f136e36789747c0e0054b10 100644 --- a/python/baby/tracker/benchmark.py +++ b/python/baby/tracker/benchmark.py @@ -1,12 +1,14 @@ # If you publish results that make use of this software or the Birth Annotator # for Budding Yeast algorithm, please cite: -# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S -# Swain, 2021, Birth Annotator for Budding Yeast (in preparation). +# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +# and Swain, P.S. (2023). Determining growth rates from bright-field images of +# budding cells through identifying overlaps. eLife. 12:e79812. +# https://doi.org/10.7554/eLife.79812 # # # The MIT License (MIT) # -# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021 +# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to diff --git a/python/baby/tracker/core.py b/python/baby/tracker/core.py index 50f629982f62e497ea543a254ceb12ea9e25334c..4402057589613e085fe553edbff3520a76623d9b 100644 --- a/python/baby/tracker/core.py +++ b/python/baby/tracker/core.py @@ -1,14 +1,14 @@ -#!/usr/bin/env python - # If you publish results that make use of this software or the Birth Annotator # for Budding Yeast algorithm, please cite: -# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S -# Swain, 2021, Birth Annotator for Budding Yeast (in preparation). +# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +# and Swain, P.S. (2023). Determining growth rates from bright-field images of +# budding cells through identifying overlaps. eLife. 12:e79812. +# https://doi.org/10.7554/eLife.79812 # # # The MIT License (MIT) # -# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021 +# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to @@ -27,6 +27,8 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS # IN THE SOFTWARE. +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. ''' TrackerCoordinator class to coordinate cell tracking and bud assignment. @@ -43,8 +45,11 @@ from sklearn.ensemble import RandomForestClassifier from sklearn.svm import SVC from baby.errors import BadOutput from baby.tracker.utils import calc_barycentre, pick_baryfun +from baby import modelsets + + +DEFAULT_MODELSET = 'yeast-alcatras-brightfield-EMCCD-60x-5z' -models_path = join(dirname(__file__), '../models') class FeatureCalculator: ''' @@ -94,12 +99,15 @@ class FeatureCalculator: if 'area' in self.outfeats: self.aind = self.outfeats.index('area') - def load_model(self, path, fname): + def load_model(self, fname, path=None): + if path is not None: model_file = join(path, fname) - with open(model_file, 'rb') as file_to_load: - model = pickle.load(file_to_load) + else: + model_file = fname + with open(model_file, 'rb') as file_to_load: + model = pickle.load(file_to_load) - return model + return model def calc_feats_from_mask(self, masks, feats2use=None, trapfeats=None, norm=True, px_size=None): @@ -216,6 +224,7 @@ class FeatureCalculator: return feats + class CellTracker(FeatureCalculator): ''' Class used to manage cell tracking. You can call it using an existing model or @@ -264,24 +273,26 @@ class CellTracker(FeatureCalculator): if extra_feats is None: extra_feats = () - if type(model) is str or type(model) is PosixPath: - with open(Path(model), 'rb') as f: - model = pickle.load(f) + if isinstance(model, (Path, str)): + model = self.load_model(model) - if type(bak_model) is str or type(bak_model) is PosixPath: - with open(Path(bak_model), 'rb') as f: - bak_model = pickle.load(f) + if isinstance(bak_model, (Path, str)): + bak_model = self.load_model(bak_model) if aweights is None: self.aweights = None if feats2use is None: # Ignore this block when training if model is None: - model = self.load_model( models_path, - 'ct_rf_20210201_12.pkl') + default_params = modelsets.get_params(DEFAULT_MODELSET) + model_file = default_params['celltrack_model_file'] + model_file = modelsets.resolve(model_file, DEFAULT_MODELSET) + model = self.load_model(model_file) if bak_model is None: - bak_model = self.load_model( models_path, - 'ct_rf_20210125_9.pkl') + default_params = modelsets.get_params(DEFAULT_MODELSET) + bak_file = default_params['celltrack_backup_model_file'] + bak_file = modelsets.resolve(bak_file, DEFAULT_MODELSET) + bak_model = self.load_model(bak_file) self.model = model self.bak_model = bak_model @@ -542,17 +553,20 @@ class CellTracker(FeatureCalculator): return ([], [], max_lbl) return (new_lbls, new_feats, new_max) + class BudTracker(FeatureCalculator): def __init__(self, model=None, feats2use=None, **kwargs): - if model is None: - model_file = join(models_path, - 'mb_model_20201022.pkl') - with open(model_file, 'rb') as file_to_load: - model = pickle.load(file_to_load) + if isinstance(model, (Path, str)): + model = self.load_model(model) + elif model is None: + default_params = modelsets.get_params(DEFAULT_MODELSET) + model_file = default_params['budassign_model_file'] + model_file = modelsets.resolve(model_file, DEFAULT_MODELSET) + model = self.load_model(model_file) self.model = model if feats2use is None: @@ -718,6 +732,7 @@ class BudTracker(FeatureCalculator): return r_points + class MasterTracker(FeatureCalculator): ''' Coordinates the data transmission from CellTracker to BudTracker to @@ -842,6 +857,8 @@ class MasterTracker(FeatureCalculator): p_budneck, p_bud, masks, feats[:, self.bt_idx]) lblinds = np.array(new_lbls) - 1 # new_lbls are indexed from 1 lifetime[lblinds] += 1 + # TODO: the following may lead to values of p_is_mother higher + # than one, and hence negative ba_probs below p_is_mother[lblinds] = np.maximum(p_is_mother[lblinds], ba_probs.sum(1)) p_was_bud[lblinds] = np.maximum(p_was_bud[lblinds], @@ -875,6 +892,9 @@ class MasterTracker(FeatureCalculator): if assign_mothers: if max_lbl > 0: # Calculate mother assignments for this time point + # TODO: this should proceed more like the IoU assignment + # algorithm to guarantee that two buds will not be assigned to + # the same mother ma = ba_cum[0:max_lbl, 0:max_lbl].argmax(0) + 1 # Cell must have been a bud and been present for at least # min_bud_tps @@ -895,6 +915,163 @@ class MasterTracker(FeatureCalculator): return output +class MMTracker(FeatureCalculator): + ''' + Tracker specialised for mother machines. Does not use CellTracker or + BudTracker, but a custom algorithm adapted from Sean Murray's group (Robin + Koehler's code) + + input + :ctrack_args: dict with arguments to pass on to CellTracker constructor + if None it passes all the features to use + :btrack_args: dict with arguments to pass on to BudTracker constructor + if None it passes all the features to use + :**kwargs: additional arguments passed to FeatureCalculator constructor + ''' + def __init__(self, growthrate=0.1, tol_imbalance=20, nstepsback=1, **kwargs): + self.nstepsback = nstepsback + self.growthrate = growthrate + self.tol_imbalance = tol_imbalance + feats2use = ('centroid', 'major_axis_length', 'minor_axis_length') + super().__init__(feats2use, **kwargs) + + def step_trackers(self, + masks, + p_budneck=None, + p_bud=None, + state=None, + assign_mothers=False, + return_baprobs=False, + keep_full_state=False): + ''' + Calculate features and track cells and budassignments + + For compatibility with the existing MasterTracker, we retain some + unused input arguments. + + :masks: 3d ndarray (ncells, size_x, size_y) containing cell masks + :p_budneck: NOT REQUIRED + :p_bud: NOT REQUIRED + :state: running state for the tracker, or None for initialisation + :assign_mothers: whether to include mother assignments in the returned + returns + :return_baprobs: whether to include bud assignment probability matrix + in the returned output + + returns a dict consisting of + + :cell_label: list of int, the tracked global ID for each cell mask + :state: the updated state to be used in a subsequent step + :mother_assign: (optional) list of int, specifying the assigned mother + for each cell + :p_bud_assign: (optional) matrix (list of lists of floats), bud assignment + probability matrix from `predict_mother_bud` + ''' + + gr = np.exp(self.growthrate) + tol = self.tol_imbalance + + if state is None: + state = {} + + new_lbl = state.get('new_lbl', 1) + cell_lbls = state.get('cell_lbls', [[]]) + prev_feats = state.get('prev_feats', [np.zeros((0, self.ntfeats))]) + mothers = state.get('mothers', []) + + # Get features for cells at this time point + feats = self.calc_feats_from_mask(masks, norm=False) + + N_prev = prev_feats[-1].shape[0] + N_this = feats.shape[0] + + prev_lbls = np.array(cell_lbls[-1]) + this_lbls = -np.ones(N_this, dtype='int64') + Ci = self.xind + cInd_prev = np.argsort(-prev_feats[-1][:,Ci]) + cInd_this = np.argsort(-feats[:,Ci]) + Wi = self.outfeats.index('major_axis_length') + Ws_prev = prev_feats[-1][:,Wi][cInd_prev] + Ws_this = feats[:,Wi][cInd_this] + c_prev, c_this = 0, 0 + aggW_prev, aggW_this = 0, 0 + aggC_prev, aggC_this = [], [] + + while True: + if aggW_prev < aggW_this: + if c_prev >= N_prev: + break + aggW_prev += Ws_prev[c_prev] + aggC_prev.append(cInd_prev[c_prev]) + c_prev += 1 + else: + if c_this >= N_this: + break + aggW_this += Ws_this[c_this] + aggC_this.append(cInd_this[c_this]) + c_this += 1 + + balance = aggW_prev * gr - aggW_this + + # if balance is found + if balance >= -tol and balance <= tol and aggW_prev != 0 and aggW_this != 0: + if len(aggC_prev) == 1 and len(aggC_this) == 1: + this_lbls[aggC_this[0]] = prev_lbls[aggC_prev[0]] + elif len(aggC_prev) == 1 and len(aggC_this) == 2: + this_lbls[aggC_this[0]] = prev_lbls[aggC_prev[0]] + this_lbls[aggC_this[1]] = new_lbl + mothers.append((new_lbl, prev_lbls[aggC_prev[0]])) + new_lbl += 1 + else: + # In any other ambiguous cases, mark cells simply as new tracks + for c in aggC_this: + this_lbls[c] = new_lbl + new_lbl += 1 + + aggW_prev, aggW_this = 0, 0 + aggC_prev, aggC_this = [], [] + + # Mark any remaining unbalanced/extra cells as new tracks + for c in aggC_this + cInd_this[list(range(c_this, N_this))].tolist(): + this_lbls[c] = new_lbl + new_lbl += 1 + + this_lbls = this_lbls.tolist() + + if not keep_full_state: + cell_lbls = cell_lbls[-self.nstepsback:] + prev_feats = prev_feats[-self.nstepsback:] + + # Finally update the state + state = { + 'new_lbl': new_lbl, + 'cell_lbls': cell_lbls + [this_lbls], + 'prev_feats': prev_feats + [feats], + 'mothers': mothers + } + + output = { + 'cell_label': this_lbls, + 'state': state + } + + if assign_mothers: + ma = np.zeros(new_lbl - 1, dtype='int64') + for d, m in mothers: + ma[d - 1] = m + + if np.any(ma == np.arange(1, len(ma) + 1)): + raise BadOutput('Daughter has been assigned as mother to itself') + + output['mother_assign'] = ma.tolist() + + if return_baprobs: + ba_probs = np.zeros((N_this, N_this)) + output['p_bud_assign'] = ba_probs.tolist() + + return output + + # Helper functions def switch_case_nfeats(nfeats): diff --git a/python/baby/tracker/training.py b/python/baby/tracker/training.py index 904a2e99a3f540e5147a99222fc59062c5f0de35..ddea123d8b94e4127e270cac3907a5d806e2c3aa 100644 --- a/python/baby/tracker/training.py +++ b/python/baby/tracker/training.py @@ -2,13 +2,15 @@ # If you publish results that make use of this software or the Birth Annotator # for Budding Yeast algorithm, please cite: -# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S -# Swain, 2021, Birth Annotator for Budding Yeast (in preparation). +# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +# and Swain, P.S. (2023). Determining growth rates from bright-field images of +# budding cells through identifying overlaps. eLife. 12:e79812. +# https://doi.org/10.7554/eLife.79812 # # # The MIT License (MIT) # -# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021 +# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to @@ -42,7 +44,7 @@ from baby.io import load_tiled_image from baby.tracker.utils import pick_baryfun, calc_barycentre from baby.training.utils import TrainValProperty, TrainValTestProperty from baby.errors import BadProcess, BadParam -from .core import CellTracker, BudTracker +from .core import CellTracker from .benchmark import CellBenchmarker @@ -55,8 +57,14 @@ from sklearn.svm import SVC#, LinearSVC from sklearn.model_selection import GridSearchCV from sklearn.metrics import ( make_scorer, fbeta_score, accuracy_score, balanced_accuracy_score, - precision_score, recall_score, f1_score, plot_precision_recall_curve + precision_score, recall_score, f1_score ) +import sklearn +if int(sklearn.__version__[0]) > 0: + from sklearn.metrics import PrecisionRecallDisplay + plot_precision_recall_curve = PrecisionRecallDisplay.from_estimator +else: + from sklearn.metrics import plot_precision_recall_curve class CellTrainer(CellTracker): ''' @@ -252,7 +260,7 @@ class CellTrainer(CellTracker): def explore_hyperparams(self, model_type = 'rf'): self.model_type = model_type truth, data = *zip(*self.train), - if self.model_type is 'SVC': + if self.model_type == 'SVC': model = SVC(probability = True, shrinking=False, verbose=True, random_state=1) param_grid = { @@ -263,7 +271,7 @@ class CellTrainer(CellTracker): 'gamma': [1, 0.01, 0.0001], 'kernel': ['rbf', 'sigmoid'] } - elif model_type is 'rf': + elif model_type == 'rf': model = RandomForestClassifier(n_estimators=15, criterion='gini', max_depth=3, @@ -319,208 +327,6 @@ class CellTrainer(CellTracker): -class BudTrainer(BudTracker): - ''' - :props_file: File where generated property table will be saved - :kwargs: Additional arguments passed onto the parent Tracker; `px_size` is - especially useful. - ''' - - def __init__(self, props_file=None, **kwargs): - super().__init__(**kwargs) - # NB: we inherit self.feats2use from CellTracker class - self.props_file = props_file - self.rf_feats = ["p_bud_mat", "size_ratio_mat", "p_budneck_mat", - "budneck_ratio_mat", "adjacency_mat"] - - @property - def props_file(self): - return getattr(self, '_props_file') - - @props_file.setter - def props_file(self, filename): - if filename is not None: - self._props_file = Path(filename) - - @property - def props(self): - if getattr(self, '_props', None) is None: - if self.props_file and self.props_file.is_file(): - self.props = pd.read_csv(self.props_file) - else: - raise BadProcess( - 'The property table has not yet been generated') - return self._props - - @props.setter - def props(self, props): - props = pd.DataFrame(props) - required_cols = self.rf_feats + ['is_mb_pair', 'validation'] - if not all(c in props for c in required_cols): - raise BadParam( - '"props" does not have all required columns: {}'.format( - ', '.join(required_cols))) - self._props = props - if self.props_file: - props.to_csv(self.props_file) - - def generate_property_table(self, data, flattener, val_data=None): - '''Generates properties table that gets used for training - - :data: List or generator of `baby.training.SegExample` tuples - :flattener: Instance of a `baby.preprocessing.SegmentationFlattening` - object describing the targets of the CNN in data - ''' - tnames = flattener.names() - i_budneck = tnames.index('bud_neck') - bud_target = 'sml_fill' if 'sml_fill' in tnames else 'sml_inte' - i_bud = tnames.index(bud_target) - - if val_data is not None: - data = TrainValProperty(data, val_data) - if isinstance(data, (TrainValProperty, TrainValTestProperty)): - data = chain(zip(repeat(False), data.train), - zip(repeat(True), data.val)) - else: - data = zip(repeat(None), data) - - p_list = [] - for is_val, seg_example in data: - if len(seg_example.target) < 2: - # Skip if no pairs are present - continue - mb_stats = self.calc_mother_bud_stats(seg_example.pred[i_budneck], - seg_example.pred[i_bud], seg_example.target) - p = pd.DataFrame(mb_stats, columns=self.rf_feats) - p['validation'] = is_val - - # "cellLabels" specifies the label for each mask - cell_labels = seg_example.info.get('cellLabels', []) or [] - if type(cell_labels) is int: - cell_labels = [cell_labels] - # "buds" specifies the label of the bud for each mask - buds = seg_example.info.get('buds', []) or [] - if type(buds) is int: - buds = [buds] - - # Build a ground truth matrix identifying mother-bud pairs - ncells = len(seg_example.target) - is_mb_pair = np.zeros((ncells, ncells), dtype=bool) - mb_inds = [ - (i, cell_labels.index(b)) - for i, b in enumerate(buds) - if b > 0 and b in cell_labels - ] - if len(mb_inds) > 0: - mother_inds, bud_inds = zip(*mb_inds) - is_mb_pair[mother_inds, bud_inds] = True - p['is_mb_pair'] = is_mb_pair.flatten() - - # Ignore any rows containing NaNs - nanrows = np.isnan(mb_stats).any(axis=1) - if (p['is_mb_pair'] & nanrows).any(): - id_keys = ('experimentID', 'position', 'trap', 'tp') - info = seg_example.info - img_id = ' / '.join( - [k + ': ' + str(info[k]) for k in id_keys if k in info]) - warn('Mother-bud pairs omitted due to feature NaNs') - print('Mother-bud pair omitted in "{}"'.format(img_id)) - p = p.loc[~nanrows, :] - p_list.append(p) - - props = pd.concat(p_list, ignore_index=True) - # TODO: should search for any None values in validation column and - # assign a train-validation split to those rows - - self.props = props # also saves - - def explore_hyperparams(self, hyper_param_target='precision'): - # Train bud assignment model on validation data, since this more - # closely represents real-world performance of the CNN: - data = self.props.loc[self.props['validation'], self.rf_feats] - truth = self.props.loc[self.props['validation'], 'is_mb_pair'] - - rf = RandomForestClassifier(n_estimators=15, - criterion='gini', - max_depth=3, - class_weight='balanced') - - param_grid = { - 'n_estimators': [6, 15, 50, 100], - 'max_features': ['auto', 'sqrt', 'log2'], - 'max_depth': [2, 3, 4], - 'class_weight': [None, 'balanced', 'balanced_subsample'] - } - - def get_balanced_best_index(cv_results_): - '''Find a model balancing F1 score and speed''' - df = pd.DataFrame(cv_results_) - best_score = df.iloc[df.mean_test_f1.idxmax(), :] - thresh = best_score.mean_test_f1 - 0.1 * best_score.std_test_f1 - return df.loc[df.mean_test_f1 > thresh, 'mean_score_time'].idxmin() - - self._rf = GridSearchCV(estimator=rf, param_grid=param_grid, cv=5, - scoring=SCORING_METRICS, refit=hyper_param_target) - self._rf.fit(data, truth) - - df = pd.DataFrame(self._rf.cv_results_) - disp_cols = [c for c in df.columns if c.startswith('mean_') - or c.startswith('param_')] - print(df.loc[self._rf.best_index_, disp_cols]) - - def performance(self): - if not isinstance(getattr(self, '_rf', None), GridSearchCV): - raise BadProcess('"explore_hyperparams" has not been run') - - best_rf = self._rf.best_estimator_ - isval = self.props['validation'] - data = self.props.loc[~isval, self.rf_feats] - truth = self.props.loc[~isval, 'is_mb_pair'] - valdata = self.props.loc[isval, self.rf_feats] - valtruth = self.props.loc[isval, 'is_mb_pair'] - metrics = tuple(SCORING_METRICS.values()) - return TrainValProperty( - Score(*(m(best_rf, data, truth) for m in metrics)), - Score(*(m(best_rf, valdata, valtruth) for m in metrics))) - - def plot_PR(self): - best_rf = self._rf.best_estimator_ - isval = self.props['validation'] - valdata = self.props.loc[isval, self.rf_feats] - valtruth = self.props.loc[isval, 'is_mb_pair'] - plot_precision_recall_curve(best_rf, valdata, valtruth) - - def save_model(self, filename): - f = open(filename, 'wb') - pickle.dump(self._rf.best_estimator_, f) - - -SCORING_METRICS = { - 'accuracy': make_scorer(accuracy_score), - 'balanced_accuracy': make_scorer(balanced_accuracy_score), - 'precision': make_scorer(precision_score), - 'recall': make_scorer(recall_score), - 'f1': make_scorer(f1_score), - 'f0_5': make_scorer(fbeta_score, beta=0.5), - 'f2': make_scorer(fbeta_score, beta=2) -} - - -class Score(NamedTuple): - accuracy: float - balanced_accuracy: float - precision: float - recall: float - F1: float - F0_5: float - F2: float - - def __str__(self): - return 'Score({})'.format(', '.join([ - '{}={:.3f}'.format(k, v) for k, v in self._asdict().items() - ])) - - def get_ground_truth(cell_labels, buds): ncells = len(cell_labels) diff --git a/python/baby/tracker/utils.py b/python/baby/tracker/utils.py index 12c1dfc9c45f67d43c13a1156a00fe5c056f90aa..17b5c06670e9f502e519f5983b2f009b3eb1ca87 100644 --- a/python/baby/tracker/utils.py +++ b/python/baby/tracker/utils.py @@ -1,12 +1,14 @@ # If you publish results that make use of this software or the Birth Annotator # for Budding Yeast algorithm, please cite: -# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S -# Swain, 2021, Birth Annotator for Budding Yeast (in preparation). +# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +# and Swain, P.S. (2023). Determining growth rates from bright-field images of +# budding cells through identifying overlaps. eLife. 12:e79812. +# https://doi.org/10.7554/eLife.79812 # # # The MIT License (MIT) # -# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021 +# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to diff --git a/python/baby/training/__init__.py b/python/baby/training/__init__.py index 09abf58649a4f6bfc38987e512e62f4332cfb07d..83aa435af93bfe80874eca5ae47c9566eff5063b 100644 --- a/python/baby/training/__init__.py +++ b/python/baby/training/__init__.py @@ -1,12 +1,14 @@ # If you publish results that make use of this software or the Birth Annotator # for Budding Yeast algorithm, please cite: -# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S -# Swain, 2021, Birth Annotator for Budding Yeast (in preparation). +# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +# and Swain, P.S. (2023). Determining growth rates from bright-field images of +# budding cells through identifying overlaps. eLife. 12:e79812. +# https://doi.org/10.7554/eLife.79812 # # # The MIT License (MIT) # -# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021 +# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to @@ -35,7 +37,7 @@ It includes the following trainers * `HyperParameterTrainer`: CNN hyper-parameters * `CNNTrainer`: CNN using gradient descent to optimize for a given loss * `SegmentationTrainer`: hyper-parameters for post-processing of CNN Output -into cell instances and attributes + into cell instances and attributes Given the appropriate inputs, each of these can be trained separately. This is useful for fine-tuning or re-training parts separately. @@ -43,8 +45,10 @@ This is useful for fine-tuning or re-training parts separately. For training the entire framework at once, it is recommended to use the `BabyTrainer` class, which is also aliased as `Nursery`. """ +from .utils import fix_tf_rtx_gpu_bug from .smoothing_model_trainer import SmoothingModelTrainer from .flattener_trainer import FlattenerTrainer +from .segmentation_trainer import SegmentationTrainer import tensorflow as tf if tf.__version__.startswith('1'): from .v1_hyper_parameter_trainer import HyperParamV1 \ @@ -52,7 +56,4 @@ if tf.__version__.startswith('1'): else: from .hyper_parameter_trainer import HyperParameterTrainer from .cnn_trainer import CNNTrainer - -from .training import * -from .utils import fix_tf_rtx_gpu_bug - +from baby.training.training import BabyTrainer, Nursery diff --git a/python/baby/training/bud_trainer.py b/python/baby/training/bud_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..f0a60d6952ffa6122346e9caa34cc48882ca2824 --- /dev/null +++ b/python/baby/training/bud_trainer.py @@ -0,0 +1,332 @@ +# If you publish results that make use of this software or the Birth Annotator +# for Budding Yeast algorithm, please cite: +# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +# and Swain, P.S. (2023). Determining growth rates from bright-field images of +# budding cells through identifying overlaps. eLife. 12:e79812. +# https://doi.org/10.7554/eLife.79812 +# +# +# The MIT License (MIT) +# +# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. + +from pathlib import Path +from itertools import repeat, chain +from warnings import warn +import pickle +import numpy as np +import pandas as pd +from typing import NamedTuple + +from sklearn.ensemble import RandomForestClassifier +from sklearn.model_selection import GridSearchCV +from sklearn.metrics import ( + make_scorer, fbeta_score, accuracy_score, balanced_accuracy_score, + precision_score, recall_score, f1_score +) +import sklearn +if int(sklearn.__version__[0]) > 0: + from sklearn.metrics import PrecisionRecallDisplay + plot_precision_recall_curve = PrecisionRecallDisplay.from_estimator +else: + from sklearn.metrics import plot_precision_recall_curve + +from baby.errors import BadProcess, BadParam +from baby.tracker.core import BudTracker + +from .utils import (SharedParameterContainer, SharedDataContainer, + TrainValTestProperty, TrainValProperty, + standard_augmenter) +from .segmentation_trainer import SegmentationTrainer + + +SCORING_METRICS = { + 'accuracy': make_scorer(accuracy_score), + 'balanced_accuracy': make_scorer(balanced_accuracy_score), + 'precision': make_scorer(precision_score), + 'recall': make_scorer(recall_score), + 'f1': make_scorer(f1_score), + 'f0_5': make_scorer(fbeta_score, beta=0.5), + 'f2': make_scorer(fbeta_score, beta=2) +} + + +class Score(NamedTuple): + accuracy: float + balanced_accuracy: float + precision: float + recall: float + F1: float + F0_5: float + F2: float + + def __str__(self): + return 'Score({})'.format(', '.join([ + '{}={:.3f}'.format(k, v) for k, v in self._asdict().items() + ])) + + +class BudTrainer(BudTracker): + """Coordinates training for the mother-bud assignment model + + Args: + kwargs: Additional arguments passed onto the parent Tracker. + shared_params: Training and segmentation parameters as provided by + :py:class:`utils.SharedParameterContainer`. + shared_data: Training data as provided by + :py:class:`utils.SharedDataContainer`. + cnn_trainer: Trainer with optimised CNN. + """ + + def __init__(self, + shared_params: SharedParameterContainer, + shared_data: SharedDataContainer, + seg_trainer: SegmentationTrainer, + **kwargs): + + kwargs.setdefault('px_size', + shared_params.parameters.target_pixel_size) + kwargs.setdefault('model', False) + self._shared_params = shared_params + super().__init__(**kwargs) + + self._shared_data = shared_data + self._seg_trainer = seg_trainer + + # NB: we inherit self.feats2use from CellTracker class + self.rf_feats = ["p_bud_mat", "size_ratio_mat", "p_budneck_mat", + "budneck_ratio_mat", "adjacency_mat"] + + self._model = None + + @property + def px_size(self): + return self._shared_params.parameters.target_pixel_size + + @px_size.setter + def px_size(self, val): + if val != self._shared_params.parameters.target_pixel_size: + raise BadParam('px_size should be set via the ' + '`target_pixel_size` parameter in the ' + '`SharedParameterContainer`') + + @property + def save_dir(self): + """Base directory in which to save trained models""" + return self._shared_params.save_dir + + @property + def props_file(self): + return (self.save_dir / + self._shared_params.parameters.mother_bud_props_file) + + @property + def props(self): + if getattr(self, '_props', None) is None: + if self.props_file and self.props_file.is_file(): + self.props = pd.read_csv(self.props_file) + else: + raise BadProcess( + 'The property table has not yet been generated') + return self._props + + @props.setter + def props(self, props): + props = pd.DataFrame(props) + required_cols = self.rf_feats + ['is_mb_pair', 'validation'] + if not all(c in props for c in required_cols): + raise BadParam( + '"props" does not have all required columns: {}'.format( + ', '.join(required_cols))) + self._props = props + props.to_csv(self.props_file) + + @property + def model_save_file(self): + return (self.save_dir / + self._shared_params.parameters.mother_bud_model_file) + + @property + def model(self): + if self._model is None: + if isinstance(getattr(self, '_rf', None), GridSearchCV): + self._model = self._rf.best_estimator_ + elif self.model_save_file.isfile(): + with open(filename, 'rb') as f: + self._model = pickle.load(f) + else: + raise BadProcess('"explore_hyperparams" has not been run') + return self._model + + @model.setter + def model(self, val): + if val: + if not isinstance(val, RandomForestClassifier): + raise BadParam('model must be a RandomForestClassifier') + self._model = val + else: + self._model = None + + @property + def feature_importance(self): + return dict(zip(self.rf_feats, self.model.feature_importances_)) + + def generate_property_table(self): + """Generate table of properties to be used for training + """ + segtrainer = self._seg_trainer + flattener = segtrainer.flattener + data = chain(zip(repeat('train'), segtrainer.examples.train), + zip(repeat('val'), segtrainer.examples.val), + zip(repeat('test'), segtrainer.examples.test)) + + tnames = flattener.names() + i_budneck = tnames.index('bud_neck') + bud_target = 'sml_fill' if 'sml_fill' in tnames else 'sml_inte' + i_bud = tnames.index(bud_target) + + p_list = [] + for train_split, seg_example in data: + if len(seg_example.target) < 2: + # Skip if no pairs are present + continue + mb_stats = self.calc_mother_bud_stats(seg_example.pred[i_budneck], + seg_example.pred[i_bud], seg_example.target) + p = pd.DataFrame(mb_stats, columns=self.rf_feats) + p['validation'] = train_split == 'val' + p['testing'] = train_split == 'test' + + # "cellLabels" specifies the label for each mask + cell_labels = seg_example.info.get('cellLabels', []) or [] + if type(cell_labels) is int: + cell_labels = [cell_labels] + # "buds" specifies the label of the bud for each mask + buds = seg_example.info.get('buds', []) or [] + if type(buds) is int: + buds = [buds] + + # Build a ground truth matrix identifying mother-bud pairs + ncells = len(seg_example.target) + is_mb_pair = np.zeros((ncells, ncells), dtype=bool) + mb_inds = [ + (i, cell_labels.index(b)) + for i, b in enumerate(buds) + if b > 0 and b in cell_labels + ] + if len(mb_inds) > 0: + mother_inds, bud_inds = zip(*mb_inds) + is_mb_pair[mother_inds, bud_inds] = True + p['is_mb_pair'] = is_mb_pair.flatten() + + # Ignore any rows containing NaNs + nanrows = np.isnan(mb_stats).any(axis=1) + if (p['is_mb_pair'] & nanrows).any(): + id_keys = ('experimentID', 'position', 'trap', 'tp') + info = seg_example.info + img_id = ' / '.join( + [k + ': ' + str(info[k]) for k in id_keys if k in info]) + warn('Mother-bud pairs omitted due to feature NaNs') + print('Mother-bud pair omitted in "{}"'.format(img_id)) + p = p.loc[~nanrows, :] + p_list.append(p) + + props = pd.concat(p_list, ignore_index=True) + # TODO: should search for any None values in validation column and + # assign a train-validation split to those rows + + self.props = props # also saves + + def explore_hyperparams(self, hyper_param_target='precision'): + # Train bud assignment model on validation data, since this more + # closely represents real-world performance of the CNN: + data = self.props.loc[self.props['validation'], self.rf_feats] + truth = self.props.loc[self.props['validation'], 'is_mb_pair'] + + rf = RandomForestClassifier(n_estimators=15, + criterion='gini', + max_depth=3, + class_weight='balanced') + + param_grid = { + 'n_estimators': [6, 15, 50, 100], + 'max_features': ['auto', 'sqrt', 'log2'], + 'max_depth': [2, 3, 4], + 'class_weight': [None, 'balanced', 'balanced_subsample'] + } + + def get_balanced_best_index(cv_results_): + """Find a model balancing F1 score and speed""" + df = pd.DataFrame(cv_results_) + best_score = df.iloc[df.mean_test_f1.idxmax(), :] + thresh = best_score.mean_test_f1 - 0.1 * best_score.std_test_f1 + return df.loc[df.mean_test_f1 > thresh, 'mean_score_time'].idxmin() + + self._model = None + self._rf = GridSearchCV(estimator=rf, param_grid=param_grid, cv=5, + scoring=SCORING_METRICS, refit=hyper_param_target) + self._rf.fit(data, truth) + + def save_model(self, filename=None): + if filename is None: + filename = self.model_save_file + with open(filename, 'wb') as f: + pickle.dump(self.model, f) + + def fit(self, **kwargs): + try: + self.props + except BadProcess: + self.generate_property_table() + self.explore_hyperparams(**kwargs) + self.save_model() + + def performance(self): + isval = self.props['validation'] + istest = self.props['testing'] + data = self.props.loc[(~isval) & (~istest), self.rf_feats] + truth = self.props.loc[(~isval) & (~istest), 'is_mb_pair'] + valdata = self.props.loc[isval, self.rf_feats] + valtruth = self.props.loc[isval, 'is_mb_pair'] + testdata = self.props.loc[istest, self.rf_feats] + testtruth = self.props.loc[istest, 'is_mb_pair'] + metrics = tuple(SCORING_METRICS.values()) + return TrainValTestProperty( + Score(*(m(self.model, data, truth) for m in metrics)), + Score(*(m(self.model, valdata, valtruth) for m in metrics)), + Score(*(m(self.model, testdata, testtruth) for m in metrics))) + + def grid_search_summary(self): + if not isinstance(getattr(self, '_rf', None), GridSearchCV): + raise BadProcess('"explore_hyperparams" has not been run') + df = pd.DataFrame(self._rf.cv_results_) + disp_cols = [c for c in df.columns if c.startswith('mean_') + or c.startswith('param_')] + return df.loc[self._rf.best_index_, disp_cols] + + def plot_PR(self): + if not isinstance(getattr(self, '_rf', None), GridSearchCV): + raise BadProcess('"explore_hyperparams" has not been run') + best_rf = self._rf.best_estimator_ + isval = self.props['validation'] + valdata = self.props.loc[isval, self.rf_feats] + valtruth = self.props.loc[isval, 'is_mb_pair'] + plot_precision_recall_curve(best_rf, valdata, valtruth) + diff --git a/python/baby/training/cnn_trainer.py b/python/baby/training/cnn_trainer.py index 2807b69be04b387dfea9fbad932e6c5bacccc80a..91ba33e80427505a4c86d2f743af856f63010b1d 100644 --- a/python/baby/training/cnn_trainer.py +++ b/python/baby/training/cnn_trainer.py @@ -1,12 +1,14 @@ # If you publish results that make use of this software or the Birth Annotator # for Budding Yeast algorithm, please cite: -# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S -# Swain, 2021, Birth Annotator for Budding Yeast (in preparation). +# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +# and Swain, P.S. (2023). Determining growth rates from bright-field images of +# budding cells through identifying overlaps. eLife. 12:e79812. +# https://doi.org/10.7554/eLife.79812 # # # The MIT License (MIT) # -# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021 +# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to @@ -25,26 +27,28 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS # IN THE SOFTWARE. +from typing import List, Tuple, Union +from types import MappingProxyType import json import pathlib -from typing import List, Tuple - -import matplotlib.pyplot as plt import pickle +import matplotlib.pyplot as plt +from matplotlib.colors import to_rgba +from scipy.signal import savgol_filter import tensorflow as tf +from tensorflow.keras.callbacks import (ModelCheckpoint, TensorBoard, + LearningRateScheduler, CSVLogger) +from tensorflow.keras.models import load_model + from baby.augmentation import Augmenter from baby.generator import ImageLabel from baby.preprocessing import SegmentationFlattening -from matplotlib.colors import to_rgba -from scipy.signal import savgol_filter -from tensorflow.python.keras.callbacks import ModelCheckpoint, TensorBoard, \ - LearningRateScheduler -from tensorflow.python.keras.models import load_model - from baby import models from baby.errors import BadType, BadProcess from baby.losses import bce_dice_loss, dice_coeff from baby.utils import get_name, schedule_steps +from .utils import SharedParameterContainer +from .flattener_trainer import FlattenerTrainer custom_objects = {'bce_dice_loss': bce_dice_loss, 'dice_coeff': dice_coeff} @@ -52,45 +56,73 @@ OPT_WEIGHTS_FILE = 'weights.h5' INIT_WEIGHTS_FILE = 'init_weights.h5' FINAL_WEIGHTS_FILE = 'final_weights.h5' HISTORY_FILE = 'history.pkl' +CSV_HISTORY_FILE = 'history.csv' LOG_DIR = 'logs' +HYPER_PARAMS_FILE = 'hyperparameters.json' class CNNTrainer: + """Methods for optimising CNN weights using gradient descent. + + After training, the optimised weights are be stored in a file named + :py:data:`OPT_WEIGHTS_FILE` within the :py:attr:`cnn_dir` directory. The + directory is specific to each of the architectures listed in ``cnn_set``. + + Args: + shared_params: Training and segmentation parameters as provided by + :py:class:`utils.SharedParameterContainer`. + flattener_trainer: Trainer that defines optimised targets for the CNN. + max_cnns: The maximum number of CNNs to keep in memory, default is 3. + cnn_fn: The CNN architecture to start with. Defaults to the + :py:attr:`utils.BabyTrainingParameters.cnn_fn` architecture + as found in ``shared_params.parameters``. """ - Methods for optimising the weights of the CNNs using gradient descent. - """ - def __init__(self, save_dir: pathlib.Path, - cnn_set: tuple, gen: ImageLabel, aug: Augmenter, - flattener: SegmentationFlattening, + def __init__(self, + shared_params: SharedParameterContainer, + flattener_trainer: FlattenerTrainer, max_cnns: int = 3, cnn_fn: str = None): - """ - Methods for optimising the weights of the CNNs using gradient descent. - - :param save_dir: base directory in which to save weights and outputs - :param cnn_set: the names of CNN architectures to be trained - :param gen: the data generator - :param aug: the data augmentor - :param flattener: the data flattener - :param max_cnns: the maximum number of CNNs to train/keep, default is 3 - :param cnn_fn: the CNN architecture to start with, defaults to None - """ - self.flattener = flattener - self.aug = aug - self.gen = gen + self._shared_params = shared_params + self._flattener_trainer = flattener_trainer + # TODO no longer needed with hyperparmeter optim self._max_cnns = max_cnns - self.save_dir = save_dir - self.cnn_set = cnn_set self._cnn_fn = cnn_fn self._cnns = dict() self._opt_cnn = None + @property + def save_dir(self): + """Base directory in which to save trained models""" + return self._shared_params.save_dir + + @property + def cnn_set(self): + return self._shared_params.parameters.cnn_set + + @property + def gen(self): + """Data generators used for training models. + + A :py:class:`utils.TrainValTestProperty` of training, validation and + test data generators with augmentations obtained from + :py:attr:`FlattenerTrainer.default_gen`. + """ + return self._flattener_trainer.default_gen + + @property + def flattener(self): + """Target definitions for models trained by this class.""" + return self._flattener_trainer.flattener + @property def cnn_fn(self): """The current CNN architecture function.""" if self._cnn_fn is None: - self.cnn_fn = self.cnn_set[0] + if self._shared_params.parameters.cnn_fn is None: + self.cnn_fn = self.cnn_set[0] + else: + self.cnn_fn = self._shared_params.parameters.cnn_fn return getattr(models, self._cnn_fn) @cnn_fn.setter @@ -100,6 +132,7 @@ class CNNTrainer: if not hasattr(models, fn): raise BadType('That is not a recognised model') self._cnn_fn = fn + self._hyperparameters = None @property def cnn_dir(self): @@ -114,39 +147,70 @@ class CNNTrainer: """The name of the currently active CNN architecture.""" return get_name(self.cnn_fn) + @property + def hyperparameters(self): + """Custom hyperparameters defined for the active CNN architecture. + + A ``dict`` specifying keyword arguments to be passed when building the + active CNN architecture. If left unset, any defaults as given in + :py:mod:`baby.models` will be used. + + NB: The property is returned as a :py:class:`types.MappingProxyType`, + so ``dict`` items cannot be modified. To change hyperparameters, the + whole ``dict`` needs to be replaced. + """ + if not getattr(self, '_hyperparameters', None): + # Load hyperparameters + hyper_param_file = self.cnn_dir / "hyperparameters.json" + if hyper_param_file.exists(): + with open(hyper_param_file, 'r') as fd: + self._hyperparameters = json.load(fd) + else: + # Use defaults specified in `models` + self._hyperparameters = {} + return MappingProxyType(self._hyperparameters) + + @hyperparameters.setter + def hyperparameters(self, params): + if not type(params) == dict or type(params) == MappingProxyType: + raise BadType('Hyperparameters must be specified as a `dict`') + hyper_param_file = self.cnn_dir / "hyperparameters.json" + with open(hyper_param_file, 'w') as f: + json.dump(dict(params), f) + self._hyperparameters = dict(params) + # If the hyperparameters have changed, we need to regenerate the model + # and initial weights + if self.cnn_name in self._cnns: + del self._cnns[self.cnn_name] + # Delete initial weights if they have already been saved + init_weights_file = self.cnn_dir / INIT_WEIGHTS_FILE + init_weights_file.unlink(missing_ok=True) + @property def cnn(self): """The keras Model for the active CNN.""" if self.cnn_name not in self._cnns: - if len(self._cnns) > self._max_cnns: + n_loaded = getattr(self, '_n_cnns_loaded', 0) + if n_loaded > self._max_cnns: # To avoid over-consuming memory reset graph # TODO: ensure TF1/TF2 compat and check RTX bug tf.keras.backend.clear_session() # Reset any potentially loaded models self._cnns = dict() self._opt_cnn = None - # Todo: separate generator from trainer - # Make model accept an input shape and a set of outputs - self.gen.train.aug = self.aug.train + n_loaded = 0 + print('Loading "{}" CNN...'.format(self.cnn_name)) - # Load hyperparameters - hyper_param_file = self.cnn_dir / "hyperparameters.json" - if not hyper_param_file.exists(): - # Todo: just use defaults - raise FileNotFoundError("Hyperparameter file {} for {} not " - "found.".format(hyper_param_file, - self.cnn_name)) - with open(hyper_param_file, 'r') as fd: - hyperparameters = json.load(fd) model = self.cnn_fn(self.gen.train, self.flattener, - **hyperparameters) + **self.hyperparameters) self._cnns[self.cnn_name] = model + self._n_cnns_loaded = n_loaded + 1 # Save initial weights if they haven't already been saved - filename = self.cnn_dir / INIT_WEIGHTS_FILE - if not filename.exists(): + init_weights_file = self.cnn_dir / INIT_WEIGHTS_FILE + if not init_weights_file.exists(): print('Saving initial weights...') - model.save_weights(str(filename)) + model.save_weights(str(init_weights_file)) return self._cnns[self.cnn_name] @property @@ -195,7 +259,7 @@ class CNNTrainer: return self._opt_cnn def fit(self, epochs: int = 400, - schedule: List[Tuple[int, float]] = None, + schedule: Union[str, List[Tuple[int, float]]] = None, replace: bool = False, extend: bool = False): """Fit the active CNN to minimise loss on the (augmented) generator. @@ -214,6 +278,11 @@ class CNNTrainer: if schedule is None: schedule = [(1e-3, epochs)] + if callable(schedule): + schedulefn = schedule + else: + schedulefn = lambda epoch: schedule_steps(epoch, schedule) + finalfile = self.cnn_dir / FINAL_WEIGHTS_FILE if extend: self.cnn.load_weights(str(finalfile)) @@ -225,22 +294,28 @@ class CNNTrainer: if not replace and optfile.is_file(): raise BadProcess('Optimised weights already exist') + csv_history_file = self.cnn_dir / CSV_HISTORY_FILE logdir = self.cnn_dir / LOG_DIR callbacks = [ ModelCheckpoint(filepath=str(optfile), monitor='val_loss', save_best_only=True, verbose=1), + CSVLogger(filename=str(csv_history_file), append=extend), TensorBoard(log_dir=str(logdir)), - LearningRateScheduler( - lambda epoch: schedule_steps(epoch, schedule)) + LearningRateScheduler(schedulefn) ] - self.gen.train.aug = self.aug.train - self.gen.val.aug = self.aug.val - history = self.cnn.fit_generator(generator=self.gen.train, - validation_data=self.gen.val, - epochs=epochs, - callbacks=callbacks) + + if tf.version.VERSION.startswith('1'): + history = self.cnn.fit_generator(generator=self.gen.train, + validation_data=self.gen.val, + epochs=epochs, + callbacks=callbacks) + else: + history = self.cnn.fit(self.gen.train, + validation_data=self.gen.val, + epochs=epochs, + callbacks=callbacks) # Save history with open(self.cnn_dir / HISTORY_FILE, 'wb') as f: diff --git a/python/baby/training/flattener_trainer.py b/python/baby/training/flattener_trainer.py index 36e47ff108e620e745721f884d32fd01b606638b..63dcf16d1451c5b760a863c05ef412767701fb1b 100644 --- a/python/baby/training/flattener_trainer.py +++ b/python/baby/training/flattener_trainer.py @@ -1,23 +1,25 @@ # If you publish results that make use of this software or the Birth Annotator # for Budding Yeast algorithm, please cite: -# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S -# Swain, 2021, Birth Annotator for Budding Yeast (in preparation). -# -# +# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +# and Swain, P.S. (2023). Determining growth rates from bright-field images of +# budding cells through identifying overlaps. eLife. 12:e79812. +# https://doi.org/10.7554/eLife.79812 +# +# # The MIT License (MIT) -# -# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021 -# +# +# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 +# # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to # deal in the Software without restriction, including without limitation the # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or # sell copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: -# +# # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. -# +# # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE @@ -28,22 +30,67 @@ """Optimising the hyper-parameters of the `SegmentationFlattener`""" import json import pathlib +from itertools import chain import matplotlib.pyplot as plt import numpy as np +from scipy.ndimage import binary_erosion + from baby.augmentation import Augmenter from baby.errors import BadProcess, BadType, BadParam from baby.generator import ImageLabel from baby.preprocessing import dwsquareconn, SegmentationFlattening from baby.utils import find_file -from scipy.ndimage import binary_erosion +from baby.visualise import colour_segstack + +from .utils import TrainValProperty, standard_augmenter +from .smoothing_model_trainer import SmoothingModelTrainer + -from .utils import augmented_generator, TrainValProperty +def _generate_flattener_stats_inner(segs, nerode, keep_zero): + segs = segs > 0 + # First remove any sections that may have been augmented away + # NB: the ScalingAugmenter does remove most of these at the rough_crop + # step, but there will still be many excluded after the second crop + segs = segs[..., segs.any(axis=(0,1))] + s_sizes = segs.sum(axis=(0,1)).tolist() + nsegs = segs.shape[2] + esizes = [[] for s in range(nsegs)] + overlap_sizes = [[] for _ in nerode] + for e in nerode: + for s0 in range(nsegs): + seg0 = segs[..., s0] + n0 = int(seg0.sum()) + esizes[s0].append(n0) + if n0 == 0: + continue + for s1 in range(s0 + 1, nsegs): + seg1 = segs[..., s1] + # Calculate number of overlapping pixels + nO = np.sum(seg0 & seg1) + # Calculate fraction overlap + fO = float(nO / np.sum(seg0 | seg1)) + if fO > 0 or keep_zero: + sizes = tuple(sorted([s_sizes[s0], s_sizes[s1]])) + if keep_zero: + overlap_sizes[e].append(sizes + (fO, nO)) + else: + overlap_sizes[e].append(sizes + (fO,)) + segs = binary_erosion(segs, dwsquareconn) + return esizes, overlap_sizes + + +def _batch_parallel_generator(gen, sample_inds, batch_size=8, n_jobs=4): + for i in range(np.ceil(len(sample_inds) / batch_size).astype('int')): + yield gen.parallel_get_indices( + sample_inds[batch_size*i:batch_size*(i+1)], n_jobs=n_jobs) def _generate_flattener_stats(gen: ImageLabel, max_erode: int, - keep_zero=False) -> dict: + keep_zero=False, + batch_size=16, + n_jobs=4) -> dict: """ Generates flattener statistics of the data output by the generator. This function measures the size (in pixels) of segmentation mask @@ -61,36 +108,24 @@ def _generate_flattener_stats(gen: ImageLabel, size after applying the indexed number of erosions. """ nerode = list(range(max_erode + 1)) - overlap_sizes = [[] for _ in nerode] - erosion_sizes = [] - for t in range(len(gen.paths)): - _, segs = gen.get_by_index(t) - nsegs = segs.shape[2] - segs = segs > 0 - s_sizes = [int(segs[..., s].sum()) for s in range(nsegs)] - esizes = [[] for s in range(nsegs)] - for e in nerode: - for s0 in range(nsegs): - seg0 = segs[..., s0] - n0 = int(seg0.sum()) - esizes[s0].append(n0) - if n0 == 0: - continue - for s1 in range(s0 + 1, nsegs): - seg1 = segs[..., s1] - # Calculate number of overlapping pixels - nO = np.sum(seg0 & seg1) - # Calculate fraction overlap - fO = float(nO / np.sum(seg0 | seg1)) - if fO > 0 or keep_zero: - sizes = tuple(sorted([s_sizes[s0], s_sizes[s1]])) - if keep_zero: - overlap_sizes[e].append(sizes + (fO, nO)) - else: - overlap_sizes[e].append(sizes + (fO,)) - segs = binary_erosion(segs, dwsquareconn) - erosion_sizes.extend(esizes) + sample_inds = np.repeat(np.arange(len(gen.nsamples)), gen.nsamples) + sample_gen = _batch_parallel_generator(gen, sample_inds, n_jobs=n_jobs) + n_batches = np.ceil(len(sample_inds) / batch_size).astype('int') + erosion_sizes = [] + overlap_sizes = [[] for _ in range(len(nerode))] + from joblib import Parallel, delayed + from tqdm import trange + for i in trange(n_batches): + gen_batch = gen.parallel_get_indices( + sample_inds[batch_size*i:batch_size*(i+1)], n_jobs=n_jobs) + e_sizes, o_sizes = zip(*Parallel(n_jobs=n_jobs)( + delayed(_generate_flattener_stats_inner)(segs, nerode, keep_zero) + for _, segs in gen_batch)) + erosion_sizes.extend(chain(*e_sizes)) + o_sizes = [chain(*e) for e in zip(*o_sizes)] + for e, o in zip(o_sizes, overlap_sizes): + o.extend(e) return {'overlap_sizes': overlap_sizes, 'erosion_sizes': erosion_sizes} @@ -128,7 +163,7 @@ def _group_overlapping(os, thresh, pad=0): ]) -def _best_overlapping(overlapping, erosion_sizes, min_size): +def _best_overlapping(overlapping, erosion_sizes, min_size, min_size_frac): """Return overlap stats for highest level of erosion without losing cells Binary erosion is valid if it does not reduce the area of any cells below @@ -145,8 +180,12 @@ def _best_overlapping(overlapping, erosion_sizes, min_size): """ # Rearrange `erosion_sizes` by number of applied erosions sz_erosions = list(zip(*erosion_sizes)) - # Erosions are invalid if any cells drop below the minimum allowed size - e_invalid = [any([c < min_size for c in e]) for e in sz_erosions[:0:-1]] + min_median_sz = np.median(sz_erosions[0]) * min_size_frac + # Erosions are invalid if any cells drop below the minimum allowed size, + # or if the median size drops below a fraction of the original median + e_invalid = [np.any(np.array(e) < min_size) + or np.median(e) < min_median_sz + for e in sz_erosions[:0:-1]] # Return only overlap stats for valid numbers of erosions o_valid = [o for o, e in zip(overlapping[:0:-1], e_invalid) if not e] o_valid += [overlapping[0]] @@ -199,54 +238,74 @@ def _best_nerode(szg, min_size): class FlattenerTrainer: + """Optimises the hyper-parameters for the `SegmentationFlattener`. - def __init__(self, save_dir: pathlib.Path, stats_file: str, - flattener_file: str): - """Optimises the hyper-parameters for the `SegmentationFlattener`. - + #TODO describe method for optimisation - - #TODO describe method for optimisation - :param save_dir: the base directory in which to save outputs - :param stats_file: the name of the file in which the stats are saved - :param flattener_file: the name of the file defining the flattener - """ - self.save_dir = save_dir - self.stats_file = self.save_dir / stats_file - self.flattener_file = self.save_dir / flattener_file + Args: + shared_params (utils.SharedParameterContainer): Training and + segmentation parameters as provided by + :py:class:`utils.SharedParameterContainer`. + shared_data (utils.SharedDataContainer): Access to training data. + ssm_trainer (SmoothingModelTrainer): Trainer from which to obtain a + smoothing sigma model. + """ + def __init__(self, shared_params, shared_data, ssm_trainer): + self._shared_params = shared_params + self._shared_data = shared_data + self._ssm_trainer = ssm_trainer self._flattener = None self._stats = None - def generate_flattener_stats(self, - train_gen: ImageLabel, - val_gen: ImageLabel, - train_aug: Augmenter, - val_aug: Augmenter, - max_erode: int = 5): + @property + def save_dir(self): + return self._shared_params.save_dir + + @property + def stats_file(self): + """File in which to save derived data for training flattener""" + return (self.save_dir / + self._shared_params.parameters.flattener_stats_file) + + @property + def flattener_file(self): + return (self.save_dir / + self._shared_params.parameters.flattener_file) + + def generate_flattener_stats(self, max_erode: int = 5, n_jobs=None): """Generate overlap and erosion statistics for augmented data in input. - :param train_gen: the generator of training images and their labels - :param val_gen: the generator of validation images and their labels - :param train_aug: the augmenter to use for training images - :param val_aug: the augmenter to use for validation images - :param max_erode: the maximum allowed number of erosions used to + Saves results to file specified in :py:attr:`stats_file`. + + Args: + max_erode: the maximum allowed number of erosions used to generate erosion values :return: None, saves results to `self.stats_file` """ - with augmented_generator(train_gen, train_aug) as gen: - fs_train = _generate_flattener_stats(gen, max_erode) - with augmented_generator(val_gen, val_aug) as gen: - fs_val = _generate_flattener_stats(gen, max_erode) + dummy_flattener = lambda x, y: x + ssm = self._ssm_trainer.model + params = self._shared_params.parameters + # NB: use isval=True als for training aug since we do not need extra + # augmentations for calibrating the flattener + aug = standard_augmenter(ssm, dummy_flattener, params, isval=True) + train_gen, val_gen, _ = self._shared_data.gen_with_aug(aug) + + if n_jobs is None: + n_jobs = params.n_jobs + fs_train = _generate_flattener_stats(train_gen, max_erode, + n_jobs=n_jobs) + fs_val = _generate_flattener_stats(val_gen, max_erode, + n_jobs=n_jobs) with open(self.stats_file, 'wt') as f: json.dump({'train': fs_train, 'val': fs_val}, f) self._stats = None # trigger reload of property @property def stats(self) -> TrainValProperty: - """The last statistics computed, loaded from `self.stats_file` + """The last statistics computed, loaded from :py:attr:`stats_file`. - :return: The last training and validation statistics computed. - :raises: BadProcess error if the file does not exist. + Raises: + BadProcess: If the file does not exist. """ if self._stats is None: if not self.stats_file.exists(): @@ -258,6 +317,42 @@ class FlattenerTrainer: return TrainValProperty(self._stats.get('train', {}), self._stats.get('val', {})) + def plot_stats(self, nbins=30, nrows=1, sqrt_area=False): + """Plot a histogram of cell overlap statistics of the training set. + + # TODO describe what the plot means + # TODO add an image as an example + + :param nbins: binning of data, passed to `matplotlib.pyplot.hist2d` + :return: None, saves the resulting figure under `self.save_dir / + "flattener_stats.png"` + """ + overlapping = self.stats.train.get('overlap_sizes', []).copy() + if sqrt_area: + # Transform all areas by sqrt + for i, o in enumerate(overlapping): + x, y, f = zip(*o) + overlapping[i] = list(zip(np.sqrt(x), np.sqrt(y), f)) + max_erode = len(overlapping) + ncols = int(np.ceil(max_erode / nrows)) + fig, axs = plt.subplots(nrows, ncols, figsize=(4 * ncols, 4 * nrows)) + axs = axs.flatten() + x, y, _ = zip(*overlapping[0]) + max_size = max(x + y) + for ax, (e, os) in zip(axs, enumerate(overlapping)): + if len(os) > 0: + x, y, w = zip(*os) + else: + x, y, w = 3 * [[]] + ax.hist2d(x, + y, + bins=nbins, + weights=w, + range=[[0, max_size], [0, max_size]]) + ax.plot((0, max_size), (0, max_size), 'r') + ax.set_title('nerosions = {:d}'.format(e)) + fig.savefig(self.save_dir / 'flattener_stats.png') + @property def flattener(self) -> SegmentationFlattening: """The last flattener saved to file. @@ -286,7 +381,8 @@ class FlattenerTrainer: f.save(self.flattener_file) self._flattener = f - def fit(self, nbins=30, min_size=10, pad_frac=0.03, bud_max=200): + def fit(self, nbins=30, min_size=10, min_size_frac=0., pad_frac=0.03, + bud_max=200, sqrt_area=False, overlaps=None): """Optimise the parameters of the `SegmentationFlattener` based on previously computed statistics. @@ -305,13 +401,21 @@ class FlattenerTrainer: raise BadParam('"pad_frac" must be between 0 and 0.2') # Load generated stats for training data - overlapping = self.stats.train.get('overlap_sizes', []) - erosion_sizes = self.stats.train.get('erosion_sizes', []) + overlapping = self.stats.train.get('overlap_sizes', []).copy() + erosion_sizes = self.stats.train.get('erosion_sizes', []).copy() if len(overlapping) == 0 or len(erosion_sizes) == 0 or \ len(list(zip(*erosion_sizes))) != len(overlapping): raise BadProcess( '"flattener_stats.json" file appears to be corrupted') + if sqrt_area: + # Transform all areas by sqrt + min_size = np.sqrt(min_size) + erosion_sizes = np.sqrt(erosion_sizes).tolist() + for i, o in enumerate(overlapping): + x, y, f = zip(*o) + overlapping[i] = list(zip(np.sqrt(x), np.sqrt(y), f)) + # Find the best single split point by brute force iteration over a # binned version of the training data @@ -323,7 +427,8 @@ class FlattenerTrainer: edges = np.linspace(pad, max_size - pad, nbins)[1:-1] # Use overlap stats at maximum valid level of erosion - o_maxerode = _best_overlapping(overlapping, erosion_sizes, min_size) + o_maxerode = _best_overlapping(overlapping, erosion_sizes, min_size, + min_size_frac) # Then iterate over the thresholds (edges) to find which split # minimises the overlap fraction split0, w0 = _find_best_fgroup_split(o_maxerode, edges, pad=pad) @@ -335,8 +440,8 @@ class FlattenerTrainer: szgL, szgH = _group_sizes(erosion_sizes, split0, pad=pad) # And again use the overlap stats at maximum valid level of erosion - ogL = _best_overlapping(ogL, szgL, min_size) - ogH = _best_overlapping(ogH, szgH, min_size) + ogL = _best_overlapping(ogL, szgL, min_size, min_size_frac) + ogH = _best_overlapping(ogH, szgH, min_size, min_size_frac) w_ogL = sum([w for _, _, w in ogL]) w_ogH = sum([w for _, _, w in ogH]) @@ -361,14 +466,21 @@ class FlattenerTrainer: ne1 = _best_nerode(szg1, min_size) ne2 = _best_nerode(szg2, min_size) + if sqrt_area: + untransform = lambda x: int(np.round(np.square(x))) + else: + untransform = lambda x: int(np.round(x)) + flattener = SegmentationFlattening() - flattener.addGroup('small', upper=int(np.round(splits[0] + pad))) + flattener.addGroup('small', upper=untransform(splits[0] + pad)) flattener.addGroup('medium', - lower=int(np.round(splits[0] - pad)), - upper=int(np.round(splits[1] + pad))) - flattener.addGroup('large', lower=int(np.round(splits[1] - pad))) + lower=untransform(splits[0] - pad), + upper=untransform(splits[1] + pad)) + flattener.addGroup('large', lower=untransform(splits[1] - pad)) flattener.addGroup('buds', upper=bud_max, budonly=True) + if overlaps == 'all': + flattener.addGroup('all') flattener.addTarget('lge_inte', 'large', 'interior', nerode=ne2) flattener.addTarget('lge_edge', 'large', 'edge') @@ -377,35 +489,48 @@ class FlattenerTrainer: flattener.addTarget('sml_inte', 'small', 'filled', nerode=ne0) flattener.addTarget('sml_edge', 'small', 'edge') flattener.addTarget('bud_neck', 'buds', 'budneck') + if overlaps == 'all': + flattener.addTarget('all_ovlp', 'all', 'overlap') - flattener.save(self.flattener_file) - self._flattener = None + self.flattener = flattener - def plot_stats(self, nbins=30): - """Plot a histogram of cell overlap statistics of the training set. - - # TODO describe what the plot means - # TODO add an image as an example + @property + def default_gen(self): + """Get default data generators using the current flattener.""" + ssm = self._ssm_trainer.model + params = self._shared_params.parameters + t_aug = standard_augmenter(ssm, self.flattener, params, isval=False) + v_aug = standard_augmenter(ssm, self.flattener, params, isval=True) + return self._shared_data.gen_with_aug((t_aug, v_aug, v_aug)) + + def plot_default_gen_sample(self, i=0, figsize=3, validation=False): + g = self.default_gen.val if validation else self.default_gen.train + img_batch, lbl_batch = g[i] + lbl_batch = np.concatenate(lbl_batch, axis=3) + + f = self.flattener + target_names = f.names() + edge_inds = np.flatnonzero([t.prop == 'edge' for t in f.targets]) + + ncol = len(img_batch) + nrow = len(target_names) + 1 + fig = plt.figure(figsize=(figsize * ncol, figsize * nrow)) + for b, (bf, seg) in enumerate(zip(img_batch, lbl_batch)): + plt.subplot(nrow, ncol, b + 0 * ncol + 1) + plt.imshow(bf[:, :, 0], cmap='gray') + plt.imshow(colour_segstack(seg[:, :, edge_inds], dw=True)) + plt.grid(False) + plt.xticks([]) + plt.yticks([]) + + for i, name in enumerate(target_names): + plt.subplot(nrow, ncol, b + (i + 1) * ncol + 1) + plt.imshow(seg[:, :, i], cmap='gray') + plt.grid(False) + plt.xticks([]) + plt.yticks([]) + plt.title(name) + + fig.savefig(self.save_dir / '{}_generator_sample.png'.format( + 'validation' if validation else 'training')) - :param nbins: binning of data, passed to `matplotlib.pyplot.hist2d` - :return: None, saves the resulting figure under `self.save_dir / - "flattener_stats.png"` - """ - overlapping = self.stats.train.get('overlap_sizes', []) - max_erode = len(overlapping) - fig, axs = plt.subplots(1, max_erode, figsize=(16, 16 / max_erode)) - x, y, _ = zip(*overlapping[0]) - max_size = max(x + y) - for ax, (e, os) in zip(axs, enumerate(overlapping)): - if len(os) > 0: - x, y, w = zip(*os) - else: - x, y, w = 3 * [[]] - ax.hist2d(x, - y, - bins=nbins, - weights=w, - range=[[0, max_size], [0, max_size]]) - ax.plot((0, max_size), (0, max_size), 'r') - ax.set_title('nerosions = {:d}'.format(e)) - fig.savefig(self.save_dir / 'flattener_stats.png') diff --git a/python/baby/training/hyper_parameter_trainer.py b/python/baby/training/hyper_parameter_trainer.py index cedbcaf5176534268cea7df7197332c91445e5ad..b52dcaaa419c26583293563fbf181f22d848eac1 100644 --- a/python/baby/training/hyper_parameter_trainer.py +++ b/python/baby/training/hyper_parameter_trainer.py @@ -1,12 +1,14 @@ # If you publish results that make use of this software or the Birth Annotator # for Budding Yeast algorithm, please cite: -# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S -# Swain, 2021, Birth Annotator for Budding Yeast (in preparation). +# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +# and Swain, P.S. (2023). Determining growth rates from bright-field images of +# budding cells through identifying overlaps. eLife. 12:e79812. +# https://doi.org/10.7554/eLife.79812 # # # The MIT License (MIT) # -# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021 +# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to @@ -29,12 +31,11 @@ import json from pathlib import Path from typing import Union +from baby.generator import augmented_generator from baby.training.hypermodels import get_hypermodel from kerastuner import RandomSearch, Hyperband, BayesianOptimization, Tuner, \ HyperModel -from .utils import augmented_generator - def instantiate_tuner(model, method='random', **kwargs): method = method.lower() diff --git a/python/baby/training/hypermodels.py b/python/baby/training/hypermodels.py index 3ebd0a340d7bca196be2c76906ef6688ddab4a77..01a6aa7356e7e452ad4b0a9a45ef01b94526454d 100644 --- a/python/baby/training/hypermodels.py +++ b/python/baby/training/hypermodels.py @@ -1,12 +1,14 @@ # If you publish results that make use of this software or the Birth Annotator # for Budding Yeast algorithm, please cite: -# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S -# Swain, 2021, Birth Annotator for Budding Yeast (in preparation). +# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +# and Swain, P.S. (2023). Determining growth rates from bright-field images of +# budding cells through identifying overlaps. eLife. 12:e79812. +# https://doi.org/10.7554/eLife.79812 # # # The MIT License (MIT) # -# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021 +# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to @@ -25,11 +27,13 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS # IN THE SOFTWARE. -from baby.layers import msd_block, make_outputs, unet_block +from baby.layers import msd_block, make_outputs, unet_block, res_block, conv_block from baby.losses import dice_coeff, bce_dice_loss -from kerastuner import HyperModel -from tensorflow.python.keras import Input, Model -from tensorflow.python.keras.optimizers import Adam +from keras_tuner import HyperModel +from tensorflow.keras import Input, Model +from tensorflow.keras.initializers import VarianceScaling +from tensorflow.keras.optimizers import Adam +from tensorflow_addons.optimizers import AdamW @@ -48,18 +52,44 @@ class UNet(HyperModel): return dict(depth=4, layer_size=8, batchnorm=True, dropout=0.0) def build(self, hp): + # Universal variants + width = hp.Choice('width', [8, 16, 32, 64]) + depth = hp.Int('depth', min_value=2, max_value=5) + kernel = hp.Choice('kernel', [3, 5, 7]) + init = hp.Choice('initializer', ['glorot_uniform', 'variance_scaling']) + if init == 'variance_scaling': + init = VarianceScaling(2., mode='fan_out') + activation = hp.Choice('activation', ['relu', 'gelu', 'swish']) + conv_pool = hp.Boolean('conv_pool', default=True) + up_activate = hp.Boolean('up_activate', default=True) + residual_skip = hp.Boolean('residual_skip') + enc_repeats = hp.Int('enc_repeats', min_value=2, max_value=4) + dec_repeats = hp.Int('dec_repeats', min_value=2, max_value=4) + block_type = hp.Choice('block_type', + ['effnet', 'effnet-preact', 'convnext', 'conv']) + block_args = { + 'conv': dict(block=conv_block, stem=False), + 'effnet': dict(block=res_block, stem=True), + 'effnet-preact': dict(block=res_block, pre_activate=True, stem=True), + 'convnext': dict(block=res_block, stem=True, convnext=True) + }[block_type] + expand_ratio, drop = 1, 0. + if block_type != 'conv': + expand_ratio = hp.Choice('expand_ratio', [0.5, 1., 2., 4.]) + width = width / expand_ratio + drop = hp.Choice('block_drop', [0., 0.2]) + inputs = Input(shape=self.input_shape) - depth = hp.Int('depth', min_value=2, max_value=4, step=1) - layer_size = hp.Choice('layer_size', values=[8, 16, 32]) - layer_sizes = [layer_size*(2**i) for i in range(depth)] - batchnorm = hp.Boolean('batchnorm', default=True) - dropout = hp.Float('dropout', min_value=0., max_value=0.7, step=0.1) - unet = unet_block(inputs, layer_sizes, batchnorm=batchnorm, - dropout=dropout) + layer_sizes = [width*(2**i) for i in range(depth)] + unet = unet_block(inputs, layer_sizes, kernel=kernel, init=init, + activation=activation, conv_pool=conv_pool, + up_activate=up_activate, residual_skip=residual_skip, + enc_repeats=enc_repeats, dec_repeats=dec_repeats, + expand_ratio=expand_ratio, drop=drop, **block_args) model = Model(inputs=[inputs], outputs=make_outputs(unet, self.outputs)) # Todo: tuning optimizer - model.compile(optimizer=Adam(amsgrad=False), + model.compile(optimizer=AdamW(weight_decay=0.00025), metrics=[dice_coeff], loss=bce_dice_loss, loss_weights=self.weights) diff --git a/python/baby/training/segmentation_trainer.py b/python/baby/training/segmentation_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..68fceece662e4d4964eef6f8601b535cd149bbd3 --- /dev/null +++ b/python/baby/training/segmentation_trainer.py @@ -0,0 +1,982 @@ +# If you publish results that make use of this software or the Birth Annotator +# for Budding Yeast algorithm, please cite: +# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +# and Swain, P.S. (2023). Determining growth rates from bright-field images of +# budding cells through identifying overlaps. eLife. 12:e79812. +# https://doi.org/10.7554/eLife.79812 +# +# +# The MIT License (MIT) +# +# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +from math import floor, log10 +from itertools import combinations, product, chain, repeat, islice +from typing import NamedTuple, Union, Tuple, Any +import numpy as np +import pandas as pd +from matplotlib import pyplot as plt +from tqdm import tqdm +# from numba import njit + +from baby import segmentation +from baby.morph_thresh_seg import (MorphSegGrouped, Group, + SegmentationParameters, + BROADCASTABLE_PARAMS) +from baby.utils import batch_iterator, split_batch_pred +from baby.augmentation import ScalingAugmenter +from baby.performance import calc_IoUs, best_IoU, calc_AP +from baby.errors import BadProcess, BadParam, BadType +from .utils import (SharedParameterContainer, SharedDataContainer, + TrainValTestProperty, standard_augmenter) +from .smoothing_model_trainer import SmoothingModelTrainer +from .cnn_trainer import CNNTrainer + + +# Todo add default parameters to SegTrainer +# Search space ! +DEFAULT_SEG_PARAM_COORDS = { + 'nclosing': [0, 1, 2], + 'nopening': [0, 1, 2], + 'interior_threshold': np.arange(0.3, 1.0, 0.05).tolist(), + 'connectivity': [1, 2], + 'edge_sub_dilations': [0, 1, 2] +} + + +class SegExample(NamedTuple): + """CNN output paired with target segmented outlines and info + + Used for optimising segmentation hyperparameters and for training bud + assigment models + """ + pred: np.ndarray + target: np.ndarray + info: dict + img: np.ndarray + + +class Score(NamedTuple): + """Scoring metrics for segmentation performance. + + Metrics are defined in terms of number of true positives ``TP``, number of + false positives ``FP`` and number of false negatives ``FN``. + + Attributes: + precision: ``TP / (TP + FP)``. + recall: ``TP / (TP + FN)``. + F1: Balanced F-score ``2 * precision * recall / (precision + + recall)``. + F0_5: F-score biased for recall ``1.25 * precision * recall / (0.25 * + precision + recall)``. + F2: F-score biased for precision ``5 * precision * recall / (4 * + precision + recall)``. + meanIoU: Mean intersection over union between segmented and ground + truth cell masks. + """ + precision: float + recall: float + F1: float + F0_5: float + F2: float + meanIoU: float + + +def _example_generator(cnn, dgen): + # b_iter = batch_iterator(sorted(dgen.ordering), batch_size=dgen.batch_size) + # Sorting is counterproductive to method of subsample_frac + b_iter = batch_iterator(dgen.ordering, batch_size=dgen.batch_size) + with tqdm(total=len(dgen.ordering)) as pbar: + for b_inds in b_iter: + batch = dgen.parallel_get_indices(b_inds) + + imgs = [img for img, _ in batch] + preds = split_batch_pred( + cnn.predict(np.stack(imgs), verbose=0)) + for pred, (img, (lbl, info)) in zip(preds, batch): + pbar.update() + lbl = lbl.transpose(2, 0, 1) + # Filter out examples that have been augmented away + valid = lbl.sum(axis=(1, 2)) > 0 + lbl = lbl[valid] + info = info.copy() # ensure we do not modify original + clab = info.get('cellLabels', []) or [] + if type(clab) is int: + clab = [clab] + clab = [l for l, v in zip(clab, valid) if v] + info['cellLabels'] = clab + buds = info.get('buds', []) or [] + if type(buds) is int: + buds = [buds] + buds = [b for b, v in zip(buds, valid) if v] + info['buds'] = buds + yield SegExample(pred, lbl, info, img) + + +def _sub_params(sub, param_template): + """Helper function for ``SegFilterParamOptim.fit_filter_params``. + + Substitutes parameters in a ``dict``, possibly group-specific. + """ + p = { + k: v.copy() if type(v) == list else v + for k, v in param_template.items() + } + for (k, g), v in sub.items(): + if g is None: + p[k] = v + else: + p[k][g] = v + return p + + +# @njit +def _filter_trial(containment, containment_thresh, area, min_area, + pedge_thresh, group_thresh_expansion, + gIds, gLs, gUs, gRs, group, + p_edge, max_IoU, IoU_thresh, assignId, nT): + """Numba optimised filter trial implementation. + + TODO: currently the numba version kills the kernel... This implementation + is nonetheless orders of magnitude faster than the previous. + """ + + rejects = (containment > containment_thresh) | (area < min_area) + for t_pe, g_ex, g, l, u, gr in zip(pedge_thresh, + group_thresh_expansion, + gIds, gLs, gUs, gRs): + g_ex = g_ex * gr + l = max(l - g_ex, 1) + u = u + g_ex + rejects |= (group == g) & ((p_edge < t_pe) | (area < l) | (area > u)) + + TP_mask = (~rejects) & (max_IoU >= IoU_thresh) + + # Find the maximum IoU in each eid-assignment group following the + # suggestion on stackoverflow: + # https://stackoverflow.com/questions/8623047/group-by-max-or-min-in-a-numpy-array + # NB: the grouping variables and IoUs are presorted when + # _filter_trial_cache is generated (see below). + TPs_IoU = max_IoU[TP_mask] + if TPs_IoU.size > 0: + grouping = assignId[TP_mask] + index = np.empty(TPs_IoU.size, 'bool') + index[-1] = True + last_in_group = np.any(grouping[1:] != grouping[:-1], axis=1) + index[:-1] = last_in_group + TPs_IoU = TPs_IoU[index] + + nPs = np.sum(~rejects) + nTPs = TPs_IoU.size + nFPs = nPs - nTPs + nFNs = nT - nTPs + precision = nTPs / (nTPs + nFPs) + recall = nTPs / (nTPs + nFNs) + # Fbeta = (1 + beta^2) * P * R / (beta^2 * P + R) + F1 = 2 * precision * recall / (precision + recall) + F0_5 = 1.25 * precision * recall / (0.25 * precision + recall) + F2 = 5 * precision * recall / (4 * precision + recall) + return precision, recall, F1, F0_5, F2, np.mean(TPs_IoU) + + +def _filter_trial_bootstraps(filter_trial_cache, pedge_thresh, + group_thresh_expansion, containment_thresh, + min_area, IoU_thresh, return_stderr=False): + + score_boots = [] + for df in filter_trial_cache: + scr = _filter_trial(df['containment'], containment_thresh, + df['area'], min_area, + np.array(pedge_thresh, dtype='float'), + np.array(group_thresh_expansion, + dtype='float'), + df['gIds'], df['gLs'], df['gUs'], df['gRs'], + df['group'], df['p_edge'], df['max_IoU'], + IoU_thresh, df['assignId'], df['nT']) + score_boots.append(Score(*scr)) + + score_boots = np.array(score_boots) + mean_score = Score(*score_boots.mean(axis=0)) + if return_stderr: + stderr = score_boots.std(axis=0) / np.sqrt(score_boots.shape[0]) + return (mean_score, Score(*stderr)) + else: + return mean_score + + +def _generate_stat_table(s, seg_ex, segrps, containment_func): + """Parallelisation helper for ``SegFilterParamOptim.generate_stat_table``.""" + + ncells = len(seg_ex.target) if seg_ex.target.any() else 0 + + # Perform within-group segmentation + shape = np.squeeze(seg_ex.pred[0]).shape + border_rect = np.pad(np.zeros(tuple(x - 2 for x in shape), + dtype='bool'), + pad_width=1, + mode='constant', + constant_values=True) + + # Generate local versions of the Group segmenters to avoid potential race + # conditions in parallel usage: + segrps = [Group(grp.targets, grp.params) for grp in segrps] + masks = [] + for grp in segrps: + grp.segment(seg_ex.pred, border_rect) + for cell in grp.cells: + masks.append(cell.mask) + + # Calculate containment scores across groups + contained_cells = {} + paired_groups = zip(segrps, segrps[1:]) + for g, (lower_group, upper_group) in enumerate(paired_groups): + for l, lower in enumerate(lower_group.cells): + for u, upper in enumerate(upper_group.cells): + containment = containment_func(lower.mask, upper.mask) + if containment > 0: + if lower.edge_score > upper.edge_score: + contained_cells[(g + 1, u)] = containment + else: + contained_cells[(g, l)] = containment + + if ncells > 0: + IoUs = calc_IoUs(seg_ex.target, masks, fill_holes=False) + max_IoU = IoUs.max(axis=0) + assignments = IoUs.argmax(axis=0) + _, best_assignments = best_IoU(IoUs.T) + else: + max_IoU = np.zeros(len(masks)) + assignments = np.zeros(len(masks), dtype=np.uint16) + best_assignments = -np.ones(len(masks), dtype=np.int32) + + ind = 0 + rows = [] + for g, grp in enumerate(segrps): + for c, cell in enumerate(grp.cells): + rows.append((s, g, c, cell.area, cell.edge_score, + contained_cells.get((g, c), 0.), + assignments[ind], max_IoU[ind], + best_assignments[ind])) + ind += 1 + + return (s, ncells), rows + +def _seg_filter_optim( + group_id, + new_values, + param_names, + flattener, + seg_gen, + base_params=SegmentationParameters(), + scoring='F0_5', + n_jobs=4): + """Parallelisation helper for ``SegmentationTrainer.fit_seg_params``.""" + + # Replace default params with specified ones + new_params = {} + for k, v in zip(param_names, new_values): + if k in BROADCASTABLE_PARAMS: + if k not in new_params: + gval = getattr(base_params, k) + if type(gval) != list: + raise BadType('Segmentation parameters should ' + 'have been broadcast.') + new_params[k] = gval.copy() + new_params[k][group_id] = v + else: + new_params[k] = v + new_params = base_params._replace(**new_params) + + # Optimise filtering parameters and return score + sfpo = SegFilterParamOptim(flattener, base_params=new_params, + scoring=scoring) + with np.errstate(all='ignore'): + sfpo.generate_stat_table(seg_gen, n_jobs=n_jobs) + sfpo.fit_filter_params(lazy=True, bootstrap=False, n_jobs=n_jobs) + return { + 'group': group_id, + 'base': {k: getattr(new_params, k) for k in param_names}, + 'filter': sfpo.opt_params, + 'score': sfpo.opt_score, + 'scores': sfpo.filter_trial(**sfpo.opt_params, bootstrap=False) + } + + +round_to_n = lambda x, n: round(x, -int(floor(log10(x))) + (n - 1)) + + +class SegFilterParamOptim: + """ + # TODO What does this class do + * What are the parameters and what do they mean + * What are the defaults, what are the ranges/admissible options? + :param flattener: + :param base_params: + :param IoU_thresh: + :param scoring: + :param nbootstraps: + :param bootstrap_frac: + """ + def __init__(self, + flattener, + base_params=SegmentationParameters(), + IoU_thresh=0.5, + scoring='F0_5', + nbootstraps=10, + bootstrap_frac=0.9): + + self.IoU_thresh = IoU_thresh + self.scoring = scoring + self.nbootstraps = nbootstraps + self.bootstrap_frac = bootstrap_frac + + # Force values for parameters optimised by this class + self._base_params = base_params._replace( + fit_radial=True, min_area=1, pedge_thresh=None, + use_group_thresh=False, group_thresh_expansion=0, + containment_thresh=0.8) + + self.segmenter = MorphSegGrouped(flattener, + params=self._base_params, + return_masks=True) + + self.group_info = [] + for g, group in enumerate(self.segrps): + lower = min( + target.definition.get('lower', 1.) + for target in group.targets) + upper = max( + target.definition.get('upper', float('inf')) + for target in group.targets) + grange = lower if upper == float('inf') else upper - lower + self.group_info.append((g, lower, upper, grange)) + + @property + def scoring(self): + """ The scoring method used during evaluation of the segmentation. + + Accepted values are any of the named attributes of :py:class:`Score` + specified as a ``str``. + """ + return self._scoring + + @scoring.setter + def scoring(self, val): + if val not in Score._fields: + raise BadParam('Specified scoring metric not available') + self._scoring = val + + @property + def base_params(self): + return self._base_params + + @property + def segrps(self): + return self.segmenter.groups + + @property + def stat_table(self): + val = getattr(self, '_stat_table', None) + if val is None: + raise BadProcess('"generate_stat_table" has not been run') + return val + + @property + def stat_table_bootstraps(self): + val = getattr(self, '_stat_table_bootstraps', None) + if val is None: + raise BadProcess('"generate_stat_table" has not been run') + return val + + @property + def truth(self): + val = getattr(self, '_nPs', None) + if val is None: + raise BadProcess('"generate_stat_table" has not been run') + return val + + @property + def truth_bootstraps(self): + val = getattr(self, '_nPs_bootstraps', None) + if val is None: + raise BadProcess('"generate_stat_table" has not been run') + return val + + @property + def opt_params(self): + val = getattr(self, '_opt_params', None) + if val is None: + raise BadProcess('"fit_filter_params" has not been run') + return val + + @property + def opt_score(self): + val = getattr(self, '_opt_score', None) + if val is None: + raise BadProcess('"fit_filter_params" has not been run') + return val + + def generate_stat_table(self, example_gen, n_jobs=4): + """Generates unfiltered segmentation results organised as a table. + + The generated output can be accessed as a ``pandas.DataFrame`` from + :py:attr:`stat_table` and as a list of bootstrap samples from + :py:attr:`stat_table_bootstraps`. + + Note: + This function is called from + :py:meth:`SegmentationTrainer.fit_seg_params` via + :py:func:`_seg_filter_optim` and also from + :py:meth:`SegmentationTrainer.refit_filter_seg_params`. + """ + containment_func = getattr(segmentation, + self.base_params.containment_func) + rows_truth = [] + rows = [] + + from joblib import Parallel, delayed + rows = Parallel(n_jobs=n_jobs)( + delayed(_generate_stat_table)(s, seg_ex, self.segrps, + containment_func) + for s, seg_ex in enumerate(example_gen)) + + rows_truth, rows = zip(*rows) + rows = list(chain(*rows)) + + df_truth = pd.DataFrame(rows_truth, columns=['example', 'ncells']) + df_truth = df_truth.set_index('example') + self._nPs = df_truth.ncells.to_numpy() + + dtypes = [('example', np.uint16), + ('group', np.uint8), + ('cell', np.uint16), + ('area', np.uint16), + ('p_edge', np.float64), + ('containment', np.float64), + ('assignments', np.uint16), + ('max_IoU', np.float64), + ('best_assignments', np.int32)] + df = pd.DataFrame(np.array(rows, dtype=dtypes)) + + df['is_best'] = ((df.best_assignments >= 0) & + (df.max_IoU >= self.IoU_thresh)) + df['eid'] = df.example + df['uid'] = tuple(zip(df.example, df.assignments)) + + # Generate a set of bootstrapping filters over 90% of the examples + examples = list(set(df_truth.index.values)) + nperboot = np.round(self.bootstrap_frac * len(examples)).astype(int) + bootstraps = [ + np.random.choice(examples, nperboot, replace=True) + for _ in range(self.nbootstraps) + ] + self._nPs_bootstraps = [df_truth.loc[b].sum() for b in bootstraps] + # Limit bootstrap examples to those present in segmentation output + bootstraps = [b[np.isin(b, df.example)] for b in bootstraps] + df.set_index('example', drop=False, inplace=True) + example_counts = df.example.value_counts() + self._stat_table_bootstraps = [] + for b in bootstraps: + df_boot = df.loc[b] + # Renumber examples to handle the case of duplicated examples in + # the bootstrap: + df_boot['eid'] = tuple( + chain(*(repeat(i, example_counts.loc[e]) + for i, e in enumerate(b)))) + df_boot['uid'] = tuple(zip(df_boot.eid, df_boot.assignments)) + df_boot.set_index('uid', inplace=True) + self._stat_table_bootstraps.append(df_boot) + + df.set_index('uid', inplace=True) + self._stat_table = df + + def _ensure_filter_trial_cache(self, bootstrap): + if getattr(self, '_filter_trial_cache') is None: + if bootstrap: + dfs = self.stat_table_bootstraps + truths = self.truth_bootstraps + else: + dfs = [self.stat_table] + truths = [self.truth.sum()] + self._filter_trial_cache = [] + for df, nT in zip(dfs, truths): + # From stackoverflow: + # https://stackoverflow.com/questions/8623047/group-by-max-or-min-in-a-numpy-array + # We need to pre-sort the DataFrame by a combined grouping + # metric according to eid and assignments and IoU + df = df.sort_values(['eid','assignments','max_IoU']) + # Also need to reformat the group_info + gIds, gLs, gUs, gRs = zip(*self.group_info) + cache = dict(nT=nT, gIds=gIds, + gLs=np.array(gLs, dtype='float'), + gUs=np.array(gUs, dtype='float'), + gRs=np.array(gRs, dtype='float'), + containment=df.containment.to_numpy(), + area=df.area.to_numpy(), + p_edge=df.p_edge.to_numpy(), + group=df.group.to_numpy(), + max_IoU=df.max_IoU.to_numpy(), + assignId=df[['eid','assignments']].to_numpy()) + self._filter_trial_cache.append(cache) + + def filter_trial(self, + pedge_thresh, + group_thresh_expansion, + containment_thresh, + min_area, + bootstrap=True, + return_stderr=False, + use_cache=False): + + if not use_cache: + self._filter_trial_cache = None + self._ensure_filter_trial_cache(bootstrap) + + return _filter_trial_bootstraps(self._filter_trial_cache, + pedge_thresh, group_thresh_expansion, + containment_thresh, min_area, + self.IoU_thresh, + return_stderr=return_stderr) + + def parallel_filter_trials(self, params_list, score=None, bootstrap=True, + return_stderr=False, use_cache=False, n_jobs=4): + if not use_cache: + self._filter_trial_cache = None + self._ensure_filter_trial_cache(bootstrap) + cache = self._filter_trial_cache + + from joblib import Parallel, delayed + mean_scores = Parallel(n_jobs=n_jobs)( + delayed(_filter_trial_bootstraps)(cache, **params, + IoU_thresh=self.IoU_thresh, + return_stderr=return_stderr) + for params in params_list) + + if return_stderr: + mean_scores, stderrs = zip(*mean_scores) + if score is not None: + mean_scores = [getattr(s, score) for s in mean_scores] + stderrs = [getattr(s, score) for s in stderrs] + return mean_scores, stderrs + else: + if score is not None: + mean_scores = [getattr(s, score) for s in mean_scores] + return mean_scores + + def fit_filter_params(self, lazy=False, bootstrap=False, n_jobs=4): + # Define parameter grid values, firstly those not specific to a group + params = { + ('containment_thresh', None): np.linspace(0, 1, 21), + ('min_area', None): np.arange(0, 20, 1) + } + + # Ensure we reset the filter trial cache + self._filter_trial_cache = None + self._ensure_filter_trial_cache(bootstrap) + + # Determine the pedge_threshold range based on the observed p_edge + # range for each group + if (self.stat_table.is_best.all() or not self.stat_table.is_best.any()): + t_pe_upper = self.stat_table.groupby('group').p_edge.mean() + else: + q_pe = self.stat_table.groupby(['group', 'is_best']) + q_pe = q_pe.p_edge.quantile([0.25, 0.95]).unstack((1, 2)) + t_pe_upper = q_pe.loc[:, [(False, 0.95), (True, 0.25)]].mean(1) + + t_pe_vals = [ + np.arange(0, u, round_to_n(u / 20, 1)) for u in t_pe_upper + ] + + # Set group-specific parameter grid values + g_ex_vals = repeat(np.linspace(0, 0.4, 21)) + for g, (t_pe, g_ex) in enumerate(zip(t_pe_vals, g_ex_vals)): + params[('pedge_thresh', g)] = t_pe + params[('group_thresh_expansion', g)] = g_ex + + # Default starting point is with thresholds off and no group expansion + ngroups = len(self.segrps) + dflt_params = { + 'containment_thresh': 0, + 'min_area': 0, + 'pedge_thresh': list(repeat(0, ngroups)), + 'group_thresh_expansion': list(repeat(0, ngroups)) + } + + # Search first along each parameter dimension with all others kept at + # default: + opt_params = {} + for k, pvals in params.items(): + params_list = [_sub_params({k: v}, dflt_params) for v in pvals] + scrs = self.parallel_filter_trials(params_list, + score=self.scoring, + bootstrap=bootstrap, + use_cache=True, n_jobs=n_jobs) + maxInd = np.argmax(scrs) + opt_params[k] = pvals[maxInd] + + # Reset the template parameters to the best along each dimension + base_params = _sub_params(opt_params, dflt_params) + + if lazy: + # Simply repeat search along each parameter dimension, but now + # using the new optimum as a starting point + opt_params = {} + for k, pvals in params.items(): + params_list = [_sub_params({k: v}, base_params) for v in pvals] + scrs = self.parallel_filter_trials(params_list, + score=self.scoring, + bootstrap=bootstrap, + use_cache=True, n_jobs=n_jobs) + maxInd = np.argmax(scrs) + opt_params[k] = pvals[maxInd] + opt_params = _sub_params(opt_params, base_params) + scr = self.filter_trial(**opt_params, bootstrap=bootstrap, + use_cache=True) + self._opt_params = opt_params + self._opt_score = getattr(scr, self.scoring) + return + + # Next perform a joint search for parameters with optimal pairings + opt_param_pairs = {k: {v} for k, v in opt_params.items()} + for k1, k2 in combinations(params.keys(), 2): + vals = list(product(params[k1], params[k2])) + params_list = [_sub_params({ k1: v1, k2: v2 }, base_params) + for v1, v2 in vals] + scrs = self.parallel_filter_trials(params_list, + score=self.scoring, + bootstrap=bootstrap, + use_cache=True, n_jobs=n_jobs) + maxInd = np.argmax(scrs) + p1opt, p2opt = vals[maxInd] + opt_param_pairs[k1].add(p1opt) + opt_param_pairs[k2].add(p2opt) + + # Finally search over all combinations of the parameter values found + # with optimal pairings + params_list = [_sub_params({k: v for k, v in + zip(opt_param_pairs.keys(), pvals)}, + base_params) + for pvals in product(*opt_param_pairs.values())] + scrs = self.parallel_filter_trials(params_list, + score=self.scoring, + bootstrap=bootstrap, + use_cache=True, n_jobs=n_jobs) + maxInd = np.argmax(scrs) + self._opt_params = params_list[maxInd] + self._opt_score = scrs[maxInd] + + +class SegmentationTrainer(object): + """Finds optimal segmentation parameters given a trained CNN. + + Args: + shared_params: Training and segmentation parameters as provided by + :py:class:`utils.SharedParameterContainer`. + shared_data: Training data as provided by + :py:class:`utils.SharedDataContainer`. + ssm_trainer: SmoothingModelTrainer with optimised model for + determination of smoothing sigma. + cnn_trainer: Trainer with optimised CNN. + """ + def __init__(self, + shared_params: SharedParameterContainer, + shared_data: SharedDataContainer, + ssm_trainer: SmoothingModelTrainer, + cnn_trainer: CNNTrainer): + self._shared_params = shared_params + self._shared_data = shared_data + self._ssm_trainer = ssm_trainer + self._cnn_trainer = cnn_trainer + + @property + def save_dir(self): + """Base directory in which to save trained models""" + return self._shared_params.save_dir + + @property + def training_parameters(self): + return self._shared_params.parameters + + @property + def segment_parameters(self): + return self._shared_params.segmentation_parameters + + @segment_parameters.setter + def segment_parameters(self, val): + self._shared_params.segmentation_parameters = val + + @property + def segment_parameter_coords(self): + param_coords = DEFAULT_SEG_PARAM_COORDS.copy() + param_coords.update(self.training_parameters.seg_param_coords) + return param_coords + + @property + def gen(self): + """Training, validation and test data generators with raw output. + + This attribute provides three :py:class:`ImageLabel` generators as a + :py:class:`TrainValTestProperty`, with each generator assigned a + :py:class:`ScalingAugmenter` that outputs an ``(image, label, info)`` + tuple, where ``label`` provides the unflattened ``ndarray`` of cell + outlines and ``info`` is a ``dict`` of meta-data associated with the + label (if any). Augmentations are limited to just cropping and scaling + operations to match the intended pixel size and input size of the CNN. + """ + # Create an augmenter that returns unflattened label images + aug = standard_augmenter( + self._ssm_trainer.model, + lambda lbl, _: lbl, + self.training_parameters, + isval=True) + + def seg_example_aug(img, lbl): + # Assume that the label preprocessing function also returns info + _, info = lbl + img, lbl = aug(img, lbl) + return img, (lbl > 0, info) + + # In this case, always prefer the validation augmenter + return self._shared_data.gen_with_aug(seg_example_aug) + + @property + def examples(self): + """Training, validation and test segmentation example generators. + + This attribute provides three generators as a + :py:class:`TrainValTestProperty`, with each generator yielding + a :py:class:`SegExample` for each image-label pair in the training, + validation or test image data collections. + """ + # Ensure that the saved generators are updated if more data is + # added... + if getattr(self, '_ncells', None) != self._shared_data.data.ncells: + self._seg_examples = None + self._ncells = self._shared_data.data.ncells + # ...or if data generation parameters change: + old_gen_params = getattr(self, '_current_gen_params', None) + new_gen_params = tuple(getattr(self.training_parameters, p) for p in + ('in_memory', 'input_norm_dw', 'batch_size', + 'balanced_sampling', 'use_sample_weights', + 'xy_out', 'target_pixel_size', 'substacks')) + if old_gen_params != new_gen_params: + self._seg_examples = None + self._current_gen_params = new_gen_params + + cnn = self._cnn_trainer.opt_cnn + gen = self.gen + if self.training_parameters.in_memory: + if getattr(self, '_seg_examples', None) is None: + self._seg_examples = TrainValTestProperty( + list(_example_generator(cnn, gen.train)), + list(_example_generator(cnn, gen.val)), + list(_example_generator(cnn, gen.test))) + return TrainValTestProperty( + (e for e in self._seg_examples.train), + (e for e in self._seg_examples.val), + (e for e in self._seg_examples.test)) + else: + self._seg_examples = None + return TrainValTestProperty( + _example_generator(cnn, gen.train), + _example_generator(cnn, gen.val), + _example_generator(cnn, gen.test)) + + @property + def flattener(self): + return self._cnn_trainer.flattener + + @property + def seg_param_stats(self): + if getattr(self, '_seg_param_stats', None) is None: + seg_stats_file = self.training_parameters.segmentation_stats_file + stats_file = self.save_dir / seg_stats_file + if not stats_file.is_file(): + raise BadProcess('"fit_seg_params" has not been run yet') + self._seg_param_stats = pd.read_csv(stats_file, index_col=0) + return self._seg_param_stats + + def fit_seg_params(self, n_jobs=None, scoring='F0_5', subsample_frac=1., + fit_on_split='val'): + """Find optimal segmentation hyperparameters. + + Args: + njobs (int): Number of parallel processes to run. + scoring (str): Scoring metric to be used to assess segmentation + performance. The name of any of the attributes in + :py:class:`Score` may be specified. + """ + + if n_jobs is None: + n_jobs = self.training_parameters.n_jobs + + # Initialise the default segmenter to determine the number of groups + # and obtain broadcast base parameters: + segmenter = MorphSegGrouped( + self.flattener, + params=self.segment_parameters) + ngroups = len(segmenter.groups) + base_params = segmenter.params + + # Generate parameter search grid according to training parameters + param_coords = self.segment_parameter_coords + param_grid = list(product(*param_coords.values())) + par_names = list(param_coords.keys()) + + if type(subsample_frac) == float: + subsample_frac = (subsample_frac,) + if type(fit_on_split) == str: + fit_on_split = (fit_on_split,) + if len(subsample_frac) == 1 and len(fit_on_split) > 1: + subsample_frac = subsample_frac * len(fit_on_split) + if len(fit_on_split) == 1 and len(subsample_frac) > 1: + fit_on_split = fit_on_split * len(subsample_frac) + + examples = [] + for ssf, split in zip(subsample_frac, fit_on_split): + step = max(int(np.floor(1. / ssf)), 1) + # NB: the following essentially assumes that examples are + # presented in random order (see _example_generator) + examples.extend(list(islice( + getattr(self.examples, split), None, None, step))) + + rows = [] + for gind in range(ngroups)[::-1]: + for pars in tqdm(param_grid): + rows.append(_seg_filter_optim(gind, pars, par_names, + self.flattener, examples, + base_params=base_params, + scoring=scoring, n_jobs=n_jobs)) + + rows_expanded = [] + for row in rows: + row_details = chain( + [('group', row['group']), ('score', row['score'])], + row['scores']._asdict().items(), + row['base'].items(), row['filter'].items()) + row_expanded = [] + for k, v in row_details: + if k in BROADCASTABLE_PARAMS and type(v) is list: + kvpairs = [('_'.join((k, str(g))), gv) + for g, gv in enumerate(v)] + else: + kvpairs = [(k, v)] + row_expanded.extend(kvpairs) + rows_expanded.append(dict(row_expanded)) + + # TODO: if None values are combined with integer values, the entire + # column gets converted here to float, with the None values as NaN. + # This causes errors, for example, with specification of + # edge_sub_dilations. There is currently no obvious solution to this. + self._seg_param_stats = pd.DataFrame(rows_expanded) + stats_file = (self.save_dir / + self.training_parameters.segmentation_stats_file) + self._seg_param_stats.to_csv(stats_file) + + self.refit_filter_seg_params(scoring=scoring, n_jobs=n_jobs) + + def refit_filter_seg_params(self, + lazy=False, + bootstrap=False, + scoring='F0_5', + n_jobs=None): + + if n_jobs is None: + n_jobs = self.training_parameters.n_jobs + + # Initialise the default segmenter to determine the number of groups + # and obtain broadcast base parameters: + segmenter = MorphSegGrouped( + self.flattener, + params=self.segment_parameters) + ngroups = len(segmenter.groups) + base_params = segmenter.params + + # Merge the best parameters from each group into a single parameter set + par_names = list(self.segment_parameter_coords.keys()) + broadcast_par_names = [k for k in par_names + if k in BROADCASTABLE_PARAMS] + merged_params = {k: getattr(base_params, k) for k in par_names} + for k in broadcast_par_names: + merged_params[k] = merged_params[k].copy() + stats = self.seg_param_stats + for g, r in enumerate(stats.groupby('group').score.idxmax()): + for k in broadcast_par_names: + merged_params[k][g] = stats.loc[r, k + '_' + str(g)] + merged_params = base_params._replace(**merged_params) + + sfpo = SegFilterParamOptim(self.flattener, + base_params=merged_params, + scoring=scoring) + with np.errstate(all='ignore'): + sfpo.generate_stat_table(list(self.examples.val), n_jobs=n_jobs) + sfpo.fit_filter_params(lazy=lazy, bootstrap=bootstrap, + n_jobs=n_jobs) + + self.segment_parameters = merged_params._replace(**sfpo.opt_params) + + def validate_seg_params( + self, + iou_thresh=0.7, + save=True, + refine_outlines=True): + segmenter = MorphSegGrouped(self.flattener, + params=self.segment_parameters, + return_masks=True) + edge_inds = [ + i for i, t in enumerate(self.flattener.targets) + if t.prop == 'edge' + ] + stats = {} + dfs = {} + for k, seg_exs in zip(self.examples._fields, self.examples): + stats[k] = [] + for seg_ex in seg_exs: + seg = segmenter.segment(seg_ex.pred, + refine_outlines=refine_outlines) + edge_scores = np.array([ + seg_ex.pred[edge_inds, ...].max(axis=0)[s].mean() + for s in seg.edges + ]) + IoUs = calc_IoUs(seg_ex.target, seg.masks) + bIoU, _ = best_IoU(IoUs) + stats[k].append((edge_scores, IoUs, np.mean(bIoU), + np.min(bIoU, initial=1), + calc_AP(IoUs, + probs=edge_scores, + iou_thresh=iou_thresh)[0])) + dfs[k] = pd.DataFrame([s[2:] for s in stats[k]], + columns=['IoU_mean', 'IoU_min', 'AP']) + + print({k: df.mean() for k, df in dfs.items()}) + + nrows = len(dfs) + ncols = dfs['val'].shape[1] + fig, axs = plt.subplots(nrows=nrows, + ncols=ncols, + figsize=(ncols * 4, nrows * 4)) + for axrow, (k, df) in zip(axs, dfs.items()): + for ax, col in zip(axrow, df.columns): + ax.hist(df.loc[:, col], bins=26, range=(0, 1)) + ax.set(xlabel=col, title=k) + if save: + fig.savefig(self.save_dir / 'seg_validation_plot.png') + plt.close(fig) + diff --git a/python/baby/training/smoothing_model_trainer.py b/python/baby/training/smoothing_model_trainer.py index 00388d2e8e7d7e2c4249ba36abdd646e19d2026c..9960a061528f7296dfab6e458e6badb129f7f8c8 100644 --- a/python/baby/training/smoothing_model_trainer.py +++ b/python/baby/training/smoothing_model_trainer.py @@ -1,12 +1,14 @@ # If you publish results that make use of this software or the Birth Annotator # for Budding Yeast algorithm, please cite: -# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S -# Swain, 2021, Birth Annotator for Budding Yeast (in preparation). +# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +# and Swain, P.S. (2023). Determining growth rates from bright-field images of +# budding cells through identifying overlaps. eLife. 12:e79812. +# https://doi.org/10.7554/eLife.79812 # # # The MIT License (MIT) # -# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021 +# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to @@ -31,8 +33,10 @@ import pandas as pd from baby.augmentation import _filled_canny, _apply_crop, SmoothingSigmaModel from baby.errors import BadProcess, BadType from baby.preprocessing import dwsquareconn -from baby.segmentation import binary_edge, squareconn, morph_radial_thresh_fit, \ - draw_radial, mask_iou +from baby.generator import augmented_generator +from baby.segmentation import (binary_edge, squareconn, mask_to_knots, + draw_radial, mask_iou) +from baby.morph_thresh_seg import SegmentationParameters from baby.utils import find_file from numpy.polynomial import Polynomial from scipy.ndimage import binary_fill_holes @@ -41,10 +45,10 @@ from skimage import filters, transform from skimage.measure import regionprops from tqdm import trange -from .utils import augmented_generator, TrainValProperty +from .utils import TrainValProperty -def _generate_smoothing_sigma_stats(gen): +def _generate_smoothing_sigma_stats(gen, params): sigmas = np.arange(0.4, 5.0, 0.20) rotations = np.arange(7, 45, 7) scaling = np.linspace(0.5, 1.5, 6) @@ -52,7 +56,7 @@ def _generate_smoothing_sigma_stats(gen): square_edge = lambda m: binary_edge(m, squareconn) smoothing_stats = [] for t in trange(len(gen.paths)): - _, (segs, _) = gen.get_by_index(t) + segs = gen.get_by_index(t) if segs.shape[2] == 1 and segs.sum() == 0: continue @@ -64,13 +68,40 @@ def _generate_smoothing_sigma_stats(gen): sfill = segs_fill[..., c] sedge = segs_edge[..., c] nedge = segs[..., c].sum() + area = sfill.sum() + + # Add padding to make input image square + nx, ny = sfill.shape[:2] + if nx != ny: + rprops = regionprops(sfill.astype('int'))[0] + min_rr, min_cc, max_rr, max_cc = rprops.bbox + pad = np.round(0.25 * max(max_rr - min_rr, max_cc - min_cc)).astype(int) + min_rr = max(min_rr - pad, 0) + min_cc = max(min_cc - pad, 0) + max_rr = min(max_rr + pad, nx) + max_cc = min(max_cc + pad, ny) + sfill = sfill[min_rr:max_rr, min_cc:max_cc] + sedge = sedge[min_rr:max_rr, min_cc:max_cc] + nx, ny = sfill.shape[:2] + + # Add padding to ensure features do not rotate out of image limits + xpad = max(nx, ny) - nx + ypad = max(nx, ny) - ny + xlpad = xpad // 2 + ylpad = ypad // 2 + sfill = np.pad(sfill, ((xlpad, xpad - xlpad), + (ylpad, ypad - ylpad)), mode='constant') + sedge = np.pad(sedge, ((xlpad, xpad - xlpad), + (ylpad, ypad - ylpad)), mode='constant') - # fit radial spline to generate accurate reference edges for - # resize transformation rprops = regionprops(sfill.astype('int'))[0] centre = np.array(rprops.centroid) - radii, angles = morph_radial_thresh_fit(sedge, sfill, rprops) - genedge = draw_radial(radii, angles, centre, sedge.shape) + + # fit spline to generate accurate reference edges for resize + # transformation. We do not have a predicted edge in this case, so + # supply the ground truth edge for the fitting routine. + (_, radii, angles), genedge = mask_to_knots( + sfill, p_edge=sedge, **params._asdict()) genfill = binary_fill_holes(genedge, squareconn) # Limit the number of rotation and scaling operations by @@ -82,6 +113,7 @@ def _generate_smoothing_sigma_stats(gen): sblur = filters.gaussian(sfill, s) genblur = filters.gaussian(genfill, s) + # Identity transformation (on raw edge) spf = _filled_canny(sblur) smoothing_stats += [{ 'ind': t, @@ -90,18 +122,20 @@ def _generate_smoothing_sigma_stats(gen): 'rotation': 0, 'scaling': 1, 'nedge': nedge, + 'area': area, 'iou': mask_iou(spf, sfill), 'edge_iou': mask_iou(square_edge(spf), sedge) - }] + }] + # Rotation transformation (on raw edge) sr = transform.rotate(sblur, - angle=r, - mode='reflect', - resize=True) + angle=r, + mode='reflect', + resize=True) sr = transform.rotate(sr, - angle=-r, - mode='reflect', - resize=False) + angle=-r, + mode='reflect', + resize=False) srf = _filled_canny(_apply_crop(sr, spf.shape)) smoothing_stats += [{ 'ind': t, @@ -110,15 +144,17 @@ def _generate_smoothing_sigma_stats(gen): 'rotation': r, 'scaling': 1, 'nedge': nedge, + 'area': area, 'iou': mask_iou(srf, sfill), 'edge_iou': mask_iou(square_edge(srf), sedge) - }] + }] + # Scaling transformation (on generated edge) insize = np.array(spf.shape) outsize = np.round(insize * z).astype('int') centre_sc = outsize / 2 + z * (centre - insize / 2) genedge_sc = draw_radial(z * radii, angles, centre_sc, - outsize) + outsize, cartesian_spline=params.cartesian_spline) genfill_sc = binary_fill_holes(genedge_sc, squareconn) sd = transform.resize(genblur, outsize, anti_aliasing=False) sdf = _filled_canny(sd) @@ -129,25 +165,52 @@ def _generate_smoothing_sigma_stats(gen): 'rotation': 0, 'scaling': z, 'nedge': nedge, + 'area': area, 'iou': mask_iou(sdf, genfill_sc), 'edge_iou': mask_iou(square_edge(sdf), genedge_sc) - }] + }] return pd.DataFrame(smoothing_stats) class SmoothingModelTrainer: - def __init__(self, save_dir, stats_file, model_file): - self.save_dir = save_dir - self.stats_file = save_dir / stats_file - self.model_file = save_dir / model_file + """Trains the smoothing model for augmentations on binary masks. + + Args: + shared_params (utils.SharedParameterContainer): training and + segmentation parameters as provided by + :py:class:`utils.SharedParameterContainer`. + shared_data (utils.SharedDataContainer): Access to training data. + """ + def __init__(self, shared_params, shared_data): + self._shared_params = shared_params + self._shared_data = shared_data self._model = None self._stats = None - def generate_smoothing_sigma_stats(self, train_gen, val_gen): - with augmented_generator(train_gen, lambda x, y: (x, y)) as gen: - sss_train = _generate_smoothing_sigma_stats(gen) - with augmented_generator(val_gen, lambda x, y: (x, y)) as gen: - sss_val = _generate_smoothing_sigma_stats(gen) + @property + def save_dir(self): + return self._shared_params.save_dir + + @property + def stats_file(self): + return (self.save_dir / + self._shared_params.parameters.smoothing_sigma_stats_file) + + @property + def model_file(self): + return (self.save_dir / + self._shared_params.parameters.smoothing_sigma_model_file) + + @property + def segment_params(self): + return self._shared_params.segmentation_parameters + + def generate_smoothing_sigma_stats(self, aug=lambda x, y: y[0]): + train_gen, val_gen, _ = self._shared_data.gen + with augmented_generator(train_gen, aug) as gen: + sss_train = _generate_smoothing_sigma_stats(gen, self.segment_params) + with augmented_generator(val_gen, aug) as gen: + sss_val = _generate_smoothing_sigma_stats(gen, self.segment_params) sss_train['validation'] = False sss_val['validation'] = True sss = pd.concat((sss_train, sss_val)) @@ -158,10 +221,10 @@ class SmoothingModelTrainer: if self._stats is None: if not self.stats_file.exists(): raise BadProcess( - 'smoothing sigma stats have not been generated') + 'smoothing sigma stats have not been generated') self._stats = pd.read_csv(self.stats_file) return TrainValProperty(self._stats[~self._stats['validation']], - self._stats[self._stats['validation']]) + self._stats[self._stats['validation']]) @property def model(self): @@ -172,7 +235,7 @@ class SmoothingModelTrainer: self._model = smoothing_sigma_model else: raise BadProcess( - 'The "smoothing_sigma_model" has not been assigned yet') + 'The "smoothing_sigma_model" has not been assigned yet') return self._model @model.setter @@ -183,8 +246,8 @@ class SmoothingModelTrainer: ssm.load(ssm_file) if not isinstance(ssm, SmoothingSigmaModel): raise BadType( - '"smoothing_sigma_model" must be of type "baby.augmentation.SmoothingSigmaModel"' - ) + '"smoothing_sigma_model" must be of type "baby.augmentation.SmoothingSigmaModel"' + ) ssm.save(self.model_file) self._model = ssm @@ -194,10 +257,10 @@ class SmoothingModelTrainer: stats = self.stats.train stats = stats.groupby(idcols).apply(group_best_iou) filts = { - 'identity': (stats.scaling == 1) & (stats.rotation == 0), - 'scaling': stats.scaling != 1, - 'rotation': stats.rotation != 0 - } + 'identity': (stats.scaling == 1) & (stats.rotation == 0), + 'scaling': stats.scaling != 1, + 'rotation': stats.rotation != 0 + } return stats, filts def fit(self, filt='identity'): @@ -210,8 +273,8 @@ class SmoothingModelTrainer: b = 10 # initial guess for offset term in final model # Fit s = c + m * log(n - b); want n = b + exp((s - c)/m) pinv = Polynomial.fit(np.log(np.clip(stats.nedge - b, 1, None)), - stats.sigma, - deg=1) + stats.sigma, + deg=1) c = pinv(0) m = pinv(1) - c @@ -228,34 +291,34 @@ class SmoothingModelTrainer: params = (self.model._a, self._model._b, self.model._c) fig, axs = plt.subplots(2, - len(filts), - figsize=(12, 12 * 2 / len(filts))) + len(filts), + figsize=(12, 12 * 2 / len(filts))) sigma_max = stats.sigma.max() nedge_max = stats.nedge.max() sigma = np.linspace(0, sigma_max, 100) for ax, (k, f) in zip(axs[0], filts.items()): ax.scatter(stats[f].sigma, - stats[f].nedge, - 16, - alpha=0.05, - edgecolors='none') + stats[f].nedge, + 16, + alpha=0.05, + edgecolors='none') ax.plot(sigma, model(sigma, *params), 'r') ax.set(title=k.title(), - xlabel='sigma', - ylabel='nedge', - ylim=[0, nedge_max]) + xlabel='sigma', + ylabel='nedge', + ylim=[0, nedge_max]) nedge = np.linspace(1, nedge_max, 100) for ax, (k, f) in zip(axs[1], filts.items()): ax.scatter(stats[f].nedge, - stats[f].sigma, - 16, - alpha=0.05, - edgecolors='none') + stats[f].sigma, + 16, + alpha=0.05, + edgecolors='none') ax.plot(nedge, [self.model(n) for n in nedge], 'r') ax.set(title=k.title(), - xlabel='nedge', - ylabel='sigma', - ylim=[0, sigma_max]) + xlabel='nedge', + ylabel='sigma', + ylim=[0, sigma_max]) fig.tight_layout() fig.savefig(self.save_dir / 'fitted_smoothing_sigma_model.png') diff --git a/python/baby/training/training.py b/python/baby/training/training.py index ad9f87c502a3ff84075323eb8fb3438d8c744b54..1c1738e341a7bc52504a4c2a4f95ad9c6fa777af 100644 --- a/python/baby/training/training.py +++ b/python/baby/training/training.py @@ -1,12 +1,14 @@ # If you publish results that make use of this software or the Birth Annotator # for Budding Yeast algorithm, please cite: -# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S -# Swain, 2021, Birth Annotator for Budding Yeast (in preparation). +# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +# and Swain, P.S. (2023). Determining growth rates from bright-field images of +# budding cells through identifying overlaps. eLife. 12:e79812. +# https://doi.org/10.7554/eLife.79812 # # # The MIT License (MIT) # -# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021 +# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to @@ -28,7 +30,6 @@ import inspect import json import pickle -import shutil import warnings from itertools import product, chain from pathlib import Path @@ -37,98 +38,50 @@ from typing import NamedTuple import numpy as np import pandas as pd import tensorflow as tf -from baby.brain import default_params -from baby.morph_thresh_seg import MorphSegGrouped -from baby.performance import calc_IoUs, best_IoU, calc_AP -from baby.seg_trainer import SegFilterParamOptim, _sub_params -from baby.tracker.training import CellTrainer, BudTrainer -from baby.training import CNNTrainer - from matplotlib import pyplot as plt -from tqdm import tqdm + +from baby.tracker.training import CellTrainer +from baby.training import CNNTrainer +from baby.errors import BadParam, BadFile, BadType, BadProcess if tf.__version__.startswith('2'): - tf.compat.v1.disable_eager_execution() + # tf.compat.v1.disable_eager_execution() from .hyper_parameter_trainer import HyperParameterTrainer else: from .v1_hyper_parameter_trainer import HyperParamV1 as \ - HyperParameterTrainer + HyperParameterTrainer -from .utils import BabyTrainerParameters, TrainValTestProperty, \ - augmented_generator +from .utils import (BabyTrainerParameters, TrainValTestProperty, + standard_augmenter, SharedParameterContainer, + SharedDataContainer) from .smoothing_model_trainer import SmoothingModelTrainer from .flattener_trainer import FlattenerTrainer - -from baby.utils import find_file, as_python_object, jsonify, batch_iterator, \ - split_batch_pred -from baby.errors import BadParam, BadFile, BadType, BadProcess -from baby.io import TrainValTestPairs -from baby.preprocessing import (robust_norm, seg_norm) -from baby.augmentation import (Augmenter, ScalingAugmenter) -from baby.generator import ImageLabel -from baby.visualise import colour_segstack - -LOG_DIR = 'logs' - - -# TODO: ADD TO UTILS in training -class SegExample(NamedTuple): - """CNN output paired with target segmented outlines and info - - Used for optimising segmentation hyperparameters and for training bud - assigment models - """ - pred: np.ndarray - target: np.ndarray - info: dict - img: np.ndarray - - -# interior_threshold: threshold on predicted interior -# nclosing: number of closing operations on threshold mask -# nopening: number of opening operations on threshold mask -# -# TODO Create SegTrainer -# TODO Add default parameters to SegTrainer -# Structure taht the parameters should actually be (shape/size) -# Thresholds modified to the correct parameters -base_seg_params = { - 'interior_threshold': [0.5, 0.5, 0.5], - 'nclosing': [0, 0, 0], - 'nopening': [0, 0, 0], - 'connectivity': [2, 2, 2], - 'edge_sub_dilations': [0, 0, 0] -} - -# Todo add default parameters to SegTrainer -# Search space ! -seg_param_coords = { - 'nclosing': [0, 1, 2], - 'nopening': [0, 1, 2], - 'interior_threshold': np.arange(0.3, 1.0, 0.05).tolist(), - 'connectivity': [1, 2], - 'edge_sub_dilations': [0, 1, 2] -} +from .segmentation_trainer import SegmentationTrainer, SegExample +from .bud_trainer import BudTrainer class BabyTrainer(object): - """Manager to set up and train BABY models - - :param save_dir: directory in which to save parameters and logs (and - from which to auto-load parameters and logs) - :param train_val_images: either a dict with keys 'training' and - 'validation' and values specifying lists of file name pairs, or the - name of a json file containing such a dict. The file name pairs - should correspond to image-label pairs suitable for input to - `baby.generator.ImageLabel`. - :param flattener: either a `baby.preprocessing.SegmentationFlattening` - object, or the name of a json file that is a saved - `SegmentationFlattening` object. + """Manager to set up and train BABY models. + + Args: + save_dir (str or Path): Directory in which to save parameters and logs + (and from which to auto-load parameters and logs). Must be an + absolute path, or specified relative to ``base_dir``. + base_dir (str or Path or None): Base directory within which all + relevant image files can be found. By default (i.e., specifying + ``None``), uses the current working directory. References to the + image files will be saved relative to this directory. If the base + directory is copied/moved but the structure within that directory + is left intact, then references to the image files will be + correctly maintained. + params (None or BabyTrainerParameters or str or Path): Optionally + specify initial training parameters as a + :py:class:`BabyTrainerParameters` instance or the path to a saved + :py:class:`BabyTrainerParameters` instance. """ def __init__(self, save_dir, base_dir=None, params=None, max_cnns=3): - # Register the save dir if base_dir is not None: base_dir = Path(base_dir) if not base_dir.is_dir(): @@ -136,25 +89,22 @@ class BabyTrainer(object): else: base_dir = Path.cwd() self.base_dir = base_dir - # All other directories are relative to the base dir + + # The save directory may be specified relative to the base directory save_dir = Path(save_dir) if not save_dir.is_absolute(): save_dir = base_dir / save_dir if not save_dir.is_dir(): raise BadParam('"save_dir" must be a valid directory') - self.save_dir = save_dir - self._parameters_file = 'parameters.json' - - # Register parameters if specified - if isinstance(params, BabyTrainerParameters): - self.parameters = params - elif isinstance(params, str): - filename = find_file(params, save_dir, 'params') - savename = save_dir / self._parameters_file - if not savename.is_file(): - shutil.copy(filename, savename) + + # Initialise shared parameters and data + self._shared_params = SharedParameterContainer( + save_dir, params=params) + self._shared_data = SharedDataContainer( + self._shared_params, base_dir=base_dir) self._max_cnns = max_cnns + # Trainers self._smoothing_sigma_trainer = None self._flattener_trainer = None @@ -162,653 +112,192 @@ class BabyTrainer(object): self._cnn_trainer = None self._track_trainer = None self._bud_trainer = None + self._segmentation_trainer = None + + @property + def parameters(self): + return self._shared_params.parameters + + @parameters.setter + def parameters(self, params): + self._shared_params.parameters = params + + @property + def segmentation_parameters(self): + return self._shared_params.segmentation_parameters + + @segmentation_parameters.setter + def segmentation_parameters(self, params): + self._shared_params.segmentation_parameters = params + + @property + def save_dir(self): + return self._shared_params.save_dir + + @property + def in_memory(self): + return self.parameters.in_memory + + @in_memory.setter + def in_memory(self, x): + self.parameters = dict(in_memory=x) + + @property + def data(self): + return self._shared_data.data + + @data.setter + def data(self, x): + self._shared_data.data = x + + @property + def tracker_data(self): + return self._shared_data.tracker_data + + @tracker_data.setter + def tracker_data(self, x): + self._shared_data.tracker_data = x + + @property + def gen(self): + return self.flattener_trainer.default_gen + + def gen_with_aug(self, aug): + return self._shared_data.gen_with_aug(aug) @property def smoothing_sigma_trainer(self): if self._smoothing_sigma_trainer is None: self._smoothing_sigma_trainer = SmoothingModelTrainer( - save_dir=self.save_dir, - stats_file=self.parameters.smoothing_sigma_stats_file, - model_file=self.parameters.smoothing_sigma_model_file) + self._shared_params, self._shared_data) return self._smoothing_sigma_trainer @property def flattener_trainer(self): if self._flattener_trainer is None: self._flattener_trainer = FlattenerTrainer( - save_dir=self.save_dir, - stats_file=self.parameters.flattener_stats_file, - flattener_file=self.parameters.flattener_file) + self._shared_params, self._shared_data, + self.smoothing_sigma_trainer) return self._flattener_trainer @property def hyperparameter_trainer(self): if self._hyperparameter_trainer is None: self._hyperparameter_trainer = HyperParameterTrainer( - save_dir=self.save_dir, - cnn_set=self.parameters.cnn_set, - gen=self.gen, - aug=self.aug, - outputs=self.flattener_trainer.flattener.names(), - tuner_params=None - # Todo: tuner file if it exists in parameters - ) + save_dir=self.save_dir, + cnn_set=self.parameters.cnn_set, + gen=self.gen, + aug=self.aug, + outputs=self.flattener_trainer.flattener.names(), + tuner_params=None + # Todo: tuner file if it exists in parameters + ) return self._hyperparameter_trainer @property def cnn_trainer(self): if self._cnn_trainer is None: self._cnn_trainer = CNNTrainer( - save_dir=self.save_dir, - cnn_set=self.parameters.cnn_set, - gen=self.gen, - aug=self.aug, - flattener=self.flattener_trainer.flattener, - cnn_fn=self.parameters.cnn_fn, - max_cnns=self._max_cnns # Todo: private access OK? - ) + self._shared_params, + self.flattener_trainer, + max_cnns=self._max_cnns # Todo: private access OK? + ) return self._cnn_trainer + @property + def segmentation_trainer(self): + if self._segmentation_trainer is None: + self._segmentation_trainer = SegmentationTrainer( + self._shared_params, self._shared_data, + self.smoothing_sigma_trainer, self.cnn_trainer) + return self._segmentation_trainer + @property def track_trainer(self): if self._track_trainer is None: self._track_trainer = CellTrainer(self.tracker_data._metadata, - self.tracker_data) + self.tracker_data) return self._track_trainer @track_trainer.setter - def track_trainer(self, all_feats2use=None): + def track_trainer(self, all_feats2use): self._track_trainer = CellTrainer(self.tracker_data._metadata, - data=self.tracker_data, - all_feats2use=all_feats2use) + data=self.tracker_data, + all_feats2use=all_feats2use) @property def bud_trainer(self): - props_file = self.save_dir / self.parameters.mother_bud_props_file - if not hasattr(self, '_bud_trainer') or not self._bud_trainer: + if self._bud_trainer is None: self._bud_trainer = BudTrainer( - props_file=props_file, - px_size=self.parameters.target_pixel_size) + self._shared_params, self._shared_data, + self.segmentation_trainer) return self._bud_trainer - @property - def parameters(self): - if not hasattr(self, '_parameters') or not self._parameters: - param_file = self.save_dir / self._parameters_file - if param_file.is_file(): - with open(param_file, 'rt') as f: - params = json.load(f, object_hook=as_python_object) - if not isinstance(params, BabyTrainerParameters): - raise BadFile('Parameters file has been corrupted') - self._parameters = params - else: - self.parameters = BabyTrainerParameters() - return self._parameters - - @parameters.setter - def parameters(self, params): - if isinstance(params, dict): - p = self.parameters._asdict() - p.update(params) - params = BabyTrainerParameters(**p) - elif not isinstance(params, BabyTrainerParameters): - params = BabyTrainerParameters(*params) - self._parameters = params - with open(self.save_dir / self._parameters_file, 'wt') as f: - json.dump(jsonify(self._parameters), f) - - @property - def in_memory(self): - return self.parameters.in_memory - - @in_memory.setter - def in_memory(self, x): - self.parameters = dict(in_memory=x) - - def _check_for_data_update(self): - if getattr(self, '_ncells', None) != self._impairs.ncells: - # Reset generators - self._gen_train = None - self._gen_val = None - self._gen_test = None - # Trigger save of the data - datafile = self.save_dir / self.parameters.train_val_test_pairs_file - self._impairs.save(datafile, self.base_dir) - self._ncells = self._impairs.ncells - - # And for thet tracker datasets too - def _tracker_check_for_data_update(self): - if getattr(self, '_tracker_ncells', - None) != self._tracker_impairs.ncells: - datafile = self.save_dir / self.parameters.tracker_tvt_pairs_file - self._tracker_impairs.save(datafile, self.base_dir) - self._tracker_ncells = self._tracker_impairs.ncells - - @property - def data(self): - if not hasattr(self, '_impairs') or not self._impairs: - self._impairs = TrainValTestPairs() - pairs_file = self.save_dir / self.parameters.train_val_test_pairs_file - if pairs_file.is_file(): - self._impairs.load(pairs_file, self.base_dir) - self._check_for_data_update() - return self._impairs - - @data.setter - def data(self, train_val_test_pairs): - if isinstance(train_val_test_pairs, str): - pairs_file = find_file(train_val_test_pairs, self.save_dir, - 'data') - train_val_test_pairs = TrainValTestPairs() - train_val_test_pairs.load(pairs_file, self.base_dir) - if not isinstance(train_val_test_pairs, TrainValTestPairs): - raise BadType( - '"data" must be of type "baby.io.TrainValTestPairs"') - self._impairs = train_val_test_pairs - self._check_for_data_update() - - @property - def tracker_data(self): - if not hasattr(self, '_impairs') or not self._impairs: - self._tracker_impairs = TrainValTestPairs() - pairs_file = self.save_dir / self.parameters.tracker_tvt_pairs_file - if pairs_file.is_file(): - self._tracker_impairs.load(pairs_file, self.base_dir) - self._check_for_data_update() - return self._tracker_impairs - - @data.setter - def tracker_data(self, train_val_test_pairs): - if isinstance(train_val_test_pairs, str): - pairs_file = find_file(train_val_test_pairs, self.save_dir, - 'data') - train_val_test_pairs = TrainValTestPairs() - train_val_test_pairs.load(pairs_file, self.base_dir) - if not isinstance(train_tvt_pairs, TrainValTestPairs): - raise BadType( - '"data" must be of type "baby.io.TrainValTestPairs"') - self._tracker_impairs = train_val_test_pairs - self._tracker_check_for_data_update() - - @property - def gen(self): - # NB: generator init ensures all specified images exist - # NB: only dummy augmenters are assigned to begin with - p = self.parameters - if not getattr(self, '_gen_train', None): - if len(self.data.training) == 0: - raise BadProcess('No training images have been added') - # Initialise generator for training images - self._gen_train = ImageLabel(self.data.training, - batch_size=p.batch_size, - aug=Augmenter(), - preprocess=(robust_norm, seg_norm), - in_memory=p.in_memory) - - if not getattr(self, '_gen_val', None): - if len(self.data.validation) == 0: - raise BadProcess('No validation images have been added') - # Initialise generator for validation images - self._gen_val = ImageLabel(self.data.validation, - batch_size=p.batch_size, - aug=Augmenter(), - preprocess=(robust_norm, seg_norm), - in_memory=p.in_memory) - - if not getattr(self, '_gen_test', None): - if len(self.data.testing) == 0: - raise BadProcess('No testing images have been added') - # Initialise generator for testing images - self._gen_test = ImageLabel(self.data.testing, - batch_size=p.batch_size, - aug=Augmenter(), - preprocess=(robust_norm, seg_norm), - in_memory=p.in_memory) - - return TrainValTestProperty(self._gen_train, self._gen_val, - self._gen_test) - - def plot_gen_sample(self, validation=False): - # TODO: Move to flattener? - g = self.gen.val if validation else self.gen.train - g.aug = self.aug.val if validation else self.aug.train - img_batch, lbl_batch = g[0] - lbl_batch = np.concatenate(lbl_batch, axis=3) - - f = self.flattener - target_names = f.names() - edge_inds = np.flatnonzero([t.prop == 'edge' for t in f.targets]) - - ncol = len(img_batch) - nrow = len(target_names) + 1 - fig = plt.figure(figsize=(3 * ncol, 3 * nrow)) - for b, (bf, seg) in enumerate(zip(img_batch, lbl_batch)): - plt.subplot(nrow, ncol, b + 0 * ncol + 1) - plt.imshow(bf[:, :, 0], cmap='gray') - plt.imshow(colour_segstack(seg[:, :, edge_inds])) - - for i, name in enumerate(target_names): - plt.subplot(nrow, ncol, b + (i + 1) * ncol + 1) - plt.imshow(seg[:, :, i], cmap='gray') - plt.title(name) - - fig.savefig(self.save_dir / '{}_generator_sample.png'.format( - 'validation' if validation else 'training')) - - def generate_smoothing_sigma_stats(self): - # train_gen = augmented_generator(self.gen.train, lambda x, y: (x, y)) - # val_gen = augmented_generator(self.gen.train, lambda x, y: (x, y)) - self.smoothing_sigma_trainer.generate_smoothing_sigma_stats( - self.gen.train, self.gen.val) + def fit_smoothing_model(self, filt='identity'): + try: + self.smoothing_sigma_trainer.stats + except BadProcess: + self.smoothing_sigma_trainer.generate_smoothing_sigma_stats() + self.smoothing_sigma_trainer.fit(filt=filt) - @property - def smoothing_sigma_stats(self): + def plot_fitted_smoothing_model(self): warnings.warn( - "nursery.smoothing_sigma_stats will soon be " - "deprecated, use nursery.smoothing_sigma_trainer.stats " - "instead", DeprecationWarning) - return self.smoothing_sigma_trainer.stats + "nursery.plot_fitted_smoothing_sigma_model will soon be " + "deprecated, use " + "nursery.smoothing_signa_trainer.plot_fitted_model " + "instead", DeprecationWarning) + self.smoothing_sigma_trainer.plot_fitted_model() @property def smoothing_sigma_model(self): return self.smoothing_sigma_trainer.model - def generate_flattener_stats(self, max_erode=5): - warnings.warn( - "nursery.generate_flattener_stats will soon be " - "deprecated, use " - "nursery.flattener_trainer.generate_stats(train_gen, " - "val_gen, train_aug, val_aug, max_erode=5) instead", - DeprecationWarning) - - flattener = lambda x, y: x - # NB: use isval=True for training aug since we do not need extra - # augmentations for calibrating the flattener - tAug = _std_aug(self.smoothing_sigma_model, flattener, - self.parameters, isval=True) - vAug = _std_aug(self.smoothing_sigma_model, - flattener, - self.parameters, - isval=True) - - self.flattener_trainer.generate_flattener_stats(*self.gen[:2], - tAug, - vAug, - max_erode=max_erode) + @smoothing_sigma_model.setter + def smoothing_sigma_model(self, m): + self.smoothing_sigma_trainer.model = m - @property - def flattener_stats(self): - warnings.warn( - "nursery.flattener_stats will soon be " - "deprecated, use nursery.flattener_trainer.stats " - "instead", DeprecationWarning) - return self.flattener_trainer.stats + def fit_flattener(self, max_erode=5, **kwargs): + try: + self.flattener_trainer.stats + except BadProcess: + self.flattener_trainer.generate_flattener_stats( + max_erode=max_erode) + self.flattener_trainer.fit(**kwargs) + + def plot_flattener_stats(self, **kwargs): + self.flattener_trainer.plot_stats(**kwargs) @property def flattener(self): - warnings.warn( - "nursery.flattener will soon be " - "deprecated, use nursery.flattener_trainer.flattener " - "instead", DeprecationWarning) return self.flattener_trainer.flattener @flattener.setter def flattener(self, f): - warnings.warn( - "nursery.flattener will soon be " - "deprecated, use nursery.flattener_trainer.flattener " - "instead", DeprecationWarning) self.flattener_trainer.flattener = f - @property - def aug(self): - p = self.parameters - t = _std_aug(self.smoothing_sigma_model, self.flattener, - self.parameters) - v = _std_aug(self.smoothing_sigma_model, - self.flattener, - self.parameters, - isval=True) - w = _std_aug(self.smoothing_sigma_model, - self.flattener, - self.parameters, - isval=True) - return TrainValTestProperty(t, v, w) - - @property - def cnn_fn(self): - warnings.warn( - "nursery.cnn_fn will soon be " - "deprecated, use nursery.cnn_trainer.cnn_fn " - "instead", DeprecationWarning) - return self.cnn_trainer.cnn_fn - - @cnn_fn.setter - def cnn_fn(self, fn): - warnings.warn( - "nursery.cnn_fn will soon be " - "deprecated, use nursery.cnn_trainer.cnn_fn " - "instead", DeprecationWarning) - self.cnn_trainer.cnn_fn = fn - - @property - def cnn_dir(self): - warnings.warn( - "nursery.cnn_dir will soon be " - "deprecated, use nursery.cnn_trainer.cnn_dir " - "instead", DeprecationWarning) - return self.cnn_trainer.cnn_dir - - @property - def cnn_name(self): - warnings.warn( - "nursery.cnn_bane will soon be " - "deprecated, use nursery.cnn_trainer.cnn_name " - "instead", DeprecationWarning) - return self.cnn_trainer.cnn_name - - @property - def cnn(self): - warnings.warn( - "nursery.cnn will soon be " - "deprecated, use nursery.cnn_trainer.cnn " - "instead", DeprecationWarning) - return self.cnn_trainer.cnn - - @property - def histories(self): - warnings.warn( - "nursery.histories will soon be " - "deprecated, use nursery.cnn_trainer.histories " - "instead", DeprecationWarning) - return self.cnn_trainer.histories - - @property - def cnn_opt_dir(self): - warnings.warn( - "nursery.opt_dir will soon be " - "deprecated, use nursery.cnn_trainer.opt_dir " - "instead", DeprecationWarning) - return self.cnn_trainer.opt_dir - - @property - def cnn_opt(self): - warnings.warn( - "nursery.cnn_opt will soon be " - "deprecated, use nursery.cnn_trainer.opt_cnn " - "instead", DeprecationWarning) - return self.cnn_trainer.opt_cnn - - def fit_smoothing_model(self, filt='identity'): - warnings.warn( - "nursery.fit_smoothing_model will soon be " - "deprecated, use nursery.smoothing_signa_trainer.fit " - "instead", DeprecationWarning) - self.smoothing_sigma_trainer.fit(filt=filt) - - def plot_fitted_smoothing_sigma_model(self): - warnings.warn( - "nursery.plot_fitted_smoothing_sigma_model will soon be " - "deprecated, use " - "nursery.smoothing_signa_trainer.plot_fitted_model " - "instead", DeprecationWarning) - self.smoothing_sigma_trainer.plot_fitted_model() - - def fit_flattener(self, **kwargs): - warnings.warn( - "nursery.fit_flattener will soon be " - "deprecated, use nursery.flattener_trainer.fit " - "instead", DeprecationWarning) - self.flattener_trainer.fit(**kwargs) - - def plot_flattener_stats(self, **kwargs): - warnings.warn( - "nursery.plot_flattener_stats will soon be " - "deprecated, use nursery.flattener_trainer.plot_stats " - "instead", DeprecationWarning) - self.flattener_trainer.plot_stats(**kwargs) + def plot_generator_sample(self, **kwargs): + self.flattener_trainer.plot_default_gen_sample(**kwargs) def fit_cnn(self, **kwargs): - warnings.warn( - "nursery.fit_cnn will soon be " - "deprecated, use nursery.cnn_trainer.fit " - "instead", DeprecationWarning) self.cnn_trainer.fit(**kwargs) - def plot_histories(self, **kwargs): - warnings.warn( - "nursery.plot_histories will soon be " - "deprecated, use nursery.cnn_trainer.plot_histories " - "instead", DeprecationWarning) + def plot_cnn_histories(self, **kwargs): self.cnn_trainer.plot_histories(**kwargs) - # TODO: move to Segmentation Param Trainer @property - def seg_examples(self): - p = self.parameters - a = ScalingAugmenter(self.smoothing_sigma_model, - lambda lbl, _: lbl, - xy_out=p.xy_out, - target_pixel_size=p.target_pixel_size, - substacks=p.substacks, - p_noop=1, - probs={ - 'vshift': 0.25, - 'hshift': 0.25 - }) - - def seg_example_aug(img, lbl): - # Assume that the label preprocessing function also returns info - _, info = lbl - img, lbl = a(img, lbl) - # In this case, always prefer the validation augmenter - return img, lbl > 0, info - - def example_generator(dgen): - opt_cnn = self.cnn_opt - b_iter = batch_iterator(list(range(dgen.n_pairs)), - batch_size=dgen.batch_size) - with tqdm(total=dgen.n_pairs) as pbar: - for b_inds in b_iter: - batch = [ - dgen.get_by_index(b, aug=seg_example_aug) - for b in b_inds - ] - preds = split_batch_pred( - opt_cnn.predict(np.stack([img for img, _, _ in batch - ]))) - for pred, (img, lbl, info) in zip(preds, batch): - pbar.update() - lbl = lbl.transpose(2, 0, 1) - # Filter out examples that have been augmented away - valid = lbl.sum(axis=(1, 2)) > 0 - lbl = lbl[valid] - clab = info.get('cellLabels', []) or [] - if type(clab) is int: - clab = [clab] - clab = [l for l, v in zip(clab, valid) if v] - info['cellLabels'] = clab - buds = info.get('buds', []) or [] - if type(buds) is int: - buds = [buds] - buds = [b for b, v in zip(buds, valid) if v] - info['buds'] = buds - yield SegExample(pred, lbl, info, img) - - if self.in_memory: - if getattr(self, '_seg_examples', None) is None: - self._seg_examples = TrainValTestProperty( - list(example_generator(self.gen.train)), - list(example_generator(self.gen.val)), - list(example_generator(self.gen.test))) - return TrainValTestProperty((e for e in self._seg_examples.train), - (e for e in self._seg_examples.val), - (e for e in self._seg_examples.test)) - else: - self._seg_examples = None - return TrainValTestProperty(example_generator(self.gen.train), - example_generator(self.gen.val), - example_generator(self.gen.test)) + def cnn_dir(self): + """Directory containing saved weights for the optimised CNN.""" + return self.cnn_trainer.opt_dir - # TODO Move to Segmentation Parameter TRainer - @property - def seg_param_stats(self): - if getattr(self, '_seg_param_stats', None) is None: - p = self.parameters - stats_file = self.save_dir / p.segmentation_stats_file - if not stats_file.is_file(): - raise BadProcess('"fit_seg_params" has not been run yet') - self._seg_param_stats = pd.read_csv(stats_file, index_col=0) - return self._seg_param_stats - - # TODO Move to Segmentation Parameter TRainer @property - def seg_params(self): - params_file = self.save_dir / self.parameters.segmentation_param_file - with open(params_file, 'rt') as f: - params = json.load(f) - return params - - # TODO Move to Segmentation Parameter TRainer - @seg_params.setter - def seg_params(self, val): - if not type(val) == dict: - raise BadParam('"seg_params" should be a "dict"') - msg_args = inspect.getfullargspec(MorphSegGrouped.__init__).args - if not set(val.keys()).issubset(msg_args): - raise BadParam( - '"seg_params" must specify arguments to "MorphSegGrouped"') - params_file = self.save_dir / self.parameters.segmentation_param_file - with open(params_file, 'wt') as f: - json.dump(jsonify(val), f) - - # TODO Deprecate and keep in bud_trainer - def generate_bud_stats(self): - self.bud_trainer.generate_property_table(self.seg_examples, - self.flattener) - - # TODO Depcrecate and keep in bud_trainer - def fit_bud_model(self, **kwargs): - self.bud_trainer.explore_hyperparams(**kwargs) - model_file = self.save_dir / self.parameters.mother_bud_model_file - self.bud_trainer.save_model(model_file) - - #Todo Move to Segmentation Parameter Trainer - def fit_seg_params(self, njobs=5, scoring='F0_5'): - param_grid = list(product(*seg_param_coords.values())) - basic_pars = list(seg_param_coords.keys()) - - # TODO switch back to validation examples - val_examples = list(self.seg_examples.val) - from joblib import Parallel, delayed - rows = [] - for gind in range(3)[::-1]: - rows.extend( - Parallel(n_jobs=njobs)( - delayed(_seg_filter_optim)(gind, - pars, - basic_pars, - self.flattener, - val_examples, - base_params=base_seg_params, - scoring=scoring) - for pars in tqdm(param_grid))) - - rows_expanded = [ - dict( - chain(*[[('_'.join((k, str(g))), gv) - for g, gv in enumerate(v)] if type(v) == list else [( - k, v)] - for k, v in chain([ - ('group', row['group']), ('score', row['score']) - ], row['basic'].items(), row['filter'].items())])) - for row in rows - ] - - self._seg_param_stats = pd.DataFrame(rows_expanded) - stats_file = self.save_dir / self.parameters.segmentation_stats_file - self._seg_param_stats.to_csv(stats_file) - - self.refit_filter_seg_params(scoring=scoring) - - # TODO move to Segmentation Parameter Trainer - def refit_filter_seg_params(self, - lazy=False, - bootstrap=False, - scoring='F0_5'): - - # Merge the best parameters from each group into a single parameter set - merged_params = {k: v.copy() for k, v in base_seg_params.items()} - stats = self.seg_param_stats - for g, r in enumerate(stats.groupby('group').score.idxmax()): - for k in merged_params: - merged_params[k][g] = stats.loc[r, k + '_' + str(g)] - - sfpo = SegFilterParamOptim(self.flattener, - basic_params=merged_params, - scoring=scoring) - sfpo.generate_stat_table(self.seg_examples.val) - - sfpo.fit_filter_params(lazy=lazy, bootstrap=bootstrap) - merged_params.update(sfpo.opt_params) - self.seg_params = merged_params - - # TODO Move to Segmentation Parameter Trainer - def validate_seg_params(self, iou_thresh=0.7, save=True): - segmenter = MorphSegGrouped(self.flattener, - return_masks=True, - fit_radial=True, - use_group_thresh=True, - **self.seg_params) - edge_inds = [ - i for i, t in enumerate(self.flattener.targets) - if t.prop == 'edge' - ] - stats = {} - dfs = {} - for k, seg_exs in zip(self.seg_examples._fields, self.seg_examples): - stats[k] = [] - for seg_ex in seg_exs: - seg = segmenter.segment(seg_ex.pred, refine_outlines=True) - edge_scores = np.array([ - seg_ex.pred[edge_inds, ...].max(axis=0)[s].mean() - for s in seg.edges - ]) - IoUs = calc_IoUs(seg_ex.target, seg.masks) - bIoU, _ = best_IoU(IoUs) - stats[k].append((edge_scores, IoUs, np.mean(bIoU), - np.min(bIoU, initial=1), - calc_AP(IoUs, - probs=edge_scores, - iou_thresh=iou_thresh)[0])) - dfs[k] = pd.DataFrame([s[2:] for s in stats[k]], - columns=['IoU_mean', 'IoU_min', 'AP']) - - print({k: df.mean() for k, df in dfs.items()}) - - nrows = len(dfs) - ncols = dfs['val'].shape[1] - fig, axs = plt.subplots(nrows=nrows, - ncols=ncols, - figsize=(ncols * 4, nrows * 4)) - for axrow, (k, df) in zip(axs, dfs.items()): - for ax, col in zip(axrow, df.columns): - ax.hist(df.loc[:, col], bins=26, range=(0, 1)) - ax.set(xlabel=col, title=k) - if save: - fig.savefig(self.save_dir / 'seg_validation_plot.png') - plt.close(fig) - + def cnn(self): + """Optimised CNN from the ``cnn_trainer``.""" + return self.cnn_trainer.opt_cnn class Nursery(BabyTrainer): pass -def load_history(subdir): - with open(LOG_DIR / subdir / 'history.pkl', 'rb') as f: - return pickle.load(f) - - def get_best_and_worst(model, gen): best = {} worst = {} @@ -843,39 +332,3 @@ def get_best_and_worst(model, gen): worst[output][worst_maxind] = out return best, worst - - -def _seg_filter_optim(g, - p, - pk, - flattener, - seg_gen, - base_params=default_params, - scoring='F0_5'): - p = _sub_params({(k, g): v for k, v in zip(pk, p)}, base_params) - sfpo = SegFilterParamOptim(flattener, basic_params=p, scoring=scoring) - sfpo.generate_stat_table(seg_gen) - sfpo.fit_filter_params(lazy=True, bootstrap=False) - return { - 'group': g, - 'basic': p, - 'filter': sfpo.opt_params, - 'score': sfpo.opt_score - } - - -def _std_aug(ssm, flattener, p, isval=False): - probs = {'vshift': 0.25, 'hshift': 0.25} - extra_args = {} - if isval: - extra_args['p_noop'] = 1 - else: - probs['rotate'] = 0.2 - - return ScalingAugmenter(ssm, - flattener, - xy_out=p.xy_out, - target_pixel_size=p.target_pixel_size, - substacks=p.substacks, - probs=probs, - **extra_args) diff --git a/python/baby/training/utils.py b/python/baby/training/utils.py index 90dbb2fea7ef229d6b712aa732e6430a6c990367..d0730044a3911f95d3d2351809b9a4946788c95f 100644 --- a/python/baby/training/utils.py +++ b/python/baby/training/utils.py @@ -1,12 +1,14 @@ # If you publish results that make use of this software or the Birth Annotator # for Budding Yeast algorithm, please cite: -# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S -# Swain, 2021, Birth Annotator for Budding Yeast (in preparation). +# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +# and Swain, P.S. (2023). Determining growth rates from bright-field images of +# budding cells through identifying overlaps. eLife. 12:e79812. +# https://doi.org/10.7554/eLife.79812 # # # The MIT License (MIT) # -# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021 +# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to @@ -25,13 +27,24 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS # IN THE SOFTWARE. -from contextlib import contextmanager from typing import NamedTuple, Any, Tuple, Union +from itertools import repeat +import shutil +import json +from pathlib import Path import tensorflow as tf -from baby.augmentation import Augmenter -from baby.generator import ImageLabel -from baby.utils import EncodableNamedTuple +from tensorflow.keras.optimizers.schedules import CosineDecay + +from baby.errors import BadFile, BadParam, BadProcess, BadType +from baby.io import TrainValTestPairs +from baby.preprocessing import (robust_norm, robust_norm_dw, seg_norm, + SegmentationFlattening) +from baby.augmentation import Augmenter, ScalingAugmenter, SmoothingSigmaModel +from baby.generator import ImageLabel, AugmentedGenerator +from baby.utils import (find_file, EncodableNamedTuple, + jsonify, as_python_object) +from baby.morph_thresh_seg import SegmentationParameters def fix_tf_rtx_gpu_bug(): @@ -52,10 +65,10 @@ def fix_tf_rtx_gpu_bug(): logical_gpus = tf.config.experimental.list_logical_devices('GPU') print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs") - else: - raise Exception( - 'Unsupported version of tensorflow encountered ({})'.format( - tf.version.VERSION)) + else: + raise Exception( + 'Unsupported version of tensorflow encountered ({})'.format( + tf.version.VERSION)) class TrainValProperty(NamedTuple): @@ -83,24 +96,379 @@ class BabyTrainerParameters(NamedTuple): segmentation_param_file: str = 'segmentation_params.json' mother_bud_props_file: str = 'mother_bud_props.csv' mother_bud_model_file: str = 'mother_bud_model.pkl' - cnn_set: Tuple[str, ...] = ('msd_d80', 'unet_4s') + cnn_set: Tuple[str, ...] = ('unet_4s',) cnn_fn: Union[None, str] = None batch_size: int = 8 in_memory: bool = True xy_out: int = 80 target_pixel_size: float = 0.263 substacks: Union[None, int] = None + aug_probs: dict = {} + aug_p_noop: float = 0.05 + base_seg_params: dict = {} + seg_param_coords: dict = {} + input_norm_dw: bool = False + only_basic_augs: bool = True + balanced_sampling: bool = False + use_sample_weights: bool = False + canny_padding: int = 2 + n_jobs: int = 4 + + +TRAINING_PARAMETERS_FILENAME = 'training_parameters.json' +"""File name to which training parameters are saved""" + + +class SharedParameterContainer(object): + """A container of current training-related parameters. + + Designed to be used by :py:class:`training.BabyTrainer` and passed to the + children trainers. Updates on this object by the + :py:class:`training.BabyTrainer` are then automatially propagated to the + classes that use the parameters. + + Parameters are auto-saved to a file + :py:const:`TRAINING_PARAMETERS_FILENAME` within the ``save_dir``. + + Args: + save_dir (str or Path): directory in which to save parameters and/or + from which to auto-load parameters. + params (None or BabyTrainerParameters or str or Path): Optionally + specify initial training parameters as a + :py:class:`BabyTrainerParameters` instance or the path to a saved + :py:class:`BabyTrainerParameters` instance. + """ + def __init__(self, save_dir, params=None): + self.save_dir = Path(save_dir) + self._parameters_file = TRAINING_PARAMETERS_FILENAME + + # Trigger parameter loading/initialisation via property + self.parameters + self.segmentation_parameters + + # Register parameters if specified + if isinstance(params, BabyTrainerParameters): + self.parameters = params + elif params is not None: + filename = find_file(params, save_dir, 'params') + savename = save_dir / self._parameters_file + if filename != savename: + shutil.copy(filename, savename) + + @property + def parameters(self): + """A :py:class:`BabyTrainerParameters` instance. + + If not already initialised, parameters are loaded from the save file + if found, otherwise they are initialised to defaults as per + :py:class:`BabyTrainerParameters`. + + This can be set to either a new ``BabyTrainerParameters`` instance, or + to a ``dict`` that specifies existing parameter values to replace. + """ + if not getattr(self, '_parameters', None): + param_file = self.save_dir / self._parameters_file + if param_file.is_file(): + with open(param_file, 'rt') as f: + params = json.load(f, object_hook=as_python_object) + if not isinstance(params, BabyTrainerParameters): + raise BadFile('Parameters file has been corrupted') + self._parameters = params + else: + self.parameters = BabyTrainerParameters() + return self._parameters + + @parameters.setter + def parameters(self, params): + if isinstance(params, dict): + if not getattr(self, '_parameters', None): + self._parameters = BabyTrainerParameters() + params = self._parameters._replace(**params) + elif not isinstance(params, BabyTrainerParameters): + params = BabyTrainerParameters(*params) + self._parameters = params + with open(self.save_dir / self._parameters_file, 'wt') as f: + json.dump(jsonify(self._parameters), f) + + @property + def segmentation_parameters(self): + """A :py:class:`baby.segmentation.SegmentationParameters` instance. + + If not already initialised, parameters are loaded from the save file + if found, otherwise they are initialised to defaults as per + :py:class:`baby.segmentation.SegmentationParameters`. + + This can be set to either a new ``SegmentationParameters`` instance, + or to a ``dict`` that specifies existing parameter values to replace. + """ + if not getattr(self, '_segmentation_parameters', None): + seg_param_file = (self.save_dir / + self.parameters.segmentation_param_file) + if seg_param_file.is_file(): + with open(seg_param_file, 'rt') as f: + params = json.load(f, object_hook=as_python_object) + if not isinstance(params, SegmentationParameters): + raise BadFile( + 'Segmentation parameters file has been corrupted.') + self._segmentation_parameters = params + else: + self.segmentation_parameters = SegmentationParameters() + return self._segmentation_parameters + + @segmentation_parameters.setter + def segmentation_parameters(self, params): + if isinstance(params, dict): + if not getattr(self, '_segmentation_parameters', None): + self._segmentation_parameters = SegmentationParameters() + params = self._segmentation_parameters._replace(**params) + elif not isinstance(params, SegmentationParameters): + params = SegmentationParameters(*params) + self._segmentation_parameters = params + seg_param_file = (self.save_dir / + self.parameters.segmentation_param_file) + with open(seg_param_file, 'wt') as f: + json.dump(jsonify(self._segmentation_parameters), f, indent=2) + + +class SharedDataContainer(object): + """A container of current data and generators. + + Designed to be used by :py:class:`training.BabyTrainer` and passed to the + children trainers. Updates on this object by the + :py:class:`training.BabyTrainer` are then automatially propagated to the + classes that use the parameters. + + Parameters are auto-saved to a file + :py:const:`TRAINING_PARAMETERS_FILENAME` within the ``save_dir``. + + Args: + shared_params (SharedParameterContainer): Shared parameters. + base_dir (None or str or Path): Base directory within which all + relevant image files can be found. References to the image files + will be saved relative to this directory. + """ + def __init__(self, shared_params, base_dir=None): + self._shared_params = shared_params + + if base_dir is not None: + base_dir = Path(base_dir) + if not base_dir.is_dir(): + raise BadParam('"base_dir" must be a valid directory or None') + else: + base_dir = Path.cwd() + self.base_dir = base_dir + @property + def save_dir(self): + return self._shared_params.save_dir + @property + def parameters(self): + return self._shared_params.parameters -@contextmanager -def augmented_generator(gen: ImageLabel, aug: Augmenter): - # Save the previous augmenter if any - saved_aug = gen.aug - gen.aug = aug - try: - yield gen - # Todo: add except otherwise there might be an issue of there is an error? - finally: - gen.aug = saved_aug + def _check_for_data_update(self): + if getattr(self, '_ncells', None) != self._impairs.ncells: + # Reset generators + self._gen_train = None + self._gen_val = None + self._gen_test = None + # Trigger save of the data + datafile = self.save_dir / self.parameters.train_val_test_pairs_file + self._impairs.save(datafile, self.base_dir) + self._ncells = self._impairs.ncells + def _tracker_check_for_data_update(self): + # And for thet tracker datasets too + if (getattr(self, '_tracker_ncells', None) + != self._tracker_impairs.ncells): + datafile = self.save_dir / self.parameters.tracker_tvt_pairs_file + self._tracker_impairs.save(datafile, self.base_dir) + self._tracker_ncells = self._tracker_impairs.ncells + + @property + def data(self): + if not hasattr(self, '_impairs') or not self._impairs: + self._impairs = TrainValTestPairs() + pairs_file = self.save_dir / self.parameters.train_val_test_pairs_file + if pairs_file.is_file(): + self._impairs.load(pairs_file, self.base_dir) + self._check_for_data_update() + return self._impairs + + @data.setter + def data(self, train_val_test_pairs): + if isinstance(train_val_test_pairs, str): + pairs_file = find_file(train_val_test_pairs, self.save_dir, + 'data') + train_val_test_pairs = TrainValTestPairs() + train_val_test_pairs.load(pairs_file, self.base_dir) + if not isinstance(train_val_test_pairs, TrainValTestPairs): + raise BadType( + '"data" must be of type "baby.io.TrainValTestPairs"') + self._impairs = train_val_test_pairs + self._check_for_data_update() + + @property + def tracker_data(self): + if not hasattr(self, '_impairs') or not self._impairs: + self._tracker_impairs = TrainValTestPairs() + pairs_file = self.save_dir / self.parameters.tracker_tvt_pairs_file + if pairs_file.is_file(): + self._tracker_impairs.load(pairs_file, self.base_dir) + self._check_for_data_update() + return self._tracker_impairs + + @data.setter + def tracker_data(self, train_val_test_pairs): + if isinstance(train_val_test_pairs, str): + pairs_file = find_file(train_val_test_pairs, self.save_dir, + 'data') + train_val_test_pairs = TrainValTestPairs() + train_val_test_pairs.load(pairs_file, self.base_dir) + if not isinstance(train_tvt_pairs, TrainValTestPairs): + raise BadType( + '"data" must be of type "baby.io.TrainValTestPairs"') + self._tracker_impairs = train_val_test_pairs + self._tracker_check_for_data_update() + + @property + def gen(self): + """Training, validation and test data generators. + + This attribute provides three :py:class:`ImageLabel` generators as a + :py:class:`TrainValTestProperty`, with each generator assigned just a + dummy augmenter to begin with. + + Note: + Generator initialisation requires that all specified images exist. + """ + + old_gen_params = getattr(self, '_current_gen_params', None) + new_gen_params = tuple(getattr(self.parameters, p) for p in + ('in_memory', 'input_norm_dw', 'batch_size', + 'balanced_sampling', 'use_sample_weights')) + if old_gen_params != new_gen_params: + self._gen_train = None + self._gen_val = None + self._gen_test = None + self._current_gen_params = new_gen_params + + (in_memory, input_norm_dw, batch_size, + balanced_sampling, use_sample_weights) = new_gen_params + input_norm = robust_norm_dw if input_norm_dw else robust_norm + + if not getattr(self, '_gen_train', None): + if len(self.data.training) == 0: + raise BadProcess('No training images have been added') + # Initialise generator for training images + self._gen_train = ImageLabel(self.data.training, + batch_size=batch_size, + aug=Augmenter(), + preprocess=(input_norm, seg_norm), + in_memory=in_memory, + balanced_sampling=balanced_sampling, + use_sample_weights=use_sample_weights) + + if not getattr(self, '_gen_val', None): + if len(self.data.validation) == 0: + raise BadProcess('No validation images have been added') + # Initialise generator for validation images + self._gen_val = ImageLabel(self.data.validation, + batch_size=batch_size, + aug=Augmenter(), + preprocess=(input_norm, seg_norm), + in_memory=in_memory, + balanced_sampling=balanced_sampling, + use_sample_weights=use_sample_weights) + + if not getattr(self, '_gen_test', None): + if len(self.data.testing) == 0: + raise BadProcess('No testing images have been added') + # Initialise generator for testing images + self._gen_test = ImageLabel(self.data.testing, + batch_size=batch_size, + aug=Augmenter(), + preprocess=(input_norm, seg_norm), + in_memory=in_memory, + balanced_sampling=balanced_sampling, + use_sample_weights=use_sample_weights) + + self._gen_train.n_jobs = self.parameters.n_jobs + self._gen_val.n_jobs = self.parameters.n_jobs + self._gen_test.n_jobs = self.parameters.n_jobs + return TrainValTestProperty(self._gen_train, self._gen_val, + self._gen_test) + + def gen_with_aug(self, aug): + """Returns generators wrapped with alternative augmenters. + + Args: + aug (Augmenter or Tuple[Augmenter, Augmenter, Augmenter]): + Augmenter to use, or tuple of different augmenters for + training, validation and testing generators. + + Returns: + :py:class:`TrainValTestProperty` of :py:class:`AugmentedGenerator` + objects for training, validation and testing generators. + """ + atrain, aval, atest = aug if type(aug) == tuple else repeat(aug, 3) + gtrain, gval, gtest = self.gen + return TrainValTestProperty( + AugmentedGenerator(gtrain, atrain), + AugmentedGenerator(gval, aval), + AugmentedGenerator(gtest, atest)) + + +VALIDATION_AUGMENTATIONS = {'vshift', 'hshift'} + + +def standard_augmenter(ssm, flattener, params, isval=False): + """Returns an augmenter for training on flattenable inputs. + + Args: + ssm (SmoothingSigmaModel): Smoothing model to use. + flattener (SegmentationFlattening): Flattener to apply after + augmentation. + params (BabyTrainerParameters): Parameters to use when constructing + augmenter. + isval (bool): If ``True``, set unspecified augmentation probabilities + to zero for generating validation data. If ``False``, augmentation + probabilities are left at defaults (see also + :py:class:`ScalingAugmenter` and + :py:class:`BabyTrainerParameters`). + + Returns: + A :py:class:`ScalingAugmenter` with specified ``ssm`` and + ``flattener`` and parameterisation according to ``params``. + """ + probs = {'vshift': 0.25, 'hshift': 0.25} + extra_args = dict(canny_padding=params.canny_padding, + p_noop=params.aug_p_noop) + if isval: + extra_args['p_noop'] = 1 + else: + probs['rotate'] = 0.25 + + probs.update(params.aug_probs) + if isval: + probs = {k: v for k, v in probs.items() + if k in VALIDATION_AUGMENTATIONS} + + return ScalingAugmenter(ssm, + flattener, + xy_out=params.xy_out, + target_pixel_size=params.target_pixel_size, + substacks=params.substacks, + probs=probs, + only_basic_augs=params.only_basic_augs, + **extra_args) + + +def warmup_and_cosine_decay(learning_rate=0.001, warmup_steps=30, decay_steps=370): + decay = CosineDecay(learning_rate, decay_steps) + def warmup_schedule(step): + if step < warmup_steps: + return step / warmup_steps * learning_rate + else: + return decay(step - warmup_steps) + return warmup_schedule diff --git a/python/baby/training/v1_hyper_parameter_trainer.py b/python/baby/training/v1_hyper_parameter_trainer.py index 507d8f1139b9fb12e5821214abc165e49b14761a..934d33f9515bc054ca335532cd022208430d5c7b 100644 --- a/python/baby/training/v1_hyper_parameter_trainer.py +++ b/python/baby/training/v1_hyper_parameter_trainer.py @@ -1,12 +1,14 @@ # If you publish results that make use of this software or the Birth Annotator # for Budding Yeast algorithm, please cite: -# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S -# Swain, 2021, Birth Annotator for Budding Yeast (in preparation). +# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +# and Swain, P.S. (2023). Determining growth rates from bright-field images of +# budding cells through identifying overlaps. eLife. 12:e79812. +# https://doi.org/10.7554/eLife.79812 # # # The MIT License (MIT) # -# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021 +# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to diff --git a/python/baby/utils.py b/python/baby/utils.py index 02f5bc890ada72420d1bace21931ea0c69b9b81b..ef6ef5c1dd7c9e087d92f011a75db6475517146f 100644 --- a/python/baby/utils.py +++ b/python/baby/utils.py @@ -1,12 +1,14 @@ # If you publish results that make use of this software or the Birth Annotator # for Budding Yeast algorithm, please cite: -# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S -# Swain, 2021, Birth Annotator for Budding Yeast (in preparation). +# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +# and Swain, P.S. (2023). Determining growth rates from bright-field images of +# budding cells through identifying overlaps. eLife. 12:e79812. +# https://doi.org/10.7554/eLife.79812 # # # The MIT License (MIT) # -# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021 +# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to @@ -67,26 +69,13 @@ def get_name(obj): return getattr(obj, '_baby_name', obj.__name__) -def NamedTupleToJSON(self): - return { - '_python_NamedTuple': self._asdict(), - '__module__': self.__class__.__module__, - '__class__': self.__class__.__name__ - } - - -def EncodableNamedTuple(obj): - obj.toJSON = NamedTupleToJSON - return obj - - def jsonify(obj): if hasattr(obj, 'toJSON'): return obj.toJSON() elif hasattr(obj, 'dtype') and hasattr(obj, 'tolist'): return obj.tolist() elif isinstance(obj, tuple): - return {'_python_tuple': list(obj)} + return {'_python_tuple': [jsonify(v) for v in obj]} if isinstance(obj, set): return {'_python_set': list(obj)} elif isinstance(obj, dict): @@ -97,6 +86,19 @@ def jsonify(obj): return obj +def NamedTupleToJSON(self): + return { + '_python_NamedTuple': jsonify(self._asdict()), + '__module__': self.__class__.__module__, + '__class__': self.__class__.__name__ + } + + +def EncodableNamedTuple(obj): + obj.toJSON = NamedTupleToJSON + return obj + + def as_python_object(obj): if '_python_NamedTuple' in obj: obj_class = getattr(import_module(obj['__module__']), obj['__class__']) diff --git a/python/baby/visualise.py b/python/baby/visualise.py index 524b48fd399d220ac1c16aa5a6d6c846b80de9a0..4ad8ac94bfaee53bc6a8e322d56ed6ddadcd65dd 100644 --- a/python/baby/visualise.py +++ b/python/baby/visualise.py @@ -1,12 +1,14 @@ # If you publish results that make use of this software or the Birth Annotator # for Budding Yeast algorithm, please cite: -# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S -# Swain, 2021, Birth Annotator for Budding Yeast (in preparation). +# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +# and Swain, P.S. (2023). Determining growth rates from bright-field images of +# budding cells through identifying overlaps. eLife. 12:e79812. +# https://doi.org/10.7554/eLife.79812 # # # The MIT License (MIT) # -# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021 +# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to @@ -80,8 +82,9 @@ def plot_ims(ims, size=4, cmap=plt.cm.gray, show=True, dw=False, **kwargs): if dw: ims = ims.transpose([2, 0, 1]) + aspect = ims.shape[2] / ims.shape[1] ncols = len(ims) - fig, axs = plt.subplots(1, ncols, figsize=(ncols * size, size), + fig, axs = plt.subplots(1, ncols, figsize=(aspect * ncols * size, size), squeeze=False) axs = axs[0] for ax, im in zip(axs, ims): diff --git a/python/baby/volume.py b/python/baby/volume.py index d62c03f297ae56b7e84fb5d1933ba8b74373f1ac..53ad6ed7a78806fbb8494b7514c8acb674ef73b6 100644 --- a/python/baby/volume.py +++ b/python/baby/volume.py @@ -1,12 +1,14 @@ # If you publish results that make use of this software or the Birth Annotator # for Budding Yeast algorithm, please cite: -# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S -# Swain, 2021, Birth Annotator for Budding Yeast (in preparation). +# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +# and Swain, P.S. (2023). Determining growth rates from bright-field images of +# budding cells through identifying overlaps. eLife. 12:e79812. +# https://doi.org/10.7554/eLife.79812 # # # The MIT License (MIT) # -# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021 +# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to diff --git a/setup.py b/setup.py deleted file mode 100644 index 23a3df5e852764e62540fa6bf93a17221a591201..0000000000000000000000000000000000000000 --- a/setup.py +++ /dev/null @@ -1,63 +0,0 @@ -from setuptools import setup, find_packages - -setup( - name='baby', - version='0.24', - packages=find_packages('python'), - package_dir={'': 'python'}, - include_package_data=True, - entry_points={ - 'console_scripts': [ - 'baby-phone = baby.server:main', - 'baby-race = baby.speed_tests:main', - 'baby-fit-grs = baby.postprocessing:main' - ] - }, - url='', - license='MIT License', - author='Julian Pietsch', - author_email='julian.pietsch@ed.ac.uk', - description='Birth Annotator for Budding Yeast', - long_description=''' -If you publish results that make use of this software or the Birth Annotator -for Budding Yeast algorithm, please cite: -Julian M J Pietsch, Alán F Muñoz, Diane-Yayra A Adjavon, Ivan B N Clark, Peter -S Swain, 2022, A label-free method to track individuals and lineages of -budding cells (in submission). - - -The MIT License (MIT) - -Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2022 - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. - ''', - install_requires=['tensorflow>=1.14,<2.4', - 'scipy', - 'numpy', - 'pandas', - 'scikit-image', - 'scikit-learn==0.22.2', - 'tqdm', - 'imageio', - 'pillow<9', - 'matplotlib', - 'aiohttp', - 'gaussianprocessderivatives'] -) diff --git a/tests/brain_test.py b/tests/brain_test.py index 802b91460758524eb767d5818db652d0f59a3639..f7ca5319bcbc961bb3ab54cb91abde366cf00ff6 100644 --- a/tests/brain_test.py +++ b/tests/brain_test.py @@ -1,12 +1,14 @@ # If you publish results that make use of this software or the Birth Annotator # for Budding Yeast algorithm, please cite: -# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S -# Swain, 2021, Birth Annotator for Budding Yeast (in preparation). +# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +# and Swain, P.S. (2023). Determining growth rates from bright-field images of +# budding cells through identifying overlaps. eLife. 12:e79812. +# https://doi.org/10.7554/eLife.79812 # # # The MIT License (MIT) # -# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021 +# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to @@ -34,13 +36,15 @@ import json import baby from baby.brain import BabyBrain -from baby.morph_thresh_seg import MorphSegGrouped +from baby.morph_thresh_seg import MorphSegGrouped, SegmentationParameters #from .conftest import BASE_DIR -#MODEL_PATH = BASE_DIR / 'models' -MODEL_PATH = baby.model_path() -DEFAULT_MODELSET = 'evolve_brightfield_60x_5z' +DEFAULT_MODELSET = 'yeast-alcatras-brightfield-EMCCD-60x-5z' +TEST_MODELSETS = ['yeast-alcatras-brightfield-EMCCD-60x-5z', + 'yeast-alcatras-brightfield-sCMOS-60x-5z'] + #'ecoli-mm-phase-sCMOS-100x-1z'] + eqlen_outkeys = { 'angles', 'radii', 'cell_label', 'edgemasks', 'ellipse_dims', 'volumes', @@ -52,9 +56,7 @@ eqlen_outkeys = { def bb(modelsets, tf_session_graph): # Attempt to load default evolve model tf_session, tf_graph = tf_session_graph - return BabyBrain(session=tf_session, - graph=tf_graph, - **modelsets[DEFAULT_MODELSET]) + return modelsets.get(DEFAULT_MODELSET, session=tf_session, graph=tf_graph) @pytest.fixture(scope='module') @@ -62,24 +64,44 @@ def imgstack(imgs_evolve60): return np.stack([v['Brightfield'][0] for v in imgs_evolve60.values()]) -def test_modelsets(modelsets): +def test_modelsets(modelsets, verify_all_modelsets): + # Ensure that at least the test model sets are present + all_mset_ids = modelsets.ids() + assert all([mset_id in all_mset_ids for mset_id in TEST_MODELSETS]) + bb_args = inspect.getfullargspec(BabyBrain.__init__).args - msg_args = inspect.getfullargspec(MorphSegGrouped.__init__).args + msg_args = set(SegmentationParameters()._fields) + + if verify_all_modelsets: + modelsets.update('all', force=False) + else: + modelsets.update(TEST_MODELSETS, force=False) - for mset in modelsets.values(): + for mset_info in modelsets.specifications().values(): # Make sure all parameters match the BabyBrain and MorphSegGrouped # signatures + mset = mset_info['brain_params'] assert set(mset.keys()).issubset(bb_args) params = mset.get('params', {}) - assert set(params.keys()).issubset(msg_args) + if type(params) == dict: + assert set(params.keys()).issubset(msg_args) + + share_path = modelsets.LOCAL_MODELSETS_PATH / modelsets.SHARE_PATH + for mset_id, mset_info in modelsets.specifications(local=True).items(): + mset_path = modelsets.LOCAL_MODELSETS_PATH / mset_id + mset = mset_info['brain_params'] + + assert mset_path.is_dir() + params = mset.get('params', {}) + if type(params) != dict and type(params) != SegmentationParameters: + assert ((mset_path / params).is_file() or + (share_path / params).is_file()) # Make sure all model files exist for k, v in mset.items(): if k.endswith('_file'): - assert (MODEL_PATH / v).is_file() - - # Ensure that the default test model is present - assert DEFAULT_MODELSET in modelsets + assert ((mset_path / v).is_file() or + (share_path / v).is_file()) def test_init(bb, imgstack): @@ -92,7 +114,7 @@ def test_init(bb, imgstack): [len(o['centres']) == len(o[k]) for k in o if k in eqlen_outkeys]) -def test_segment(bb, imgstack): +def test_evolve_segment(bb, imgstack): # Test segment with all options enabled output = bb.segment(imgstack, yield_edgemasks=True, @@ -105,6 +127,39 @@ def test_segment(bb, imgstack): [len(o['centres']) == len(o[k]) for k in o if k in eqlen_outkeys]) +def test_prime_segment(bb_prime60, imgs_prime60): + imgstack = np.stack([v['Brightfield'][0] for v in imgs_prime60.values()]) + # Test segment with all options enabled + output = bb_prime60.segment(imgstack, + yield_edgemasks=True, + yield_masks=True, + yield_preds=True, + yield_volumes=True, + refine_outlines=True) + for o in output: + assert all( + [len(o['centres']) == len(o[k]) for k in o if k in eqlen_outkeys]) + + +def test_mm_segment(bb_mmscmos, imgs_mmscmos): + # The sample mother machine sCMOS images have different shapes so cannot + # be stacked, so segment each image separately + for imgpair in imgs_mmscmos.values(): + img = imgpair['Brightfield'][0] + # Test segment with all options enabled except refine_outlines, which is + # not yet available for the cartesian splines used for E. coli + o = bb_mmscmos.segment(img[None, ...], + yield_edgemasks=True, + yield_masks=True, + yield_preds=True, + yield_volumes=True, + refine_outlines=False) + # Expand generator and select first image + o = list(o)[0] + assert all( + [len(o['centres']) == len(o[k]) for k in o if k in eqlen_outkeys]) + + def test_evolve_segment_and_track(bb, imgstack, imgs_evolve60): # Test stateless version output = bb.segment_and_track(imgstack, diff --git a/tests/cnn_test.py b/tests/cnn_test.py index d113088c05b404b49d1fd0fef29d9053cd8a7ba9..8fcfe3bcaf0c8e9d3d0d399cac0766c492e7f1b0 100644 --- a/tests/cnn_test.py +++ b/tests/cnn_test.py @@ -1,12 +1,14 @@ # If you publish results that make use of this software or the Birth Annotator # for Budding Yeast algorithm, please cite: -# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S -# Swain, 2021, Birth Annotator for Budding Yeast (in preparation). +# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +# and Swain, P.S. (2023). Determining growth rates from bright-field images of +# budding cells through identifying overlaps. eLife. 12:e79812. +# https://doi.org/10.7554/eLife.79812 # # # The MIT License (MIT) # -# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021 +# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to @@ -40,18 +42,17 @@ def test_evolve_predict(bb_evolve60, imgs_evolve60, save_cnn_predictions, imgstack = np.stack([robust_norm(*v['Brightfield']) for v in imgs_evolve60.values()]) - preds = bb_evolve60.morph_predict(imgstack) - assert len(preds) == len(bb_evolve60.flattener.names()) - assert all([p_out.shape[:3] == imgstack.shape[:3] for p_out in preds]) + preds = list(bb_evolve60.morph_predict(imgstack)) + assert len(preds) == len(imgstack) + npredchan = len(bb_evolve60.flattener.names()) + assert all([len(pred) == npredchan for pred in preds]) + assert all([pred.shape[1:] == imgstack.shape[1:3] for pred in preds]) - morph_preds = split_batch_pred(preds) - assert len(morph_preds) == len(imgstack) - - assert all([pred.max() <= 1 and pred.min() >= 0 for pred in morph_preds]) + assert all([pred.max() <= 1 and pred.min() >= 0 for pred in preds]) if save_cnn_predictions: # Save prediction output as 16 bit tiled png - for pred, (k, v) in zip(morph_preds, imgs_evolve60.items()): + for pred, (k, v) in zip(preds, imgs_evolve60.items()): _, info = v['Brightfield'] info['channel'] = 'cnnpred' save_tiled_image( diff --git a/tests/conftest.py b/tests/conftest.py index cff9f5cd1e1213b7c1e803b4353fa18dedbe2bea..6e0053a3dc45b63be2938ca972d6bf1437edb77c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,12 +1,14 @@ # If you publish results that make use of this software or the Birth Annotator # for Budding Yeast algorithm, please cite: -# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S -# Swain, 2021, Birth Annotator for Budding Yeast (in preparation). +# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +# and Swain, P.S. (2023). Determining growth rates from bright-field images of +# budding cells through identifying overlaps. eLife. 12:e79812. +# https://doi.org/10.7554/eLife.79812 # # # The MIT License (MIT) # -# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021 +# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to @@ -26,9 +28,11 @@ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS # IN THE SOFTWARE. from pathlib import Path +import os import baby import pytest +from baby import modelsets as modelsets_module from baby.brain import BabyBrain from baby.io import load_paired_images @@ -48,15 +52,17 @@ def pytest_addoption(parser): "--save-crawler-output", action="store_true", default=False, help="When running crawler tests, save the predictions to tmp dir" ) + parser.addoption( + "--verify-all-modelsets", action="store_true", default=False, + help="Download all models to confirm model set files exist" + ) -@pytest.fixture(scope='session') -def model_dir(): - return baby.model_path() @pytest.fixture(scope='session') def image_dir(): return IMAGE_DIR + @pytest.fixture(scope='session') def save_cnn_predictions(request): return request.config.getoption("--save-cnn-predictions") @@ -72,14 +78,25 @@ def save_segment_outlines(request): return request.config.getoption("--save-segment-outlines") +@pytest.fixture(scope='session') +def verify_all_modelsets(request): + return request.config.getoption("--verify-all-modelsets") + + @pytest.fixture(scope='session') def imgs_evolve60(): return load_paired_images(IMAGE_DIR.glob('evolve_*.png')) @pytest.fixture(scope='session') -def modelsets(): - return baby.modelsets() +def imgs_prime60(): + return load_paired_images(IMAGE_DIR.glob('prime95b_*.png')) + + +@pytest.fixture(scope='session') +def imgs_mmscmos(): + return load_paired_images(IMAGE_DIR.glob('mmsCMOS_*.png')) + @pytest.fixture(scope='session') def tf_session_graph(): @@ -111,15 +128,36 @@ def tf_session_graph(): return tf_session, tf_graph +@pytest.fixture(scope='session') +def modelsets(tmp_path_factory): + envpath = os.environ.get(modelsets_module.ENV_VAR_MODELSETS_PATH) + if not envpath: + envpath = tmp_path_factory.mktemp('modelsets') + print(envpath) + # Patch modelsets module with temporary storage path + modelsets_module.ENV_LOCAL_MODELSETS_PATH = envpath + modelsets_module.LOCAL_MODELSETS_PATH = envpath + modelsets_module.LOCAL_MODELSETS_CACHE = ( + envpath / modelsets_module.MODELSETS_FILENAME) + return modelsets_module + + @pytest.fixture(scope='session') def bb_evolve60(modelsets, tf_session_graph): tf_session, tf_graph = tf_session_graph - return BabyBrain(session=tf_session, graph=tf_graph, - **modelsets['evolve_brightfield_60x_5z']) + return modelsets.get('yeast-alcatras-brightfield-EMCCD-60x-5z', + session=tf_session, graph=tf_graph) @pytest.fixture(scope='module') def bb_prime60(modelsets, tf_session_graph): tf_session, tf_graph = tf_session_graph - return BabyBrain(session=tf_session, graph=tf_graph, - **modelsets['prime95b_brightfield_60x_5z']) + return modelsets.get('yeast-alcatras-brightfield-sCMOS-60x-5z', + session=tf_session, graph=tf_graph) + + +@pytest.fixture(scope='module') +def bb_mmscmos(modelsets, tf_session_graph): + tf_session, tf_graph = tf_session_graph + return modelsets.get('ecoli-mm-phase-sCMOS-100x-1z', + session=tf_session, graph=tf_graph) diff --git a/tests/crawler_test.py b/tests/crawler_test.py index 97424c12ff6202142384f0cf05e42e193e707e52..5b51fed959e659b3367af0c101cbdf10582d8c5d 100644 --- a/tests/crawler_test.py +++ b/tests/crawler_test.py @@ -1,12 +1,14 @@ # If you publish results that make use of this software or the Birth Annotator # for Budding Yeast algorithm, please cite: -# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S -# Swain, 2021, Birth Annotator for Budding Yeast (in preparation). +# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +# and Swain, P.S. (2023). Determining growth rates from bright-field images of +# budding cells through identifying overlaps. eLife. 12:e79812. +# https://doi.org/10.7554/eLife.79812 # # # The MIT License (MIT) # -# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021 +# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to diff --git a/tests/images/evolve_testA_preds.png b/tests/images/evolve_testA_preds.png index 23ef2498373b16fa2da9c79fb6aa9bfa1717f03a..3c0d72ce0106dfd6a0d7c1ddce81271384d3a043 100644 Binary files a/tests/images/evolve_testA_preds.png and b/tests/images/evolve_testA_preds.png differ diff --git a/tests/images/evolve_testB_preds.png b/tests/images/evolve_testB_preds.png index 5ab187c2e34472477e43f87a7a977a2e3eb25db6..cca0dd9c7a4755405577d00ced8d5d943197eba8 100644 Binary files a/tests/images/evolve_testB_preds.png and b/tests/images/evolve_testB_preds.png differ diff --git a/tests/images/evolve_testC_preds.png b/tests/images/evolve_testC_preds.png index e6101f64115de1e065b1e0a76a9230e5b91a6c68..d0e9fe8575337a67482f55ae079cc030f3aea2c4 100644 Binary files a/tests/images/evolve_testC_preds.png and b/tests/images/evolve_testC_preds.png differ diff --git a/tests/images/evolve_testD_preds.png b/tests/images/evolve_testD_preds.png index 2d762d03cff83b05a10f4a0350ece867553c17b3..cfdf9ced1437a838f7cca5276ee3af3f6fc23e60 100644 Binary files a/tests/images/evolve_testD_preds.png and b/tests/images/evolve_testD_preds.png differ diff --git a/tests/images/evolve_testE_preds.png b/tests/images/evolve_testE_preds.png index d1a785e82308336068445d01ea188611f2d1fd7b..d1fc0d0bb3feab593dc568f31d1ce73b1d39f57d 100644 Binary files a/tests/images/evolve_testE_preds.png and b/tests/images/evolve_testE_preds.png differ diff --git a/tests/images/evolve_testF_tp3_preds.png b/tests/images/evolve_testF_tp3_preds.png index f46d2b2ccfd49aaf7d3fdae2d432dd76a047f621..0541c2cabef40796b25967063da92d5e96c65780 100644 Binary files a/tests/images/evolve_testF_tp3_preds.png and b/tests/images/evolve_testF_tp3_preds.png differ diff --git a/tests/images/evolve_testF_tp4_preds.png b/tests/images/evolve_testF_tp4_preds.png index e884d8151baf3e720aa69a0aaf528ec73d226fbe..124a87dc7fc35e33d292fd04176f917bfd22da97 100644 Binary files a/tests/images/evolve_testF_tp4_preds.png and b/tests/images/evolve_testF_tp4_preds.png differ diff --git a/tests/images/evolve_testG_tp1_preds.png b/tests/images/evolve_testG_tp1_preds.png index 2d79c356bbe77e9dc25b05ae88286b58cab3c378..3a628f025bcbb1bf311d06b6686237f0a2c6ffd3 100644 Binary files a/tests/images/evolve_testG_tp1_preds.png and b/tests/images/evolve_testG_tp1_preds.png differ diff --git a/tests/images/evolve_testG_tp2_preds.png b/tests/images/evolve_testG_tp2_preds.png index debec78ca7e8f63675e03393c827965e89915ce2..e49496b6654d6b5853b8505b08ebcadac259d8de 100644 Binary files a/tests/images/evolve_testG_tp2_preds.png and b/tests/images/evolve_testG_tp2_preds.png differ diff --git a/tests/images/evolve_testG_tp3_preds.png b/tests/images/evolve_testG_tp3_preds.png index 0653f82ffdfe8ff1c7596766aa1c7e9f0793573c..1cde31e9d99c91f0fddcfeb3955881bfcdbcc2ee 100644 Binary files a/tests/images/evolve_testG_tp3_preds.png and b/tests/images/evolve_testG_tp3_preds.png differ diff --git a/tests/images/evolve_testG_tp4_preds.png b/tests/images/evolve_testG_tp4_preds.png index 3b226b5b8d636df2e3f37cf50799c1669013d75b..b47b3a42eb78f88e9f923903913b58eaed32fbc2 100644 Binary files a/tests/images/evolve_testG_tp4_preds.png and b/tests/images/evolve_testG_tp4_preds.png differ diff --git a/tests/images/evolve_testG_tp5_preds.png b/tests/images/evolve_testG_tp5_preds.png index 4e6af49f652c613bcf8d47a050cba55115ac3d80..8bb83faeeb01a40fdb3b79fbda4a457f97b98e47 100644 Binary files a/tests/images/evolve_testG_tp5_preds.png and b/tests/images/evolve_testG_tp5_preds.png differ diff --git a/tests/images/mmsCMOS_testA_Brightfield.png b/tests/images/mmsCMOS_testA_Brightfield.png new file mode 100644 index 0000000000000000000000000000000000000000..b880c1a5238c03a7edf3ba2b88e1609f06585559 Binary files /dev/null and b/tests/images/mmsCMOS_testA_Brightfield.png differ diff --git a/tests/images/mmsCMOS_testA_segoutlines.png b/tests/images/mmsCMOS_testA_segoutlines.png new file mode 100644 index 0000000000000000000000000000000000000000..132839155dddf20c90a4ad18998ee71a1feecd54 Binary files /dev/null and b/tests/images/mmsCMOS_testA_segoutlines.png differ diff --git a/tests/images/mmsCMOS_testB_Brightfield.png b/tests/images/mmsCMOS_testB_Brightfield.png new file mode 100644 index 0000000000000000000000000000000000000000..dc00c401f7d768c1419b67d7005f042f69a66975 Binary files /dev/null and b/tests/images/mmsCMOS_testB_Brightfield.png differ diff --git a/tests/images/mmsCMOS_testB_segoutlines.png b/tests/images/mmsCMOS_testB_segoutlines.png new file mode 100644 index 0000000000000000000000000000000000000000..7a26bf512e9fa49712d4f93c3eedc2beda5a0f6d Binary files /dev/null and b/tests/images/mmsCMOS_testB_segoutlines.png differ diff --git a/tests/io_test.py b/tests/io_test.py index e9d5c1b5a2fc4c6e685b979cafa491ddf99a5022..c40a587484393e3e5350b2979912f8d82c0d548a 100644 --- a/tests/io_test.py +++ b/tests/io_test.py @@ -1,12 +1,14 @@ # If you publish results that make use of this software or the Birth Annotator # for Budding Yeast algorithm, please cite: -# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S -# Swain, 2021, Birth Annotator for Budding Yeast (in preparation). +# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +# and Swain, P.S. (2023). Determining growth rates from bright-field images of +# budding cells through identifying overlaps. eLife. 12:e79812. +# https://doi.org/10.7554/eLife.79812 # # # The MIT License (MIT) # -# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021 +# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to @@ -38,6 +40,7 @@ def test_tiled_save_load(tmp_path): save_tiled_image(img, testfile, info=info) loaded_img, loaded_info = load_tiled_image(testfile) + assert loaded_img.dtype == 'uint16' assert (loaded_img == img).all() assert 'experimentID' in loaded_info assert loaded_info['experimentID'] == info['experimentID'] diff --git a/tests/modelsets_test.py b/tests/modelsets_test.py new file mode 100644 index 0000000000000000000000000000000000000000..41b42beb37333095d69f99e362eda1bf22dfe23a --- /dev/null +++ b/tests/modelsets_test.py @@ -0,0 +1,123 @@ +# If you publish results that make use of this software or the Birth Annotator +# for Budding Yeast algorithm, please cite: +# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +# and Swain, P.S. (2023). Determining growth rates from bright-field images of +# budding cells through identifying overlaps. eLife. 12:e79812. +# https://doi.org/10.7554/eLife.79812 +# +# +# The MIT License (MIT) +# +# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +import pytest +from pathlib import Path +import os + +from baby import BabyBrain + + +DEFAULT_MODELSET = 'yeast-alcatras-brightfield-sCMOS-60x-5z' + + +@pytest.fixture(scope='module') +def remote_modelsets(modelsets): + return modelsets.remote_modelsets() + + +def test_remote_modelsets_available(remote_modelsets): + assert 'models' in remote_modelsets + assert 'shared' in remote_modelsets + + +def test_local_cache_dir(modelsets): + localpath = modelsets.LOCAL_MODELSETS_PATH + envpath = os.environ.get(modelsets.ENV_VAR_MODELSETS_PATH) + if not envpath: + assert localpath != modelsets.DEFAULT_LOCAL_MODELSETS_PATH + else: + assert localpath == Path(envpath) + modelsets._ensure_local_path() + assert localpath.exists() + + +def test_meta(modelsets, remote_modelsets): + modelsets.ids() # trigger creation of local model sets cache + assert modelsets.LOCAL_MODELSETS_CACHE.exists() + + # Listed models match those on remote + remote_keys = remote_modelsets['models'].keys() + assert len(set(modelsets.ids()).difference(remote_keys)) == 0 + + # Listed meta data matches that on remote + remote_meta = {k: v['meta'] for k, v in + remote_modelsets['models'].items()} + assert modelsets.meta() == remote_meta + + +def test_update(modelsets, remote_modelsets): + remote_default = remote_modelsets['models'][DEFAULT_MODELSET] + modelsets.update([DEFAULT_MODELSET]) + localpath = modelsets.LOCAL_MODELSETS_PATH / DEFAULT_MODELSET + assert localpath.is_dir() + for filename in remote_default['files']: + assert (localpath / filename).is_file() + + +def test_get_params(modelsets, remote_modelsets): + remote_default = remote_modelsets['models'][DEFAULT_MODELSET] + local_params = modelsets.get_params(DEFAULT_MODELSET) + assert local_params == remote_default['brain_params'] + localpath = modelsets.LOCAL_MODELSETS_PATH / DEFAULT_MODELSET + print(localpath) + + # Test auto-update for missing model set file + (localpath / modelsets.MODELSET_FILENAME).unlink() + modelsets.get_params(DEFAULT_MODELSET) + assert (localpath / modelsets.MODELSET_FILENAME).exists() + + # Test auto-update for missing model file + target_file = [f for f in remote_default['files'] + if f != modelsets.MODELSET_FILENAME][0] + target_file = localpath / target_file + target_file.unlink() + modelsets.get_params(DEFAULT_MODELSET) + assert target_file.exists() + + +def test_resolve(modelsets): + params = modelsets.get_params(DEFAULT_MODELSET) + localpath = modelsets.LOCAL_MODELSETS_PATH / DEFAULT_MODELSET + sharepath = modelsets.LOCAL_MODELSETS_PATH / modelsets.SHARE_PATH + localtest = params['morph_model_file'] + sharetest = params['celltrack_model_file'] + assert (localpath / localtest).is_file() + assert (sharepath / sharetest).is_file() + assert ((localpath / localtest) == + modelsets.resolve(localtest, DEFAULT_MODELSET)) + assert ((sharepath / sharetest) == + modelsets.resolve(sharetest, DEFAULT_MODELSET)) + + +def test_get(modelsets): + bb = modelsets.get(DEFAULT_MODELSET) + assert type(bb) == BabyBrain + + diff --git a/tests/parallel_test.py b/tests/parallel_test.py index f38e205fcc9d13450f037997c7e2e945bff8af65..0f01d398a578b91fee08d758e5bdb824bac65f35 100644 --- a/tests/parallel_test.py +++ b/tests/parallel_test.py @@ -1,12 +1,14 @@ # If you publish results that make use of this software or the Birth Annotator # for Budding Yeast algorithm, please cite: -# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S -# Swain, 2021, Birth Annotator for Budding Yeast (in preparation). +# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +# and Swain, P.S. (2023). Determining growth rates from bright-field images of +# budding cells through identifying overlaps. eLife. 12:e79812. +# https://doi.org/10.7554/eLife.79812 # # # The MIT License (MIT) # -# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021 +# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to @@ -28,34 +30,30 @@ import pickle from collections import namedtuple, Counter from os.path import isfile +import json +from pathlib import Path import baby import numpy as np import pytest -from baby.brain import default_params from baby.io import load_paired_images -from baby.morph_thresh_seg import MorphSegGrouped +from baby.morph_thresh_seg import MorphSegGrouped, SegmentationParameters from baby.preprocessing import raw_norm, SegmentationFlattening from baby.tracker.core import MasterTracker +from baby.utils import as_python_object from joblib import Parallel, delayed, parallel_backend -MODEL_DIR = baby.model_path() - -def resolve_file(filename): - if not isfile(filename): - filename = MODEL_DIR / filename - assert isfile(filename) - return filename +DEFAULT_MODELSET = 'yeast-alcatras-brightfield-EMCCD-60x-5z' @pytest.fixture(scope='module') def evolve60env(modelsets, image_dir): - mset = modelsets['evolve_brightfield_60x_5z'] + mset = modelsets.get_params(DEFAULT_MODELSET) # Load flattener - ff = resolve_file(mset['flattener_file']) + ff = modelsets.resolve(mset['flattener_file'], DEFAULT_MODELSET) flattener = SegmentationFlattening(ff) tnames = flattener.names() @@ -64,10 +62,16 @@ def evolve60env(modelsets, image_dir): i_bud = tnames.index(bud_target) # Load BabyBrain param defaults - params = default_params.copy() - params.update(mset.get('params', {})) + params = mset['params'] + if type(params) == dict: + params = SegmentationParameters(**params) + if type(params) != SegmentationParameters: + param_file = modelsets.resolve(mset['params'], DEFAULT_MODELSET) + with open(param_file, 'rt') as f: + params = json.load(f, object_hook=as_python_object) + assert type(params) == SegmentationParameters - segmenter = MorphSegGrouped(flattener, return_masks=True, **params) + segmenter = MorphSegGrouped(flattener, params=params, return_masks=True) # Load CNN outputs impairs = load_paired_images(image_dir.glob('evolve_test[FG]_tp*.png'), @@ -78,10 +82,10 @@ def evolve60env(modelsets, image_dir): ]), sorted([k for k in impairs.keys() if k.startswith('evolve_testG')])) # Load the celltrack and budassign models - ctm_file = resolve_file(mset['celltrack_model_file']) + ctm_file = modelsets.resolve(mset['celltrack_model_file'], DEFAULT_MODELSET) with open(ctm_file, 'rb') as f: ctm = pickle.load(f) - bam_file = resolve_file(mset['budassign_model_file']) + bam_file = modelsets.resolve(mset['budassign_model_file'], DEFAULT_MODELSET) with open(bam_file, 'rb') as f: bam = pickle.load(f) diff --git a/tests/segment_test.py b/tests/segment_test.py index 1dc16db33d0b4f591a698d813c9da82f66632a36..baada9265ef233b769bbb45f9bf0f6d298d41acd 100644 --- a/tests/segment_test.py +++ b/tests/segment_test.py @@ -1,12 +1,14 @@ # If you publish results that make use of this software or the Birth Annotator # for Budding Yeast algorithm, please cite: -# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S -# Swain, 2021, Birth Annotator for Budding Yeast (in preparation). +# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +# and Swain, P.S. (2023). Determining growth rates from bright-field images of +# budding cells through identifying overlaps. eLife. 12:e79812. +# https://doi.org/10.7554/eLife.79812 # # # The MIT License (MIT) # -# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021 +# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to @@ -28,6 +30,9 @@ import pytest from os.path import isfile +from pathlib import Path +import json +import inspect import numpy as np from scipy.ndimage import binary_fill_holes from itertools import chain @@ -37,13 +42,17 @@ from baby.io import load_paired_images, save_tiled_image from baby.errors import BadParam from baby.preprocessing import (raw_norm, seg_norm, dwsquareconn, SegmentationFlattening) -from baby.brain import default_params -from baby.morph_thresh_seg import MorphSegGrouped, SegmentationOutput +from baby.morph_thresh_seg import (MorphSegGrouped, SegmentationOutput, + SegmentationParameters) +from baby.utils import as_python_object +from baby import segmentation from baby.segmentation import morph_seg_grouped from baby.performance import (calc_IoUs, best_IoU, calc_AP, flattener_seg_probs) -#from .conftest import MODEL_DIR, IMAGE_DIR + +DEFAULT_MODELSET = 'yeast-alcatras-brightfield-EMCCD-60x-5z' + # Tuple for variables needed to test segmentation SegmentationEnv = namedtuple( @@ -51,6 +60,19 @@ SegmentationEnv = namedtuple( ['flattener', 'cparams', 'fparams', 'cnn_out', 'truth', 'imnames']) +# Old default parameters +DEFAULT_PARAMETERS = SegmentationParameters( + interior_threshold=(0.7, 0.5, 0.5), + nclosing=(1, 0, 0), + nopening=(1, 0, 0), + connectivity=(2, 2, 1), + pedge_thresh=0.001, + fit_radial=True, + edge_sub_dilations=1, + use_group_thresh=True, + group_thresh_expansion=0.1) + + def compare_edges_and_masks(edges, masks): if len(edges) == 0 and len(masks) == 0: return 1, 1 @@ -95,24 +117,33 @@ def run_performance_checks(seg_outputs, cnn_outputs, flattener, truth): @pytest.fixture(scope='module') -def evolve60env(modelsets, model_dir, image_dir): - mset = modelsets['evolve_brightfield_60x_5z'] +def evolve60env(modelsets, image_dir): + mset = modelsets.get_params(DEFAULT_MODELSET) # Load flattener - ff = mset['flattener_file'] - if not isfile(ff): - ff = model_dir / ff - assert isfile(ff) + ff = modelsets.resolve(mset['flattener_file'], DEFAULT_MODELSET) flattener = SegmentationFlattening(ff) # Load BabyBrain param defaults - cparams = default_params.copy() - cparams.update(mset.get('params', {})) + params = mset['params'] + if type(params) == dict: + params = SegmentationParameters(**params) + if type(params) != SegmentationParameters: + param_file = modelsets.resolve(mset['params'], DEFAULT_MODELSET) + with open(param_file, 'rt') as f: + params = json.load(f, object_hook=as_python_object) + assert type(params) == SegmentationParameters + cparams = params # Convert to params compatible with morph_seg_grouped - fparams = cparams.copy() - del fparams['edge_sub_dilations'] + msg_args = inspect.signature(morph_seg_grouped).parameters.keys() + fparams = cparams._asdict() + fparams = {k: v for k, v in fparams.items() if k in msg_args} fparams['ingroup_edge_segment'] = True + fparams['containment_func'] = getattr(segmentation, + fparams['containment_func']) + if fparams['cellgroups'] is None: + fparams['cellgroups'] = ['large', 'medium', 'small'] # Load CNN outputs impairs = load_paired_images(image_dir.glob('evolve_*.png'), @@ -197,9 +228,9 @@ def test_segmenter_bbparams_empty(evolve60env): params = evolve60env.cparams ntargets = len(flattener.names()) segmenter = MorphSegGrouped(flattener, + params=params, return_masks=True, - return_coords=True, - **params) + return_coords=True) out = segmenter.segment(np.zeros((ntargets, 81, 81))) assert tuple(len(o) for o in out) == (0, 0, 0, 0) @@ -254,9 +285,9 @@ def test_segfunc_bbparams_preds(evolve60env, save_segoutlines): def test_segmenter_bbparams_preds(evolve60env, save_segoutlines): flattener, params, _, cnn_out, truth, imnames = evolve60env segmenter = MorphSegGrouped(flattener, + params=params, return_masks=True, - return_coords=True, - **params) + return_coords=True) seg_outputs = [segmenter.segment(pred) for pred in cnn_out] save_segoutlines(seg_outputs, imnames) @@ -290,9 +321,9 @@ def test_segfunc_refined_preds(evolve60env, save_segoutlines): def test_segmenter_refined_preds(evolve60env, save_segoutlines): flattener, params, _, cnn_out, truth, imnames = evolve60env segmenter = MorphSegGrouped(flattener, + params=params, return_masks=True, - return_coords=True, - **params) + return_coords=True) seg_outputs = [ segmenter.segment(pred, refine_outlines=True) for pred in cnn_out ] diff --git a/tests/tracker_test.py b/tests/tracker_test.py index 413f7e1c68a0771f8c507d7e59964b40de91d0c4..3a36c03fb4d9570bb04e79a249532e4436fb9e7a 100644 --- a/tests/tracker_test.py +++ b/tests/tracker_test.py @@ -1,12 +1,14 @@ # If you publish results that make use of this software or the Birth Annotator # for Budding Yeast algorithm, please cite: -# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S -# Swain, 2021, Birth Annotator for Budding Yeast (in preparation). +# Pietsch, J.M.J., Muñoz, A.F., Adjavon, D.-Y.A., Farquhar, I., Clark, I.B.N., +# and Swain, P.S. (2023). Determining growth rates from bright-field images of +# budding cells through identifying overlaps. eLife. 12:e79812. +# https://doi.org/10.7554/eLife.79812 # # # The MIT License (MIT) # -# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021 +# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2023 # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to @@ -28,34 +30,29 @@ import pickle from collections import namedtuple, Counter from os.path import isfile +import json +from pathlib import Path import baby import numpy as np import pytest -from baby.brain import default_params from baby.io import load_paired_images -from baby.morph_thresh_seg import MorphSegGrouped +from baby.morph_thresh_seg import MorphSegGrouped, SegmentationParameters from baby.preprocessing import raw_norm, SegmentationFlattening from baby.tracker.core import MasterTracker +from baby.utils import as_python_object -MODEL_DIR = baby.model_path() TrackerEnv = namedtuple('TrackerEnv', ['masks', 'p_budneck', 'p_bud']) - - -def resolve_file(filename): - if not isfile(filename): - filename = MODEL_DIR / filename - assert isfile(filename) - return filename +DEFAULT_MODELSET = 'yeast-alcatras-brightfield-EMCCD-60x-5z' @pytest.fixture(scope='module') def evolve60env(modelsets, image_dir): - mset = modelsets['evolve_brightfield_60x_5z'] + mset = modelsets.get_params(DEFAULT_MODELSET) # Load flattener - ff = resolve_file(mset['flattener_file']) + ff = modelsets.resolve(mset['flattener_file'], DEFAULT_MODELSET) flattener = SegmentationFlattening(ff) tnames = flattener.names() @@ -64,10 +61,16 @@ def evolve60env(modelsets, image_dir): i_bud = tnames.index(bud_target) # Load BabyBrain param defaults - params = default_params.copy() - params.update(mset.get('params', {})) + params = mset['params'] + if type(params) == dict: + params = SegmentationParameters(**params) + if type(params) != SegmentationParameters: + param_file = modelsets.resolve(mset['params'], DEFAULT_MODELSET) + with open(param_file, 'rt') as f: + params = json.load(f, object_hook=as_python_object) + assert type(params) == SegmentationParameters - segmenter = MorphSegGrouped(flattener, return_masks=True, **params) + segmenter = MorphSegGrouped(flattener, params=params, return_masks=True) # Load CNN outputs impairs = load_paired_images(image_dir.glob('evolve_test[FG]_tp*.png'), @@ -94,10 +97,10 @@ def evolve60env(modelsets, image_dir): trkF, trkG = trks # Load the celltrack and budassign models - ctm_file = resolve_file(mset['celltrack_model_file']) + ctm_file = modelsets.resolve(mset['celltrack_model_file'], DEFAULT_MODELSET) with open(ctm_file, 'rb') as f: ctm = pickle.load(f) - bam_file = resolve_file(mset['budassign_model_file']) + bam_file = modelsets.resolve(mset['budassign_model_file'], DEFAULT_MODELSET) with open(bam_file, 'rb') as f: bam = pickle.load(f)