Skip to content
Snippets Groups Projects
Commit 9dbee0be authored by Alán Muñoz's avatar Alán Muñoz
Browse files

type(extractor): Improve overall typing in ext

parent 680720b0
No related branches found
No related tags found
No related merge requests found
...@@ -20,6 +20,15 @@ from extraction.core.functions.loaders import ( ...@@ -20,6 +20,15 @@ from extraction.core.functions.loaders import (
load_redfuns, load_redfuns,
) )
# Define types
reduction_method = t.Union[t.Callable, str, None]
extraction_tree = t.Dict[
str, t.Dict[reduction_method, t.Dict[str, t.Collection]]
]
extraction_result = t.Dict[
str, t.Dict[reduction_method, t.Dict[str, t.Dict[str, pd.Series]]]
]
# Global parameters used to load functions that either analyse cells or their background. These global parameters both allow the functions to be stored in a dictionary for access only on demand and to be defined simply in extraction/core/functions. # Global parameters used to load functions that either analyse cells or their background. These global parameters both allow the functions to be stored in a dictionary for access only on demand and to be defined simply in extraction/core/functions.
CELL_FUNS, TRAPFUNS, FUNS = load_funs() CELL_FUNS, TRAPFUNS, FUNS = load_funs()
CUSTOM_FUNS, CUSTOM_ARGS = load_custom_args() CUSTOM_FUNS, CUSTOM_ARGS = load_custom_args()
...@@ -37,7 +46,7 @@ class ExtractorParameters(ParametersABC): ...@@ -37,7 +46,7 @@ class ExtractorParameters(ParametersABC):
def __init__( def __init__(
self, self,
tree: t.Dict[str, t.Dict[t.Callable, t.List[str]]], tree: t.Dict[str, t.Dict[reduction_method, t.Collection[str]]],
sub_bg: set = set(), sub_bg: set = set(),
multichannel_ops: t.Dict = {}, multichannel_ops: t.Dict = {},
): ):
...@@ -112,8 +121,8 @@ class Extractor(ProcessABC): ...@@ -112,8 +121,8 @@ class Extractor(ProcessABC):
def __init__( def __init__(
self, self,
parameters: ExtractorParameters, parameters: ExtractorParameters,
store: str = None, store: t.Optional[str] = None,
tiler: Tiler = None, tiler: t.Optional[Tiler] = None,
): ):
""" """
Initialise Extractor. Initialise Extractor.
...@@ -193,7 +202,7 @@ class Extractor(ProcessABC): ...@@ -193,7 +202,7 @@ class Extractor(ProcessABC):
# consider only those already loaded from CUSTOM_FUNS # consider only those already loaded from CUSTOM_FUNS
funs = funs.intersection(CUSTOM_FUNS.keys()) funs = funs.intersection(CUSTOM_FUNS.keys())
# find their arguments # find their arguments
ARG_VALS = { self._custom_arg_vals = {
k: {k2: self.get_meta(k2) for k2 in v} k: {k2: self.get_meta(k2) for k2 in v}
for k, v in CUSTOM_ARGS.items() for k, v in CUSTOM_ARGS.items()
} }
...@@ -204,7 +213,10 @@ class Extractor(ProcessABC): ...@@ -204,7 +213,10 @@ class Extractor(ProcessABC):
def tmp(f): def tmp(f):
# pass extra arguments to custom function # pass extra arguments to custom function
return lambda cell_masks, trap_image: trap_apply( return lambda cell_masks, trap_image: trap_apply(
f, cell_masks, trap_image, **ARG_VALS.get(k, {}) f,
cell_masks,
trap_image,
**self._custom_arg_vals.get(k, {}),
) )
self._custom_funs[k] = tmp(f) self._custom_funs[k] = tmp(f)
...@@ -222,10 +234,10 @@ class Extractor(ProcessABC): ...@@ -222,10 +234,10 @@ class Extractor(ProcessABC):
def get_tiles( def get_tiles(
self, self,
tp: int, tp: int,
channels: list = None, channels: t.Optional[t.List[t.Union[str, int]]] = None,
z: list = None, z: t.Optional[t.List[str]] = None,
**kwargs, **kwargs,
) -> tuple: ) -> t.Optional[np.ndarray]:
""" """
Finds traps for a given time point and given channels and z-stacks. Finds traps for a given time point and given channels and z-stacks.
Returns None if no traps are found. Returns None if no traps are found.
...@@ -254,7 +266,7 @@ class Extractor(ProcessABC): ...@@ -254,7 +266,7 @@ class Extractor(ProcessABC):
if z is None: if z is None:
z = list(range(self.tiler.shape[-1])) z = list(range(self.tiler.shape[-1]))
# gets the data via tiler # gets the data via tiler
traps = ( tiles = (
self.tiler.get_tiles_timepoint( self.tiler.get_tiles_timepoint(
tp, channels=channel_ids, z=z, **kwargs tp, channels=channel_ids, z=z, **kwargs
) )
...@@ -262,15 +274,15 @@ class Extractor(ProcessABC): ...@@ -262,15 +274,15 @@ class Extractor(ProcessABC):
else None else None
) )
# data arranged as (traps, channels, timepoints, X, Y, Z) # data arranged as (traps, channels, timepoints, X, Y, Z)
return traps return tiles
def extract_traps( def extract_traps(
self, self,
traps: List[np.array], traps: t.List[np.ndarray],
masks: List[np.array], masks: t.List[np.ndarray],
metric: str, metric: str,
labels: t.Dict = None, labels: t.Dict[int, t.List[int]],
) -> dict: ) -> t.Tuple[t.Union[t.Tuple[float], t.Tuple[t.Tuple[int]]]]:
""" """
Apply a function to a whole position. Apply a function to a whole position.
...@@ -321,9 +333,9 @@ class Extractor(ProcessABC): ...@@ -321,9 +333,9 @@ class Extractor(ProcessABC):
self, self,
traps: List[np.array], traps: List[np.array],
masks: List[np.array], masks: List[np.array],
metrics: List[str], metrics: t.List[str],
**kwargs, **kwargs,
) -> dict: ) -> t.Dict[str, pd.Series]:
""" """
Returns dict with metrics as key and metrics applied to data as values for data from one timepoint. Returns dict with metrics as key and metrics applied to data as values for data from one timepoint.
""" """
...@@ -337,11 +349,11 @@ class Extractor(ProcessABC): ...@@ -337,11 +349,11 @@ class Extractor(ProcessABC):
def reduce_extract( def reduce_extract(
self, self,
traps: np.array, traps: np.ndarray,
masks: list, masks: t.List[np.ndarray],
red_metrics: dict, red_metrics: t.Dict[reduction_method, t.Collection[str]],
**kwargs, **kwargs,
) -> dict: ) -> t.Dict[str, t.Dict[reduction_method, t.Dict[str, pd.Series]]]:
""" """
Wrapper to apply reduction and then extraction. Wrapper to apply reduction and then extraction.
...@@ -381,7 +393,9 @@ class Extractor(ProcessABC): ...@@ -381,7 +393,9 @@ class Extractor(ProcessABC):
} }
return d return d
def reduce_dims(self, img: np.array, method=None) -> np.array: def reduce_dims(
self, img: np.ndarray, method: reduction_method = None
) -> np.ndarray:
""" """
Collapse a z-stack into 2d array using method. Collapse a z-stack into 2d array using method.
If method is None, return the original data. If method is None, return the original data.
...@@ -401,10 +415,10 @@ class Extractor(ProcessABC): ...@@ -401,10 +415,10 @@ class Extractor(ProcessABC):
def extract_tp( def extract_tp(
self, self,
tp: int, tp: int,
tree: dict = None, tree: t.Optional[extraction_tree] = None,
tile_size: int = 117, tile_size: int = 117,
masks=None, masks: t.Optional[t.List[np.ndarray]] = None,
labels=None, labels: t.Optional[t.List[int]] = None,
**kwargs, **kwargs,
) -> t.Dict[str, t.Dict[str, t.Dict[str, tuple]]]: ) -> t.Dict[str, t.Dict[str, t.Dict[str, tuple]]]:
""" """
...@@ -546,7 +560,7 @@ class Extractor(ProcessABC): ...@@ -546,7 +560,7 @@ class Extractor(ProcessABC):
return d return d
def get_imgs(self, channel, traps, channels=None): def get_imgs(self, channel: t.Optional[str], traps, channels=None):
""" """
Returns the image from a correct source, either raw or bgsub Returns the image from a correct source, either raw or bgsub
...@@ -663,7 +677,9 @@ class Extractor(ProcessABC): ...@@ -663,7 +677,9 @@ class Extractor(ProcessABC):
### Helpers ### Helpers
def flatten_nesteddict(nest: dict, to="series", tp: int = None) -> dict: def flatten_nesteddict(
nest: dict, to="series", tp: int = None
) -> t.Dict[str, pd.Series]:
""" """
Converts a nested extraction dict into a dict of pd.Series Converts a nested extraction dict into a dict of pd.Series
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment