Skip to content
Snippets Groups Projects
Commit 17b9d3e5 authored by Alán Muñoz's avatar Alán Muñoz
Browse files

add agora and cells

parent c92f6cd9
No related branches found
No related tags found
No related merge requests found
import logging
from pathlib import Path, PosixPath
from time import perf_counter
from typing import Union
from itertools import groupby
from collections.abc import Iterable
from utils_find_1st import find_1st, cmp_equal
import h5py
import numpy as np
from scipy import ndimage
from scipy.sparse.base import isdense
from agora.io.writer import load_complex
def cell_factory(store, type="hdf5"):
if type == "hdf5":
return CellsHDF(store)
else:
raise TypeError(
"Could not get cells for type {}:" "valid types are matlab and hdf5"
)
class Cells:
"""An object that gathers information about all the cells in a given
trap.
This is the abstract object, used for type testing
"""
def __init__(self):
pass
@staticmethod
def from_source(source: Union[PosixPath, str], kind: str = None):
if isinstance(source, str):
source = Path(source)
if kind is None: # Infer kind from filename
kind = "matlab" if source.suffix == ".mat" else "hdf5"
return cell_factory(source, kind)
@staticmethod
def _asdense(array):
if not isdense(array):
array = array.todense()
return array
@staticmethod
def _astype(array, kind):
# Convert sparse arrays if needed and if kind is 'mask' it fills the outline
array = Cells._asdense(array)
if kind == "mask":
array = ndimage.binary_fill_holes(array).astype(int)
return array
@classmethod
def hdf(cls, fpath):
return CellsHDF(fpath)
@classmethod
def mat(cls, path):
return CellsMat(matObject(store))
class CellsHDF(Cells):
def __init__(self, filename, path="cell_info"):
self.filename = filename
self.cinfo_path = path
self._edgem_indices = None
self._edgemasks = None
self._tile_size = None
def __getitem__(self, item):
if item == "edgemasks":
return self.edgemasks
_item = "_" + item
if not hasattr(self, _item):
setattr(self, _item, self._fetch(item))
return getattr(self, _item)
def _get_idx(self, cell_id, trap_id):
return (self["cell_label"] == cell_id) & (self["trap"] == trap_id)
def _fetch(self, path):
with h5py.File(self.filename, mode="r") as f:
return f[self.cinfo_path][path][()]
@property
def ntraps(self):
with h5py.File(self.filename, mode="r") as f:
return len(f["/trap_info/trap_locations"][()])
@property
def traps(self):
return list(set(self["trap"]))
@property
def tile_size(self): # TODO read from metadata
if self._tile_size is None:
with h5py.File(self.filename, mode="r") as f:
self._tile_size == f["trap_info/tile_size"][0]
return self._tile_size
@property
def edgem_indices(self):
if self._edgem_indices is None:
edgem_path = "edgemasks/indices"
self._edgem_indices = load_complex(self._fetch(edgem_path))
return self._edgem_indices
@property
def edgemasks(self):
if self._edgemasks is None:
edgem_path = "edgemasks/values"
self._edgemasks = self._fetch(edgem_path)
return self._edgemasks
def _edgem_where(self, cell_id, trap_id):
ix = trap_id + 1j * cell_id
return find_1st(self.edgem_indices == ix, True, cmp_equal)
@property
def labels(self):
"""
Return all cell labels in object
We use mother_assign to list traps because it is the only propriety that appears even
when no cells are found"""
return [self.labels_in_trap(trap) for trap in self.traps]
def where(self, cell_id, trap_id):
"""
Returns
Parameters
----------
cell_id: int
Cell index
trap_id: int
Trap index
Returns
----------
indices int array
boolean mask array
edge_ix int array
"""
indices = self._get_idx(cell_id, trap_id)
edgem_ix = self._edgem_where(cell_id, trap_id)
return (
self["timepoint"][indices],
indices,
edgem_ix,
) # FIXME edgem_ix makes output different to matlab's Cell
def outline(self, cell_id, trap_id):
times, indices, cell_ix = self.where(cell_id, trap_id)
return times, self["edgemasks"][cell_ix, times]
def mask(self, cell_id, trap_id):
times, outlines = self.outline(cell_id, trap_id)
return times, np.array(
[ndimage.morphology.binary_fill_holes(o) for o in outlines]
)
def at_time(self, timepoint, kind="mask"):
ix = self["timepoint"] == timepoint
cell_ix = self["cell_label"][ix]
traps = self["trap"][ix]
indices = traps + 1j * cell_ix
choose = np.in1d(self.edgem_indices, indices)
edgemasks = self["edgemasks"][choose, timepoint]
masks = [
self._astype(edgemask, kind) for edgemask in edgemasks if edgemask.any()
]
return self.group_by_traps(traps, masks)
def group_by_traps(self, traps, data):
# returns a dict with traps as keys and labels as value
iterator = groupby(zip(traps, data), lambda x: x[0])
d = {key: [x[1] for x in group] for key, group in iterator}
d = {i: d.get(i, []) for i in self.traps}
return d
def labels_in_trap(self, trap_id):
# Return set of cell ids in a trap.
return set((self["cell_label"][self["trap"] == trap_id]))
def labels_at_time(self, timepoint):
labels = self["cell_label"][self["timepoint"] == timepoint]
traps = self["trap"][self["timepoint"] == timepoint]
return self.group_by_traps(traps, labels)
class CellsMat(Cells):
def __init__(self, mat_object):
super(CellsMat, self).__init__()
# TODO add __contains__ to the matObject
timelapse_traps = mat_object.get(
"timelapseTrapsOmero", mat_object.get("timelapseTraps", None)
)
if timelapse_traps is None:
raise NotImplementedError(
"Could not find a timelapseTraps or "
"timelapseTrapsOmero object. Cells "
"from cellResults not implemented"
)
else:
self.trap_info = timelapse_traps["cTimepoint"]["trapInfo"]
if isinstance(self.trap_info, list):
self.trap_info = {
k: list([res.get(k, []) for res in self.trap_info])
for k in self.trap_info[0].keys()
}
def where(self, cell_id, trap_id):
times, indices = zip(
*[
(tp, np.where(cell_id == x)[0][0])
for tp, x in enumerate(self.trap_info["cellLabel"][:, trap_id].tolist())
if np.any(cell_id == x)
]
)
return times, indices
def outline(self, cell_id, trap_id):
times, indices = self.where(cell_id, trap_id)
info = self.trap_info["cell"][times, trap_id]
def get_segmented(cell, index):
if cell["segmented"].ndim == 0:
return cell["segmented"][()].todense()
else:
return cell["segmented"][index].todense()
segmentation_outline = [
get_segmented(cell, idx) for idx, cell in zip(indices, info)
]
return times, np.array(segmentation_outline)
def mask(self, cell_id, trap_id):
times, outlines = self.outline(cell_id, trap_id)
return times, np.array(
[ndimage.morphology.binary_fill_holes(o) for o in outlines]
)
def at_time(self, timepoint, kind="outline"):
"""Returns the segmentations for all the cells at a given timepoint.
FIXME: this is extremely hacky and accounts for differently saved
results in the matlab object. Deprecate ASAP.
"""
# Case 1: only one cell per trap: trap_info['cell'][timepoint] is a
# structured array
if isinstance(self.trap_info["cell"][timepoint], dict):
segmentations = [
self._astype(x, "outline")
for x in self.trap_info["cell"][timepoint]["segmented"]
]
# Case 2: Multiple cells per trap: it becomes a list of arrays or
# dictionaries, one for each trap
# Case 2.1 : it's a dictionary
elif isinstance(self.trap_info["cell"][timepoint][0], dict):
segmentations = []
for x in self.trap_info["cell"][timepoint]:
seg = x["segmented"]
if not isinstance(seg, np.ndarray):
seg = [seg]
segmentations.append([self._astype(y, "outline") for y in seg])
# Case 2.2 : it's an array
else:
segmentations = [
[self._astype(y, type) for y in x["segmented"]] if x.ndim != 0 else []
for x in self.trap_info["cell"][timepoint]
]
# Return dict for compatibility with hdf5 output
return {i: v for i, v in enumerate(segmentations)}
def labels_at_time(self, tp):
labels = self.trap_info["cellLabel"]
labels = [_aslist(x) for x in labels[tp]]
labels = {i: [lbl for lbl in lblset] for i, lblset in enumerate(labels)}
return labels
@property
def ntraps(self):
return len(self.trap_info["cellLabel"][0])
@property
def tile_size(self):
pass
class ExtractionRunner:
"""An object to run extraction of fluorescence, and general data out of
segmented data.
Configure with what extraction we want to run.
Cell selection criteria.
Filtering criteria.
"""
def __init__(self, tiler, cells):
pass
def run(self, keys, store, **kwargs):
pass
def _aslist(x):
if isinstance(x, Iterable):
if hasattr(x, "tolist"):
x = x.tolist()
else:
x = [x]
return x
"""Segment/segmented pipelines.
Includes splitting the image into traps/parts,
cell segmentation, nucleus segmentation."""
import warnings
from functools import lru_cache
import h5py
import numpy as np
from pathlib import Path, PosixPath
from skimage.registration import phase_cross_correlation
from agora.abc import ParametersABC, ProcessABC
from aliby.traps import segment_traps
from agora.io.writer import load_attributes
trap_template_directory = Path(__file__).parent / "trap_templates"
# TODO do we need multiple templates, one for each setup?
trap_template = np.array([]) # np.load(trap_template_directory / "trap_prime.npy")
def get_tile_shapes(x, tile_size, max_shape):
half_size = tile_size // 2
xmin = int(x[0] - half_size)
ymin = max(0, int(x[1] - half_size))
if xmin + tile_size > max_shape[0]:
xmin = max_shape[0] - tile_size
if ymin + tile_size > max_shape[1]:
ymin = max_shape[1] - tile_size
return xmin, xmin + tile_size, ymin, ymin + tile_size
###################### Dask versions ########################
class Trap:
def __init__(self, centre, parent, size, max_size):
self.centre = centre
self.parent = parent # Used to access drifts
self.size = size
self.half_size = size // 2
self.max_size = max_size
def padding_required(self, tp):
"""Check if we need to pad the trap image for this time point."""
try:
assert all(self.at_time(tp) - self.half_size >= 0)
assert all(self.at_time(tp) + self.half_size <= self.max_size)
except AssertionError:
return True
return False
def at_time(self, tp):
"""Return trap centre at time tp"""
drifts = self.parent.drifts
return self.centre - np.sum(drifts[:tp], axis=0)
def as_tile(self, tp):
"""Return trap in the OMERO tile format of x, y, w, h
Also returns the padding necessary for this tile.
"""
x, y = self.at_time(tp)
# tile bottom corner
x = int(x - self.half_size)
y = int(y - self.half_size)
return x, y, self.size, self.size
def as_range(self, tp):
"""Return trap in a range format, two slice objects that can be used in Arrays"""
x, y, w, h = self.as_tile(tp)
return slice(x, x + w), slice(y, y + h)
class TrapLocations:
def __init__(self, initial_location, tile_size, max_size=1200, drifts=[]):
self.tile_size = tile_size
self.max_size = max_size
self.initial_location = initial_location
self.traps = [
Trap(centre, self, tile_size, max_size) for centre in initial_location
]
self.drifts = drifts
@classmethod
def from_source(cls, fpath: str):
with h5py.File(fpath, "r") as f:
# TODO read tile size from file metadata
drifts = f["trap_info/drifts"][()]
tlocs = cls(f["trap_info/trap_locations"][()], tile_size=96, drifts=drifts)
return tlocs
@property
def shape(self):
return len(self.traps), len(self.drifts)
def __len__(self):
return len(self.traps)
def __iter__(self):
yield from self.traps
def padding_required(self, tp):
return any([trap.padding_required(tp) for trap in self.traps])
def to_dict(self, tp):
res = dict()
if tp == 0:
res["trap_locations"] = self.initial_location
res["attrs/tile_size"] = self.tile_size
res["attrs/max_size"] = self.max_size
res["drifts"] = np.expand_dims(self.drifts[tp], axis=0)
# res['processed_timepoints'] = tp
return res
@classmethod
def read_hdf5(cls, file):
with h5py.File(file, "r") as hfile:
trap_info = hfile["trap_info"]
initial_locations = trap_info["trap_locations"][()]
drifts = trap_info["drifts"][()]
max_size = trap_info.attrs["max_size"]
tile_size = trap_info.attrs["tile_size"]
trap_locs = cls(initial_locations, tile_size, max_size=max_size)
trap_locs.drifts = drifts
return trap_locs
class TilerParameters(ParametersABC):
def __init__(
self, tile_size: int, ref_channel: str, ref_z: int, template_name: str = None
):
self.tile_size = tile_size
self.ref_channel = ref_channel
self.ref_z = ref_z
self.template_name = template_name
@classmethod
def from_template(cls, template_name: str, ref_channel: str, ref_z: int):
return cls(template.shape[0], ref_channel, ref_z, template_path=template_name)
@classmethod
def default(cls):
return cls(96, "Brightfield", 0)
class Tiler(ProcessABC):
"""A dummy TimelapseTiler object fora Dask Demo.
Does trap finding and image registration."""
def __init__(
self,
image,
metadata,
parameters: TilerParameters,
):
super().__init__(parameters)
self.image = image
self.channels = metadata["channels"]
self.ref_channel = self.get_channel_index(parameters.ref_channel)
@classmethod
def from_image(cls, image, parameters: TilerParameters):
return cls(image.data, image.metadata, parameters)
@classmethod
def from_hdf5(cls, image, filepath, tile_size=None):
trap_locs = TrapLocations.read_hdf5(filepath)
metadata = load_attributes(filepath)
metadata["channels"] = metadata["channels/channel"].tolist()
if tile_size is None:
tile_size = trap_locs.tile_size
return Tiler(
image=image,
metadata=metadata,
template=None,
tile_size=tile_size,
trap_locs=trap_locs,
)
@lru_cache(maxsize=2)
def get_tc(self, t, c):
# Get image
full = self.image[t, c].compute() # FORCE THE CACHE
return full
@property
def shape(self):
c, t, z, y, x = self.image.shape
return (c, t, x, y, z)
@property
def n_processed(self):
if not hasattr(self, "_n_processed"):
self._n_processed = 0
return self._n_processed
@n_processed.setter
def n_processed(self, value):
self._n_processed = value
@property
def n_traps(self):
return len(self.trap_locs)
@property
def finished(self):
return self.n_processed == self.image.shape[0]
def _initialise_traps(self, tile_size):
"""Find initial trap positions.
Removes all those that are too close to the edge so no padding is necessary.
"""
half_tile = tile_size // 2
max_size = min(self.image.shape[-2:])
initial_image = self.image[
0, self.ref_channel, self.ref_z
] # First time point, first channel, first z-position
trap_locs = segment_traps(initial_image, tile_size)
trap_locs = [
[x, y]
for x, y in trap_locs
if half_tile < x < max_size - half_tile
and half_tile < y < max_size - half_tile
]
self.trap_locs = TrapLocations(trap_locs, tile_size)
def find_drift(self, tp):
# TODO check that the drift doesn't move any tiles out of the image, remove them from list if so
prev_tp = max(0, tp - 1)
drift, error, _ = phase_cross_correlation(
self.image[prev_tp, self.ref_channel, self.ref_z],
self.image[tp, self.ref_channel, self.ref_z],
)
self.trap_locs.drifts.append(drift)
def get_tp_data(self, tp, c):
traps = []
full = self.get_tc(tp, c)
# if self.trap_locs.padding_required(tp):
for trap in self.trap_locs:
ndtrap = self.ifoob_pad(full, trap.as_range(tp))
traps.append(ndtrap)
return np.stack(traps)
def get_trap_data(self, trap_id, tp, c):
full = self.get_tc(tp, c)
trap = self.trap_locs.traps[trap_id]
ndtrap = self.ifoob_pad(full, trap.as_range(tp))
return ndtrap
@staticmethod
def ifoob_pad(full, slices):
"""
Returns the slices padded if it is out of bounds
Parameters:
----------
full: (zstacks, max_size, max_size) ndarray
Entire position with zstacks as first axis
slices: tuple of two slices
Each slice indicates an axis to index
Returns
Trap for given slices, padded with median if needed, or np.nan if the padding is too much
"""
max_size = full.shape[-1]
y, x = [slice(max(0, s.start), min(max_size, s.stop)) for s in slices]
trap = full[:, y, x]
padding = np.array(
[(-min(0, s.start), -min(0, max_size - s.stop)) for s in slices]
)
if padding.any():
tile_size = slices[0].stop - slices[0].start
if (padding > tile_size / 4).any():
trap = np.full((full.shape[0], tile_size, tile_size), np.nan)
else:
trap = np.pad(trap, [[0, 0]] + padding.tolist(), "median")
return trap
def run_tp(self, tp):
assert tp >= self.n_processed, "Time point already processed"
# TODO check contiguity?
if self.n_processed == 0:
self._initialise_traps(self.tile_size)
self.find_drift(tp) # Get drift
# update n_processed
self.n_processed += 1
# Return result for writer
return self.trap_locs.to_dict(tp)
def run(self, tp):
if self.n_processed == 0:
self._initialise_traps(self.tile_size)
self.find_drift(tp) # Get drift
# update n_processed
self.n_processed += 1
# Return result for writer
return self.trap_locs.to_dict(tp)
# The next set of functions are necessary for the extraction object
def get_traps_timepoint(self, tp, tile_size=None, channels=None, z=None):
# FIXME we currently ignore the tile size
# FIXME can we ignore z(always give)
res = []
for c in channels:
val = self.get_tp_data(tp, c)[:, z] # Only return requested z
# positions
# Starts at traps, z, y, x
# Turn to Trap, C, T, X, Y, Z order
val = val.swapaxes(1, 3).swapaxes(1, 2)
val = np.expand_dims(val, axis=1)
res.append(val)
return np.stack(res, axis=1)
def get_channel_index(self, item):
for i, ch in enumerate(self.channels):
if item in ch:
return i
def get_position_annotation(self):
# TODO required for matlab support
return None
"""
A set of utilities for dealing with ALCATRAS traps
"""
import numpy as np
from tqdm import tqdm
from skimage import transform, feature
from skimage.filters.rank import entropy
from skimage.filters import threshold_otsu
from skimage.segmentation import clear_border
from skimage.measure import label, regionprops
from skimage.morphology import disk, closing, square
def stretch_image(image):
image = ((image - image.min()) / (image.max() - image.min())) * 255
minval = np.percentile(image, 2)
maxval = np.percentile(image, 98)
image = np.clip(image, minval, maxval)
image = (image - minval) / (maxval - minval)
return image
def segment_traps(image, tile_size, downscale=0.4):
# Make image go between 0 and 255
img = image # Keep a memory of image in case need to re-run
# stretched = stretch_image(image)
# img = stretch_image(image)
# TODO Optimise the hyperparameters
disk_radius = int(min([0.01 * x for x in img.shape]))
min_area = 0.2 * (tile_size ** 2)
if downscale != 1:
img = transform.rescale(image, downscale)
entropy_image = entropy(img, disk(disk_radius))
if downscale != 1:
entropy_image = transform.rescale(entropy_image, 1 / downscale)
# apply threshold
thresh = threshold_otsu(entropy_image)
bw = closing(entropy_image > thresh, square(3))
# remove artifacts connected to image border
cleared = clear_border(bw)
# label image regions
label_image = label(cleared)
areas = [
region.area
for region in regionprops(label_image)
if region.area > min_area and region.area < tile_size ** 2 * 0.8
]
traps = (
np.array(
[
region.centroid
for region in regionprops(label_image)
if region.area > min_area and region.area < tile_size ** 2 * 0.8
]
)
.round()
.astype(int)
)
ma = (
np.array(
[
region.minor_axis_length
for region in regionprops(label_image)
if region.area > min_area and region.area < tile_size ** 2 * 0.8
]
)
.round()
.astype(int)
)
maskx = (tile_size // 2 < traps[:, 0]) & (
traps[:, 0] < image.shape[0] - tile_size // 2
)
masky = (tile_size // 2 < traps[:, 1]) & (
traps[:, 1] < image.shape[1] - tile_size // 2
)
traps = traps[maskx & masky, :]
ma = ma[maskx & masky]
chosen_trap_coords = np.round(traps[ma.argmin()]).astype(int)
x, y = chosen_trap_coords
template = image[
x - tile_size // 2 : x + tile_size // 2, y - tile_size // 2 : y + tile_size // 2
]
traps = identify_trap_locations(image, template)
if len(traps) < 10 and downscale != 1:
print("Trying again.")
return segment_traps(image, tile_size, downscale=1)
return traps
# def segment_traps(image, tile_size, downscale=0.4):
# # Make image go between 0 and 255
# img = image # Keep a memory of image in case need to re-run
# image = stretch_image(image)
# # TODO Optimise the hyperparameters
# disk_radius = int(min([0.01 * x for x in img.shape]))
# min_area = 0.1 * (tile_size ** 2)
# if downscale != 1:
# img = transform.rescale(image, downscale)
# entropy_image = entropy(img, disk(disk_radius))
# if downscale != 1:
# entropy_image = transform.rescale(entropy_image, 1 / downscale)
# # apply threshold
# thresh = threshold_otsu(entropy_image)
# bw = closing(entropy_image > thresh, square(3))
# # remove artifacts connected to image border
# cleared = clear_border(bw)
# # label image regions
# label_image = label(cleared)
# traps = [
# region.centroid for region in regionprops(label_image) if region.area > min_area
# ]
# if len(traps) < 10 and downscale != 1:
# print("Trying again.")
# return segment_traps(image, tile_size, downscale=1)
# return traps
def identify_trap_locations(
image, trap_template, optimize_scale=True, downscale=0.35, trap_size=None
):
"""
Identify the traps in a single image based on a trap template.
This assumes a trap template that is similar to the image in question
(same camera, same magification; ideally same experiment).
This method speeds up the search by downscaling both the image and
the trap template before running the template match.
It also optimizes the scale and the rotation of the trap template.
:param image:
:param trap_template:
:param optimize_scale:
:param downscale:
:param trap_rotation:
:return:
"""
trap_size = trap_size if trap_size is not None else trap_template.shape[0]
# Careful, the image is float16!
img = transform.rescale(image.astype(float), downscale)
temp = transform.rescale(trap_template, downscale)
# TODO random search hyperparameter optimization
# optimize rotation
matches = {
rotation: feature.match_template(
img,
transform.rotate(temp, rotation, cval=np.median(img)),
pad_input=True,
mode="median",
)
** 2
for rotation in [0, 90, 180, 270]
}
best_rotation = max(matches, key=lambda x: np.percentile(matches[x], 99.9))
temp = transform.rotate(temp, best_rotation, cval=np.median(img))
if optimize_scale:
scales = np.linspace(0.5, 2, 10)
matches = {
scale: feature.match_template(
img, transform.rescale(temp, scale), mode="median", pad_input=True
)
** 2
for scale in scales
}
best_scale = max(matches, key=lambda x: np.percentile(matches[x], 99.9))
matched = matches[best_scale]
else:
matched = feature.match_template(img, temp, pad_input=True, mode="median")
coordinates = feature.peak_local_max(
transform.rescale(matched, 1 / downscale),
min_distance=int(trap_template.shape[0] * 0.70),
exclude_border=(trap_size // 3),
)
return coordinates
def get_tile_shapes(x, tile_size, max_shape):
half_size = tile_size // 2
xmin = int(x[0] - half_size)
ymin = max(0, int(x[1] - half_size))
# if xmin + tile_size > max_shape[0]:
# xmin = max_shape[0] - tile_size
# if ymin + tile_size > max_shape[1]:
# # ymin = max_shape[1] - tile_size
# return max(xmin, 0), xmin + tile_size, max(ymin, 0), ymin + tile_size
return xmin, xmin + tile_size, ymin, ymin + tile_size
def in_image(img, xmin, xmax, ymin, ymax, xidx=2, yidx=3):
if xmin >= 0 and ymin >= 0:
if xmax < img.shape[xidx] and ymax < img.shape[yidx]:
return True
else:
return False
def get_xy_tile(img, xmin, xmax, ymin, ymax, xidx=2, yidx=3, pad_val=None):
if pad_val is None:
pad_val = np.median(img)
# Get the tile from the image
idx = [slice(None)] * len(img.shape)
idx[xidx] = slice(max(0, xmin), min(xmax, img.shape[xidx]))
idx[yidx] = slice(max(0, ymin), min(ymax, img.shape[yidx]))
tile = img[tuple(idx)]
# Check if the tile is in the image
if in_image(img, xmin, xmax, ymin, ymax, xidx, yidx):
return tile
else:
# Add padding
pad_shape = [(0, 0)] * len(img.shape)
pad_shape[xidx] = (max(-xmin, 0), max(xmax - img.shape[xidx], 0))
pad_shape[yidx] = (max(-ymin, 0), max(ymax - img.shape[yidx], 0))
tile = np.pad(tile, pad_shape, constant_values=pad_val)
return tile
def get_trap_timelapse(
raw_expt, trap_locations, trap_id, tile_size=117, channels=None, z=None
):
"""
Get a timelapse for a given trap by specifying the trap_id
:param trap_id: An integer defining which trap to choose. Counted
between 0 and Tiler.n_traps - 1
:param tile_size: The size of the trap tile (centered around the
trap as much as possible, edge cases exist)
:param channels: Which channels to fetch, indexed from 0.
If None, defaults to [0]
:param z: Which z_stacks to fetch, indexed from 0.
If None, defaults to [0].
:return: A numpy array with the timelapse in (C,T,X,Y,Z) order
"""
# Set the defaults (list is mutable)
channels = channels if channels is not None else [0]
z = z if z is not None else [0]
# Get trap location for that id:
trap_centers = [trap_locations[i][trap_id] for i in range(len(trap_locations))]
max_shape = (raw_expt.shape[2], raw_expt.shape[3])
tiles_shapes = [
get_tile_shapes((x[0], x[1]), tile_size, max_shape) for x in trap_centers
]
timelapse = [
get_xy_tile(
raw_expt[channels, i, :, :, z], xmin, xmax, ymin, ymax, pad_val=None
)
for i, (xmin, xmax, ymin, ymax) in enumerate(tiles_shapes)
]
return np.hstack(timelapse)
def get_trap_timelapse_omero(
raw_expt, trap_locations, trap_id, tile_size=117, channels=None, z=None, t=None
):
"""
Get a timelapse for a given trap by specifying the trap_id
:param raw_expt: A Timelapse object from which data is obtained
:param trap_id: An integer defining which trap to choose. Counted
between 0 and Tiler.n_traps - 1
:param tile_size: The size of the trap tile (centered around the
trap as much as possible, edge cases exist)
:param channels: Which channels to fetch, indexed from 0.
If None, defaults to [0]
:param z: Which z_stacks to fetch, indexed from 0.
If None, defaults to [0].
:return: A numpy array with the timelapse in (C,T,X,Y,Z) order
"""
# Set the defaults (list is mutable)
channels = channels if channels is not None else [0]
z_positions = z if z is not None else [0]
times = (
t if t is not None else np.arange(raw_expt.shape[1])
) # TODO choose sub-set of time points
shape = (len(channels), len(times), tile_size, tile_size, len(z_positions))
# Get trap location for that id:
zct_tiles, slices, trap_ids = all_tiles(
trap_locations, shape, raw_expt, z_positions, channels, times, [trap_id]
)
# TODO Make this an explicit function in TimelapseOMERO
images = raw_expt.pixels.getTiles(zct_tiles)
timelapse = np.full(shape, np.nan)
total = len(zct_tiles)
for (z, c, t, _), (y, x), image in tqdm(
zip(zct_tiles, slices, images), total=total
):
ch = channels.index(c)
tp = times.tolist().index(t)
z_pos = z_positions.index(z)
timelapse[ch, tp, x[0] : x[1], y[0] : y[1], z_pos] = image
# for x in timelapse: # By channel
# np.nan_to_num(x, nan=np.nanmedian(x), copy=False)
return timelapse
def all_tiles(trap_locations, shape, raw_expt, z_positions, channels, times, traps):
_, _, x, y, _ = shape
_, _, MAX_X, MAX_Y, _ = raw_expt.shape
trap_ids = []
zct_tiles = []
slices = []
for z in z_positions:
for ch in channels:
for t in times:
for trap_id in traps:
centre = trap_locations[t][trap_id]
xmin, ymin, xmax, ymax, r_xmin, r_ymin, r_xmax, r_ymax = tile_where(
centre, x, y, MAX_X, MAX_Y
)
slices.append(
((r_ymin - ymin, r_ymax - ymin), (r_xmin - xmin, r_xmax - xmin))
)
tile = (r_ymin, r_xmin, r_ymax - r_ymin, r_xmax - r_xmin)
zct_tiles.append((z, ch, t, tile))
trap_ids.append(trap_id) # So we remember the order!
return zct_tiles, slices, trap_ids
def tile_where(centre, x, y, MAX_X, MAX_Y):
# Find the position of the tile
xmin = int(centre[1] - x // 2)
ymin = int(centre[0] - y // 2)
xmax = xmin + x
ymax = ymin + y
# What do we actually have available?
r_xmin = max(0, xmin)
r_xmax = min(MAX_X, xmax)
r_ymin = max(0, ymin)
r_ymax = min(MAX_Y, ymax)
return xmin, ymin, xmax, ymax, r_xmin, r_ymin, r_xmax, r_ymax
def get_tile(shape, center, raw_expt, ch, t, z):
"""Returns a tile from the raw experiment with a given shape.
:param shape: The shape of the tile in (C, T, Z, Y, X) order.
:param center: The x,y position of the centre of the tile
:param
"""
_, _, x, y, _ = shape
_, _, MAX_X, MAX_Y, _ = raw_expt.shape
tile = np.full(shape, np.nan)
# Find the position of the tile
xmin = int(center[1] - x // 2)
ymin = int(center[0] - y // 2)
xmax = xmin + x
ymax = ymin + y
# What do we actually have available?
r_xmin = max(0, xmin)
r_xmax = min(MAX_X, xmax)
r_ymin = max(0, ymin)
r_ymax = min(MAX_Y, ymax)
# Fill values
tile[
:, :, (r_xmin - xmin) : (r_xmax - xmin), (r_ymin - ymin) : (r_ymax - ymin), :
] = raw_expt[ch, t, r_xmin:r_xmax, r_ymin:r_ymax, z]
# fill_val = np.nanmedian(tile)
# np.nan_to_num(tile, nan=fill_val, copy=False)
return tile
def get_traps_timepoint(
raw_expt, trap_locations, tp, tile_size=96, channels=None, z=None
):
"""
Get all the traps from a given time point
:param raw_expt:
:param trap_locations:
:param tp:
:param tile_size:
:param channels:
:param z:
:return: A numpy array with the traps in the (trap, C, T, X, Y,
Z) order
"""
# Set the defaults (list is mutable)
channels = channels if channels is not None else [0]
z_positions = z if z is not None else [0]
if isinstance(z_positions, slice):
n_z = z_positions.stop
z_positions = list(range(n_z)) # slice is not iterable error
elif isinstance(z_positions, list):
n_z = len(z_positions)
else:
n_z = 1
n_traps = len(trap_locations[tp])
trap_ids = list(range(n_traps))
shape = (len(channels), 1, tile_size, tile_size, n_z)
# all tiles
zct_tiles, slices, trap_ids = all_tiles(
trap_locations, shape, raw_expt, z_positions, channels, [tp], trap_ids
)
# TODO Make this an explicit function in TimelapseOMERO
images = raw_expt.pixels.getTiles(zct_tiles)
# Initialise empty traps
traps = np.full((n_traps,) + shape, np.nan)
for trap_id, (z, c, _, _), (y, x), image in zip(
trap_ids, zct_tiles, slices, images
):
ch = channels.index(c)
z_pos = z_positions.index(z)
traps[trap_id, ch, 0, x[0] : x[1], y[0] : y[1], z_pos] = image
for x in traps: # By channel
np.nan_to_num(x, nan=np.nanmedian(x), copy=False)
return traps
def centre(img, percentage=0.3):
y, x = img.shape
cropx = int(np.ceil(x * percentage))
cropy = int(np.ceil(y * percentage))
startx = int(x // 2 - (cropx // 2))
starty = int(y // 2 - (cropy // 2))
return img[starty : starty + cropy, startx : startx + cropx]
def align_timelapse_images(
raw_data, channel=0, reference_reset_time=80, reference_reset_drift=25
):
"""
Uses image registration to align images in the timelapse.
Uses the channel with id `channel` to perform the registration.
Starts with the first timepoint as a reference and changes the
reference to the current timepoint if either the images have moved
by half of a trap width or `reference_reset_time` has been reached.
Sets `self.drift`, a 3D numpy array with shape (t, drift_x, drift_y).
We assume no drift occurs in the z-direction.
:param reference_reset_drift: Upper bound on the allowed drift before
resetting the reference image.
:param reference_reset_time: Upper bound on number of time points to
register before resetting the reference image.
:param channel: index of the channel to use for image registration.
"""
ref = centre(np.squeeze(raw_data[channel, 0, :, :, 0]))
size_t = raw_data.shape[1]
drift = [np.array([0, 0])]
for i in range(1, size_t):
img = centre(np.squeeze(raw_data[channel, i, :, :, 0]))
shifts, _, _ = feature.register_translation(ref, img)
# If a huge move is detected at a single time point it is taken
# to be inaccurate and the correction from the previous time point
# is used.
# This might be common if there is a focus loss for example.
if any([abs(x - y) > reference_reset_drift for x, y in zip(shifts, drift[-1])]):
shifts = drift[-1]
drift.append(shifts)
ref = img
# TODO test necessity for references, description below
# If the images have drifted too far from the reference or too
# much time has passed we change the reference and keep track of
# which images are kept as references
return np.stack(drift)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment