Skip to content
Snippets Groups Projects
Commit 25c73bb5 authored by pswain's avatar pswain
Browse files

feature(imageviewer): co-viewing multiple traps

parent 9ff355a4
No related branches found
No related tags found
No related merge requests found
...@@ -161,7 +161,20 @@ class ImageViewer: ...@@ -161,7 +161,20 @@ class ImageViewer:
return tile_dict return tile_dict
def get_data_for_viewing(self, trap_id, tps, channels_to_skip): def get_data_for_viewing(self, trap_id, tps, channels_to_skip):
"""Get images and outlines as multidimensional arrays for Napari.""" """
Get images and outlines as multidimensional arrays for Napari.
Returns
-------
ts_images: array
An array of images ordered by time points, channels, z stack,
x and y values.
ts_labels: array
An array of images of cell outlines ordered by time points,
z stack, x and y values.
channels: list of str
List of channel names.
"""
# get outlines and tiles # get outlines and tiles
outlines = self.get_outlines(trap_id, tps) outlines = self.get_outlines(trap_id, tps)
tiles_dict = self.get_tiles(trap_id, tps, channels_to_skip) tiles_dict = self.get_tiles(trap_id, tps, channels_to_skip)
...@@ -183,7 +196,65 @@ class ImageViewer: ...@@ -183,7 +196,65 @@ class ImageViewer:
ts_labels[tp_index, 0, ...] = outlines[tp_index] ts_labels[tp_index, 0, ...] = outlines[tp_index]
return ts_images, ts_labels, channels return ts_images, ts_labels, channels
def view(self, trap_id, tpt_end=10, tpt_start=0, channels_to_skip=None): def concat(self, arrangement, image_dict, axis):
"""
Concat dict of images into one image array.
Following the vertical layout in arrangment.
"""
# concatenate vertically into a list
images_v = [
np.concatenate(
[image_dict[trap_id] for trap_id in v_stack], axis=axis
)
for v_stack in arrangement
]
# concatenate horizontally into an array
if len(arrangement) > 1:
images = np.concatenate(
[image_v_stack for image_v_stack in images_v], axis=axis + 1
)
else:
images = images_v[0]
return images
def combine_tiles(self, ts_images_dict, ts_labels_dict, no_vertical_tiles):
"""Combine tiles into one image first vertically then horizontally."""
no_tiles = len(ts_images_dict)
trap_ids = list(ts_images_dict.keys())
# find how tiles will be arranged in the concatenated image
if no_tiles < no_vertical_tiles:
arrangement = [trap_ids]
else:
arrangement = [
trap_ids[i : min(i + no_vertical_tiles, no_tiles)]
for i in range(
0,
int(
np.floor(no_tiles / no_vertical_tiles)
* no_vertical_tiles
),
no_vertical_tiles,
)
]
if no_tiles > np.array(arrangement).size:
print(
f"Too many tiles: {no_tiles - np.array(arrangement).size}"
" will be ignored."
)
# concatenate images
ts_images = self.concat(arrangement, ts_images_dict, axis=3)
ts_labels = self.concat(arrangement, ts_labels_dict, axis=2)
return ts_images, ts_labels
def view(
self,
trap_ids,
tpt_end=10,
tpt_start=0,
channels_to_skip=None,
no_vertical_tiles=3,
):
""" """
Use Napari to view all channels and outlines for a particular trap. Use Napari to view all channels and outlines for a particular trap.
...@@ -200,10 +271,22 @@ class ImageViewer: ...@@ -200,10 +271,22 @@ class ImageViewer:
will be slow. will be slow.
""" """
tps = np.arange(tpt_start, tpt_end + 1) tps = np.arange(tpt_start, tpt_end + 1)
ts_images, ts_labels, channels = self.get_data_for_viewing( if isinstance(trap_ids, int):
trap_id, tps, channels_to_skip trap_ids = [trap_ids]
ts_images_dict, ts_labels_dict = {}, {}
for trap_id in trap_ids:
ts_images_dict[trap_id], ts_labels_dict[trap_id], channels = (
self.get_data_for_viewing(trap_id, tps, channels_to_skip)
)
# combine tiles
ts_images, ts_labels = self.combine_tiles(
ts_images_dict, ts_labels_dict, no_vertical_tiles
) )
# launch napari # launch napari
self.launch_napari(ts_images, ts_labels, channels)
def launch_napari(self, ts_images, ts_labels, channels):
"""Use Napari to see the images and outlines."""
viewer = napari.Viewer() viewer = napari.Viewer()
viewer.add_image( viewer.add_image(
ts_images[:, channels.index("Brightfield"), ...], ts_images[:, channels.index("Brightfield"), ...],
...@@ -222,57 +305,7 @@ class ImageViewer: ...@@ -222,57 +305,7 @@ class ImageViewer:
) )
# class LocalImageViewer(BaseImageViewer): ####
# """
# View images from local files.
# Files are either zarr or organised in directories.
# """
# def __init__(self, h5file: str, image_file: str):
# """Initialise using a h5file and a zarr file of images."""
# h5file_path = Path(h5file)
# image_file_path = Path(image_file)
# if h5file_path.exists() and image_file_path.exists():
# super().__init__(h5file_path)
# with dispatch_image(image_file_path)(image_file_path) as image:
# self.tiler = Tiler.from_h5(image, h5file_path)
# self.cells = Cells.from_source(h5file_path)
# else:
# if not h5file_path.exists():
# print(f" Trouble loading {h5file}.")
# if not image_file_path.exists():
# print(f" Trouble loading {image_file}.")
# class RemoteImageViewer(BaseImageViewer):
# """View images from OMERO."""
# def __init__(
# self, h5file: str, server_info: t.Dict[str, str], omero_id: int
# ):
# """Initialise using a h5file and importing aliby.io.omero."""
# h5file_path = Path(h5file)
# super().__init__(h5file_path)
# with h5py.File(h5file_path, "r") as f:
# # get image_id from the h5 file
# image_id = f.attrs.get("image_id")
# if image_id is None:
# # get image_id from OMERO
# with Dataset(omero_id, **server_info) as dataset_om:
# positions = dataset_om.get_position_ids()
# image_id = positions.get(h5file_path.name.split(".")[0])
# if image_id is None:
# print("Can't find an image.")
# else:
# print(f"Using image ID {image_id}.")
# self.image_id = image_id
# image = OImage(image_id, **server_info)
# print("Connected to OMERO.")
# self.tiler = Tiler.from_h5(image, h5file_path)
# self.cells = Cells.from_source(h5file_path)
def colormap(channel): def colormap(channel):
"""Find default colormap.""" """Find default colormap."""
if "GFP" in channel: if "GFP" in channel:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment