Skip to content
Snippets Groups Projects
Commit 2df9da06 authored by pswain's avatar pswain
Browse files

some annotation

parent 4a4623b4
No related branches found
No related tags found
No related merge requests found
......@@ -76,10 +76,13 @@ from dask import delayed
def get_data_lazy(image) -> da.Array:
"""Get 5D dask array, with delayed reading from OMERO image."""
"""
Get 5D dask array, with delayed reading from OMERO image.
"""
nt, nc, nz, ny, nx = [getattr(image, f"getSize{x}")() for x in "TCZYX"]
pixels = image.getPrimaryPixels()
dtype = PIXEL_TYPES.get(pixels.getPixelsType().value, None)
# using dask
get_plane = delayed(lambda idx: pixels.getPlane(*idx))
def get_lazy_plane(zct):
......
......@@ -5,6 +5,10 @@ from datetime import datetime
import xmltodict
from tifffile import TiffFile
# dask extends numpy to multi-core machines and distributed clusters
# and allows data to be stored that is larger than the RAM by
# sharing between RAM and a hard disk
import dask.array as da
from dask.array.image import imread
......@@ -77,23 +81,50 @@ class ImageLocal:
class Image(Argo):
""""""
"""
Loads images from OMERO and gives access to the data and metadata.
"""
def __init__(self, image_id, **server_info):
'''
Establishes the connection to the OMERO server via the Argo
base class.
Parameters
----------
image_id: integer
server_info: dictionary
Specifies the host, username, and password as strings
'''
super().__init__(**server_info)
self.image_id = image_id
# images from OMERO
self._image_wrap = None
@property
def image_wrap(self):
'''
Get images from OMERO
'''
if self._image_wrap is None:
# get images using OMERO
self._image_wrap = self.conn.getObject("Image", self.image_id)
return self._image_wrap
# Version with local file processing
# version with local file processing
def get_data_lazy_local(path: str) -> da.Array:
"""Return 5D dask array. For lazy-loading local multidimensional tiff files"""
"""
For lazy-loading - loading on demand only -- local,
multidimensional tiff files.
Parameters
----------
path: string
Returns
-------
5D dask array
"""
return da.from_delayed(imread(str(path))[0], shape=())
@property
......@@ -106,6 +137,10 @@ class Image(Argo):
@property
def metadata(self):
"""
Store metadata saved in OMERO: image size, number of time points,
labels of channels, and image name.
"""
meta = dict()
meta["size_x"] = self.image_wrap.getSizeX()
meta["size_y"] = self.image_wrap.getSizeY()
......
from omero.gateway import BlitzGateway
class Argo:
"""
Base OMERO-interactive class
Base class to interact with OMERO.
See
https://docs.openmicroscopy.org/omero/5.6.0/developers/Python.html
"""
def __init__(
self, host="islay.bio.ed.ac.uk", username="upload", password="***REMOVED***"
self,
host="islay.bio.ed.ac.uk",
username="upload",
password="***REMOVED***",
):
"""
Parameters
----------
host : string
web address of OMERO host
username: string
password : string
"""
self.conn = None
self.host = host
self.username = username
self.password = password
# standard method required for Python's with statement
def __enter__(self):
self.conn = BlitzGateway(
host=self.host, username=self.username, passwd=self.password
......@@ -21,6 +34,7 @@ class Argo:
self.conn.connect()
return self
# standard method required for Python's with statement
def __exit__(self, *exc):
self.conn.close()
return False
......@@ -3,60 +3,72 @@ Includes splitting the image into traps/parts,
cell segmentation, nucleus segmentation."""
import warnings
from functools import lru_cache
import h5py
import numpy as np
from pathlib import Path, PosixPath
from skimage.registration import phase_cross_correlation
from agora.abc import ParametersABC, ProcessABC
from aliby.tile.traps import segment_traps
from agora.abc import ParametersABC, ProcessABC
from agora.io.writer import load_attributes
# Alan: is this necessary?
trap_template_directory = Path(__file__).parent / "trap_templates"
# TODO do we need multiple templates, one for each setup?
trap_template = np.array([]) # np.load(trap_template_directory / "trap_prime.npy")
trap_template = np.array([])
# np.load(trap_template_directory / "trap_prime.npy")
#
# def get_tile_shapes(x, tile_size, max_shape):
# half_size = tile_size // 2
# xmin = int(x[0] - half_size)
# ymin = max(0, int(x[1] - half_size))
# if xmin + tile_size > max_shape[0]:
# xmin = max_shape[0] - tile_size
# if ymin + tile_size > max_shape[1]:
# ymin = max_shape[1] - tile_size
# return xmin, xmin + tile_size, ymin, ymin + tile_size
def get_tile_shapes(x, tile_size, max_shape):
half_size = tile_size // 2
xmin = int(x[0] - half_size)
ymin = max(0, int(x[1] - half_size))
if xmin + tile_size > max_shape[0]:
xmin = max_shape[0] - tile_size
if ymin + tile_size > max_shape[1]:
ymin = max_shape[1] - tile_size
return xmin, xmin + tile_size, ymin, ymin + tile_size
###################### Dask versions ########################
class Trap:
'''
Stores a trap's location and size.
Allows checks to see if the trap should be padded.
Can export the trap either in OMERO or numpy formats.
'''
def __init__(self, centre, parent, size, max_size):
self.centre = centre
self.parent = parent # Used to access drifts
self.parent = parent # used to access drifts
self.size = size
self.half_size = size // 2
self.max_size = max_size
###
def padding_required(self, tp):
"""Check if we need to pad the trap image for this time point."""
"""
Check if we need to pad the trap image for this time point.
"""
try:
assert all(self.at_time(tp) - self.half_size >= 0)
assert all(self.at_time(tp) + self.half_size <= self.max_size)
# return False
except AssertionError:
return True
return False
def at_time(self, tp):
"""Return trap centre at time tp"""
"""
Return trap centre at time tp by applying drifts
"""
drifts = self.parent.drifts
return self.centre - np.sum(drifts[: tp + 1], axis=0)
###
def as_tile(self, tp):
"""Return trap in the OMERO tile format of x, y, w, h
"""
Return trap in the OMERO tile format of x, y, w, h
Also returns the padding necessary for this tile.
"""
......@@ -66,62 +78,85 @@ class Trap:
y = int(y - self.half_size)
return x, y, self.size, self.size
###
def as_range(self, tp):
"""Return trap in a range format, two slice objects that can be used in Arrays"""
"""
Return trap in a range format: two slice objects that can
be used in Arrays
"""
x, y, w, h = self.as_tile(tp)
return slice(x, x + w), slice(y, y + h)
###
class TrapLocations:
def __init__(self, initial_location, tile_size, max_size=1200, drifts=None):
'''
Stores each trap as an instance of Trap.
Traps can be iterated.
'''
def __init__(
self, initial_location, tile_size, max_size=1200, drifts=None
):
if drifts is None:
drifts = []
self.tile_size = tile_size
self.max_size = max_size
self.initial_location = initial_location
self.traps = [
Trap(centre, self, tile_size, max_size) for centre in initial_location
Trap(centre, self, tile_size, max_size)
for centre in initial_location
]
self.drifts = drifts
# @classmethod
# def from_source(cls, fpath: str):
# with h5py.File(fpath, "r") as f:
# # TODO read tile size from file metadata
# drifts = f["trap_info/drifts"][()].tolist()
# tlocs = cls(f["trap_info/trap_locations"][()], tile_size=96, drifts=drifts)
# return tlocs
@property
def shape(self):
return len(self.traps), len(self.drifts)
def __len__(self):
return len(self.traps)
def __iter__(self):
yield from self.traps
###
@property
def shape(self):
'''
Returns no of traps and no of drifts
'''
return len(self.traps), len(self.drifts)
def padding_required(self, tp):
'''
Check if any traps need padding
'''
return any([trap.padding_required(tp) for trap in self.traps])
def to_dict(self, tp):
'''
Export inital locations, tile_size, max_size, and drifts
as a dictionary
'''
res = dict()
if tp == 0:
res["trap_locations"] = self.initial_location
res["attrs/tile_size"] = self.tile_size
res["attrs/max_size"] = self.max_size
res["drifts"] = np.expand_dims(self.drifts[tp], axis=0)
# res["processed_timepoints"] = tp
return res
###
@classmethod
def from_tiler_init(cls, initial_location, tile_size, max_size=1200):
'''
Instantiate class from an instance of the Tiler class
'''
return cls(initial_location, tile_size, max_size, drifts=[])
@classmethod
def read_hdf5(cls, file):
'''
Instantiate class from a hdf5 file
'''
with h5py.File(file, "r") as hfile:
trap_info = hfile["trap_info"]
initial_locations = trap_info["trap_locations"][()]
......@@ -130,15 +165,17 @@ class TrapLocations:
tile_size = trap_info.attrs["tile_size"]
trap_locs = cls(initial_locations, tile_size, max_size=max_size)
trap_locs.drifts = drifts
# trap_locs.n_processed = len(drifts)
return trap_locs
###
class TilerParameters(ParametersABC):
_defaults = {"tile_size": 117,
"ref_channel": "Brightfield",
"ref_z": 0}
_defaults = {"tile_size": 117, "ref_channel": "Brightfield", "ref_z": 0}
###
# Alan: is this necessary?
class TilerABC(ProcessABC):
"""
Base class for different types of Tilers.
......@@ -182,10 +219,15 @@ class TilerABC(ProcessABC):
return trap
####
class Tiler(ProcessABC):
"""Remote Timelapse Tiler.
"""
Remote Timelapse Tiler.
Does trap finding and image registration. Fetches images from as erver
Finds traps and re-registers images if there is any drifting.
Fetches images from a server.
"""
def __init__(
......@@ -204,11 +246,15 @@ class Tiler(ProcessABC):
try:
self.z_perchannel = {
ch: metadata["zsectioning/nsections"] if zsect else 1
for zsect, ch in zip(metadata["channels"], metadata["channels/zsect"])
for zsect, ch in zip(
metadata["channels"], metadata["channels/zsect"]
)
}
except Exception as e:
print(f"Warning:Tiler: No z_perchannel data: {e}")
###
@classmethod
def from_image(cls, image, parameters: TilerParameters):
return cls(image.data, image.metadata, parameters)
......@@ -234,14 +280,26 @@ class Tiler(ProcessABC):
tiler.n_processed = len(trap_locs.drifts)
return tiler
###
@lru_cache(maxsize=2)
def get_tc(self, t, c):
# Get image
full = self.image[t, c].compute() # FORCE THE CACHE
return full
###
@property
def shape(self):
"""
Returns properties of the time-lapse experiment
no of channels
no of time points
no of z stacks
no of pixels in y direction
no of pixles in z direction
"""
c, t, z, y, x = self.image.shape
return (c, t, x, y, z)
......@@ -261,40 +319,63 @@ class Tiler(ProcessABC):
@property
def finished(self):
"""
Returns True if all channels have been processed
"""
return self.n_processed == self.image.shape[0]
###
def _initialise_traps(self, tile_size):
"""Find initial trap positions.
"""
Find initial trap positions.
Removes all those that are too close to the edge so no padding is necessary.
Removes all those that are too close to the edge so no padding
is necessary.
"""
half_tile = tile_size // 2
# max_size is the minimal no of x or y pixels
max_size = min(self.image.shape[-2:])
# first time point, first channel, first z-position
initial_image = self.image[
0, self.ref_channel, self.ref_z
] # First time point, first channel, first z-position
]
# find the traps
trap_locs = segment_traps(initial_image, tile_size)
# keep only traps that are not near an edge
trap_locs = [
[x, y]
for x, y in trap_locs
if half_tile < x < max_size - half_tile
and half_tile < y < max_size - half_tile
]
# store traps in an instance of TrapLocations
self.trap_locs = TrapLocations.from_tiler_init(trap_locs, tile_size)
# self.trap_locs = TrapLocations(trap_locs, tile_size)
###
def find_drift(self, tp):
# TODO check that the drift doesn't move any tiles out of the image, remove them from list if so
'''
Find any translational drifts between two images at consecutive
time points using cross correlation
'''
# TODO check that the drift doesn't move any tiles out of
# the image, remove them from list if so
prev_tp = max(0, tp - 1)
# cross-correlate
drift, error, _ = phase_cross_correlation(
self.image[prev_tp, self.ref_channel, self.ref_z],
self.image[tp, self.ref_channel, self.ref_z],
)
# store drift
if 0 < tp < len(self.trap_locs.drifts):
self.trap_locs.drifts[tp] = drift.tolist()
else:
self.trap_locs.drifts.append(drift.tolist())
###
def get_tp_data(self, tp, c):
traps = []
full = self.get_tc(tp, c)
......@@ -305,75 +386,51 @@ class Tiler(ProcessABC):
traps.append(ndtrap)
return np.stack(traps)
###
def get_trap_data(self, trap_id, tp, c):
full = self.get_tc(tp, c)
trap = self.trap_locs.traps[trap_id]
ndtrap = self.ifoob_pad(full, trap.as_range(tp))
return ndtrap
@staticmethod
def ifoob_pad(full, slices): # TODO Remove when inheriting TilerABC
"""
Returns the slices padded if it is out of bounds
Parameters:
----------
full: (zstacks, max_size, max_size) ndarray
Entire position with zstacks as first axis
slices: tuple of two slices
Each slice indicates an axis to index
Returns
Trap for given slices, padded with median if needed, or np.nan if the padding is too much
"""
max_size = full.shape[-1]
y, x = [slice(max(0, s.start), min(max_size, s.stop)) for s in slices]
trap = full[:, y, x]
padding = np.array(
[(-min(0, s.start), -min(0, max_size - s.stop)) for s in slices]
)
if padding.any():
tile_size = slices[0].stop - slices[0].start
if (padding > tile_size / 4).any():
trap = np.full((full.shape[0], tile_size, tile_size), np.nan)
else:
trap = np.pad(trap, [[0, 0]] + padding.tolist(), "median")
return trap
###
def run_tp(self, tp):
'''
Find traps if they have not yet been found.
Determine any translational drift of the current image from the
previous one.
'''
# assert tp >= self.n_processed, "Time point already processed"
# TODO check contiguity?
if self.n_processed == 0 or not hasattr(self.trap_locs, "drifts"):
self._initialise_traps(self.tile_size)
if hasattr(self.trap_locs, "drifts"):
drift_len = len(self.trap_locs.drifts)
if self.n_processed != drift_len:
raise (Exception("Tiler:N_processed and ndrifts don't match"))
raise Exception("Tiler:n_processed and ndrifts don't match")
self.n_processed = drift_len
self.find_drift(tp) # Get drift
# determine drift
self.find_drift(tp)
# update n_processed
self.n_processed = tp + 1
# Return result for writer
# return result for writer
return self.trap_locs.to_dict(tp)
# Alan !!!! this function is the same as the previous one !!!!!
def run(self, tp):
if self.n_processed == 0:
self._initialise_traps(self.tile_size)
self.find_drift(tp) # Get drift
# determine drift
self.find_drift(tp)
# update n_processed
self.n_processed += 1
# Return result for writer
# return result for writer
return self.trap_locs.to_dict(tp)
###
# The next set of functions are necessary for the extraction object
def get_traps_timepoint(self, tp, tile_size=None, channels=None, z=None):
# FIXME we currently ignore the tile size
......@@ -389,11 +446,52 @@ class Tiler(ProcessABC):
res.append(val)
return np.stack(res, axis=1)
###
def get_channel_index(self, item):
for i, ch in enumerate(self.channels):
if item in ch:
return i
###
def get_position_annotation(self):
# TODO required for matlab support
return None
###
@staticmethod
def ifoob_pad(full, slices): # TODO Remove when inheriting TilerABC
"""
Returns the slices padded if it is out of bounds
Parameters:
----------
full: (zstacks, max_size, max_size) ndarray
Entire position with zstacks as first axis
slices: tuple of two slices
Each slice indicates an axis to index
Returns
Trap for given slices, padded with median if needed, or np.nan if the padding is too much
"""
max_size = full.shape[-1]
y, x = [slice(max(0, s.start), min(max_size, s.stop)) for s in slices]
trap = full[:, y, x]
padding = np.array(
[(-min(0, s.start), -min(0, max_size - s.stop)) for s in slices]
)
if padding.any():
tile_size = slices[0].stop - slices[0].start
if (padding > tile_size / 4).any():
trap = np.full((full.shape[0], tile_size, tile_size), np.nan)
else:
trap = np.pad(trap, [[0, 0]] + padding.tolist(), "median")
return trap
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment