Skip to content
Snippets Groups Projects
Commit b87fc9a5 authored by Alán Muñoz's avatar Alán Muñoz
Browse files

fix(extraction): tests account for dimension used

parent 59ff4a72
No related branches found
No related tags found
No related merge requests found
...@@ -23,7 +23,7 @@ def trap_apply(cell_fun, cell_masks, *args, **kwargs): ...@@ -23,7 +23,7 @@ def trap_apply(cell_fun, cell_masks, *args, **kwargs):
return [cell_fun(mask, *args, **kwargs) for mask in cell_masks] return [cell_fun(mask, *args, **kwargs) for mask in cell_masks]
def reduce_z(trap_image: np.ndarray, fun: t.Callable): def reduce_z(trap_image: np.ndarray, fun: t.Callable, axis: int = 0):
""" """
Reduce the trap_image to 2d. Reduce the trap_image to 2d.
...@@ -33,15 +33,16 @@ def reduce_z(trap_image: np.ndarray, fun: t.Callable): ...@@ -33,15 +33,16 @@ def reduce_z(trap_image: np.ndarray, fun: t.Callable):
Images for all the channels associated with a trap Images for all the channels associated with a trap
fun: function fun: function
Function to execute the reduction Function to execute the reduction
axis: int (default 0)
Axis in which we apply the reduction operation.
""" """
# FUTURE replace with py3.10's match-case. # FUTURE replace with py3.10's match-case.
if ( if (
hasattr(fun, "__module__") and fun.__module__[:10] == "bottleneck" hasattr(fun, "__module__") and fun.__module__[:10] == "bottleneck"
): # Bottleneck type ): # Bottleneck type
return getattr(bn.reduce, fun.__name__)(trap_image, axis=2) return getattr(bn.reduce, fun.__name__)(trap_image, axis=axis)
elif isinstance(fun, np.ufunc): elif isinstance(fun, np.ufunc):
# optimise the reduction function if possible # optimise the reduction function if possible
return fun.reduce(trap_image, axis=2) return fun.reduce(trap_image, axis=axis)
else: # WARNING: Very slow, only use when no alternatives exist else: # WARNING: Very slow, only use when no alternatives exist
return np.apply_along_axis(fun, 2, trap_image) return np.apply_along_axis(fun, axis, trap_image)
...@@ -29,12 +29,13 @@ def load_tiled_image(filename): ...@@ -29,12 +29,13 @@ def load_tiled_image(filename):
nt = info.get("ntiles", 1) nt = info.get("ntiles", 1)
nr, nc = info.get("layout", (1, 1)) nr, nc = info.get("layout", (1, 1))
nc_final_row = np.mod(nt, nc) nc_final_row = np.mod(nt, nc)
img = np.zeros((tw, th, nt), dtype=tImg.dtype) img = np.zeros((nt, tw, th), dtype=tImg.dtype)
for i in range(nr): for i in range(nr):
i_nc = nc_final_row if i + 1 == nr and nc_final_row > 0 else nc i_nc = nc_final_row if i + 1 == nr and nc_final_row > 0 else nc
for j in range(i_nc): for j in range(i_nc):
ind = i * nc + j ind = i * nc + j
img[:, :, ind] = tImg[i * tw : (i + 1) * tw, j * th : (j + 1) * th] img[ind, :, :] = tImg[i * tw : (i + 1) * tw, j * th : (j + 1) * th]
return img, info return img, info
...@@ -72,7 +73,6 @@ def load(path=None): ...@@ -72,7 +73,6 @@ def load(path=None):
""" """
if path is None: if path is None:
# FUTURE can be replaced by importlib.resources.files('aliby') after upgrading to 3.11
path = ( path = (
files("aliby").parent.parent files("aliby").parent.parent
/ "examples" / "examples"
......
...@@ -30,7 +30,7 @@ def test_metrics_run(imgs, masks, f): ...@@ -30,7 +30,7 @@ def test_metrics_run(imgs, masks, f):
for ch, img in imgs.items(): for ch, img in imgs.items():
if ch != "segoutlines": if ch != "segoutlines":
assert tuple(masks.shape[:2]) == tuple(imgs[ch].shape) assert tuple(masks.shape[-2:]) == tuple(imgs[ch].shape)
f(masks, img) f(masks, img)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment