diff --git a/src/aliby/utils/imageViewer.py b/src/aliby/utils/imageViewer.py index 52a15ee763f776aba369276e34748b1becff76c0..f3331983dbe73755f8ed7f17f7fcd69e1f87a3fa 100644 --- a/src/aliby/utils/imageViewer.py +++ b/src/aliby/utils/imageViewer.py @@ -1,24 +1,7 @@ -""" -ImageViewer class, used to look at individual or multiple traps over time. - - -Example of usage: - -fpath = "/home/alan/data/16543/URA8_young018.h5" - -tile_id = 9 -trange = list(range(0, 10)) -ncols = 8 - -riv = remoteImageViewer(fpath) -riv.plot_labelled_trap(tile_id, trange, [0], ncols=ncols) - -""" - import re import typing as t from abc import ABC - +from pathlib import Path import h5py import matplotlib.pyplot as plt import numpy as np @@ -27,13 +10,9 @@ from skimage.morphology import dilation from agora.io.cells import Cells from agora.io.metadata import dispatch_metadata_parser -from aliby.io.image import ImageDir, ImageZarr, dispatch_image +from aliby.io.image import dispatch_image -try: - from aliby.io.omero import UnsafeImage as OImage -except ModuleNotFoundError: - print("Viewing available only for local files.") -from aliby.tile.tiler import Tiler, TilerParameters +from aliby.tile.tiler import Tiler from aliby.utils.plot import stretch_clip default_colours = { @@ -59,18 +38,23 @@ def custom_imshow(a, norm=None, cmap=None, *args, **kwargs): class BaseImageViewer(ABC): - def __init__(self, fpath): - self._fpath = fpath - self.attrs = dispatch_metadata_parser(fpath.parent) - self._logfiles_meta = {} - self.image_id = self.attrs.get("image_id") + """Base class with routines common to all ImageViewers.""" + + def __init__(self, h5file_path): + """Initialise from a Path to a h5 file.""" + self.h5file_path = h5file_path + self.logfiles_meta = dispatch_metadata_parser(h5file_path.parent) + self.image_id = self.logfiles_meta.get("image_id") if self.image_id is None: - with h5py.File(fpath, "r") as f: + with h5py.File(h5file_path, "r") as f: self.image_id = f.attrs.get("image_id") - assert self.image_id is not None, "No valid image_id found in metadata" + if self.image_id is None: + raise ("No valid image_id found in metadata.") + self.full = {} @property def shape(self): + """Return shape of image array.""" return self.tiler.image.shape @property @@ -87,193 +71,112 @@ class BaseImageViewer(ABC): """Find cell label at a given time point.""" return self.cells.labels_at_time(tp) - -class LocalImageViewer(BaseImageViewer): - """ - Generate figures from local files. - - File are either zarr or organised in directories. - TODO move common functionality from RemoteImageViewer to BaseImageViewer - """ - - def __init__(self, h5file_path: str, data_path: str): - super().__init__(h5file_path) - self._image_class = ( - ImageZarr if str(data_path).endswith(".zarr") else ImageDir - ) - image = ImageZarr(data_path) - self.tiler = Tiler( - image.data, - self._meta if hasattr(self, "_meta") else self._logfiles_meta, - TilerParameters.default(), - ) - self.cells = Cells.from_source(h5file_path) - - -class RemoteImageViewer(BaseImageViewer): - """ - This ImageViewer combines fetching remote images with tiling and outline display. - """ - - _credentials = ("host", "username", "password") - - def __init__( - self, - results_path: str, - server_info: t.Dict[str, str], + def find_channel_indices( + self, channels: t.Union[str, t.Collection[str]], guess=True ): - super().__init__(results_path) - self._server_info = server_info or { - k: self.attrs["parameters"]["general"][k] - for k in self._credentials - } - self._image_instance = OImage(self.image_id, **self._server_info) - self.tiler = Tiler.from_h5(self._image_instance, results_path) - self.cells = Cells.from_source(results_path) - - def random_valid_trap_tp( - self, - min_ncells: int = None, - min_consecutive_tps: int = None, - label_modulo: int = None, - ): - # Call Cells convenience function to pick a random trap and tp - # containing cells for x cells for y - return self.cells.random_valid_trap_tp( - min_ncells=min_ncells, - min_consecutive_tps=min_consecutive_tps, - ) - - def get_entire_position(self): - raise (NotImplementedError) - - def get_position_timelapse(self): - raise (NotImplementedError) - - @property - def full(self): - if not hasattr(self, "_full"): - self._full = {} - return self._full - - def get_tc(self, tp, channel=None, server_info=None): - server_info = server_info or self._server_info - channel = channel or self.tiler.ref_channel - - with self._image_class(self.image_id, **server_info) as image: - self.tiler.image = image.data - return self.tiler.load_image(tp, channel) - - def _find_channels(self, channels: str, guess: bool = True): + """Find index for particular channels.""" channels = channels or self.tiler.ref_channel if isinstance(channels, (int, str)): channels = [channels] if isinstance(channels[0], str): if guess: - channels = [self.tiler.channels.index(ch) for ch in channels] + indices = [self.tiler.channels.index(ch) for ch in channels] else: - channels = [ + indices = [ re.search(ch, tiler_channels) for ch in channels for tiler_channels in self.tiler.channels ] + return indices + else: + return channels + + def get_outlines_tiles_dict(self, tile_id, trange, channels): + """Get outlines and dict of tiles with channel indices as keys.""" + outlines = None + tile_dict = {} + for ch in self.find_channel_indices(channels): + outlines, tile_dict[ch] = self.get_outlines_tiles( + tile_id, trange, channels=[ch] + ) + return outlines, tile_dict + + def get_outlines_tiles( + self, + tile_id: int, + tps: t.Union[range, t.Collection[int]], + channels=None, + concatenate=True, + **kwargs, + ) -> t.Tuple[np.array]: + """ + Get masks uniquely labelled for each cell with the corresponding tiles. - return channels + Returns a list of masks, each an array with distinct masks for each cell, + and an array of tiles for the given channel. + """ + tile_dict = self.get_tiles(tps, channels=channels, **kwargs) + # get tiles of interest + tiles = [x[tile_id] for x in tile_dict.values()] + # get outlines for each time point + outlines = [ + self.cells.at_time(tp, kind="edgemask").get(tile_id, []) for tp in tps + ] + # get cell labels for each time point + cell_labels = [self.cells.labels_at_time(tp).get(tile_id, []) for tp in tps] + # generate one image with all cell outlines uniquely labelled per tile + labelled_outlines = [ + np.stack( + [outline * label for outline, label in zip(outlines_tp, labels_tp)] + ).max(axis=0) + if len(labels_tp) + else np.zeros_like(tiles[0]).astype(bool) + for outlines_tp, labels_tp in zip(outlines, cell_labels) + ] + if concatenate: + # concatenate to allow potential image processing + labelled_outlines = np.concatenate(labelled_outlines, axis=1) + tiles = np.concatenate(tiles, axis=1) + return labelled_outlines, tiles - def get_pos_timepoints( + def get_tiles( self, tps: t.Union[int, t.Collection[int]], - channels: t.Union[str, t.Collection[str]] = None, + channels: None, z: int = None, - server_info=None, ): + """Get dict with time points as keys and all available tiles as values.""" if tps and not isinstance(tps, t.Collection): tps = range(tps) - - # TODO add support for multiple channels or refactor - if channels and not isinstance(channels, t.Collection): - channels = [channels] - if z is None: z = 0 - - server_info = server_info or self._server_info - channels = 0 or self._find_channels(channels) z = z or self.tiler.ref_z - - ch_tps = [(channels[0], tp) for tp in tps] - - image = self._image_instance - self.tiler.image = image.data + channel_indices = self.find_channel_indices(channels) + ch_tps = [(channel_indices[0], tp) for tp in tps] for ch, tp in ch_tps: if (ch, tp) not in self.full: self.full[(ch, tp)] = self.tiler.get_tiles_timepoint( tp, channels=[ch], z=[z] )[:, 0, 0, z, ...] - requested_trap = {tp: self.full[(ch, tp)] for ch, tp in ch_tps} - - return requested_trap - - def get_labelled_trap( - self, - tile_id: int, - tps: t.Union[range, t.Collection[int]], - channels=None, - concatenate=True, - **kwargs, - ) -> t.Tuple[np.array]: - """ - Core method to fetch traps and labels together - """ - imgs = self.get_pos_timepoints(tps, channels=channels, **kwargs) - imgs_list = [x[tile_id] for x in imgs.values()] - outlines = [ - self.cells.at_time(tp, kind="edgemask").get(tile_id, []) - for tp in tps - ] - lbls = [self.cells.labels_at_time(tp).get(tile_id, []) for tp in tps] - lbld_outlines = [ - np.stack([mask * lbl for mask, lbl in zip(maskset, lblset)]).max( - axis=0 - ) - if len(lblset) - else np.zeros_like(imgs_list[0]).astype(bool) - for maskset, lblset in zip(outlines, lbls) - ] - if concatenate: - lbld_outlines = np.concatenate(lbld_outlines, axis=1) - imgs_list = np.concatenate(imgs_list, axis=1) - return lbld_outlines, imgs_list - - def get_images(self, tile_id, trange, channels, **kwargs): - """ - Wrapper to fetch images - """ - out = None - imgs = {} - - for ch in self._find_channels(channels): - out, imgs[ch] = self.get_labelled_trap( - tile_id, trange, channels=[ch], **kwargs - ) - return out, imgs + tile_dict = {tp: self.full[(ch, tp)] for ch, tp in ch_tps} + return tile_dict def plot_labelled_trap( self, - tile_id: int, + tile_id, channels, trange: t.Union[range, t.Collection[int]], remove_axis: bool = False, - savefile: str = None, skip_outlines: bool = False, - norm: str = None, + norm=True, ncols: int = None, local_colours: bool = True, img_plot_kwargs: dict = {}, lbl_plot_kwargs: dict = {"alpha": 0.8}, **kwargs, ): - """Wrapper to plot time-lapses of individual traps + """ + Plot time-lapses of individual tiles. Use Cells and Tiler to generate images of cells with their resulting outlines. @@ -282,15 +185,13 @@ class RemoteImageViewer(BaseImageViewer): ---------- tile_id : int Identifier of trap - channels : Union[str, int] + channel : Union[str, int] Channels to use trange : t.Union[range, t.Collection[int]] Range or collection indicating the time-points to use. remove_axis : bool None, "off", or "x". Determines whether to remove the x-axis, both axes or none. - savefile : str - Saves file to a location. skip_outlines : bool Do not add overlay with outlines norm : str @@ -304,68 +205,40 @@ class RemoteImageViewer(BaseImageViewer): Arguments to pass to plt.imshow used for images. lbl_plot_kwargs : dict Keyword arguments to pass to label plots. - **kwargs : dict - Additional keyword arguments passed to ImageViewer.get_images. - - Examples - -------- - FIXME: Add docs. - """ + # set up for plotting if ncols is None: ncols = len(trange) nrows = int(np.ceil(len(trange) / ncols)) width = self.tiler.tile_size * ncols - - out, images = self.get_images(tile_id, trange, channels, **kwargs) - - # dilation makes outlines easier to see - out = dilation(out).astype(float) - out[out == 0] = np.nan - + outlines, tiles_dict = self.get_outlines_tiles_dict(tile_id, trange, channels) channel_labels = [ - self.tiler.channels[ch] if isinstance(ch, int) else ch - for ch in channels + self.tiler.channels[ch] if isinstance(ch, int) else ch for ch in channels + ] + # dilate to make outlines easier to see + outlines = dilation(outlines).astype(float) + outlines[outlines == 0] = np.nan + # split concatenated tiles into one tile per time point in a row + tiles = [ + into_image_time_series(tile, width, nrows) for tile in tiles_dict.values() ] - - assert not norm or norm in ( - "l1", - "l2", - "max", - ), "Invalid norm argument." - - if norm and norm in ("l1", "l2", "max"): - images = {k: stretch_clip(v) for k, v in images.items()} - - images = [concat_pad(img, width, nrows) for img in images.values()] # TODO convert to RGB to draw fluorescence with colour - tiled_imgs = {} - tiled_imgs["img"] = np.concatenate(images, axis=0) - tiled_imgs["cell_labels"] = np.concatenate( - [concat_pad(out, width, nrows) for _ in images], axis=0 - ) - - custom_imshow( - tiled_imgs["img"], - **img_plot_kwargs, + res = {} + # concatenate different channels vertically for display + res["tiles"] = np.concatenate(tiles, axis=0) + res["cell_labels"] = np.concatenate( + [into_image_time_series(outlines, width, nrows) for _ in tiles], axis=0 ) + custom_imshow(res["tiles"], **img_plot_kwargs) custom_imshow( - tiled_imgs["cell_labels"], - cmap=sns.color_palette("Paired", as_cmap=True), - **lbl_plot_kwargs, + res["cell_labels"], cmap=default_colours["cell_label"], **lbl_plot_kwargs ) - if remove_axis is True: plt.axis("off") elif remove_axis == "x": plt.tick_params( - axis="x", - which="both", - bottom=False, - top=False, - labelbottom=False, + axis="x", which="both", bottom=False, top=False, labelbottom=False ) - if remove_axis != "True": plt.yticks( ticks=[ @@ -375,39 +248,63 @@ class RemoteImageViewer(BaseImageViewer): ], labels=channel_labels, ) - if not remove_axis: xlabels = ( - ["+ {} ".format(i) for i in range(ncols)] - if nrows > 1 - else list(trange) + ["+ {} ".format(i) for i in range(ncols)] if nrows > 1 else list(trange) ) plt.xlabel("Time-point") - plt.xticks( ticks=[self.tiler.tile_size * (i + 0.5) for i in range(ncols)], labels=xlabels, ) + if not np.any(outlines): + print("ImageViewer:Warning: No cell outlines found.") + plt.tight_layout() + plt.show(block=False) - if not np.any(out): - print("ImageViewer:Warning:No cell outlines found") - - if savefile: - plt.savefig(savefile, bbox_inches="tight", dpi=300) - plt.close() - else: - plt.show() - -def concat_pad(a: np.array, width, nrows): +class LocalImageViewer(BaseImageViewer): """ - Melt an array into having multiple blocks as rows + View images from local files. + + File are either zarr or organised in directories. """ + + def __init__(self, h5file: str, image_direc: str): + """Initialise using a h5file and a local directory of images.""" + h5file_path = Path(h5file) + image_direc_path = Path(image_direc) + super().__init__(h5file_path) + with dispatch_image(image_direc_path)(image_direc_path) as image: + self.tiler = Tiler.from_h5(image, h5file_path) + self.cells = Cells.from_source(h5file_path) + + +class RemoteImageViewer(BaseImageViewer): + """Fetching remote images with tiling and outline display.""" + + credentials = ("host", "username", "password") + + def __init__(self, h5file: str, server_info: t.Dict[str, str]): + """Initialise using a h5file and importing aliby.io.omero.""" + from aliby.io.omero import UnsafeImage as OImage + + h5file_path = Path(h5file) + super().__init__(h5file_path) + self.server_info = server_info or { + k: self.attrs["parameters"]["general"][k] for k in self.credentials + } + image = OImage(self.image_id, **self._server_info) + self.tiler = Tiler.from_h5(image, h5file_path) + self.cells = Cells.from_source(h5file_path) + + +def into_image_time_series(a: np.array, width, nrows): + """Split into sub-arrays and then concatenate into one.""" return np.concatenate( np.array_split( np.pad( a, - # ((0, 0), (0, width - (a.shape[1] % width))), ((0, 0), (0, a.shape[1] % width)), constant_values=np.nan, ),