From 9dbee0bef687f5947c488fdba3a3e664f2997256 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <alan.munoz@ed.ac.uk>
Date: Wed, 7 Sep 2022 11:18:56 +0100
Subject: [PATCH] type(extractor): Improve overall typing in ext

---
 extraction/core/extractor.py | 68 ++++++++++++++++++++++--------------
 1 file changed, 42 insertions(+), 26 deletions(-)

diff --git a/extraction/core/extractor.py b/extraction/core/extractor.py
index db9ee726..bce0ef85 100644
--- a/extraction/core/extractor.py
+++ b/extraction/core/extractor.py
@@ -20,6 +20,15 @@ from extraction.core.functions.loaders import (
     load_redfuns,
 )
 
+# Define types
+reduction_method = t.Union[t.Callable, str, None]
+extraction_tree = t.Dict[
+    str, t.Dict[reduction_method, t.Dict[str, t.Collection]]
+]
+extraction_result = t.Dict[
+    str, t.Dict[reduction_method, t.Dict[str, t.Dict[str, pd.Series]]]
+]
+
 # Global parameters used to load functions that either analyse cells or their background. These global parameters both allow the functions to be stored in a dictionary for access only on demand and to be defined simply in extraction/core/functions.
 CELL_FUNS, TRAPFUNS, FUNS = load_funs()
 CUSTOM_FUNS, CUSTOM_ARGS = load_custom_args()
@@ -37,7 +46,7 @@ class ExtractorParameters(ParametersABC):
 
     def __init__(
         self,
-        tree: t.Dict[str, t.Dict[t.Callable, t.List[str]]],
+        tree: t.Dict[str, t.Dict[reduction_method, t.Collection[str]]],
         sub_bg: set = set(),
         multichannel_ops: t.Dict = {},
     ):
@@ -112,8 +121,8 @@ class Extractor(ProcessABC):
     def __init__(
         self,
         parameters: ExtractorParameters,
-        store: str = None,
-        tiler: Tiler = None,
+        store: t.Optional[str] = None,
+        tiler: t.Optional[Tiler] = None,
     ):
         """
         Initialise Extractor.
@@ -193,7 +202,7 @@ class Extractor(ProcessABC):
         # consider only those already loaded from CUSTOM_FUNS
         funs = funs.intersection(CUSTOM_FUNS.keys())
         # find their arguments
-        ARG_VALS = {
+        self._custom_arg_vals = {
             k: {k2: self.get_meta(k2) for k2 in v}
             for k, v in CUSTOM_ARGS.items()
         }
@@ -204,7 +213,10 @@ class Extractor(ProcessABC):
             def tmp(f):
                 # pass extra arguments to custom function
                 return lambda cell_masks, trap_image: trap_apply(
-                    f, cell_masks, trap_image, **ARG_VALS.get(k, {})
+                    f,
+                    cell_masks,
+                    trap_image,
+                    **self._custom_arg_vals.get(k, {}),
                 )
 
             self._custom_funs[k] = tmp(f)
@@ -222,10 +234,10 @@ class Extractor(ProcessABC):
     def get_tiles(
         self,
         tp: int,
-        channels: list = None,
-        z: list = None,
+        channels: t.Optional[t.List[t.Union[str, int]]] = None,
+        z: t.Optional[t.List[str]] = None,
         **kwargs,
-    ) -> tuple:
+    ) -> t.Optional[np.ndarray]:
         """
         Finds traps for a given time point and given channels and z-stacks.
         Returns None if no traps are found.
@@ -254,7 +266,7 @@ class Extractor(ProcessABC):
         if z is None:
             z = list(range(self.tiler.shape[-1]))
         # gets the data via tiler
-        traps = (
+        tiles = (
             self.tiler.get_tiles_timepoint(
                 tp, channels=channel_ids, z=z, **kwargs
             )
@@ -262,15 +274,15 @@ class Extractor(ProcessABC):
             else None
         )
         # data arranged as (traps, channels, timepoints, X, Y, Z)
-        return traps
+        return tiles
 
     def extract_traps(
         self,
-        traps: List[np.array],
-        masks: List[np.array],
+        traps: t.List[np.ndarray],
+        masks: t.List[np.ndarray],
         metric: str,
-        labels: t.Dict = None,
-    ) -> dict:
+        labels: t.Dict[int, t.List[int]],
+    ) -> t.Tuple[t.Union[t.Tuple[float], t.Tuple[t.Tuple[int]]]]:
         """
         Apply a function to a whole position.
 
@@ -321,9 +333,9 @@ class Extractor(ProcessABC):
         self,
         traps: List[np.array],
         masks: List[np.array],
-        metrics: List[str],
+        metrics: t.List[str],
         **kwargs,
-    ) -> dict:
+    ) -> t.Dict[str, pd.Series]:
         """
         Returns dict with metrics as key and metrics applied to data as values for data from one timepoint.
         """
@@ -337,11 +349,11 @@ class Extractor(ProcessABC):
 
     def reduce_extract(
         self,
-        traps: np.array,
-        masks: list,
-        red_metrics: dict,
+        traps: np.ndarray,
+        masks: t.List[np.ndarray],
+        red_metrics: t.Dict[reduction_method, t.Collection[str]],
         **kwargs,
-    ) -> dict:
+    ) -> t.Dict[str, t.Dict[reduction_method, t.Dict[str, pd.Series]]]:
         """
         Wrapper to apply reduction and then extraction.
 
@@ -381,7 +393,9 @@ class Extractor(ProcessABC):
         }
         return d
 
-    def reduce_dims(self, img: np.array, method=None) -> np.array:
+    def reduce_dims(
+        self, img: np.ndarray, method: reduction_method = None
+    ) -> np.ndarray:
         """
         Collapse a z-stack into 2d array using method.
         If method is None, return the original data.
@@ -401,10 +415,10 @@ class Extractor(ProcessABC):
     def extract_tp(
         self,
         tp: int,
-        tree: dict = None,
+        tree: t.Optional[extraction_tree] = None,
         tile_size: int = 117,
-        masks=None,
-        labels=None,
+        masks: t.Optional[t.List[np.ndarray]] = None,
+        labels: t.Optional[t.List[int]] = None,
         **kwargs,
     ) -> t.Dict[str, t.Dict[str, t.Dict[str, tuple]]]:
         """
@@ -546,7 +560,7 @@ class Extractor(ProcessABC):
 
         return d
 
-    def get_imgs(self, channel, traps, channels=None):
+    def get_imgs(self, channel: t.Optional[str], traps, channels=None):
         """
         Returns the image from a correct source, either raw or bgsub
 
@@ -663,7 +677,9 @@ class Extractor(ProcessABC):
 
 
 ### Helpers
-def flatten_nesteddict(nest: dict, to="series", tp: int = None) -> dict:
+def flatten_nesteddict(
+    nest: dict, to="series", tp: int = None
+) -> t.Dict[str, pd.Series]:
     """
     Converts a nested extraction dict into a dict of pd.Series
 
-- 
GitLab