From 9c3bee4232c1c1751908fef130d23932544aa18e Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <alan.munoz@ed.ac.uk>
Date: Fri, 7 Oct 2022 18:37:34 +0100
Subject: [PATCH] refactor(grouper): to use chainer

---
 src/postprocessor/grouper.py | 207 ++++++++++++++++++++---------------
 1 file changed, 118 insertions(+), 89 deletions(-)

diff --git a/src/postprocessor/grouper.py b/src/postprocessor/grouper.py
index e8a6f3ce..daaec4aa 100644
--- a/src/postprocessor/grouper.py
+++ b/src/postprocessor/grouper.py
@@ -13,7 +13,7 @@ import matplotlib.pyplot as plt
 import numpy as np
 import pandas as pd
 import seaborn as sns
-from agora.io.signal import Signal
+from postprocessor.chainer import Chainer
 from pathos.multiprocessing import Pool
 
 
@@ -28,27 +28,27 @@ class Grouper(ABC):
         assert path.exists(), "Dir does not exist"
         self.files = list(path.glob("*.h5"))
         assert len(self.files), "No valid h5 files in dir"
-        self.load_signals()
+        self.load_chains()
 
-    def load_signals(self) -> None:
-        # Sets self.signals
-        self.signals = {f.name[:-3]: Signal(f) for f in self.files}
+    def load_chains(self) -> None:
+        # Sets self.chainers
+        self.chainers = {f.name[:-3]: Chainer(f) for f in self.files}
 
     @property
-    def fsignal(self) -> Signal:
+    def fsignal(self) -> Chainer:
         # Returns first signal
-        return list(self.signals.values())[0]
+        return list(self.chainers.values())[0]
 
     @property
     def ntimepoints(self) -> int:
-        return max([s.ntimepoints for s in self.signals.values()])
+        return max([s.ntimepoints for s in self.chainers.values()])
 
     @property
     def tintervals(self) -> float:
-        tintervals = set([s.tinterval / 60 for s in self.signals.values()])
+        tintervals = set([s.tinterval / 60 for s in self.chainers.values()])
         assert (
             len(tintervals) == 1
-        ), "Not all signals have the same time interval"
+        ), "Not all chains have the same time interval"
 
         return max(tintervals)
 
@@ -60,7 +60,7 @@ class Grouper(ABC):
     def available_grouped(self) -> None:
         if not hasattr(self, "_available_grouped"):
             self._available_grouped = Counter(
-                [x for s in self.signals.values() for x in s.available]
+                [x for s in self.chainers.values() for x in s.available]
             )
 
         for s, n in self._available_grouped.items():
@@ -68,7 +68,7 @@ class Grouper(ABC):
 
     @property
     def datasets(self) -> None:
-        """Print available datasets in first Signal instance."""
+        """Print available datasets in first Chainer instance."""
         return self.fsignal.datasets
 
     @abstractproperty
@@ -78,13 +78,12 @@ class Grouper(ABC):
     def concat_signal(
         self,
         path: str,
-        reduce_cols: t.Optional[bool] = None,
-        axis: int = 0,
-        mode: str = "retained",
         pool: t.Optional[int] = None,
+        mode: str = "retained",
+        standard: t.Optional[bool] = False,
         **kwargs,
     ):
-        """Concatenate a single signal.
+        """Concate
 
         Parameters
         ----------
@@ -106,36 +105,58 @@ class Grouper(ABC):
         if path.startswith("/"):
             path = path.strip("/")
 
-        # Check the path is in a given signal
-        sitems = {k: v for k, v in self.signals.items() if path in v.available}
-        nsignals_dif = len(self.signals) - len(sitems)
-        if nsignals_dif:
-            print(
-                f"Grouper:Warning: {nsignals_dif} signals do not contain"
-                f" channel {path}"
-            )
+        sitems = self.filter_path(path)
+        if standard:
+            fn_pos = concat_standard
+        else:
+            fn_pos = concat_signal_ind
+            kwargs["mode"] = mode
 
-        signals = self.pool_function(
+        kymographs = self.pool_function(
             path=path,
-            f=concat_signal_ind,
-            mode=mode,
+            f=fn_pos,
             pool=pool,
-            signals=sitems,
+            chainers=sitems,
             **kwargs,
         )
 
-        errors = [k for s, k in zip(signals, self.signals.keys()) if s is None]
-        signals = [s for s in signals if s is not None]
+        errors = [
+            k
+            for kymo, k in zip(kymographs, self.chainers.keys())
+            if kymo is None
+        ]
+        kymographs = [kymo for kymo in kymographs if kymo is not None]
         if len(errors):
             print("Warning: Positions contain errors {errors}")
-            assert len(signals), "All datasets contain errors"
-        sorted = pd.concat(signals, axis=axis).sort_index()
-        if reduce_cols:
-            sorted = sorted.apply(np.nanmean, axis=1)
-            spath = path.split("/")
-            sorted.name = "_".join([spath[1], spath[-1]])
 
-        return sorted
+        assert len(kymographs), "All datasets contain errors"
+
+        concat_sorted = (
+            pd.concat(kymographs, axis=0)
+            .reorder_levels(
+                ("group", "position", "trap", "cell_label", "mother_label")
+            )
+            .sort_index()
+        )
+        return concat_sorted
+
+    def filter_path(self, path: str) -> t.Dict[str, Chainer]:
+        # Check the path is in a given signal
+        sitems = {
+            k: v
+            for k, v in self.chainers.items()
+            if path in [*v.available, *v.common_chains]
+        }
+        nchains_dif = len(self.chainers) - len(sitems)
+        if nchains_dif:
+            print(
+                f"Grouper:Warning: {nchains_dif} chains do not contain"
+                f" channel {path}"
+            )
+
+        assert len(sitems), "No valid dataset to use"
+
+        return sitems
 
     @property
     def nmembers(self) -> t.Dict[str, int]:
@@ -144,7 +165,7 @@ class Grouper(ABC):
 
     @property
     def ntraps(self):
-        for pos, s in self.signals.items():
+        for pos, s in self.chainers.items():
             with h5py.File(s.filename, "r") as f:
                 print(pos, f["/trap_info/trap_locations"].shape[0])
 
@@ -152,7 +173,7 @@ class Grouper(ABC):
     def ntraps_by_pos(self) -> t.Dict[str, int]:
         # Return total number of traps grouped
         ntraps = {}
-        for pos, s in self.signals.items():
+        for pos, s in self.chainers.items():
             with h5py.File(s.filename, "r") as f:
                 ntraps[pos] = f["/trap_info/trap_locations"].shape[0]
 
@@ -164,7 +185,7 @@ class Grouper(ABC):
 
     def traplocs(self):
         d = {}
-        for pos, s in self.signals.items():
+        for pos, s in self.chainers.items():
             with h5py.File(s.filename, "r") as f:
                 d[pos] = f["/trap_info/trap_locations"][()]
         return d
@@ -201,15 +222,16 @@ class Grouper(ABC):
         path: str,
         f: t.Callable,
         pool: t.Optional[int] = None,
-        signals: t.Dict[str, Signal] = None,
+        chainers: t.Dict[str, Chainer] = None,
         **kwargs,
     ):
         """
-        Wrapper to add support for threading to process independent signals.
+        Wrapper to add support for threading to process independent chains.
         Particularly useful when aggregating multiple elements.
         """
-        pool = pool or 8
-        signals = signals or self.signals
+        if pool is None:
+            pool = 8
+        chainers = chainers or self.chainers
 
         if pool:
 
@@ -217,25 +239,37 @@ class Grouper(ABC):
                 kymographs = p.map(
                     lambda x: f(
                         path=path,
-                        signal=x[1],
+                        chainer=x[1],
                         group=self.positions_groups[x[0]],
+                        position=x[0],
                         **kwargs,
                     ),
-                    signals.items(),
+                    chainers.items(),
                 )
         else:
             kymographs = [
                 f(
                     path=path,
-                    signal=signal,
+                    chainer=chainer,
                     group=self.positions_groups[name],
+                    position=name,
                     **kwargs,
                 )
-                for name, signal in self.signals.items()
+                for name, chainer in self.chainers.items()
             ]
 
         return kymographs
 
+    @property
+    def channels(self):
+        return set(
+            [
+                channel
+                for chainer in self.chainers.values()
+                for channel in chainer.channels
+            ]
+        )
+
     @property
     def stages_span(self):
         return self.fsignal.stages_span
@@ -273,33 +307,13 @@ class NameGrouper(Grouper):
     def positions_groups(self) -> t.Dict[str, str]:
         if not hasattr(self, "_positions_groups"):
             self._positions_groups = {}
-            for name in self.signals.keys():
+            for name in self.chainers.keys():
                 self._positions_groups[name] = name[
                     self.criteria[0] : self.criteria[1]
                 ]
 
         return self._positions_groups
 
-    # def aggregate_multisignals(self, paths=None, **kwargs):
-    #     aggregated = pd.concat(
-    #         [
-    #             self.concat_signal(path, reduce_cols=np.nanmean, **kwargs)
-    #             for path in paths
-    #         ],
-    #         axis=1,
-    #     )
-    #     # ph = pd.Series(
-    #     #     [
-    #     #         self.ph_from_group(x[list(aggregated.index.names).index("group")])
-    #     #         for x in aggregated.index
-    #     #     ],
-    #     #     index=aggregated.index,
-    #     #     name="media_pH",
-    #     # )
-    #     # self.aggregated = pd.concat((aggregated, ph), axis=1)
-
-    #     return aggregated
-
 
 class phGrouper(NameGrouper):
     """Grouper for pH calibration experiments where all surveyed media pH
@@ -318,8 +332,8 @@ class phGrouper(NameGrouper):
 
         return float(group_name.replace("_", "."))
 
-    def aggregate_multisignals(self, paths: list) -> pd.DataFrame:
-        """Accumulate multiple signals."""
+    def aggregate_multichains(self, paths: list) -> pd.DataFrame:
+        """Accumulate multiple chains."""
 
         aggregated = pd.concat(
             [
@@ -343,9 +357,26 @@ class phGrouper(NameGrouper):
         return aggregated
 
 
+def concat_standard(
+    path: str,
+    chainer: Chainer,
+    group: str,
+    position: t.Optional[str] = None,
+    **kwargs,
+) -> pd.DataFrame:
+
+    combined = chainer.get(path, **kwargs).copy()
+    combined["position"] = position
+    combined["group"] = group
+    combined.set_index(["group", "position"], inplace=True, append=True)
+    combined.index = combined.index.copy().swaplevel(-2, 0).swaplevel(-1, 1)
+
+    return combined
+
+
 def concat_signal_ind(
     path: str,
-    signal: Signal,
+    chainer: Chainer,
     group: str,
     mode: str = "retained",
     position=None,
@@ -354,23 +385,21 @@ def concat_signal_ind(
     """Core function that handles retrieval of an individual signal, applies
     filtering if requested and adjusts indices."""
     if position is None:
-        position = signal.stem
+        position = chainer.stem
     if mode == "retained":
-        combined = signal.retained(path, **kwargs)
+        combined = chainer.retained(path, **kwargs)
     if mode == "mothers":
         raise (NotImplementedError)
     elif mode == "raw":
-        combined = signal.get_raw(path, **kwargs)
+        combined = chainer.get_raw(path, **kwargs)
     elif mode == "families":
-        combined = signal[path]
+        combined = chainer[path]
     combined["position"] = position
     combined["group"] = group
     combined.set_index(["group", "position"], inplace=True, append=True)
     combined.index = combined.index.swaplevel(-2, 0).swaplevel(-1, 1)
 
     return combined
-    # except:
-    #     return None
 
 
 class MultiGrouper:
@@ -385,7 +414,7 @@ class MultiGrouper:
             self.exp_dirs = [Path(x) for x in source]
         self.groupers = [NameGrouper(d) for d in self.exp_dirs]
         for group in self.groupers:
-            group.load_signals()
+            group.load_chains()
 
     @property
     def available(self) -> None:
@@ -406,7 +435,7 @@ class MultiGrouper:
 
         if not hasattr(self, "_sigtable"):
             raw_mat = [
-                [s.available for s in gpr.signals.values()]
+                [s.available for s in gpr.chains.values()]
                 for gpr in self.groupers
             ]
             available_grouped = [
@@ -435,7 +464,7 @@ class MultiGrouper:
         return self._sigtable
 
     def sigtable_plot(self) -> None:
-        """Plot number of signals for all available experiments.
+        """Plot number of chains for all available experiments.
 
         Examples
         --------
@@ -452,17 +481,17 @@ class MultiGrouper:
 
     def aggregate_signal(
         self,
-        signals: Union[str, list],
+        path: Union[str, list],
         **kwargs,
     ) -> Union[pd.DataFrame, Dict[str, pd.DataFrame]]:
-        """Aggregate signals from multiple Groupers (and thus experiments)
+        """Aggregate chains from multiple Groupers (and thus experiments)
 
         Parameters
         ----------
-        signals : Union[str, list]
+        chains : Union[str, list]
             string or list of strings indicating the signal(s) to fetch.
         **kwargs : keyword arguments to pass to Grouper.concat_signal
-            Customise the filters and format to fetch signals.
+            Customise the filters and format to fetch chains.
 
         Returns
         -------
@@ -473,11 +502,11 @@ class MultiGrouper:
         --------
         FIXME: Add docs.
         """
-        if isinstance(signals, str):
-            signals = [signals]
+        if isinstance(path, str):
+            path = [path]
 
-        sigs = {s: [] for s in signals}
-        for s in signals:
+        sigs = {s: [] for s in path}
+        for s in path:
             for grp in self.groupers:
                 try:
                     sigset = grp.concat_signal(s, **kwargs)
-- 
GitLab