From 9dbee0bef687f5947c488fdba3a3e664f2997256 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <alan.munoz@ed.ac.uk> Date: Wed, 7 Sep 2022 11:18:56 +0100 Subject: [PATCH] type(extractor): Improve overall typing in ext --- extraction/core/extractor.py | 68 ++++++++++++++++++++++-------------- 1 file changed, 42 insertions(+), 26 deletions(-) diff --git a/extraction/core/extractor.py b/extraction/core/extractor.py index db9ee726..bce0ef85 100644 --- a/extraction/core/extractor.py +++ b/extraction/core/extractor.py @@ -20,6 +20,15 @@ from extraction.core.functions.loaders import ( load_redfuns, ) +# Define types +reduction_method = t.Union[t.Callable, str, None] +extraction_tree = t.Dict[ + str, t.Dict[reduction_method, t.Dict[str, t.Collection]] +] +extraction_result = t.Dict[ + str, t.Dict[reduction_method, t.Dict[str, t.Dict[str, pd.Series]]] +] + # Global parameters used to load functions that either analyse cells or their background. These global parameters both allow the functions to be stored in a dictionary for access only on demand and to be defined simply in extraction/core/functions. CELL_FUNS, TRAPFUNS, FUNS = load_funs() CUSTOM_FUNS, CUSTOM_ARGS = load_custom_args() @@ -37,7 +46,7 @@ class ExtractorParameters(ParametersABC): def __init__( self, - tree: t.Dict[str, t.Dict[t.Callable, t.List[str]]], + tree: t.Dict[str, t.Dict[reduction_method, t.Collection[str]]], sub_bg: set = set(), multichannel_ops: t.Dict = {}, ): @@ -112,8 +121,8 @@ class Extractor(ProcessABC): def __init__( self, parameters: ExtractorParameters, - store: str = None, - tiler: Tiler = None, + store: t.Optional[str] = None, + tiler: t.Optional[Tiler] = None, ): """ Initialise Extractor. @@ -193,7 +202,7 @@ class Extractor(ProcessABC): # consider only those already loaded from CUSTOM_FUNS funs = funs.intersection(CUSTOM_FUNS.keys()) # find their arguments - ARG_VALS = { + self._custom_arg_vals = { k: {k2: self.get_meta(k2) for k2 in v} for k, v in CUSTOM_ARGS.items() } @@ -204,7 +213,10 @@ class Extractor(ProcessABC): def tmp(f): # pass extra arguments to custom function return lambda cell_masks, trap_image: trap_apply( - f, cell_masks, trap_image, **ARG_VALS.get(k, {}) + f, + cell_masks, + trap_image, + **self._custom_arg_vals.get(k, {}), ) self._custom_funs[k] = tmp(f) @@ -222,10 +234,10 @@ class Extractor(ProcessABC): def get_tiles( self, tp: int, - channels: list = None, - z: list = None, + channels: t.Optional[t.List[t.Union[str, int]]] = None, + z: t.Optional[t.List[str]] = None, **kwargs, - ) -> tuple: + ) -> t.Optional[np.ndarray]: """ Finds traps for a given time point and given channels and z-stacks. Returns None if no traps are found. @@ -254,7 +266,7 @@ class Extractor(ProcessABC): if z is None: z = list(range(self.tiler.shape[-1])) # gets the data via tiler - traps = ( + tiles = ( self.tiler.get_tiles_timepoint( tp, channels=channel_ids, z=z, **kwargs ) @@ -262,15 +274,15 @@ class Extractor(ProcessABC): else None ) # data arranged as (traps, channels, timepoints, X, Y, Z) - return traps + return tiles def extract_traps( self, - traps: List[np.array], - masks: List[np.array], + traps: t.List[np.ndarray], + masks: t.List[np.ndarray], metric: str, - labels: t.Dict = None, - ) -> dict: + labels: t.Dict[int, t.List[int]], + ) -> t.Tuple[t.Union[t.Tuple[float], t.Tuple[t.Tuple[int]]]]: """ Apply a function to a whole position. @@ -321,9 +333,9 @@ class Extractor(ProcessABC): self, traps: List[np.array], masks: List[np.array], - metrics: List[str], + metrics: t.List[str], **kwargs, - ) -> dict: + ) -> t.Dict[str, pd.Series]: """ Returns dict with metrics as key and metrics applied to data as values for data from one timepoint. """ @@ -337,11 +349,11 @@ class Extractor(ProcessABC): def reduce_extract( self, - traps: np.array, - masks: list, - red_metrics: dict, + traps: np.ndarray, + masks: t.List[np.ndarray], + red_metrics: t.Dict[reduction_method, t.Collection[str]], **kwargs, - ) -> dict: + ) -> t.Dict[str, t.Dict[reduction_method, t.Dict[str, pd.Series]]]: """ Wrapper to apply reduction and then extraction. @@ -381,7 +393,9 @@ class Extractor(ProcessABC): } return d - def reduce_dims(self, img: np.array, method=None) -> np.array: + def reduce_dims( + self, img: np.ndarray, method: reduction_method = None + ) -> np.ndarray: """ Collapse a z-stack into 2d array using method. If method is None, return the original data. @@ -401,10 +415,10 @@ class Extractor(ProcessABC): def extract_tp( self, tp: int, - tree: dict = None, + tree: t.Optional[extraction_tree] = None, tile_size: int = 117, - masks=None, - labels=None, + masks: t.Optional[t.List[np.ndarray]] = None, + labels: t.Optional[t.List[int]] = None, **kwargs, ) -> t.Dict[str, t.Dict[str, t.Dict[str, tuple]]]: """ @@ -546,7 +560,7 @@ class Extractor(ProcessABC): return d - def get_imgs(self, channel, traps, channels=None): + def get_imgs(self, channel: t.Optional[str], traps, channels=None): """ Returns the image from a correct source, either raw or bgsub @@ -663,7 +677,9 @@ class Extractor(ProcessABC): ### Helpers -def flatten_nesteddict(nest: dict, to="series", tp: int = None) -> dict: +def flatten_nesteddict( + nest: dict, to="series", tp: int = None +) -> t.Dict[str, pd.Series]: """ Converts a nested extraction dict into a dict of pd.Series -- GitLab