From ce10e3358d82715664e43c2924ba1f1858f7ca2e Mon Sep 17 00:00:00 2001 From: pswain <peter.swain@ed.ac.uk> Date: Fri, 2 Aug 2024 10:54:01 +0100 Subject: [PATCH] change: added typing; corrected docs --- src/wela/imageviewer.py | 100 +++++++++++++++++++++++++--------------- 1 file changed, 62 insertions(+), 38 deletions(-) diff --git a/src/wela/imageviewer.py b/src/wela/imageviewer.py index a2bfd5f..bed51f6 100644 --- a/src/wela/imageviewer.py +++ b/src/wela/imageviewer.py @@ -9,7 +9,10 @@ except ModuleNotFoundError: "Napari cannot be imported.\nRun", ' python -m pip install "napari[all]"', ) +from typing import Any, Dict, List + import numpy as np +import numpy.typing as npt from agora.io.cells import Cells from aliby.io.image import dispatch_image from aliby.io.omero import Dataset @@ -47,7 +50,7 @@ class ImageViewer: print(f" Trouble loading {image_file}.") @classmethod - def remote(cls, h5file: str, server_info: dict, omero_id: int): + def remote(cls, h5file: str, server_info: Dict, omero_id: int): """View images from OMERO.""" iv = cls(h5file) with h5py.File(iv.h5file_path, "r") as f: @@ -69,7 +72,9 @@ class ImageViewer: iv.cells = Cells.from_source(iv.h5file_path) return iv - def get_all_traps_with_cells(self, tpt_end, tpt_start=0, display=True): + def get_all_traps_with_cells( + self, tpt_end: int, tpt_start: int = 0, display: bool = True + ): """List traps with cells.""" cells = self.cells tpts = range(tpt_start, tpt_end) @@ -84,7 +89,9 @@ class ImageViewer: print(f"Traps with cells {list(traps_with_cells)}") return traps_with_cells - def sample_traps_with_cells(self, no_cells, tpt_end, tpt_start=0): + def sample_traps_with_cells( + self, no_cells: int, tpt_end: int, tpt_start: int = 0 + ): """Sample some traps that have cells.""" traps_with_cells = self.get_all_traps_with_cells( tpt_end, tpt_start, display=False @@ -96,7 +103,12 @@ class ImageViewer: ) return samples - def get_tiles(self, trap_id, tps, channels_to_skip=None, cell_only=True): + def get_tiles( + self, + trap_id: int, + tps: List[int], + channels_to_skip: List[str] = None, + ): """Get dict of tiles with channel indices as keys.""" tiles_dict = {} if channels_to_skip is None: @@ -132,7 +144,7 @@ class ImageViewer: tiles_dict[ch] = new_tiles return tiles_dict - def get_outlines(self, trap_id, tps): + def get_outlines(self, trap_id: int, tps: List[int]): """Get uniquely labelled outlines for each cell time point.""" # get outlines for each time point outlines = [ @@ -161,14 +173,14 @@ class ImageViewer: def get_all_tiles( self, - tps, - channel_index, - z=0, + tps: List[int], + channel_index: str, + z: int = 0, ): """ Get dict with time points as keys and all available tiles as values. - We assume only a single channel. + Assume only a single channel. """ z = z or self.tiler.ref_z ch_tps = [(channel_index, tp) for tp in tps] @@ -181,7 +193,9 @@ class ImageViewer: tile_dict = {tp: self.full[(ch, tp)] for ch, tp in ch_tps} return tile_dict - def get_data_for_viewing(self, trap_id, tps, channels_to_skip): + def get_data_for_viewing( + self, trap_id: int, tps: List[int], channels_to_skip: List[str] + ): """ Get images and outlines as multidimensional arrays for Napari. @@ -217,11 +231,11 @@ class ImageViewer: ts_labels[tp_index, 0, ...] = outlines[tp_index] return ts_images, ts_labels, channels - def concat(self, arrangement, image_dict, axis): + def concat(self, arrangement: List[int], image_dict: Dict, axis: int): """ Concat dict of images into one image array. - Following the vertical layout in arrangment. + Follow the vertical layout in arrangment. """ # concatenate vertically into a list images_v = [ @@ -239,23 +253,22 @@ class ImageViewer: images = images_v[0] return images - def combine_tiles(self, ts_images_dict, ts_labels_dict, no_vertical_tiles): - """Combine tiles into one image first vertically then horizontally.""" + def combine_tiles( + self, ts_images_dict: Dict, ts_labels_dict: Dict, no_rows: int + ): + """Combine tiles into one image first into rows then columns.""" no_tiles = len(ts_images_dict) trap_ids = list(ts_images_dict.keys()) # find how tiles will be arranged in the concatenated image - if no_tiles < no_vertical_tiles: + if no_tiles < no_rows: arrangement = [trap_ids] else: arrangement = [ - trap_ids[i : min(i + no_vertical_tiles, no_tiles)] + trap_ids[i : min(i + no_rows, no_tiles)] for i in range( 0, - int( - np.floor(no_tiles / no_vertical_tiles) - * no_vertical_tiles - ), - no_vertical_tiles, + int(np.floor(no_tiles / no_rows) * no_rows), + no_rows, ) ] if no_tiles > np.array(arrangement).size: @@ -271,26 +284,30 @@ class ImageViewer: def view( self, - trap_ids, - tpt_end=10, - tpt_start=0, - channels_to_skip=None, - no_vertical_tiles=3, + trap_ids: List[int], + tpt_start: int = 0, + tpt_end: int = 10, + channels_to_skip: List[str] = None, + no_rows: int = 2, ): """ - Use Napari to view all channels and outlines for a particular trap. + Use Napari to view all channels and outlines for particular traps. Fluorescence channels will not be immediately visible. + Concatenating traps into one image can become slow for multiple traps. Parameters ---------- - trap_id: int - The trap to be viewed. - tps: int or array of ints - Either the last time point to be viewed or a rage of time points - to view. - If None, all time points will be viewed, but gathering the images - will be slow. + trap_ids: list of int + The traps to be viewed. + tpt_start: int + The index for the initial time point to view. + tpt_end: int + The index for the final time point. + channels_to_skip: list of str + Channels to ignore, such as "cy5". + no_rows: int + The number of rows of traps in the final concatenated image. """ tps = np.arange(tpt_start, tpt_end + 1) if isinstance(trap_ids, int): @@ -302,13 +319,18 @@ class ImageViewer: ) # combine tiles ts_images, ts_labels = self.combine_tiles( - ts_images_dict, ts_labels_dict, no_vertical_tiles + ts_images_dict, ts_labels_dict, no_rows ) # launch napari self.launch_napari(ts_images, ts_labels, channels) - def launch_napari(self, ts_images, ts_labels, channels): - """Use Napari to see the images and outlines.""" + def launch_napari( + self, + ts_images: npt.NDArray[Any], + ts_labels: npt.NDArray[Any], + channels: List[str], + ): + """Call Napari viewer.""" viewer = napari.Viewer() viewer.add_image( ts_images[:, channels.index("Brightfield"), ...], @@ -328,7 +350,9 @@ class ImageViewer: #### -def colormap(channel): + + +def colormap(channel: str): """Find default colormap.""" if "GFP" in channel: colormap = "green" -- GitLab