From b87fc9a5e3d5320720419ed6ab3d24786313e8ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <alan.munoz@ed.ac.uk> Date: Tue, 28 Feb 2023 11:51:11 +0000 Subject: [PATCH] fix(extraction): tests account for dimension used --- src/extraction/core/functions/distributors.py | 11 ++++++----- src/extraction/local_data_loaders.py | 6 +++--- tests/aliby/network/extraction/test_base.py | 2 +- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/extraction/core/functions/distributors.py b/src/extraction/core/functions/distributors.py index 92152a88..e9b5265f 100644 --- a/src/extraction/core/functions/distributors.py +++ b/src/extraction/core/functions/distributors.py @@ -23,7 +23,7 @@ def trap_apply(cell_fun, cell_masks, *args, **kwargs): return [cell_fun(mask, *args, **kwargs) for mask in cell_masks] -def reduce_z(trap_image: np.ndarray, fun: t.Callable): +def reduce_z(trap_image: np.ndarray, fun: t.Callable, axis: int = 0): """ Reduce the trap_image to 2d. @@ -33,15 +33,16 @@ def reduce_z(trap_image: np.ndarray, fun: t.Callable): Images for all the channels associated with a trap fun: function Function to execute the reduction - + axis: int (default 0) + Axis in which we apply the reduction operation. """ # FUTURE replace with py3.10's match-case. if ( hasattr(fun, "__module__") and fun.__module__[:10] == "bottleneck" ): # Bottleneck type - return getattr(bn.reduce, fun.__name__)(trap_image, axis=2) + return getattr(bn.reduce, fun.__name__)(trap_image, axis=axis) elif isinstance(fun, np.ufunc): # optimise the reduction function if possible - return fun.reduce(trap_image, axis=2) + return fun.reduce(trap_image, axis=axis) else: # WARNING: Very slow, only use when no alternatives exist - return np.apply_along_axis(fun, 2, trap_image) + return np.apply_along_axis(fun, axis, trap_image) diff --git a/src/extraction/local_data_loaders.py b/src/extraction/local_data_loaders.py index 8f9a371d..0079b056 100644 --- a/src/extraction/local_data_loaders.py +++ b/src/extraction/local_data_loaders.py @@ -29,12 +29,13 @@ def load_tiled_image(filename): nt = info.get("ntiles", 1) nr, nc = info.get("layout", (1, 1)) nc_final_row = np.mod(nt, nc) - img = np.zeros((tw, th, nt), dtype=tImg.dtype) + img = np.zeros((nt, tw, th), dtype=tImg.dtype) for i in range(nr): i_nc = nc_final_row if i + 1 == nr and nc_final_row > 0 else nc for j in range(i_nc): ind = i * nc + j - img[:, :, ind] = tImg[i * tw : (i + 1) * tw, j * th : (j + 1) * th] + img[ind, :, :] = tImg[i * tw : (i + 1) * tw, j * th : (j + 1) * th] + return img, info @@ -72,7 +73,6 @@ def load(path=None): """ if path is None: - # FUTURE can be replaced by importlib.resources.files('aliby') after upgrading to 3.11 path = ( files("aliby").parent.parent / "examples" diff --git a/tests/aliby/network/extraction/test_base.py b/tests/aliby/network/extraction/test_base.py index 59ebc976..14209dbd 100644 --- a/tests/aliby/network/extraction/test_base.py +++ b/tests/aliby/network/extraction/test_base.py @@ -30,7 +30,7 @@ def test_metrics_run(imgs, masks, f): for ch, img in imgs.items(): if ch != "segoutlines": - assert tuple(masks.shape[:2]) == tuple(imgs[ch].shape) + assert tuple(masks.shape[-2:]) == tuple(imgs[ch].shape) f(masks, img) -- GitLab