From 1eaca3da7388a193f9981f0a48f72ebb64c2cfbf Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <alan.munoz@ed.ac.uk>
Date: Tue, 28 Feb 2023 20:02:20 +0000
Subject: [PATCH] feat(vis_tools): add overlay_masks_tiles

---
 src/aliby/utils/vis_tools.py | 55 ++++++++++++++++++++++++++++++++++++
 1 file changed, 55 insertions(+)

diff --git a/src/aliby/utils/vis_tools.py b/src/aliby/utils/vis_tools.py
index c8937d5d..22d920b2 100644
--- a/src/aliby/utils/vis_tools.py
+++ b/src/aliby/utils/vis_tools.py
@@ -1,6 +1,9 @@
 #!/usr/bin/env jupyter
 """
 Visualisation tools useful to generate figures cell pictures and figures from scripts.
+
+These do not depend on matplotlib to work, they focus on array processing.
+To check plot-related functions look at plots.py in this folder.
 """
 import typing as t
 from copy import copy
@@ -152,3 +155,55 @@ def crop_mask(img: np.ndarray, mask: np.ndarray):
     img = copy(img).astype(float)
     img[~mask] = np.nan
     return img
+
+
+def overlay_masks_tiles(
+    image_path: str,
+    results_path: str,
+    masks: np.ndarray,
+    locations: t.Tuple[t.Tuple[int], t.Tuple[int], t.Tuple[int]],
+    bg_channel: int = 0,
+    fg_channel: int = 1,
+    reduce_z: t.Union[None, t.Callable] = np.max,
+) -> t.Tuple[np.ndarray, np.ndarray]:
+
+    tcs = np.stack(
+        [
+            [
+                fetch_tc(image_path, results_path, tp, i)
+                for i in (bg_channel, fg_channel)
+            ]
+            for tp in locations[1]
+        ]
+    )  # Returns TC(tile)ZYX
+
+    tiles = np.stack(
+        [tcs[i, :, tile].astype(float) for i, tile in enumerate(locations[0])]
+    )
+
+    reduced_z = (
+        reduce_z(tiles, axis=2) if reduce_z else concatenate_dims(tiles, 2, -2)
+    )
+
+    repeated_mask = np.stack(
+        [tile_like(mask, reduced_z[0, 0]) for mask in masks]
+    )
+
+    cropped_fg = np.stack(
+        [crop_mask(c, mask) for mask, c in zip(repeated_mask, reduced_z[:, 1])]
+    )
+
+    return reduced_z[:, 0], cropped_fg
+
+
+def _sample_n_tiles_masks(
+    image_path: str, results_path: str, n: int, seed: int = 0
+) -> t.Tuple[t.Tuple, t.Tuple[np.ndarray, np.ndarray]]:
+
+    cells = Cells(results_path)
+    locations, masks = cells._sample_masks(n, seed=seed)
+
+    processed_tiles, cropped_masks = overlay_masks_tiles(
+        image_path, results_path, masks, [locations[i] for i in (0, 2)]
+    )
+    return locations, (processed_tiles, cropped_masks)
-- 
GitLab