diff --git a/src/wela/imageviewer.py b/src/wela/imageviewer.py index ff3c762309bcd6fb1f10923c577d60c65023ce18..5a94cb7ce6d4d0e26dbd6fcc5844250cd7823fd5 100644 --- a/src/wela/imageviewer.py +++ b/src/wela/imageviewer.py @@ -6,10 +6,10 @@ import h5py import napari import numpy as np from agora.io.cells import Cells -from agora.io.metadata import parse_metadata from aliby.io.image import dispatch_image -from aliby.tile.tiler import Tiler +from aliby.io.omero import Dataset from aliby.io.omero import UnsafeImage as OImage +from aliby.tile.tiler import Tiler def colormap(channel): @@ -32,10 +32,23 @@ class BaseImageViewer(ABC): print(f"Viewing {str(h5file_path)}") self.full = {} - def get_tiles(self, trap_id, tps, cell_only=True): + def print_trap_info(self, cells): + """List available traps - those with identified cells.""" + traps_with_labels = [ + i for i, labels in enumerate(cells.labels) if labels + ] + print(f"Traps with labelled cells {traps_with_labels}.") + print(f"Maximum number of time points {cells.ntimepoints}.") + + def get_tiles(self, trap_id, tps, channels_to_skip=None, cell_only=True): """Get dict of tiles with channel indices as keys.""" tiles_dict = {} - channels = self.tiler.channels + if channels_to_skip is None: + channels = self.tiler.channels + else: + channels = [ + ch for ch in self.tiler.channels if ch not in channels_to_skip + ] channel_indices = [channels.index(ch) for ch in channels] for ch_index, ch in zip(channel_indices, channels): tile_dict_for_ch = self.get_all_tiles(tps, ch_index) @@ -99,23 +112,24 @@ class BaseImageViewer(ABC): """ Get dict with time points as keys and all available tiles as values. - We assume there is only a single channel. + We assume only a single channel. """ z = z or self.tiler.ref_z ch_tps = [(channel_index, tp) for tp in tps] for ch, tp in ch_tps: if (ch, tp) not in self.full: + print(f"Getting {self.tiler.channels[ch]} at time point {tp}.") self.full[(ch, tp)] = self.tiler.get_tiles_timepoint( tp, channels=[ch], z=[z] )[:, 0, 0, z, ...] tile_dict = {tp: self.full[(ch, tp)] for ch, tp in ch_tps} return tile_dict - def get_data_for_viewing(self, trap_id, tps): + def get_data_for_viewing(self, trap_id, tps, channels_to_skip): """Get images and outlines as multidimensional arrays for Napari.""" # get outlines and tiles outlines = self.get_outlines(trap_id, tps) - tiles_dict = self.get_tiles(trap_id, tps) + tiles_dict = self.get_tiles(trap_id, tps, channels_to_skip) channels = list(tiles_dict.keys()) # put time series into one array with dimensions TCZYX ydim, xdim = tiles_dict[list(tiles_dict.keys())[0]][0].shape @@ -134,7 +148,7 @@ class BaseImageViewer(ABC): ts_labels[tp_index, 0, ...] = outlines[tp_index] return ts_images, ts_labels, channels - def view(self, trap_id, tps=10): + def view(self, trap_id, tps=10, channels_to_skip=None): """ Use Napari to view all channels and outlines for a particular trap. @@ -155,7 +169,7 @@ class BaseImageViewer(ABC): elif type(tps) is int: tps = np.arange(tps) ts_images, ts_labels, channels = self.get_data_for_viewing( - trap_id, tps + trap_id, tps, channels_to_skip ) # launch napari viewer = napari.Viewer() @@ -192,11 +206,7 @@ class LocalImageViewer(BaseImageViewer): with dispatch_image(image_file_path)(image_file_path) as image: self.tiler = Tiler.from_h5(image, h5file_path) self.cells = Cells.from_source(h5file_path) - traps_with_labels = [ - i for i, labels in enumerate(self.cells.labels) if labels - ] - print(f"Traps with labels {traps_with_labels}.") - print(f"Maximum number of time points {self.cells.ntimepoints}.") + self.print_trap_info(self.cells) else: if not h5file_path.exists(): print(f" Trouble loading {h5file}.") @@ -214,13 +224,22 @@ class RemoteImageViewer(BaseImageViewer): h5file_path = Path(h5file) super().__init__(h5file_path) with h5py.File(h5file_path, "r") as f: - breakpoint() - self.image_id = f.attrs.get("image_id") - - image = OImage(omero_id, **server_info) - breakpoint() - self.tiler = Tiler.from_h5(image, h5file_path) - self.cells = Cells.from_source(h5file_path) + # get image_id from the h5 file + image_id = f.attrs.get("image_id") + if image_id is None: + # get image_id from OMERO + with Dataset(omero_id, **server_info) as dataset_om: + positions = dataset_om.get_position_ids() + image_id = positions.get(h5file_path.name.split(".")[0]) + if image_id is None: + print("Can't find an image.") + else: + self.image_id = image_id + image = OImage(image_id, **server_info) + print("Connected to OMERO.") + self.tiler = Tiler.from_h5(image, h5file_path) + self.cells = Cells.from_source(h5file_path) + self.print_trap_info(self.cells) def get_files(