Skip to content
Snippets Groups Projects
Commit 897da295 authored by Alán Muñoz's avatar Alán Muñoz
Browse files

refactor(signal): use decorator to parse input

parent 0c6d7660
No related branches found
No related tags found
No related merge requests found
......@@ -9,6 +9,7 @@ import pandas as pd
from utils_find_1st import cmp_larger, find_1st
from agora.io.bridge import BridgeH5
from agora.io.decorators import _first_arg_str_to_df
class Signal(BridgeH5):
......@@ -122,18 +123,24 @@ class Signal(BridgeH5):
).T
return lineage
def apply_prepost(self, dataset: str, skip_pick: t.Optional[bool] = None):
@_first_arg_str_to_df
def apply_prepost(
self,
data: t.Union[str, pd.DataFrame],
merges: np.ndarray = None,
picks: t.Optional[bool] = None,
):
"""
Apply modifier operations (picker, merger) to a given dataframe.
"""
merges = self.get_merges()
df = self.get_raw(dataset)
merged = copy(df)
if merges is None:
merges = self.get_merges()
merged = copy(data)
if merges.any():
# Split in two dfs, one with rows relevant for merging and one
# without them
valid_merges = validate_merges(merges, np.array(list(df.index)))
valid_merges = validate_merges(merges, np.array(list(data.index)))
# TODO use the same info from validate_merges to select both
valid_indices = [
......@@ -141,18 +148,18 @@ class Signal(BridgeH5):
for x in (np.unique(valid_merges.reshape(-1, 2), axis=0))
]
merged = self.apply_merge(
df.loc[valid_indices],
data.loc[valid_indices],
valid_merges,
)
nonmergeable_ids = df.index.difference(valid_indices)
nonmergeable_ids = data.index.difference(valid_indices)
merged = pd.concat(
(merged, df.loc[nonmergeable_ids]), names=df.index.names
(merged, data.loc[nonmergeable_ids]), names=data.index.names
)
with h5py.File(self.filename, "r") as f:
if "modifiers/picks" in f and not skip_pick:
if "modifiers/picks" in f and not picks:
picks = self.get_picks(names=merged.index.names)
# missing_cells = [i for i in picks if tuple(i) not in
# set(merged.index)]
......@@ -163,7 +170,7 @@ class Signal(BridgeH5):
[tuple(x) for x in merged.index]
)
]
return merged.loc[picks]
else:
if isinstance(merged.index, pd.MultiIndex):
empty_lvls = [[] for i in merged.index.names]
......@@ -236,7 +243,7 @@ class Signal(BridgeH5):
return df
def get_raw(self, dataset, in_minutes=True):
def get_raw(self, dataset: str, in_minutes: bool = True):
try:
if isinstance(dataset, str):
with h5py.File(self.filename, "r") as f:
......@@ -268,7 +275,13 @@ class Signal(BridgeH5):
else:
return None
def dataset_to_df(self, f: h5py.File, path: str):
def dataset_to_df(self, f: h5py.File, path: str) -> pd.DataFrame:
"""
Fetch DataFrame from results storage file.
"""
assert path in f, f"{path} not in {f}"
dset = f[path]
index_names = copy(self.index_names)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment