Skip to content
Snippets Groups Projects
Commit 9c3bee42 authored by Alán Muñoz's avatar Alán Muñoz
Browse files

refactor(grouper): to use chainer

parent 07ac0b9f
No related branches found
No related tags found
No related merge requests found
...@@ -13,7 +13,7 @@ import matplotlib.pyplot as plt ...@@ -13,7 +13,7 @@ import matplotlib.pyplot as plt
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import seaborn as sns import seaborn as sns
from agora.io.signal import Signal from postprocessor.chainer import Chainer
from pathos.multiprocessing import Pool from pathos.multiprocessing import Pool
...@@ -28,27 +28,27 @@ class Grouper(ABC): ...@@ -28,27 +28,27 @@ class Grouper(ABC):
assert path.exists(), "Dir does not exist" assert path.exists(), "Dir does not exist"
self.files = list(path.glob("*.h5")) self.files = list(path.glob("*.h5"))
assert len(self.files), "No valid h5 files in dir" assert len(self.files), "No valid h5 files in dir"
self.load_signals() self.load_chains()
def load_signals(self) -> None: def load_chains(self) -> None:
# Sets self.signals # Sets self.chainers
self.signals = {f.name[:-3]: Signal(f) for f in self.files} self.chainers = {f.name[:-3]: Chainer(f) for f in self.files}
@property @property
def fsignal(self) -> Signal: def fsignal(self) -> Chainer:
# Returns first signal # Returns first signal
return list(self.signals.values())[0] return list(self.chainers.values())[0]
@property @property
def ntimepoints(self) -> int: def ntimepoints(self) -> int:
return max([s.ntimepoints for s in self.signals.values()]) return max([s.ntimepoints for s in self.chainers.values()])
@property @property
def tintervals(self) -> float: def tintervals(self) -> float:
tintervals = set([s.tinterval / 60 for s in self.signals.values()]) tintervals = set([s.tinterval / 60 for s in self.chainers.values()])
assert ( assert (
len(tintervals) == 1 len(tintervals) == 1
), "Not all signals have the same time interval" ), "Not all chains have the same time interval"
return max(tintervals) return max(tintervals)
...@@ -60,7 +60,7 @@ class Grouper(ABC): ...@@ -60,7 +60,7 @@ class Grouper(ABC):
def available_grouped(self) -> None: def available_grouped(self) -> None:
if not hasattr(self, "_available_grouped"): if not hasattr(self, "_available_grouped"):
self._available_grouped = Counter( self._available_grouped = Counter(
[x for s in self.signals.values() for x in s.available] [x for s in self.chainers.values() for x in s.available]
) )
for s, n in self._available_grouped.items(): for s, n in self._available_grouped.items():
...@@ -68,7 +68,7 @@ class Grouper(ABC): ...@@ -68,7 +68,7 @@ class Grouper(ABC):
@property @property
def datasets(self) -> None: def datasets(self) -> None:
"""Print available datasets in first Signal instance.""" """Print available datasets in first Chainer instance."""
return self.fsignal.datasets return self.fsignal.datasets
@abstractproperty @abstractproperty
...@@ -78,13 +78,12 @@ class Grouper(ABC): ...@@ -78,13 +78,12 @@ class Grouper(ABC):
def concat_signal( def concat_signal(
self, self,
path: str, path: str,
reduce_cols: t.Optional[bool] = None,
axis: int = 0,
mode: str = "retained",
pool: t.Optional[int] = None, pool: t.Optional[int] = None,
mode: str = "retained",
standard: t.Optional[bool] = False,
**kwargs, **kwargs,
): ):
"""Concatenate a single signal. """Concate
Parameters Parameters
---------- ----------
...@@ -106,36 +105,58 @@ class Grouper(ABC): ...@@ -106,36 +105,58 @@ class Grouper(ABC):
if path.startswith("/"): if path.startswith("/"):
path = path.strip("/") path = path.strip("/")
# Check the path is in a given signal sitems = self.filter_path(path)
sitems = {k: v for k, v in self.signals.items() if path in v.available} if standard:
nsignals_dif = len(self.signals) - len(sitems) fn_pos = concat_standard
if nsignals_dif: else:
print( fn_pos = concat_signal_ind
f"Grouper:Warning: {nsignals_dif} signals do not contain" kwargs["mode"] = mode
f" channel {path}"
)
signals = self.pool_function( kymographs = self.pool_function(
path=path, path=path,
f=concat_signal_ind, f=fn_pos,
mode=mode,
pool=pool, pool=pool,
signals=sitems, chainers=sitems,
**kwargs, **kwargs,
) )
errors = [k for s, k in zip(signals, self.signals.keys()) if s is None] errors = [
signals = [s for s in signals if s is not None] k
for kymo, k in zip(kymographs, self.chainers.keys())
if kymo is None
]
kymographs = [kymo for kymo in kymographs if kymo is not None]
if len(errors): if len(errors):
print("Warning: Positions contain errors {errors}") print("Warning: Positions contain errors {errors}")
assert len(signals), "All datasets contain errors"
sorted = pd.concat(signals, axis=axis).sort_index()
if reduce_cols:
sorted = sorted.apply(np.nanmean, axis=1)
spath = path.split("/")
sorted.name = "_".join([spath[1], spath[-1]])
return sorted assert len(kymographs), "All datasets contain errors"
concat_sorted = (
pd.concat(kymographs, axis=0)
.reorder_levels(
("group", "position", "trap", "cell_label", "mother_label")
)
.sort_index()
)
return concat_sorted
def filter_path(self, path: str) -> t.Dict[str, Chainer]:
# Check the path is in a given signal
sitems = {
k: v
for k, v in self.chainers.items()
if path in [*v.available, *v.common_chains]
}
nchains_dif = len(self.chainers) - len(sitems)
if nchains_dif:
print(
f"Grouper:Warning: {nchains_dif} chains do not contain"
f" channel {path}"
)
assert len(sitems), "No valid dataset to use"
return sitems
@property @property
def nmembers(self) -> t.Dict[str, int]: def nmembers(self) -> t.Dict[str, int]:
...@@ -144,7 +165,7 @@ class Grouper(ABC): ...@@ -144,7 +165,7 @@ class Grouper(ABC):
@property @property
def ntraps(self): def ntraps(self):
for pos, s in self.signals.items(): for pos, s in self.chainers.items():
with h5py.File(s.filename, "r") as f: with h5py.File(s.filename, "r") as f:
print(pos, f["/trap_info/trap_locations"].shape[0]) print(pos, f["/trap_info/trap_locations"].shape[0])
...@@ -152,7 +173,7 @@ class Grouper(ABC): ...@@ -152,7 +173,7 @@ class Grouper(ABC):
def ntraps_by_pos(self) -> t.Dict[str, int]: def ntraps_by_pos(self) -> t.Dict[str, int]:
# Return total number of traps grouped # Return total number of traps grouped
ntraps = {} ntraps = {}
for pos, s in self.signals.items(): for pos, s in self.chainers.items():
with h5py.File(s.filename, "r") as f: with h5py.File(s.filename, "r") as f:
ntraps[pos] = f["/trap_info/trap_locations"].shape[0] ntraps[pos] = f["/trap_info/trap_locations"].shape[0]
...@@ -164,7 +185,7 @@ class Grouper(ABC): ...@@ -164,7 +185,7 @@ class Grouper(ABC):
def traplocs(self): def traplocs(self):
d = {} d = {}
for pos, s in self.signals.items(): for pos, s in self.chainers.items():
with h5py.File(s.filename, "r") as f: with h5py.File(s.filename, "r") as f:
d[pos] = f["/trap_info/trap_locations"][()] d[pos] = f["/trap_info/trap_locations"][()]
return d return d
...@@ -201,15 +222,16 @@ class Grouper(ABC): ...@@ -201,15 +222,16 @@ class Grouper(ABC):
path: str, path: str,
f: t.Callable, f: t.Callable,
pool: t.Optional[int] = None, pool: t.Optional[int] = None,
signals: t.Dict[str, Signal] = None, chainers: t.Dict[str, Chainer] = None,
**kwargs, **kwargs,
): ):
""" """
Wrapper to add support for threading to process independent signals. Wrapper to add support for threading to process independent chains.
Particularly useful when aggregating multiple elements. Particularly useful when aggregating multiple elements.
""" """
pool = pool or 8 if pool is None:
signals = signals or self.signals pool = 8
chainers = chainers or self.chainers
if pool: if pool:
...@@ -217,25 +239,37 @@ class Grouper(ABC): ...@@ -217,25 +239,37 @@ class Grouper(ABC):
kymographs = p.map( kymographs = p.map(
lambda x: f( lambda x: f(
path=path, path=path,
signal=x[1], chainer=x[1],
group=self.positions_groups[x[0]], group=self.positions_groups[x[0]],
position=x[0],
**kwargs, **kwargs,
), ),
signals.items(), chainers.items(),
) )
else: else:
kymographs = [ kymographs = [
f( f(
path=path, path=path,
signal=signal, chainer=chainer,
group=self.positions_groups[name], group=self.positions_groups[name],
position=name,
**kwargs, **kwargs,
) )
for name, signal in self.signals.items() for name, chainer in self.chainers.items()
] ]
return kymographs return kymographs
@property
def channels(self):
return set(
[
channel
for chainer in self.chainers.values()
for channel in chainer.channels
]
)
@property @property
def stages_span(self): def stages_span(self):
return self.fsignal.stages_span return self.fsignal.stages_span
...@@ -273,33 +307,13 @@ class NameGrouper(Grouper): ...@@ -273,33 +307,13 @@ class NameGrouper(Grouper):
def positions_groups(self) -> t.Dict[str, str]: def positions_groups(self) -> t.Dict[str, str]:
if not hasattr(self, "_positions_groups"): if not hasattr(self, "_positions_groups"):
self._positions_groups = {} self._positions_groups = {}
for name in self.signals.keys(): for name in self.chainers.keys():
self._positions_groups[name] = name[ self._positions_groups[name] = name[
self.criteria[0] : self.criteria[1] self.criteria[0] : self.criteria[1]
] ]
return self._positions_groups return self._positions_groups
# def aggregate_multisignals(self, paths=None, **kwargs):
# aggregated = pd.concat(
# [
# self.concat_signal(path, reduce_cols=np.nanmean, **kwargs)
# for path in paths
# ],
# axis=1,
# )
# # ph = pd.Series(
# # [
# # self.ph_from_group(x[list(aggregated.index.names).index("group")])
# # for x in aggregated.index
# # ],
# # index=aggregated.index,
# # name="media_pH",
# # )
# # self.aggregated = pd.concat((aggregated, ph), axis=1)
# return aggregated
class phGrouper(NameGrouper): class phGrouper(NameGrouper):
"""Grouper for pH calibration experiments where all surveyed media pH """Grouper for pH calibration experiments where all surveyed media pH
...@@ -318,8 +332,8 @@ class phGrouper(NameGrouper): ...@@ -318,8 +332,8 @@ class phGrouper(NameGrouper):
return float(group_name.replace("_", ".")) return float(group_name.replace("_", "."))
def aggregate_multisignals(self, paths: list) -> pd.DataFrame: def aggregate_multichains(self, paths: list) -> pd.DataFrame:
"""Accumulate multiple signals.""" """Accumulate multiple chains."""
aggregated = pd.concat( aggregated = pd.concat(
[ [
...@@ -343,9 +357,26 @@ class phGrouper(NameGrouper): ...@@ -343,9 +357,26 @@ class phGrouper(NameGrouper):
return aggregated return aggregated
def concat_standard(
path: str,
chainer: Chainer,
group: str,
position: t.Optional[str] = None,
**kwargs,
) -> pd.DataFrame:
combined = chainer.get(path, **kwargs).copy()
combined["position"] = position
combined["group"] = group
combined.set_index(["group", "position"], inplace=True, append=True)
combined.index = combined.index.copy().swaplevel(-2, 0).swaplevel(-1, 1)
return combined
def concat_signal_ind( def concat_signal_ind(
path: str, path: str,
signal: Signal, chainer: Chainer,
group: str, group: str,
mode: str = "retained", mode: str = "retained",
position=None, position=None,
...@@ -354,23 +385,21 @@ def concat_signal_ind( ...@@ -354,23 +385,21 @@ def concat_signal_ind(
"""Core function that handles retrieval of an individual signal, applies """Core function that handles retrieval of an individual signal, applies
filtering if requested and adjusts indices.""" filtering if requested and adjusts indices."""
if position is None: if position is None:
position = signal.stem position = chainer.stem
if mode == "retained": if mode == "retained":
combined = signal.retained(path, **kwargs) combined = chainer.retained(path, **kwargs)
if mode == "mothers": if mode == "mothers":
raise (NotImplementedError) raise (NotImplementedError)
elif mode == "raw": elif mode == "raw":
combined = signal.get_raw(path, **kwargs) combined = chainer.get_raw(path, **kwargs)
elif mode == "families": elif mode == "families":
combined = signal[path] combined = chainer[path]
combined["position"] = position combined["position"] = position
combined["group"] = group combined["group"] = group
combined.set_index(["group", "position"], inplace=True, append=True) combined.set_index(["group", "position"], inplace=True, append=True)
combined.index = combined.index.swaplevel(-2, 0).swaplevel(-1, 1) combined.index = combined.index.swaplevel(-2, 0).swaplevel(-1, 1)
return combined return combined
# except:
# return None
class MultiGrouper: class MultiGrouper:
...@@ -385,7 +414,7 @@ class MultiGrouper: ...@@ -385,7 +414,7 @@ class MultiGrouper:
self.exp_dirs = [Path(x) for x in source] self.exp_dirs = [Path(x) for x in source]
self.groupers = [NameGrouper(d) for d in self.exp_dirs] self.groupers = [NameGrouper(d) for d in self.exp_dirs]
for group in self.groupers: for group in self.groupers:
group.load_signals() group.load_chains()
@property @property
def available(self) -> None: def available(self) -> None:
...@@ -406,7 +435,7 @@ class MultiGrouper: ...@@ -406,7 +435,7 @@ class MultiGrouper:
if not hasattr(self, "_sigtable"): if not hasattr(self, "_sigtable"):
raw_mat = [ raw_mat = [
[s.available for s in gpr.signals.values()] [s.available for s in gpr.chains.values()]
for gpr in self.groupers for gpr in self.groupers
] ]
available_grouped = [ available_grouped = [
...@@ -435,7 +464,7 @@ class MultiGrouper: ...@@ -435,7 +464,7 @@ class MultiGrouper:
return self._sigtable return self._sigtable
def sigtable_plot(self) -> None: def sigtable_plot(self) -> None:
"""Plot number of signals for all available experiments. """Plot number of chains for all available experiments.
Examples Examples
-------- --------
...@@ -452,17 +481,17 @@ class MultiGrouper: ...@@ -452,17 +481,17 @@ class MultiGrouper:
def aggregate_signal( def aggregate_signal(
self, self,
signals: Union[str, list], path: Union[str, list],
**kwargs, **kwargs,
) -> Union[pd.DataFrame, Dict[str, pd.DataFrame]]: ) -> Union[pd.DataFrame, Dict[str, pd.DataFrame]]:
"""Aggregate signals from multiple Groupers (and thus experiments) """Aggregate chains from multiple Groupers (and thus experiments)
Parameters Parameters
---------- ----------
signals : Union[str, list] chains : Union[str, list]
string or list of strings indicating the signal(s) to fetch. string or list of strings indicating the signal(s) to fetch.
**kwargs : keyword arguments to pass to Grouper.concat_signal **kwargs : keyword arguments to pass to Grouper.concat_signal
Customise the filters and format to fetch signals. Customise the filters and format to fetch chains.
Returns Returns
------- -------
...@@ -473,11 +502,11 @@ class MultiGrouper: ...@@ -473,11 +502,11 @@ class MultiGrouper:
-------- --------
FIXME: Add docs. FIXME: Add docs.
""" """
if isinstance(signals, str): if isinstance(path, str):
signals = [signals] path = [path]
sigs = {s: [] for s in signals} sigs = {s: [] for s in path}
for s in signals: for s in path:
for grp in self.groupers: for grp in self.groupers:
try: try:
sigset = grp.concat_signal(s, **kwargs) sigset = grp.concat_signal(s, **kwargs)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment