Skip to content
Snippets Groups Projects
Commit 37cc32fc authored by pswain's avatar pswain
Browse files

feature(grouper, signal): added tmax_in_mins_dict and tmax_in_mins

parent 20d85934
No related branches found
No related tags found
No related merge requests found
...@@ -52,10 +52,14 @@ class Signal(BridgeH5): ...@@ -52,10 +52,14 @@ class Signal(BridgeH5):
else: else:
raise Exception(f"Invalid type {type(dsets)} to get datasets") raise Exception(f"Invalid type {type(dsets)} to get datasets")
def get(self, dset_name: t.Union[str, t.Collection]): def get(
self,
dset_name: t.Union[str, t.Collection],
tmax_in_mins: int = None,
):
"""Get Signal after merging and picking.""" """Get Signal after merging and picking."""
if isinstance(dset_name, str): if isinstance(dset_name, str):
dsets = self.get_raw(dset_name) dsets = self.get_raw(dset_name, tmax_in_mins)
if dsets is not None: if dsets is not None:
picked_merged = self.apply_merging_picking(dsets) picked_merged = self.apply_merging_picking(dsets)
return self.add_name(picked_merged, dset_name) return self.add_name(picked_merged, dset_name)
...@@ -101,11 +105,11 @@ class Signal(BridgeH5): ...@@ -101,11 +105,11 @@ class Signal(BridgeH5):
) )
return 300 return 300
def retained(self, signal, cutoff=0): def retained(self, signal, cutoff=0, tmax_in_mins: int = None):
"""Get retained cells for a Signal or list of Signals.""" """Get retained cells for a Signal or list of Signals."""
if isinstance(signal, str): if isinstance(signal, str):
# get data frame # get data frame
signal = self.get(signal) signal = self.get(signal, tmax_in_mins=tmax_in_mins)
if isinstance(signal, pd.DataFrame): if isinstance(signal, pd.DataFrame):
return self.get_retained(signal, cutoff) return self.get_retained(signal, cutoff)
elif isinstance(signal, list): elif isinstance(signal, list):
...@@ -199,7 +203,7 @@ class Signal(BridgeH5): ...@@ -199,7 +203,7 @@ class Signal(BridgeH5):
return merged return merged
@cached_property @cached_property
def p_available(self): def print_available(self):
"""Print data sets available in h5 file.""" """Print data sets available in h5 file."""
if not hasattr(self, "_available"): if not hasattr(self, "_available"):
self._available = [] self._available = []
...@@ -218,7 +222,6 @@ class Signal(BridgeH5): ...@@ -218,7 +222,6 @@ class Signal(BridgeH5):
f.visititems(self.store_signal_path) f.visititems(self.store_signal_path)
except Exception as e: except Exception as e:
self._log("Exception when visiting h5: {}".format(e), "exception") self._log("Exception when visiting h5: {}".format(e), "exception")
return self._available return self._available
def get_merged(self, dataset): def get_merged(self, dataset):
...@@ -250,6 +253,7 @@ class Signal(BridgeH5): ...@@ -250,6 +253,7 @@ class Signal(BridgeH5):
in_minutes: bool = True, in_minutes: bool = True,
lineage: bool = False, lineage: bool = False,
stop_on_lineage_check: bool = True, stop_on_lineage_check: bool = True,
tmax_in_mins: int = None,
**kwargs, **kwargs,
) -> pd.DataFrame or t.List[pd.DataFrame]: ) -> pd.DataFrame or t.List[pd.DataFrame]:
""" """
...@@ -265,6 +269,11 @@ class Signal(BridgeH5): ...@@ -265,6 +269,11 @@ class Signal(BridgeH5):
If True, add mother_label to index. If True, add mother_label to index.
run_lineage_check: boolean run_lineage_check: boolean
If True, raise exception if a likely error in the lineage assignment. If True, raise exception if a likely error in the lineage assignment.
tmax_in_mins: int (optional)
Discard data for times > tmax_in_mins. Cells with all NaNs will also
be discarded to help with assigning lineages.
Setting tmax_in_mins is a way to ignore parts of the experiment with
incorrect lineages generated by clogging.
""" """
try: try:
if isinstance(dataset, str): if isinstance(dataset, str):
...@@ -274,6 +283,10 @@ class Signal(BridgeH5): ...@@ -274,6 +283,10 @@ class Signal(BridgeH5):
df = df.sort_index() df = df.sort_index()
if in_minutes: if in_minutes:
df = self.cols_in_mins(df) df = self.cols_in_mins(df)
# limit data by time and discard NaNs
if tmax_in_mins and type(tmax_in_mins) is int:
df = df[df.columns[df.columns < tmax_in_mins]]
df = df.dropna(how="all")
# add mother label to data frame # add mother label to data frame
if lineage: if lineage:
mother_label = np.zeros(len(df), dtype=int) mother_label = np.zeros(len(df), dtype=int)
...@@ -293,7 +306,12 @@ class Signal(BridgeH5): ...@@ -293,7 +306,12 @@ class Signal(BridgeH5):
return df return df
elif isinstance(dataset, list): elif isinstance(dataset, list):
return [ return [
self.get_raw(dset, in_minutes=in_minutes, lineage=lineage) self.get_raw(
dset,
in_minutes=in_minutes,
lineage=lineage,
tmax_in_mins=tmax_in_mins,
)
for dset in dataset for dset in dataset
] ]
except Exception as e: except Exception as e:
......
...@@ -33,7 +33,7 @@ class Grouper(ABC): ...@@ -33,7 +33,7 @@ class Grouper(ABC):
def load_positions(self) -> None: def load_positions(self) -> None:
"""Load a Signal for each position, or h5 file.""" """Load a Signal for each position, or h5 file."""
self.positions = {f.name[:-3]: Signal(f) for f in self.files} self.positions = {f.name[:-3]: Signal(f) for f in sorted(self.files)}
@property @property
def first_signal(self) -> Signal: def first_signal(self) -> Signal:
...@@ -132,23 +132,25 @@ class Grouper(ABC): ...@@ -132,23 +132,25 @@ class Grouper(ABC):
# check for errors # check for errors
errors = [ errors = [
position position
for record, position in zip(records, self.positions.keys()) for record, position in zip(records, good_positions.keys())
if record is None if record is None
] ]
records = [record for record in records if record is not None] records = [record for record in records if record is not None]
if len(errors): if len(errors):
print(f"Warning: Positions ({errors}) contain errors.") print(f"Warning: Positions ({errors}) contain errors.")
assert len(records), "All data sets contain errors" assert len(records), "All data sets contain errors"
# combine into one dataframe # combine into one data frame
concat = pd.concat(records, axis=0) concat = pd.concat(records, axis=0)
if len(concat.index.names) > 4: if len(concat.index.names) > 4:
# reorder levels in the multi-index dataframe # reorder levels in the multi-index data frame
# when mother_label is present # when mother_label is present
concat = concat.reorder_levels( concat = concat.reorder_levels(
("group", "position", "trap", "cell_label", "mother_label") ("group", "position", "trap", "cell_label", "mother_label")
) )
concat_sorted = concat.sort_index() concat_sorted = concat.sort_index()
return concat_sorted return concat_sorted
else:
print("No data found.")
def filter_positions(self, path: str) -> t.Dict[str, Signal]: def filter_positions(self, path: str) -> t.Dict[str, Signal]:
"""Filter positions to those whose data is available in the h5 file.""" """Filter positions to those whose data is available in the h5 file."""
...@@ -267,13 +269,8 @@ class Grouper(ABC): ...@@ -267,13 +269,8 @@ class Grouper(ABC):
@property @property
def groups(self) -> t.Tuple[str]: def groups(self) -> t.Tuple[str]:
"""Get groups, sorted alphabetically, as a tuple.""" """Get groups, sorted alphabetically, as a list."""
return tuple(sorted(set(self.positions_groups.values()))) return list(sorted(set(self.positions_groups.values())))
@property
def positions(self) -> t.Tuple[str]:
"""Get positions, sorted alphabetically, as a tuple."""
return tuple(sorted(set(self.positions_groups.keys())))
def concat_one_signal( def concat_one_signal(
...@@ -282,29 +279,39 @@ def concat_one_signal( ...@@ -282,29 +279,39 @@ def concat_one_signal(
group: str, group: str,
mode: str = "retained", mode: str = "retained",
position_name=None, position_name=None,
tmax_in_mins_dict=None,
**kwargs, **kwargs,
) -> pd.DataFrame: ) -> pd.DataFrame:
""" """Retrieve a signal for one position."""
Retrieve an individual signal. if tmax_in_mins_dict and position_name in tmax_in_mins_dict:
tmax_in_mins = tmax_in_mins_dict[position_name]
Applies filtering if requested and adjusts indices. else:
""" tmax_in_mins = None
if position_name is None: if position_name is None:
# name of h5 file # name of h5 file
position_name = position.stem position_name = position.stem
print(f"Finding signal for {position_name}.") if tmax_in_mins:
print(
f" Loading {path} for {position_name} up to time {tmax_in_mins}."
)
else:
print(f" Loading {path} for {position_name}.")
if mode == "retained": if mode == "retained":
combined = position.retained(path, **kwargs) combined = position.retained(path, tmax_in_mins=tmax_in_mins, **kwargs)
elif mode == "raw": elif mode == "raw":
combined = position.get_raw(path, **kwargs) combined = position.get_raw(path, tmax_in_mins=tmax_in_mins, **kwargs)
elif mode == "daughters": elif mode == "daughters":
combined = position.get_raw(path, lineage=True, **kwargs) combined = position.get_raw(
path, lineage=True, tmax_in_mins=tmax_in_mins, **kwargs
)
if combined is not None: if combined is not None:
combined = combined.loc[ combined = combined.loc[
combined.index.get_level_values("mother_label") > 0 combined.index.get_level_values("mother_label") > 0
] ]
elif mode == "mothers": elif mode == "mothers":
combined = position.get_raw(path, lineage=True, **kwargs) combined = position.get_raw(
path, lineage=True, tmax_in_mins=tmax_in_mins, **kwargs
)
if combined is not None: if combined is not None:
combined = combined.loc[ combined = combined.loc[
combined.index.get_level_values("mother_label") == 0 combined.index.get_level_values("mother_label") == 0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment