From e855162d286a0088b86e619802fe4bfec2c9549b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <amuoz@ed.ac.uk>
Date: Fri, 4 Mar 2022 18:35:52 +0000
Subject: [PATCH] threading bugfixes

---
 aliby/pipeline.py | 66 +++++++++++++++++++++++------------------------
 1 file changed, 33 insertions(+), 33 deletions(-)

diff --git a/aliby/pipeline.py b/aliby/pipeline.py
index 20d0e285..af5de9a7 100644
--- a/aliby/pipeline.py
+++ b/aliby/pipeline.py
@@ -120,15 +120,12 @@ class Pipeline(ProcessABC):
     """
 
     iterative_steps = ["tiler", "baby", "extraction"]
-    steps = {}
     step_sequence = [
         "tiler",
         "baby",
         "extraction",
         "postprocessing",
     ]
-    process_from = {k: 0 for k in iterative_steps}
-    ow = {k: 0 for k in step_sequence}
     writer_groups = {
         "tiler": ["trap_info"],
         "baby": ["cell_info"],
@@ -242,11 +239,14 @@ class Pipeline(ProcessABC):
         general_config = config["general"]
         session = None
         earlystop = general_config.get("earlystop", None)
+        steps = {}
+        process_from = {k: 0 for k in self.iterative_steps}
+        ow = {k: 0 for k in self.step_sequence}
 
         # check overwriting
         ow_id = config.get("overwrite", 0)
         if ow_id:
-            self.ow = {
+            ow = {
                 step: self.step_sequence.index(ow_id) < i
                 for i, step in enumerate(self.step_sequence, 1)
             }
@@ -268,7 +268,7 @@ class Pipeline(ProcessABC):
                             f.attrs["parameters"]
                         ).to_dict()
 
-                    for k, v in self.ow.items():
+                    for k, v in ow.items():
                         if v:
                             with h5py.File(filename, "a") as f:
                                 del f[self.writer_groups[k]]
@@ -278,16 +278,16 @@ class Pipeline(ProcessABC):
                         overwrite=True,
                     )
 
-                    if not self.ow["tiler"]:  # Try to load config from file
+                    if not ow["tiler"]:  # Try to load config from file
                         try:
                             with h5py.File(filename, "r") as f:
-                                self.steps["tiler"] = Tiler.from_hdf5(image, filename)
+                                steps["tiler"] = Tiler.from_hdf5(image, filename)
                                 s = Signal(filename)
 
-                                for k, v in self.process_from.items():
-                                    if not self.ow[k]:
+                                for k, v in process_from.items():
+                                    if not ow[k]:
 
-                                        self.process_from[k] = (
+                                        process_from[k] = (
                                             f[self.writer_groups[k][-1]].attrs.get(
                                                 "last_processed",
                                                 max(
@@ -305,17 +305,17 @@ class Pipeline(ProcessABC):
                                             + 1
                                         )
                                 # get state array
-                                if not self.ow["baby"]:
+                                if not ow["baby"]:
                                     trackers_state = StateReader(
                                         filename
                                     ).get_formatted_states()
-                                self.steps["tiler"].n_processed = max(
-                                    0, self.process_from["tiler"] - 1
+                                steps["tiler"].n_processed = max(
+                                    0, process_from["tiler"] - 1
                                 )
                                 # process_from += 1
 
-                            config["tiler"] = self.steps["tiler"].parameters.to_dict()
-                            if not np.any(self.ow.values()):
+                            config["tiler"] = steps["tiler"].parameters.to_dict()
+                            if not np.any(ow.values()):
                                 from_start = False
                                 print(f"Existing file {filename} will be used.")
                         except Exception as e:
@@ -323,7 +323,7 @@ class Pipeline(ProcessABC):
 
                 if from_start:  # New experiment or overwriting
                     if config.get("overwrite", False) is True or np.all(
-                        list(self.ow.values())
+                        list(ow.values())
                     ):
                         if Path(filename).exists():
                             os.remove(filename)
@@ -353,22 +353,22 @@ class Pipeline(ProcessABC):
                 }
 
                 # Initialise Steps
-                if "tiler" not in self.steps:
-                    self.steps["tiler"] = Tiler.from_image(
+                if "tiler" not in steps:
+                    steps["tiler"] = Tiler.from_image(
                         image, TilerParameters.from_dict(config["tiler"])
                     )
 
-                if self.process_from["baby"] < tps:
+                if process_from["baby"] < tps:
                     session = initialise_tf(2)
-                    self.steps["baby"] = BabyRunner.from_tiler(
-                        BabyParameters.from_dict(config["baby"]), self.steps["tiler"]
+                    steps["baby"] = BabyRunner.from_tiler(
+                        BabyParameters.from_dict(config["baby"]), steps["tiler"]
                     )
                     if trackers_state:
-                        self.steps["baby"].crawler.tracker_states = trackers_state
+                        steps["baby"].crawler.tracker_states = trackers_state
 
                 # Limit extraction parameters during run using the available channels in tiler
-                if self.process_from["extraction"] < tps:
-                    av_channels = set((*self.steps["tiler"].channels, "general"))
+                if process_from["extraction"] < tps:
+                    av_channels = set((*steps["tiler"].channels, "general"))
                     config["extraction"]["tree"] = {
                         k: v
                         for k, v in config["extraction"]["tree"].items()
@@ -394,15 +394,15 @@ class Pipeline(ProcessABC):
                     }
 
                     exparams = ExtractorParameters.from_dict(config["extraction"])
-                    self.steps["extraction"] = Extractor.from_tiler(
-                        exparams, store=filename, tiler=self.steps["tiler"]
+                    steps["extraction"] = Extractor.from_tiler(
+                        exparams, store=filename, tiler=steps["tiler"]
                     )
 
                 # RUN
                 # Adjust tps based on how many tps are available on the server
                 frac_clogged_traps = 0
                 # print(f"Processing from {process_from}")
-                min_process_from = min(self.process_from.values())
+                min_process_from = min(process_from.values())
                 pbar = tqdm(
                     range(min_process_from, tps),
                     desc=image.name,
@@ -418,9 +418,9 @@ class Pipeline(ProcessABC):
                     ):
 
                         for step in self.iterative_steps:
-                            if i >= self.process_from[step]:
+                            if i >= process_from[step]:
                                 t = perf_counter()
-                                result = self.steps[step].run_tp(
+                                result = steps[step].run_tp(
                                     i, **run_kwargs.get(step, {})
                                 )
                                 logging.debug(f"Timing:{step}:{perf_counter() - t}s")
@@ -439,7 +439,7 @@ class Pipeline(ProcessABC):
                                 # Step-specific actions
                                 if step == "baby":  # Write state and pass info to ext
                                     loaded_writers["state"].write(
-                                        data=self.steps[step].crawler.tracker_states,
+                                        data=steps[step].crawler.tracker_states,
                                         overwrite=loaded_writers[
                                             "state"
                                         ].datatypes.keys(),
@@ -449,7 +449,7 @@ class Pipeline(ProcessABC):
                                         result["trap"],
                                         result["cell_label"],
                                         result["edgemasks"],
-                                        self.steps["tiler"].n_traps,
+                                        steps["tiler"].n_traps,
                                     )
 
                                 elif (
@@ -460,11 +460,11 @@ class Pipeline(ProcessABC):
 
                             if i == min_process_from:
                                 print(
-                                    f"Found {self.steps['tiler'].n_traps} traps in {image.name}"
+                                    f"Found {steps['tiler'].n_traps} traps in {image.name}"
                                 )
 
                         frac_clogged_traps = self.check_earlystop(
-                            filename, earlystop, self.steps["tiler"].tile_size
+                            filename, earlystop, steps["tiler"].tile_size
                         )
                         logging.debug(f"Quality:Clogged_traps:{frac_clogged_traps}")
 
-- 
GitLab