From 657493195ce62d75a02e1abe39e3746bfb45b003 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <alan.munoz@ed.ac.uk>
Date: Thu, 2 Mar 2023 00:21:16 +0000
Subject: [PATCH] feat(plot): add plot_overlay

---
 src/aliby/utils/plot.py | 23 ++++++++++++++++++++++-
 1 file changed, 22 insertions(+), 1 deletion(-)

diff --git a/src/aliby/utils/plot.py b/src/aliby/utils/plot.py
index 72cf29a4..16734d20 100644
--- a/src/aliby/utils/plot.py
+++ b/src/aliby/utils/plot.py
@@ -3,10 +3,31 @@
 Basic plotting functions for cell visualisation
 """
 
+import typing as t
+
+import numpy as np
+from grid_strategy import strategies
 from matplotlib import pyplot as plt
 
 
-def plot_overlay(bg, fg, alpha=0.5, ax=plt) -> None:
+def plot_overlay(
+    bg: np.ndarray, fg: np.ndarray, alpha: float = 0.5, ax=plt
+) -> None:
+
     ax.imshow(bg, cmap=plt.cm.gray, interpolation="none")
     ax.imshow(fg, alpha=alpha, interpolation="none")
     ax.axis("off")
+
+
+def plot_overlay_in_square(data: t.Tuple[np.ndarray, np.ndarray]):
+    specs = strategies.SquareStrategy("center").get_grid(len(data))
+    for i, (gs, (tile, mask)) in enumerate(zip(specs, data)):
+        ax = plt.subplot(gs)
+        plot_overlay(tile, mask, ax=ax)
+
+
+def plot_in_square(data: t.Iterable):
+    specs = strategies.SquareStrategy("center").get_grid(len(data))
+    for i, (gs, datum) in enumerate(zip(specs, data)):
+        ax = plt.subplot(gs)
+        ax.imshow(datum)
-- 
GitLab