From c14c2ad98292e8683f66e892d66a7b9bd0b0df32 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <alan.munoz@ed.ac.uk>
Date: Thu, 2 Mar 2023 11:16:50 +0000
Subject: [PATCH] change(vis_tools): refactor _overlay_mask_tile

---
 src/aliby/utils/vis_tools.py | 96 ++++++++++++++++++++++--------------
 1 file changed, 60 insertions(+), 36 deletions(-)

diff --git a/src/aliby/utils/vis_tools.py b/src/aliby/utils/vis_tools.py
index 2436b421..602e8868 100644
--- a/src/aliby/utils/vis_tools.py
+++ b/src/aliby/utils/vis_tools.py
@@ -41,7 +41,7 @@ def get_tiles_at_times(
     Parameters
     ----------
     image_path : str
-        hdf5 location
+        hdf5 index
     timepoints : t.List[int]
         list of timepoints to fetch
     tile_reduction : t.Union[int, t.List[int], str, t.Callable]
@@ -164,60 +164,84 @@ def crop_mask(img: np.ndarray, mask: np.ndarray):
     return img
 
 
-def overlay_masks_tiles(
+def _sample_n_tiles_masks(
     image_path: str,
     results_path: str,
-    masks: np.ndarray,
-    locations: t.Tuple[t.Tuple[int], t.Tuple[int], t.Tuple[int]],
+    n: int,
+    seed: int = 0,
+    interval=None,
+    as_generator=False,
+) -> t.Tuple[t.Tuple, t.Tuple[np.ndarray, np.ndarray]]:
+
+    cells = Cells(results_path)
+    indices, masks = cells._sample_masks(n, seed=seed, interval=interval)
+
+    processed_tiles, cropped_masks = _overlay_masks_tiles(
+        image_path,
+        results_path,
+        masks,
+        [indices[i] for i in (0, 2)],
+        as_generator=as_generator,
+    )
+    return indices, (processed_tiles, cropped_masks)
+
+
+def _overlay_mask_tile(
+    image_path: str,
+    results_path: str,
+    mask: np.ndarray,
+    index: t.Tuple[int, int, int],
     bg_channel: int = 0,
     fg_channel: int = 1,
     reduce_z: t.Union[None, t.Callable] = np.max,
+    as_generator: bool = False,
 ) -> t.Tuple[np.ndarray, np.ndarray]:
+    """
+    Return a tuplw with two channels
+    """
 
-    tcs = np.stack(
+    tc = np.stack(
         [
-            [
-                fetch_tc(image_path, results_path, tp, i)
-                for i in (bg_channel, fg_channel)
-            ]
-            for tp in locations[1]
+            fetch_tc(image_path, results_path, index[1], i)
+            for i in (bg_channel, fg_channel)
         ]
-    )  # Returns TC(tile)ZYX
+    )  # Returns C(tile)ZYX
 
-    tiles = np.stack(
-        [tcs[i, :, tile].astype(float) for i, tile in enumerate(locations[0])]
-    )
+    tiles = tc[:, index[0]].astype(float)
 
     reduced_z = (
-        reduce_z(tiles, axis=2) if reduce_z else concatenate_dims(tiles, 2, -2)
+        reduce_z(tiles, axis=1) if reduce_z else concatenate_dims(tiles, 1, -2)
     )
 
-    repeated_mask = np.stack(
-        [tile_like(mask, reduced_z[0, 0]) for mask in masks]
-    )
+    repeated_mask = tile_like(mask, reduced_z[0])
 
-    cropped_fg = np.stack(
-        [crop_mask(c, mask) for mask, c in zip(repeated_mask, reduced_z[:, 1])]
-    )
+    cropped_fg = crop_mask(reduced_z[1], repeated_mask)
 
-    return reduced_z[:, 0], cropped_fg
+    return reduced_z[0], cropped_fg
 
 
-def _sample_n_tiles_masks(
+def _overlay_masks_tiles(
     image_path: str,
     results_path: str,
-    n: int,
-    seed: int = 0,
-    interval=None,
-) -> t.Tuple[t.Tuple, t.Tuple[np.ndarray, np.ndarray]]:
+    masks: np.ndarray,
+    indices: t.Tuple[t.Tuple[int], t.Tuple[int], t.Tuple[int]],
+    bg_channel: int = 0,
+    fg_channel: int = 1,
+    reduce_z: t.Union[None, t.Callable] = np.max,
+    as_generator: bool = False,
+) -> t.Tuple[np.ndarray, np.ndarray]:
 
-    cells = Cells(results_path)
-    locations, masks = cells._sample_masks(n, seed=seed, interval=interval)
+    tmp = [
+        _overlay_mask_tile(
+            image_path,
+            results_path,
+            mask,
+            index,
+            bg_channel,
+            fg_channel,
+            reduce_z,
+        )
+        for mask, index in zip(masks, zip(*indices))
+    ]
 
-    processed_tiles, cropped_masks = overlay_masks_tiles(
-        image_path,
-        results_path,
-        masks,
-        [locations[i] for i in (0, 2)],
-    )
-    return locations, (processed_tiles, cropped_masks)
+    return [np.stack(x) for x in zip(*tmp)]
-- 
GitLab