Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • swain-lab/aliby/aliby-mirror
  • swain-lab/aliby/alibylite
2 results
Show changes
Showing
with 4801 additions and 222 deletions
#!/usr/bin/env jupyter
import argparse
from agora.utils.cast import _str_to_int
from aliby.pipeline import Pipeline, PipelineParameters
def run():
"""
Run a default microscopy analysis pipeline.
Parse command-line arguments and set default parameter values for running a pipeline, then
construct and execute the pipeline with the parameters obtained. Command-line arguments can
override default parameter values. If a command-line argument is a string representation of
an integer, convert it to an integer.
Returns
-------
None
Examples
--------
FIXME: Add docs.
"""
parser = argparse.ArgumentParser(
prog="aliby-run",
description="Run a default microscopy analysis pipeline",
)
param_values = {
"expt_id": None,
"distributed": 2,
"tps": 2,
"directory": "./data",
"filter": 0,
"host": None,
"username": None,
"password": None,
}
for k in param_values:
parser.add_argument(f"--{k}", action="store")
args = parser.parse_args()
for k in param_values:
if passed_value := _str_to_int(getattr(args, k)):
param_values[k] = passed_value
params = PipelineParameters.default(general=param_values)
p = Pipeline(params)
p.run()
import numpy as np
from time import perf_counter
"""
Neural network initialisation.
"""
from pathlib import Path
from time import perf_counter
import numpy as np
import tensorflow as tf
from agora.io.writer import DynamicWriter
......@@ -21,7 +23,9 @@ def initialise_tf(version):
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
logical_gpus = tf.config.experimental.list_logical_devices("GPU")
print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
print(
len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs"
)
return None
......@@ -33,8 +37,6 @@ def timer(func, *args, **kwargs):
################## CUSTOM OBJECTS ##################################
class ModelPredictor:
"""Generic object that takes a NN and returns the prediction.
......@@ -49,12 +51,16 @@ class ModelPredictor:
def get_data(self, tp):
# Change axes to X,Y,Z rather than Z,Y,X
return self.tiler.get_tp_data(tp, self.bf_channel).swapaxes(1, 3).swapaxes(1, 2)
return (
self.tiler.get_tp_data(tp, self.bf_channel)
.swapaxes(1, 3)
.swapaxes(1, 2)
)
def format_result(self, result, tp):
return {self.name: result, "timepoints": [tp] * len(result)}
def run_tp(self, tp, **kwargs):
def run_tp(self, tp):
"""Simulating processing time with sleep"""
# Access the image
segmentation = self.model.predict(self.get_data(tp))
......@@ -64,34 +70,8 @@ class ModelPredictor:
class ModelPredictorWriter(DynamicWriter):
def __init__(self, file, name, shape, dtype):
super.__init__(file)
self.datatypes = {name: (shape, dtype), "timepoint": ((None,), np.uint16)}
self.datatypes = {
name: (shape, dtype),
"timepoint": ((None,), np.uint16),
}
self.group = f"{self.name}_info"
class Saver:
channel_names = {0: "BrightField", 1: "GFP"}
def __init__(self, tiler, save_directory, pos_name):
"""This class straight up saves the trap data for use with neural networks in the future."""
self.tiler = tiler
self.name = pos_name
self.save_dir = Path(save_directory)
def channel_dir(self, index):
ch_dir = self.save_dir / self.channel_names[index]
if not ch_dir.exists():
ch_dir.mkdir()
return ch_dir
def get_data(self, tp, ch):
return self.tiler.get_tp_data(tp, ch).swapaxes(1, 3).swapaxes(1, 2)
def cache(self, tp):
# Get a given time point
# split into channels
for ch in self.channel_names:
ch_dir = self.channel_dir(ch)
data = self.get_data(tp, ch)
for tid, trap in enumerate(data):
np.save(ch_dir / f"{self.name}_{tid}_{tp}.npy", trap)
return
"""
Methods and classes to fetch image data and metadata.
"""
#!/usr/bin/env python3
"""
Dataset is a group of classes to manage multiple types of experiments:
- Remote experiments on an OMERO server (located in src/aliby/io/omero.py)
- Local experiments in a multidimensional OME-TIFF image containing the metadata
- Local experiments in a directory containing multiple positions in independent images with or without metadata
"""
import os
import shutil
import time
import typing as t
from abc import ABC, abstractproperty, abstractmethod
from pathlib import Path
from agora.io.bridge import BridgeH5
from aliby.io.image import ImageLocalOME
def dispatch_dataset(expt_id: int or str, **kwargs):
"""
Find paths to the data.
Connects to OMERO if data is remotely available.
Parameters
----------
expt_id: int or str
To identify the data, either an OMERO ID or an OME-TIFF file or a local directory.
Returns
-------
A callable Dataset instance, either network-dependent or local.
"""
if isinstance(expt_id, int):
# data available online
from aliby.io.omero import Dataset
return Dataset(expt_id, **kwargs)
elif isinstance(expt_id, str):
# data available locally
expt_path = Path(expt_id)
if expt_path.is_dir():
# data in multiple folders
return DatasetLocalDir(expt_path)
else:
# data in one folder as OME-TIFF files
return DatasetLocalOME(expt_path)
else:
raise Warning(f"{expt_id} is an invalid expt_id")
class DatasetLocalABC(ABC):
"""
Abstract Base class to find local files, either OME-XML or raw images.
"""
_valid_suffixes = ("tiff", "png", "zarr")
_valid_meta_suffixes = ("txt", "log")
def __init__(self, dpath: t.Union[str, Path], *args, **kwargs):
self.path = Path(dpath)
def __enter__(self):
return self
def __exit__(self, *exc):
return False
@property
def dataset(self):
return self.path
@property
def name(self):
return self.path.name
@property
def unique_name(self):
return self.path.name
@property
def files(self):
"""Return a dictionary with any available metadata files."""
if not hasattr(self, "_files"):
self._files = {
f: f
for f in self.path.rglob("*")
if any(
str(f).endswith(suffix)
for suffix in self._valid_meta_suffixes
)
}
return self._files
def cache_logs(self, root_dir):
"""Copy metadata files to results folder."""
for name, annotation in self.files.items():
shutil.copy(annotation, root_dir / name.name)
return True
@abstractproperty
def date(self):
pass
@abstractmethod
def get_images(self):
pass
class DatasetLocalDir(DatasetLocalABC):
"""Find paths to a data set, comprising multiple images in different folders."""
def __init__(self, dpath: t.Union[str, Path], *args, **kwargs):
super().__init__(dpath)
@property
def date(self):
"""Find date when a folder was created."""
return time.strftime(
"%Y%m%d", time.strptime(time.ctime(os.path.getmtime(self.path)))
)
def get_images(self):
"""Return a dictionary of folder or file names and their paths.
FUTURE 3.12 use pathlib is_junction to pick Dir or File
"""
images = {
item.name: item
for item in self.path.glob("*/")
if item.is_dir()
and any(
path
for suffix in self._valid_suffixes
for path in item.glob(f"*.{suffix}")
)
or item.suffix[1:] in self._valid_suffixes
}
return images
class DatasetLocalOME(DatasetLocalABC):
"""Find names of images in a folder, assuming images in OME-TIFF format."""
def __init__(self, dpath: t.Union[str, Path], *args, **kwargs):
super().__init__(dpath)
assert len(
self.get_images()
), f"No valid files found. Formats are {self._valid_suffixes}"
@property
def date(self):
"""Get the date from the metadata of the first position."""
return ImageLocalOME(list(self.get_images().values())[0]).date
def get_images(self):
"""Return a dictionary with the names of the image files."""
return {
f.name: str(f)
for suffix in self._valid_suffixes
for f in self.path.glob(f"*.{suffix}")
}
#!/usr/bin/env python3
"""
Image: Loads images and registers them.
Image instances loads images from a specified directory into an object that
also contains image properties such as name and metadata. Pixels from images
are stored in dask arrays; the standard way is to store them in 5-dimensional
arrays: T(ime point), C(channel), Z(-stack), Y, X.
This module consists of a base Image class (BaseLocalImage). ImageLocalOME
handles local OMERO images. ImageDir handles cases in which images are split
into directories, with each time point and channel having its own image file.
ImageDummy is a dummy class for silent failure testing.
"""
import typing as t
from abc import ABC, abstractmethod, abstractproperty
from datetime import datetime
from pathlib import Path
import dask.array as da
import numpy as np
import xmltodict
import zarr
from dask.array.image import imread
from importlib_resources import files
from tifffile import TiffFile
from agora.io.metadata import dir_to_meta, dispatch_metadata_parser
def get_examples_dir():
"""Get examples directory which stores dummy image for tiler"""
return files("aliby").parent.parent / "examples" / "tiler"
def instantiate_image(
source: t.Union[str, int, t.Dict[str, str], Path], **kwargs
):
"""Wrapper to instatiate the appropiate image
Parameters
----------
source : t.Union[str, int, t.Dict[str, str], Path]
Image identifier
Examples
--------
image_path = "path/to/image"]
with instantiate_image(image_path) as img:
print(imz.data, img.metadata)
"""
return dispatch_image(source)(source, **kwargs)
def dispatch_image(source: t.Union[str, int, t.Dict[str, str], Path]):
"""
Wrapper to pick the appropiate Image class depending on the source of data.
"""
if isinstance(source, (int, np.int64)):
from aliby.io.omero import Image
instatiator = Image
elif isinstance(source, dict) or (
isinstance(source, (str, Path)) and Path(source).is_dir()
):
if Path(source).suffix == ".zarr":
instatiator = ImageZarr
else:
instatiator = ImageDir
elif isinstance(source, str) and Path(source).is_file():
instatiator = ImageLocalOME
else:
raise Exception(f"Invalid data source at {source}")
return instatiator
class BaseLocalImage(ABC):
"""
Base Image class to set path and provide context management method.
"""
_default_dimorder = "tczyx"
def __init__(self, path: t.Union[str, Path]):
# If directory, assume contents are naturally sorted
self.path = Path(path)
def __enter__(self):
return self
def __exit__(self, *exc):
for e in exc:
if e is not None:
print(e)
return False
def rechunk_data(self, img):
# Format image using x and y size from metadata.
self._rechunked_img = da.rechunk(
img,
chunks=(
1,
1,
1,
self._meta["size_y"],
self._meta["size_x"],
),
)
return self._rechunked_img
@abstractmethod
def get_data_lazy(self) -> da.Array:
pass
@abstractproperty
def name(self):
pass
@abstractproperty
def dimorder(self):
pass
@property
def data(self):
return self.get_data_lazy()
@property
def metadata(self):
return self._meta
def set_meta(self):
"""Load metadata using parser dispatch"""
self._meta = dispatch_metadata_parser(self.path)
class ImageLocalOME(BaseLocalImage):
"""
Local OMERO Image class.
This is a derivative Image class. It fetches an image from OMEXML data format,
in which a multidimensional tiff image contains the metadata.
"""
def __init__(self, path: str, dimorder=None):
super().__init__(path)
self._id = str(path)
def set_meta(self):
meta = dict()
try:
with TiffFile(path) as f:
self._meta = xmltodict.parse(f.ome_metadata)["OME"]
for dim in self.dimorder:
meta["size_" + dim.lower()] = int(
self._meta["Image"]["Pixels"]["@Size" + dim]
)
meta["channels"] = [
x["@Name"]
for x in self._meta["Image"]["Pixels"]["Channel"]
]
meta["name"] = self._meta["Image"]["@Name"]
meta["type"] = self._meta["Image"]["Pixels"]["@Type"]
except Exception as e: # Images not in OMEXML
print("Warning:Metadata not found: {}".format(e))
print(
f"Warning: No dimensional info provided. Assuming {self._default_dimorder}"
)
# Mark non-existent dimensions for padding
self.base = self._default_dimorder
# self.ids = [self.index(i) for i in dimorder]
self._dimorder = base
self._meta = meta
@property
def name(self):
return self._meta["name"]
@property
def date(self):
date_str = [
x
for x in self._meta["StructuredAnnotations"]["TagAnnotation"]
if x["Description"] == "Date"
][0]["Value"]
return datetime.strptime(date_str, "%d-%b-%Y")
@property
def dimorder(self):
"""Order of dimensions in image"""
if not hasattr(self, "_dimorder"):
self._dimorder = self._meta["Image"]["Pixels"]["@DimensionOrder"]
return self._dimorder
@dimorder.setter
def dimorder(self, order: str):
self._dimorder = order
return self._dimorder
def get_data_lazy(self) -> da.Array:
"""Return 5D dask array. For lazy-loading multidimensional tiff files"""
if not hasattr(self, "formatted_img"):
if not hasattr(self, "ids"): # Standard dimension order
img = (imread(str(self.path))[0],)
else: # Custom dimension order, we rearrange the axes for compatibility
img = imread(str(self.path))[0]
for i, d in enumerate(self._dimorder):
self._meta["size_" + d.lower()] = img.shape[i]
target_order = (
*self.ids,
*[
i
for i, d in enumerate(self.base)
if d not in self.dimorder
],
)
reshaped = da.reshape(
img,
shape=(
*img.shape,
*[1 for _ in range(5 - len(self.dimorder))],
),
)
img = da.moveaxis(
reshaped, range(len(reshaped.shape)), target_order
)
return self.rechunk_data(img)
class ImageDir(BaseLocalImage):
"""
Image class for the case in which all images are split in one or
multiple folders with time-points and channels as independent files.
It inherits from BaseLocalImage so we only override methods that are critical.
Assumptions:
- One folders per position.
- Images are flat.
- Channel, Time, z-stack and the others are determined by filenames.
- Provides Dimorder as it is set in the filenames, or expects order during instatiation
"""
def __init__(self, path: t.Union[str, Path], **kwargs):
super().__init__(path)
self.image_id = str(self.path.stem)
self._meta = dir_to_meta(self.path)
def get_data_lazy(self) -> da.Array:
"""Return 5D dask array. For lazy-loading local multidimensional tiff files"""
img = imread(str(self.path / "*.tiff"))
# If extra channels, pick the first stack of the last dimensions
while len(img.shape) > 3:
img = img[..., 0]
if self._meta:
self._meta["size_x"], self._meta["size_y"] = img.shape[-2:]
# Reshape using metadata
# img = da.reshape(img, (*self._meta, *img.shape[1:]))
img = da.reshape(img, self._meta.values())
original_order = [
i[-1] for i in self._meta.keys() if i.startswith("size")
]
# Swap axis to conform with normal order
target_order = [
self._default_dimorder.index(x) for x in original_order
]
img = da.moveaxis(
img,
list(range(len(original_order))),
target_order,
)
pixels = self.rechunk_data(img)
return pixels
@property
def name(self):
return self.path.stem
@property
def dimorder(self):
# Assumes only dimensions start with "size"
return [
k.split("_")[-1] for k in self._meta.keys() if k.startswith("size")
]
class ImageZarr(BaseLocalImage):
"""
Read zarr compressed files.
These are outputed by the script
skeletons/scripts/howto_omero/convert_clone_zarr_to_tiff.py
"""
def __init__(self, path: t.Union[str, Path], **kwargs):
super().__init__(path)
self.set_meta()
try:
self._img = zarr.open(self.path)
self.add_size_to_meta()
except Exception as e:
print(f"Could not add size info to metadata: {e}")
def get_data_lazy(self) -> da.Array:
"""Return 5D dask array. For lazy-loading local multidimensional zarr files"""
return self._img
def add_size_to_meta(self):
self._meta.update(
{
f"size_{dim}": shape
for dim, shape in zip(self.dimorder, self._img.shape)
}
)
@property
def name(self):
return self.path.stem
@property
def dimorder(self):
# FIXME hardcoded order based on zarr compression/cloning script
return "TCZYX"
# Assumes only dimensions start with "size"
# return [
# k.split("_")[-1] for k in self._meta.keys() if k.startswith("size")
# ]
class ImageDummy(BaseLocalImage):
"""
Dummy Image class.
ImageDummy mimics the other Image classes in such a way that it is accepted
by Tiler. The purpose of this class is for testing, in particular,
identifying silent failures. If something goes wrong, we should be able to
know whether it is because of bad parameters or bad input data.
For the purposes of testing parameters, ImageDummy assumes that we already
know the tiler parameters before Image instances are instantiated. This is
true for a typical pipeline run.
"""
def __init__(self, tiler_parameters: dict):
"""Builds image instance
Parameters
----------
tiler_parameters : dict
Tiler parameters, in dict form. Following
aliby.tile.tiler.TilerParameters, the keys are: "tile_size" (size of
tile), "ref_channel" (reference channel for tiling), and "ref_z"
(reference z-stack, 0 to choose a default).
"""
self.ref_channel = tiler_parameters["ref_channel"]
self.ref_z = tiler_parameters["ref_z"]
# Goal: make Tiler happy.
@staticmethod
def pad_array(
image_array: da.Array,
dim: int,
n_empty_slices: int,
image_position: int = 0,
):
"""Extends a dimension in a dask array and pads with zeros
Extends a dimension in a dask array that has existing content, then pads
with zeros.
Parameters
----------
image_array : da.Array
Input dask array
dim : int
Dimension in which to extend the dask array.
n_empty_slices : int
Number of empty slices to extend the dask array by, in the specified
dimension/axis.
image_position : int
Position within the new dimension to place the input arary, default 0
(the beginning).
Examples
--------
```
extended_array = pad_array(
my_da_array, dim = 2, n_empty_slices = 4, image_position = 1)
```
Extends a dask array called `my_da_array` in the 3rd dimension
(dimensions start from 0) by 4 slices, filled with zeros. And puts the
original content in slice 1 of the 3rd dimension
"""
# Concats zero arrays with same dimensions as image_array, and puts
# image_array as first element in list of arrays to be concatenated
zeros_array = da.zeros_like(image_array)
return da.concatenate(
[
*([zeros_array] * image_position),
image_array,
*([zeros_array] * (n_empty_slices - image_position)),
],
axis=dim,
)
# Logic: We want to return a image instance
def get_data_lazy(self) -> da.Array:
"""Return 5D dask array. For lazy-loading multidimensional tiff files. Dummy image."""
examples_dir = get_examples_dir()
# TODO: Make this robust to having multiple TIFF images, one for each z-section,
# all falling under the same "pypipeline_unit_test_00_000001_Brightfield_*.tif"
# naming scheme. The aim is to create a multidimensional dask array that stores
# the z-stacks.
img_filename = "pypipeline_unit_test_00_000001_Brightfield_003.tif"
img_path = examples_dir / img_filename
# img is a dask array has three dimensions: z, x, y
# TODO: Write a test to confirm this: If everything worked well,
# z = 1, x = 1200, y = 1200
img = imread(str(img_path))
# Adds t & c dimensions
img = da.reshape(
img, (1, 1, img.shape[-3], img.shape[-2], img.shape[-1])
)
# Pads t, c, and z dimensions
img = self.pad_array(
img, dim=0, n_empty_slices=199
) # 200 timepoints total
img = self.pad_array(img, dim=1, n_empty_slices=2) # 3 channels
img = self.pad_array(
img, dim=2, n_empty_slices=4, image_position=self.ref_z
) # 5 z-stacks
return img
@property
def name(self):
pass
@property
def dimorder(self):
pass
"""
Tools to manage I/O using a remote OMERO server.
"""
import re
import typing as t
from abc import abstractmethod
from pathlib import Path
import dask.array as da
import numpy as np
import omero
from dask import delayed
from omero.gateway import BlitzGateway
from omero.model import enums as omero_enums
from yaml import safe_load
from agora.io.bridge import BridgeH5
# convert OMERO definitions into numpy types
PIXEL_TYPES = {
omero_enums.PixelsTypeint8: np.int8,
omero_enums.PixelsTypeuint8: np.uint8,
omero_enums.PixelsTypeint16: np.int16,
omero_enums.PixelsTypeuint16: np.uint16,
omero_enums.PixelsTypeint32: np.int32,
omero_enums.PixelsTypeuint32: np.uint32,
omero_enums.PixelsTypefloat: np.float32,
omero_enums.PixelsTypedouble: np.float64,
}
class BridgeOmero:
"""
Core to interact with OMERO, using credentials or fetching them from h5 file (temporary trick).
See
https://docs.openmicroscopy.org/omero/5.6.0/developers/Python.html
"""
def __init__(
self,
host: str = None,
username: str = None,
password: str = None,
ome_id: int = None,
):
"""
Parameters
----------
host : string
web address of OMERO host
username: string
password : string
ome_id: Optional int
Unique identifier on Omero database. Used to fetch specific objects.
"""
# assert all((host, username, password)), str(f"Invalid credentials host:{host}, user:{username}, pass:{pass}")
assert all(
(host, username, password)
), f"Invalid credentials. host: {host}, user: {username}, pwd: {password}"
self.conn = None
self.host = host
self.username = username
self.password = password
self.ome_id = ome_id
# standard method required for Python's with statement
def __enter__(self):
self.create_gate()
return self
@property
def ome_class(self):
# Initialise Omero Object Wrapper for instances when applicable.
if not hasattr(self, "_ome_class"):
assert (
self.conn.isConnected() and self.ome_id is not None
), "No Blitz connection or valid omero id"
ome_type = [
valid_name
for valid_name in ("Dataset", "Image")
if re.match(
f".*{ valid_name }.*",
self.__class__.__name__,
re.IGNORECASE,
)
][0]
self._ome_class = self.conn.getObject(ome_type, self.ome_id)
assert self._ome_class, f"{ome_type} {self.ome_id} not found."
return self._ome_class
def create_gate(self) -> bool:
self.conn = BlitzGateway(
host=self.host, username=self.username, passwd=self.password
)
self.conn.connect()
self.conn.c.enableKeepAlive(60)
self.conn.isConnected()
# standard method required for Python's with statement
def __exit__(self, *exc) -> bool:
for e in exc:
if e is not None:
print(e)
self.conn.close()
return False
@classmethod
def server_info_from_h5(
cls,
filepath: t.Union[str, Path],
):
"""Return server info from hdf5 file.
Parameters
----------
cls : BridgeOmero
BridgeOmero class
filepath : t.Union[str, Path]
Location of hdf5 file.
Examples
--------
FIXME: Add docs.
"""
# metadata = load_attributes(filepath)
bridge = BridgeH5(filepath)
meta = safe_load(bridge.meta_h5["parameters"])["general"]
server_info = {k: meta[k] for k in ("host", "username", "password")}
return server_info
def set_id(self, ome_id: int):
self.ome_id = ome_id
@property
def file_annotations(self):
valid_annotations = [
ann.getFileName()
for ann in self.ome_class.listAnnotations()
if hasattr(ann, "getFileName")
]
return valid_annotations
def add_file_as_annotation(
self, file_to_upload: t.Union[str, Path], **kwargs
):
"""Upload annotation to object on OMERO server. Only valid in subclasses.
Parameters
----------
file_to_upload: File to upload
**kwargs: Additional keyword arguments passed on
to BlitzGateway.createFileAnnfromLocalFile
"""
file_annotation = self.conn.createFileAnnfromLocalFile(
file_to_upload,
mimetype="text/plain",
**kwargs,
)
self.ome_class.linkAnnotation(file_annotation)
class Dataset(BridgeOmero):
"""
Tool to interact with Omero Datasets remotely, access their
metadata and associated files and images.
Parameters
----------
expt_id: int Dataset id on server
server_info: dict host, username and password
"""
def __init__(self, expt_id: int, **server_info):
super().__init__(ome_id=expt_id, **server_info)
@property
def name(self):
return self.ome_class.getName()
@property
def date(self):
return self.ome_class.getDate()
@property
def unique_name(self):
return "_".join(
(
str(self.ome_id),
self.date.strftime("%Y_%m_%d").replace("/", "_"),
self.name,
)
)
def get_images(self):
return {
im.getName(): im.getId() for im in self.ome_class.listChildren()
}
@property
def files(self):
if not hasattr(self, "_files"):
self._files = {
x.getFileName(): x
for x in self.ome_class.listAnnotations()
if isinstance(x, omero.gateway.FileAnnotationWrapper)
}
if not len(self._files):
raise Exception(
"exception:metadata: experiment has no annotation files."
)
elif len(self.file_annotations) != len(self._files):
raise Exception("Number of files and annotations do not match")
return self._files
@property
def tags(self):
if self._tags is None:
self._tags = {
x.getname(): x
for x in self.ome_class.listAnnotations()
if isinstance(x, omero.gateway.TagAnnotationWrapper)
}
return self._tags
def cache_logs(self, root_dir):
valid_suffixes = ("txt", "log")
for _, annotation in self.files.items():
filepath = root_dir / annotation.getFileName().replace("/", "_")
if (
any([str(filepath).endswith(suff) for suff in valid_suffixes])
and not filepath.exists()
):
# save only the text files
with open(str(filepath), "wb") as fd:
for chunk in annotation.getFileInChunks():
fd.write(chunk)
return True
@classmethod
def from_h5(
cls,
filepath: t.Union[str, Path],
):
"""Instatiate Dataset from a hdf5 file.
Parameters
----------
cls : Image
Image class
filepath : t.Union[str, Path]
Location of hdf5 file.
Examples
--------
FIXME: Add docs.
"""
# metadata = load_attributes(filepath)
bridge = BridgeH5(filepath)
dataset_keys = ("omero_id", "omero_id,", "dataset_id")
for k in dataset_keys:
if k in bridge.meta_h5:
return cls(
bridge.meta_h5[k], **cls.server_info_from_h5(filepath)
)
class Image(BridgeOmero):
"""
Loads images from OMERO and gives access to the data and metadata.
"""
def __init__(self, image_id: int, **server_info):
"""
Establishes the connection to the OMERO server via the Argo
base class.
Parameters
----------
image_id: integer
server_info: dictionary
Specifies the host, username, and password as strings
"""
super().__init__(ome_id=image_id, **server_info)
@classmethod
def from_h5(
cls,
filepath: t.Union[str, Path],
):
"""Instatiate Image from a hdf5 file.
Parameters
----------
cls : Image
Image class
filepath : t.Union[str, Path]
Location of hdf5 file.
Examples
--------
FIXME: Add docs.
"""
# metadata = load_attributes(filepath)
bridge = BridgeH5(filepath)
image_id = bridge.meta_h5["image_id"]
return cls(image_id, **cls.server_info_from_h5(filepath))
@property
def name(self):
return self.ome_class.getName()
@property
def data(self):
return get_data_lazy(self.ome_class)
@property
def metadata(self):
"""
Store metadata saved in OMERO: image size, number of time points,
labels of channels, and image name.
"""
meta = dict()
meta["size_x"] = self.ome_class.getSizeX()
meta["size_y"] = self.ome_class.getSizeY()
meta["size_z"] = self.ome_class.getSizeZ()
meta["size_c"] = self.ome_class.getSizeC()
meta["size_t"] = self.ome_class.getSizeT()
meta["channels"] = self.ome_class.getChannelLabels()
meta["name"] = self.ome_class.getName()
return meta
class UnsafeImage(Image):
"""
Loads images from OMERO and gives access to the data and metadata.
This class is a temporary solution while we find a way to use
context managers inside napari. It risks resulting in zombie connections
and producing freezes in an OMERO server.
"""
def __init__(self, image_id, **server_info):
"""
Establishes the connection to the OMERO server via the Argo
base class.
Parameters
----------
image_id: integer
server_info: dictionary
Specifies the host, username, and password as strings
"""
super().__init__(image_id, **server_info)
self.create_gate()
@property
def data(self):
try:
return get_data_lazy(self.ome_class)
except Exception as e:
print(f"ERROR: Failed fetching image from server: {e}")
self.conn.connect(False)
def get_data_lazy(image) -> da.Array:
"""
Get 5D dask array, with delayed reading from OMERO image.
"""
nt, nc, nz, ny, nx = [getattr(image, f"getSize{x}")() for x in "TCZYX"]
pixels = image.getPrimaryPixels()
dtype = PIXEL_TYPES.get(pixels.getPixelsType().value, None)
# using dask
get_plane = delayed(lambda idx: pixels.getPlane(*idx))
def get_lazy_plane(zct):
return da.from_delayed(get_plane(zct), shape=(ny, nx), dtype=dtype)
# 5D stack: TCZXY
t_stacks = []
for t in range(nt):
c_stacks = []
for c in range(nc):
z_stack = []
for z in range(nz):
z_stack.append(get_lazy_plane((z, c, t)))
c_stacks.append(da.stack(z_stack))
t_stacks.append(da.stack(c_stacks))
return da.stack(t_stacks)
"""
Models that link regions of interest, such as mothers and buds.
"""
"""
Extracted from the baby repository. Bud Tracker algorithm to link
cell outlines as mothers and buds.
"""
# /usr/bin/env jupyter
import pickle
from os.path import join
import numpy as np
from skimage.draw import polygon
from agora.track_abc import FeatureCalculator
models_path = join(dirname(__file__), "./models")
class BudTracker(FeatureCalculator):
def __init__(self, model=None, feats2use=None, **kwargs):
if model is None:
model_file = join(models_path, "mb_model_20201022.pkl")
with open(model_file, "rb") as file_to_load:
model = pickle.load(file_to_load)
self.model = model
if feats2use is None:
feats2use = ["centroid", "area", "minor_axis_length"]
super().__init__(feats2use, **kwargs)
self.a_ind = self.outfeats.index("area")
self.ma_ind = self.outfeats.index("minor_axis_length")
### Assign mother-
def calc_mother_bud_stats(self, p_budneck, p_bud, masks, feats=None):
"""
---
input
:p_budneck: 2d ndarray (size_x, size_y) giving the probability that a
pixel corresponds to a bud neck
:p_bud: 2d ndarray (size_x, size_y) giving the probability that a pixel
corresponds to a bud
:masks: 3d ndarray (ncells, size_x, size_y)
:feats: ndarray (ncells, nfeats)
NB: ASSUMES FEATS HAVE ALREADY BEEN NORMALISED!
returns
:n2darray: 2d ndarray (ncells x ncells, n_feats) specifying,
for each pair of cells in the masks array, the features used for
mother-bud pair prediction (as per 'feats2use')
"""
if feats is None:
feats = self.calc_feats_from_mask(masks)
elif len(feats) != len(masks):
raise Exception("number of features must match number of masks")
ncells = len(masks)
# Entries will be NaN unless validly specified below
p_bud_mat = np.nan * np.ones((ncells, ncells))
p_budneck_mat = np.nan * np.ones((ncells, ncells))
budneck_ratio_mat = np.nan * np.ones((ncells, ncells))
size_ratio_mat = np.nan * np.ones((ncells, ncells))
adjacency_mat = np.nan * np.ones((ncells, ncells))
for m in range(ncells):
for d in range(ncells):
if m == d:
# Mother-bud pairs can only be between different cells
continue
p_bud_mat[m, d] = np.mean(p_bud[masks[d].astype("bool")])
a_i = self.a_ind
size_ratio_mat[m, d] = feats[m, a_i] / feats[d, a_i]
# Draw rectangle
r_points = self.get_rpoints(feats, d, m)
if r_points is None:
continue
rr, cc = polygon(
r_points[0, :], r_points[1, :], p_budneck.shape
)
if len(rr) == 0:
# Rectangles with zero size are not informative
continue
r_im = np.zeros(p_budneck.shape, dtype="bool")
r_im[rr, cc] = True
# Calculate the mean of bud neck probabilities greater than some threshold
pbn = p_budneck[r_im].flatten()
pbn = pbn[pbn > 0.2]
p_budneck_mat[m, d] = np.mean(pbn) if len(pbn) > 0 else 0
# Normalise number of bud-neck positive pixels by the scale of
# the bud (a value proportional to circumference):
raw_circumf_est = np.sqrt(feats[d, a_i]) / self.pixel_size
budneck_ratio_mat[m, d] = pbn.sum() / raw_circumf_est
# Adjacency is the proportion of the joining rectangle that overlaps the mother daughter union
md_union = masks[m] | masks[d]
adjacency_mat[m, d] = np.sum(md_union & r_im) / np.sum(r_im)
return np.hstack(
[
s.flatten()[:, np.newaxis]
for s in (
p_bud_mat,
size_ratio_mat,
p_budneck_mat,
budneck_ratio_mat,
adjacency_mat,
)
]
)
def predict_mother_bud(self, p_budneck, p_bud, masks, feats=None):
"""
---
input
:p_budneck: 2d ndarray (size_x, size_y) giving the probability that a
pixel corresponds to a bud neck
:p_bud: 2d ndarray (size_x, size_y) giving the probability that a pixel
corresponds to a bud
:masks: 3d ndarray (ncells, size_x, size_y)
:feats: ndarray (ncells, nfeats)
returns
:n2darray: 2d ndarray (ncells, ncells) giving probability that a cell
(row) is a mother to another cell (column)
"""
ncells = len(masks)
mb_stats = self.calc_mother_bud_stats(p_budneck, p_bud, masks, feats)
good_stats = ~np.isnan(mb_stats).any(axis=1)
# Assume probability of bud assignment for any rows that are NaN will
# be zero
ba_probs = np.zeros(ncells**2)
if good_stats.any():
ba_probs[good_stats] = self.model.predict_proba(
mb_stats[good_stats, :]
)[:, 1]
ba_probs = ba_probs.reshape((ncells,) * 2)
return ba_probs
def get_rpoints(self, feats, d, m):
"""
Draw a rectangle in the budneck of cells
---
NB: ASSUMES FEATS HAVE ALREADY BEEN NORMALISED!
input
feats: 2d ndarray (ncells, nfeats)
returns
r_points: 2d ndarray (2,4) with the coordinates of the rectangle corner
"""
# Get un-scaled features for m-d pair
descaled_feats = feats / self.pixel_size
m_centre = descaled_feats[m, :2]
d_centre = descaled_feats[d, :2]
r_width = np.max((2, descaled_feats[d, self.ma_ind] * 0.25))
# Draw connecting rectangle
r_hvec = d_centre - m_centre
r_wvec = np.matmul(np.array([[0, -1], [1, 0]]), r_hvec)
r_wvec_len = np.linalg.norm(r_wvec)
if r_wvec_len == 0:
return None
r_wvec = r_width * r_wvec / r_wvec_len
r_points = np.zeros((2, 4))
r_points[:, 0] = m_centre - 0.5 * r_wvec
r_points[:, 1] = r_points[:, 0] + r_hvec
r_points[:, 2] = r_points[:, 1] + r_wvec
r_points[:, 3] = r_points[:, 2] - r_hvec
return r_points
"""Set up and run pipelines: tiling, segmentation, extraction, and then post-processing."""
import logging
import os
import re
import traceback
import typing as t
from copy import copy
from importlib.metadata import version
from pathlib import Path
import h5py
import numpy as np
import pandas as pd
from pathos.multiprocessing import Pool
from tqdm import tqdm
from agora.abc import ParametersABC, ProcessABC
from agora.io.metadata import MetaData, parse_logfiles
from agora.io.reader import StateReader
from agora.io.signal import Signal
from agora.io.writer import (
LinearBabyWriter,
StateWriter,
TilerWriter,
)
from aliby.baby_client import BabyParameters, BabyRunner
from aliby.haystack import initialise_tf
from aliby.io.dataset import dispatch_dataset
from aliby.io.image import dispatch_image
from aliby.tile.tiler import Tiler, TilerParameters
from extraction.core.extractor import Extractor, ExtractorParameters
from extraction.core.functions.defaults import exparams_from_meta
from postprocessor.core.processor import PostProcessor, PostProcessorParameters
class PipelineParameters(ParametersABC):
"""Define parameters for the steps of the pipeline."""
_pool_index = None
def __init__(
self, general, tiler, baby, extraction, postprocessing, reporting
):
"""Initialise, but called by a class method - not directly."""
self.general = general
self.tiler = tiler
self.baby = baby
self.extraction = extraction
self.postprocessing = postprocessing
self.reporting = reporting
@classmethod
def default(
cls,
general={},
tiler={},
baby={},
extraction={},
postprocessing={},
):
"""
Initialise parameters for steps of the pipeline.
Some parameters are extracted from the log files.
Parameters
---------
general: dict
Parameters to set up the pipeline.
tiler: dict
Parameters for tiler.
baby: dict (optional)
Parameters for Baby.
extraction: dict (optional)
Parameters for extraction.
postprocessing: dict (optional)
Parameters for post-processing.
"""
expt_id = general.get("expt_id", 19993)
if isinstance(expt_id, Path):
assert expt_id.exists()
expt_id = str(expt_id)
general["expt_id"] = expt_id
directory = Path(general["directory"])
# get log files, either locally or via OMERO
with dispatch_dataset(
expt_id,
**{k: general.get(k) for k in ("host", "username", "password")},
) as conn:
directory = directory / conn.unique_name
if not directory.exists():
directory.mkdir(parents=True)
# download logs for metadata
conn.cache_logs(directory)
try:
meta_d = MetaData(directory, None).load_logs()
except Exception as e:
logging.getLogger("aliby").warn(
f"WARNING:Metadata: error when loading: {e}"
)
minimal_default_meta = {
"channels": ["Brightfield"],
"ntps": [2000],
}
# set minimal metadata
meta_d = minimal_default_meta
# define default values for general parameters
tps = meta_d.get("ntps", 2000)
defaults = {
"general": dict(
id=expt_id,
distributed=0,
tps=tps,
directory=str(directory.parent),
filter="",
earlystop=dict(
min_tp=100,
thresh_pos_clogged=0.4,
thresh_trap_ncells=8,
thresh_trap_area=0.9,
ntps_to_eval=5,
),
logfile_level="INFO",
use_explog=True,
)
}
# update default values using inputs
for k, v in general.items():
if k not in defaults["general"]:
defaults["general"][k] = v
elif isinstance(v, dict):
for k2, v2 in v.items():
defaults["general"][k][k2] = v2
else:
defaults["general"][k] = v
# define defaults and update with any inputs
defaults["tiler"] = TilerParameters.default(**tiler).to_dict()
# Generate a backup channel, for when logfile meta is available
# but not image metadata.
backup_ref_channel = None
if "channels" in meta_d and isinstance(
defaults["tiler"]["ref_channel"], str
):
backup_ref_channel = meta_d["channels"].index(
defaults["tiler"]["ref_channel"]
)
defaults["tiler"]["backup_ref_channel"] = backup_ref_channel
defaults["baby"] = BabyParameters.default(**baby).to_dict()
defaults["extraction"] = (
exparams_from_meta(meta_d)
or BabyParameters.default(**extraction).to_dict()
)
defaults["postprocessing"] = PostProcessorParameters.default(
**postprocessing
).to_dict()
defaults["reporting"] = {}
return cls(**{k: v for k, v in defaults.items()})
def load_logs(self):
parsed_flattened = parse_logfiles(self.log_dir)
return parsed_flattened
class Pipeline(ProcessABC):
"""
Initialise and run tiling, segmentation, extraction and post-processing.
Each step feeds the next one.
To customise parameters for any step use the PipelineParameters class.stem
"""
pipeline_steps = ["tiler", "baby", "extraction"]
step_sequence = [
"tiler",
"baby",
"extraction",
"postprocessing",
]
# Specify the group in the h5 files written by each step
writer_groups = {
"tiler": ["trap_info"],
"baby": ["cell_info"],
"extraction": ["extraction"],
"postprocessing": ["postprocessing", "modifiers"],
}
writers = { # TODO integrate Extractor and PostProcessing in here
"tiler": [("tiler", TilerWriter)],
"baby": [("baby", LinearBabyWriter), ("state", StateWriter)],
}
def __init__(self, parameters: PipelineParameters, store=None):
"""Initialise - not usually called directly."""
super().__init__(parameters)
if store is not None:
store = Path(store)
self.store = store
@staticmethod
def setLogger(
folder, file_level: str = "INFO", stream_level: str = "WARNING"
):
"""Initialise and format logger."""
logger = logging.getLogger("aliby")
logger.setLevel(getattr(logging, file_level))
formatter = logging.Formatter(
"%(asctime)s - %(levelname)s:%(message)s",
datefmt="%Y-%m-%dT%H:%M:%S%z",
)
# for streams - stdout, files, etc.
ch = logging.StreamHandler()
ch.setLevel(getattr(logging, stream_level))
ch.setFormatter(formatter)
logger.addHandler(ch)
# create file handler that logs even debug messages
fh = logging.FileHandler(Path(folder) / "aliby.log", "w+")
fh.setLevel(getattr(logging, file_level))
fh.setFormatter(formatter)
logger.addHandler(fh)
@classmethod
def from_yaml(cls, fpath):
# This is just a convenience function, think before implementing
# for other processes
return cls(parameters=PipelineParameters.from_yaml(fpath))
@classmethod
def from_folder(cls, dir_path):
"""
Re-process all h5 files in a given folder.
All files must share the same parameters, even if they have different channels.
Parameters
---------
dir_path : str or Pathlib
Folder containing the files.
"""
# find h5 files
dir_path = Path(dir_path)
files = list(dir_path.rglob("*.h5"))
assert len(files), "No valid files found in folder"
fpath = files[0]
# TODO add support for non-standard unique folder names
with h5py.File(fpath, "r") as f:
pipeline_parameters = PipelineParameters.from_yaml(
f.attrs["parameters"]
)
pipeline_parameters.general["directory"] = dir_path.parent
pipeline_parameters.general["filter"] = [fpath.stem for fpath in files]
# fix legacy post-processing parameters
post_process_params = pipeline_parameters.postprocessing.get(
"parameters", None
)
if post_process_params:
pipeline_parameters.postprocessing["param_sets"] = copy(
post_process_params
)
del pipeline_parameters.postprocessing["parameters"]
return cls(pipeline_parameters)
@classmethod
def from_existing_h5(cls, fpath):
"""
Re-process an existing h5 file.
Not suitable for more than one file.
Parameters
---------
fpath: str
Name of file.
"""
with h5py.File(fpath, "r") as f:
pipeline_parameters = PipelineParameters.from_yaml(
f.attrs["parameters"]
)
directory = Path(fpath).parent
pipeline_parameters.general["directory"] = directory
pipeline_parameters.general["filter"] = Path(fpath).stem
post_process_params = pipeline_parameters.postprocessing.get(
"parameters", None
)
if post_process_params:
pipeline_parameters.postprocessing["param_sets"] = copy(
post_process_params
)
del pipeline_parameters.postprocessing["parameters"]
return cls(pipeline_parameters, store=directory)
@property
def _logger(self):
return logging.getLogger("aliby")
def run(self):
"""Run separate pipelines for all positions in an experiment."""
# general information in config
config = self.parameters.to_dict()
expt_id = config["general"]["id"]
distributed = config["general"]["distributed"]
pos_filter = config["general"]["filter"]
root_dir = Path(config["general"]["directory"])
self.server_info = {
k: config["general"].get(k)
for k in ("host", "username", "password")
}
dispatcher = dispatch_dataset(expt_id, **self.server_info)
logging.getLogger("aliby").info(
f"Fetching data using {dispatcher.__class__.__name__}"
)
# get log files, either locally or via OMERO
with dispatcher as conn:
image_ids = conn.get_images()
directory = self.store or root_dir / conn.unique_name
if not directory.exists():
directory.mkdir(parents=True)
# download logs to use for metadata
conn.cache_logs(directory)
# update configuration
self.parameters.general["directory"] = str(directory)
config["general"]["directory"] = directory
self.setLogger(directory)
# pick particular images if desired
if pos_filter is not None:
if isinstance(pos_filter, list):
image_ids = {
k: v
for filt in pos_filter
for k, v in self.apply_filter(image_ids, filt).items()
}
else:
image_ids = self.apply_filter(image_ids, pos_filter)
assert len(image_ids), "No images to segment"
# create pipelines
if distributed != 0:
# multiple cores
with Pool(distributed) as p:
results = p.map(
lambda x: self.run_one_position(*x),
[(k, i) for i, k in enumerate(image_ids.items())],
)
else:
# single core
results = []
for k, v in tqdm(image_ids.items()):
r = self.run_one_position((k, v), 1)
results.append(r)
return results
def apply_filter(self, image_ids: dict, filt: int or str):
"""Select images by picking a particular one or by using a regular expression to parse their file names."""
if isinstance(filt, str):
# pick images using a regular expression
image_ids = {
k: v for k, v in image_ids.items() if re.search(filt, k)
}
elif isinstance(filt, int):
# pick the filt'th image
image_ids = {
k: v for i, (k, v) in enumerate(image_ids.items()) if i == filt
}
return image_ids
def run_one_position(
self,
name_image_id: t.Tuple[str, str or Path or int],
index: t.Optional[int] = None,
):
"""Set up and run a pipeline for one position."""
self._pool_index = index
name, image_id = name_image_id
# session and filename are defined by calling setup_pipeline.
# can they be deleted here?
session = None
filename = None
#
run_kwargs = {"extraction": {"labels": None, "masks": None}}
try:
(
filename,
meta,
config,
process_from,
tps,
steps,
earlystop,
session,
trackers_state,
) = self._setup_pipeline(image_id)
loaded_writers = {
name: writer(filename)
for k in self.step_sequence
if k in self.writers
for name, writer in self.writers[k]
}
writer_ow_kwargs = {
"state": loaded_writers["state"].datatypes.keys(),
"baby": ["mother_assign"],
}
# START PIPELINE
frac_clogged_traps = 0.0
min_process_from = min(process_from.values())
with dispatch_image(image_id)(
image_id, **self.server_info
) as image:
# initialise steps
if "tiler" not in steps:
steps["tiler"] = Tiler.from_image(
image, TilerParameters.from_dict(config["tiler"])
)
if process_from["baby"] < tps:
session = initialise_tf(2)
steps["baby"] = BabyRunner.from_tiler(
BabyParameters.from_dict(config["baby"]),
steps["tiler"],
)
if trackers_state:
steps["baby"].crawler.tracker_states = trackers_state
# limit extraction parameters using the available channels in tiler
if process_from["extraction"] < tps:
# TODO Move this parameter validation into Extractor
av_channels = set((*steps["tiler"].channels, "general"))
config["extraction"]["tree"] = {
k: v
for k, v in config["extraction"]["tree"].items()
if k in av_channels
}
config["extraction"]["sub_bg"] = av_channels.intersection(
config["extraction"]["sub_bg"]
)
av_channels_wsub = av_channels.union(
[c + "_bgsub" for c in config["extraction"]["sub_bg"]]
)
tmp = copy(config["extraction"]["multichannel_ops"])
for op, (input_ch, _, _) in tmp.items():
if not set(input_ch).issubset(av_channels_wsub):
del config["extraction"]["multichannel_ops"][op]
exparams = ExtractorParameters.from_dict(
config["extraction"]
)
steps["extraction"] = Extractor.from_tiler(
exparams, store=filename, tiler=steps["tiler"]
)
# set up progress meter
pbar = tqdm(
range(min_process_from, tps),
desc=image.name,
initial=min_process_from,
total=tps,
)
for i in pbar:
if (
frac_clogged_traps
< earlystop["thresh_pos_clogged"]
or i < earlystop["min_tp"]
):
# run through steps
for step in self.pipeline_steps:
if i >= process_from[step]:
result = steps[step].run_tp(
i, **run_kwargs.get(step, {})
)
if step in loaded_writers:
loaded_writers[step].write(
data=result,
overwrite=writer_ow_kwargs.get(
step, []
),
tp=i,
meta={"last_processed": i},
)
# perform step
if (
step == "tiler"
and i == min_process_from
):
logging.getLogger("aliby").info(
f"Found {steps['tiler'].n_tiles} traps in {image.name}"
)
elif step == "baby":
# write state and pass info to Extractor
loaded_writers["state"].write(
data=steps[
step
].crawler.tracker_states,
overwrite=loaded_writers[
"state"
].datatypes.keys(),
tp=i,
)
elif step == "extraction":
# remove mask/label after extraction
for k in ["masks", "labels"]:
run_kwargs[step][k] = None
# check and report clogging
frac_clogged_traps = self.check_earlystop(
filename, earlystop, steps["tiler"].tile_size
)
if frac_clogged_traps > 0.3:
self._log(
f"{name}:Clogged_traps:{frac_clogged_traps}"
)
frac = np.round(frac_clogged_traps * 100)
pbar.set_postfix_str(f"{frac} Clogged")
else:
# stop if too many traps are clogged
self._log(
f"{name}:Stopped early at time {i} with {frac_clogged_traps} clogged traps"
)
meta.add_fields({"end_status": "Clogged"})
break
meta.add_fields({"last_processed": i})
# run post-processing
meta.add_fields({"end_status": "Success"})
post_proc_params = PostProcessorParameters.from_dict(
config["postprocessing"]
)
PostProcessor(filename, post_proc_params).run()
self._log("Analysis finished successfully.", "info")
return 1
except Exception as e:
# catch bugs during setup or run time
logging.exception(
f"{name}: Exception caught.",
exc_info=True,
)
# print the type, value, and stack trace of the exception
traceback.print_exc()
raise e
finally:
_close_session(session)
@staticmethod
def check_earlystop(filename: str, es_parameters: dict, tile_size: int):
"""
Check recent time points for tiles with too many cells.
Returns the fraction of clogged tiles, where clogged tiles have
too many cells or too much of their area covered by cells.
Parameters
----------
filename: str
Name of h5 file.
es_parameters: dict
Parameters defining when early stopping should happen.
For example:
{'min_tp': 100,
'thresh_pos_clogged': 0.4,
'thresh_trap_ncells': 8,
'thresh_trap_area': 0.9,
'ntps_to_eval': 5}
tile_size: int
Size of tile.
"""
# get the area of the cells organised by trap and cell number
s = Signal(filename)
df = s.get_raw("/extraction/general/None/area")
# check the latest time points only
cells_used = df[
df.columns[-1 - es_parameters["ntps_to_eval"] : -1]
].dropna(how="all")
# find tiles with too many cells
traps_above_nthresh = (
cells_used.groupby("trap").count().apply(np.mean, axis=1)
> es_parameters["thresh_trap_ncells"]
)
# find tiles with cells covering too great a fraction of the tiles' area
traps_above_athresh = (
cells_used.groupby("trap").sum().apply(np.mean, axis=1)
/ tile_size**2
> es_parameters["thresh_trap_area"]
)
return (traps_above_nthresh & traps_above_athresh).mean()
# FIXME: Remove this functionality. It used to be for
# older hdf5 file formats.
def _load_config_from_file(
self,
filename: Path,
process_from: t.Dict[str, int],
trackers_state: t.List,
overwrite: t.Dict[str, bool],
):
with h5py.File(filename, "r") as f:
for k in process_from.keys():
if not overwrite[k]:
process_from[k] = self.legacy_get_last_tp[k](f)
process_from[k] += 1
return process_from, trackers_state, overwrite
# FIXME: Remove this functionality. It used to be for
# older hdf5 file formats.
@staticmethod
def legacy_get_last_tp(step: str) -> t.Callable:
"""Get last time-point in different ways depending
on which step we are using
To support segmentation in aliby < v0.24
TODO Deprecate and replace with State method
"""
switch_case = {
"tiler": lambda f: f["trap_info/drifts"].shape[0] - 1,
"baby": lambda f: f["cell_info/timepoint"][-1],
"extraction": lambda f: f[
"extraction/general/None/area/timepoint"
][-1],
}
return switch_case[step]
def _setup_pipeline(
self, image_id: int
) -> t.Tuple[
Path,
MetaData,
t.Dict,
int,
t.Dict,
t.Dict,
t.Optional[int],
t.List[np.ndarray],
]:
"""
Initialise steps in a pipeline.
If necessary use a file to re-start experiments already partly run.
Parameters
----------
image_id : int or str
Identifier of a data set in an OMERO server or a filename.
Returns
-------
filename: str
Path to a h5 file to write to.
meta: object
agora.io.metadata.MetaData object
config: dict
Configuration parameters.
process_from: dict
Gives from which time point each step of the pipeline should start.
tps: int
Number of time points.
steps: dict
earlystop: dict
Parameters to check whether the pipeline should be stopped.
session: None
trackers_state: list
States of any trackers from earlier runs.
"""
config = self.parameters.to_dict()
# TODO Alan: Verify if session must be passed
session = None
earlystop = config["general"].get("earlystop", None)
process_from = {k: 0 for k in self.pipeline_steps}
steps = {}
# check overwriting
ow_id = config["general"].get("overwrite", 0)
ow = {step: True for step in self.step_sequence}
if ow_id and ow_id is not True:
ow = {
step: self.step_sequence.index(ow_id) < i
for i, step in enumerate(self.step_sequence, 1)
}
# Set up
directory = config["general"]["directory"]
trackers_state: t.List[np.ndarray] = []
with dispatch_image(image_id)(image_id, **self.server_info) as image:
filename = Path(f"{directory}/{image.name}.h5")
meta = MetaData(directory, filename)
from_start = True if np.any(ow.values()) else False
# remove existing file if overwriting
if (
from_start
and (
config["general"].get("overwrite", False)
or np.all(list(ow.values()))
)
and filename.exists()
):
os.remove(filename)
# if the file exists with no previous segmentation use its tiler
if filename.exists():
self._log("Result file exists.", "info")
if not ow["tiler"]:
steps["tiler"] = Tiler.from_hdf5(image, filename)
try:
(
process_from,
trackers_state,
ow,
) = self._load_config_from_file(
filename, process_from, trackers_state, ow
)
# get state array
trackers_state = (
[]
if ow["baby"]
else StateReader(filename).get_formatted_states()
)
config["tiler"] = steps["tiler"].parameters.to_dict()
except Exception:
self._log(f"Overwriting tiling data")
if config["general"]["use_explog"]:
meta.run()
# add metadata not in the log file
meta.add_fields(
{
"aliby_version": version("aliby"),
"baby_version": version("aliby-baby"),
"omero_id": config["general"]["id"],
"image_id": image_id
if isinstance(image_id, int)
else str(image_id),
"parameters": PipelineParameters.from_dict(
config
).to_yaml(),
}
)
tps = min(config["general"]["tps"], image.data.shape[0])
return (
filename,
meta,
config,
process_from,
tps,
steps,
earlystop,
session,
trackers_state,
)
def _close_session(session):
if session:
session.close()
"""
Select regions from a larger image for efficient processing.
"""
"""
Tiler: Divides images into smaller tiles.
The tasks of the Tiler are selecting regions of interest, or tiles, of images - with one trap per tile, correcting for the drift of the microscope stage over time, and handling errors and bridging between the image data and Aliby’s image-processing steps.
Tiler subclasses deal with either network connections or local files.
To find tiles, we use a two-step process: we analyse the bright-field image to produce the template of a trap, and we fit this template to the image to find the tiles' centres.
We use texture-based segmentation (entropy) to split the image into foreground -- cells and traps -- and background, which we then identify with an Otsu filter. Two methods are used to produce a template trap from these regions: pick the trap with the smallest minor axis length and average over all validated traps.
A peak-identifying algorithm recovers the x and y-axis location of traps in the original image, and we choose the approach to template that identifies the most tiles.
The experiment is stored as an array with a standard indexing order of (Time, Channels, Z-stack, X, Y).
"""
import logging
import re
import typing as t
import warnings
from functools import lru_cache
from pathlib import Path
import dask.array as da
import h5py
import numpy as np
from skimage.registration import phase_cross_correlation
from agora.abc import ParametersABC, StepABC
from agora.io.writer import BridgeH5
from aliby.io.image import ImageDummy
from aliby.tile.traps import segment_traps
class Tile:
"""
Store a tile's location and size.
Checks to see if the tile should be padded.
Can export the tile either in OMERO or numpy formats.
"""
def __init__(self, centre, parent, size, max_size):
self.centre = centre
self.parent = parent # used to access drifts
self.size = size
self.half_size = size // 2
self.max_size = max_size
def at_time(self, tp: int) -> t.List[int]:
"""
Return tile's centre by applying drifts.
Parameters
----------
tp: integer
Index for the time point of interest.
"""
drifts = self.parent.drifts
tile_centre = self.centre - np.sum(drifts[: tp + 1], axis=0)
return list(tile_centre.astype(int))
def as_tile(self, tp: int):
"""
Return tile in the OMERO tile format of x, y, w, h.
Here x, y are at the bottom left corner of the tile
and w and h are the tile width and height.
Parameters
----------
tp: integer
Index for the time point of interest.
Returns
-------
x: int
x-coordinate of bottom left corner of tile
y: int
y-coordinate of bottom left corner of tile
w: int
Width of tile
h: int
Height of tile
"""
x, y = self.at_time(tp)
# tile bottom corner
x = int(x - self.half_size)
y = int(y - self.half_size)
return x, y, self.size, self.size
def as_range(self, tp: int):
"""
Return tile in a range format: two slice objects that can
be used in arrays.
Parameters
----------
tp: integer
Index for a time point
Returns
-------
A slice of x coordinates from left to right
A slice of y coordinates from top to bottom
"""
x, y, w, h = self.as_tile(tp)
return slice(x, x + w), slice(y, y + h)
class TileLocations:
"""Store each tile as an instance of Tile."""
def __init__(
self,
initial_location: np.array,
tile_size: int = None,
max_size: int = 1200,
drifts: np.array = None,
):
if drifts is None:
drifts = []
self.tile_size = tile_size
self.max_size = max_size
self.initial_location = initial_location
self.tiles = [
Tile(centre, self, tile_size or max_size, max_size)
for centre in initial_location
]
self.drifts = drifts
def __len__(self):
return len(self.tiles)
def __iter__(self):
yield from self.tiles
@property
def shape(self):
"""Return numbers of tiles and drifts."""
return len(self.tiles), len(self.drifts)
def to_dict(self, tp: int):
"""
Export initial locations, tile_size, max_size, and drifts
as a dictionary.
Parameters
----------
tp: integer
An index for a time point
"""
res = dict()
if tp == 0:
res["trap_locations"] = self.initial_location
res["attrs/tile_size"] = self.tile_size
res["attrs/max_size"] = self.max_size
res["drifts"] = np.expand_dims(self.drifts[tp], axis=0)
return res
def at_time(self, tp: int) -> np.ndarray:
"""Return an array of tile centres (x- and y-coords)."""
return np.array([tile.at_time(tp) for tile in self.tiles])
@classmethod
def from_tiler_init(
cls, initial_location, tile_size: int = None, max_size: int = 1200
):
"""Instantiate from a Tiler."""
return cls(initial_location, tile_size, max_size, drifts=[])
@classmethod
def read_hdf5(cls, file):
"""Instantiate from a h5 file."""
with h5py.File(file, "r") as hfile:
tile_info = hfile["trap_info"]
initial_locations = tile_info["trap_locations"][()]
drifts = tile_info["drifts"][()].tolist()
max_size = tile_info.attrs["max_size"]
tile_size = tile_info.attrs["tile_size"]
tile_loc_cls = cls(initial_locations, tile_size, max_size=max_size)
tile_loc_cls.drifts = drifts
return tile_loc_cls
class TilerParameters(ParametersABC):
"""
tile_size: int
ref_channel: str or int
ref_z: int
backup_ref_channel int or None, if int indicates the index for reference channel. Used when image does not include metadata, ref_channel is a string and channel names are included in parsed logfiles.
"""
_defaults = {
"tile_size": 117,
"ref_channel": "Brightfield",
"ref_z": 0,
"backup_ref_channel": None,
}
class Tiler(StepABC):
"""
Divide images into smaller tiles for faster processing.
Finds tiles and re-registers images if they drift.
Fetch images from an OMERO server if necessary.
Uses an Image instance, which lazily provides the data on pixels,
and, as an independent argument, metadata.
"""
def __init__(
self,
image: da.core.Array,
metadata: dict,
parameters: TilerParameters,
tile_locs=None,
):
"""
Initialise.
Parameters
----------
image: an instance of Image
metadata: dictionary
parameters: an instance of TilerParameters
tile_locs: (optional)
"""
super().__init__(parameters)
self.image = image
self._metadata = metadata
self.channels = metadata.get(
"channels",
list(range(metadata.get("size_c", 0))),
)
self.ref_channel = self.get_channel_index(parameters.ref_channel)
if self.ref_channel is None:
self.ref_channel = self.backup_ref_channel
self.ref_channel = self.get_channel_index(parameters.ref_channel)
self.tile_locs = tile_locs
try:
self.z_perchannel = {
ch: zsect
for ch, zsect in zip(self.channels, metadata["zsections"])
}
except Exception as e:
self._log(f"No z_perchannel data: {e}")
self.tile_size = self.tile_size or min(self.image.shape[-2:])
@classmethod
def dummy(cls, parameters: dict):
"""
Instantiate dummy Tiler from dummy image.
If image.dimorder exists dimensions are saved in that order.
Otherwise default to "tczyx".
Parameters
----------
parameters: dict
An instance of TilerParameters converted to a dict.
"""
imgdmy_obj = ImageDummy(parameters)
dummy_image = imgdmy_obj.get_data_lazy()
# default to "tczyx" if image.dimorder is None
dummy_omero_metadata = {
f"size_{dim}": dim_size
for dim, dim_size in zip(
imgdmy_obj.dimorder or "tczyx", dummy_image.shape
)
}
dummy_omero_metadata.update(
{
"channels": [
parameters["ref_channel"],
*(["nil"] * (dummy_omero_metadata["size_c"] - 1)),
],
"name": "",
}
)
return cls(
imgdmy_obj.data,
dummy_omero_metadata,
TilerParameters.from_dict(parameters),
)
@classmethod
def from_image(cls, image, parameters: TilerParameters):
"""
Instantiate from an Image instance.
Parameters
----------
image: an instance of Image
parameters: an instance of TilerPameters
"""
return cls(image.data, image.metadata, parameters)
@classmethod
def from_h5(
cls,
image,
filepath: t.Union[str, Path],
parameters: t.Optional[TilerParameters] = None,
):
"""
Instantiate from h5 files.
Parameters
----------
image: an instance of Image
filepath: Path instance
Path to a directory of h5 files
parameters: an instance of TileParameters (optional)
"""
tile_locs = TileLocations.read_hdf5(filepath)
metadata = BridgeH5(filepath).meta_h5
metadata["channels"] = image.metadata["channels"]
if parameters is None:
parameters = TilerParameters.default()
tiler = cls(
image.data,
metadata,
parameters,
tile_locs=tile_locs,
)
if hasattr(tile_locs, "drifts"):
tiler.n_processed = len(tile_locs.drifts)
return tiler
@lru_cache(maxsize=2)
def get_tc(self, t: int, c: int) -> np.ndarray:
"""
Load image using dask.
Assumes the image is arranged as
no of time points
no of channels
no of z stacks
no of pixels in y direction
no of pixels in x direction
Parameters
----------
t: integer
An index for a time point
c: integer
An index for a channel
Returns
-------
full: an array of images
"""
full = self.image[t, c]
if hasattr(full, "compute"): # If using dask fetch images here
full = full.compute(scheduler="synchronous")
return full
@property
def shape(self):
"""
Return properties of the time-lapse as shown by self.image.shape
"""
return self.image.shape
@property
def n_processed(self):
"""Return the number of processed images."""
if not hasattr(self, "_n_processed"):
self._n_processed = 0
return self._n_processed
@n_processed.setter
def n_processed(self, value):
self._n_processed = value
@property
def n_tiles(self):
"""Return number of tiles."""
return len(self.tile_locs)
def initialise_tiles(self, tile_size: int = None):
"""
Find initial positions of tiles.
Remove tiles that are too close to the edge of the image
so no padding is necessary.
Parameters
----------
tile_size: integer
The size of a tile.
"""
initial_image = self.image[0, self.ref_channel, self.ref_z]
if tile_size:
half_tile = tile_size // 2
# max_size is the minimal number of x or y pixels
max_size = min(self.image.shape[-2:])
# first time point, reference channel, reference z-position
# find the tiles
tile_locs = segment_traps(initial_image, tile_size)
# keep only tiles that are not near an edge
tile_locs = [
[x, y]
for x, y in tile_locs
if half_tile < x < max_size - half_tile
and half_tile < y < max_size - half_tile
]
# store tiles in an instance of TileLocations
self.tile_locs = TileLocations.from_tiler_init(
tile_locs, tile_size
)
else:
yx_shape = self.image.shape[-2:]
tile_locs = [[x // 2 for x in yx_shape]]
self.tile_locs = TileLocations.from_tiler_init(
tile_locs, max_size=min(yx_shape)
)
def find_drift(self, tp: int):
"""
Find any translational drift between two images at consecutive
time points using cross correlation.
Arguments
---------
tp: integer
Index for a time point.
"""
prev_tp = max(0, tp - 1)
# cross-correlate
drift, _, _ = phase_cross_correlation(
self.image[prev_tp, self.ref_channel, self.ref_z],
self.image[tp, self.ref_channel, self.ref_z],
)
# store drift
if 0 < tp < len(self.tile_locs.drifts):
self.tile_locs.drifts[tp] = drift.tolist()
else:
self.tile_locs.drifts.append(drift.tolist())
def get_tp_data(self, tp, c) -> np.ndarray:
"""
Returns all tiles corrected for drift.
Parameters
----------
tp: integer
An index for a time point
c: integer
An index for a channel
Returns
----------
Numpy ndarray of tiles with shape (tile, z, y, x)
"""
tiles = []
# get OMERO image
full = self.get_tc(tp, c)
for tile in self.tile_locs:
# pad tile if necessary
ndtile = self.ifoob_pad(full, tile.as_range(tp))
tiles.append(ndtile)
return np.stack(tiles)
def get_tile_data(self, tile_id: int, tp: int, c: int):
"""
Return a particular tile corrected for drift and padding.
Parameters
----------
tile_id: integer
Number of tile.
tp: integer
Index of time points.
c: integer
Index of channel.
Returns
-------
ndtile: array
An array of (x, y) arrays, one for each z stack
"""
full = self.get_tc(tp, c)
tile = self.tile_locs.tiles[tile_id]
ndtile = self.ifoob_pad(full, tile.as_range(tp))
return ndtile
def _run_tp(self, tp: int):
"""
Find tiles if they have not yet been found.
Determine any translational drift of the current image from the
previous one.
Arguments
---------
tp: integer
The time point to tile.
"""
# assert tp >= self.n_processed, "Time point already processed"
# TODO check contiguity?
if self.n_processed == 0 or not hasattr(self.tile_locs, "drifts"):
self.initialise_tiles(self.tile_size)
if hasattr(self.tile_locs, "drifts"):
drift_len = len(self.tile_locs.drifts)
if self.n_processed != drift_len:
warnings.warn("Tiler:n_processed and ndrifts don't match")
self.n_processed = drift_len
# determine drift
self.find_drift(tp)
# update n_processed
self.n_processed = tp + 1
# return result for writer
return self.tile_locs.to_dict(tp)
def run(self, time_dim=None):
"""
Tile all time points in an experiment at once.
"""
if time_dim is None:
time_dim = 0
for frame in range(self.image.shape[time_dim]):
self.run_tp(frame)
return None
def get_traps_timepoint(self, *args, **kwargs):
self._log(
"get_traps_timepoint is deprecated; get_tiles_timepoint instead."
)
return self.get_tiles_timepoint(*args, **kwargs)
# The next set of functions are necessary for the extraction object
def get_tiles_timepoint(
self, tp: int, tile_shape=None, channels=None, z: int = 0
) -> np.ndarray:
"""
Get a multidimensional array with all tiles for a set of channels
and z-stacks.
Used by extractor.
Parameters
---------
tp: int
Index of time point
tile_shape: int or tuple of two ints
Size of tile in x and y dimensions
channels: string or list of strings
Names of channels of interest
z: int
Index of z-channel of interest
Returns
-------
res: array
Data arranged as (tiles, channels, time points, X, Y, Z)
"""
# FIXME add support for sub-tiling a tile
# FIXME can we ignore z
if channels is None:
channels = [0]
elif isinstance(channels, str):
channels = [channels]
# get the data
res = []
for c in channels:
# only return requested z
val = self.get_tp_data(tp, c)[:, z]
# starts with the order: tiles, z, y, x
# returns the order: tiles, C, T, Z, X, Y
val = np.expand_dims(val, axis=1)
res.append(val)
if tile_shape is not None:
if isinstance(tile_shape, int):
tile_shape = (tile_shape, tile_shape)
assert np.all(
[
(tile_size - ax) > -1
for tile_size, ax in zip(tile_shape, res[0].shape[-3:-2])
]
)
return np.stack(res, axis=1)
@property
def ref_channel_index(self):
"""Return index of reference channel."""
return self.get_channel_index(self.parameters.ref_channel)
def get_channel_index(self, channel: str or int) -> int or None:
"""
Find index for channel using regex. Returns the first matched string.
If self.channels is integers (no image metadata) it returns None.
If channel is integer
Parameters
----------
channel: string or int
The channel or index to be used.
"""
if all(map(lambda x: isinstance(x, int), self.channels)):
channel = channel if isinstance(channel, int) else None
if isinstance(channel, str):
channel = find_channel_index(self.channels, channel)
return channel
@staticmethod
def ifoob_pad(full, slices):
"""
Return the slices padded if out of bounds.
Parameters
----------
full: array
Slice of OMERO image (zstacks, x, y) - the entire position
with zstacks as first axis
slices: tuple of two slices
Delineates indices for the x- and y- ranges of the tile.
Returns
-------
tile: array
A tile with all z stacks for the given slices.
If some padding is needed, the median of the image is used.
If much padding is needed, a tile of NaN is returned.
"""
# number of pixels in the y direction
max_size = full.shape[-1]
# ignore parts of the tile outside of the image
y, x = [slice(max(0, s.start), min(max_size, s.stop)) for s in slices]
# get the tile including all z stacks
tile = full[:, y, x]
# find extent of padding needed in x and y
padding = np.array(
[(-min(0, s.start), -min(0, max_size - s.stop)) for s in slices]
)
if padding.any():
tile_size = slices[0].stop - slices[0].start
if (padding > tile_size / 4).any():
# too much of the tile is outside of the image
# fill with NaN
tile = np.full((full.shape[0], tile_size, tile_size), np.nan)
else:
# pad tile with median value of the tile
tile = np.pad(tile, [[0, 0]] + padding.tolist(), "median")
return tile
# FIXME: Refactor to support both channel or index
# self._log below is not defined
def find_channel_index(image_channels: t.List[str], channel: str):
"""
Access
"""
for i, ch in enumerate(image_channels):
found = re.match(channel, ch, re.IGNORECASE)
if found:
if len(found.string) - (found.endpos - found.start()):
logging.getLogger("aliby").log(
logging.WARNING,
f"Channel {channel} matched {ch} using regex",
)
return i
def find_channel_name(image_channels: t.List[str], channel: str):
"""
Find the name of the channel using regex.
Parameters
----------
image_channels: list of str
Channels.
channel: str
A regular expression.
"""
index = find_channel_index(image_channels, channel)
if index is not None:
return image_channels[index]
"""Functions for identifying and dealing with ALCATRAS traps."""
import numpy as np
from skimage import feature, transform
from skimage.filters import threshold_otsu
from skimage.filters.rank import entropy
from skimage.measure import label, regionprops
from skimage.morphology import closing, disk, square
from skimage.segmentation import clear_border
from skimage.util import img_as_ubyte
def half_floor(x, tile_size):
return x - tile_size // 2
def half_ceil(x, tile_size):
return x + -(tile_size // -2)
def segment_traps(
image,
tile_size,
downscale=0.4,
disk_radius_frac=0.01,
square_size=3,
min_frac_tilesize=0.3,
**identify_traps_kwargs,
):
"""
Use an entropy filter and Otsu thresholding to find a trap template,
which is then passed to identify_trap_locations.
To obtain candidate traps the major axis length of a tile must be smaller than tilesize.
The hyperparameters have not been optimised.
Parameters
----------
image: 2D array
tile_size: integer
Size of the tile
downscale: float (optional)
Fraction by which to shrink image
disk_radius_frac: float (optional)
Radius of disk using in the entropy filter
square_size: integer (optional)
Parameter for a morphological closing applied to thresholded
image
min_frac_tilesize: float (optional)
max_frac_tilesize: float (optional)
Used to determine bounds on the major axis length of regions
suspected of containing traps.
identify_traps_kwargs:
Passed to identify_trap_locations
Returns
-------
traps: an array of pairs of integers
The coordinates of the centroids of the traps.
"""
# keep a memory of image in case need to re-run
img = image
# bounds on major axis length of traps
min_mal = min_frac_tilesize * tile_size
# shrink image
if downscale != 1:
img = transform.rescale(image, downscale)
# generate an entropy image using a disk footprint
disk_radius = int(min([disk_radius_frac * x for x in img.shape]))
entropy_image = entropy(img_as_ubyte(img), disk(disk_radius))
if downscale != 1:
entropy_image = transform.rescale(entropy_image, 1 / downscale)
# find Otsu threshold for entropy image
thresh = threshold_otsu(entropy_image)
# apply morphological closing to thresholded, and so binary, image
bw = closing(entropy_image > thresh, square(square_size))
# remove artifacts connected to image border
cleared = clear_border(bw)
# label distinct regions of the image
label_image = label(cleared)
# find regions likely to contain traps:
# with a major axis length within a certain range
# and a centroid at least tile_size // 2 away from the image edge
idx_valid_region = [
(i, region)
for i, region in enumerate(regionprops(label_image))
if min_mal < region.major_axis_length < tile_size
and tile_size // 2
< region.centroid[0]
< half_floor(image.shape[0], tile_size) - 1
and tile_size // 2
< region.centroid[1]
< half_floor(image.shape[1], tile_size) - 1
]
assert idx_valid_region, "No valid tiling regions found"
_, valid_region = zip(*idx_valid_region)
# find centroids and minor axes lengths of valid regions
centroids = (
np.array([x.centroid for x in valid_region]).round().astype(int)
)
minals = [region.minor_axis_length for region in valid_region]
# coords for best trap
x, y = np.round(centroids[np.argmin(minals)]).astype(int)
# make candidate templates from the other traps found
candidate_templates = [
image[
half_floor(x, tile_size) : half_ceil(x, tile_size),
half_floor(y, tile_size) : half_ceil(y, tile_size),
]
for x, y in centroids
]
# make a mean template from all the found traps
mean_template = np.stack(candidate_templates).astype(int).mean(axis=0)
# find traps using the mean trap template
traps = identify_trap_locations(
image, mean_template, **identify_traps_kwargs
)
# if there are too few traps, try again
traps_retry = []
if len(traps) < 30 and downscale != 1:
print("Tiler:TrapIdentification: Trying again.")
traps_retry = segment_traps(image, tile_size, downscale=1)
# return results with the most number of traps
if len(traps_retry) < len(traps):
return traps
else:
return traps_retry
def identify_trap_locations(
image, trap_template, optimize_scale=True, downscale=0.35, trap_size=None
):
"""
Identify the traps in a single image based on a trap template.
Requires the trap template to be similar to the image
(same camera, same magnification - ideally the same experiment).
Use normalised correlation in scikit-image's to match_template.
The search is sped up by down-scaling both the image and
the trap template before running the template matching.
The trap template is rotated and re-scaled to improve matching.
The parameters of the rotation and re-scaling are optimised, although
over restricted ranges.
Parameters
----------
image: 2D array
trap_template: 2D array
optimize_scale : boolean (optional)
downscale: float (optional)
Fraction by which to downscale to increase speed
trap_size: integer (optional)
If unspecified, the size is determined from the trap_template
Returns
-------
coordinates: an array of pairs of integers
"""
if trap_size is None:
trap_size = trap_template.shape[0]
# careful: the image is float16!
img = transform.rescale(image.astype(float), downscale)
template = transform.rescale(trap_template, downscale)
# try multiple rotations of template to determine
# which best matches the image
# result is squared because the sign of the correlation is unimportant
matches = {
rotation: feature.match_template(
img,
transform.rotate(template, rotation, cval=np.median(img)),
pad_input=True,
mode="median",
)
** 2
for rotation in [0, 90, 180, 270]
}
# find best rotation
best_rotation = max(matches, key=lambda x: np.percentile(matches[x], 99.9))
# rotate template by best rotation
template = transform.rotate(template, best_rotation, cval=np.median(img))
if optimize_scale:
# try multiple scales appled to template to determine which
# best matches the image
scales = np.linspace(0.5, 2, 10)
matches = {
scale: feature.match_template(
img,
transform.rescale(template, scale),
mode="median",
pad_input=True,
)
** 2
for scale in scales
}
# find best scale
best_scale = max(
matches, key=lambda x: np.percentile(matches[x], 99.9)
)
# choose the best result - an image of normalised correlations
# with the template
matched = matches[best_scale]
else:
# find the image of normalised correlations with the template
matched = feature.match_template(
img, template, pad_input=True, mode="median"
)
# re-scale back the image of normalised correlations
# find the coordinates of local maxima
coordinates = feature.peak_local_max(
transform.rescale(matched, 1 / downscale),
min_distance=int(trap_size * 0.70),
exclude_border=(trap_size // 3),
)
return coordinates
"""
Classes that link outlines within and between time-points.
"""
import numpy as np
import pandas as pd
from scipy.ndimage import binary_fill_holes
from baby.io import load_tiled_image
from baby.tracker.core import CellTracker
class CellBenchmarker: # TODO Simplify this by inheritance
"""
Takes a metadata dataframe and a model and estimates the prediction in a trap-wise manner.
This class can also produce confusion matrices for a given Tracker and validation dataset.
"""
def __init__(self, meta, model, bak_model, nstepsback=None):
self.indices = ["experimentID", "position", "trap", "tp"]
self.cindices = self.indices + ["cellLabels"]
self.meta = meta.copy()
self.meta["cont_list_index"] = (*range(len(self.meta)),)
self.tracker = CellTracker(model=model, bak_model=bak_model)
if nstepsback is None:
self.nstepsback = self.tracker.nstepsback
self.traps_loc
@property
def traps_loc(self):
"""
Generates a list of trap locations using the metadata.
"""
if not hasattr(self, "_traps_loc"):
traps = np.unique(
[
ind[: self.indices.index("trap") + 1]
for ind in self.meta.index
],
axis=0,
)
# str->int conversion
traps = [(ind[0], *map(int, ind[1:])) for ind in traps]
self._traps_loc = (*map(tuple, traps),)
return self._traps_loc
@property
def masks(self):
if not hasattr(self, "_masks"):
self._masks = [
load_tiled_image(fname)[0] for fname in self.meta["filename"]
]
for i, mask in enumerate(self._masks):
for j in range(mask.shape[2]):
self._masks[i][..., j] = binary_fill_holes(
self._masks[i][..., j]
)
self._masks = [np.moveaxis(mask, 2, 0) for mask in self._masks]
return self._masks
def predict_lbls_from_tpimgs(self, tp_img_tuple):
max_lbl = 0
prev_feats = []
cell_lbls = []
for tp, masks in tp_img_tuple:
lastn_lbls = cell_lbls[-self.nstepsback :]
lastn_feats = prev_feats[-self.nstepsback :]
new_lbl, feats, max_lbl = self.tracker.get_new_lbls(
masks, lastn_lbls, lastn_feats, max_lbl
)
cell_lbls = cell_lbls + [new_lbl]
prev_feats = prev_feats + [feats]
return (tp, cell_lbls)
def df_get_imglist(self, exp, pos, trap, tp=None):
df = self.meta.loc[(exp, pos, trap), ["cont_list_index", "cellLabels"]]
return zip(df.index, [self.masks[i] for i in df["cont_list_index"]])
def predict_set(self, exp, pos, trap, tp=None):
"""
Predict labels using tp1-tp2 accuracy of prediction
"""
# print("Processing trap {}".format(exp, pos, trap))
tp_img_tuple = (*self.df_get_imglist(exp, pos, trap),)
tp, lbl_list = self.predict_lbls_from_tpimgs(tp_img_tuple)
# print("loc {}, {}, {}, labels: {}".format(exp, pos, trap, lbl_list))
return lbl_list
def compare_traps(self, exp, pos, trap):
"""
Error calculator for testing model and assignment heuristics.
Uses the trap id to compare the amount of cells correctly predicted.
This uses local indices, not whole timepoints. It returns the
fraction of cells correctly predicted, and the timepoints of mistakes
Returns:
float: Fraction of cells correctly predicted
list of 2-sized tuples: list of tp id of errors and the mistaken cell
"""
print("Processing trap {}, {}, {}".format(exp, pos, trap))
new_cids = self.predict_set(exp, pos, trap)
test_df = self.meta.loc(axis=0)[(exp, pos, trap)].copy()
test_df["pred_cellLabels"] = new_cids
orig = test_df["cellLabels"].values
new = test_df["pred_cellLabels"].values
local_indices = [[], []]
# Case just defines if it is the test or new set
# print("Making tp-wise comparison")
for i, case in enumerate(
(zip(orig[:-1], orig[1:]), zip(new[:-1], new[1:]))
):
for prev_cells, pos_cells in case:
local_assignment = [
prev_cells.index(cell) if cell in prev_cells else -1
for cell in pos_cells
]
local_indices[i] += local_assignment
# Flatten
if len(local_indices) > 2:
flt_test, flt_new = [
np.array([j for i in case for j in i])
for case in local_indices
]
tp_list = np.array(
[i for i, vals in enumerate(local_indices[0]) for j in vals]
)
else:
flt_test, flt_new = [
np.array([i for i in case]) for case in local_indices
]
# tp_list = np.array(
# [i for i, vals in enumerate(local_indices[0]) for j in vals])
correct = flt_test == flt_new
if len(local_indices) > 2:
error_list = tp_list[~correct]
error_cid = (
test_df.iloc[1:]["cellLabels"].explode().dropna()[~correct].values
)
frac_correct = np.mean(correct)
print("Fraction of correct predictions", frac_correct)
if len(local_indices) > 2:
return (frac_correct, list(zip(error_list, error_cid)))
else:
# print("Warning: Single set of tps for this position")
return (frac_correct, error_cid)
def predict_all(self):
"""
Predict all datasets defined in self.traps_loc
"""
stepsback = [2]
threshs = [0.9]
self.predictions = {}
for nstepsback in stepsback:
for thresh in threshs:
self.nstepsback = nstepsback
self.tracker.nstepsback = nstepsback
self.low_thresh = 1 - thresh
self.high_thresh = thresh
self.thresh = thresh * 5 / 8
for address in self.traps_loc:
self.predictions[
(nstepsback, thresh, address)
] = self.predict_set(*address)
def calculate_errsum(self):
"""
Calculate all errors, addresses of images with errors and error fractions.
"""
frac_errs = {}
all_errs = {}
nerrs = {}
stepsback = list(range(1, 3))
threshs = [0.95]
for nstepsback in stepsback:
for thresh in threshs:
self.nstepsback = nstepsback
self.tracker.nstepsback = nstepsback
self.low_thresh = 1 - thresh
self.high_thresh = thresh
self.thresh = thresh * 5 / 8
all_errs[(thresh, nstepsback)] = {}
frac_errs[(thresh, nstepsback)] = []
nerrs[(thresh, nstepsback)] = []
for address in self.traps_loc:
fraction, errors = self.compare_traps(*address)
if len(errors):
all_errs[(thresh, nstepsback)][address] = errors
frac_errs[(thresh, nstepsback)].append(fraction)
nerrs[(thresh, nstepsback)].append(len(errors))
else:
nerrs[(thresh, nstepsback)].append(0)
frac_errs[(thresh, nstepsback)].append(1.0)
return (frac_errs, all_errs, nerrs)
def get_truth_matrix_from_pair(self, pair):
"""
Requires self.meta
args:
:pair: tuple of size 4 (experimentID, position, trap (tp1, tp2))
returns
:truth_mat: boolean ndarray of shape (ncells(tp1) x ncells(tp2)
links cells in tp1 to cells in tp2
"""
clabs1 = self.meta.loc[pair[:3] + (pair[3][0],), "cellLabels"]
clabs2 = self.meta.loc[pair[:3] + (pair[3][1],), "cellLabels"]
truth_mat = gen_boolmat_from_clabs(clabs1, clabs2)
return truth_mat
def get_mota_stats(self, pair):
true_mat = self.get_truth_matrix_from_pair(pair)
prob_mat = self.tracker.predict_proba_from_ndarray(
ndarray, *args, **kwargs
)
pred_mat = prob_mat > thresh
true_flat = true_mat.flatten()
pred_flat = pred_mat.flatten()
true_pos = np.sum(true_flat & pred_flat)
false_pos = np.sum(true_flat & ~pred_flat)
# TODO add identity switch
def gen_cm_stats(self, pair, thresh=0.7, *args, **kwargs):
"""
Calculate confusion matrix for a pair of pos-timepoints
"""
masks = [self.masks[i] for i in self.meta.loc[pair, "cont_list_index"]]
feats = [self.tracker.calc_feats_from_mask(mask) for mask in masks]
ndarray = self.tracker.calc_feat_ndarray(*feats)
self.tracker.low_thresh = 1 - thresh
self.tracker.high_thresh = thresh
prob_mat = self.tracker.predict_proba_from_ndarray(
ndarray, *args, **kwargs
)
pred_mat = prob_mat > thresh
true_mat = self.get_truth_matrix_from_pair(pair)
if not len(true_mat) and not len(pred_mat):
return (0, 0, 0, 0)
true_flat = true_mat.flatten()
pred_flat = pred_mat.flatten()
true_pos = np.sum(true_flat & pred_flat)
false_pos = np.sum(true_flat & ~pred_flat)
false_neg = np.sum(~true_flat & pred_flat)
true_neg = np.sum(~true_flat & ~pred_flat)
return (true_pos, false_pos, false_neg, true_neg)
def extract_pairs_from_trap(self, trap_loc):
subdf = self.meta[["list_index", "cellLabels"]].loc(axis=0)[trap_loc]
pairs = [
trap_loc + tuple((pair,))
for pair in zip(subdf.index[:-1], subdf.index[1:])
]
return pairs
def gen_pairlist(self):
self.pairs = [
self.extract_pairs_from_trap(trap) for trap in self.traps_loc
]
def gen_cm_from_pairs(self, thresh=0.5, *args, **kwargs):
con_mat = {}
con_mat["tp"] = 0
con_mat["fp"] = 0
con_mat["fn"] = 0
con_mat["tn"] = 0
for pairset in self.pairs:
for pair in pairset:
res = self.gen_cm_stats(pair, thresh=thresh, *args, **kwargs)
con_mat["tp"] += res[0]
con_mat["fp"] += res[1]
con_mat["fn"] += res[2]
con_mat["tn"] += res[3]
self._con_mat = con_mat
return self._con_mat
def get_frac_error_df(self):
"""
Calculates the trap-wise error and averages across a position.
"""
self.frac_errs, self.all_errs, self.nerrs = self.calculate_errsum()
# nerrs_df = pd.DataFrame(self.nerrs).melt()
frac_df = pd.DataFrame(self.frac_errs).melt()
return frac_df
def gen_errorplots(self):
"""
Calculates the trap-wise error and averages across a position.
"""
frac_df = self.get_frac_error_df()
import seaborn as sns
from matplotlib import pyplot as plt
# ax = sns.barplot(x='variable_0', y='value', data=frac_df)
ax = sns.barplot(
x="variable_1", y="value", hue="variable_0", data=frac_df
)
ax.set(
xlabel="Backtrace depth",
ylabel="Fraction of correct assignments",
ylim=(0.9, 1),
)
plt.legend(title="Threshold")
plt.savefig("tracker_benchmark_btdepth.png")
plt.show()
# def plot_pair(self, address)
def gen_boolmat_from_clabs(clabs1, clabs2):
if not np.any(clabs1) and not np.any(clabs2):
return np.array([])
boolmat = np.zeros((len(clabs1), len(clabs2))).astype(bool)
for i, lab1 in enumerate(clabs1):
for j, lab2 in enumerate(clabs2):
if lab1 == lab2:
boolmat[i, j] = True
return boolmat
def gen_stats_dict(results):
"""
Generates a dictionary using results from different binary classification tasks,
for example, using different thresholds
output
dictionary containing the name of statistic as a key and a list
of that statistic for the data subsets.
"""
funs = (get_precision, get_recall, get_tnr, get_balanced_acc)
names = ("precision", "recall", "TNR", "balanced_acc")
stats_dict = {
name: [fun(res) for res in results] for fun, name in zip(funs, names)
}
return stats_dict
def get_precision(res_dict):
return (res_dict["tp"]) / (res_dict["tp"] + res_dict["fp"])
def get_recall(res_dict):
return res_dict["tp"] / (res_dict["tp"] + res_dict["fn"])
def get_tnr(res_dict):
return res_dict["tn"] / (res_dict["tn"] + res_dict["fp"])
def get_balanced_acc(res_dict):
return (get_recall(res_dict) + get_tnr(res_dict)) / 2
#!/usr/bin/env python
"""
TrackerCoordinator class to coordinate cell tracking and bud assignment.
"""
import pickle
import typing as t
from collections import Counter
from os.path import dirname, join
from pathlib import Path
import numpy as np
from scipy.optimize import linear_sum_assignment
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from agora.track_abc import FeatureCalculator
models_path = join(dirname(__file__), "../models")
class CellTracker(FeatureCalculator):
"""
Class used to manage cell tracking. You can call it using an existing model or
use the inherited CellTrainer to get a new one.
Initialization parameters:
:model: sklearn.ensemble.RandomForestClassifier object
:trapfeats: Features to manually calculate within a trap
:extrafeats: Additional features to calculate
:model: Model to use, if provided ignores all other args but threshs
:bak_model: Backup mode to use when prediction is unsure
:nstepsback: int Number of timepoints to go back
:thresh: float Cut-off value to assume a cell is not new
:low_thresh: Lower thresh for model switching
:high_thresh: Higher thresh for model switching.
Probabilities between these two thresholds summon the
backup model.
:aweights: area weight for barycentre calculations
# Feature order in array features
1. basic features
2. trap features (within a trap)
3. extra features (process between imgs)
"""
def __init__(
self,
feats2use=None,
trapfeats=None,
extrafeats=None,
model=None,
bak_model=None,
thresh=None,
low_thresh=None,
high_thresh=None,
nstepsback=None,
aweights=None,
red_fun=None,
max_distance=None,
**kwargs,
):
if trapfeats is None:
trapfeats = ()
if extrafeats is None:
extrafeats = ()
if type(model) is str or type(model) is Path:
with open(Path(model), "rb") as f:
model = pickle.load(f)
if type(bak_model) is str or type(bak_model) is Path:
with open(Path(bak_model), "rb") as f:
bak_model = pickle.load(f)
if aweights is None:
self.aweights = None
if feats2use is None: # Ignore this block when training
if model is None:
model = self.load_model(models_path, "ct_XGBC_20220703_11.pkl")
if bak_model is None:
bak_model = self.load_model(
models_path, "ct_XGBC_20220703_10.pkl"
)
self.model = model
self.bak_model = bak_model
main_feats = model.all_ifeats
bak_feats = bak_model.all_ifeats
feats2use, trapfeats, extrafeats = [
tuple(sorted(set(main).union(bak)))
for main, bak in zip(main_feats, bak_feats)
]
# Training AND non-training part
super().__init__(
feats2use, trapfeats=trapfeats, extrafeats=extrafeats, **kwargs
)
self.extrafeats = tuple(extrafeats)
self.all_ofeats = self.outfeats + trapfeats + extrafeats
self.noutfeats = len(self.all_ofeats)
if hasattr(self, "bak_model"): # Back to non-training only
self.mainof_ids = [
self.all_ofeats.index(f) for f in self.model.all_ofeats
]
self.bakof_ids = [
self.all_ofeats.index(f) for f in self.bak_model.all_ofeats
]
if nstepsback is None:
nstepsback = 3
self.nstepsback = nstepsback
if thresh is None:
thresh = 0.5
self.thresh = thresh
if low_thresh is None:
low_thresh = 0.4
if high_thresh is None:
high_thresh = 0.6
self.low_thresh, self.high_thresh = low_thresh, high_thresh
if red_fun is None:
red_fun = np.nanmax
self.red_fun = red_fun
if max_distance is None:
self.max_distance = max_distance
def get_feats2use(self):
"""
Return feats to be used from a loaded random forest model model
"""
nfeats = get_nfeats_from_model(self.model)
nfeats_bak = get_nfeats_from_model(self.bak_model)
# max_nfeats = max((nfeats, nfeats_bak))
return (switch_case_nfeats(nfeats), switch_case_nfeats(nfeats_bak))
def calc_feat_ndarray(self, prev_feats, new_feats):
"""
Calculate feature ndarray using two ndarrays of features.
---
input
:prev_feats: ndarray (ncells, nfeats) of timepoint 1
:new_feats: ndarray (ncells, nfeats) of timepoint 2
returns
:n3darray: ndarray (ncells_prev, ncells_new, nfeats) containing a
cell-wise substraction of the features in the input ndarrays.
"""
if not (new_feats.any() and prev_feats.any()):
return np.array([])
n3darray = np.empty((len(prev_feats), len(new_feats), self.ntfeats))
# print('self: ', self, ' self.ntfeats: ', self.ntfeats, ' featsshape: ', new_feats.shape)
for i in range(self.ntfeats):
n3darray[..., i] = np.subtract.outer(
prev_feats[:, i], new_feats[:, i]
)
n3darray = self.calc_dtfeats(n3darray)
return n3darray
def calc_dtfeats(self, n3darray):
"""
Calculates features obtained between timepoints, such as distance
for every pair of cells from t1 to t2.
---
input
:n3darray: ndarray (ncells_prev, ncells_new, ntfeats) containing a
cell-wise substraction of the features in the input ndarrays.
returns
:newarray: 3d array taking the features specified in self.outfeats and self.trapfeats
and adding dtfeats
"""
newarray = np.empty(n3darray.shape[:-1] + (self.noutfeats,))
newarray[..., : len(self.outfeats)] = n3darray[
..., : len(self.outfeats)
]
newarray[
..., len(self.outfeats) : len(self.outfeats) + len(self.trapfeats)
] = n3darray[..., len(self.out_merged) :]
for i, feat in enumerate(self.all_ofeats):
if feat == "distance":
newarray[..., i] = np.sqrt(
n3darray[..., self.xind] ** 2
+ n3darray[..., self.yind] ** 2
)
return newarray
def assign_lbls(
self,
prob_backtrace: np.ndarray,
prev_lbls: t.List[t.List[int]],
red_fun=None,
):
"""Assign labels using a prediction matrix of nxmxl where n is the number
of cells in the previous image, m the number of steps back considered
and l in the new image. It assigns the
number zero if it doesn't find the cell.
---
input
:prob_backtrace: Probability n x m x l array obtained as an output of rforest
:prev_labels: List of cell labels for previous timepoint to be compared.
:red_fun: Function used to collapse the previous timepoints into one.
If none provided it uses maximum and ignores np.nans.
returns
:new_lbls: ndarray of newly assigned labels obtained, new cells as
zero.
"""
if red_fun is None:
red_fun = self.red_fun
new_lbls = np.zeros(prob_backtrace.shape[2], dtype=int)
pred_matrix = red_fun(prob_backtrace, axis=1)
if pred_matrix.any():
# assign available hits
row_ids, col_ids = linear_sum_assignment(-pred_matrix)
for i, j in zip(row_ids, col_ids):
if pred_matrix[i, j] > self.thresh:
new_lbls[j] = prev_lbls[i]
return new_lbls
def predict_proba_from_ndarray(
self,
array_3d: np.ndarray,
model: str = None,
boolean: bool = False,
max_distance: float = None,
):
"""
input
:array_3d: (ncells_tp1, ncells_tp2, out_feats) ndarray
:model: str, {'model', 'bak_model'} can force a unique model instead of an ensemble
:boolean: bool, if False returns probability, if True returns prediction
:max_distance: float Maximum distance (in um) to be considered. If None it uses the instance's value,
if zero it skips checking distances.
requires
:self.model:
:self.mainof_ids: list of indices corresponding to the main model's features
:self.bakof_ids: list of indices corresponding to the backup model's features
returns
(ncells_tp1, ncells_tp2) ndarray with probabilities or prediction
of cell identities depending on "boolean" arg.
"""
if array_3d.size == 0:
return np.array([])
if model is None:
model2use = self.model
bak_model2use = self.bak_model
bakof_ids = self.bakof_ids
mainof_ids = self.mainof_ids
else:
model2use = getattr(self, "model")
bak_model2use = model2use
bakof_ids = [
self.all_ofeats.index(f) for f in model2use.all_ofeats
]
mainof_ids = [
self.all_ofeats.index(f) for f in model2use.all_ofeats
]
fun2use = "predict" if boolean else "predict_proba"
predict_fun = getattr(model2use, fun2use)
bak_pred_fun = getattr(bak_model2use, fun2use)
if max_distance is None:
max_distance = self.max_distance
orig_shape = array_3d.shape[:2]
# Ignore cells that are too far away to possibly be the same
cells_near = np.ones(orig_shape, dtype=bool)
if max_distance and set(self.all_ofeats).intersection(
("distance", "centroid-0", "centroid-1")
):
if "distance" in self.all_ofeats:
cells_near = (
array_3d[..., self.all_ofeats.index("distance")]
< max_distance
)
else: # Calculate euclidean distance
cells_near = (
np.sqrt(
array_3d[..., self.all_ofeats.index("centroid-0")]
+ array_3d[..., self.all_ofeats.index("centroid-1")]
)
< max_distance
)
pred_matrix = np.zeros(orig_shape)
prob = np.zeros(orig_shape)
if cells_near.any():
prob = predict_fun(array_3d[cells_near][:, mainof_ids])[:, 1]
uncertain_dfeats = (self.low_thresh < prob) & (
prob < self.high_thresh
)
if uncertain_dfeats.any():
bak_prob = bak_pred_fun(
array_3d[cells_near][uncertain_dfeats][:, bakof_ids]
)[:, 1]
probs_compared = np.stack((prob[uncertain_dfeats], bak_prob))
most_confident_proba = abs((probs_compared - 0.5)).argmax(
axis=0
)
prob[uncertain_dfeats] = probs_compared[
most_confident_proba, range(most_confident_proba.shape[0])
]
pred_matrix[cells_near] = prob
return pred_matrix
def get_new_lbls(
self,
new_img,
prev_lbls,
prev_feats,
max_lbl,
new_feats=None,
pixel_size=None,
**kwargs,
):
"""
Core function to calculate the new cell labels.
----
input
:new_img: ndarray (len, width, ncells) containing the cell outlines
:max_lbl: int indicating the last assigned cell label
:prev_feats: list of ndarrays of size (ncells x nfeatures)
containing the features of previous timepoints
:prev_lbls: list of list of ints corresponding to the cell labels in
the previous timepoints
:new_feats: (optional) Directly give a feature ndarray. It ignores
new_img if given.
:kwargs: Additional keyword values passed to self.predict_proba_from_ndarray
returns
:new_lbls: list of labels assigned to new timepoint
:new_feats: list of ndarrays containing the updated features
:new_max: updated max cell label assigned
"""
if new_feats is None:
new_feats = self.calc_feats_from_mask(new_img)
if new_feats.any():
if np.any([len(prev_feat) for prev_feat in prev_feats]):
counts = Counter(
[lbl for lbl_set in prev_lbls for lbl in lbl_set]
)
lbls_order = list(counts.keys())
probs = np.full(
(len(lbls_order), self.nstepsback, len(new_feats)), np.nan
)
for i, (lblset, prev_feat) in enumerate(
zip(prev_lbls, prev_feats)
):
if len(prev_feat):
feats_3darray = self.calc_feat_ndarray(
prev_feat, new_feats
)
pred_matrix = self.predict_proba_from_ndarray(
feats_3darray,
**kwargs,
)
for j, lbl in enumerate(lblset):
probs[lbls_order.index(lbl), i, :] = pred_matrix[
j, :
]
new_lbls = self.assign_lbls(probs, lbls_order)
new_cells_pos = new_lbls == 0
new_max = max_lbl + sum(new_cells_pos)
new_lbls[new_cells_pos] = [*range(max_lbl + 1, new_max + 1)]
# ensure that label output is consistently a list
new_lbls = new_lbls.tolist()
else:
new_lbls = [*range(max_lbl + 1, max_lbl + len(new_feats) + 1)]
new_max = max_lbl + len(new_feats)
else:
return ([], [], max_lbl)
return (new_lbls, new_feats, new_max)
def probabilities_from_impair(
self,
image_t1: np.ndarray,
image_t2: np.ndarray,
kwargs_feat_calc: dict = {},
**kwargs,
):
"""
Convenience function to test tracking between two time-points
:image_t1: np.ndarray containing mask of first time-point
:image_t2: np.ndarray containing mask of second time-point
:kwargs_feat_calc: are passed to self.calc_feats_from_mask calls
:kwargs: are passed to self.predict_proba_from_ndarray
"""
feats_t1 = self.calc_feats_from_mask(image_t1, **kwargs_feat_calc)
feats_t2 = self.calc_feats_from_mask(image_t2, **kwargs_feat_calc)
probability_matrix = np.array([])
if feats_t1.any() and feats_t2.any():
feats_3darray = self.calc_feat_ndarray(feats_t1, feats_t2)
probability_matrix = self.predict_proba_from_ndarray(
feats_3darray, **kwargs
)
return probability_matrix
# Step CellTracker
def run_tp(
self,
masks: np.ndarray,
state: t.Dict[str, t.Union[int, list]] = None,
**kwargs,
) -> t.Dict[str, t.Union[t.List[int], t.Dict[str, t.Union[int, list]]]]:
"""Assign labels to new masks using a state dictionary.
Parameters
----------
masks : np.ndarray
Cell masks to label.
state : t.Dict[str, t.Union[int, list]]
Dictionary containing maximum cell label, and previous cell labels
and features for those cells.
kwargs : keyword arguments
Keyword arguments passed to self.get_new_lbls
Returns
-------
t.Dict[str, t.Union[t.List[int], t.Dict[str, t.Union[int, list]]]]
New labels and new state dictionary.
Examples
--------
FIXME: Add example beyond the trivial one.
import numpy as np
from baby.tracker.core import celltracker
from tqdm import tqdm
# overlapping outlines are of shape (t,z,x,y)
masks = np.zeros((5, 3, 20, 20), dtype=bool)
masks[0, 0, 2:6, 2:6] = true
masks[1:, 0, 13:14, 13:14] = true
masks[:, 1, 8:12, 8:12] = true
masks[:, 2, 14:18, 14:18] = true
# 13um pixel size
ct = celltracker(pixel_size=0.185)
labels = []
state = none
for masks_tp in tqdm(masks):
new_labels, state = ct.run_tp(masks_tp, state=state)
labels.append(new_labels)
# should result in state['cell_lbls']
# [[1, 2, 3], [4, 2, 3], [4, 2, 3], [4, 2, 3], [4, 2, 3]]
"""
if state is None:
state = {}
max_lbl = state.get("max_lbl", 0)
cell_lbls = state.get("cell_lbls", [])
prev_feats = state.get("prev_feats", [])
# Get features for cells at this time point
feats = self.calc_feats_from_mask(masks)
lastn_lbls = cell_lbls[-self.nstepsback :]
lastn_feats = prev_feats[-self.nstepsback :]
new_lbls, _, max_lbl = self.get_new_lbls(
masks, lastn_lbls, lastn_feats, max_lbl, feats, **kwargs
)
state = {
"max_lbl": max_lbl,
"cell_lbls": cell_lbls + [new_lbls],
"prev_feats": prev_feats + [feats],
}
return (new_lbls, state)
# Helper functions
def switch_case_nfeats(nfeats):
"""
Convenience TEMPORAL function to determine whether to use distance/location
as a feature for tracking or not (nfeats=5 for no distance, 7 for distance)
input
number of feats
returns
list of main and extra feats based on the number of feats
"""
main_feats = {
4: [
("area", "minor_axis_length", "major_axis_length", "bbox_area"),
(),
(),
],
5: [
(
"area",
"minor_axis_length",
"major_axis_length",
"bbox_area",
"perimeter",
),
(),
(),
],
6: [
(
"area",
"minor_axis_length",
"major_axis_length",
"bbox_area",
"perimeter",
),
(),
("distance",),
],
# Including centroid
# 7 : [('centroid', 'area', 'minor_axis_length', 'major_axis_length',
# 'bbox_area', 'perimeter'), () , ()],
7: [
(
"area",
"minor_axis_length",
"major_axis_length",
"bbox_area",
"perimeter",
),
("baryangle", "barydist"),
(),
],
8: [ # Minus centroid
(
"area",
"minor_axis_length",
"major_axis_length",
"bbox_area",
# "eccentricity",
"equivalent_diameter",
# "solidity",
# "extent",
"orientation",
# "perimeter",
),
("baryangle", "barydist"),
(),
],
10: [ # Minus distance
(
"centroid",
"area",
"minor_axis_length",
"major_axis_length",
"bbox_area",
# "eccentricity",
"equivalent_diameter",
# "solidity",
# "extent",
"orientation",
# "perimeter",
),
("baryangle", "barydist"),
(),
],
11: [ # Minus computationally-expensive features
(
"centroid",
"area",
"minor_axis_length",
"major_axis_length",
"bbox_area",
# "eccentricity",
"equivalent_diameter",
# "solidity",
# "extent",
"orientation",
# "perimeter",
),
("baryangle", "barydist"),
("distance",),
],
15: [ # All features
(
"centroid",
"area",
"minor_axis_length",
"major_axis_length",
"bbox_area",
"eccentricity",
"equivalent_diameter",
"solidity",
"extent",
"orientation",
"perimeter",
),
("baryangle", "barydist"),
("distance",),
],
}
assert nfeats in main_feats.keys(), "invalid nfeats"
return main_feats.get(nfeats, [])
def get_nfeats_from_model(model) -> int:
if isinstance(model, SVC):
nfeats = model.support_vectors_.shape[-1]
elif isinstance(model, RandomForestClassifier):
nfeats = model.n_features_
return nfeats
# If you publish results that make use of this software or the Birth Annotator
# for Budding Yeast algorithm, please cite:
# Julian M J Pietsch, Alán Muñoz, Diane Adjavon, Ivan B N Clark, Peter S
# Swain, 2021, Birth Annotator for Budding Yeast (in preparation).
#
#
# The MIT License (MIT)
#
# Copyright (c) Julian Pietsch, Alán Muñoz and Diane Adjavon 2021
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to
# deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
# sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.
import typing as t
import numpy as np
# Calculate barycentre
def calc_barycentre(centres, weights=None, **kwargs):
"""
:centres: ndarray containing the (x,y) centres of each cell
:weights: (optional) list of weights to consider for each cell
"""
if weights is None:
weights = np.ones_like(centres)
barycentre = np.average(centres, axis=0, weights=weights)
return barycentre
# Calculate distance to center
def calc_barydists(centres, bary, **kwargs):
"""
Calculate distances to the barycentre
:centre: int (2,) tuple. Centre of cell
:bary: float (2,) tuple. Barycentre of image
"""
vec2bary = centres - bary
dists = np.sqrt(np.sum(vec2bary**2, axis=1))
return dists
# Calculate angle to center
def calc_baryangles(centres, bary, areas=None, **kwargs):
"""
Calculate angle using centre of cell and barycentre
:centre: int (2,) tuple. Centre of cell
:bary: float (2,) tuple. Barycentre of image
:anchor_cell: int Cell id to use as angle 0.
"""
angles = []
vec2bary = centres - bary
angles = np.apply_along_axis(lambda x: np.arctan2(*x), 1, vec2bary)
if areas is not None:
anchor_cell = np.argmax(areas)
angles -= angles[anchor_cell]
return angles
def pick_baryfun(key):
baryfuns = {"barydist": calc_barydists, "baryangle": calc_baryangles}
return baryfuns[key]
## Tracking benchmark utils
def lol_to_adj(cell_ids: t.List[t.List[int]]):
"""
Convert a series list of lists with cell ids into a matrix
representing a graph.
Note that information is lost in the process, and a matrix can't be
turned back into a list of list by itself.
input
:lol: list of lists with cell ids
returns
:adj_matrix: (n, n) ndarray where n is the number of cells
"""
n = len([y for x in cell_ids for y in x])
adj_mat = np.zeros((n, n))
prev = None
cur = 0
for c_ids_single_lst in cell_ids:
if not prev:
prev = c_ids_single_lst
else:
for i, el in enumerate(c_ids_single_lst):
prev_idx = prev.index(el) if el in prev else None
if prev_idx is not None:
adj_mat[cur + len(prev) + i, cur + prev_idx] = True
cur += len(c_ids_single_lst)
return adj_mat
def compare_pred_truth_lols(prediction, truth):
"""
input
:prediction: list of lists with predicted cell ids
:truth: list of lists with real cell ids
returns
number of diferences between equivalent truth matrices
"""
adj_pred = lol_to_adj(prediction)
adj_truth = lol_to_adj(truth)
return int(((adj_pred - adj_truth) != 0).sum())
"""Additional tools to fetch and handle datasets programatically.
"""
import csv
import io
import operator
from pathlib import Path, PosixPath
import re
import typing as t
from collections import Counter
from datetime import datetime
import re
import csv
from pathlib import Path
import numpy as np
from tqdm import tqdm
from logfile_parser import Parser
from omero.gateway import BlitzGateway, TagAnnotationWrapper
from omero.gateway import BlitzGateway, TagAnnotationWrapper, _DatasetWrapper
from tqdm import tqdm
class OmeroExplorer:
def __init__(self, host, user, password, min_date=(2020, 6, 1)):
def __init__(
self,
host: str,
user: str,
password: str,
min_date: t.Tuple[int, int, int] = (2020, 6, 1),
):
self.conn = BlitzGateway(user, password, host=host)
self.conn.connect()
......@@ -51,7 +56,7 @@ class OmeroExplorer:
def acq(self):
return {k: parse_annot(v, "acq") for k, v in self.raw_acq.items()}
def load(self, min_id=18000, min_date=None):
def load(self, min_id=0, min_date=None):
"""
:min_id: int
:min_date: tuple
......@@ -64,23 +69,38 @@ class OmeroExplorer:
if min_date:
if len(min_date) < 3:
min_date = min_date + tuple([1 for i in range(3 - len(min_date))])
min_date = min_date + tuple(
[1 for _ in range(3 - len(min_date))]
)
min_date = datetime(*min_date)
# sort by dates
dates = [d.getDate() for d in self._dsets_bak]
self._dsets_bak[:] = [a for b, a in sorted(zip(dates, self._dsets_bak))]
self._dsets_bak[:] = [
a
for _, a in sorted(
zip(dates, self._dsets_bak), key=lambda x: x[0]
)
]
self._dsets_bak = [d for d in self._dsets_bak if d.getDate() >= min_date]
self._dsets_bak = [
d for d in self._dsets_bak if d.getDate() >= min_date
]
self.dsets = self._dsets_bak
self.n_dsets
def image_ids(self):
return {
dset.getId(): [im.getId() for im in dset.listChildren()]
for dset in self.dsets
}
def dset(self, n):
try:
return [x for x in self.dsets if x.id == n][0]
except:
return
except Exception as e:
print(f"Could not fetch all data xsets: {e}")
def channels(self, setkey, present=True):
"""
......@@ -105,7 +125,9 @@ class OmeroExplorer:
self,
attr,
{
v.id: parse_annot(getattr(self, "raw_" + attr)[v.id], attr)
v.id: parse_annot(
getattr(self, "raw_" + attr)[v.id], attr
)
for v in self.dsets
},
)
......@@ -113,7 +135,9 @@ class OmeroExplorer:
for attr in ["acq", "log", "raw_acq", "raw_log"]:
setattr(
self, attr, {i.id: getattr(self, attr)[i.id] for i in self.dsets}
self,
attr,
{i.id: getattr(self, attr)[i.id] for i in self.dsets},
)
@property
......@@ -141,7 +165,9 @@ class OmeroExplorer:
if type(tags) is not list:
tags = [str(tags)]
self.dsets = [v for v in self.dsets if present == self.has_tags(v, tags)]
self.dsets = [
v for v in self.dsets if present == self.has_tags(v, tags)
]
self.n_dsets
@property
......@@ -158,7 +184,9 @@ class OmeroExplorer:
return self._tags
def get_timepoints(self):
self.image_wrappers = {d.id: list(d.listChildren())[0] for d in self.dsets}
self.image_wrappers = {
d.id: list(d.listChildren())[0] for d in self.dsets
}
return {k: i.getSizeT() for k, i in self.image_wrappers.items()}
......@@ -167,11 +195,14 @@ class OmeroExplorer:
op = operator.gt if op == "greater" else operator.le
self._timepoints = self.get_timepoints()
self.dsets = [v for v in tqdm(self.dsets) if op(self._timepoints[v.id], n)]
self.dsets = [
v for v in tqdm(self.dsets) if op(self._timepoints[v.id], n)
]
def microscope(self, microscope):
self.microscopes = {
dset.id: self.get_microscope(self.log[dset.id]) for dset in self.dsets
dset.id: self.get_microscope(self.log[dset.id])
for dset in self.dsets
}
self.n_dsets
......@@ -189,32 +220,21 @@ class OmeroExplorer:
def reset_backup(self, name):
self.dsets = self.backups[name]
def cExperiment(self, present=True):
self.dsets = [
v
for v in self.dsets
if present
* sum(
[
"cExperiment" in x.getFileName()
for x in v.listAnnotations()
if hasattr(x, "getFileName")
]
)
]
self.n_dsets
@staticmethod
def is_complete(logfile):
def is_complete(logfile: str):
return logfile.endswith("Experiment completed\r\r\n")
@staticmethod
def contains_regex(logfile):
pass
# return re.
def count_errors(logfile: str):
return re.findall("ERROR CAUGHT", logfile)
def tiler_cells(self, present=True):
self.__dsets = [v for v in self.dsets if present == tiler_cells_load(v)]
@staticmethod
def count_drift_alert(logfile: str):
return re.findall("High drift alert!", logfile)
@staticmethod
def is_interrupted(logfile: str):
return "Experiment stopped by user" in logfile
@property
def n_dsets(self):
......@@ -236,25 +256,21 @@ class OmeroExplorer:
def ids(self):
return [d.getId() for d in self.dsets]
# @property
# def acqs(self):
# if not hasattr(self, "_acqs") or len(self.__dict__.get("_acqs", [])) != len(
# self.dsets
# ):
# self._acqs = [get_annot(get_annotsets(d), "acq") for d in self.dsets]
# return self._acqs
def get_ph_params(self):
t = [
{
ch: [exp, v]
for ch, exp, v in zip(j["channel"], j["exposure"], j["voltage"])
for ch, exp, v in zip(
j["channel"], j["exposure"], j["voltage"]
)
if ch in {"GFPFast", "pHluorin405"}
}
for j in [i["channels"] for i in self.acqs]
]
ph_param_pairs = [(tuple(x.values())) for x in t if np.all(list(x.values()))]
ph_param_pairs = [
(tuple(x.values())) for x in t if np.all(list(x.values()))
]
return Counter([str(x) for x in ph_param_pairs])
......@@ -263,16 +279,41 @@ class OmeroExplorer:
# and group them for cleaning
pass
def group_by_date(tol=1):
dates = [x.getDate() for x in self.dsets]
distances = np.array(
[[abs(convert_to_hours(a - b)) for a in dates] for b in dates]
)
return explore_booldiag(distances > tol, 0, [])
def return_completed(self, kind="complete"):
return {
k: getattr(self, f"is_{kind}")(v.get("log", ""))
for k, v in self.cache.items()
}
def dset_count(
self,
dset: t.Union[int, _DatasetWrapper],
kind: str = "errors",
norm: bool = True,
):
if isinstance(dset, int):
dset = self.conn.getObject("Dataset", dset)
total_images_tps = sum([im.getSizeT() for im in dset.listChildren()])
return len(
getattr(self, f"count_{kind}")(
self.cache[dset.getId()].get("log", ""), norm=norm
)
) / (norm * total_images_tps)
def count_in_log(self, kind="errors", norm: bool = True):
return {
k: self.dset_count(k, kind=kind, norm=norm)
for k, v in self.cache.items()
}
@property
def complete(self):
self.completed = {k: self.is_complete(v) for k, v in self.raw_log_end.items()}
self.completed = {
k: self.is_complete(v.get("log", ""))
for k, v in self.cache.items()
}
self.dsets = [dset for dset in self.dsets if self.completed[dset.id]]
return self.n_dsets
......@@ -298,25 +339,6 @@ class OmeroExplorer:
return True
def explore_booldiag(bool_field, current_position, cluster_start_end):
# Recursively find the square clusters over the diagonal. Allows for duplicates
# returns a list of tuples with the start, end of clusters
if current_position < len(bool_field) - 1:
elements = np.where(bool_field[current_position])
if len(elements[0]) > 1:
start = elements[0][0]
end = elements[0][-1]
else:
start = elements[0][0]
end = elements[0][0]
cluster_start_end.append((start, end))
return explore_square(bool_field, end + 1, cluster_start_end)
else:
return cluster_start_end
_
def convert_to_hours(delta):
total_seconds = delta.total_seconds()
hours = int(total_seconds // 3600)
......@@ -324,49 +346,41 @@ def convert_to_hours(delta):
class Argo(OmeroExplorer):
def __init__(self,*args, **kwargs):
super().__init__(*args,**kwargs)
def get_creds():
return (
"upload",
"***REMOVED***", # OMERO Password
)
def list_files(dset):
return {x for x in dset.listAnnotations() if hasattr(x, "getFileName")}
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def annot_from_dset(dset, kind):
v = [x for x in dset.listAnnotations() if hasattr(x, "getFileName")]
infname = kind if kind is "log" else kind.title()
infname = kind if kind == "log" else kind.title()
try:
acqfile = [x for x in v if x.getFileName().endswith(infname + ".txt")][0]
acqfile = [x for x in v if x.getFileName().endswith(infname + ".txt")][
0
]
decoded = list(acqfile.getFileInChunks())[0].decode("utf-8")
acq = parse_annot(decoded, kind)
except:
except Exception as e:
print(f"Conversion from acquisition file failed: {e}")
return {}
return acq
def check_channels(acq, channels, _all=True):
I = set(acq["channels"]["channel"]).intersection(channels)
shared_channels = set(acq["channels"]["channel"]).intersection(channels)
condition = False
if _all:
if len(I) == len(channels):
if len(shared_channels) == len(channels):
condition = True
else:
if len(I):
if len(shared_channels):
condition = True
return condition
def get_chs(exptype):
# TODO Documentation
exptypes = {
"dual_ph": ("GFP", "pHluorin405", "mCherry"),
"ph": ("GFP", "pHluorin405"),
......@@ -377,7 +391,8 @@ def get_chs(exptype):
def load_annot_from_cache(exp_id, cache_dir="cache/"):
if type(cache_dir) is not PosixPath:
# TODO Documentation
if type(cache_dir) is not Path:
cache_dir = Path(cache_dir)
annot_sets = {}
......@@ -402,19 +417,11 @@ def parse_annot(str_io, fmt):
return parser.parse(io.StringIO(str_io))
def get_log_date(annot_sets):
log = get_annot(annot_sets, "log")
return log.get("date", None)
def get_log_microscope(annot_sets):
log = get_annot(annot_sets, "log")
return log.get("microscope", None)
def get_annotsets(dset):
annot_files = [
annot.getFile() for annot in dset.listAnnotations() if hasattr(annot, "getFile")
annot.getFile()
for annot in dset.listAnnotations()
if hasattr(annot, "getFile")
]
annot_sets = {
ftype[:-4].lower(): annot
......@@ -429,21 +436,18 @@ def get_annotsets(dset):
return annot_sets
# def has_tags(d, tags):
# if set(tags).intersection(annot_from_dset(d, "log").get("omero_tags", [])):
# return True
def load_acq(dset):
# TODO Documentation
try:
acq = annot_from_dset(dset, kind="acq")
return acq
except:
print("dset", dset.getId(), " failed acq loading")
except Exception as e:
print(f"dset{dset.getId()}failed acq loading: {e}")
return False
def has_channels(dset, exptype):
# TODO Documentation
acq = load_acq(dset)
if acq:
return check_channels(acq, get_chs(exptype))
......@@ -451,77 +455,12 @@ def has_channels(dset, exptype):
return
def get_id_from_name(exp_name, conn=None):
if conn is None:
conn = BlitzGateway(*get_creds(), host="islay.bio.ed.ac.uk", port=4064)
if not conn.isConnected():
conn.connect()
cand_dsets = [
d
for d in conn.getObjects("Dataset") # , opts={'offset': 10600,
# 'limit':500})
if exp_name in d.name
] # increase the offset for better speed
# return cand_dsets
if len(cand_dsets) > 1:
# Get date and try to find it using date and microscope name and date
# found = []
# for cand in cand_dsets:
# annot_sets = get_annotsets(cand)
# date = get_log_date(annot_sets)
# microscope = get_log_microscope(annot_sets)
# if date==date_name and microscope == microscope_name:
# found.append(cand)
# if True:#len(found)==1:
# return best_cand.id#best_cand = found[0]
if True:
print("Multiple options found. Selecting the one with most children")
max_dset = np.argmax(
[
len(list(conn.getObject("Dataset", c.id).listChildren()))
for c in cand_dsets
]
)
best_cand = cand_dsets[max_dset]
return best_cand.id
elif len(cand_dsets) == 1:
return cand_dsets[0].id
# Custom functions
def compare_dsets_voltages_exp(dsets):
a = {}
for d in dsets:
try:
acq = annot_from_dset(d, kind="acq")["channels"]
a[d.getId()] = {
k: (v, e)
for k, v, e in zip(acq["channel"], acq["voltage"], acq["exposure"])
}
except:
print(d, "didnt work")
return a
def get_logfile(dset):
# TODO Documentation
annot_file = [
annot.getFile()
for annot in dset.listAnnotations()
if hasattr(annot, "getFile") and annot.getFileName().endswith("log.txt")
if hasattr(annot, "getFile")
and annot.getFileName().endswith("log.txt")
][0]
return list(annot_file.getFileInChunks())[-1].decode("utf-8")
# 19920 -> 19300/19310
#
"""
ImageViewer class, used to look at individual or multiple traps over time.
Example of usage:
fpath = "/home/alan/Documents/dev/skeletons/scripts/data/16543_2019_07_16_aggregates_CTP_switch_2_0glu_0_0glu_URA7young_URA8young_URA8old_01/URA8_young018.h5"
tile_id = 9
trange = list(range(0, 10))
ncols = 8
riv = remoteImageViewer(fpath)
riv.plot_labelled_trap(tile_id, trange, [0], ncols=ncols)
"""
import re
import typing as t
import h5py
from abc import ABC
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from PIL import Image
from skimage.morphology import dilation
from agora.io.cells import Cells
from agora.io.metadata import dispatch_metadata_parser
from aliby.tile.tiler import Tiler, TilerParameters
from aliby.utils.plot import stretch_clip
default_colours = {
"Brightfield": "Greys_r",
"GFP": "Greens_r",
"mCherry": "Reds_r",
"cell_label": sns.color_palette("Paired", as_cmap=True),
}
def custom_imshow(a, norm=None, cmap=None, *args, **kwargs):
"""
Wrapper on plt.imshow function.
"""
if cmap is None:
cmap = "Greys_r"
return plt.imshow(
a,
*args,
cmap=cmap,
interpolation=None,
interpolation_stage="rgba",
**kwargs,
)
class BaseImageViewer(ABC):
def __init__(self, fpath):
self._fpath = fpath
attrs = dispatch_metadata_parser(fpath.parent)
self._logfiles_meta = {}
self.image_id = attrs.get("image_id")
if self.image_id is None:
with h5py.File(fpath, "r") as f:
self.image_id = f.attrs.get("image_id")
assert self.image_id is not None, "No valid image_id found in metadata"
@property
def shape(self):
return self.tiler.image.shape
@property
def ntraps(self):
return self.cells.ntraps
@property
def max_labels(self):
# Print max cell label in whole experiment
return [max(x) for x in self.cells.labels]
def labels_at_time(self, tp: int):
# Print cell label at a given time-point
return self.cells.labels_at_time(tp)
class LocalImageViewer(BaseImageViewer):
"""
Tool to generate figures from local files, either zarr or files organised
in directories.
TODO move common functionality from RemoteImageViewer to BaseImageViewer
"""
def __init__(self, results_path: str, data_path: str):
super().__init__(results_path)
from aliby.io.image import ImageDir, ImageZarr
self._image_class = (
ImageZarr if data_path.endswith(".zar") else ImageDir
)
with dispatch_image(data_path)(data_path) as image:
self.tiler = Tiler(
image.data,
self._meta if hasattr(self, "_meta") else self._logfiles_meta,
TilerParameters.default(),
)
self.cells = Cells.from_source(results_path)
class RemoteImageViewer(BaseImageViewer):
"""
This ImageViewer combines fetching remote images with tiling and outline display.
"""
_credentials = ("host", "username", "password")
def __init__(
self,
results_path: str,
server_info: t.Dict[str, str],
):
super().__init__(results_path)
from aliby.io.omero import UnsafeImage as OImage
self._server_info = server_info or {
k: attrs["parameters"]["general"][k] for k in self._credentials
}
self._image_instance = OImage(self.image_id, **self._server_info)
self.tiler = Tiler.from_h5(self._image_instance, results_path)
self.cells = Cells.from_source(results_path)
def random_valid_trap_tp(
self,
min_ncells: int = None,
min_consecutive_tps: int = None,
label_modulo: int = None,
):
# Call Cells convenience function to pick a random trap and tp
# containing cells for x cells for y
return self.cells.random_valid_trap_tp(
min_ncells=min_ncells,
min_consecutive_tps=min_consecutive_tps,
)
def get_entire_position(self):
raise (NotImplementedError)
def get_position_timelapse(self):
raise (NotImplementedError)
@property
def full(self):
if not hasattr(self, "_full"):
self._full = {}
return self._full
def get_tc(self, tp, channel=None, server_info=None):
server_info = server_info or self._server_info
channel = channel or self.tiler.ref_channel
with self._image_class(self.image_id, **server_info) as image:
self.tiler.image = image.data
return self.tiler.get_tc(tp, channel)
def _find_channels(self, channels: str, guess: bool = True):
channels = channels or self.tiler.ref_channel
if isinstance(channels, (int, str)):
channels = [channels]
if isinstance(channels[0], str):
if guess:
channels = [self.tiler.channels.index(ch) for ch in channels]
else:
channels = [
re.search(ch, tiler_channels)
for ch in channels
for tiler_channels in self.tiler.channels
]
return channels
def get_pos_timepoints(
self,
tps: t.Union[int, t.Collection[int]],
channels: t.Union[str, t.Collection[str]] = None,
z: int = None,
server_info=None,
):
if tps and not isinstance(tps, t.Collection):
tps = range(tps)
# TODO add support for multiple channels or refactor
if channels and not isinstance(channels, t.Collection):
channels = [channels]
if z is None:
z = 0
server_info = server_info or self._server_info
channels = 0 or self._find_channels(channels)
z = z or self.tiler.ref_z
ch_tps = [(channels[0], tp) for tp in tps]
image = self._image_instance
self.tiler.image = image.data
for ch, tp in ch_tps:
if (ch, tp) not in self.full:
self.full[(ch, tp)] = self.tiler.get_tiles_timepoint(
tp, channels=[ch], z=[z]
)[:, 0, 0, z, ...]
requested_trap = {tp: self.full[(ch, tp)] for ch, tp in ch_tps}
return requested_trap
def get_labelled_trap(
self,
tile_id: int,
tps: t.Union[range, t.Collection[int]],
channels=None,
concatenate=True,
**kwargs,
) -> t.Tuple[np.array]:
"""
Core method to fetch traps and labels together
"""
imgs = self.get_pos_timepoints(tps, channels=channels, **kwargs)
imgs_list = [x[tile_id] for x in imgs.values()]
outlines = [
self.cells.at_time(tp, kind="edgemask").get(tile_id, [])
for tp in tps
]
lbls = [self.cells.labels_at_time(tp).get(tile_id, []) for tp in tps]
lbld_outlines = [
np.stack([mask * lbl for mask, lbl in zip(maskset, lblset)]).max(
axis=0
)
if len(lblset)
else np.zeros_like(imgs_list[0]).astype(bool)
for maskset, lblset in zip(outlines, lbls)
]
if concatenate:
lbld_outlines = np.concatenate(lbld_outlines, axis=1)
imgs_list = np.concatenate(imgs_list, axis=1)
return lbld_outlines, imgs_list
def get_images(self, tile_id, trange, channels, **kwargs):
"""
Wrapper to fetch images
"""
out = None
imgs = {}
for ch in self._find_channels(channels):
out, imgs[ch] = self.get_labelled_trap(
tile_id, trange, channels=[ch], **kwargs
)
return out, imgs
def plot_labelled_trap(
self,
tile_id: int,
channels,
trange: t.Union[range, t.Collection[int]],
remove_axis: bool = False,
savefile: str = None,
skip_outlines: bool = False,
norm: str = None,
ncols: int = None,
local_colours: bool = True,
img_plot_kwargs: dict = {},
lbl_plot_kwargs: dict = {"alpha": 0.8},
**kwargs,
):
"""Wrapper to plot time-lapses of individual traps
Use Cells and Tiler to generate images of cells with their resulting
outlines.
Parameters
----------
tile_id : int
Identifier of trap
channels : Union[str, int]
Channels to use
trange : t.Union[range, t.Collection[int]]
Range or collection indicating the time-points to use.
remove_axis : bool
None, "off", or "x". Determines whether to remove the x-axis, both
axes or none.
savefile : str
Saves file to a location.
skip_outlines : bool
Do not add overlay with outlines
norm : str
Normalise signals
ncols : int
Number of columns to plot.
local_colours : bool
Bypass label indicators to guarantee that colours are not repeated
(TODO implement)
img_plot_kwargs : dict
Arguments to pass to plt.imshow used for images.
lbl_plot_kwargs : dict
Keyword arguments to pass to label plots.
**kwargs : dict
Additional keyword arguments passed to ImageViewer.get_images.
Examples
--------
FIXME: Add docs.
"""
if ncols is None:
ncols = len(trange)
nrows = int(np.ceil(len(trange) / ncols))
width = self.tiler.tile_size * ncols
out, images = self.get_images(tile_id, trange, channels, **kwargs)
# dilation makes outlines easier to see
out = dilation(out).astype(float)
out[out == 0] = np.nan
channel_labels = [
self.tiler.channels[ch] if isinstance(ch, int) else ch
for ch in channels
]
assert not norm or norm in (
"l1",
"l2",
"max",
), "Invalid norm argument."
if norm and norm in ("l1", "l2", "max"):
images = {k: stretch_clip(v) for k, v in images.items()}
images = [concat_pad(img, width, nrows) for img in images.values()]
# TODO convert to RGB to draw fluorescence with colour
tiled_imgs = {}
tiled_imgs["img"] = np.concatenate(images, axis=0)
tiled_imgs["cell_labels"] = np.concatenate(
[concat_pad(out, width, nrows) for _ in images], axis=0
)
custom_imshow(
tiled_imgs["img"],
**img_plot_kwargs,
)
custom_imshow(
tiled_imgs["cell_labels"],
cmap=sns.color_palette("Paired", as_cmap=True),
**lbl_plot_kwargs,
)
if remove_axis is True:
plt.axis("off")
elif remove_axis == "x":
plt.tick_params(
axis="x",
which="both",
bottom=False,
top=False,
labelbottom=False,
)
if remove_axis != "True":
plt.yticks(
ticks=[
(i * self.tiler.tile_size * nrows)
+ self.tiler.tile_size * nrows / 2
for i in range(len(channels))
],
labels=channel_labels,
)
if not remove_axis:
xlabels = (
["+ {} ".format(i) for i in range(ncols)]
if nrows > 1
else list(trange)
)
plt.xlabel("Time-point")
plt.xticks(
ticks=[self.tiler.tile_size * (i + 0.5) for i in range(ncols)],
labels=xlabels,
)
if not np.any(out):
print("ImageViewer:Warning:No cell outlines found")
if savefile:
plt.savefig(savefile, bbox_inches="tight", dpi=300)
plt.close()
else:
plt.show()
def concat_pad(a: np.array, width, nrows):
"""
Melt an array into having multiple blocks as rows
"""
return np.concatenate(
np.array_split(
np.pad(
a,
# ((0, 0), (0, width - (a.shape[1] % width))),
((0, 0), (0, a.shape[1] % width)),
constant_values=np.nan,
),
nrows,
axis=1,
)
)
#!/usr/bin/env jupyter
"""
Basic plotting functions for cell visualisation
"""
import typing as t
import numpy as np
from grid_strategy import strategies
from matplotlib import pyplot as plt
def plot_overlay(
bg: np.ndarray, fg: np.ndarray, alpha: float = 1.0, ax=plt
) -> None:
"""
Plot two images, one on top of the other.
"""
ax1 = ax.imshow(
bg, cmap=plt.cm.gray, interpolation="none", interpolation_stage="rgba"
)
ax2 = ax.imshow(
stretch(fg),
alpha=alpha,
interpolation="none",
interpolation_stage="rgba",
)
plt.axis("off")
return ax1, ax2
def plot_overlay_in_square(data: t.Tuple[np.ndarray, np.ndarray]):
"""
Plot images in an automatically-arranged grid.
"""
specs = strategies.SquareStrategy("center").get_grid(len(data))
for i, (gs, (tile, mask)) in enumerate(zip(specs, data)):
ax = plt.subplot(gs)
plot_overlay(tile, mask, ax=ax)
def plot_in_square(data: t.Iterable):
"""
Plot images in an automatically-arranged grid. Only takes one mask
"""
specs = strategies.SquareStrategy("center").get_grid(len(data))
for i, (gs, datum) in enumerate(zip(specs, data)):
ax = plt.subplot(gs)
ax.imshow(datum)
def stretch_clip(image, clip=True):
"""
Performs contrast stretching on an input image.
This function takes an array-like input image and enhances its contrast by adjusting
the dynamic range of pixel values. It first scales the pixel values between 0 and 255,
then clips the values that are below the 2nd percentile or above the 98th percentile.
Finally, the pixel values are scaled to the range between 0 and 1.
Parameters
----------
image : array-like
Input image.
Returns
-------
stretched : ndarray
Contrast-stretched version of the input image.
Examples
--------
FIXME: Add docs.
"""
from copy import copy
image = image[~np.isnan(image)]
image = ((image - image.min()) / (image.max() - image.min())) * 255
if clip:
minval = np.percentile(image, 2)
maxval = np.percentile(image, 98)
image = np.clip(image, minval, maxval)
image = (image - minval) / (maxval - minval)
return image
def stretch(image):
nona = image[~np.isnan(image)]
return (image - nona.min()) / (nona.max() - nona.min())