From 27b9c3f6eb085dec0add5f553f46468f1af6b622 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <alan.munoz@ed.ac.uk>
Date: Mon, 17 Oct 2022 15:12:57 +0100
Subject: [PATCH] fix(signal): fringe cases transition at edges

---
 src/agora/io/signal.py       | 28 ++++++++++++++++++++++++----
 src/postprocessor/chainer.py |  2 +-
 2 files changed, 25 insertions(+), 5 deletions(-)

diff --git a/src/agora/io/signal.py b/src/agora/io/signal.py
index 489a5256..ed49ffd9 100644
--- a/src/agora/io/signal.py
+++ b/src/agora/io/signal.py
@@ -381,31 +381,51 @@ class Signal(BridgeH5):
         """
         flowrate_name = "pumpinit/flowrate"
         pumprate_name = "pumprate"
+        switchtimes_name = "switchtimes"
+
         main_pump_id = np.concatenate(
             (
                 (np.argmax(self.meta_h5[flowrate_name]),),
                 np.argmax(self.meta_h5[pumprate_name], axis=0),
             )
         )
+        if not self.meta_h5[switchtimes_name][0]:  # Cover for t0 switches
+            main_pump_id = main_pump_id[1:]
         return [self.meta_h5["pumpinit/contents"][i] for i in main_pump_id]
 
     @property
     def nstages(self) -> int:
-        switchtimes_name = "switchtimes"
-        return self.meta_h5[switchtimes_name] + 1
+        return len(self.switch_times) + 1
 
     @property
     def max_span(self) -> int:
         return int(self.tinterval * self.ntps / 60)
 
+    @property
+    def switch_frames(self) -> t.List[int]:
+        switchtimes_name = "switchtimes"
+        switch_frames = self.meta_h5[switchtimes_name]
+
+        return [
+            tp for tp in switch_frames if tp and tp < self.max_span
+        ]  # Cover for t0 switches
+
     @property
     def stages_span(self) -> t.Tuple[t.Tuple[str, int], ...]:
         # Return consecutive stages and their corresponding number of time-points
-        switchtimes_name = "switchtimes"
-        transition_tps = (0, *self.meta_h5[switchtimes_name])
+        transition_tps = (0, *self.switch_frames, self.max_span)
         spans = [
             end - start
             for start, end in zip(transition_tps[:-1], transition_tps[1:])
             if end <= self.max_span
         ]
         return tuple((stage, ntps) for stage, ntps in zip(self.stages, spans))
+
+    @property
+    def stages_span_tp(self) -> t.Tuple[t.Tuple[str, int], ...]:
+        return tuple(
+            [
+                (name, t_min // self.tinterval * 60)
+                for name, t_min in self.stages_span
+            ]
+        )
diff --git a/src/postprocessor/chainer.py b/src/postprocessor/chainer.py
index 25dabb1d..d174a4b2 100644
--- a/src/postprocessor/chainer.py
+++ b/src/postprocessor/chainer.py
@@ -80,7 +80,7 @@ class Chainer(Signal):
 
             stages_index = [
                 x
-                for i, (name, span) in enumerate(self.stages_span)
+                for i, (name, span) in enumerate(self.stages_span_tp)
                 for x in (f"{i} { name }",) * span
             ]
             data.columns = pd.MultiIndex.from_tuples(
-- 
GitLab