From 9ea738be7d38bae6f94023761ff2137584d9df6a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <amuoz@ed.ac.uk>
Date: Tue, 1 Feb 2022 13:10:58 +0000
Subject: [PATCH] add  continue_interrupted functionality

---
 aliby/pipeline.py   | 29 +++++++++++++++++------------
 aliby/tile/tiler.py | 35 +++++++++++++++++++----------------
 2 files changed, 36 insertions(+), 28 deletions(-)

diff --git a/aliby/pipeline.py b/aliby/pipeline.py
index 1da0f055..70f52a4d 100644
--- a/aliby/pipeline.py
+++ b/aliby/pipeline.py
@@ -26,7 +26,7 @@ from aliby.baby_client import BabyRunner, BabyParameters
 from aliby.tile.tiler import Tiler, TilerParameters
 from aliby.io.omero import Dataset, Image
 from agora.abc import ParametersABC, ProcessABC
-from agora.io.writer import TilerWriter, BabyWriter
+from agora.io.writer import TilerWriter, BabyWriter, StateWriter
 from agora.io.signal import Signal
 from extraction.core.extractor import Extractor, ExtractorParameters
 from extraction.core.functions.defaults import exparams_from_meta
@@ -86,9 +86,9 @@ class PipelineParameters(ParametersABC):
                 distributed=0,
                 tps=tps,
                 directory=str(directory),
-                strain="",
+                filter="",
                 earlystop=dict(
-                    min_tp=100,
+                    min_tp=50,
                     thresh_pos_clogged=0.3,
                     thresh_trap_clogged=7,
                     ntps_to_eval=5,
@@ -219,7 +219,7 @@ class Pipeline(ProcessABC):
                     try:
                         print(f"Existing file {filename} will be used.")
                         with h5py.File(filename, "r") as f:
-                            tiler = Tiler.from_hdf5(image.data, filename)
+                            tiler = Tiler.from_hdf5(image, filename)
                             s = Signal(filename)
                             process_from = (
                                 f.attrs["last_processed"]
@@ -230,9 +230,8 @@ class Pipeline(ProcessABC):
                             )
                             # get state array
                             state_array = f.get("state_array", 0)
-                        if process_from > 2:
-                            processFalsefrom = process_from - 3
                             tiler.n_processed = process_from
+                            process_from += 1
                         from_start = False
                     except:
                         pass
@@ -244,17 +243,17 @@ class Pipeline(ProcessABC):
                         pass
 
                     process_from = 0
+                    meta.run()
+                    meta.add_fields(
+                        {"omero_id,": config["general"]["id"], "image_id": image_id}
+                    )
                     try:
-                        meta.run()
-                        meta.add_fields(
-                            {"omero_id,": config["general"]["id"], "image_id": image_id}
-                        )
                         tiler = Tiler.from_image(
                             image, TilerParameters.from_dict(config["tiler"])
                         )
                     except:
                         # Remove and try to run again?
-                        pass
+                        meta.add_fields({"end_status": "Untiled"})
 
                 writer = TilerWriter(filename)
                 session = initialise_tf(2)
@@ -306,7 +305,7 @@ class Pipeline(ProcessABC):
                         trap_info = tiler.run_tp(i)
                         logging.debug(f"Timing:Trap:{perf_counter() - t}s")
                         t = perf_counter()
-                        writer.write(trap_info, overwrite=[])
+                        writer.write(trap_info, overwrite=[], tp=i)
                         logging.debug(f"Timing:Writing-trap:{perf_counter() - t}s")
                         t = perf_counter()
                         seg = runner.run_tp(i)
@@ -318,6 +317,8 @@ class Pipeline(ProcessABC):
                         bwriter.write(seg, overwrite=["mother_assign"])
                         logging.debug(f"Timing:Writing-baby:{perf_counter() - t}s")
 
+                        # TODO add time-skipping for cases when the
+                        # an interruption happens after writing segmentation but before extraction
                         t = perf_counter()
                         labels, masks = groupby_traps(
                             seg["trap"],
@@ -334,6 +335,7 @@ class Pipeline(ProcessABC):
                         print(
                             f"Stopping analysis at time {i} with {frac_clogged_traps} clogged traps"
                         )
+                        meta.add_fields({"end_status": "Clogged"})
                         break
 
                     if (
@@ -345,12 +347,15 @@ class Pipeline(ProcessABC):
 
                     meta.add_fields({"last_processed": i})
                 # Run post processing
+
+                meta.add_fields({"end_status": "Success"})
                 post_proc_params = PostProcessorParameters.from_dict(
                     self.parameters.postprocessing
                 ).to_dict()
                 PostProcessor(filename, post_proc_params).run()
 
                 return 1
+
         except Exception as e:  # bug in the trap getting
             logging.exception(
                 f"Caught exception in worker thread (x = {name}):", exc_info=True
diff --git a/aliby/tile/tiler.py b/aliby/tile/tiler.py
index cd5b91a8..f95a99ae 100644
--- a/aliby/tile/tiler.py
+++ b/aliby/tile/tiler.py
@@ -82,14 +82,14 @@ class TrapLocations:
         ]
         self.drifts = drifts
 
-    @classmethod
-    def from_source(cls, fpath: str):
-        with h5py.File(fpath, "r") as f:
-            # TODO read tile size from file metadata
-            drifts = f["trap_info/drifts"][()]
-            tlocs = cls(f["trap_info/trap_locations"][()], tile_size=96, drifts=drifts)
+        # @classmethod
+        # def from_source(cls, fpath: str):
+        #     with h5py.File(fpath, "r") as f:
+        #         # TODO read tile size from file metadata
+        #         drifts = f["trap_info/drifts"][()].tolist()
+        #         tlocs = cls(f["trap_info/trap_locations"][()], tile_size=96, drifts=drifts)
 
-        return tlocs
+        # return tlocs
 
     @property
     def shape(self):
@@ -106,12 +106,12 @@ class TrapLocations:
 
     def to_dict(self, tp):
         res = dict()
-        if tp == 0:
-            res["trap_locations"] = self.initial_location
-            res["attrs/tile_size"] = self.tile_size
-            res["attrs/max_size"] = self.max_size
+        # if tp == 0:
+        res["trap_locations"] = self.initial_location
+        res["attrs/tile_size"] = self.tile_size
+        res["attrs/max_size"] = self.max_size
         res["drifts"] = np.expand_dims(self.drifts[tp], axis=0)
-        # res['processed_timepoints'] = tp
+        # res["processed_timepoints"] = tp
         return res
 
     @classmethod
@@ -119,7 +119,7 @@ class TrapLocations:
         with h5py.File(file, "r") as hfile:
             trap_info = hfile["trap_info"]
             initial_locations = trap_info["trap_locations"][()]
-            drifts = trap_info["drifts"][()]
+            drifts = trap_info["drifts"][()].tolist()
             max_size = trap_info.attrs["max_size"]
             tile_size = trap_info.attrs["tile_size"]
         trap_locs = cls(initial_locations, tile_size, max_size=max_size)
@@ -234,7 +234,10 @@ class Tiler(ProcessABC):
             self.image[prev_tp, self.ref_channel, self.ref_z],
             self.image[tp, self.ref_channel, self.ref_z],
         )
-        self.trap_locs.drifts.append(drift)
+        if 0 < tp < len(self.trap_locs.drifts):
+            self.trap_locs.drifts[tp] = drift.tolist()
+        else:
+            self.trap_locs.drifts.append(drift.tolist())
 
     def get_tp_data(self, tp, c):
         traps = []
@@ -288,13 +291,13 @@ class Tiler(ProcessABC):
         return trap
 
     def run_tp(self, tp):
-        assert tp >= self.n_processed, "Time point already processed"
+        # assert tp >= self.n_processed, "Time point already processed"
         # TODO check contiguity?
         if self.n_processed == 0:
             self._initialise_traps(self.tile_size)
         self.find_drift(tp)  # Get drift
         # update n_processed
-        self.n_processed += 1
+        self.n_processed = tp + 1
         # Return result for writer
         return self.trap_locs.to_dict(tp)
 
-- 
GitLab