From 37cc32fc5de80743c7654a27fdb6d0491f8cf468 Mon Sep 17 00:00:00 2001
From: pswain <peter.swain@ed.ac.uk>
Date: Thu, 21 Dec 2023 19:03:10 +0000
Subject: [PATCH] feature(grouper, signal): added tmax_in_mins_dict and
 tmax_in_mins

---
 src/agora/io/signal.py       | 32 +++++++++++++++++------
 src/postprocessor/grouper.py | 49 ++++++++++++++++++++----------------
 2 files changed, 53 insertions(+), 28 deletions(-)

diff --git a/src/agora/io/signal.py b/src/agora/io/signal.py
index 2f787ff..3a5261f 100644
--- a/src/agora/io/signal.py
+++ b/src/agora/io/signal.py
@@ -52,10 +52,14 @@ class Signal(BridgeH5):
         else:
             raise Exception(f"Invalid type {type(dsets)} to get datasets")
 
-    def get(self, dset_name: t.Union[str, t.Collection]):
+    def get(
+        self,
+        dset_name: t.Union[str, t.Collection],
+        tmax_in_mins: int = None,
+    ):
         """Get Signal after merging and picking."""
         if isinstance(dset_name, str):
-            dsets = self.get_raw(dset_name)
+            dsets = self.get_raw(dset_name, tmax_in_mins)
             if dsets is not None:
                 picked_merged = self.apply_merging_picking(dsets)
                 return self.add_name(picked_merged, dset_name)
@@ -101,11 +105,11 @@ class Signal(BridgeH5):
                 )
                 return 300
 
-    def retained(self, signal, cutoff=0):
+    def retained(self, signal, cutoff=0, tmax_in_mins: int = None):
         """Get retained cells for a Signal or list of Signals."""
         if isinstance(signal, str):
             # get data frame
-            signal = self.get(signal)
+            signal = self.get(signal, tmax_in_mins=tmax_in_mins)
         if isinstance(signal, pd.DataFrame):
             return self.get_retained(signal, cutoff)
         elif isinstance(signal, list):
@@ -199,7 +203,7 @@ class Signal(BridgeH5):
             return merged
 
     @cached_property
-    def p_available(self):
+    def print_available(self):
         """Print data sets available in h5 file."""
         if not hasattr(self, "_available"):
             self._available = []
@@ -218,7 +222,6 @@ class Signal(BridgeH5):
                 f.visititems(self.store_signal_path)
         except Exception as e:
             self._log("Exception when visiting h5: {}".format(e), "exception")
-
         return self._available
 
     def get_merged(self, dataset):
@@ -250,6 +253,7 @@ class Signal(BridgeH5):
         in_minutes: bool = True,
         lineage: bool = False,
         stop_on_lineage_check: bool = True,
+        tmax_in_mins: int = None,
         **kwargs,
     ) -> pd.DataFrame or t.List[pd.DataFrame]:
         """
@@ -265,6 +269,11 @@ class Signal(BridgeH5):
             If True, add mother_label to index.
         run_lineage_check: boolean
             If True, raise exception if a likely error in the lineage assignment.
+        tmax_in_mins: int (optional)
+            Discard data for times > tmax_in_mins. Cells with all NaNs will also
+            be discarded to help with assigning lineages.
+            Setting tmax_in_mins is a way to ignore parts of the experiment with
+            incorrect lineages generated by clogging.
         """
         try:
             if isinstance(dataset, str):
@@ -274,6 +283,10 @@ class Signal(BridgeH5):
                         df = df.sort_index()
                         if in_minutes:
                             df = self.cols_in_mins(df)
+                        # limit data by time and discard NaNs
+                        if tmax_in_mins and type(tmax_in_mins) is int:
+                            df = df[df.columns[df.columns < tmax_in_mins]]
+                            df = df.dropna(how="all")
                         # add mother label to data frame
                         if lineage:
                             mother_label = np.zeros(len(df), dtype=int)
@@ -293,7 +306,12 @@ class Signal(BridgeH5):
                     return df
             elif isinstance(dataset, list):
                 return [
-                    self.get_raw(dset, in_minutes=in_minutes, lineage=lineage)
+                    self.get_raw(
+                        dset,
+                        in_minutes=in_minutes,
+                        lineage=lineage,
+                        tmax_in_mins=tmax_in_mins,
+                    )
                     for dset in dataset
                 ]
         except Exception as e:
diff --git a/src/postprocessor/grouper.py b/src/postprocessor/grouper.py
index 0abc8ad..f964af7 100644
--- a/src/postprocessor/grouper.py
+++ b/src/postprocessor/grouper.py
@@ -33,7 +33,7 @@ class Grouper(ABC):
 
     def load_positions(self) -> None:
         """Load a Signal for each position, or h5 file."""
-        self.positions = {f.name[:-3]: Signal(f) for f in self.files}
+        self.positions = {f.name[:-3]: Signal(f) for f in sorted(self.files)}
 
     @property
     def first_signal(self) -> Signal:
@@ -132,23 +132,25 @@ class Grouper(ABC):
             # check for errors
             errors = [
                 position
-                for record, position in zip(records, self.positions.keys())
+                for record, position in zip(records, good_positions.keys())
                 if record is None
             ]
             records = [record for record in records if record is not None]
             if len(errors):
                 print(f"Warning: Positions ({errors}) contain errors.")
             assert len(records), "All data sets contain errors"
-            # combine into one dataframe
+            # combine into one data frame
             concat = pd.concat(records, axis=0)
             if len(concat.index.names) > 4:
-                # reorder levels in the multi-index dataframe
+                # reorder levels in the multi-index data frame
                 # when mother_label is present
                 concat = concat.reorder_levels(
                     ("group", "position", "trap", "cell_label", "mother_label")
                 )
             concat_sorted = concat.sort_index()
             return concat_sorted
+        else:
+            print("No data found.")
 
     def filter_positions(self, path: str) -> t.Dict[str, Signal]:
         """Filter positions to those whose data is available in the h5 file."""
@@ -267,13 +269,8 @@ class Grouper(ABC):
 
     @property
     def groups(self) -> t.Tuple[str]:
-        """Get groups, sorted alphabetically, as a tuple."""
-        return tuple(sorted(set(self.positions_groups.values())))
-
-    @property
-    def positions(self) -> t.Tuple[str]:
-        """Get positions, sorted alphabetically, as a tuple."""
-        return tuple(sorted(set(self.positions_groups.keys())))
+        """Get groups, sorted alphabetically, as a list."""
+        return list(sorted(set(self.positions_groups.values())))
 
 
 def concat_one_signal(
@@ -282,29 +279,39 @@ def concat_one_signal(
     group: str,
     mode: str = "retained",
     position_name=None,
+    tmax_in_mins_dict=None,
     **kwargs,
 ) -> pd.DataFrame:
-    """
-    Retrieve an individual signal.
-
-    Applies filtering if requested and adjusts indices.
-    """
+    """Retrieve a signal for one position."""
+    if tmax_in_mins_dict and position_name in tmax_in_mins_dict:
+        tmax_in_mins = tmax_in_mins_dict[position_name]
+    else:
+        tmax_in_mins = None
     if position_name is None:
         # name of h5 file
         position_name = position.stem
-    print(f"Finding signal for {position_name}.")
+    if tmax_in_mins:
+        print(
+            f" Loading {path} for {position_name} up to time {tmax_in_mins}."
+        )
+    else:
+        print(f" Loading {path} for {position_name}.")
     if mode == "retained":
-        combined = position.retained(path, **kwargs)
+        combined = position.retained(path, tmax_in_mins=tmax_in_mins, **kwargs)
     elif mode == "raw":
-        combined = position.get_raw(path, **kwargs)
+        combined = position.get_raw(path, tmax_in_mins=tmax_in_mins, **kwargs)
     elif mode == "daughters":
-        combined = position.get_raw(path, lineage=True, **kwargs)
+        combined = position.get_raw(
+            path, lineage=True, tmax_in_mins=tmax_in_mins, **kwargs
+        )
         if combined is not None:
             combined = combined.loc[
                 combined.index.get_level_values("mother_label") > 0
             ]
     elif mode == "mothers":
-        combined = position.get_raw(path, lineage=True, **kwargs)
+        combined = position.get_raw(
+            path, lineage=True, tmax_in_mins=tmax_in_mins, **kwargs
+        )
         if combined is not None:
             combined = combined.loc[
                 combined.index.get_level_values("mother_label") == 0
-- 
GitLab