From b87fc9a5e3d5320720419ed6ab3d24786313e8ee Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <alan.munoz@ed.ac.uk>
Date: Tue, 28 Feb 2023 11:51:11 +0000
Subject: [PATCH] fix(extraction): tests account for dimension used

---
 src/extraction/core/functions/distributors.py | 11 ++++++-----
 src/extraction/local_data_loaders.py          |  6 +++---
 tests/aliby/network/extraction/test_base.py   |  2 +-
 3 files changed, 10 insertions(+), 9 deletions(-)

diff --git a/src/extraction/core/functions/distributors.py b/src/extraction/core/functions/distributors.py
index 92152a88..e9b5265f 100644
--- a/src/extraction/core/functions/distributors.py
+++ b/src/extraction/core/functions/distributors.py
@@ -23,7 +23,7 @@ def trap_apply(cell_fun, cell_masks, *args, **kwargs):
     return [cell_fun(mask, *args, **kwargs) for mask in cell_masks]
 
 
-def reduce_z(trap_image: np.ndarray, fun: t.Callable):
+def reduce_z(trap_image: np.ndarray, fun: t.Callable, axis: int = 0):
     """
     Reduce the trap_image to 2d.
 
@@ -33,15 +33,16 @@ def reduce_z(trap_image: np.ndarray, fun: t.Callable):
         Images for all the channels associated with a trap
     fun: function
         Function to execute the reduction
-
+    axis: int (default 0)
+        Axis in which we apply the reduction operation.
     """
     # FUTURE replace with py3.10's match-case.
     if (
         hasattr(fun, "__module__") and fun.__module__[:10] == "bottleneck"
     ):  # Bottleneck type
-        return getattr(bn.reduce, fun.__name__)(trap_image, axis=2)
+        return getattr(bn.reduce, fun.__name__)(trap_image, axis=axis)
     elif isinstance(fun, np.ufunc):
         # optimise the reduction function if possible
-        return fun.reduce(trap_image, axis=2)
+        return fun.reduce(trap_image, axis=axis)
     else:  # WARNING: Very slow, only use when no alternatives exist
-        return np.apply_along_axis(fun, 2, trap_image)
+        return np.apply_along_axis(fun, axis, trap_image)
diff --git a/src/extraction/local_data_loaders.py b/src/extraction/local_data_loaders.py
index 8f9a371d..0079b056 100644
--- a/src/extraction/local_data_loaders.py
+++ b/src/extraction/local_data_loaders.py
@@ -29,12 +29,13 @@ def load_tiled_image(filename):
     nt = info.get("ntiles", 1)
     nr, nc = info.get("layout", (1, 1))
     nc_final_row = np.mod(nt, nc)
-    img = np.zeros((tw, th, nt), dtype=tImg.dtype)
+    img = np.zeros((nt, tw, th), dtype=tImg.dtype)
     for i in range(nr):
         i_nc = nc_final_row if i + 1 == nr and nc_final_row > 0 else nc
         for j in range(i_nc):
             ind = i * nc + j
-            img[:, :, ind] = tImg[i * tw : (i + 1) * tw, j * th : (j + 1) * th]
+            img[ind, :, :] = tImg[i * tw : (i + 1) * tw, j * th : (j + 1) * th]
+
     return img, info
 
 
@@ -72,7 +73,6 @@ def load(path=None):
     """
     if path is None:
 
-        # FUTURE can be replaced by importlib.resources.files('aliby') after upgrading to 3.11
         path = (
             files("aliby").parent.parent
             / "examples"
diff --git a/tests/aliby/network/extraction/test_base.py b/tests/aliby/network/extraction/test_base.py
index 59ebc976..14209dbd 100644
--- a/tests/aliby/network/extraction/test_base.py
+++ b/tests/aliby/network/extraction/test_base.py
@@ -30,7 +30,7 @@ def test_metrics_run(imgs, masks, f):
 
     for ch, img in imgs.items():
         if ch != "segoutlines":
-            assert tuple(masks.shape[:2]) == tuple(imgs[ch].shape)
+            assert tuple(masks.shape[-2:]) == tuple(imgs[ch].shape)
             f(masks, img)
 
 
-- 
GitLab