From e602d557ff61c61e3747928d80b8733228d09244 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <amuoz@ed.ac.uk>
Date: Wed, 29 Jun 2022 18:39:18 +0100
Subject: [PATCH] style(extractor): Blackify with new defaults

---
 extraction/core/extractor.py | 72 +++++++++++++++++++++++++++---------
 1 file changed, 55 insertions(+), 17 deletions(-)

diff --git a/extraction/core/extractor.py b/extraction/core/extractor.py
index 9f769eaa..6b7e0ca8 100644
--- a/extraction/core/extractor.py
+++ b/extraction/core/extractor.py
@@ -39,7 +39,7 @@ class ExtractorParameters(ParametersABC):
 
     def __init__(
         self,
-        tree: Dict[Union[str, None], Dict[Union[Callable, None], List[str]]] = None,
+        tree: Dict[str, Dict[Callable, List[str]]] = None,
         sub_bg: set = set(),
         multichannel_ops: Dict = {},
     ):
@@ -63,7 +63,9 @@ class ExtractorParameters(ParametersABC):
         """
 
         with h5py.open(store_name, "r") as f:
-            microscope = f["/"].attrs.get("microscope")  # TODO Check this with Arin
+            microscope = f["/"].attrs.get(
+                "microscope"
+            )  # TODO Check this with Arin
         assert microscope, "No metadata found"
 
         return "_".join((microscope, suffix))
@@ -95,7 +97,10 @@ class Extractor(ProcessABC):
     default_meta = {"pixel_size": 0.236, "z_size": 0.6, "spacing": 0.6}
 
     def __init__(
-        self, parameters: ExtractorParameters, store: str = None, tiler: Tiler = None
+        self,
+        parameters: ExtractorParameters,
+        store: str = None,
+        tiler: Tiler = None,
     ):
         self.params = parameters
         if store:
@@ -108,11 +113,15 @@ class Extractor(ProcessABC):
         self.load_funs()
 
     @classmethod
-    def from_tiler(cls, parameters: ExtractorParameters, store: str, tiler: Tiler):
+    def from_tiler(
+        cls, parameters: ExtractorParameters, store: str, tiler: Tiler
+    ):
         return cls(parameters, store=store, tiler=tiler)
 
     @classmethod
-    def from_img(cls, parameters: ExtractorParameters, store: str, img_meta: tuple):
+    def from_img(
+        cls, parameters: ExtractorParameters, store: str, img_meta: tuple
+    ):
         return cls(parameters, store=store, tiler=Tiler(*img_meta))
 
     @property
@@ -154,14 +163,17 @@ class Extractor(ProcessABC):
         )
         funs = funs.intersection(CUSTOM_FUNS.keys())
         ARG_VALS = {
-            k: {k2: self.get_meta(k2) for k2 in v} for k, v in CUSTOM_ARGS.items()
+            k: {k2: self.get_meta(k2) for k2 in v}
+            for k, v in CUSTOM_ARGS.items()
         }
         # self._custom_funs = {trap_apply(CUSTOM_FUNS[fun],])
         self._custom_funs = {}
         for k, f in CUSTOM_FUNS.items():
 
             def tmp(f):
-                return lambda m, img: trap_apply(f, m, img, **ARG_VALS.get(k, {}))
+                return lambda m, img: trap_apply(
+                    f, m, img, **ARG_VALS.get(k, {})
+                )
 
             self._custom_funs[k] = tmp(f)
 
@@ -187,7 +199,9 @@ class Extractor(ProcessABC):
             z = list(range(self.tiler.shape[-1]))
 
         traps = (
-            self.tiler.get_traps_timepoint(tp, channels=channel_ids, z=z, **kwargs)
+            self.tiler.get_traps_timepoint(
+                tp, channels=channel_ids, z=z, **kwargs
+            )
             if channel_ids
             else None
         )
@@ -238,7 +252,11 @@ class Extractor(ProcessABC):
         return (tuple(results), tuple(idx))
 
     def extract_funs(
-        self, traps: List[np.array], masks: List[np.array], metrics: List[str], **kwargs
+        self,
+        traps: List[np.array],
+        masks: List[np.array],
+        metrics: List[str],
+        **kwargs,
     ) -> dict:
         """
         Extract multiple metrics from a timepoint
@@ -253,7 +271,11 @@ class Extractor(ProcessABC):
         return d
 
     def reduce_extract(
-        self, traps: Union[np.array, None], masks: list, red_metrics: dict, **kwargs
+        self,
+        traps: np.array,
+        masks: list,
+        red_metrics: dict,
+        **kwargs,
     ) -> dict:
         """
         Wrapper to apply reduction and then extraction.
@@ -274,7 +296,8 @@ class Extractor(ProcessABC):
         if traps is not None:
             for red_fun in red_metrics.keys():
                 reduced_traps[red_fun] = [
-                    self.reduce_dims(trap, method=RED_FUNS[red_fun]) for trap in traps
+                    self.reduce_dims(trap, method=RED_FUNS[red_fun])
+                    for trap in traps
                 ]
 
         d = {
@@ -326,7 +349,8 @@ class Extractor(ProcessABC):
         if labels is None:
             raw_labels = cells.labels_at_time(tp)
             labels = {
-                trap_id: raw_labels.get(trap_id, []) for trap_id in range(cells.ntraps)
+                trap_id: raw_labels.get(trap_id, [])
+                for trap_id in range(cells.ntraps)
             }
 
         # masks
@@ -368,7 +392,11 @@ class Extractor(ProcessABC):
                 img = traps[:, tree_chs.index(ch), 0]
 
             d[ch] = self.reduce_extract(
-                red_metrics=red_metrics, traps=img, masks=masks, labels=labels, **kwargs
+                red_metrics=red_metrics,
+                traps=img,
+                masks=masks,
+                labels=labels,
+                **kwargs,
             )
 
             if (
@@ -396,9 +424,15 @@ class Extractor(ProcessABC):
                 )
 
         # Additional operations between multiple channels (e.g. pH calculations)
-        for name, (chs, merge_fun, red_metrics) in self.params.multichannel_ops.items():
+        for name, (
+            chs,
+            merge_fun,
+            red_metrics,
+        ) in self.params.multichannel_ops.items():
             if len(
-                set(chs).intersection(set(self.img_bgsub.keys()).union(tree_chs))
+                set(chs).intersection(
+                    set(self.img_bgsub.keys()).union(tree_chs)
+                )
             ) == len(chs):
                 imgs = [self.get_imgs(ch, traps, tree_chs) for ch in chs]
                 merged = MERGE_FUNS[merge_fun](*imgs)
@@ -435,7 +469,9 @@ class Extractor(ProcessABC):
         """
         return self.run(tps=[tp], **kwargs)
 
-    def run(self, tree=None, tps: List[int] = None, save=True, **kwargs) -> dict:
+    def run(
+        self, tree=None, tps: List[int] = None, save=True, **kwargs
+    ) -> dict:
 
         if tree is None:
             tree = self.params.tree
@@ -523,7 +559,9 @@ class Extractor(ProcessABC):
         if not hasattr(flds, "__iter__"):
             flds = [flds]
         meta_short = {k.split("/")[-1]: v for k, v in self.meta.items()}
-        return {f: meta_short.get(f, self.default_meta.get(f, None)) for f in flds}
+        return {
+            f: meta_short.get(f, self.default_meta.get(f, None)) for f in flds
+        }
 
 
 ### Helpers
-- 
GitLab