Skip to content
Snippets Groups Projects
Commit dd6a27d3 authored by pswain's avatar pswain
Browse files

removed chainer; NameGrouper is now Grouper

parent 6a5afb7c
No related branches found
No related tags found
No related merge requests found
#!/usr/bin/env jupyter
import re
import typing as t
from copy import copy
import pandas as pd
from agora.io.signal import Signal
from agora.utils.kymograph import bidirectional_retainment_filter
from postprocessor.core.abc import get_process
class Chainer(Signal):
"""
Extend Signal by applying post-processes and allowing composite signals that combine basic signals.
Chainer "chains" multiple processes upon fetching a dataset.
Instead of reading processes previously applied, Chainer executes
them when called.
"""
_synonyms = {}
def __init__(self, *args, **kwargs):
"""Initialise chainer."""
super().__init__(*args, **kwargs)
def replace_path(path: str, bgsub: bool = ""):
# function to add bgsub to paths
channel = path.split("/")[1]
suffix = "_bgsub" if bgsub else ""
path = re.sub(channel, f"{channel}{suffix}", path)
return path
# add chain with and without bgsub for composite statistics
self.common_chains = {
alias
+ bgsub: lambda **kwargs: self.get(
replace_path(denominator, alias + bgsub), **kwargs
)
/ self.get(replace_path(numerator, alias + bgsub), **kwargs)
for alias, (denominator, numerator) in self._synonyms.items()
for bgsub in ("", "_bgsub")
}
def get(
self,
dataset: str,
chain: t.Collection[str] = ("standard", "interpolate", "savgol"),
in_minutes: bool = True,
stages: bool = True,
retain: t.Optional[float] = None,
**kwargs,
):
"""Load data from an h5 file."""
if dataset in self.common_chains:
# get dataset for composite chains
data = self.common_chains[dataset](**kwargs)
else:
# use Signal's get_raw
data = self.get_raw(dataset, in_minutes=in_minutes, lineage=True)
if chain:
data = self.apply_chain(data, chain, **kwargs)
if retain:
# keep data only from early time points
data = self.get_retained(data, retain)
if stages and "stage" not in data.columns.names:
# return stages as additional column level
stages_index = [
x
for i, (name, span) in enumerate(self.stages_span_tp)
for x in (f"{i} { name }",) * span
]
data.columns = pd.MultiIndex.from_tuples(
zip(stages_index, data.columns),
names=("stage", "time"),
)
return data
def apply_chain(
self, input_data: pd.DataFrame, chain: t.Tuple[str, ...], **kwargs
):
"""
Apply a series of processes to a data set.
Like postprocessing, Chainer consecutively applies processes.
Parameters can be passed as kwargs.
Chainer does not support applying the same process multiple times with different parameters.
Parameters
----------
input_data : pd.DataFrame
Input data to process.
chain : t.Tuple[str, ...]
Tuple of strings with the names of the processes
**kwargs : kwargs
Arguments passed on to Process.as_function() method to modify the parameters.
Examples
--------
FIXME: Add docs.
"""
result = copy(input_data)
self._intermediate_steps = []
for process in chain:
if process == "standard":
result = bidirectional_retainment_filter(result)
else:
params = kwargs.get(process, {})
process_cls = get_process(process)
result = process_cls.as_function(result, **params)
process_type = process_cls.__module__.split(".")[-2]
if process_type == "reshapers":
if process == "merger":
raise (NotImplementedError)
self._intermediate_steps.append(result)
return result
...@@ -13,14 +13,13 @@ import pandas as pd ...@@ -13,14 +13,13 @@ import pandas as pd
from pathos.multiprocessing import Pool from pathos.multiprocessing import Pool
from agora.io.signal import Signal from agora.io.signal import Signal
from postprocessor.chainer import Chainer
class Grouper(ABC): class Grouper(ABC):
"""Base grouper class.""" """Base grouper class."""
def __init__(self, dir: Union[str, Path], name_inds=(0, -4)): def __init__(self, dir: Union[str, Path], name_inds=(0, -4)):
"""Find h5 files and load a chain for each one.""" """Find h5 files and load each one."""
path = Path(dir) path = Path(dir)
assert path.exists(), f"{str(dir)} does not exist" assert path.exists(), f"{str(dir)} does not exist"
self.name = path.name self.name = path.name
...@@ -29,31 +28,33 @@ class Grouper(ABC): ...@@ -29,31 +28,33 @@ class Grouper(ABC):
self.load_positions() self.load_positions()
self.positions_groups = { self.positions_groups = {
name: name[name_inds[0] : name_inds[1]] name: name[name_inds[0] : name_inds[1]]
for name in self.chainers.keys() for name in self.positions.keys()
} }
def load_positions(self) -> None: def load_positions(self) -> None:
"""Load a chain for each position, or h5 file.""" """Load a Signal for each position, or h5 file."""
self.chainers = {f.name[:-3]: Signal(f) for f in self.files} self.positions = {f.name[:-3]: Signal(f) for f in self.files}
@property @property
def first_signal(self) -> Signal: def first_signal(self) -> Signal:
"""Get Signal for the first position.""" """Get Signal for the first position."""
return list(self.chainers.values())[0] return list(self.positions.values())[0]
@property @property
def ntimepoints(self) -> int: def ntimepoints(self) -> int:
"""Find number of time points.""" """Find number of time points."""
return max([s.ntimepoints for s in self.chainers.values()]) return max([s.ntimepoints for s in self.positions.values()])
@property @property
def max_tinterval(self) -> float: def tinterval(self) -> float:
"""Find the maximum time interval for all chains.""" """Find the time interval for all positions."""
tintervals = set([s.tinterval / 60 for s in self.chainers.values()]) tintervals = list(
set([s.tinterval / 60 for s in self.positions.values()])
)
assert ( assert (
len(tintervals) == 1 len(tintervals) == 1
), "Not all chains have the same time interval" ), "Not all positions have the same time interval."
return max(tintervals) return tintervals[0]
@property @property
def available(self) -> t.Collection[str]: def available(self) -> t.Collection[str]:
...@@ -65,7 +66,7 @@ class Grouper(ABC): ...@@ -65,7 +66,7 @@ class Grouper(ABC):
"""Display available signals and the number of positions with these signals.""" """Display available signals and the number of positions with these signals."""
if not hasattr(self, "available_grouped"): if not hasattr(self, "available_grouped"):
self._available_grouped = Counter( self._available_grouped = Counter(
[x for s in self.chainers.values() for x in s.available] [x for s in self.positions.values() for x in s.available]
) )
for s, n in self._available_grouped.items(): for s, n in self._available_grouped.items():
print(f"{s} - {n}") print(f"{s} - {n}")
...@@ -78,7 +79,7 @@ class Grouper(ABC): ...@@ -78,7 +79,7 @@ class Grouper(ABC):
**kwargs, **kwargs,
): ):
""" """
Concatenate data for one signal from different h5 files into a data frame. Concatenate data for one signal from different h5 files.
Each h5 files corresponds to a different position. Each h5 files corresponds to a different position.
...@@ -89,7 +90,6 @@ class Grouper(ABC): ...@@ -89,7 +90,6 @@ class Grouper(ABC):
pool : int pool : int
Number of threads used; if 0 or None only one core is used. Number of threads used; if 0 or None only one core is used.
mode: str mode: str
standard: boolean
**kwargs : key, value pairings **kwargs : key, value pairings
Named arguments for concat_ind_function Named arguments for concat_ind_function
...@@ -101,19 +101,18 @@ class Grouper(ABC): ...@@ -101,19 +101,18 @@ class Grouper(ABC):
path = path.strip("/") path = path.strip("/")
good_positions = self.filter_positions(path) good_positions = self.filter_positions(path)
if good_positions: if good_positions:
fn_pos = concat_one_signal
kwargs["mode"] = mode kwargs["mode"] = mode
records = self.pool_function( records = self.pool_function(
path=path, path=path,
f=fn_pos, f=concat_one_signal,
pool=pool, pool=pool,
chainers=good_positions, positions=good_positions,
**kwargs, **kwargs,
) )
# check for errors # check for errors
errors = [ errors = [
k k
for kymo, k in zip(records, self.chainers.keys()) for kymo, k in zip(records, self.positions.keys())
if kymo is None if kymo is None
] ]
records = [record for record in records if record is not None] records = [record for record in records if record is not None]
...@@ -130,12 +129,12 @@ class Grouper(ABC): ...@@ -130,12 +129,12 @@ class Grouper(ABC):
concat_sorted = concat.sort_index() concat_sorted = concat.sort_index()
return concat_sorted return concat_sorted
def filter_positions(self, path: str) -> t.Dict[str, Chainer]: def filter_positions(self, path: str) -> t.Dict[str, Signal]:
"""Filter chains to those whose data is available in the h5 file.""" """Filter positions to those whose data is available in the h5 file."""
good_positions = { good_positions = {
k: v for k, v in self.chainers.items() if path in [*v.available] k: v for k, v in self.positions.items() if path in [*v.available]
} }
no_positions_dif = len(self.chainers) - len(good_positions) no_positions_dif = len(self.positions) - len(good_positions)
if no_positions_dif: if no_positions_dif:
print( print(
f"Grouper: Warning: {no_positions_dif} positions do not contain" f"Grouper: Warning: {no_positions_dif} positions do not contain"
...@@ -148,52 +147,52 @@ class Grouper(ABC): ...@@ -148,52 +147,52 @@ class Grouper(ABC):
path: str, path: str,
f: t.Callable, f: t.Callable,
pool: t.Optional[int] = None, pool: t.Optional[int] = None,
chainers: t.Dict[str, Chainer] = None, positions: t.Dict[str, Signal] = None,
**kwargs, **kwargs,
): ):
""" """
Enable different threads for independent chains. Enable different threads for different positions.
Particularly useful when aggregating multiple elements. Particularly useful when aggregating multiple elements.
""" """
chainers = chainers or self.chainers positions = positions or self.positions
if pool: if pool:
with Pool(pool) as p: with Pool(pool) as p:
records = p.map( records = p.map(
lambda x: f( lambda x: f(
path=path, path=path,
chainer=x[1], position=x[1],
group=self.positions_groups[x[0]], group=self.positions_groups[x[0]],
position=x[0], position_name=x[0],
**kwargs, **kwargs,
), ),
chainers.items(), positions.items(),
) )
else: else:
records = [ records = [
f( f(
path=path, path=path,
chainer=chainer, position=position,
group=self.positions_groups[name], group=self.positions_groups[name],
position=name, position_name=name,
**kwargs, **kwargs,
) )
for name, chainer in self.chainers.items() for name, position in self.positions.items()
] ]
return records return records
@property @property
def no_tiles(self): def no_tiles(self):
"""Get total number of tiles per position (h5 file).""" """Get total number of tiles per position (h5 file)."""
for pos, s in self.chainers.items(): for pos, s in self.positions.items():
with h5py.File(s.filename, "r") as f: with h5py.File(s.filename, "r") as f:
print(pos, f["/trap_info/trap_locations"].shape[0]) print(pos, f["/trap_info/trap_locations"].shape[0])
@property @property
def tilelocs(self) -> t.Dict[str, np.ndarray]: def tile_locs(self) -> t.Dict[str, np.ndarray]:
"""Get the locations of the tiles for each position as a dictionary.""" """Get the locations of the tiles for each position as a dictionary."""
d = {} d = {}
for pos, s in self.chainers.items(): for pos, s in self.positions.items():
with h5py.File(s.filename, "r") as f: with h5py.File(s.filename, "r") as f:
d[pos] = f["/trap_info/trap_locations"][()] d[pos] = f["/trap_info/trap_locations"][()]
return d return d
...@@ -223,16 +222,11 @@ class Grouper(ABC): ...@@ -223,16 +222,11 @@ class Grouper(ABC):
return set( return set(
[ [
channel channel
for position in self.chainers.values() for position in self.positions.values()
for channel in position.channels for channel in position.channels
] ]
) )
@property
def tinterval(self):
"""Get interval between time points in seconds."""
return self.first_signal.tinterval
@property @property
def no_members(self) -> t.Dict[str, int]: def no_members(self) -> t.Dict[str, int]:
"""Get the number of positions belonging to each group.""" """Get the number of positions belonging to each group."""
...@@ -242,7 +236,7 @@ class Grouper(ABC): ...@@ -242,7 +236,7 @@ class Grouper(ABC):
def no_tiles_by_group(self) -> t.Dict[str, int]: def no_tiles_by_group(self) -> t.Dict[str, int]:
"""Get total number of tiles per group.""" """Get total number of tiles per group."""
no_tiles = {} no_tiles = {}
for pos, s in self.chainers.items(): for pos, s in self.positions.items():
with h5py.File(s.filename, "r") as f: with h5py.File(s.filename, "r") as f:
no_tiles[pos] = f["/trap_info/trap_locations"].shape[0] no_tiles[pos] = f["/trap_info/trap_locations"].shape[0]
no_tiles_by_group = {k: 0 for k in self.groups} no_tiles_by_group = {k: 0 for k in self.groups}
...@@ -261,21 +255,12 @@ class Grouper(ABC): ...@@ -261,21 +255,12 @@ class Grouper(ABC):
return tuple(sorted(set(self.positions_groups.keys()))) return tuple(sorted(set(self.positions_groups.keys())))
class NameGrouper(Grouper):
"""Group a set of positions with a shorter version of the group's name."""
def __init__(self, dir, name_inds=(0, -4)):
"""Define the indices to slice names."""
super().__init__(dir=dir)
self.name_inds = name_inds
def concat_one_signal( def concat_one_signal(
path: str, path: str,
chainer: Chainer, position: Signal,
group: str, group: str,
mode: str = "retained", mode: str = "retained",
position=None, position_name=None,
**kwargs, **kwargs,
) -> pd.DataFrame: ) -> pd.DataFrame:
""" """
...@@ -283,27 +268,26 @@ def concat_one_signal( ...@@ -283,27 +268,26 @@ def concat_one_signal(
Applies filtering if requested and adjusts indices. Applies filtering if requested and adjusts indices.
""" """
if position is None: if position_name is None:
# name of h5 file # name of h5 file
position = chainer.stem position_name = position.stem
if mode == "retained": if mode == "retained":
combined = chainer.retained(path, **kwargs) combined = position.retained(path, **kwargs)
elif mode == "raw": elif mode == "raw":
combined = chainer.get_raw(path, **kwargs) combined = position.get_raw(path, **kwargs)
elif mode == "daughters": elif mode == "daughters":
combined = chainer.get_raw(path, **kwargs) combined = position.get_raw(path, **kwargs)
combined = combined.loc[ combined = combined.loc[
combined.index.get_level_values("mother_label") > 0 combined.index.get_level_values("mother_label") > 0
] ]
elif mode == "families": elif mode == "families":
combined = chainer[path] combined = position[path]
else: else:
raise Exception(f"{mode} not recognised.") raise Exception(f"concat_one_signal: {mode} not recognised.")
if combined is not None: if combined is not None:
# adjust indices # adjust indices
combined["position"] = position combined["position"] = position_name
combined["group"] = group combined["group"] = group
combined.set_index(["group", "position"], inplace=True, append=True) combined.set_index(["group", "position"], inplace=True, append=True)
combined.index = combined.index.swaplevel(-2, 0).swaplevel(-1, 1) combined.index = combined.index.swaplevel(-2, 0).swaplevel(-1, 1)
# should there be an error message if None is returned?
return combined return combined
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment