From df4db391887ee4bd3e3f42ac1de3f39b5c4cecb7 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <amuoz@ed.ac.uk>
Date: Fri, 24 Jun 2022 17:12:53 +0100
Subject: [PATCH] add support for multiple channels

---
 aliby/utils/imageViewer.py | 345 +++++++++++++++++++++++--------------
 pyproject.toml             |   5 +-
 2 files changed, 221 insertions(+), 129 deletions(-)

diff --git a/aliby/utils/imageViewer.py b/aliby/utils/imageViewer.py
index 5d5f622b..6a60ad18 100644
--- a/aliby/utils/imageViewer.py
+++ b/aliby/utils/imageViewer.py
@@ -11,10 +11,13 @@ trange = list(range(0, 30))
 ncols = 8
 
 riv = remoteImageViewer(fpath)
-riv.plot_labeled_traps(trap_id, trange, ncols)
+riv.plot_labelled_traps(trap_id, trange, ncols)
 
 """
 
+import re
+import typing as t
+
 import matplotlib.pyplot as plt
 import numpy as np
 import yaml
@@ -25,6 +28,30 @@ from skimage.morphology import dilation
 
 from aliby.io.image import Image as OImage
 from aliby.tile.tiler import Tiler
+from aliby.tile.traps import stretch_image
+
+default_colours = {
+    "Brightfield": "Greys_r",
+    "GFP": "Greens_r",
+    "mCherry": "Reds_r",
+    "cell_label": "Set1",
+}
+
+
+def custom_imshow(a, norm=None, cmap=None, *args, **kwargs):
+    """
+    Wrapper on plt.imshow function.
+    """
+    if cmap is None:
+        cmap = "Greys_r"
+    return plt.imshow(
+        a,
+        *args,
+        cmap=cmap,
+        interpolation=None,
+        interpolation_stage="rgba",
+        **kwargs,
+    )
 
 
 class localImageViewer:
@@ -96,19 +123,16 @@ class remoteImageViewer:
         # Print  cell label at a given time-point
         return self.cells.labels_at_time(tp)
 
-    def nonempty_traps_at_time(self, tp: int):
-        # Print cell label at a given time-point
-        return [
-            trap_id
-            for trap_id, labels in self.labels_at_time(tp).items()
-            if len(labels)
-        ]
-
-    def nonempty_tps_in_trap(self, trap_id: int):
-        # Print cell label at a given time-point
-        return self.cells.labels_at_time(trap_id)
+    def random_valid_trap_tp(
+        self, min_ncells: int = None, min_consecutive_tps: int = None
+    ):
+        # Call Cells convenience function to pick a random trap and tp
+        # containing cells for x cells for y
+        return self.cells.random_valid_trap_tp(
+            min_ncells=min_ncells, min_consecutive_tps=min_consecutive_tps
+        )
 
-    def get_position(self):
+    def get_entire_position(self):
         raise (NotImplementedError)
 
     def get_position_timelapse(self):
@@ -128,48 +152,76 @@ class remoteImageViewer:
             self.tiler.image = image.data
             return self.tiler.get_tc(tp, channel)
 
-    def _find_channels(self, channels):
+    def _find_channels(self, channels: str, guess: bool = True):
         channels = channels or self.tiler.ref_channel
-        if isinstance(channels, int):
+        if isinstance(channels, (int, str)):
             channels = [channels]
-        elif isinstance(channels, list) and isinstance(channels[0], str):
-            channels = [self.tiler.channels.index(ch) for ch in channels]
+        if isinstance(channels[0], str):
+            if guess:
+                channels = [self.tiler.channels.index(ch) for ch in channels]
+            else:
+                channels = [
+                    re.search(ch, tiler_channels)
+                    for ch in channels
+                    for tiler_channels in self.tiler.channels
+                ]
 
         return channels
 
-    def get_trap_timepoints(
-        self, trap_id, tps, channels=None, z=None, server_info=None
+    def get_pos_timepoints(
+        self,
+        tps: t.Union[int, t.Collection[int]],
+        channels: t.Union[str, t.Collection[str]] = None,
+        z: int = None,
+        server_info=None,
+        **kwargs,
     ):
+
+        if tps and not isinstance(tps, t.Collection):
+            tps = range(tps)
+
+        # TODO add support for multiple channels or refactor
+        if channels and not isinstance(channels, t.Collection):
+            channels = [channels]
+
+        # if z and isinstance(z, t.Collection):
+        #     z = list(z)
+        if z is None:
+            z = 0
+
         server_info = server_info or self.server_info
-        channels = self._find_channels(channels)
+        channels = 0 or self._find_channels(channels)
         z = z or self.tiler.ref_z
 
-        ch_tps = set([(channels[0], tp) for tp in tps])
+        ch_tps = [(channels[0], tp) for tp in tps]
         with OImage(self.image_id, **server_info) as image:
             self.tiler.image = image.data
-            if ch_tps.difference(self.full.keys()):
-                tps = set(tps).difference(self.full.keys())
-                for ch, tp in ch_tps:
-                    if z is 0:
-                        self.full[(ch, tp)] = self.tiler.get_traps_timepoint(
-                            tp, channels=[ch], z=[z]
-                        )[:, 0, 0, ..., 0]
-                    else:
-                        self.full[(ch, tp)] = self.tiler.get_traps_timepoint(
-                            tp, channels=[ch], z=[z]
-                        )[:, 0, 0, ..., z]
+            # if ch_tps.difference(self.full.keys()):
+            # tps = set(tps).difference(self.full.keys())
+            for ch, tp in ch_tps:
+                if (ch, tp) not in self.full:
+                    self.full[(ch, tp)] = self.tiler.get_traps_timepoint(
+                        tp, channels=[ch], z=[z]
+                    )[:, 0, 0, ..., z]
             requested_trap = {tp: self.full[(ch, tp)] for ch, tp in ch_tps}
 
             return requested_trap
 
-    def get_labeled_trap(self, trap_id, tps, **kwargs):
-        imgs = self.get_trap_timepoints(trap_id, tps, **kwargs)
+    def get_labelled_trap(
+        self,
+        trap_id: int,
+        tps: t.Union[range, t.Collection[int]],
+        **kwargs,
+    ) -> t.Tuple[np.array]:
+        imgs = self.get_pos_timepoints(tps, **kwargs)
         imgs_list = [x[trap_id] for x in imgs.values()]
         outlines = [
             self.cells.at_time(tp, kind="edgemask").get(trap_id, [])
             for tp in tps
         ]
-        lbls = [self.cells.labels_at_time(tp).get(trap_id, []) for tp in tps]
+        lbls = [
+            self.cells.labels_at_time(tp - 1).get(trap_id, []) for tp in tps
+        ]
         lbld_outlines = [
             np.dstack([mask * lbl for mask, lbl in zip(maskset, lblset)]).max(
                 axis=2
@@ -190,119 +242,158 @@ class remoteImageViewer:
         imgs = {}
 
         for ch in self._find_channels(channels):
-            out, imgs[ch] = self.get_labeled_trap(
+            out, imgs[ch] = self.get_labelled_trap(
                 trap_id, trange, channels=[ch], **kwargs
             )
         return out, imgs
 
-    def plot_labeled_zstacks(
-        self, trap_id, channels, trange, z=None, **kwargs
+    # def plot_labelled_zstacks(
+    #     self, trap_id, channels, trange, z=None, **kwargs
+    # ):
+    #     # if z is None:
+    #     #     z =
+    #     out, images = self.get_images(trap_id, trange, channels, z=z, **kwargs)
+
+    def plot_labelled_trap(
+        self,
+        trap_id: int,
+        channels,
+        trange: t.Union[range, t.Collection[int]],
+        remove_axis: bool = False,
+        savefile: str = None,
+        skip_outlines: bool = False,
+        norm: str = None,
+        ncols: int = None,
+        img_plot_kwargs: dict = {},
+        lbl_plot_kwargs: dict = {},
+        **kwargs,
     ):
-        # if z is None:
-        #     z =
-        out, images = self.get_images(trap_id, trange, channels, z=z, **kwargs)
+        if ncols is None:
+            ncols = len(trange)
+        nrows = int(np.ceil(len(trange) / ncols))
+        width = self.tiler.tile_size * ncols
 
-    def plot_labeled_channelrows(self, trap_id, channels, trange, **kwargs):
         out, images = self.get_images(trap_id, trange, channels, **kwargs)
 
         # dilation makes outlines easier to see
         out = dilation(out).astype(float)
         out[out == 0] = np.nan
 
-        img_set = np.concatenate([v for v in imgs.values()], axis=0)
-        tiled_out = np.tile(out, (len(imgs), 1))
-        plt.imshow(
-            img_set,
-            interpolation=None,
-            cmap="Greys_r",
+        channel_labels = [
+            self.tiler.channels[ch] if isinstance(ch, int) else ch
+            for ch in channels
+        ]
+
+        assert not norm or norm in (
+            "l1",
+            "l2",
+            "max",
+        ), "Invalid norm argument."
+
+        if norm and norm in ("l1", "l2", "max"):
+            images = {k: stretch_image(v) for k, v in images.items()}
+
+        # images = [concat_pad(img, width, nrows) for img in images.values()]
+        images = [concat_pad(img, width, nrows) for img in images.values()]
+        tiled_imgs = {}
+        tiled_imgs["img"] = np.concatenate(images, axis=0)
+        tiled_imgs["cell_labels"] = np.concatenate(
+            [concat_pad(out, width, nrows) for _ in images], axis=0
         )
-        plt.imshow(
-            tiled_out,
+
+        custom_imshow(
+            tiled_imgs["img"],
+            **img_plot_kwargs,
+        )
+        custom_imshow(
+            tiled_imgs["cell_labels"],
             cmap="Set1",
-            interpolation=None,
+            **lbl_plot_kwargs,
         )
         plt.yticks(
             ticks=[
-                self.tiler.tile_size * (i + 0.5) for i in range(len(channels))
-            ],
-            labels=[
-                self.tiler.channels[ch] if isinstance(ch, int) else ch
-                for ch in channels
+                self.tiler.tile_size * (i + 0.5)
+                + (i * self.tiler.tile_size * nrows)
+                for i in range(len(channels))
             ],
+            labels=channel_labels,
         )
         plt.xticks(
-            ticks=[
-                self.tiler.tile_size * (i + 0.5) for i in range(len(trange))
-            ],
-            labels=[t for t in trange],
+            ticks=[self.tiler.tile_size * (i + 0.5) for i in range(ncols)],
+            labels=["+ {} ".format(i) for i in range(ncols)],
         )
+        plt.xlabel("Additional time-points")
         plt.show()
 
-    def plot_labeled_traps(
-        self,
-        trap_id,
-        trange,
-        ncols,
-        remove_axis=False,
-        savefile=False,
-        skip_outlines=False,
-        **kwargs,
-    ):
-        """
-        Wrapper to plot a single trap over time
-
-        Parameters
-        ---------
-        :trap_id: int trap identification
-        :trange: list list of time points to fetch
-        """
-        nrows = len(trange) // ncols
-        width = self.tiler.tile_size * ncols
-        out, img = self.get_labeled_trap(trap_id, trange, **kwargs)
-
-        # dilation makes outlines easier to see
-        out = dilation(out).astype(float)
-        out[out == 0] = np.nan
-
-        def concat_pad(array):
-            return np.concatenate(
-                np.array_split(
-                    np.pad(
-                        array,
-                        ((0, 0), (0, array.shape[1] % width)),
-                        constant_values=np.nan,
-                    ),
-                    nrows,
-                    axis=1,
-                )
-            )
-
-        plt.imshow(
-            concat_pad(img),
-            interpolation=None,
-            cmap="Greys_r",
+    # def plot_labelled_trap(
+    #     self,
+    #     trap_id: int,
+    #     trange: t.Union[range, t.Collection[int]],
+    #     ncols: int,
+    #     remove_axis: bool = False,
+    #     savefile: str = None,
+    #     skip_outlines: bool = False,
+    #     **kwargs,
+    # ):
+    #     """
+    #     Wrapper to plot a single trap over time
+
+    #     Parameters
+    #     ---------
+    #     :trap_id: int trap identification
+    #     :trange: Collection or Range list of time points to fetch
+    #     """
+    #     nrows = len(trange) // ncols
+    #     width = self.tiler.tile_size * ncols
+    #     out, img = self.get_labelled_trap(trap_id, trange, **kwargs)
+
+    #     # dilation makes outlines easier to see
+    #     out = dilation(out).astype(float)
+    #     out[out == 0] = np.nan
+
+    #     # interpolation_kwargs = {""}
+
+    #     custom_imshow(
+    #         concat_pad(img),
+    #         cmap="Greys_r",
+    #     )
+    #     if not skip_outlines:
+    #         custom_imshow(
+    #             concat_pad(out),
+    #             cmap="Set1",
+    #         )
+
+    #     bbox_inches = None
+    #     if remove_axis:
+    #         plt.axis("off")
+    #         bbox_inches = "tight"
+
+    #     else:
+    #         plt.yticks(
+    #             ticks=[self.tiler.tile_size * (i + 0.5) for i in range(nrows)],
+    #             labels=[trange[0] + ncols * i for i in range(nrows)],
+    #         )
+
+    #     if not savefile:
+    #         plt.show()
+    #     else:
+    #         if np.any(out):
+    #             plt.savefig(savefile, bbox_inches=bbox_inches)
+
+
+def concat_pad(a: np.array, width, nrows):
+    """
+    Melt an array into having multiple blocks as rows
+    """
+    return np.concatenate(
+        np.array_split(
+            np.pad(
+                a,
+                # ((0, 0), (0, width - (a.shape[1] % width))),
+                ((0, 0), (0, a.shape[1] % width)),
+                constant_values=np.nan,
+            ),
+            nrows,
+            axis=1,
         )
-        if not skip_outlines:
-            plt.imshow(
-                concat_pad(out),
-                # concat_pad(mask),
-                cmap="Set1",
-                interpolation=None,
-            )
-
-        bbox_inches = None
-        if remove_axis:
-            plt.axis("off")
-            bbox_inches = "tight"
-
-        else:
-            plt.yticks(
-                ticks=[self.tiler.tile_size * (i + 0.5) for i in range(nrows)],
-                labels=[trange[0] + ncols * i for i in range(nrows)],
-            )
-
-        if not savefile:
-            plt.show()
-        else:
-            if np.any(out):
-                plt.savefig(savefile, bbox_inches=bbox_inches)
+    )
diff --git a/pyproject.toml b/pyproject.toml
index f65420af..7bcbff7c 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -24,13 +24,14 @@ requests-toolbelt = "^0.9.1"
 h5py = "2.10" # I/O into files
 imageio = "2.8.0"
 omero-py = ">=5.6.2" # contact omero server
-zeroc-ice = "3.6.5" # pin networking interface
-aliby-agora = "^0.2.29"
+aliby-agora = "^0.2.30"
 aliby-baby = "^0.1.10"
 aliby-post = "^0.1.27"
 p-tqdm = "^1.3.3" # Parallel progress bars
 xmltodict = "^0.13.0" # read ome-tiff metadata
 protobuf = "<=3.20.1" # For pytest to work
+zeroc-ice = {version="3.6.5"} # networking interface, slow to build
+# zeroc-ice = {version="3.6.5", optional=true} # To be set as optional in the future
 
 
 [tool.poetry.dev-dependencies]
-- 
GitLab