Skip to content
Snippets Groups Projects
Commit df4db391 authored by Alán Muñoz's avatar Alán Muñoz
Browse files

add support for multiple channels

parent a713bb56
No related branches found
No related tags found
No related merge requests found
......@@ -11,10 +11,13 @@ trange = list(range(0, 30))
ncols = 8
riv = remoteImageViewer(fpath)
riv.plot_labeled_traps(trap_id, trange, ncols)
riv.plot_labelled_traps(trap_id, trange, ncols)
"""
import re
import typing as t
import matplotlib.pyplot as plt
import numpy as np
import yaml
......@@ -25,6 +28,30 @@ from skimage.morphology import dilation
from aliby.io.image import Image as OImage
from aliby.tile.tiler import Tiler
from aliby.tile.traps import stretch_image
default_colours = {
"Brightfield": "Greys_r",
"GFP": "Greens_r",
"mCherry": "Reds_r",
"cell_label": "Set1",
}
def custom_imshow(a, norm=None, cmap=None, *args, **kwargs):
"""
Wrapper on plt.imshow function.
"""
if cmap is None:
cmap = "Greys_r"
return plt.imshow(
a,
*args,
cmap=cmap,
interpolation=None,
interpolation_stage="rgba",
**kwargs,
)
class localImageViewer:
......@@ -96,19 +123,16 @@ class remoteImageViewer:
# Print cell label at a given time-point
return self.cells.labels_at_time(tp)
def nonempty_traps_at_time(self, tp: int):
# Print cell label at a given time-point
return [
trap_id
for trap_id, labels in self.labels_at_time(tp).items()
if len(labels)
]
def nonempty_tps_in_trap(self, trap_id: int):
# Print cell label at a given time-point
return self.cells.labels_at_time(trap_id)
def random_valid_trap_tp(
self, min_ncells: int = None, min_consecutive_tps: int = None
):
# Call Cells convenience function to pick a random trap and tp
# containing cells for x cells for y
return self.cells.random_valid_trap_tp(
min_ncells=min_ncells, min_consecutive_tps=min_consecutive_tps
)
def get_position(self):
def get_entire_position(self):
raise (NotImplementedError)
def get_position_timelapse(self):
......@@ -128,48 +152,76 @@ class remoteImageViewer:
self.tiler.image = image.data
return self.tiler.get_tc(tp, channel)
def _find_channels(self, channels):
def _find_channels(self, channels: str, guess: bool = True):
channels = channels or self.tiler.ref_channel
if isinstance(channels, int):
if isinstance(channels, (int, str)):
channels = [channels]
elif isinstance(channels, list) and isinstance(channels[0], str):
channels = [self.tiler.channels.index(ch) for ch in channels]
if isinstance(channels[0], str):
if guess:
channels = [self.tiler.channels.index(ch) for ch in channels]
else:
channels = [
re.search(ch, tiler_channels)
for ch in channels
for tiler_channels in self.tiler.channels
]
return channels
def get_trap_timepoints(
self, trap_id, tps, channels=None, z=None, server_info=None
def get_pos_timepoints(
self,
tps: t.Union[int, t.Collection[int]],
channels: t.Union[str, t.Collection[str]] = None,
z: int = None,
server_info=None,
**kwargs,
):
if tps and not isinstance(tps, t.Collection):
tps = range(tps)
# TODO add support for multiple channels or refactor
if channels and not isinstance(channels, t.Collection):
channels = [channels]
# if z and isinstance(z, t.Collection):
# z = list(z)
if z is None:
z = 0
server_info = server_info or self.server_info
channels = self._find_channels(channels)
channels = 0 or self._find_channels(channels)
z = z or self.tiler.ref_z
ch_tps = set([(channels[0], tp) for tp in tps])
ch_tps = [(channels[0], tp) for tp in tps]
with OImage(self.image_id, **server_info) as image:
self.tiler.image = image.data
if ch_tps.difference(self.full.keys()):
tps = set(tps).difference(self.full.keys())
for ch, tp in ch_tps:
if z is 0:
self.full[(ch, tp)] = self.tiler.get_traps_timepoint(
tp, channels=[ch], z=[z]
)[:, 0, 0, ..., 0]
else:
self.full[(ch, tp)] = self.tiler.get_traps_timepoint(
tp, channels=[ch], z=[z]
)[:, 0, 0, ..., z]
# if ch_tps.difference(self.full.keys()):
# tps = set(tps).difference(self.full.keys())
for ch, tp in ch_tps:
if (ch, tp) not in self.full:
self.full[(ch, tp)] = self.tiler.get_traps_timepoint(
tp, channels=[ch], z=[z]
)[:, 0, 0, ..., z]
requested_trap = {tp: self.full[(ch, tp)] for ch, tp in ch_tps}
return requested_trap
def get_labeled_trap(self, trap_id, tps, **kwargs):
imgs = self.get_trap_timepoints(trap_id, tps, **kwargs)
def get_labelled_trap(
self,
trap_id: int,
tps: t.Union[range, t.Collection[int]],
**kwargs,
) -> t.Tuple[np.array]:
imgs = self.get_pos_timepoints(tps, **kwargs)
imgs_list = [x[trap_id] for x in imgs.values()]
outlines = [
self.cells.at_time(tp, kind="edgemask").get(trap_id, [])
for tp in tps
]
lbls = [self.cells.labels_at_time(tp).get(trap_id, []) for tp in tps]
lbls = [
self.cells.labels_at_time(tp - 1).get(trap_id, []) for tp in tps
]
lbld_outlines = [
np.dstack([mask * lbl for mask, lbl in zip(maskset, lblset)]).max(
axis=2
......@@ -190,119 +242,158 @@ class remoteImageViewer:
imgs = {}
for ch in self._find_channels(channels):
out, imgs[ch] = self.get_labeled_trap(
out, imgs[ch] = self.get_labelled_trap(
trap_id, trange, channels=[ch], **kwargs
)
return out, imgs
def plot_labeled_zstacks(
self, trap_id, channels, trange, z=None, **kwargs
# def plot_labelled_zstacks(
# self, trap_id, channels, trange, z=None, **kwargs
# ):
# # if z is None:
# # z =
# out, images = self.get_images(trap_id, trange, channels, z=z, **kwargs)
def plot_labelled_trap(
self,
trap_id: int,
channels,
trange: t.Union[range, t.Collection[int]],
remove_axis: bool = False,
savefile: str = None,
skip_outlines: bool = False,
norm: str = None,
ncols: int = None,
img_plot_kwargs: dict = {},
lbl_plot_kwargs: dict = {},
**kwargs,
):
# if z is None:
# z =
out, images = self.get_images(trap_id, trange, channels, z=z, **kwargs)
if ncols is None:
ncols = len(trange)
nrows = int(np.ceil(len(trange) / ncols))
width = self.tiler.tile_size * ncols
def plot_labeled_channelrows(self, trap_id, channels, trange, **kwargs):
out, images = self.get_images(trap_id, trange, channels, **kwargs)
# dilation makes outlines easier to see
out = dilation(out).astype(float)
out[out == 0] = np.nan
img_set = np.concatenate([v for v in imgs.values()], axis=0)
tiled_out = np.tile(out, (len(imgs), 1))
plt.imshow(
img_set,
interpolation=None,
cmap="Greys_r",
channel_labels = [
self.tiler.channels[ch] if isinstance(ch, int) else ch
for ch in channels
]
assert not norm or norm in (
"l1",
"l2",
"max",
), "Invalid norm argument."
if norm and norm in ("l1", "l2", "max"):
images = {k: stretch_image(v) for k, v in images.items()}
# images = [concat_pad(img, width, nrows) for img in images.values()]
images = [concat_pad(img, width, nrows) for img in images.values()]
tiled_imgs = {}
tiled_imgs["img"] = np.concatenate(images, axis=0)
tiled_imgs["cell_labels"] = np.concatenate(
[concat_pad(out, width, nrows) for _ in images], axis=0
)
plt.imshow(
tiled_out,
custom_imshow(
tiled_imgs["img"],
**img_plot_kwargs,
)
custom_imshow(
tiled_imgs["cell_labels"],
cmap="Set1",
interpolation=None,
**lbl_plot_kwargs,
)
plt.yticks(
ticks=[
self.tiler.tile_size * (i + 0.5) for i in range(len(channels))
],
labels=[
self.tiler.channels[ch] if isinstance(ch, int) else ch
for ch in channels
self.tiler.tile_size * (i + 0.5)
+ (i * self.tiler.tile_size * nrows)
for i in range(len(channels))
],
labels=channel_labels,
)
plt.xticks(
ticks=[
self.tiler.tile_size * (i + 0.5) for i in range(len(trange))
],
labels=[t for t in trange],
ticks=[self.tiler.tile_size * (i + 0.5) for i in range(ncols)],
labels=["+ {} ".format(i) for i in range(ncols)],
)
plt.xlabel("Additional time-points")
plt.show()
def plot_labeled_traps(
self,
trap_id,
trange,
ncols,
remove_axis=False,
savefile=False,
skip_outlines=False,
**kwargs,
):
"""
Wrapper to plot a single trap over time
Parameters
---------
:trap_id: int trap identification
:trange: list list of time points to fetch
"""
nrows = len(trange) // ncols
width = self.tiler.tile_size * ncols
out, img = self.get_labeled_trap(trap_id, trange, **kwargs)
# dilation makes outlines easier to see
out = dilation(out).astype(float)
out[out == 0] = np.nan
def concat_pad(array):
return np.concatenate(
np.array_split(
np.pad(
array,
((0, 0), (0, array.shape[1] % width)),
constant_values=np.nan,
),
nrows,
axis=1,
)
)
plt.imshow(
concat_pad(img),
interpolation=None,
cmap="Greys_r",
# def plot_labelled_trap(
# self,
# trap_id: int,
# trange: t.Union[range, t.Collection[int]],
# ncols: int,
# remove_axis: bool = False,
# savefile: str = None,
# skip_outlines: bool = False,
# **kwargs,
# ):
# """
# Wrapper to plot a single trap over time
# Parameters
# ---------
# :trap_id: int trap identification
# :trange: Collection or Range list of time points to fetch
# """
# nrows = len(trange) // ncols
# width = self.tiler.tile_size * ncols
# out, img = self.get_labelled_trap(trap_id, trange, **kwargs)
# # dilation makes outlines easier to see
# out = dilation(out).astype(float)
# out[out == 0] = np.nan
# # interpolation_kwargs = {""}
# custom_imshow(
# concat_pad(img),
# cmap="Greys_r",
# )
# if not skip_outlines:
# custom_imshow(
# concat_pad(out),
# cmap="Set1",
# )
# bbox_inches = None
# if remove_axis:
# plt.axis("off")
# bbox_inches = "tight"
# else:
# plt.yticks(
# ticks=[self.tiler.tile_size * (i + 0.5) for i in range(nrows)],
# labels=[trange[0] + ncols * i for i in range(nrows)],
# )
# if not savefile:
# plt.show()
# else:
# if np.any(out):
# plt.savefig(savefile, bbox_inches=bbox_inches)
def concat_pad(a: np.array, width, nrows):
"""
Melt an array into having multiple blocks as rows
"""
return np.concatenate(
np.array_split(
np.pad(
a,
# ((0, 0), (0, width - (a.shape[1] % width))),
((0, 0), (0, a.shape[1] % width)),
constant_values=np.nan,
),
nrows,
axis=1,
)
if not skip_outlines:
plt.imshow(
concat_pad(out),
# concat_pad(mask),
cmap="Set1",
interpolation=None,
)
bbox_inches = None
if remove_axis:
plt.axis("off")
bbox_inches = "tight"
else:
plt.yticks(
ticks=[self.tiler.tile_size * (i + 0.5) for i in range(nrows)],
labels=[trange[0] + ncols * i for i in range(nrows)],
)
if not savefile:
plt.show()
else:
if np.any(out):
plt.savefig(savefile, bbox_inches=bbox_inches)
)
......@@ -24,13 +24,14 @@ requests-toolbelt = "^0.9.1"
h5py = "2.10" # I/O into files
imageio = "2.8.0"
omero-py = ">=5.6.2" # contact omero server
zeroc-ice = "3.6.5" # pin networking interface
aliby-agora = "^0.2.29"
aliby-agora = "^0.2.30"
aliby-baby = "^0.1.10"
aliby-post = "^0.1.27"
p-tqdm = "^1.3.3" # Parallel progress bars
xmltodict = "^0.13.0" # read ome-tiff metadata
protobuf = "<=3.20.1" # For pytest to work
zeroc-ice = {version="3.6.5"} # networking interface, slow to build
# zeroc-ice = {version="3.6.5", optional=true} # To be set as optional in the future
[tool.poetry.dev-dependencies]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment