diff --git a/aliby/experiment.py b/aliby/experiment.py index 3871194217d4d14df61a44d3a3242de5ab848aa4..55e418b5beba2e0416cf6692053cb6760804af86 100644 --- a/aliby/experiment.py +++ b/aliby/experiment.py @@ -76,10 +76,13 @@ from dask import delayed def get_data_lazy(image) -> da.Array: - """Get 5D dask array, with delayed reading from OMERO image.""" + """ + Get 5D dask array, with delayed reading from OMERO image. + """ nt, nc, nz, ny, nx = [getattr(image, f"getSize{x}")() for x in "TCZYX"] pixels = image.getPrimaryPixels() dtype = PIXEL_TYPES.get(pixels.getPixelsType().value, None) + # using dask get_plane = delayed(lambda idx: pixels.getPlane(*idx)) def get_lazy_plane(zct): diff --git a/aliby/io/image.py b/aliby/io/image.py index b849169ae2fe4496fdb7b9f78bec4e4693b8c566..7a4f845b21f684c481b6dcc47fc995dac640dbf7 100644 --- a/aliby/io/image.py +++ b/aliby/io/image.py @@ -5,6 +5,10 @@ from datetime import datetime import xmltodict from tifffile import TiffFile + +# dask extends numpy to multi-core machines and distributed clusters +# and allows data to be stored that is larger than the RAM by +# sharing between RAM and a hard disk import dask.array as da from dask.array.image import imread @@ -77,23 +81,50 @@ class ImageLocal: class Image(Argo): - """""" + """ + Loads images from OMERO and gives access to the data and metadata. + """ def __init__(self, image_id, **server_info): + ''' + Establishes the connection to the OMERO server via the Argo + base class. + + Parameters + ---------- + image_id: integer + server_info: dictionary + Specifies the host, username, and password as strings + ''' super().__init__(**server_info) self.image_id = image_id + # images from OMERO self._image_wrap = None @property def image_wrap(self): - + ''' + Get images from OMERO + ''' if self._image_wrap is None: + # get images using OMERO self._image_wrap = self.conn.getObject("Image", self.image_id) return self._image_wrap - # Version with local file processing + # version with local file processing def get_data_lazy_local(path: str) -> da.Array: - """Return 5D dask array. For lazy-loading local multidimensional tiff files""" + """ + For lazy-loading - loading on demand only -- local, + multidimensional tiff files. + + Parameters + ---------- + path: string + + Returns + ------- + 5D dask array + """ return da.from_delayed(imread(str(path))[0], shape=()) @property @@ -106,6 +137,10 @@ class Image(Argo): @property def metadata(self): + """ + Store metadata saved in OMERO: image size, number of time points, + labels of channels, and image name. + """ meta = dict() meta["size_x"] = self.image_wrap.getSizeX() meta["size_y"] = self.image_wrap.getSizeY() diff --git a/aliby/io/omero.py b/aliby/io/omero.py index ab214f58f9edc839687cd681779fa1a5394beba4..bd7bfc83a0af83ec52d157457a5f20134a8bd89c 100644 --- a/aliby/io/omero.py +++ b/aliby/io/omero.py @@ -1,19 +1,32 @@ from omero.gateway import BlitzGateway - class Argo: """ - Base OMERO-interactive class + Base class to interact with OMERO. + See + https://docs.openmicroscopy.org/omero/5.6.0/developers/Python.html """ def __init__( - self, host="islay.bio.ed.ac.uk", username="upload", password="***REMOVED***" + self, + host="islay.bio.ed.ac.uk", + username="upload", + password="***REMOVED***", ): + """ + Parameters + ---------- + host : string + web address of OMERO host + username: string + password : string + """ self.conn = None self.host = host self.username = username self.password = password + # standard method required for Python's with statement def __enter__(self): self.conn = BlitzGateway( host=self.host, username=self.username, passwd=self.password @@ -21,6 +34,7 @@ class Argo: self.conn.connect() return self + # standard method required for Python's with statement def __exit__(self, *exc): self.conn.close() return False diff --git a/aliby/tile/tiler.py b/aliby/tile/tiler.py index 87ccf247d2be83da989b1029426ec40fd270adf4..fbb7d6e8c2a9a87a53725900ce47884e28bfdc5b 100644 --- a/aliby/tile/tiler.py +++ b/aliby/tile/tiler.py @@ -3,60 +3,72 @@ Includes splitting the image into traps/parts, cell segmentation, nucleus segmentation.""" import warnings from functools import lru_cache - import h5py import numpy as np - from pathlib import Path, PosixPath - from skimage.registration import phase_cross_correlation - -from agora.abc import ParametersABC, ProcessABC from aliby.tile.traps import segment_traps - +from agora.abc import ParametersABC, ProcessABC from agora.io.writer import load_attributes + +# Alan: is this necessary? trap_template_directory = Path(__file__).parent / "trap_templates" # TODO do we need multiple templates, one for each setup? -trap_template = np.array([]) # np.load(trap_template_directory / "trap_prime.npy") +trap_template = np.array([]) +# np.load(trap_template_directory / "trap_prime.npy") +# +# def get_tile_shapes(x, tile_size, max_shape): +# half_size = tile_size // 2 +# xmin = int(x[0] - half_size) +# ymin = max(0, int(x[1] - half_size)) +# if xmin + tile_size > max_shape[0]: +# xmin = max_shape[0] - tile_size +# if ymin + tile_size > max_shape[1]: +# ymin = max_shape[1] - tile_size +# return xmin, xmin + tile_size, ymin, ymin + tile_size -def get_tile_shapes(x, tile_size, max_shape): - half_size = tile_size // 2 - xmin = int(x[0] - half_size) - ymin = max(0, int(x[1] - half_size)) - if xmin + tile_size > max_shape[0]: - xmin = max_shape[0] - tile_size - if ymin + tile_size > max_shape[1]: - ymin = max_shape[1] - tile_size - return xmin, xmin + tile_size, ymin, ymin + tile_size - -###################### Dask versions ######################## class Trap: + ''' + Stores a trap's location and size. + Allows checks to see if the trap should be padded. + Can export the trap either in OMERO or numpy formats. + ''' def __init__(self, centre, parent, size, max_size): self.centre = centre - self.parent = parent # Used to access drifts + self.parent = parent # used to access drifts self.size = size self.half_size = size // 2 self.max_size = max_size + ### + def padding_required(self, tp): - """Check if we need to pad the trap image for this time point.""" + """ + Check if we need to pad the trap image for this time point. + """ try: assert all(self.at_time(tp) - self.half_size >= 0) assert all(self.at_time(tp) + self.half_size <= self.max_size) + # return False except AssertionError: return True return False def at_time(self, tp): - """Return trap centre at time tp""" + """ + Return trap centre at time tp by applying drifts + """ drifts = self.parent.drifts return self.centre - np.sum(drifts[: tp + 1], axis=0) + ### + def as_tile(self, tp): - """Return trap in the OMERO tile format of x, y, w, h + """ + Return trap in the OMERO tile format of x, y, w, h Also returns the padding necessary for this tile. """ @@ -66,62 +78,85 @@ class Trap: y = int(y - self.half_size) return x, y, self.size, self.size + ### + def as_range(self, tp): - """Return trap in a range format, two slice objects that can be used in Arrays""" + """ + Return trap in a range format: two slice objects that can + be used in Arrays + """ x, y, w, h = self.as_tile(tp) return slice(x, x + w), slice(y, y + h) +### class TrapLocations: - def __init__(self, initial_location, tile_size, max_size=1200, drifts=None): + ''' + Stores each trap as an instance of Trap. + Traps can be iterated. + ''' + def __init__( + self, initial_location, tile_size, max_size=1200, drifts=None + ): if drifts is None: drifts = [] self.tile_size = tile_size self.max_size = max_size self.initial_location = initial_location self.traps = [ - Trap(centre, self, tile_size, max_size) for centre in initial_location + Trap(centre, self, tile_size, max_size) + for centre in initial_location ] self.drifts = drifts - # @classmethod - # def from_source(cls, fpath: str): - # with h5py.File(fpath, "r") as f: - # # TODO read tile size from file metadata - # drifts = f["trap_info/drifts"][()].tolist() - # tlocs = cls(f["trap_info/trap_locations"][()], tile_size=96, drifts=drifts) - - # return tlocs - - @property - def shape(self): - return len(self.traps), len(self.drifts) - def __len__(self): return len(self.traps) def __iter__(self): yield from self.traps + ### + + @property + def shape(self): + ''' + Returns no of traps and no of drifts + ''' + return len(self.traps), len(self.drifts) + def padding_required(self, tp): + ''' + Check if any traps need padding + ''' return any([trap.padding_required(tp) for trap in self.traps]) def to_dict(self, tp): + ''' + Export inital locations, tile_size, max_size, and drifts + as a dictionary + ''' res = dict() if tp == 0: res["trap_locations"] = self.initial_location res["attrs/tile_size"] = self.tile_size res["attrs/max_size"] = self.max_size res["drifts"] = np.expand_dims(self.drifts[tp], axis=0) - # res["processed_timepoints"] = tp return res + ### + @classmethod def from_tiler_init(cls, initial_location, tile_size, max_size=1200): + ''' + Instantiate class from an instance of the Tiler class + ''' return cls(initial_location, tile_size, max_size, drifts=[]) @classmethod def read_hdf5(cls, file): + ''' + Instantiate class from a hdf5 file + ''' with h5py.File(file, "r") as hfile: trap_info = hfile["trap_info"] initial_locations = trap_info["trap_locations"][()] @@ -130,15 +165,17 @@ class TrapLocations: tile_size = trap_info.attrs["tile_size"] trap_locs = cls(initial_locations, tile_size, max_size=max_size) trap_locs.drifts = drifts - # trap_locs.n_processed = len(drifts) return trap_locs +### class TilerParameters(ParametersABC): + _defaults = {"tile_size": 117, + "ref_channel": "Brightfield", + "ref_z": 0} - _defaults = {"tile_size": 117, "ref_channel": "Brightfield", "ref_z": 0} - - +### +# Alan: is this necessary? class TilerABC(ProcessABC): """ Base class for different types of Tilers. @@ -182,10 +219,15 @@ class TilerABC(ProcessABC): return trap +#### + + class Tiler(ProcessABC): - """Remote Timelapse Tiler. + """ + Remote Timelapse Tiler. - Does trap finding and image registration. Fetches images from as erver + Finds traps and re-registers images if there is any drifting. + Fetches images from a server. """ def __init__( @@ -204,11 +246,15 @@ class Tiler(ProcessABC): try: self.z_perchannel = { ch: metadata["zsectioning/nsections"] if zsect else 1 - for zsect, ch in zip(metadata["channels"], metadata["channels/zsect"]) + for zsect, ch in zip( + metadata["channels"], metadata["channels/zsect"] + ) } except Exception as e: print(f"Warning:Tiler: No z_perchannel data: {e}") + ### + @classmethod def from_image(cls, image, parameters: TilerParameters): return cls(image.data, image.metadata, parameters) @@ -234,14 +280,26 @@ class Tiler(ProcessABC): tiler.n_processed = len(trap_locs.drifts) return tiler + ### + @lru_cache(maxsize=2) def get_tc(self, t, c): # Get image full = self.image[t, c].compute() # FORCE THE CACHE return full + ### + @property def shape(self): + """ + Returns properties of the time-lapse experiment + no of channels + no of time points + no of z stacks + no of pixels in y direction + no of pixles in z direction + """ c, t, z, y, x = self.image.shape return (c, t, x, y, z) @@ -261,40 +319,63 @@ class Tiler(ProcessABC): @property def finished(self): + """ + Returns True if all channels have been processed + """ return self.n_processed == self.image.shape[0] + ### + def _initialise_traps(self, tile_size): - """Find initial trap positions. + """ + Find initial trap positions. - Removes all those that are too close to the edge so no padding is necessary. + Removes all those that are too close to the edge so no padding + is necessary. """ half_tile = tile_size // 2 + # max_size is the minimal no of x or y pixels max_size = min(self.image.shape[-2:]) + # first time point, first channel, first z-position initial_image = self.image[ 0, self.ref_channel, self.ref_z - ] # First time point, first channel, first z-position + ] + # find the traps trap_locs = segment_traps(initial_image, tile_size) + # keep only traps that are not near an edge trap_locs = [ [x, y] for x, y in trap_locs if half_tile < x < max_size - half_tile and half_tile < y < max_size - half_tile ] + # store traps in an instance of TrapLocations self.trap_locs = TrapLocations.from_tiler_init(trap_locs, tile_size) - # self.trap_locs = TrapLocations(trap_locs, tile_size) + + + ### def find_drift(self, tp): - # TODO check that the drift doesn't move any tiles out of the image, remove them from list if so + ''' + Find any translational drifts between two images at consecutive + time points using cross correlation + ''' + # TODO check that the drift doesn't move any tiles out of + # the image, remove them from list if so prev_tp = max(0, tp - 1) + # cross-correlate drift, error, _ = phase_cross_correlation( self.image[prev_tp, self.ref_channel, self.ref_z], self.image[tp, self.ref_channel, self.ref_z], ) + # store drift if 0 < tp < len(self.trap_locs.drifts): self.trap_locs.drifts[tp] = drift.tolist() else: self.trap_locs.drifts.append(drift.tolist()) + ### + def get_tp_data(self, tp, c): traps = [] full = self.get_tc(tp, c) @@ -305,75 +386,51 @@ class Tiler(ProcessABC): traps.append(ndtrap) return np.stack(traps) + ### + def get_trap_data(self, trap_id, tp, c): full = self.get_tc(tp, c) trap = self.trap_locs.traps[trap_id] ndtrap = self.ifoob_pad(full, trap.as_range(tp)) - return ndtrap - @staticmethod - def ifoob_pad(full, slices): # TODO Remove when inheriting TilerABC - """ - Returns the slices padded if it is out of bounds - - Parameters: - ---------- - full: (zstacks, max_size, max_size) ndarray - Entire position with zstacks as first axis - slices: tuple of two slices - Each slice indicates an axis to index - - - Returns - Trap for given slices, padded with median if needed, or np.nan if the padding is too much - """ - max_size = full.shape[-1] - - y, x = [slice(max(0, s.start), min(max_size, s.stop)) for s in slices] - trap = full[:, y, x] - - padding = np.array( - [(-min(0, s.start), -min(0, max_size - s.stop)) for s in slices] - ) - if padding.any(): - tile_size = slices[0].stop - slices[0].start - if (padding > tile_size / 4).any(): - trap = np.full((full.shape[0], tile_size, tile_size), np.nan) - else: - - trap = np.pad(trap, [[0, 0]] + padding.tolist(), "median") - - return trap + ### def run_tp(self, tp): + ''' + Find traps if they have not yet been found. + Determine any translational drift of the current image from the + previous one. + ''' # assert tp >= self.n_processed, "Time point already processed" # TODO check contiguity? if self.n_processed == 0 or not hasattr(self.trap_locs, "drifts"): self._initialise_traps(self.tile_size) - if hasattr(self.trap_locs, "drifts"): drift_len = len(self.trap_locs.drifts) - if self.n_processed != drift_len: - raise (Exception("Tiler:N_processed and ndrifts don't match")) + raise Exception("Tiler:n_processed and ndrifts don't match") self.n_processed = drift_len - - self.find_drift(tp) # Get drift + # determine drift + self.find_drift(tp) # update n_processed self.n_processed = tp + 1 - # Return result for writer + # return result for writer return self.trap_locs.to_dict(tp) + # Alan !!!! this function is the same as the previous one !!!!! def run(self, tp): if self.n_processed == 0: self._initialise_traps(self.tile_size) - self.find_drift(tp) # Get drift + # determine drift + self.find_drift(tp) # update n_processed self.n_processed += 1 - # Return result for writer + # return result for writer return self.trap_locs.to_dict(tp) + ### + # The next set of functions are necessary for the extraction object def get_traps_timepoint(self, tp, tile_size=None, channels=None, z=None): # FIXME we currently ignore the tile size @@ -389,11 +446,52 @@ class Tiler(ProcessABC): res.append(val) return np.stack(res, axis=1) + ### + def get_channel_index(self, item): for i, ch in enumerate(self.channels): if item in ch: return i + ### + def get_position_annotation(self): # TODO required for matlab support return None + + + ### + + @staticmethod + def ifoob_pad(full, slices): # TODO Remove when inheriting TilerABC + """ + Returns the slices padded if it is out of bounds + + Parameters: + ---------- + full: (zstacks, max_size, max_size) ndarray + Entire position with zstacks as first axis + slices: tuple of two slices + Each slice indicates an axis to index + + + Returns + Trap for given slices, padded with median if needed, or np.nan if the padding is too much + """ + max_size = full.shape[-1] + + y, x = [slice(max(0, s.start), min(max_size, s.stop)) for s in slices] + trap = full[:, y, x] + + padding = np.array( + [(-min(0, s.start), -min(0, max_size - s.stop)) for s in slices] + ) + if padding.any(): + tile_size = slices[0].stop - slices[0].start + if (padding > tile_size / 4).any(): + trap = np.full((full.shape[0], tile_size, tile_size), np.nan) + else: + + trap = np.pad(trap, [[0, 0]] + padding.tolist(), "median") + + return trap