From 7967559614d2c2aa230f24fa47d13b6044cc3d2e Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <alan.munoz@ed.ac.uk>
Date: Mon, 27 Feb 2023 12:40:48 +0000
Subject: [PATCH] feat(aliby): add vis_tools

---
 src/aliby/utils/vis_tools.py | 141 +++++++++++++++++++++++++++++++++++
 1 file changed, 141 insertions(+)
 create mode 100644 src/aliby/utils/vis_tools.py

diff --git a/src/aliby/utils/vis_tools.py b/src/aliby/utils/vis_tools.py
new file mode 100644
index 00000000..3d4caef4
--- /dev/null
+++ b/src/aliby/utils/vis_tools.py
@@ -0,0 +1,141 @@
+#!/usr/bin/env jupyter
+"""
+Visualisation tools useful to generate figures cell pictures and figures from scripts.
+"""
+
+import typing as t
+
+import numpy as np
+
+from agora.io.cells import Cells
+from aliby.io.image import instatiate_image
+from aliby.tile.tiler import Tiler, TilerParameters
+
+
+def fetch_tc(image_path: str, results_path: str, t: int = 0, c: int = 0):
+    with instatiate_image(image_path) as iz:
+        tiler = Tiler.from_h5(iz, results_path, TilerParameters.default())
+        tc = tiler.get_tp_data(t, c)
+    return tc
+
+
+def get_tiles_at_times(
+    image_path: str,
+    results_path: str,
+    timepoints: t.List[int] = [0],
+    tile_reduction: t.Union[
+        int, t.List[int], str, t.Callable
+    ] = lambda x: concatenate_dims(x, 1, -1),
+    channel: int = 1,
+):
+    """Use Image and tiler to get tiled position for specific time points.
+
+    Parameters
+    ----------
+    image_path : str
+        hdf5 location
+    timepoints : t.List[int]
+        list of timepoints to fetch
+    tile_reduction : t.Union[int, t.List[int], str, t.Callable]
+        Reduce dimensionality. Generally used to collapse z-stacks into one
+
+    Examples
+    --------
+    FIXME: Add docs.
+
+
+    """
+
+    # Get the correct tile in space and time
+    with instatiate_image(image_path) as image:
+        tiler = Tiler.from_h5(image, results_path, TilerParameters.default())
+        tp_channel_stack = [
+            _dispatch_tile_reduction(tile_reduction)(
+                tiler.get_tp_data(tp, channel)
+            )
+            for tp in timepoints
+        ]
+    return tp_channel_stack
+
+
+def get_cellmasks_at_times(results_path: str, timepoints: t.List[int] = [0]):
+    return Cells(results_path).at_times(timepoints)
+
+
+def concatenate_dims(ndarray, axis1: int, axis2: int):
+    return np.concatenate(np.moveaxis(ndarray, axis1, 0), axis=axis2)
+
+
+def get_tile_mask_pairs(
+    image_path: str,
+    results_path: str,
+    timepoints: t.List[int] = [0],
+    tile_reduction=lambda x: concatenate_dims(x, 1, -1),
+) -> t.Tuple[np.ndarray, t.List[t.List[np.ndarray]]]:
+
+    return (
+        get_tiles_at_times(
+            image_path, results_path, timepoints, tile_reduction
+        ),
+        get_cellmasks_at_times(results_path, timepoints),
+    )
+
+
+def _dispatch_tile_reduction(how: t.Union[int, str, t.List[int]], axis=1):
+    """
+    Return an appropriate dimensional reduction based on the input on a specified axis.
+    If "how" is a string, it operates in dimension 1  (to match tile dimension standard Tile, Z, Y, X)
+    how: int, str or list of int
+        if int or list of int those numbers are indexed;
+        if str it assumes it is a numpy function such as np max.
+        if it is a callable it applies that operation to the array.
+        if None it returns the result as-is
+    axis: Only used when "how" is string. Determines the dimension to which the
+        standard operation is applied.
+    """
+    # FUTURE use match case when migrating to python 3.10
+
+    if how is None:
+        return lambda x: x
+    elif isinstance(how, (int, list)):
+        return lambda x: x.take(how, axis=axis)
+    elif isinstance(how, str):
+        return lambda x: getattr(x, how)(axis=axis)
+    elif isinstance(how, t.Callable):
+        return lambda x: how(x)
+    else:
+        raise Exception(f"Invalid reduction {how}")
+
+
+def tile_like(arr1, arr2):
+    """
+    Tile the first two dimensions of arr1 (ND) to match arr2 (2D)
+    """
+
+    result = arr1
+    ratio = np.divide(arr2.shape, arr1.shape[-2:]).astype(int)
+    if reps := max(ratio - 1):
+        tile_ = (
+            lambda x, n: np.tile(x, n)
+            if ratio.argmax()
+            else np.tile(x.T, n + 1).T
+        )
+        result = np.stack([tile_(mask, reps + 1) for mask in arr1])
+    return result
+
+
+def centre_mask(image: np.ndarray, mask: np.ndarray):
+    """Roll image to the centre of the image based on a mask of equal size"""
+
+    cell_centroid = (
+        np.max(np.where(mask), axis=1) + np.min(np.where(mask), axis=1)
+    ) // 2
+    tile_centre = np.array(image.shape) // 2
+    return np.roll(image, (tile_centre - cell_centroid), (0, 1))
+
+
+def long_side_vertical(arr):
+    result = arr
+    if np.subtract(*arr.shape):
+        result = arr.T
+    return result
-- 
GitLab