diff --git a/src/aliby/utils/plot.py b/src/aliby/utils/plot.py index 72cf29a441589a7de2252a155e33e4468ea37380..16734d206498160d31cbbe6e630a6488c592b7d5 100644 --- a/src/aliby/utils/plot.py +++ b/src/aliby/utils/plot.py @@ -3,10 +3,31 @@ Basic plotting functions for cell visualisation """ +import typing as t + +import numpy as np +from grid_strategy import strategies from matplotlib import pyplot as plt -def plot_overlay(bg, fg, alpha=0.5, ax=plt) -> None: +def plot_overlay( + bg: np.ndarray, fg: np.ndarray, alpha: float = 0.5, ax=plt +) -> None: + ax.imshow(bg, cmap=plt.cm.gray, interpolation="none") ax.imshow(fg, alpha=alpha, interpolation="none") ax.axis("off") + + +def plot_overlay_in_square(data: t.Tuple[np.ndarray, np.ndarray]): + specs = strategies.SquareStrategy("center").get_grid(len(data)) + for i, (gs, (tile, mask)) in enumerate(zip(specs, data)): + ax = plt.subplot(gs) + plot_overlay(tile, mask, ax=ax) + + +def plot_in_square(data: t.Iterable): + specs = strategies.SquareStrategy("center").get_grid(len(data)) + for i, (gs, datum) in enumerate(zip(specs, data)): + ax = plt.subplot(gs) + ax.imshow(datum)