From df4db391887ee4bd3e3f42ac1de3f39b5c4cecb7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <amuoz@ed.ac.uk> Date: Fri, 24 Jun 2022 17:12:53 +0100 Subject: [PATCH] add support for multiple channels --- aliby/utils/imageViewer.py | 345 +++++++++++++++++++++++-------------- pyproject.toml | 5 +- 2 files changed, 221 insertions(+), 129 deletions(-) diff --git a/aliby/utils/imageViewer.py b/aliby/utils/imageViewer.py index 5d5f622b..6a60ad18 100644 --- a/aliby/utils/imageViewer.py +++ b/aliby/utils/imageViewer.py @@ -11,10 +11,13 @@ trange = list(range(0, 30)) ncols = 8 riv = remoteImageViewer(fpath) -riv.plot_labeled_traps(trap_id, trange, ncols) +riv.plot_labelled_traps(trap_id, trange, ncols) """ +import re +import typing as t + import matplotlib.pyplot as plt import numpy as np import yaml @@ -25,6 +28,30 @@ from skimage.morphology import dilation from aliby.io.image import Image as OImage from aliby.tile.tiler import Tiler +from aliby.tile.traps import stretch_image + +default_colours = { + "Brightfield": "Greys_r", + "GFP": "Greens_r", + "mCherry": "Reds_r", + "cell_label": "Set1", +} + + +def custom_imshow(a, norm=None, cmap=None, *args, **kwargs): + """ + Wrapper on plt.imshow function. + """ + if cmap is None: + cmap = "Greys_r" + return plt.imshow( + a, + *args, + cmap=cmap, + interpolation=None, + interpolation_stage="rgba", + **kwargs, + ) class localImageViewer: @@ -96,19 +123,16 @@ class remoteImageViewer: # Print cell label at a given time-point return self.cells.labels_at_time(tp) - def nonempty_traps_at_time(self, tp: int): - # Print cell label at a given time-point - return [ - trap_id - for trap_id, labels in self.labels_at_time(tp).items() - if len(labels) - ] - - def nonempty_tps_in_trap(self, trap_id: int): - # Print cell label at a given time-point - return self.cells.labels_at_time(trap_id) + def random_valid_trap_tp( + self, min_ncells: int = None, min_consecutive_tps: int = None + ): + # Call Cells convenience function to pick a random trap and tp + # containing cells for x cells for y + return self.cells.random_valid_trap_tp( + min_ncells=min_ncells, min_consecutive_tps=min_consecutive_tps + ) - def get_position(self): + def get_entire_position(self): raise (NotImplementedError) def get_position_timelapse(self): @@ -128,48 +152,76 @@ class remoteImageViewer: self.tiler.image = image.data return self.tiler.get_tc(tp, channel) - def _find_channels(self, channels): + def _find_channels(self, channels: str, guess: bool = True): channels = channels or self.tiler.ref_channel - if isinstance(channels, int): + if isinstance(channels, (int, str)): channels = [channels] - elif isinstance(channels, list) and isinstance(channels[0], str): - channels = [self.tiler.channels.index(ch) for ch in channels] + if isinstance(channels[0], str): + if guess: + channels = [self.tiler.channels.index(ch) for ch in channels] + else: + channels = [ + re.search(ch, tiler_channels) + for ch in channels + for tiler_channels in self.tiler.channels + ] return channels - def get_trap_timepoints( - self, trap_id, tps, channels=None, z=None, server_info=None + def get_pos_timepoints( + self, + tps: t.Union[int, t.Collection[int]], + channels: t.Union[str, t.Collection[str]] = None, + z: int = None, + server_info=None, + **kwargs, ): + + if tps and not isinstance(tps, t.Collection): + tps = range(tps) + + # TODO add support for multiple channels or refactor + if channels and not isinstance(channels, t.Collection): + channels = [channels] + + # if z and isinstance(z, t.Collection): + # z = list(z) + if z is None: + z = 0 + server_info = server_info or self.server_info - channels = self._find_channels(channels) + channels = 0 or self._find_channels(channels) z = z or self.tiler.ref_z - ch_tps = set([(channels[0], tp) for tp in tps]) + ch_tps = [(channels[0], tp) for tp in tps] with OImage(self.image_id, **server_info) as image: self.tiler.image = image.data - if ch_tps.difference(self.full.keys()): - tps = set(tps).difference(self.full.keys()) - for ch, tp in ch_tps: - if z is 0: - self.full[(ch, tp)] = self.tiler.get_traps_timepoint( - tp, channels=[ch], z=[z] - )[:, 0, 0, ..., 0] - else: - self.full[(ch, tp)] = self.tiler.get_traps_timepoint( - tp, channels=[ch], z=[z] - )[:, 0, 0, ..., z] + # if ch_tps.difference(self.full.keys()): + # tps = set(tps).difference(self.full.keys()) + for ch, tp in ch_tps: + if (ch, tp) not in self.full: + self.full[(ch, tp)] = self.tiler.get_traps_timepoint( + tp, channels=[ch], z=[z] + )[:, 0, 0, ..., z] requested_trap = {tp: self.full[(ch, tp)] for ch, tp in ch_tps} return requested_trap - def get_labeled_trap(self, trap_id, tps, **kwargs): - imgs = self.get_trap_timepoints(trap_id, tps, **kwargs) + def get_labelled_trap( + self, + trap_id: int, + tps: t.Union[range, t.Collection[int]], + **kwargs, + ) -> t.Tuple[np.array]: + imgs = self.get_pos_timepoints(tps, **kwargs) imgs_list = [x[trap_id] for x in imgs.values()] outlines = [ self.cells.at_time(tp, kind="edgemask").get(trap_id, []) for tp in tps ] - lbls = [self.cells.labels_at_time(tp).get(trap_id, []) for tp in tps] + lbls = [ + self.cells.labels_at_time(tp - 1).get(trap_id, []) for tp in tps + ] lbld_outlines = [ np.dstack([mask * lbl for mask, lbl in zip(maskset, lblset)]).max( axis=2 @@ -190,119 +242,158 @@ class remoteImageViewer: imgs = {} for ch in self._find_channels(channels): - out, imgs[ch] = self.get_labeled_trap( + out, imgs[ch] = self.get_labelled_trap( trap_id, trange, channels=[ch], **kwargs ) return out, imgs - def plot_labeled_zstacks( - self, trap_id, channels, trange, z=None, **kwargs + # def plot_labelled_zstacks( + # self, trap_id, channels, trange, z=None, **kwargs + # ): + # # if z is None: + # # z = + # out, images = self.get_images(trap_id, trange, channels, z=z, **kwargs) + + def plot_labelled_trap( + self, + trap_id: int, + channels, + trange: t.Union[range, t.Collection[int]], + remove_axis: bool = False, + savefile: str = None, + skip_outlines: bool = False, + norm: str = None, + ncols: int = None, + img_plot_kwargs: dict = {}, + lbl_plot_kwargs: dict = {}, + **kwargs, ): - # if z is None: - # z = - out, images = self.get_images(trap_id, trange, channels, z=z, **kwargs) + if ncols is None: + ncols = len(trange) + nrows = int(np.ceil(len(trange) / ncols)) + width = self.tiler.tile_size * ncols - def plot_labeled_channelrows(self, trap_id, channels, trange, **kwargs): out, images = self.get_images(trap_id, trange, channels, **kwargs) # dilation makes outlines easier to see out = dilation(out).astype(float) out[out == 0] = np.nan - img_set = np.concatenate([v for v in imgs.values()], axis=0) - tiled_out = np.tile(out, (len(imgs), 1)) - plt.imshow( - img_set, - interpolation=None, - cmap="Greys_r", + channel_labels = [ + self.tiler.channels[ch] if isinstance(ch, int) else ch + for ch in channels + ] + + assert not norm or norm in ( + "l1", + "l2", + "max", + ), "Invalid norm argument." + + if norm and norm in ("l1", "l2", "max"): + images = {k: stretch_image(v) for k, v in images.items()} + + # images = [concat_pad(img, width, nrows) for img in images.values()] + images = [concat_pad(img, width, nrows) for img in images.values()] + tiled_imgs = {} + tiled_imgs["img"] = np.concatenate(images, axis=0) + tiled_imgs["cell_labels"] = np.concatenate( + [concat_pad(out, width, nrows) for _ in images], axis=0 ) - plt.imshow( - tiled_out, + + custom_imshow( + tiled_imgs["img"], + **img_plot_kwargs, + ) + custom_imshow( + tiled_imgs["cell_labels"], cmap="Set1", - interpolation=None, + **lbl_plot_kwargs, ) plt.yticks( ticks=[ - self.tiler.tile_size * (i + 0.5) for i in range(len(channels)) - ], - labels=[ - self.tiler.channels[ch] if isinstance(ch, int) else ch - for ch in channels + self.tiler.tile_size * (i + 0.5) + + (i * self.tiler.tile_size * nrows) + for i in range(len(channels)) ], + labels=channel_labels, ) plt.xticks( - ticks=[ - self.tiler.tile_size * (i + 0.5) for i in range(len(trange)) - ], - labels=[t for t in trange], + ticks=[self.tiler.tile_size * (i + 0.5) for i in range(ncols)], + labels=["+ {} ".format(i) for i in range(ncols)], ) + plt.xlabel("Additional time-points") plt.show() - def plot_labeled_traps( - self, - trap_id, - trange, - ncols, - remove_axis=False, - savefile=False, - skip_outlines=False, - **kwargs, - ): - """ - Wrapper to plot a single trap over time - - Parameters - --------- - :trap_id: int trap identification - :trange: list list of time points to fetch - """ - nrows = len(trange) // ncols - width = self.tiler.tile_size * ncols - out, img = self.get_labeled_trap(trap_id, trange, **kwargs) - - # dilation makes outlines easier to see - out = dilation(out).astype(float) - out[out == 0] = np.nan - - def concat_pad(array): - return np.concatenate( - np.array_split( - np.pad( - array, - ((0, 0), (0, array.shape[1] % width)), - constant_values=np.nan, - ), - nrows, - axis=1, - ) - ) - - plt.imshow( - concat_pad(img), - interpolation=None, - cmap="Greys_r", + # def plot_labelled_trap( + # self, + # trap_id: int, + # trange: t.Union[range, t.Collection[int]], + # ncols: int, + # remove_axis: bool = False, + # savefile: str = None, + # skip_outlines: bool = False, + # **kwargs, + # ): + # """ + # Wrapper to plot a single trap over time + + # Parameters + # --------- + # :trap_id: int trap identification + # :trange: Collection or Range list of time points to fetch + # """ + # nrows = len(trange) // ncols + # width = self.tiler.tile_size * ncols + # out, img = self.get_labelled_trap(trap_id, trange, **kwargs) + + # # dilation makes outlines easier to see + # out = dilation(out).astype(float) + # out[out == 0] = np.nan + + # # interpolation_kwargs = {""} + + # custom_imshow( + # concat_pad(img), + # cmap="Greys_r", + # ) + # if not skip_outlines: + # custom_imshow( + # concat_pad(out), + # cmap="Set1", + # ) + + # bbox_inches = None + # if remove_axis: + # plt.axis("off") + # bbox_inches = "tight" + + # else: + # plt.yticks( + # ticks=[self.tiler.tile_size * (i + 0.5) for i in range(nrows)], + # labels=[trange[0] + ncols * i for i in range(nrows)], + # ) + + # if not savefile: + # plt.show() + # else: + # if np.any(out): + # plt.savefig(savefile, bbox_inches=bbox_inches) + + +def concat_pad(a: np.array, width, nrows): + """ + Melt an array into having multiple blocks as rows + """ + return np.concatenate( + np.array_split( + np.pad( + a, + # ((0, 0), (0, width - (a.shape[1] % width))), + ((0, 0), (0, a.shape[1] % width)), + constant_values=np.nan, + ), + nrows, + axis=1, ) - if not skip_outlines: - plt.imshow( - concat_pad(out), - # concat_pad(mask), - cmap="Set1", - interpolation=None, - ) - - bbox_inches = None - if remove_axis: - plt.axis("off") - bbox_inches = "tight" - - else: - plt.yticks( - ticks=[self.tiler.tile_size * (i + 0.5) for i in range(nrows)], - labels=[trange[0] + ncols * i for i in range(nrows)], - ) - - if not savefile: - plt.show() - else: - if np.any(out): - plt.savefig(savefile, bbox_inches=bbox_inches) + ) diff --git a/pyproject.toml b/pyproject.toml index f65420af..7bcbff7c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,13 +24,14 @@ requests-toolbelt = "^0.9.1" h5py = "2.10" # I/O into files imageio = "2.8.0" omero-py = ">=5.6.2" # contact omero server -zeroc-ice = "3.6.5" # pin networking interface -aliby-agora = "^0.2.29" +aliby-agora = "^0.2.30" aliby-baby = "^0.1.10" aliby-post = "^0.1.27" p-tqdm = "^1.3.3" # Parallel progress bars xmltodict = "^0.13.0" # read ome-tiff metadata protobuf = "<=3.20.1" # For pytest to work +zeroc-ice = {version="3.6.5"} # networking interface, slow to build +# zeroc-ice = {version="3.6.5", optional=true} # To be set as optional in the future [tool.poetry.dev-dependencies] -- GitLab