From b5483679f44848e2b61c74cc868c170b9f388d58 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <amuoz@ed.ac.uk>
Date: Fri, 25 Mar 2022 13:16:19 +0000
Subject: [PATCH] update imageviewer, add example and plot method

---
 aliby/utils/imageViewer.py | 153 +++++++++++++++++++++----------------
 1 file changed, 86 insertions(+), 67 deletions(-)

diff --git a/aliby/utils/imageViewer.py b/aliby/utils/imageViewer.py
index b0195d3b..32281a11 100644
--- a/aliby/utils/imageViewer.py
+++ b/aliby/utils/imageViewer.py
@@ -1,8 +1,31 @@
-"""This is an example module to show the structure."""
-from typing import Union
+"""
+ImageViewer class, used to look at individual or multiple traps over time.
 
+
+Example of usage:
+
+fpath = "/home/alan/Documents/dev/skeletons/scripts/data/16543_2019_07_16_aggregates_CTP_switch_2_0glu_0_0glu_URA7young_URA8young_URA8old_01/URA8_young018.h5"
+
+trap_id = 9
+trange = list(range(0, 30))
+ncols = 8
+
+riv = remoteImageViewer(fpath)
+riv.plot_labeled_traps(trap_id, trange, ncols)
+
+"""
+
+import yaml
 import numpy as np
 from PIL import Image
+from skimage.morphology import dilation
+
+from agora.io.cells import CellsLinear as Cells
+from agora.io.writer import load_attributes
+from aliby.io.omero import Image as OImage
+from aliby.tile.tiler import Tiler
+
+import matplotlib.pyplot as plt
 
 
 class localImageViewer:
@@ -33,50 +56,27 @@ class localImageViewer:
         Image.fromarray(pixvals.astype(np.uint8))
 
 
-from aliby.tile.tiler import Tiler, TilerParameters, TrapLocations
-from agora.io.writer import load_attributes
-
-
-import json
-
-with open("/home/alan/Documents/dev/skeletons/server_info.json", "r") as f:
-    # json.dump(
-    #     {
-    #         "host": "islay.bio.ed.ac.uk",
-    #         "username": "upload",
-    #         "password": "",
-    #     },
-    #     f,
-    # )
-    server_info = json.load(f)
-
-
-import h5py
-from aliby.io.omero import Image
+class remoteImageViewer:
+    def __init__(self, fpath, server_info=None):
+        attrs = load_attributes(fpath)
 
-from agora.io.cells import CellsLinear as Cells
+        self.image_id = attrs.get("image_id")
+        assert self.image_id is not None, "No valid image_id found in metadata"
 
+        if server_info is None:
+            server_info = yaml.safe_load(attrs["parameters"])["general"]["server_info"]
+        self.server_info = server_info
 
-class remoteImageViewer:
-    def __init__(self, fpath):
-        with h5py.File(fpath, "r") as f:
-            self.image_id = f.attrs.get("image_id", None) or 105146
-        # trap_locs = TrapLocations.from_source(fpath)
-        with Image(self.image_id, **server_info) as image:
+        with OImage(self.image_id, **self.server_info) as image:
             self.tiler = Tiler.from_hdf5(image, fpath)
-        self.cells = Cells.from_source(fpath)
-        # if parameters is None:
-        #     parameters = TilerParameters.default()
 
-        # with h5py.File(hdf, "r") as f:
-        #     # image_id = f.attrs["omero_id"]
-        #     image_id = 16543
+        self.cells = Cells.from_source(fpath)
 
     def get_position(self):
-        pass
+        raise (NotImplementedError)
 
     def get_position_timelapse(self):
-        pass
+        raise (NotImplementedError)
 
     @property
     def full(self):
@@ -84,13 +84,17 @@ class remoteImageViewer:
             self._full = {}
         return self._full
 
-    def get_tc(self, tp):
-        with Image(self.image_id, **server_info) as image:
+    def get_tc(self, tp, server_info=None):
+        server_info = server_info or self.server_info
+
+        with OImage(self.image_id, **server_info) as image:
             self.tiler.image = image.data
             return self.tiler.get_tc(tp, riv.tiler.ref_channel)
 
-    def get_trap_timepoints(self, trap_id, tps):
-        with Image(self.image_id, **server_info) as image:
+    def get_trap_timepoints(self, trap_id, tps, server_info=None):
+        server_info = server_info or self.server_info
+
+        with OImage(self.image_id, **server_info) as image:
             self.tiler.image = image.data
             if set(tps).difference(self.full.keys()):
                 tps = set(tps).difference(self.full.keys())
@@ -119,30 +123,45 @@ class remoteImageViewer:
         img_concat = np.concatenate(imgs_list, axis=1)
         return outline_concat, img_concat
 
+    def plot_labeled_traps(self, trap_id, trange, ncols, **kwargs):
+        """
+        Wrapper to plot a single trap over time
 
-import matplotlib.pyplot as plt
-
-# fpath = "/home/alan/Documents/dev/skeletons/data/2019_07_16_aggregates_CTP_switch_2_0glu_0_0glu_URA7young_URA8young_URA8old_01/2019_07_16_aggregates_CTP_switch_2_0glu_0_0glu_URA7young_URA8young_URA8old_01/URA8_old007.h5"
-fpath = "/home/alan/Documents/dev/skeletons/data/2021_11_01_01_Raf_00/2021_11_01_01_Raf_00/d1134002.h5"
-riv = remoteImageViewer(fpath)
-# pos = riv.get_tc(0)
-out, img = riv.get_labeled_trap(9, list(range(0, 30)))
-out_bak = out
-out = dilation(out).astype(float)
-out[out == 0] = np.nan
-plt.imshow(
-    np.concatenate(np.array_split(img, 6, axis=1)),
-    interpolation=None,
-    cmap="Greys_r",
-)
-plt.imshow(
-    np.concatenate(np.array_split(out, 6, axis=1)),
-    cmap="Set1",
-    interpolation=None,
-)
-plt.show()
-
-concat = lambda a: np.concatenate([x for x in a])
-add = lambda a: np.sum(a, axis=0)
-# plt.imshow(add(roll(tmp[0], 10), np.roll(roll(tmp[1], 11), 6, axis=0)))
-# plt.show()
+        Parameters
+        ---------
+        :trap_id: int trap identification
+        :trange: list list of time points to fetch
+        """
+        nrows = len(trange) // ncols
+        width = riv.tiler.tile_size * ncols
+        out, img = self.get_labeled_trap(trap_id, trange)
+
+        # dilation makes outlines easier to see
+        out = dilation(out).astype(float)
+        out[out == 0] = np.nan
+
+        def concat_pad(array):
+            return np.concatenate(
+                np.array_split(
+                    np.pad(
+                        array,
+                        ((0, 0), (0, array.shape[1] % width)),
+                        constant_values=np.nan,
+                    ),
+                    nrows,
+                    axis=1,
+                )
+            )
+
+        plt.imshow(
+            concat_pad(img),
+            interpolation=None,
+            cmap="Greys_r",
+        )
+        plt.imshow(
+            concat_pad(out),
+            # concat_pad(mask),
+            cmap="Set1",
+            interpolation=None,
+        )
+        plt.show()
-- 
GitLab