Skip to content
Snippets Groups Projects
Commit 0bb395b2 authored by Peter Swain's avatar Peter Swain
Browse files

change(imageviewer): fixed RemoteImageViewer

image_id found from OMERO; added skip_channels
parent 7d945946
No related branches found
No related tags found
No related merge requests found
...@@ -6,10 +6,10 @@ import h5py ...@@ -6,10 +6,10 @@ import h5py
import napari import napari
import numpy as np import numpy as np
from agora.io.cells import Cells from agora.io.cells import Cells
from agora.io.metadata import parse_metadata
from aliby.io.image import dispatch_image from aliby.io.image import dispatch_image
from aliby.tile.tiler import Tiler from aliby.io.omero import Dataset
from aliby.io.omero import UnsafeImage as OImage from aliby.io.omero import UnsafeImage as OImage
from aliby.tile.tiler import Tiler
def colormap(channel): def colormap(channel):
...@@ -32,10 +32,23 @@ class BaseImageViewer(ABC): ...@@ -32,10 +32,23 @@ class BaseImageViewer(ABC):
print(f"Viewing {str(h5file_path)}") print(f"Viewing {str(h5file_path)}")
self.full = {} self.full = {}
def get_tiles(self, trap_id, tps, cell_only=True): def print_trap_info(self, cells):
"""List available traps - those with identified cells."""
traps_with_labels = [
i for i, labels in enumerate(cells.labels) if labels
]
print(f"Traps with labelled cells {traps_with_labels}.")
print(f"Maximum number of time points {cells.ntimepoints}.")
def get_tiles(self, trap_id, tps, channels_to_skip=None, cell_only=True):
"""Get dict of tiles with channel indices as keys.""" """Get dict of tiles with channel indices as keys."""
tiles_dict = {} tiles_dict = {}
channels = self.tiler.channels if channels_to_skip is None:
channels = self.tiler.channels
else:
channels = [
ch for ch in self.tiler.channels if ch not in channels_to_skip
]
channel_indices = [channels.index(ch) for ch in channels] channel_indices = [channels.index(ch) for ch in channels]
for ch_index, ch in zip(channel_indices, channels): for ch_index, ch in zip(channel_indices, channels):
tile_dict_for_ch = self.get_all_tiles(tps, ch_index) tile_dict_for_ch = self.get_all_tiles(tps, ch_index)
...@@ -99,23 +112,24 @@ class BaseImageViewer(ABC): ...@@ -99,23 +112,24 @@ class BaseImageViewer(ABC):
""" """
Get dict with time points as keys and all available tiles as values. Get dict with time points as keys and all available tiles as values.
We assume there is only a single channel. We assume only a single channel.
""" """
z = z or self.tiler.ref_z z = z or self.tiler.ref_z
ch_tps = [(channel_index, tp) for tp in tps] ch_tps = [(channel_index, tp) for tp in tps]
for ch, tp in ch_tps: for ch, tp in ch_tps:
if (ch, tp) not in self.full: if (ch, tp) not in self.full:
print(f"Getting {self.tiler.channels[ch]} at time point {tp}.")
self.full[(ch, tp)] = self.tiler.get_tiles_timepoint( self.full[(ch, tp)] = self.tiler.get_tiles_timepoint(
tp, channels=[ch], z=[z] tp, channels=[ch], z=[z]
)[:, 0, 0, z, ...] )[:, 0, 0, z, ...]
tile_dict = {tp: self.full[(ch, tp)] for ch, tp in ch_tps} tile_dict = {tp: self.full[(ch, tp)] for ch, tp in ch_tps}
return tile_dict return tile_dict
def get_data_for_viewing(self, trap_id, tps): def get_data_for_viewing(self, trap_id, tps, channels_to_skip):
"""Get images and outlines as multidimensional arrays for Napari.""" """Get images and outlines as multidimensional arrays for Napari."""
# get outlines and tiles # get outlines and tiles
outlines = self.get_outlines(trap_id, tps) outlines = self.get_outlines(trap_id, tps)
tiles_dict = self.get_tiles(trap_id, tps) tiles_dict = self.get_tiles(trap_id, tps, channels_to_skip)
channels = list(tiles_dict.keys()) channels = list(tiles_dict.keys())
# put time series into one array with dimensions TCZYX # put time series into one array with dimensions TCZYX
ydim, xdim = tiles_dict[list(tiles_dict.keys())[0]][0].shape ydim, xdim = tiles_dict[list(tiles_dict.keys())[0]][0].shape
...@@ -134,7 +148,7 @@ class BaseImageViewer(ABC): ...@@ -134,7 +148,7 @@ class BaseImageViewer(ABC):
ts_labels[tp_index, 0, ...] = outlines[tp_index] ts_labels[tp_index, 0, ...] = outlines[tp_index]
return ts_images, ts_labels, channels return ts_images, ts_labels, channels
def view(self, trap_id, tps=10): def view(self, trap_id, tps=10, channels_to_skip=None):
""" """
Use Napari to view all channels and outlines for a particular trap. Use Napari to view all channels and outlines for a particular trap.
...@@ -155,7 +169,7 @@ class BaseImageViewer(ABC): ...@@ -155,7 +169,7 @@ class BaseImageViewer(ABC):
elif type(tps) is int: elif type(tps) is int:
tps = np.arange(tps) tps = np.arange(tps)
ts_images, ts_labels, channels = self.get_data_for_viewing( ts_images, ts_labels, channels = self.get_data_for_viewing(
trap_id, tps trap_id, tps, channels_to_skip
) )
# launch napari # launch napari
viewer = napari.Viewer() viewer = napari.Viewer()
...@@ -192,11 +206,7 @@ class LocalImageViewer(BaseImageViewer): ...@@ -192,11 +206,7 @@ class LocalImageViewer(BaseImageViewer):
with dispatch_image(image_file_path)(image_file_path) as image: with dispatch_image(image_file_path)(image_file_path) as image:
self.tiler = Tiler.from_h5(image, h5file_path) self.tiler = Tiler.from_h5(image, h5file_path)
self.cells = Cells.from_source(h5file_path) self.cells = Cells.from_source(h5file_path)
traps_with_labels = [ self.print_trap_info(self.cells)
i for i, labels in enumerate(self.cells.labels) if labels
]
print(f"Traps with labels {traps_with_labels}.")
print(f"Maximum number of time points {self.cells.ntimepoints}.")
else: else:
if not h5file_path.exists(): if not h5file_path.exists():
print(f" Trouble loading {h5file}.") print(f" Trouble loading {h5file}.")
...@@ -214,13 +224,22 @@ class RemoteImageViewer(BaseImageViewer): ...@@ -214,13 +224,22 @@ class RemoteImageViewer(BaseImageViewer):
h5file_path = Path(h5file) h5file_path = Path(h5file)
super().__init__(h5file_path) super().__init__(h5file_path)
with h5py.File(h5file_path, "r") as f: with h5py.File(h5file_path, "r") as f:
breakpoint() # get image_id from the h5 file
self.image_id = f.attrs.get("image_id") image_id = f.attrs.get("image_id")
if image_id is None:
image = OImage(omero_id, **server_info) # get image_id from OMERO
breakpoint() with Dataset(omero_id, **server_info) as dataset_om:
self.tiler = Tiler.from_h5(image, h5file_path) positions = dataset_om.get_position_ids()
self.cells = Cells.from_source(h5file_path) image_id = positions.get(h5file_path.name.split(".")[0])
if image_id is None:
print("Can't find an image.")
else:
self.image_id = image_id
image = OImage(image_id, **server_info)
print("Connected to OMERO.")
self.tiler = Tiler.from_h5(image, h5file_path)
self.cells = Cells.from_source(h5file_path)
self.print_trap_info(self.cells)
def get_files( def get_files(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment