From e602d557ff61c61e3747928d80b8733228d09244 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <amuoz@ed.ac.uk> Date: Wed, 29 Jun 2022 18:39:18 +0100 Subject: [PATCH] style(extractor): Blackify with new defaults --- extraction/core/extractor.py | 72 +++++++++++++++++++++++++++--------- 1 file changed, 55 insertions(+), 17 deletions(-) diff --git a/extraction/core/extractor.py b/extraction/core/extractor.py index 9f769eaa..6b7e0ca8 100644 --- a/extraction/core/extractor.py +++ b/extraction/core/extractor.py @@ -39,7 +39,7 @@ class ExtractorParameters(ParametersABC): def __init__( self, - tree: Dict[Union[str, None], Dict[Union[Callable, None], List[str]]] = None, + tree: Dict[str, Dict[Callable, List[str]]] = None, sub_bg: set = set(), multichannel_ops: Dict = {}, ): @@ -63,7 +63,9 @@ class ExtractorParameters(ParametersABC): """ with h5py.open(store_name, "r") as f: - microscope = f["/"].attrs.get("microscope") # TODO Check this with Arin + microscope = f["/"].attrs.get( + "microscope" + ) # TODO Check this with Arin assert microscope, "No metadata found" return "_".join((microscope, suffix)) @@ -95,7 +97,10 @@ class Extractor(ProcessABC): default_meta = {"pixel_size": 0.236, "z_size": 0.6, "spacing": 0.6} def __init__( - self, parameters: ExtractorParameters, store: str = None, tiler: Tiler = None + self, + parameters: ExtractorParameters, + store: str = None, + tiler: Tiler = None, ): self.params = parameters if store: @@ -108,11 +113,15 @@ class Extractor(ProcessABC): self.load_funs() @classmethod - def from_tiler(cls, parameters: ExtractorParameters, store: str, tiler: Tiler): + def from_tiler( + cls, parameters: ExtractorParameters, store: str, tiler: Tiler + ): return cls(parameters, store=store, tiler=tiler) @classmethod - def from_img(cls, parameters: ExtractorParameters, store: str, img_meta: tuple): + def from_img( + cls, parameters: ExtractorParameters, store: str, img_meta: tuple + ): return cls(parameters, store=store, tiler=Tiler(*img_meta)) @property @@ -154,14 +163,17 @@ class Extractor(ProcessABC): ) funs = funs.intersection(CUSTOM_FUNS.keys()) ARG_VALS = { - k: {k2: self.get_meta(k2) for k2 in v} for k, v in CUSTOM_ARGS.items() + k: {k2: self.get_meta(k2) for k2 in v} + for k, v in CUSTOM_ARGS.items() } # self._custom_funs = {trap_apply(CUSTOM_FUNS[fun],]) self._custom_funs = {} for k, f in CUSTOM_FUNS.items(): def tmp(f): - return lambda m, img: trap_apply(f, m, img, **ARG_VALS.get(k, {})) + return lambda m, img: trap_apply( + f, m, img, **ARG_VALS.get(k, {}) + ) self._custom_funs[k] = tmp(f) @@ -187,7 +199,9 @@ class Extractor(ProcessABC): z = list(range(self.tiler.shape[-1])) traps = ( - self.tiler.get_traps_timepoint(tp, channels=channel_ids, z=z, **kwargs) + self.tiler.get_traps_timepoint( + tp, channels=channel_ids, z=z, **kwargs + ) if channel_ids else None ) @@ -238,7 +252,11 @@ class Extractor(ProcessABC): return (tuple(results), tuple(idx)) def extract_funs( - self, traps: List[np.array], masks: List[np.array], metrics: List[str], **kwargs + self, + traps: List[np.array], + masks: List[np.array], + metrics: List[str], + **kwargs, ) -> dict: """ Extract multiple metrics from a timepoint @@ -253,7 +271,11 @@ class Extractor(ProcessABC): return d def reduce_extract( - self, traps: Union[np.array, None], masks: list, red_metrics: dict, **kwargs + self, + traps: np.array, + masks: list, + red_metrics: dict, + **kwargs, ) -> dict: """ Wrapper to apply reduction and then extraction. @@ -274,7 +296,8 @@ class Extractor(ProcessABC): if traps is not None: for red_fun in red_metrics.keys(): reduced_traps[red_fun] = [ - self.reduce_dims(trap, method=RED_FUNS[red_fun]) for trap in traps + self.reduce_dims(trap, method=RED_FUNS[red_fun]) + for trap in traps ] d = { @@ -326,7 +349,8 @@ class Extractor(ProcessABC): if labels is None: raw_labels = cells.labels_at_time(tp) labels = { - trap_id: raw_labels.get(trap_id, []) for trap_id in range(cells.ntraps) + trap_id: raw_labels.get(trap_id, []) + for trap_id in range(cells.ntraps) } # masks @@ -368,7 +392,11 @@ class Extractor(ProcessABC): img = traps[:, tree_chs.index(ch), 0] d[ch] = self.reduce_extract( - red_metrics=red_metrics, traps=img, masks=masks, labels=labels, **kwargs + red_metrics=red_metrics, + traps=img, + masks=masks, + labels=labels, + **kwargs, ) if ( @@ -396,9 +424,15 @@ class Extractor(ProcessABC): ) # Additional operations between multiple channels (e.g. pH calculations) - for name, (chs, merge_fun, red_metrics) in self.params.multichannel_ops.items(): + for name, ( + chs, + merge_fun, + red_metrics, + ) in self.params.multichannel_ops.items(): if len( - set(chs).intersection(set(self.img_bgsub.keys()).union(tree_chs)) + set(chs).intersection( + set(self.img_bgsub.keys()).union(tree_chs) + ) ) == len(chs): imgs = [self.get_imgs(ch, traps, tree_chs) for ch in chs] merged = MERGE_FUNS[merge_fun](*imgs) @@ -435,7 +469,9 @@ class Extractor(ProcessABC): """ return self.run(tps=[tp], **kwargs) - def run(self, tree=None, tps: List[int] = None, save=True, **kwargs) -> dict: + def run( + self, tree=None, tps: List[int] = None, save=True, **kwargs + ) -> dict: if tree is None: tree = self.params.tree @@ -523,7 +559,9 @@ class Extractor(ProcessABC): if not hasattr(flds, "__iter__"): flds = [flds] meta_short = {k.split("/")[-1]: v for k, v in self.meta.items()} - return {f: meta_short.get(f, self.default_meta.get(f, None)) for f in flds} + return { + f: meta_short.get(f, self.default_meta.get(f, None)) for f in flds + } ### Helpers -- GitLab