diff --git a/src/agora/io/signal.py b/src/agora/io/signal.py index 4c6bf0f7efd1a75ff24821100e475b5df7664041..49f7bd7ae7d912730ee2211d84386f2390531934 100644 --- a/src/agora/io/signal.py +++ b/src/agora/io/signal.py @@ -19,14 +19,11 @@ class Signal(BridgeH5): """ Class that fetches data from the hdf5 storage for post-processing - Signal is works under the assumption that metadata and data are - accessible, to perform time-adjustments and apply previously-recorded - postprocesses. + Signal assumes that the metadata and data are accessible to perform time-adjustments and apply previously-recorded postprocesses. """ def __init__(self, file: t.Union[str, PosixPath]): super().__init__(file, flag=None) - self.index_names = ( "experiment", "position", @@ -34,7 +31,6 @@ class Signal(BridgeH5): "cell_label", "mother_label", ) - self.candidate_channels = ( "GFP", "GFPFast", @@ -45,21 +41,14 @@ class Signal(BridgeH5): "Cy5", "pHluorin405", ) - equivalences = { "m5m": ("extraction/GFP/max/max5px", "extraction/GFP/max/median") } def __getitem__(self, dsets: t.Union[str, t.Collection]): - - if isinstance( - dsets, str - ): # or isinstance(Dsets,dsets.endswith("imBackground"): + if isinstance(dsets, str): df = self.get_raw(dsets) - - # elif isinstance(dsets, str): - # df = self.apply_prepost(dsets) - + return self.add_name(df, dsets) elif isinstance(dsets, list): is_bgd = [dset.endswith("imBackground") for dset in dsets] assert sum(is_bgd) == 0 or sum(is_bgd) == len( @@ -71,9 +60,6 @@ class Signal(BridgeH5): else: raise Exception(f"Invalid type {type(dsets)} to get datasets") - # return self.cols_in_mins(self.add_name(df, dsets)) - return self.add_name(df, dsets) - @staticmethod def add_name(df, name): df.name = name @@ -96,7 +82,11 @@ class Signal(BridgeH5): def tinterval(self) -> int: tinterval_location = "time_settings/timeinterval" with h5py.File(self.filename, "r") as f: - return f.attrs[tinterval_location][0] + if tinterval_location in f: + return f.attrs[tinterval_location][0] + else: + print("Using default time interval of 5 minutes") + return 5.0 @staticmethod def get_retained(df, cutoff): @@ -109,14 +99,10 @@ class Signal(BridgeH5): @_first_arg_str_to_df def retained(self, signal, cutoff=0.8): - - df = signal - # df = self[signal] - if isinstance(df, pd.DataFrame): - return self.get_retained(df, cutoff) - - elif isinstance(df, list): - return [self.get_retained(d, cutoff=cutoff) for d in df] + if isinstance(signal, pd.DataFrame): + return self.get_retained(signal, cutoff) + elif isinstance(signal, list): + return [self.get_retained(d, cutoff=cutoff) for d in signal] @lru_cache(2) def lineage( @@ -132,7 +118,6 @@ class Signal(BridgeH5): lineage_location = "postprocessing/lineage" if merged: lineage_location += "_merged" - with h5py.File(self.filename, "r") as f: trap_mo_da = f[lineage_location] lineage = np.array( @@ -175,31 +160,26 @@ class Signal(BridgeH5): """ if isinstance(merges, bool): merges: np.ndarray = self.get_merges() if merges else np.array([]) - - merged = copy(data) - if merges.any(): merged = apply_merges(data, merges) - + else: + merged = copy(data) if isinstance(picks, bool): picks = ( self.get_picks(names=merged.index.names) if picks else set(merged.index) ) - with h5py.File(self.filename, "r") as f: if "modifiers/picks" in f and picks: # missing_cells = [i for i in picks if tuple(i) not in # set(merged.index)] - if picks: return merged.loc[ set(picks).intersection( [tuple(x) for x in merged.index] ) ] - else: if isinstance(merged.index, pd.MultiIndex): empty_lvls = [[] for i in merged.index.names] @@ -217,10 +197,8 @@ class Signal(BridgeH5): def datasets(self): if not hasattr(self, "_available"): self._available = [] - with h5py.File(self.filename, "r") as f: f.visititems(self.store_signal_url) - for sig in self._available: print(sig) @@ -238,10 +216,8 @@ class Signal(BridgeH5): with h5py.File(self.filename, "r") as f: f.visititems(self.store_signal_url) - except Exception as e: print("Error visiting h5: {}".format(e)) - return self._available def get_merged(self, dataset): @@ -266,6 +242,17 @@ class Signal(BridgeH5): def get_raw( self, dataset: str, in_minutes: bool = True, lineage: bool = False ): + """ + Load data from a h5 file and return as a dataframe + + Parameters + ---------- + dataset: str or list of strs + The name of the h5 file or a list of h5 file names + in_minutes: boolean + If True, + lineage: boolean + """ try: if isinstance(dataset, str): with h5py.File(self.filename, "r") as f: @@ -274,8 +261,8 @@ class Signal(BridgeH5): df = self.cols_in_mins(df) elif isinstance(dataset, list): return [self.get_raw(dset) for dset in dataset] - - if lineage: # This assumes that df is sorted + if lineage: + # assumes that df is sorted mother_label = np.zeros(len(df), dtype=int) lineage = self.lineage() a, b = validate_association( @@ -285,9 +272,7 @@ class Signal(BridgeH5): ) mother_label[b] = lineage[a, 1] df = add_index_levels(df, {"mother_label": mother_label}) - return df - except Exception as e: print(f"Could not fetch dataset {dataset}") raise e @@ -298,7 +283,6 @@ class Signal(BridgeH5): merges = f.get("modifiers/merges", np.array([])) if not isinstance(merges, np.ndarray): merges = merges[()] - return merges def get_picks( @@ -313,34 +297,25 @@ class Signal(BridgeH5): picks = set() if path in f: picks = set(zip(*[f[path + name] for name in names])) - return picks def dataset_to_df(self, f: h5py.File, path: str) -> pd.DataFrame: """ Fetch DataFrame from results storage file. """ - assert path in f, f"{path} not in {f}" - dset = f[path] - - values, index, columns = ([], [], []) - + values, index, columns = [], [], [] index_names = copy(self.index_names) valid_names = [lbl for lbl in index_names if lbl in dset.keys()] if valid_names: - index = pd.MultiIndex.from_arrays( [dset[lbl] for lbl in valid_names], names=valid_names ) - - columns = dset.attrs.get("columns", None) # dset.attrs["columns"] + columns = dset.attrs.get("columns", None) if "timepoint" in dset: columns = f[path + "/timepoint"][()] - values = f[path + "/values"][()] - return pd.DataFrame( values, index=index, @@ -351,24 +326,6 @@ class Signal(BridgeH5): def stem(self): return self.filename.stem - # def dataset_to_df(self, f: h5py.File, path: str): - - # all_indices = self.index_names - - # valid_indices = { - # k: f[path][k][()] for k in all_indices if k in f[path].keys() - # } - - # new_index = pd.MultiIndex.from_arrays( - # list(valid_indices.values()), names=valid_indices.keys() - # ) - - # return pd.DataFrame( - # f[path + "/values"][()], - # index=new_index, - # columns=f[path + "/timepoint"][()], - # ) - def store_signal_url( self, fullname: str, node: t.Union[h5py.Dataset, h5py.Group] ): @@ -413,7 +370,6 @@ class Signal(BridgeH5): flowrate_name = "pumpinit/flowrate" pumprate_name = "pumprate" switchtimes_name = "switchtimes" - main_pump_id = np.concatenate( ( (np.argmax(self.meta_h5[flowrate_name]),), @@ -436,7 +392,6 @@ class Signal(BridgeH5): def switch_times(self) -> t.List[int]: switchtimes_name = "switchtimes" switches_minutes = self.meta_h5[switchtimes_name] - return [ t_min for t_min in switches_minutes diff --git a/src/postprocessor/chainer.py b/src/postprocessor/chainer.py index b9c43b0c1f0cedc63bd16296b58a32e25c45aa5f..43269ee2b27ff0c713333308b4a7422eca77ee54 100644 --- a/src/postprocessor/chainer.py +++ b/src/postprocessor/chainer.py @@ -20,13 +20,15 @@ class Chainer(Signal): Instead of reading processes previously applied, it executes them when called. """ - - process_types = ("multisignal", "processes", "reshapers") - common_chains = {} + #process_types = ("multisignal", "processes", "reshapers") + #common_chains = {} def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) for channel in self.candidate_channels: + # find first channel in h5 file that corresponds to a candidate_channel + # but channel is redefined. why is there a loop over candidate channels? + # what about capitals? try: channel = [ ch for ch in self.channels if re.match("channel", ch) @@ -34,8 +36,9 @@ class Chainer(Signal): break except: pass - try: + # what's this? + # composite statistic comprising the quotient of two others equivalences = { "m5m": ( f"extraction/{channel}/max/max5px", @@ -43,13 +46,15 @@ class Chainer(Signal): ), } + # function to add bgsub to urls def replace_url(url: str, bgsub: str = ""): - # return pattern with bgsub channel = url.split("/")[1] if "bgsub" in bgsub: + # add bgsub to url url = re.sub(channel, f"{channel}_bgsub", url) return url + # add chain with and without bgsub self.common_chains = { alias + bgsub: lambda **kwargs: self.get( @@ -59,7 +64,6 @@ class Chainer(Signal): for alias, (denominator, numerator) in equivalences.items() for bgsub in ("", "_bgsub") } - except: pass @@ -72,20 +76,17 @@ class Chainer(Signal): retain: t.Optional[float] = None, **kwargs, ): - if dataset in self.common_chains: # Produce dataset on the fly + if dataset in self.common_chains: + # produce dataset on the fly data = self.common_chains[dataset](**kwargs) else: data = self.get_raw(dataset, in_minutes=in_minutes) if chain: data = self.apply_chain(data, chain, **kwargs) - if retain: data = data.loc[data.notna().sum(axis=1) > data.shape[1] * retain] - - if ( - stages and "stage" not in data.columns.names - ): # Return stages as additional column level - + if (stages and "stage" not in data.columns.names): + # return stages as additional column level stages_index = [ x for i, (name, span) in enumerate(self.stages_span_tp) @@ -95,13 +96,13 @@ class Chainer(Signal): zip(stages_index, data.columns), names=("stage", "time"), ) - return data def apply_chain( self, input_data: pd.DataFrame, chain: t.Tuple[str, ...], **kwargs ): - """Apply a series of processes to a dataset. + """ + Apply a series of processes to a dataset. In a similar fashion to how postprocessing works, Chainer allows the consecutive application of processes to a dataset. Parameters can be diff --git a/src/postprocessor/grouper.py b/src/postprocessor/grouper.py index 167c0b5e205775b825b1950041e741a523b519dc..9f0683edb4dcc1a64b8ac1a8f502fb1c0e316932 100644 --- a/src/postprocessor/grouper.py +++ b/src/postprocessor/grouper.py @@ -26,8 +26,6 @@ from postprocessor.chainer import Chainer class Grouper(ABC): """Base grouper class.""" - files = [] - def __init__(self, dir: Union[str, PosixPath]): path = Path(dir) self.name = path.name @@ -37,12 +35,11 @@ class Grouper(ABC): self.load_chains() def load_chains(self) -> None: - # Sets self.chainers self.chainers = {f.name[:-3]: Chainer(f) for f in self.files} @property def fsignal(self) -> Chainer: - # Returns first signal + # returns first signal return list(self.chainers.values())[0] @property @@ -110,14 +107,12 @@ class Grouper(ABC): """ if path.startswith("/"): path = path.strip("/") - sitems = self.filter_path(path) if standard: fn_pos = concat_standard else: fn_pos = concat_signal_ind kwargs["mode"] = mode - kymographs = self.pool_function( path=path, f=fn_pos, @@ -125,7 +120,6 @@ class Grouper(ABC): chainers=sitems, **kwargs, ) - errors = [ k for kymo, k in zip(kymographs, self.chainers.keys()) @@ -134,20 +128,15 @@ class Grouper(ABC): kymographs = [kymo for kymo in kymographs if kymo is not None] if len(errors): print("Warning: Positions contain errors {errors}") - assert len(kymographs), "All datasets contain errors" - concat = pd.concat(kymographs, axis=0) - if ( len(concat.index.names) > 4 ): # Reorder levels when mother_label is present concat = concat.reorder_levels( ("group", "position", "trap", "cell_label", "mother_label") ) - concat_sorted = concat.sort_index() - return concat_sorted def filter_path(self, path: str) -> t.Dict[str, Chainer]: @@ -163,11 +152,9 @@ class Grouper(ABC): f"Grouper:Warning: {nchains_dif} chains do not contain" f" channel {path}" ) - assert len( sitems ), f"No valid dataset to use. Valid datasets are {self.available}" - return sitems @property