From 3b9c9edadf07ad8be5bc1f797fa37e0603e07d83 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <amuoz@ed.ac.uk>
Date: Thu, 14 Jul 2022 18:14:34 +0100
Subject: [PATCH] refactor(pipeline): split methods and add types

---
 aliby/pipeline.py | 518 ++++++++++++++++++++++++----------------------
 1 file changed, 273 insertions(+), 245 deletions(-)

diff --git a/aliby/pipeline.py b/aliby/pipeline.py
index 7f02704b..be6936ed 100644
--- a/aliby/pipeline.py
+++ b/aliby/pipeline.py
@@ -2,17 +2,14 @@
 Pipeline and chaining elements.
 """
 import logging
+import typing as t
 import os
 import re
 import traceback
 from copy import copy
 from itertools import groupby
 from pathlib import Path, PosixPath
-
-# from p_tqdm import p_map
 from time import perf_counter
-
-# from abc import ABC, abstractmethod
 from typing import Union
 
 import h5py
@@ -54,6 +51,8 @@ logging.basicConfig(
 
 
 class PipelineParameters(ParametersABC):
+    _pool_index = None
+
     def __init__(
         self, general, tiler, baby, extraction, postprocessing, reporting
     ):
@@ -72,7 +71,7 @@ class PipelineParameters(ParametersABC):
         baby={},
         extraction={},
         postprocessing={},
-        reporting={},
+        # reporting={},
     ):
         """
         Load unit test experiment
@@ -137,14 +136,16 @@ class PipelineParameters(ParametersABC):
                 defaults["general"][k] = v
 
         defaults["tiler"] = TilerParameters.default(**tiler).to_dict()
-        defaults["baby"] = BabyParameters.default(**extraction).to_dict()
-        defaults["extraction"] = exparams_from_meta(meta)
+        defaults["baby"] = BabyParameters.default(**baby).to_dict()
+        defaults["extraction"] = (
+            exparams_from_meta(meta)
+            or BabyParameters.default(**extraction).to_dict()
+        )
         defaults["postprocessing"] = PostProcessorParameters.default(
             **postprocessing
         ).to_dict()
         defaults["reporting"] = {}
-        # for k in defaults.keys():
-        #     exec("defaults[k].update(" + k + ")")
+
         return cls(**{k: v for k, v in defaults.items()})
 
     def load_logs(self):
@@ -161,12 +162,15 @@ class Pipeline(ProcessABC):
     """
 
     iterative_steps = ["tiler", "baby", "extraction"]
+
     step_sequence = [
         "tiler",
         "baby",
         "extraction",
         "postprocessing",
     ]
+
+    # Indicate groupings to perform special operations during step iteration
     writer_groups = {
         "tiler": ["trap_info"],
         "baby": ["cell_info"],
@@ -244,7 +248,6 @@ class Pipeline(ProcessABC):
         pipeline_parameters.general["directory"] = directory
         pipeline_parameters.general["filter"] = Path(fpath).stem
 
-        # Fix legacy postprocessing parameters
         post_process_params = pipeline_parameters.postprocessing.get(
             "parameters", None
         )
@@ -322,13 +325,6 @@ class Pipeline(ProcessABC):
                     # position=0,
                 )
 
-            # results = p_map(
-            #     lambda x: self.create_pipeline(*x),
-            #     [(k, i) for i, k in enumerate(image_ids.items())],
-            #     num_cpus=distributed,
-            #     position=0,
-            # )
-
         else:  # Sequential
             results = []
             for k, v in tqdm(image_ids.items()):
@@ -338,142 +334,42 @@ class Pipeline(ProcessABC):
         return results
 
     def create_pipeline(self, image_id, index=None):
-        config = self.parameters.to_dict()
-        pparams = config
+        self._pool_index = index
         name, image_id = image_id
-        general_config = config["general"]
         session = None
-        earlystop = general_config.get("earlystop", None)
-        process_from = {k: 0 for k in self.iterative_steps}
-        steps = {}
-        ow = {k: 0 for k in self.step_sequence}
-
-        # check overwriting
-        ow_id = general_config.get("overwrite", 0)
-        ow = {step: True for step in self.step_sequence}
-        if ow_id and ow_id != True:
-            ow = {
-                step: self.step_sequence.index(ow_id) < i
-                for i, step in enumerate(self.step_sequence, 1)
-            }
-
+        filename = None
+        run_kwargs = {"extraction": {"labels": None, "masks": None}}
         try:
-            # Set up
-            directory = general_config["directory"]
+            (
+                filename,
+                meta,
+                config,
+                process_from,
+                tps,
+                steps,
+                earlystop,
+                session,
+                trackers_state,
+            ) = self._setup_pipeline(image_id)
+
+            loaded_writers = {
+                name: writer(filename)
+                for k in self.step_sequence
+                if k in self.writers
+                for name, writer in self.writers[k]
+            }
+            writer_ow_kwargs = {
+                "state": loaded_writers["state"].datatypes.keys(),
+                "baby": ["mother_assign"],
+            }
 
-            image_wrapper = get_image_class(image_id)
+            # START PIPELINE
+            frac_clogged_traps = 0
+            min_process_from = min(process_from.values())
 
-            with image_wrapper(
+            with get_image_class(image_id)(
                 image_id, **self.general.get("server_info", {})
             ) as image:
-                filename = f"{directory}/{image.name}.h5"
-                meta = MetaData(directory, filename)
-
-                from_start = True
-                trackers_state = None
-                if Path(filename).exists():  # If no previous segmentation
-
-                    if not ow["tiler"]:  # Try to load config from file
-                        try:
-                            with h5py.File(filename, "r") as f:
-                                steps["tiler"] = Tiler.from_hdf5(
-                                    image, filename
-                                )
-
-                                legacy_get_last_tp = {  # Function to support seg in ver < 0.24
-                                    "tiler": lambda f: f[
-                                        "trap_info/drifts"
-                                    ].shape[0]
-                                    - 1,
-                                    "baby": lambda f: f["cell_info/timepoint"][
-                                        -1
-                                    ],
-                                    "extraction": lambda f: f[
-                                        "extraction/general/None/area/timepoint"
-                                    ][-1],
-                                }
-                                for k, v in process_from.items():
-                                    if not ow[k]:
-                                        process_from[k] = legacy_get_last_tp[
-                                            k
-                                        ](f)
-                                        # process_from[k] = f[
-                                        #     self.writer_groups[k][-1]
-                                        # ].attrs.get(
-                                        #     "last_processed",
-                                        #     legacy_get_last_tp[k](f),
-                                        # )
-                                        process_from[k] += 1
-                                # get state array
-                                if not ow["baby"]:
-                                    trackers_state = StateReader(
-                                        filename
-                                    ).get_formatted_states()
-
-                            config["tiler"] = steps[
-                                "tiler"
-                            ].parameters.to_dict()
-                            if not np.any(ow.values()):
-                                from_start = False
-                                print(
-                                    f"Existing file {filename} will be used."
-                                )
-                        except Exception as e:
-                            print(e)
-
-                    # Delete datasets to overwrite and update pipeline data
-                    with h5py.File(filename, "r") as f:
-                        pparams = PipelineParameters.from_yaml(
-                            f.attrs["parameters"]
-                        ).to_dict()
-
-                    with h5py.File(filename, "a") as f:
-                        for k, v in ow.items():
-                            if v:
-                                for gname in self.writer_groups[k]:
-                                    if gname in f:
-                                        del f[gname]
-
-                            pparams[k] = config[k]
-                    meta.add_fields(
-                        {
-                            "parameters": PipelineParameters.from_dict(
-                                pparams
-                            ).to_yaml()
-                        },
-                        overwrite=True,
-                    )
-
-                if from_start:  # New experiment or overwriting
-                    if config.get("overwrite", False) == True or np.all(
-                        list(ow.values())
-                    ):
-                        if Path(filename).exists():
-                            os.remove(filename)
-                    meta.run()
-                    meta.add_fields(  # Add non-logfile metadata
-                        {
-                            "omero_id,": config["general"]["id"],
-                            "image_id": image_id,
-                            "parameters": PipelineParameters.from_dict(
-                                pparams
-                            ).to_yaml(),
-                        }
-                    )
-
-                tps = min(general_config["tps"], image.data.shape[0])
-
-                run_kwargs = {"extraction": {"labels": None, "masks": None}}
-                loaded_writers = {
-                    name: writer(filename)
-                    for k in self.step_sequence
-                    if k in self.writers
-                    for name, writer in self.writers[k]
-                }
-                writer_ow_kwargs = {
-                    "state": loaded_writers["state"].datatypes.keys(),
-                    "baby": ["mother_assign"],
-                }
 
                 # Initialise Steps
                 if "tiler" not in steps:
@@ -508,7 +404,7 @@ class Pipeline(ProcessABC):
                         [c + "_bgsub" for c in config["extraction"]["sub_bg"]]
                     )
                     tmp = copy(config["extraction"]["multichannel_ops"])
-                    for op, (input_ch, op_id, red_ext) in tmp.items():
+                    for op, (input_ch, _, _) in tmp.items():
                         if not set(input_ch).issubset(av_channels_wsub):
                             del config["extraction"]["multichannel_ops"][op]
 
@@ -518,112 +414,101 @@ class Pipeline(ProcessABC):
                     steps["extraction"] = Extractor.from_tiler(
                         exparams, store=filename, tiler=steps["tiler"]
                     )
+                    pbar = tqdm(
+                        range(min_process_from, tps),
+                        desc=image.name,
+                        initial=min_process_from,
+                        total=tps,
+                        # position=index + 1,
+                    )
+                    for i in pbar:
 
-                # RUN
-                # Adjust tps based on how many tps are available on the server
-                frac_clogged_traps = 0
-                # print(f"Processing from {process_from}")
-                min_process_from = min(process_from.values())
-                pbar = tqdm(
-                    range(min_process_from, tps),
-                    desc=image.name,
-                    initial=min_process_from,
-                    total=tps,
-                    # position=index + 1,
-                )
-                for i in pbar:
-
-                    if (
-                        frac_clogged_traps < earlystop["thresh_pos_clogged"]
-                        or i < earlystop["min_tp"]
-                    ):
-
-                        for step in self.iterative_steps:
-                            if i >= process_from[step]:
-                                t = perf_counter()
-                                result = steps[step].run_tp(
-                                    i, **run_kwargs.get(step, {})
-                                )
-                                logging.debug(
-                                    f"Timing:{step}:{perf_counter() - t}s"
-                                )
-                                if step in loaded_writers:
+                        if (
+                            frac_clogged_traps
+                            < earlystop["thresh_pos_clogged"]
+                            or i < earlystop["min_tp"]
+                        ):
+
+                            for step in self.iterative_steps:
+                                if i >= process_from[step]:
                                     t = perf_counter()
-                                    loaded_writers[step].write(
-                                        data=result,
-                                        overwrite=writer_ow_kwargs.get(
-                                            step, []
-                                        ),
-                                        tp=i,
-                                        meta={"last_processed": i},
+                                    result = steps[step].run_tp(
+                                        i, **run_kwargs.get(step, {})
                                     )
                                     logging.debug(
-                                        f"Timing:Writing-{step}:{perf_counter() - t}s"
+                                        f"Timing:{step}:{perf_counter() - t}s"
                                     )
+                                    if step in loaded_writers:
+                                        t = perf_counter()
+                                        loaded_writers[step].write(
+                                            data=result,
+                                            overwrite=writer_ow_kwargs.get(
+                                                step, []
+                                            ),
+                                            tp=i,
+                                            meta={"last_processed": i},
+                                        )
+                                        logging.debug(
+                                            f"Timing:Writing-{step}:{perf_counter() - t}s"
+                                        )
 
-                                # Step-specific actions
-
-                                if step == "tiler":
-                                    if i == min_process_from:
+                                    # Step-specific actions
+                                    if (
+                                        step == "tiler"
+                                        and i == min_process_from
+                                    ):
                                         print(
                                             f"Found {steps['tiler'].n_traps} traps in {image.name}"
                                         )
-                                elif (
-                                    step == "baby"
-                                ):  # Write state and pass info to ext
-                                    loaded_writers["state"].write(
-                                        data=steps[
-                                            step
-                                        ].crawler.tracker_states,
-                                        overwrite=loaded_writers[
-                                            "state"
-                                        ].datatypes.keys(),
-                                        tp=i,
-                                    )
-                                    labels, masks = groupby_traps(
-                                        result["trap"],
-                                        result["cell_label"],
-                                        result["edgemasks"],
-                                        steps["tiler"].n_traps,
-                                    )
-
-                                elif (
-                                    step == "extraction"
-                                ):  # Remove mask/label after ext
-                                    for k in ["masks", "labels"]:
-                                        run_kwargs[step][k] = None
-
-                        frac_clogged_traps = self.check_earlystop(
-                            filename, earlystop, steps["tiler"].tile_size
-                        )
-                        logging.debug(
-                            f"Quality:Clogged_traps:{frac_clogged_traps}"
-                        )
-
-                        frac = np.round(frac_clogged_traps * 100)
-                        pbar.set_postfix_str(f"{frac} Clogged")
-                    else:  # Stop if more than X% traps are clogged
-                        logging.debug(
-                            f"EarlyStop:{earlystop['thresh_pos_clogged']*100}% traps clogged at time point {i}"
-                        )
-                        print(
-                            f"Stopping analysis at time {i} with {frac_clogged_traps} clogged traps"
-                        )
-                        meta.add_fields({"end_status": "Clogged"})
-                        break
-
-                    meta.add_fields({"last_processed": i})
-                # Run post processing
-
-                meta.add_fields({"end_status": "Success"})
-                post_proc_params = PostProcessorParameters.from_dict(
-                    config["postprocessing"]
-                )
-                PostProcessor(filename, post_proc_params).run()
+                                    elif (
+                                        step == "baby"
+                                    ):  # Write state and pass info to ext
+                                        loaded_writers["state"].write(
+                                            data=steps[
+                                                step
+                                            ].crawler.tracker_states,
+                                            overwrite=loaded_writers[
+                                                "state"
+                                            ].datatypes.keys(),
+                                            tp=i,
+                                        )
+                                    elif (
+                                        step == "extraction"
+                                    ):  # Remove mask/label after ext
+                                        for k in ["masks", "labels"]:
+                                            run_kwargs[step][k] = None
+
+                            frac_clogged_traps = self.check_earlystop(
+                                filename, earlystop, steps["tiler"].tile_size
+                            )
+                            logging.debug(
+                                f"Quality:Clogged_traps:{frac_clogged_traps}"
+                            )
+
+                            frac = np.round(frac_clogged_traps * 100)
+                            pbar.set_postfix_str(f"{frac} Clogged")
+                        else:  # Stop if more than X% traps are clogged
+                            logging.debug(
+                                f"EarlyStop:{earlystop['thresh_pos_clogged']*100}% traps clogged at time point {i}"
+                            )
+                            print(
+                                f"Stopping analysis at time {i} with {frac_clogged_traps} clogged traps"
+                            )
+                            meta.add_fields({"end_status": "Clogged"})
+                            break
+
+                        meta.add_fields({"last_processed": i})
+                    # Run post processing
+
+                    meta.add_fields({"end_status": "Success"})
+                    post_proc_params = PostProcessorParameters.from_dict(
+                        config["postprocessing"]
+                    )
+                    PostProcessor(filename, post_proc_params).run()
 
-                # return 1
+                    return 1
 
-        except Exception as e:  # bug in the trap getting
+        except Exception as e:  # bug during setup or runtime
             logging.exception(
                 f"Caught exception in worker thread (x = {name}):",
                 exc_info=True,
@@ -632,11 +517,9 @@ class Pipeline(ProcessABC):
             # This prints the type, value, and stack trace of the
             # current exception being handled.
             traceback.print_exc()
-            print()
             raise e
         finally:
-            if session:
-                session.close()
+            _close_session(session)
 
         # try:
         #     compiler = ExperimentCompiler(None, filepath)
@@ -666,6 +549,146 @@ class Pipeline(ProcessABC):
 
         return (traps_above_nthresh & traps_above_athresh).mean()
 
+    def _load_config_from_file(
+        self,
+        filename: PosixPath,
+        process_from: t.Dict[str, int],
+        trackers_state: t.List,
+        overwrite: t.Dict[str, bool],
+    ):
+        with h5py.File(filename, "r") as f:
+            for k in process_from.keys():
+                if not overwrite[k]:
+                    process_from[k] = self.legacy_get_last_tp[k](f)
+                    process_from[k] += 1
+        return process_from, trackers_state, overwrite
+
+    @staticmethod
+    def legacy_get_last_tp(step: str) -> t.Callable:
+        """Get last time-point in different ways depending
+        on which step we are using
+
+        To support segmentation in aliby < v0.24
+        TODO Deprecate and replace with State method
+        """
+        switch_case = {
+            "tiler": lambda f: f["trap_info/drifts"].shape[0] - 1,
+            "baby": lambda f: f["cell_info/timepoint"][-1],
+            "extraction": lambda f: f[
+                "extraction/general/None/area/timepoint"
+            ][-1],
+        }
+        return switch_case[step]
+
+    def _setup_pipeline(self, image_id: int):
+        config = self.parameters.to_dict()
+        pparams = config
+        image_id = image_id
+        general_config = config["general"]
+        session = None
+        earlystop = general_config.get("earlystop", None)
+        process_from = {k: 0 for k in self.iterative_steps}
+        steps = {}
+        ow = {k: 0 for k in self.step_sequence}
+
+        # check overwriting
+        ow_id = general_config.get("overwrite", 0)
+        ow = {step: True for step in self.step_sequence}
+        if ow_id and ow_id is not True:
+            ow = {
+                step: self.step_sequence.index(ow_id) < i
+                for i, step in enumerate(self.step_sequence, 1)
+            }
+
+        # Set up
+        directory = general_config["directory"]
+
+        with get_image_class(image_id)(
+            image_id, **self.general.get("server_info", {})
+        ) as image:
+            filename = Path(f"{directory}/{image.name}.h5")
+            meta = MetaData(directory, filename)
+
+            from_start = False if np.any(ow.values()) else True
+            trackers_state = []
+            # If no previous segmentation and keep tiler
+            if filename.exists():
+                if not ow["tiler"]:
+                    steps["tiler"] = Tiler.from_hdf5(image, filename)
+                    try:
+                        (
+                            process_from,
+                            trackers_state,
+                            ow,
+                        ) = self._load_config_from_file(
+                            filename, process_from, trackers_state, ow
+                        )
+                        # get state array
+                        trackers_state = (
+                            []
+                            if ow["baby"]
+                            else StateReader(filename).get_formatted_states()
+                        )
+
+                        config["tiler"] = steps["tiler"].parameters.to_dict()
+                    except Exception:
+                        pass
+
+                # Delete datasets to overwrite and update pipeline data
+                # Use existing parameters
+                with h5py.File(filename, "a") as f:
+                    pparams = PipelineParameters.from_yaml(
+                        f.attrs["parameters"]
+                    ).to_dict()
+
+                    for k, v in ow.items():
+                        if v:
+                            for gname in self.writer_groups[k]:
+                                if gname in f:
+                                    del f[gname]
+
+                        pparams[k] = config[k]
+                meta.add_fields(
+                    {
+                        "parameters": PipelineParameters.from_dict(
+                            pparams
+                        ).to_yaml()
+                    },
+                    overwrite=True,
+                )
+
+            if from_start:  # New experiment or overwriting
+                if (
+                    config.get("overwrite", False) is True
+                    or np.all(list(ow.values()))
+                ) and filename.exists():
+                    os.remove(filename)
+
+                meta.run()
+                meta.add_fields(  # Add non-logfile metadata
+                    {
+                        "omero_id,": config["general"]["id"],
+                        "image_id": image_id,
+                        "parameters": PipelineParameters.from_dict(
+                            pparams
+                        ).to_yaml(),
+                    }
+                )
+
+            tps = min(general_config["tps"], image.data.shape[0])
+
+            return (
+                filename,
+                meta,
+                config,
+                process_from,
+                tps,
+                steps,
+                earlystop,
+                session,
+                trackers_state,
+            )
+
 
 def groupby_traps(traps, labels, edgemasks, ntraps):
     # Group data by traps to pass onto extractor without re-reading hdf5
@@ -685,3 +708,8 @@ def groupby_traps(traps, labels, edgemasks, ntraps):
     masks = {i: mask_d.get(i, []) for i in range(ntraps)}
 
     return labels, masks
+
+
+def _close_session(session):
+    if session:
+        session.close()
-- 
GitLab