Skip to content
Snippets Groups Projects
imageviewer.py 8.8 KiB
Newer Older
import typing as t
from abc import ABC
from pathlib import Path

import h5py
import napari
import numpy as np
from agora.io.cells import Cells
from agora.io.metadata import parse_metadata
from aliby.io.image import dispatch_image
from aliby.tile.tiler import Tiler


def colormap(channel):
    """Find default colormap."""
    if "GFP" in channel:
        colormap = "green"
    elif "Cherry" in channel or "RFP" in channel:
        colormap = "red"
    else:
        colormap = "gray"
    return colormap


class BaseImageViewer(ABC):
    """Base class with routines common to all ImageViewers."""

    def __init__(self, h5file_path):
        """Initialise from a Path to a h5 file."""
        self.h5file_path = h5file_path
        print(f"Viewing {str(h5file_path)}")
        self.full = {}

    def get_tiles(self, trap_id, tps, cell_only=True):
        """Get dict of tiles with channel indices as keys."""
        tiles_dict = {}
        channels = self.tiler.channels
        channel_indices = [channels.index(ch) for ch in channels]
        for ch_index, ch in zip(channel_indices, channels):
            tile_dict_for_ch = self.get_all_tiles(tps, ch_index)
            tiles = [x[trap_id] for x in tile_dict_for_ch.values()]
            if ch == "Brightfield":
                tiles_dict[ch] = tiles
            else:
                masks = [
                    self.cells.at_time(tp, kind="mask").get(trap_id, [])
                    for tp in tps
                ]
                # some masks may be empty
                default_mask = [np.ones(self.cells.tile_size).astype(bool)]
                nmasks = [m if m else default_mask for m in masks]
                # combine all masks for each time point
                stacked_masks = [
                    np.stack([mask for mask in masks_tp]).max(axis=0)
                    for masks_tp in nmasks
                ]
                # make tiles with fluorescence only in mask pixels
                new_tiles = []
                for tile, stacked_mask in zip(tiles, stacked_masks):
                    tile[~stacked_mask] = 0
                    new_tiles.append(tile)
                tiles_dict[ch] = new_tiles
        return tiles_dict

    def get_outlines(self, trap_id, tps):
        """Get uniquely labelled outlines for each cell time point."""
        # get outlines for each time point
        outlines = [
            self.cells.at_time(tp, kind="edgemask").get(trap_id, [])
            for tp in tps
        ]
        # get cell labels for each time point
        cell_labels = [
            self.cells.labels_at_time(tp).get(trap_id, []) for tp in tps
        ]
        # generate one image with all cell outlines uniquely labelled per tile
        labelled_outlines = [
            (
                np.stack(
                    [
                        (outline * label)
                        for outline, label in zip(outlines_tp, labels_tp)
                    ]
                ).max(axis=0)
                if len(labels_tp)
                else np.zeros(self.cells.tile_size).astype(bool)
            )
            for outlines_tp, labels_tp in zip(outlines, cell_labels)
        ]
        return labelled_outlines

    def get_all_tiles(
        self,
        tps,
        channel_index,
        z=0,
    ):
        """
        Get dict with time points as keys and all available tiles as values.

        We assume there is only a single channel.
        """
        z = z or self.tiler.ref_z
        ch_tps = [(channel_index, tp) for tp in tps]
        for ch, tp in ch_tps:
            if (ch, tp) not in self.full:
                self.full[(ch, tp)] = self.tiler.get_tiles_timepoint(
                    tp, channels=[ch], z=[z]
                )[:, 0, 0, z, ...]
        tile_dict = {tp: self.full[(ch, tp)] for ch, tp in ch_tps}
        return tile_dict

    def get_data_for_viewing(self, trap_id, tps):
        """Get images and outlines as multidimensional arrays for Napari."""
        # get outlines and tiles
        outlines = self.get_outlines(trap_id, tps)
        tiles_dict = self.get_tiles(trap_id, tps)
        channels = list(tiles_dict.keys())
        # put time series into one array with dimensions TCZYX
        ydim, xdim = tiles_dict[list(tiles_dict.keys())[0]][0].shape
        ts_images = np.zeros(
            (tps.size, len(tiles_dict), 1, ydim, xdim)
        ).astype(int)
        ts_labels = np.zeros((tps.size, 1, ydim, xdim)).astype(int)
        # make array of time series of tiles
        for ch_index, channel in enumerate(tiles_dict):
            for tp_index in range(tps.size):
                ts_images[tp_index, ch_index, 0, ...] = tiles_dict[channel][
                    tp_index
                ]
        # make array of time series of outlines with no channels dimension
        for tp_index in range(tps.size):
            ts_labels[tp_index, 0, ...] = outlines[tp_index]
        return ts_images, ts_labels, channels

    def view(self, trap_id, tps=10):
        """
        Use Napari to view all channels and outlines for a particular trap.

        Fluorescence channels will not be immediately visible.

        Parameters
        ----------
        trap_id: int
            The trap to be viewed.
        tps: int or array of ints
            Either the last time point to be viewed or a rage of time points
            to view.
            If None, all time points will be viewed, but gathering the images
            will be slow.
        """
        if tps is None:
            tps = np.arange(self.cells.ntimepoints)
        elif type(tps) is int:
            tps = np.arange(tps)
        ts_images, ts_labels, channels = self.get_data_for_viewing(
            trap_id, tps
        )
        # launch napari
        viewer = napari.Viewer()
        viewer.add_image(
            ts_images[:, channels.index("Brightfield"), ...],
            name="Brightfield",
        )
        viewer.add_labels(ts_labels, name="outlines")
        # fluorescence channels are not initially visible
        for i, channel in enumerate(channels):
            if channel != "Brightfield":
                viewer.add_image(
                    ts_images[:, i, ...],
                    name=channel,
                    colormap=colormap(channel),
                    visible=False,
                    opacity=0.5,
                )


class LocalImageViewer(BaseImageViewer):
    """
    View images from local files.

    File are either zarr or organised in directories.
    """

    def __init__(self, h5file: str, image_file: str):
        """Initialise using a h5file and a zarr file of images."""
        h5file_path = Path(h5file)
        image_file_path = Path(image_file)
        if h5file_path.exists() and image_file_path.exists():
            super().__init__(h5file_path)
            with dispatch_image(image_file_path)(image_file_path) as image:
                self.tiler = Tiler.from_h5(image, h5file_path)
            self.cells = Cells.from_source(h5file_path)
            traps_with_labels = [
                i for i, labels in enumerate(self.cells.labels) if labels
            ]
            print(f"Traps with labels {traps_with_labels}.")
            print(f"Maximum number of time points {self.cells.ntimepoints}.")
        else:
            if not h5file_path.exists():
                print(f" Trouble loading {h5file}.")
            if not image_file_path.exists():
                print(f" Trouble loading {image_file}.")


class RemoteImageViewer(BaseImageViewer):
    """Fetching remote images with tiling and outline display."""

    credentials = ("host", "username", "password")

    def __init__(self, h5file: str, server_info: t.Dict[str, str]):
        """Initialise using a h5file and importing aliby.io.omero."""
        from aliby.io.omero import UnsafeImage as OImage

        h5file_path = Path(h5file)
        super().__init__(h5file_path)
        server_info = server_info or {
            k: self.attrs["parameters"]["general"][k] for k in self.credentials
        }
        logfiles_meta = parse_metadata(h5file_path.parent)
        image_id = logfiles_meta.get("image_id")
        if image_id is None:
            with h5py.File(h5file_path, "r") as f:
                image_id = f.attrs.get("image_id")
        if image_id is None:
            raise ("No valid image_id found in metadata.")
        image = OImage(image_id, **server_info)
        self.tiler = Tiler.from_h5(image, h5file_path)
        self.cells = Cells.from_source(h5file_path)


def get_files(
    aliby_input: str,
    aliby_output: str,
    omero_name: str,
    position: str,
):
    """Find the h5 file and corresponding zarr file for one position."""
    h5files = [str(f) for f in (Path(aliby_output) / omero_name).glob("*.h5")]
    h5file = [f for f in h5files if position in f][0]
    image_file_name = h5file.split("/")[-1].split(".")[0] + ".zarr"
    image_file = str(Path(aliby_input) / omero_name / image_file_name)
    return [h5file, image_file]