From c14c2ad98292e8683f66e892d66a7b9bd0b0df32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <alan.munoz@ed.ac.uk> Date: Thu, 2 Mar 2023 11:16:50 +0000 Subject: [PATCH] change(vis_tools): refactor _overlay_mask_tile --- src/aliby/utils/vis_tools.py | 96 ++++++++++++++++++++++-------------- 1 file changed, 60 insertions(+), 36 deletions(-) diff --git a/src/aliby/utils/vis_tools.py b/src/aliby/utils/vis_tools.py index 2436b421..602e8868 100644 --- a/src/aliby/utils/vis_tools.py +++ b/src/aliby/utils/vis_tools.py @@ -41,7 +41,7 @@ def get_tiles_at_times( Parameters ---------- image_path : str - hdf5 location + hdf5 index timepoints : t.List[int] list of timepoints to fetch tile_reduction : t.Union[int, t.List[int], str, t.Callable] @@ -164,60 +164,84 @@ def crop_mask(img: np.ndarray, mask: np.ndarray): return img -def overlay_masks_tiles( +def _sample_n_tiles_masks( image_path: str, results_path: str, - masks: np.ndarray, - locations: t.Tuple[t.Tuple[int], t.Tuple[int], t.Tuple[int]], + n: int, + seed: int = 0, + interval=None, + as_generator=False, +) -> t.Tuple[t.Tuple, t.Tuple[np.ndarray, np.ndarray]]: + + cells = Cells(results_path) + indices, masks = cells._sample_masks(n, seed=seed, interval=interval) + + processed_tiles, cropped_masks = _overlay_masks_tiles( + image_path, + results_path, + masks, + [indices[i] for i in (0, 2)], + as_generator=as_generator, + ) + return indices, (processed_tiles, cropped_masks) + + +def _overlay_mask_tile( + image_path: str, + results_path: str, + mask: np.ndarray, + index: t.Tuple[int, int, int], bg_channel: int = 0, fg_channel: int = 1, reduce_z: t.Union[None, t.Callable] = np.max, + as_generator: bool = False, ) -> t.Tuple[np.ndarray, np.ndarray]: + """ + Return a tuplw with two channels + """ - tcs = np.stack( + tc = np.stack( [ - [ - fetch_tc(image_path, results_path, tp, i) - for i in (bg_channel, fg_channel) - ] - for tp in locations[1] + fetch_tc(image_path, results_path, index[1], i) + for i in (bg_channel, fg_channel) ] - ) # Returns TC(tile)ZYX + ) # Returns C(tile)ZYX - tiles = np.stack( - [tcs[i, :, tile].astype(float) for i, tile in enumerate(locations[0])] - ) + tiles = tc[:, index[0]].astype(float) reduced_z = ( - reduce_z(tiles, axis=2) if reduce_z else concatenate_dims(tiles, 2, -2) + reduce_z(tiles, axis=1) if reduce_z else concatenate_dims(tiles, 1, -2) ) - repeated_mask = np.stack( - [tile_like(mask, reduced_z[0, 0]) for mask in masks] - ) + repeated_mask = tile_like(mask, reduced_z[0]) - cropped_fg = np.stack( - [crop_mask(c, mask) for mask, c in zip(repeated_mask, reduced_z[:, 1])] - ) + cropped_fg = crop_mask(reduced_z[1], repeated_mask) - return reduced_z[:, 0], cropped_fg + return reduced_z[0], cropped_fg -def _sample_n_tiles_masks( +def _overlay_masks_tiles( image_path: str, results_path: str, - n: int, - seed: int = 0, - interval=None, -) -> t.Tuple[t.Tuple, t.Tuple[np.ndarray, np.ndarray]]: + masks: np.ndarray, + indices: t.Tuple[t.Tuple[int], t.Tuple[int], t.Tuple[int]], + bg_channel: int = 0, + fg_channel: int = 1, + reduce_z: t.Union[None, t.Callable] = np.max, + as_generator: bool = False, +) -> t.Tuple[np.ndarray, np.ndarray]: - cells = Cells(results_path) - locations, masks = cells._sample_masks(n, seed=seed, interval=interval) + tmp = [ + _overlay_mask_tile( + image_path, + results_path, + mask, + index, + bg_channel, + fg_channel, + reduce_z, + ) + for mask, index in zip(masks, zip(*indices)) + ] - processed_tiles, cropped_masks = overlay_masks_tiles( - image_path, - results_path, - masks, - [locations[i] for i in (0, 2)], - ) - return locations, (processed_tiles, cropped_masks) + return [np.stack(x) for x in zip(*tmp)] -- GitLab