From a645c836f064ad781bdf4e45534456bda6d0ccf0 Mon Sep 17 00:00:00 2001
From: Swainlab <peter.swain@ed.ac.uk>
Date: Thu, 10 Aug 2023 19:45:34 +0100
Subject: [PATCH] tidying of pipeline

---
 src/aliby/pipeline.py | 133 +++++++++++++++++++++---------------------
 1 file changed, 66 insertions(+), 67 deletions(-)

diff --git a/src/aliby/pipeline.py b/src/aliby/pipeline.py
index 7a1213a1..c31d0fc2 100644
--- a/src/aliby/pipeline.py
+++ b/src/aliby/pipeline.py
@@ -350,10 +350,10 @@ class Pipeline(ProcessABC):
                 )
         else:
             # single core
-            results = []
-            for k, v in tqdm(position_ids.items()):
-                r = self.run_one_position((k, v), 1)
-                results.append(r)
+            results = [
+                self.run_one_position((k, v), 1)
+                for k, v in tqdm(position_ids.items())
+            ]
         return results
 
     def apply_filter(self, position_ids: dict, pos_filter: int or str):
@@ -387,14 +387,14 @@ class Pipeline(ProcessABC):
         """Set up and run a pipeline for one position."""
         self._pool_index = index
         name, image_id = name_image_id
-        # session is defined by calling setup_pipeline.
+        # session is defined by calling pipe_pipeline.
         # can it be deleted here?
         session = None
         run_kwargs = {"extraction": {"cell_labels": None, "masks": None}}
         try:
-            setup, session = self.setup_pipeline(image_id)
+            pipe, session = self.setup_pipeline(image_id)
             loaded_writers = {
-                name: writer(setup["filename"])
+                name: writer(pipe["filename"])
                 for k in self.step_sequence
                 if k in self.writers
                 for name, writer in self.writers[k]
@@ -406,55 +406,54 @@ class Pipeline(ProcessABC):
 
             # START PIPELINE
             frac_clogged_traps = 0.0
-            min_process_from = min(setup["process_from"].values())
+            min_process_from = min(pipe["process_from"].values())
             with dispatch_image(image_id)(
                 image_id, **self.server_info
             ) as image:
                 # initialise steps
-                if "tiler" not in setup["steps"]:
-                    setup["steps"]["tiler"] = Tiler.from_image(
+                if "tiler" not in pipe["steps"]:
+                    pipe["steps"]["tiler"] = Tiler.from_image(
                         image,
-                        TilerParameters.from_dict(setup["config"]["tiler"]),
+                        TilerParameters.from_dict(pipe["config"]["tiler"]),
                     )
-                if setup["process_from"]["baby"] < setup["tps"]:
+                if pipe["process_from"]["baby"] < pipe["tps"]:
                     session = initialise_tf(2)
-                    setup["steps"]["baby"] = BabyRunner.from_tiler(
-                        BabyParameters.from_dict(setup["config"]["baby"]),
-                        setup["steps"]["tiler"],
+                    pipe["steps"]["baby"] = BabyRunner.from_tiler(
+                        BabyParameters.from_dict(pipe["config"]["baby"]),
+                        pipe["steps"]["tiler"],
                     )
-                    if setup["trackers_state"]:
-                        setup["steps"]["baby"].crawler.tracker_states = setup[
+                    if pipe["trackers_state"]:
+                        pipe["steps"]["baby"].crawler.tracker_states = pipe[
                             "trackers_state"
                         ]
-                # limit extraction parameters using the available channels in tiler
-                if setup["process_from"]["extraction"] < setup["tps"]:
+                if pipe["process_from"]["extraction"] < pipe["tps"]:
                     exparams = ExtractorParameters.from_dict(
-                        setup["config"]["extraction"]
+                        pipe["config"]["extraction"]
                     )
-                    setup["steps"]["extraction"] = Extractor.from_tiler(
+                    pipe["steps"]["extraction"] = Extractor.from_tiler(
                         exparams,
-                        store=setup["filename"],
-                        tiler=setup["steps"]["tiler"],
+                        store=pipe["filename"],
+                        tiler=pipe["steps"]["tiler"],
                     )
-                    # set up progress bar
+                    # initiate progress bar
                     pbar = tqdm(
-                        range(min_process_from, setup["tps"]),
+                        range(min_process_from, pipe["tps"]),
                         desc=image.name,
                         initial=min_process_from,
-                        total=setup["tps"],
+                        total=pipe["tps"],
                     )
                     # run through time points
                     for i in pbar:
                         if (
                             frac_clogged_traps
-                            < setup["earlystop"]["thresh_pos_clogged"]
-                            or i < setup["earlystop"]["min_tp"]
+                            < pipe["earlystop"]["thresh_pos_clogged"]
+                            or i < pipe["earlystop"]["min_tp"]
                         ):
                             # run through steps
                             for step in self.pipeline_steps:
-                                if i >= setup["process_from"][step]:
+                                if i >= pipe["process_from"][step]:
                                     # perform step
-                                    result = setup["steps"][step].run_tp(
+                                    result = pipe["steps"][step].run_tp(
                                         i, **run_kwargs.get(step, {})
                                     )
                                     # write to h5 file using writers
@@ -474,12 +473,12 @@ class Pipeline(ProcessABC):
                                         and i == min_process_from
                                     ):
                                         logging.getLogger("aliby").info(
-                                            f"Found {setup['steps']['tiler'].n_tiles} traps in {image.name}"
+                                            f"Found {pipe['steps']['tiler'].n_tiles} traps in {image.name}"
                                         )
                                     elif step == "baby":
-                                        # write state and pass info to Extractor
+                                        # write state
                                         loaded_writers["state"].write(
-                                            data=setup["steps"][
+                                            data=pipe["steps"][
                                                 step
                                             ].crawler.tracker_states,
                                             overwrite=loaded_writers[
@@ -493,9 +492,9 @@ class Pipeline(ProcessABC):
                                             run_kwargs[step][k] = None
                             # check and report clogging
                             frac_clogged_traps = self.check_earlystop(
-                                setup["filename"],
-                                setup["earlystop"],
-                                setup["steps"]["tiler"].tile_size,
+                                pipe["filename"],
+                                pipe["earlystop"],
+                                pipe["steps"]["tiler"].tile_size,
                             )
                             if frac_clogged_traps > 0.3:
                                 self._log(
@@ -508,15 +507,15 @@ class Pipeline(ProcessABC):
                             self._log(
                                 f"{name}:Stopped early at time {i} with {frac_clogged_traps} clogged traps"
                             )
-                            setup["meta"].add_fields({"end_status": "Clogged"})
+                            pipe["meta"].add_fields({"end_status": "Clogged"})
                             break
-                        setup["meta"].add_fields({"last_processed": i})
+                        pipe["meta"].add_fields({"last_processed": i})
+                    pipe["meta"].add_fields({"end_status": "Success"})
                     # run post-processing
-                    setup["meta"].add_fields({"end_status": "Success"})
                     post_proc_params = PostProcessorParameters.from_dict(
-                        setup["config"]["postprocessing"]
+                        pipe["config"]["postprocessing"]
                     )
-                    PostProcessor(setup["filename"], post_proc_params).run()
+                    PostProcessor(pipe["filename"], post_proc_params).run()
                     self._log("Analysis finished successfully.", "info")
                     return 1
         except Exception as e:
@@ -529,7 +528,7 @@ class Pipeline(ProcessABC):
             traceback.print_exc()
             raise e
         finally:
-            _close_session(session)
+            close_session(session)
 
     def setup_pipeline(
         self, image_id: int
@@ -555,7 +554,7 @@ class Pipeline(ProcessABC):
 
         Returns
         -------
-        setup: dict
+        pipe: dict
             With keys
                 filename: str
                     Path to a h5 file to write to.
@@ -576,13 +575,13 @@ class Pipeline(ProcessABC):
                     States of any trackers from earlier runs.
         session: None
         """
-        setup = {}
+        pipe = {}
         config = self.parameters.to_dict()
         # TODO Alan: Verify if session must be passed
         session = None
-        setup["earlystop"] = config["general"].get("earlystop", None)
-        setup["process_from"] = {k: 0 for k in self.pipeline_steps}
-        setup["steps"] = {}
+        pipe["earlystop"] = config["general"].get("earlystop", None)
+        pipe["process_from"] = {k: 0 for k in self.pipeline_steps}
+        pipe["steps"] = {}
         # check overwriting
         overwrite_id = config["general"].get("overwrite", 0)
         overwrite = {step: True for step in self.step_sequence}
@@ -593,11 +592,11 @@ class Pipeline(ProcessABC):
             }
         # set up
         directory = config["general"]["directory"]
-        setup["trackers_state"] = []
+        pipe["trackers_state"] = []
         with dispatch_image(image_id)(image_id, **self.server_info) as image:
-            setup["filename"] = Path(f"{directory}/{image.name}.h5")
+            pipe["filename"] = Path(f"{directory}/{image.name}.h5")
             # load metadata from h5 file
-            setup["meta"] = MetaData(directory, setup["filename"])
+            pipe["meta"] = MetaData(directory, pipe["filename"])
             from_start = True if np.any(overwrite.values()) else False
             # remove existing h5 file if overwriting
             if (
@@ -606,15 +605,15 @@ class Pipeline(ProcessABC):
                     config["general"].get("overwrite", False)
                     or np.all(list(overwrite.values()))
                 )
-                and setup["filename"].exists()
+                and pipe["filename"].exists()
             ):
-                os.remove(setup["filename"])
+                os.remove(pipe["filename"])
             # if the file exists with no previous segmentation use its tiler
-            if setup["filename"].exists():
+            if pipe["filename"].exists():
                 self._log("Result file exists.", "info")
                 if not overwrite["tiler"]:
-                    setup["steps"]["tiler"] = Tiler.from_h5(
-                        image, setup["filename"]
+                    pipe["steps"]["tiler"] = Tiler.from_h5(
+                        image, pipe["filename"]
                     )
                     try:
                         (
@@ -622,30 +621,30 @@ class Pipeline(ProcessABC):
                             trackers_state,
                             overwrite,
                         ) = self._load_config_from_file(
-                            setup["filename"],
-                            setup["process_from"],
-                            setup["trackers_state"],
+                            pipe["filename"],
+                            pipe["process_from"],
+                            pipe["trackers_state"],
                             overwrite,
                         )
                         # get state array
-                        setup["trackers_state"] = (
+                        pipe["trackers_state"] = (
                             []
                             if overwrite["baby"]
                             else StateReader(
-                                setup["filename"]
+                                pipe["filename"]
                             ).get_formatted_states()
                         )
-                        config["tiler"] = setup["steps"][
+                        config["tiler"] = pipe["steps"][
                             "tiler"
                         ].parameters.to_dict()
                     except Exception:
                         self._log("Overwriting tiling data")
 
             if config["general"]["use_explog"]:
-                setup["meta"].run()
-            setup["config"] = config
+                pipe["meta"].run()
+            pipe["config"] = config
             # add metadata not in the log file
-            setup["meta"].add_fields(
+            pipe["meta"].add_fields(
                 {
                     "aliby_version": version("aliby"),
                     "baby_version": version("aliby-baby"),
@@ -658,8 +657,8 @@ class Pipeline(ProcessABC):
                     ).to_yaml(),
                 }
             )
-            setup["tps"] = min(config["general"]["tps"], image.data.shape[0])
-            return setup, session
+            pipe["tps"] = min(config["general"]["tps"], image.data.shape[0])
+            return pipe, session
 
     @staticmethod
     def check_earlystop(filename: str, es_parameters: dict, tile_size: int):
@@ -705,6 +704,6 @@ class Pipeline(ProcessABC):
         return (traps_above_nthresh & traps_above_athresh).mean()
 
 
-def _close_session(session):
+def close_session(session):
     if session:
         session.close()
-- 
GitLab