diff --git a/src/wela/imageviewer.py b/src/wela/imageviewer.py index 75f5363fed233fc4da0580ff8b6e8015594b6f46..2a349f408bf6d3c9007b92ec13aacbe6e45551e0 100644 --- a/src/wela/imageviewer.py +++ b/src/wela/imageviewer.py @@ -161,7 +161,20 @@ class ImageViewer: return tile_dict def get_data_for_viewing(self, trap_id, tps, channels_to_skip): - """Get images and outlines as multidimensional arrays for Napari.""" + """ + Get images and outlines as multidimensional arrays for Napari. + + Returns + ------- + ts_images: array + An array of images ordered by time points, channels, z stack, + x and y values. + ts_labels: array + An array of images of cell outlines ordered by time points, + z stack, x and y values. + channels: list of str + List of channel names. + """ # get outlines and tiles outlines = self.get_outlines(trap_id, tps) tiles_dict = self.get_tiles(trap_id, tps, channels_to_skip) @@ -183,7 +196,65 @@ class ImageViewer: ts_labels[tp_index, 0, ...] = outlines[tp_index] return ts_images, ts_labels, channels - def view(self, trap_id, tpt_end=10, tpt_start=0, channels_to_skip=None): + def concat(self, arrangement, image_dict, axis): + """ + Concat dict of images into one image array. + + Following the vertical layout in arrangment. + """ + # concatenate vertically into a list + images_v = [ + np.concatenate( + [image_dict[trap_id] for trap_id in v_stack], axis=axis + ) + for v_stack in arrangement + ] + # concatenate horizontally into an array + if len(arrangement) > 1: + images = np.concatenate( + [image_v_stack for image_v_stack in images_v], axis=axis + 1 + ) + else: + images = images_v[0] + return images + + def combine_tiles(self, ts_images_dict, ts_labels_dict, no_vertical_tiles): + """Combine tiles into one image first vertically then horizontally.""" + no_tiles = len(ts_images_dict) + trap_ids = list(ts_images_dict.keys()) + # find how tiles will be arranged in the concatenated image + if no_tiles < no_vertical_tiles: + arrangement = [trap_ids] + else: + arrangement = [ + trap_ids[i : min(i + no_vertical_tiles, no_tiles)] + for i in range( + 0, + int( + np.floor(no_tiles / no_vertical_tiles) + * no_vertical_tiles + ), + no_vertical_tiles, + ) + ] + if no_tiles > np.array(arrangement).size: + print( + f"Too many tiles: {no_tiles - np.array(arrangement).size}" + " will be ignored." + ) + # concatenate images + ts_images = self.concat(arrangement, ts_images_dict, axis=3) + ts_labels = self.concat(arrangement, ts_labels_dict, axis=2) + return ts_images, ts_labels + + def view( + self, + trap_ids, + tpt_end=10, + tpt_start=0, + channels_to_skip=None, + no_vertical_tiles=3, + ): """ Use Napari to view all channels and outlines for a particular trap. @@ -200,10 +271,22 @@ class ImageViewer: will be slow. """ tps = np.arange(tpt_start, tpt_end + 1) - ts_images, ts_labels, channels = self.get_data_for_viewing( - trap_id, tps, channels_to_skip + if isinstance(trap_ids, int): + trap_ids = [trap_ids] + ts_images_dict, ts_labels_dict = {}, {} + for trap_id in trap_ids: + ts_images_dict[trap_id], ts_labels_dict[trap_id], channels = ( + self.get_data_for_viewing(trap_id, tps, channels_to_skip) + ) + # combine tiles + ts_images, ts_labels = self.combine_tiles( + ts_images_dict, ts_labels_dict, no_vertical_tiles ) # launch napari + self.launch_napari(ts_images, ts_labels, channels) + + def launch_napari(self, ts_images, ts_labels, channels): + """Use Napari to see the images and outlines.""" viewer = napari.Viewer() viewer.add_image( ts_images[:, channels.index("Brightfield"), ...], @@ -222,57 +305,7 @@ class ImageViewer: ) -# class LocalImageViewer(BaseImageViewer): -# """ -# View images from local files. - -# Files are either zarr or organised in directories. -# """ - -# def __init__(self, h5file: str, image_file: str): -# """Initialise using a h5file and a zarr file of images.""" -# h5file_path = Path(h5file) -# image_file_path = Path(image_file) -# if h5file_path.exists() and image_file_path.exists(): -# super().__init__(h5file_path) -# with dispatch_image(image_file_path)(image_file_path) as image: -# self.tiler = Tiler.from_h5(image, h5file_path) -# self.cells = Cells.from_source(h5file_path) -# else: -# if not h5file_path.exists(): -# print(f" Trouble loading {h5file}.") -# if not image_file_path.exists(): -# print(f" Trouble loading {image_file}.") - - -# class RemoteImageViewer(BaseImageViewer): -# """View images from OMERO.""" - -# def __init__( -# self, h5file: str, server_info: t.Dict[str, str], omero_id: int -# ): -# """Initialise using a h5file and importing aliby.io.omero.""" -# h5file_path = Path(h5file) -# super().__init__(h5file_path) -# with h5py.File(h5file_path, "r") as f: -# # get image_id from the h5 file -# image_id = f.attrs.get("image_id") -# if image_id is None: -# # get image_id from OMERO -# with Dataset(omero_id, **server_info) as dataset_om: -# positions = dataset_om.get_position_ids() -# image_id = positions.get(h5file_path.name.split(".")[0]) -# if image_id is None: -# print("Can't find an image.") -# else: -# print(f"Using image ID {image_id}.") -# self.image_id = image_id -# image = OImage(image_id, **server_info) -# print("Connected to OMERO.") -# self.tiler = Tiler.from_h5(image, h5file_path) -# self.cells = Cells.from_source(h5file_path) - - +#### def colormap(channel): """Find default colormap.""" if "GFP" in channel: