From 37cc32fc5de80743c7654a27fdb6d0491f8cf468 Mon Sep 17 00:00:00 2001 From: pswain <peter.swain@ed.ac.uk> Date: Thu, 21 Dec 2023 19:03:10 +0000 Subject: [PATCH] feature(grouper, signal): added tmax_in_mins_dict and tmax_in_mins --- src/agora/io/signal.py | 32 +++++++++++++++++------ src/postprocessor/grouper.py | 49 ++++++++++++++++++++---------------- 2 files changed, 53 insertions(+), 28 deletions(-) diff --git a/src/agora/io/signal.py b/src/agora/io/signal.py index 2f787ff..3a5261f 100644 --- a/src/agora/io/signal.py +++ b/src/agora/io/signal.py @@ -52,10 +52,14 @@ class Signal(BridgeH5): else: raise Exception(f"Invalid type {type(dsets)} to get datasets") - def get(self, dset_name: t.Union[str, t.Collection]): + def get( + self, + dset_name: t.Union[str, t.Collection], + tmax_in_mins: int = None, + ): """Get Signal after merging and picking.""" if isinstance(dset_name, str): - dsets = self.get_raw(dset_name) + dsets = self.get_raw(dset_name, tmax_in_mins) if dsets is not None: picked_merged = self.apply_merging_picking(dsets) return self.add_name(picked_merged, dset_name) @@ -101,11 +105,11 @@ class Signal(BridgeH5): ) return 300 - def retained(self, signal, cutoff=0): + def retained(self, signal, cutoff=0, tmax_in_mins: int = None): """Get retained cells for a Signal or list of Signals.""" if isinstance(signal, str): # get data frame - signal = self.get(signal) + signal = self.get(signal, tmax_in_mins=tmax_in_mins) if isinstance(signal, pd.DataFrame): return self.get_retained(signal, cutoff) elif isinstance(signal, list): @@ -199,7 +203,7 @@ class Signal(BridgeH5): return merged @cached_property - def p_available(self): + def print_available(self): """Print data sets available in h5 file.""" if not hasattr(self, "_available"): self._available = [] @@ -218,7 +222,6 @@ class Signal(BridgeH5): f.visititems(self.store_signal_path) except Exception as e: self._log("Exception when visiting h5: {}".format(e), "exception") - return self._available def get_merged(self, dataset): @@ -250,6 +253,7 @@ class Signal(BridgeH5): in_minutes: bool = True, lineage: bool = False, stop_on_lineage_check: bool = True, + tmax_in_mins: int = None, **kwargs, ) -> pd.DataFrame or t.List[pd.DataFrame]: """ @@ -265,6 +269,11 @@ class Signal(BridgeH5): If True, add mother_label to index. run_lineage_check: boolean If True, raise exception if a likely error in the lineage assignment. + tmax_in_mins: int (optional) + Discard data for times > tmax_in_mins. Cells with all NaNs will also + be discarded to help with assigning lineages. + Setting tmax_in_mins is a way to ignore parts of the experiment with + incorrect lineages generated by clogging. """ try: if isinstance(dataset, str): @@ -274,6 +283,10 @@ class Signal(BridgeH5): df = df.sort_index() if in_minutes: df = self.cols_in_mins(df) + # limit data by time and discard NaNs + if tmax_in_mins and type(tmax_in_mins) is int: + df = df[df.columns[df.columns < tmax_in_mins]] + df = df.dropna(how="all") # add mother label to data frame if lineage: mother_label = np.zeros(len(df), dtype=int) @@ -293,7 +306,12 @@ class Signal(BridgeH5): return df elif isinstance(dataset, list): return [ - self.get_raw(dset, in_minutes=in_minutes, lineage=lineage) + self.get_raw( + dset, + in_minutes=in_minutes, + lineage=lineage, + tmax_in_mins=tmax_in_mins, + ) for dset in dataset ] except Exception as e: diff --git a/src/postprocessor/grouper.py b/src/postprocessor/grouper.py index 0abc8ad..f964af7 100644 --- a/src/postprocessor/grouper.py +++ b/src/postprocessor/grouper.py @@ -33,7 +33,7 @@ class Grouper(ABC): def load_positions(self) -> None: """Load a Signal for each position, or h5 file.""" - self.positions = {f.name[:-3]: Signal(f) for f in self.files} + self.positions = {f.name[:-3]: Signal(f) for f in sorted(self.files)} @property def first_signal(self) -> Signal: @@ -132,23 +132,25 @@ class Grouper(ABC): # check for errors errors = [ position - for record, position in zip(records, self.positions.keys()) + for record, position in zip(records, good_positions.keys()) if record is None ] records = [record for record in records if record is not None] if len(errors): print(f"Warning: Positions ({errors}) contain errors.") assert len(records), "All data sets contain errors" - # combine into one dataframe + # combine into one data frame concat = pd.concat(records, axis=0) if len(concat.index.names) > 4: - # reorder levels in the multi-index dataframe + # reorder levels in the multi-index data frame # when mother_label is present concat = concat.reorder_levels( ("group", "position", "trap", "cell_label", "mother_label") ) concat_sorted = concat.sort_index() return concat_sorted + else: + print("No data found.") def filter_positions(self, path: str) -> t.Dict[str, Signal]: """Filter positions to those whose data is available in the h5 file.""" @@ -267,13 +269,8 @@ class Grouper(ABC): @property def groups(self) -> t.Tuple[str]: - """Get groups, sorted alphabetically, as a tuple.""" - return tuple(sorted(set(self.positions_groups.values()))) - - @property - def positions(self) -> t.Tuple[str]: - """Get positions, sorted alphabetically, as a tuple.""" - return tuple(sorted(set(self.positions_groups.keys()))) + """Get groups, sorted alphabetically, as a list.""" + return list(sorted(set(self.positions_groups.values()))) def concat_one_signal( @@ -282,29 +279,39 @@ def concat_one_signal( group: str, mode: str = "retained", position_name=None, + tmax_in_mins_dict=None, **kwargs, ) -> pd.DataFrame: - """ - Retrieve an individual signal. - - Applies filtering if requested and adjusts indices. - """ + """Retrieve a signal for one position.""" + if tmax_in_mins_dict and position_name in tmax_in_mins_dict: + tmax_in_mins = tmax_in_mins_dict[position_name] + else: + tmax_in_mins = None if position_name is None: # name of h5 file position_name = position.stem - print(f"Finding signal for {position_name}.") + if tmax_in_mins: + print( + f" Loading {path} for {position_name} up to time {tmax_in_mins}." + ) + else: + print(f" Loading {path} for {position_name}.") if mode == "retained": - combined = position.retained(path, **kwargs) + combined = position.retained(path, tmax_in_mins=tmax_in_mins, **kwargs) elif mode == "raw": - combined = position.get_raw(path, **kwargs) + combined = position.get_raw(path, tmax_in_mins=tmax_in_mins, **kwargs) elif mode == "daughters": - combined = position.get_raw(path, lineage=True, **kwargs) + combined = position.get_raw( + path, lineage=True, tmax_in_mins=tmax_in_mins, **kwargs + ) if combined is not None: combined = combined.loc[ combined.index.get_level_values("mother_label") > 0 ] elif mode == "mothers": - combined = position.get_raw(path, lineage=True, **kwargs) + combined = position.get_raw( + path, lineage=True, tmax_in_mins=tmax_in_mins, **kwargs + ) if combined is not None: combined = combined.loc[ combined.index.get_level_values("mother_label") == 0 -- GitLab