From 897da295f1b7b1efe3013de0d8336016ac621011 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <alan.munoz@ed.ac.uk> Date: Thu, 29 Sep 2022 21:03:56 +0100 Subject: [PATCH] refactor(signal): use decorator to parse input --- src/agora/io/signal.py | 37 +++++++++++++++++++++++++------------ 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/src/agora/io/signal.py b/src/agora/io/signal.py index e97e2dc3..5c52f2d3 100644 --- a/src/agora/io/signal.py +++ b/src/agora/io/signal.py @@ -9,6 +9,7 @@ import pandas as pd from utils_find_1st import cmp_larger, find_1st from agora.io.bridge import BridgeH5 +from agora.io.decorators import _first_arg_str_to_df class Signal(BridgeH5): @@ -122,18 +123,24 @@ class Signal(BridgeH5): ).T return lineage - def apply_prepost(self, dataset: str, skip_pick: t.Optional[bool] = None): + @_first_arg_str_to_df + def apply_prepost( + self, + data: t.Union[str, pd.DataFrame], + merges: np.ndarray = None, + picks: t.Optional[bool] = None, + ): """ Apply modifier operations (picker, merger) to a given dataframe. """ - merges = self.get_merges() - df = self.get_raw(dataset) - merged = copy(df) + if merges is None: + merges = self.get_merges() + merged = copy(data) if merges.any(): # Split in two dfs, one with rows relevant for merging and one # without them - valid_merges = validate_merges(merges, np.array(list(df.index))) + valid_merges = validate_merges(merges, np.array(list(data.index))) # TODO use the same info from validate_merges to select both valid_indices = [ @@ -141,18 +148,18 @@ class Signal(BridgeH5): for x in (np.unique(valid_merges.reshape(-1, 2), axis=0)) ] merged = self.apply_merge( - df.loc[valid_indices], + data.loc[valid_indices], valid_merges, ) - nonmergeable_ids = df.index.difference(valid_indices) + nonmergeable_ids = data.index.difference(valid_indices) merged = pd.concat( - (merged, df.loc[nonmergeable_ids]), names=df.index.names + (merged, data.loc[nonmergeable_ids]), names=data.index.names ) with h5py.File(self.filename, "r") as f: - if "modifiers/picks" in f and not skip_pick: + if "modifiers/picks" in f and not picks: picks = self.get_picks(names=merged.index.names) # missing_cells = [i for i in picks if tuple(i) not in # set(merged.index)] @@ -163,7 +170,7 @@ class Signal(BridgeH5): [tuple(x) for x in merged.index] ) ] - return merged.loc[picks] + else: if isinstance(merged.index, pd.MultiIndex): empty_lvls = [[] for i in merged.index.names] @@ -236,7 +243,7 @@ class Signal(BridgeH5): return df - def get_raw(self, dataset, in_minutes=True): + def get_raw(self, dataset: str, in_minutes: bool = True): try: if isinstance(dataset, str): with h5py.File(self.filename, "r") as f: @@ -268,7 +275,13 @@ class Signal(BridgeH5): else: return None - def dataset_to_df(self, f: h5py.File, path: str): + def dataset_to_df(self, f: h5py.File, path: str) -> pd.DataFrame: + """ + Fetch DataFrame from results storage file. + """ + + assert path in f, f"{path} not in {f}" + dset = f[path] index_names = copy(self.index_names) -- GitLab