Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • swain-lab/aliby/aliby-mirror
  • swain-lab/aliby/alibylite
2 results
Show changes
Showing
with 619 additions and 1957 deletions
"""
Testing the "run" functions in the pipeline elements.
"""
import pytest
pytest.mark.skip(reason='All tests still WIP')
# Todo: data needed: an experiment object
# Todo: data needed: an sqlite database
# Todo: data needed: a Shelf storage
class TestPipeline:
def test_experiment(self):
pass
def test_omero_experiment(self):
pass
def test_tiler(self):
pass
def test_baby_client(self):
pass
def test_baby_runner(self):
pass
def test_pipeline(self):
pass
import pytest
pytest.mark.skip("all tests still WIP")
from core.core import PersistentDict
# Todo: temporary file needed
class TestPersistentDict:
@pytest.fixture(autouse=True, scope='class')
def _get_json_file(self, tmp_path):
self._filename = tmp_path / 'persistent_dict.json'
def test_persistent_dict(self):
p = PersistentDict(self._filename)
p['hello/from/the/other/side'] = "adele"
p['hello/how/you/doing'] = 'lionel'
# Todo: run checks
# Todo: data needed - small experiment
class TestExperiment:
def test_shape(self):
pass
def test_positions(self):
pass
def test_channels(self):
pass
def test_hypercube(self):
pass
# Todo: data needed - a dummy OMERO server
class TestConnection:
def test_dataset(self):
pass
def test_image(self):
pass
# Todo data needed - a position
class TestTimelapse:
def test_id(self):
pass
def test_name(self):
pass
def test_size_z(self):
pass
def test_size_c(self):
pass
def test_size_t(self):
pass
def test_size_x(self):
pass
def test_size_y(self):
pass
def test_channels(self):
pass
def test_channel_index(self):
pass
# Todo: data needed image and template
class TestTrapUtils:
def test_trap_locations(self):
pass
def test_tile_shape(self):
pass
def test_get_tile(self):
pass
def test_centre(self):
pass
# Todo: data needed - a functional experiment object
class TestTiler:
def test_n_timepoints(self):
pass
def test_n_traps(self):
pass
def test_get_trap_timelapse(self):
pass
def test_get_trap_timepoints(self):
pass
# Todo: data needed - a functional tiler object
# Todo: running server needed
class TestBabyClient:
def test_get_new_session(self):
pass
def test_queue_image(self):
pass
def test_get_segmentation(self):
pass
# Todo: data needed - a functional tiler object
class TestBabyRunner:
def test_model_choice(self):
pass
def test_properties(self):
pass
def test_segment(self):
pass
"""Segment/segmented pipelines.
Includes splitting the image into traps/parts,
cell segmentation, nucleus segmentation."""
import warnings
from functools import lru_cache
import h5py
import numpy as np
from pathlib import Path, PosixPath
from skimage.registration import phase_cross_correlation
from agora.abc import ParametersABC, ProcessABC
from aliby.traps import segment_traps
from agora.io.writer import load_attributes
trap_template_directory = Path(__file__).parent / "trap_templates"
# TODO do we need multiple templates, one for each setup?
trap_template = np.array([]) # np.load(trap_template_directory / "trap_prime.npy")
def get_tile_shapes(x, tile_size, max_shape):
half_size = tile_size // 2
xmin = int(x[0] - half_size)
ymin = max(0, int(x[1] - half_size))
if xmin + tile_size > max_shape[0]:
xmin = max_shape[0] - tile_size
if ymin + tile_size > max_shape[1]:
ymin = max_shape[1] - tile_size
return xmin, xmin + tile_size, ymin, ymin + tile_size
###################### Dask versions ########################
class Trap:
def __init__(self, centre, parent, size, max_size):
self.centre = centre
self.parent = parent # Used to access drifts
self.size = size
self.half_size = size // 2
self.max_size = max_size
def padding_required(self, tp):
"""Check if we need to pad the trap image for this time point."""
try:
assert all(self.at_time(tp) - self.half_size >= 0)
assert all(self.at_time(tp) + self.half_size <= self.max_size)
except AssertionError:
return True
return False
def at_time(self, tp):
"""Return trap centre at time tp"""
drifts = self.parent.drifts
return self.centre - np.sum(drifts[:tp], axis=0)
def as_tile(self, tp):
"""Return trap in the OMERO tile format of x, y, w, h
Also returns the padding necessary for this tile.
"""
x, y = self.at_time(tp)
# tile bottom corner
x = int(x - self.half_size)
y = int(y - self.half_size)
return x, y, self.size, self.size
def as_range(self, tp):
"""Return trap in a range format, two slice objects that can be used in Arrays"""
x, y, w, h = self.as_tile(tp)
return slice(x, x + w), slice(y, y + h)
class TrapLocations:
def __init__(self, initial_location, tile_size, max_size=1200, drifts=[]):
self.tile_size = tile_size
self.max_size = max_size
self.initial_location = initial_location
self.traps = [
Trap(centre, self, tile_size, max_size) for centre in initial_location
]
self.drifts = drifts
@classmethod
def from_source(cls, fpath: str):
with h5py.File(fpath, "r") as f:
# TODO read tile size from file metadata
drifts = f["trap_info/drifts"][()]
tlocs = cls(f["trap_info/trap_locations"][()], tile_size=96, drifts=drifts)
return tlocs
@property
def shape(self):
return len(self.traps), len(self.drifts)
def __len__(self):
return len(self.traps)
def __iter__(self):
yield from self.traps
def padding_required(self, tp):
return any([trap.padding_required(tp) for trap in self.traps])
def to_dict(self, tp):
res = dict()
if tp == 0:
res["trap_locations"] = self.initial_location
res["attrs/tile_size"] = self.tile_size
res["attrs/max_size"] = self.max_size
res["drifts"] = np.expand_dims(self.drifts[tp], axis=0)
# res['processed_timepoints'] = tp
return res
@classmethod
def read_hdf5(cls, file):
with h5py.File(file, "r") as hfile:
trap_info = hfile["trap_info"]
initial_locations = trap_info["trap_locations"][()]
drifts = trap_info["drifts"][()]
max_size = trap_info.attrs["max_size"]
tile_size = trap_info.attrs["tile_size"]
trap_locs = cls(initial_locations, tile_size, max_size=max_size)
trap_locs.drifts = drifts
return trap_locs
class TilerParameters(ParametersABC):
def __init__(
self, tile_size: int, ref_channel: str, ref_z: int, template_name: str = None
):
self.tile_size = tile_size
self.ref_channel = ref_channel
self.ref_z = ref_z
self.template_name = template_name
@classmethod
def from_template(cls, template_name: str, ref_channel: str, ref_z: int):
return cls(template.shape[0], ref_channel, ref_z, template_path=template_name)
@classmethod
def default(cls):
return cls(96, "Brightfield", 0)
class Tiler(ProcessABC):
"""A dummy TimelapseTiler object fora Dask Demo.
Does trap finding and image registration."""
def __init__(
self,
image,
metadata,
parameters: TilerParameters,
):
super().__init__(parameters)
self.image = image
self.channels = metadata["channels"]
self.ref_channel = self.get_channel_index(parameters.ref_channel)
@classmethod
def from_image(cls, image, parameters: TilerParameters):
return cls(image.data, image.metadata, parameters)
@classmethod
def from_hdf5(cls, image, filepath, tile_size=None):
trap_locs = TrapLocations.read_hdf5(filepath)
metadata = load_attributes(filepath)
metadata["channels"] = metadata["channels/channel"].tolist()
if tile_size is None:
tile_size = trap_locs.tile_size
return Tiler(
image=image,
metadata=metadata,
template=None,
tile_size=tile_size,
trap_locs=trap_locs,
)
@lru_cache(maxsize=2)
def get_tc(self, t, c):
# Get image
full = self.image[t, c].compute() # FORCE THE CACHE
return full
@property
def shape(self):
c, t, z, y, x = self.image.shape
return (c, t, x, y, z)
@property
def n_processed(self):
if not hasattr(self, "_n_processed"):
self._n_processed = 0
return self._n_processed
@n_processed.setter
def n_processed(self, value):
self._n_processed = value
@property
def n_traps(self):
return len(self.trap_locs)
@property
def finished(self):
return self.n_processed == self.image.shape[0]
def _initialise_traps(self, tile_size):
"""Find initial trap positions.
Removes all those that are too close to the edge so no padding is necessary.
"""
half_tile = tile_size // 2
max_size = min(self.image.shape[-2:])
initial_image = self.image[
0, self.ref_channel, self.ref_z
] # First time point, first channel, first z-position
trap_locs = segment_traps(initial_image, tile_size)
trap_locs = [
[x, y]
for x, y in trap_locs
if half_tile < x < max_size - half_tile
and half_tile < y < max_size - half_tile
]
self.trap_locs = TrapLocations(trap_locs, tile_size)
def find_drift(self, tp):
# TODO check that the drift doesn't move any tiles out of the image, remove them from list if so
prev_tp = max(0, tp - 1)
drift, error, _ = phase_cross_correlation(
self.image[prev_tp, self.ref_channel, self.ref_z],
self.image[tp, self.ref_channel, self.ref_z],
)
self.trap_locs.drifts.append(drift)
def get_tp_data(self, tp, c):
traps = []
full = self.get_tc(tp, c)
# if self.trap_locs.padding_required(tp):
for trap in self.trap_locs:
ndtrap = self.ifoob_pad(full, trap.as_range(tp))
traps.append(ndtrap)
return np.stack(traps)
def get_trap_data(self, trap_id, tp, c):
full = self.get_tc(tp, c)
trap = self.trap_locs.traps[trap_id]
ndtrap = self.ifoob_pad(full, trap.as_range(tp))
return ndtrap
@staticmethod
def ifoob_pad(full, slices):
"""
Returns the slices padded if it is out of bounds
Parameters:
----------
full: (zstacks, max_size, max_size) ndarray
Entire position with zstacks as first axis
slices: tuple of two slices
Each slice indicates an axis to index
Returns
Trap for given slices, padded with median if needed, or np.nan if the padding is too much
"""
max_size = full.shape[-1]
y, x = [slice(max(0, s.start), min(max_size, s.stop)) for s in slices]
trap = full[:, y, x]
padding = np.array(
[(-min(0, s.start), -min(0, max_size - s.stop)) for s in slices]
)
if padding.any():
tile_size = slices[0].stop - slices[0].start
if (padding > tile_size / 4).any():
trap = np.full((full.shape[0], tile_size, tile_size), np.nan)
else:
trap = np.pad(trap, [[0, 0]] + padding.tolist(), "median")
return trap
def run_tp(self, tp):
assert tp >= self.n_processed, "Time point already processed"
# TODO check contiguity?
if self.n_processed == 0:
self._initialise_traps(self.tile_size)
self.find_drift(tp) # Get drift
# update n_processed
self.n_processed += 1
# Return result for writer
return self.trap_locs.to_dict(tp)
def run(self, tp):
if self.n_processed == 0:
self._initialise_traps(self.tile_size)
self.find_drift(tp) # Get drift
# update n_processed
self.n_processed += 1
# Return result for writer
return self.trap_locs.to_dict(tp)
# The next set of functions are necessary for the extraction object
def get_traps_timepoint(self, tp, tile_size=None, channels=None, z=None):
# FIXME we currently ignore the tile size
# FIXME can we ignore z(always give)
res = []
for c in channels:
val = self.get_tp_data(tp, c)[:, z] # Only return requested z
# positions
# Starts at traps, z, y, x
# Turn to Trap, C, T, X, Y, Z order
val = val.swapaxes(1, 3).swapaxes(1, 2)
val = np.expand_dims(val, axis=1)
res.append(val)
return np.stack(res, axis=1)
def get_channel_index(self, item):
for i, ch in enumerate(self.channels):
if item in ch:
return i
def get_position_annotation(self):
# TODO required for matlab support
return None
"""
A set of utilities for dealing with ALCATRAS traps
"""
import numpy as np
from tqdm import tqdm
from skimage import transform, feature
from skimage.filters.rank import entropy
from skimage.filters import threshold_otsu
from skimage.segmentation import clear_border
from skimage.measure import label, regionprops
from skimage.morphology import disk, closing, square
def stretch_image(image):
image = ((image - image.min()) / (image.max() - image.min())) * 255
minval = np.percentile(image, 2)
maxval = np.percentile(image, 98)
image = np.clip(image, minval, maxval)
image = (image - minval) / (maxval - minval)
return image
def segment_traps(image, tile_size, downscale=0.4):
# Make image go between 0 and 255
img = image # Keep a memory of image in case need to re-run
# stretched = stretch_image(image)
# img = stretch_image(image)
# TODO Optimise the hyperparameters
disk_radius = int(min([0.01 * x for x in img.shape]))
min_area = 0.2 * (tile_size ** 2)
if downscale != 1:
img = transform.rescale(image, downscale)
entropy_image = entropy(img, disk(disk_radius))
if downscale != 1:
entropy_image = transform.rescale(entropy_image, 1 / downscale)
# apply threshold
thresh = threshold_otsu(entropy_image)
bw = closing(entropy_image > thresh, square(3))
# remove artifacts connected to image border
cleared = clear_border(bw)
# label image regions
label_image = label(cleared)
areas = [
region.area
for region in regionprops(label_image)
if region.area > min_area and region.area < tile_size ** 2 * 0.8
]
traps = (
np.array(
[
region.centroid
for region in regionprops(label_image)
if region.area > min_area and region.area < tile_size ** 2 * 0.8
]
)
.round()
.astype(int)
)
ma = (
np.array(
[
region.minor_axis_length
for region in regionprops(label_image)
if region.area > min_area and region.area < tile_size ** 2 * 0.8
]
)
.round()
.astype(int)
)
maskx = (tile_size // 2 < traps[:, 0]) & (
traps[:, 0] < image.shape[0] - tile_size // 2
)
masky = (tile_size // 2 < traps[:, 1]) & (
traps[:, 1] < image.shape[1] - tile_size // 2
)
traps = traps[maskx & masky, :]
ma = ma[maskx & masky]
chosen_trap_coords = np.round(traps[ma.argmin()]).astype(int)
x, y = chosen_trap_coords
template = image[
x - tile_size // 2 : x + tile_size // 2, y - tile_size // 2 : y + tile_size // 2
]
traps = identify_trap_locations(image, template)
if len(traps) < 10 and downscale != 1:
print("Trying again.")
return segment_traps(image, tile_size, downscale=1)
return traps
# def segment_traps(image, tile_size, downscale=0.4):
# # Make image go between 0 and 255
# img = image # Keep a memory of image in case need to re-run
# image = stretch_image(image)
# # TODO Optimise the hyperparameters
# disk_radius = int(min([0.01 * x for x in img.shape]))
# min_area = 0.1 * (tile_size ** 2)
# if downscale != 1:
# img = transform.rescale(image, downscale)
# entropy_image = entropy(img, disk(disk_radius))
# if downscale != 1:
# entropy_image = transform.rescale(entropy_image, 1 / downscale)
# # apply threshold
# thresh = threshold_otsu(entropy_image)
# bw = closing(entropy_image > thresh, square(3))
# # remove artifacts connected to image border
# cleared = clear_border(bw)
# # label image regions
# label_image = label(cleared)
# traps = [
# region.centroid for region in regionprops(label_image) if region.area > min_area
# ]
# if len(traps) < 10 and downscale != 1:
# print("Trying again.")
# return segment_traps(image, tile_size, downscale=1)
# return traps
def identify_trap_locations(
image, trap_template, optimize_scale=True, downscale=0.35, trap_size=None
):
"""
Identify the traps in a single image based on a trap template.
This assumes a trap template that is similar to the image in question
(same camera, same magification; ideally same experiment).
This method speeds up the search by downscaling both the image and
the trap template before running the template match.
It also optimizes the scale and the rotation of the trap template.
:param image:
:param trap_template:
:param optimize_scale:
:param downscale:
:param trap_rotation:
:return:
"""
trap_size = trap_size if trap_size is not None else trap_template.shape[0]
# Careful, the image is float16!
img = transform.rescale(image.astype(float), downscale)
temp = transform.rescale(trap_template, downscale)
# TODO random search hyperparameter optimization
# optimize rotation
matches = {
rotation: feature.match_template(
img,
transform.rotate(temp, rotation, cval=np.median(img)),
pad_input=True,
mode="median",
)
** 2
for rotation in [0, 90, 180, 270]
}
best_rotation = max(matches, key=lambda x: np.percentile(matches[x], 99.9))
temp = transform.rotate(temp, best_rotation, cval=np.median(img))
if optimize_scale:
scales = np.linspace(0.5, 2, 10)
matches = {
scale: feature.match_template(
img, transform.rescale(temp, scale), mode="median", pad_input=True
)
** 2
for scale in scales
}
best_scale = max(matches, key=lambda x: np.percentile(matches[x], 99.9))
matched = matches[best_scale]
else:
matched = feature.match_template(img, temp, pad_input=True, mode="median")
coordinates = feature.peak_local_max(
transform.rescale(matched, 1 / downscale),
min_distance=int(trap_template.shape[0] * 0.70),
exclude_border=(trap_size // 3),
)
return coordinates
def get_tile_shapes(x, tile_size, max_shape):
half_size = tile_size // 2
xmin = int(x[0] - half_size)
ymin = max(0, int(x[1] - half_size))
# if xmin + tile_size > max_shape[0]:
# xmin = max_shape[0] - tile_size
# if ymin + tile_size > max_shape[1]:
# # ymin = max_shape[1] - tile_size
# return max(xmin, 0), xmin + tile_size, max(ymin, 0), ymin + tile_size
return xmin, xmin + tile_size, ymin, ymin + tile_size
def in_image(img, xmin, xmax, ymin, ymax, xidx=2, yidx=3):
if xmin >= 0 and ymin >= 0:
if xmax < img.shape[xidx] and ymax < img.shape[yidx]:
return True
else:
return False
def get_xy_tile(img, xmin, xmax, ymin, ymax, xidx=2, yidx=3, pad_val=None):
if pad_val is None:
pad_val = np.median(img)
# Get the tile from the image
idx = [slice(None)] * len(img.shape)
idx[xidx] = slice(max(0, xmin), min(xmax, img.shape[xidx]))
idx[yidx] = slice(max(0, ymin), min(ymax, img.shape[yidx]))
tile = img[tuple(idx)]
# Check if the tile is in the image
if in_image(img, xmin, xmax, ymin, ymax, xidx, yidx):
return tile
else:
# Add padding
pad_shape = [(0, 0)] * len(img.shape)
pad_shape[xidx] = (max(-xmin, 0), max(xmax - img.shape[xidx], 0))
pad_shape[yidx] = (max(-ymin, 0), max(ymax - img.shape[yidx], 0))
tile = np.pad(tile, pad_shape, constant_values=pad_val)
return tile
def get_trap_timelapse(
raw_expt, trap_locations, trap_id, tile_size=117, channels=None, z=None
):
"""
Get a timelapse for a given trap by specifying the trap_id
:param trap_id: An integer defining which trap to choose. Counted
between 0 and Tiler.n_traps - 1
:param tile_size: The size of the trap tile (centered around the
trap as much as possible, edge cases exist)
:param channels: Which channels to fetch, indexed from 0.
If None, defaults to [0]
:param z: Which z_stacks to fetch, indexed from 0.
If None, defaults to [0].
:return: A numpy array with the timelapse in (C,T,X,Y,Z) order
"""
# Set the defaults (list is mutable)
channels = channels if channels is not None else [0]
z = z if z is not None else [0]
# Get trap location for that id:
trap_centers = [trap_locations[i][trap_id] for i in range(len(trap_locations))]
max_shape = (raw_expt.shape[2], raw_expt.shape[3])
tiles_shapes = [
get_tile_shapes((x[0], x[1]), tile_size, max_shape) for x in trap_centers
]
timelapse = [
get_xy_tile(
raw_expt[channels, i, :, :, z], xmin, xmax, ymin, ymax, pad_val=None
)
for i, (xmin, xmax, ymin, ymax) in enumerate(tiles_shapes)
]
return np.hstack(timelapse)
def get_trap_timelapse_omero(
raw_expt, trap_locations, trap_id, tile_size=117, channels=None, z=None, t=None
):
"""
Get a timelapse for a given trap by specifying the trap_id
:param raw_expt: A Timelapse object from which data is obtained
:param trap_id: An integer defining which trap to choose. Counted
between 0 and Tiler.n_traps - 1
:param tile_size: The size of the trap tile (centered around the
trap as much as possible, edge cases exist)
:param channels: Which channels to fetch, indexed from 0.
If None, defaults to [0]
:param z: Which z_stacks to fetch, indexed from 0.
If None, defaults to [0].
:return: A numpy array with the timelapse in (C,T,X,Y,Z) order
"""
# Set the defaults (list is mutable)
channels = channels if channels is not None else [0]
z_positions = z if z is not None else [0]
times = (
t if t is not None else np.arange(raw_expt.shape[1])
) # TODO choose sub-set of time points
shape = (len(channels), len(times), tile_size, tile_size, len(z_positions))
# Get trap location for that id:
zct_tiles, slices, trap_ids = all_tiles(
trap_locations, shape, raw_expt, z_positions, channels, times, [trap_id]
)
# TODO Make this an explicit function in TimelapseOMERO
images = raw_expt.pixels.getTiles(zct_tiles)
timelapse = np.full(shape, np.nan)
total = len(zct_tiles)
for (z, c, t, _), (y, x), image in tqdm(
zip(zct_tiles, slices, images), total=total
):
ch = channels.index(c)
tp = times.tolist().index(t)
z_pos = z_positions.index(z)
timelapse[ch, tp, x[0] : x[1], y[0] : y[1], z_pos] = image
# for x in timelapse: # By channel
# np.nan_to_num(x, nan=np.nanmedian(x), copy=False)
return timelapse
def all_tiles(trap_locations, shape, raw_expt, z_positions, channels, times, traps):
_, _, x, y, _ = shape
_, _, MAX_X, MAX_Y, _ = raw_expt.shape
trap_ids = []
zct_tiles = []
slices = []
for z in z_positions:
for ch in channels:
for t in times:
for trap_id in traps:
centre = trap_locations[t][trap_id]
xmin, ymin, xmax, ymax, r_xmin, r_ymin, r_xmax, r_ymax = tile_where(
centre, x, y, MAX_X, MAX_Y
)
slices.append(
((r_ymin - ymin, r_ymax - ymin), (r_xmin - xmin, r_xmax - xmin))
)
tile = (r_ymin, r_xmin, r_ymax - r_ymin, r_xmax - r_xmin)
zct_tiles.append((z, ch, t, tile))
trap_ids.append(trap_id) # So we remember the order!
return zct_tiles, slices, trap_ids
def tile_where(centre, x, y, MAX_X, MAX_Y):
# Find the position of the tile
xmin = int(centre[1] - x // 2)
ymin = int(centre[0] - y // 2)
xmax = xmin + x
ymax = ymin + y
# What do we actually have available?
r_xmin = max(0, xmin)
r_xmax = min(MAX_X, xmax)
r_ymin = max(0, ymin)
r_ymax = min(MAX_Y, ymax)
return xmin, ymin, xmax, ymax, r_xmin, r_ymin, r_xmax, r_ymax
def get_tile(shape, center, raw_expt, ch, t, z):
"""Returns a tile from the raw experiment with a given shape.
:param shape: The shape of the tile in (C, T, Z, Y, X) order.
:param center: The x,y position of the centre of the tile
:param
"""
_, _, x, y, _ = shape
_, _, MAX_X, MAX_Y, _ = raw_expt.shape
tile = np.full(shape, np.nan)
# Find the position of the tile
xmin = int(center[1] - x // 2)
ymin = int(center[0] - y // 2)
xmax = xmin + x
ymax = ymin + y
# What do we actually have available?
r_xmin = max(0, xmin)
r_xmax = min(MAX_X, xmax)
r_ymin = max(0, ymin)
r_ymax = min(MAX_Y, ymax)
# Fill values
tile[
:, :, (r_xmin - xmin) : (r_xmax - xmin), (r_ymin - ymin) : (r_ymax - ymin), :
] = raw_expt[ch, t, r_xmin:r_xmax, r_ymin:r_ymax, z]
# fill_val = np.nanmedian(tile)
# np.nan_to_num(tile, nan=fill_val, copy=False)
return tile
def get_traps_timepoint(
raw_expt, trap_locations, tp, tile_size=96, channels=None, z=None
):
"""
Get all the traps from a given time point
:param raw_expt:
:param trap_locations:
:param tp:
:param tile_size:
:param channels:
:param z:
:return: A numpy array with the traps in the (trap, C, T, X, Y,
Z) order
"""
# Set the defaults (list is mutable)
channels = channels if channels is not None else [0]
z_positions = z if z is not None else [0]
if isinstance(z_positions, slice):
n_z = z_positions.stop
z_positions = list(range(n_z)) # slice is not iterable error
elif isinstance(z_positions, list):
n_z = len(z_positions)
else:
n_z = 1
n_traps = len(trap_locations[tp])
trap_ids = list(range(n_traps))
shape = (len(channels), 1, tile_size, tile_size, n_z)
# all tiles
zct_tiles, slices, trap_ids = all_tiles(
trap_locations, shape, raw_expt, z_positions, channels, [tp], trap_ids
)
# TODO Make this an explicit function in TimelapseOMERO
images = raw_expt.pixels.getTiles(zct_tiles)
# Initialise empty traps
traps = np.full((n_traps,) + shape, np.nan)
for trap_id, (z, c, _, _), (y, x), image in zip(
trap_ids, zct_tiles, slices, images
):
ch = channels.index(c)
z_pos = z_positions.index(z)
traps[trap_id, ch, 0, x[0] : x[1], y[0] : y[1], z_pos] = image
for x in traps: # By channel
np.nan_to_num(x, nan=np.nanmedian(x), copy=False)
return traps
def centre(img, percentage=0.3):
y, x = img.shape
cropx = int(np.ceil(x * percentage))
cropy = int(np.ceil(y * percentage))
startx = int(x // 2 - (cropx // 2))
starty = int(y // 2 - (cropy // 2))
return img[starty : starty + cropy, startx : startx + cropx]
def align_timelapse_images(
raw_data, channel=0, reference_reset_time=80, reference_reset_drift=25
):
"""
Uses image registration to align images in the timelapse.
Uses the channel with id `channel` to perform the registration.
Starts with the first timepoint as a reference and changes the
reference to the current timepoint if either the images have moved
by half of a trap width or `reference_reset_time` has been reached.
Sets `self.drift`, a 3D numpy array with shape (t, drift_x, drift_y).
We assume no drift occurs in the z-direction.
:param reference_reset_drift: Upper bound on the allowed drift before
resetting the reference image.
:param reference_reset_time: Upper bound on number of time points to
register before resetting the reference image.
:param channel: index of the channel to use for image registration.
"""
ref = centre(np.squeeze(raw_data[channel, 0, :, :, 0]))
size_t = raw_data.shape[1]
drift = [np.array([0, 0])]
for i in range(1, size_t):
img = centre(np.squeeze(raw_data[channel, i, :, :, 0]))
shifts, _, _ = feature.register_translation(ref, img)
# If a huge move is detected at a single time point it is taken
# to be inaccurate and the correction from the previous time point
# is used.
# This might be common if there is a focus loss for example.
if any([abs(x - y) > reference_reset_drift for x, y in zip(shifts, drift[-1])]):
shifts = drift[-1]
drift.append(shifts)
ref = img
# TODO test necessity for references, description below
# If the images have drifted too far from the reference or too
# much time has passed we change the reference and keep track of
# which images are kept as references
return np.stack(drift)
import itertools
import logging
import h5py
import numpy as np
from pathlib import Path
from tqdm import tqdm
import cv2
from aliby.io.matlab import matObject
from agora.io.utils import Cache, imread, get_store_path
logger = logging.getLogger(__name__)
def parse_local_fs(pos_dir, tp=None):
"""
Local file structure:
- pos_dir
-- exptID_{timepointID}_{ChannelID}_{z_position_id}.png
:param pos_dirs:
:return: Image_mapper
"""
pos_dir = Path(pos_dir)
img_mapper = dict()
def channel_idx(img_name):
return img_name.stem.split("_")[-2]
def tp_idx(img_name):
return int(img_name.stem.split("_")[-3]) - 1
def z_idx(img_name):
return img_name.stem.split("_")[-1]
if tp is not None:
img_list = [img for img in pos_dir.iterdir() if tp_idx(img) in tp]
else:
img_list = [img for img in pos_dir.iterdir()]
for tp, group in itertools.groupby(sorted(img_list, key=tp_idx), key=tp_idx):
img_mapper[int(tp)] = {
channel: {i: item for i, item in enumerate(sorted(grp, key=z_idx))}
for channel, grp in itertools.groupby(
sorted(group, key=channel_idx), key=channel_idx
)
}
return img_mapper
class Timelapse:
"""
Timelapse class contains the specifics of one position.
"""
def __init__(self):
self._id = None
self._name = None
self._channels = []
self._size_c = 0
self._size_t = 0
self._size_x = 0
self._size_y = 0
self._size_z = 0
self.image_cache = None
self.annotation = None
def __repr__(self):
return self.name
def full_mask(self):
return np.full(self.shape, False)
def __getitem__(self, item):
cached = self.image_cache[item]
# Check if there are missing values, if so reload
# TODO only reload missing
mask = np.isnan(cached)
if np.any(mask):
full = self.load_fn(item)
shape = self.image_cache[
item
].shape # TODO speed this up by recognising the shape from the item
self.image_cache[item] = np.reshape(full, shape)
return full
return cached
def get_hypercube(self):
pass
def load_fn(self, item):
"""
The hypercube is ordered as: C, T, X, Y, Z
:param item:
:return:
"""
def parse_slice(s):
step = s.step if s.step is not None else 1
if s.start is None and s.stop is None:
return None
elif s.start is None and s.stop is not None:
return range(0, s.stop, step)
elif s.start is not None and s.stop is None:
return [s.start]
else: # both s.start and s.stop are not None
return range(s.start, s.stop, step)
def parse_subitem(subitem, kw):
if isinstance(subitem, (int, float)):
res = [int(subitem)]
elif isinstance(subitem, list) or isinstance(subitem, tuple):
res = list(subitem)
elif isinstance(subitem, slice):
res = parse_slice(subitem)
else:
res = subitem
# raise ValueError(f"Cannot parse slice {kw}: {subitem}")
if kw in ["x", "y"]:
# Need exactly two values
if res is not None:
if len(res) < 2:
# An int was passed, assume it was
res = [res[0], self.size_x]
elif len(res) > 2:
res = [res[0], res[-1] + 1]
return res
if isinstance(item, int):
return self.get_hypercube(
x=None, y=None, z_positions=None, channels=[item], timepoints=None
)
elif isinstance(item, slice):
return self.get_hypercube(channels=parse_slice(item))
keywords = ["channels", "timepoints", "x", "y", "z_positions"]
kwargs = dict()
for kw, subitem in zip(keywords, item):
kwargs[kw] = parse_subitem(subitem, kw)
return self.get_hypercube(**kwargs)
@property
def shape(self):
return (self.size_c, self.size_t, self.size_x, self.size_y, self.size_z)
@property
def id(self):
return self._id
@property
def name(self):
return self._name
@property
def size_z(self):
return self._size_z
@property
def size_c(self):
return self._size_c
@property
def size_t(self):
return self._size_t
@property
def size_x(self):
return self._size_x
@property
def size_y(self):
return self._size_y
@property
def channels(self):
return self._channels
def get_channel_index(self, channel):
return self.channels.index(channel)
def load_annotation(filepath: Path):
try:
return matObject(filepath)
except Exception as e:
raise (
"Could not load annotation file. \n"
"Non MATLAB files currently unsupported"
) from e
class TimelapseOMERO(Timelapse):
"""
Connected to an Image object which handles database I/O.
"""
def __init__(self, image, annotation, cache, **kwargs):
super(TimelapseOMERO, self).__init__()
self.image = image
# Pre-load pixels
self.pixels = self.image.getPrimaryPixels()
self._id = self.image.getId()
self._name = self.image.getName()
self._size_x = self.image.getSizeX()
self._size_y = self.image.getSizeY()
self._size_z = self.image.getSizeZ()
self._size_c = self.image.getSizeC()
self._size_t = self.image.getSizeT()
self._channels = self.image.getChannelLabels()
# Check whether there are file annotations for this position
if annotation is not None:
self.annotation = load_annotation(annotation)
# Get an HDF5 dataset to use as a cache.
compression = kwargs.get("compression", None)
self.image_cache = cache.require_dataset(
self.name,
self.shape,
dtype=np.float16,
fillvalue=np.nan,
compression=compression,
)
def get_hypercube(
self, x=None, y=None, z_positions=None, channels=None, timepoints=None
):
if x is None and y is None:
tile = None # Get full plane
elif x is None:
ymin, ymax = y
tile = (None, ymin, None, ymax - ymin)
elif y is None:
xmin, xmax = x
tile = (xmin, None, xmax - xmin, None)
else:
xmin, xmax = x
ymin, ymax = y
tile = (xmin, ymin, xmax - xmin, ymax - ymin)
if z_positions is None:
z_positions = range(self.size_z)
if channels is None:
channels = range(self.size_c)
if timepoints is None:
timepoints = range(self.size_t)
z_positions = z_positions or [0]
channels = channels or [0]
timepoints = timepoints or [0]
zcttile_list = [
(z, c, t, tile)
for z, c, t in itertools.product(z_positions, channels, timepoints)
]
planes = list(self.pixels.getTiles(zcttile_list))
order = (
len(z_positions),
len(channels),
len(timepoints),
planes[0].shape[-2],
planes[0].shape[-1],
)
result = np.stack([x for x in planes]).reshape(order)
# Set to C, T, X, Y, Z order
result = np.moveaxis(result, -1, -2)
return np.moveaxis(result, 0, -1)
def cache_set(self, save_dir, timepoints, expt_name, quiet=True):
# TODO deprecate when this is default
pos_dir = save_dir / self.name
if not pos_dir.exists():
pos_dir.mkdir()
for tp in tqdm(timepoints, desc=self.name):
for channel in tqdm(self.channels, disable=quiet):
for z_pos in tqdm(range(self.size_z), disable=quiet):
ch_id = self.get_channel_index(channel)
image = self.get_hypercube(
x=None,
y=None,
channels=[ch_id],
z_positions=[z_pos],
timepoints=[tp],
)
im_name = "{}_{:06d}_{}_{:03d}.png".format(
expt_name, tp + 1, channel, z_pos + 1
)
cv2.imwrite(str(pos_dir / im_name), np.squeeze(image))
# TODO update positions table to get the number of timepoints?
return list(itertools.product([self.name], timepoints))
def run(self, keys, store, save_dir="./", **kwargs):
"""
Parse file structure and get images for the timepoints in keys.
"""
save_dir = Path(save_dir)
if keys is None:
# TODO save final metadata
return None
store = save_dir / store
# A position specific store
store = store.with_name(self.name + store.name)
# Create store if it does not exist
if not store.exists():
# The first run, add metadata to the store
with h5py.File(store, "w") as pos_store:
# TODO Add metadata to the store.
pass
# TODO check how sensible the keys are with what is available
# if some of the keys don't make sense, log a warning and remove
# them so that the next steps of the pipeline make sense
return keys
def clear_cache(self):
self.image_cache.clear()
class TimelapseLocal(Timelapse):
def __init__(
self, position, root_dir, finished=True, annotation=None, cache=None, **kwargs
):
"""
Linked to a local directory containing the images for one position
in an experiment.
Can be a still running experiment or a finished one.
:param position: Name of the position
:param root_dir: Root directory
:param finished: Whether the experiment has finished running or the
class will be used as part of a pipeline, mostly with calls to `run`
"""
super(TimelapseLocal, self).__init__()
self.pos_dir = Path(root_dir) / position
assert self.pos_dir.exists()
self._id = position
self._name = position
if finished:
self.image_mapper = parse_local_fs(self.pos_dir)
self._update_metadata()
else:
self.image_mapper = dict()
self.annotation = None
# Check whether there are file annotations for this position
if annotation is not None:
self.annotation = load_annotation(annotation)
compression = kwargs.get("compression", None)
self.image_cache = cache.require_dataset(
self.name,
self.shape,
dtype=np.float16,
fillvalue=np.nan,
compression=compression,
)
def _update_metadata(self):
self._size_t = len(self.image_mapper)
# Todo: if cy5 is the first one it causes issues with getting x, y
# hence the sorted but it's not very robust
self._channels = sorted(
list(set.union(*[set(tp.keys()) for tp in self.image_mapper.values()]))
)
self._size_c = len(self._channels)
# Todo: refactor so we don't rely on there being any images at all
self._size_z = max([len(self.image_mapper[0][ch]) for ch in self._channels])
single_img = self.get_hypercube(
x=None, y=None, z_positions=None, channels=[0], timepoints=[0]
)
self._size_x = single_img.shape[2]
self._size_y = single_img.shape[3]
def get_hypercube(
self, x=None, y=None, z_positions=None, channels=None, timepoints=None
):
xmin, xmax = x if x is not None else (None, None)
ymin, ymax = y if y is not None else (None, None)
if z_positions is None:
z_positions = range(self.size_z)
if channels is None:
channels = range(self.size_c)
if timepoints is None:
timepoints = range(self.size_t)
def z_pos_getter(z_positions, ch_id, t):
default = np.zeros((self.size_x, self.size_y))
names = [
self.image_mapper[t][self.channels[ch_id]].get(i, None)
for i in z_positions
]
res = [imread(name) if name is not None else default for name in names]
return res
# nested list of images in C, T, X, Y, Z order
ctxyz = []
for ch_id in channels:
txyz = []
for t in timepoints:
xyz = z_pos_getter(z_positions, ch_id, t)
txyz.append(np.dstack(list(xyz))[xmin:xmax, ymin:ymax])
ctxyz.append(np.stack(txyz))
return np.stack(ctxyz)
def clear_cache(self):
self.image_cache.clear()
def run(self, keys, store, save_dir="./", **kwargs):
"""
Parse file structure and get images for the time points in keys.
"""
if keys is None:
return None
elif isinstance(keys, int):
keys = [keys]
self.image_mapper.update(parse_local_fs(self.pos_dir, tp=keys))
self._update_metadata()
# Create store if it does not exist
store = get_store_path(save_dir, store, self.name)
if not store.exists():
# The first run, add metadata to the store
with h5py.File(store, "w") as pos_store:
# TODO Add metadata to the store.
pass
# TODO check how sensible the keys are with what is available
# if some of the keys don't make sense, log a warning and remove
# them so that the next steps of the pipeline make sense
return keys
This diff is collapsed.
{"host": "sce-bio-c04287.bio.ed.ac.uk", "password": "***REMOVED***", "port": 4064,
"user": "upload", "experiment": 10932}
# Modified from https://ome.github.io/training-docker/12-dockercompose/
version: "3"
services:
database:
image: "postgres:11"
environment:
- POSTGRES_USER=omero
- POSTGRES_DB=omero
- POSTGRES_PASSWORD=SeCrEtPaSsWoRd
networks:
- omero-network
volumes:
- "database-volume:/var/lib/postgresql/data"
omeroserver:
image: "openmicroscopy/omero-server:5.6"
environment:
- CONFIG_omero_db_host=database
- CONFIG_omero_db_user=omero
- CONFIG_omero_db_pass=SeCrEtPaSsWoRd
- CONFIG_omero_db_name=omero
- ROOTPASS=omero-root-password
networks:
- omero-network
ports:
- "4063:4063"
- "4064:4064"
volumes:
- "omero-volume:/OMERO"
omeroweb:
image: "openmicroscopy/omero-web-standalone:master"
environment:
- OMEROHOST=omeroserver
networks:
- omero-network
ports:
- "4080:4080"
networks:
omero-network:
volumes:
database-volume:
omero-volume:
numpydoc>=1.3.1
aliby[network]>=0.1.43
sphinx-autodoc-typehints==1.19.2
sphinx-rtd-theme==1.0.0
sphinxcontrib-applehelp==1.0.2
sphinxcontrib-devhelp==1.0.2
sphinxcontrib-htmlhelp==2.0.0
sphinxcontrib-jsmath==1.0.1
sphinxcontrib-qthelp==1.0.3
sphinxcontrib-serializinghtml==1.1.5
myst-parser
This diff is collapsed.
# Running the analysis pipeline
You can run the analysis pipeline either via the command line interface (CLI) or using a script that incorporates the `aliby.pipeline.Pipeline` object.
## CLI
On a CLI, you can use the `aliby-run` command. This command takes options as follows:
- `--host`: Address of image-hosting server.
- `--username`: Username to access image-hosting server.
- `--password`: Password to access image-hosting server.
- `--expt_id`: Number ID of experiment stored on host server.
- `--distributed`: Number of distributed cores to use for segmentation and signal processing. If 0, there is no parallelisation.
- `--tps`: Optional. Number of time points from the beginning of the experiment to use. If not specified, the pipeline processes all time points.
- `--directory`: Optional. Parent directory to save the data files (HDF5) generated, `./data` by default; the files will be stored in a child directory whose name is the name of the experiment.
- `--filter`: Optional. List of positions to use for analysis. Alternatively, a regex (regular expression) or list of regexes to search for positions. **Note: for the CLI, currently it is not able to take a list of strings as input.**
- `--overwrite`: Optional. Whether to overwrite an existing data directory. True by default.
- `--override_meta`: Optional. Whether to overwrite an existing data directory. True by default.
Example usage:
```bash
aliby-run --expt_id EXPT_PATH --distributed 4 --tps None
```
And to run Omero servers, the basic arguments are shown:
```bash
aliby-run --expt_id XXX --host SERVER.ADDRESS --user USER --password PASSWORD
```
## Script
Use the `aliby.pipeline.Pipeline` object and supply a dictionary, following the example below. The meaning of the parameters are the same as described in the CLI section above.
```python
#!/usr/bin/env python3
from aliby.pipeline import Pipeline, PipelineParameters
# Specify experiment IDs
ids = [101, 102]
for i in ids:
print(i)
try:
params = PipelineParameters.default(
# Create dictionary to define pipeline parameters.
general={
"expt_id": i,
"distributed": 6,
"host": "INSERT ADDRESS HERE",
"username": "INSERT USERNAME HERE",
"password": "INSERT PASSWORD HERE",
# Ensure data will be overwriten
"override_meta": True,
"overwrite": True,
}
)
# Fine-grained control beyond general parameters:
# change specific leaf in the extraction tree.
# This example tells the pipeline to additionally compute the
# nuc_est_conv quantity, which is a measure of the degree of
# localisation of a signal in a cell.
params = params.to_dict()
leaf_to_change = params["extraction"]["tree"]["GFP"]["np_max"]
leaf_to_change.add("nuc_est_conv")
# Regenerate PipelineParameters
p = Pipeline(PipelineParameters.from_dict(params))
# Run pipeline
p.run()
# Error handling
except Exception as e:
print(e)
```
This example code can be the contents of a `run.py` file, and you can run it via
```bash
python run.py
```
in the appropriate virtual environment.
Alternatively, the example code can be the contents of a cell in a jupyter notebook.
This diff is collapsed.
This diff is collapsed.
..
DO NOT DELETE THIS FILE! It contains the all-important `.. autosummary::` directive with `:recursive:` option, without
which API documentation wouldn't get extracted from docstrings by the `sphinx.ext.autosummary` engine. It is hidden
(not declared in any toctree) to remove an unnecessary intermediate page; index.rst instead points directly to the
package page. DO NOT REMOVE THIS FILE!.. autosummary::
.. autosummary::
:toctree: _autosummary
:template: custom-module-template.rst
:recursive:
aliby
agora
extraction
postprocessor
logfile_parser
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.