Skip to content
Snippets Groups Projects
Commit e0a48510 authored by pswain's avatar pswain
Browse files

minor code re-arrangements

parent ba5266d4
No related branches found
No related tags found
No related merge requests found
......@@ -19,14 +19,11 @@ class Signal(BridgeH5):
"""
Class that fetches data from the hdf5 storage for post-processing
Signal is works under the assumption that metadata and data are
accessible, to perform time-adjustments and apply previously-recorded
postprocesses.
Signal assumes that the metadata and data are accessible to perform time-adjustments and apply previously-recorded postprocesses.
"""
def __init__(self, file: t.Union[str, PosixPath]):
super().__init__(file, flag=None)
self.index_names = (
"experiment",
"position",
......@@ -34,7 +31,6 @@ class Signal(BridgeH5):
"cell_label",
"mother_label",
)
self.candidate_channels = (
"GFP",
"GFPFast",
......@@ -45,21 +41,14 @@ class Signal(BridgeH5):
"Cy5",
"pHluorin405",
)
equivalences = {
"m5m": ("extraction/GFP/max/max5px", "extraction/GFP/max/median")
}
def __getitem__(self, dsets: t.Union[str, t.Collection]):
if isinstance(
dsets, str
): # or isinstance(Dsets,dsets.endswith("imBackground"):
if isinstance(dsets, str):
df = self.get_raw(dsets)
# elif isinstance(dsets, str):
# df = self.apply_prepost(dsets)
return self.add_name(df, dsets)
elif isinstance(dsets, list):
is_bgd = [dset.endswith("imBackground") for dset in dsets]
assert sum(is_bgd) == 0 or sum(is_bgd) == len(
......@@ -71,9 +60,6 @@ class Signal(BridgeH5):
else:
raise Exception(f"Invalid type {type(dsets)} to get datasets")
# return self.cols_in_mins(self.add_name(df, dsets))
return self.add_name(df, dsets)
@staticmethod
def add_name(df, name):
df.name = name
......@@ -96,7 +82,11 @@ class Signal(BridgeH5):
def tinterval(self) -> int:
tinterval_location = "time_settings/timeinterval"
with h5py.File(self.filename, "r") as f:
return f.attrs[tinterval_location][0]
if tinterval_location in f:
return f.attrs[tinterval_location][0]
else:
print("Using default time interval of 5 minutes")
return 5.0
@staticmethod
def get_retained(df, cutoff):
......@@ -109,14 +99,10 @@ class Signal(BridgeH5):
@_first_arg_str_to_df
def retained(self, signal, cutoff=0.8):
df = signal
# df = self[signal]
if isinstance(df, pd.DataFrame):
return self.get_retained(df, cutoff)
elif isinstance(df, list):
return [self.get_retained(d, cutoff=cutoff) for d in df]
if isinstance(signal, pd.DataFrame):
return self.get_retained(signal, cutoff)
elif isinstance(signal, list):
return [self.get_retained(d, cutoff=cutoff) for d in signal]
@lru_cache(2)
def lineage(
......@@ -132,7 +118,6 @@ class Signal(BridgeH5):
lineage_location = "postprocessing/lineage"
if merged:
lineage_location += "_merged"
with h5py.File(self.filename, "r") as f:
trap_mo_da = f[lineage_location]
lineage = np.array(
......@@ -175,31 +160,26 @@ class Signal(BridgeH5):
"""
if isinstance(merges, bool):
merges: np.ndarray = self.get_merges() if merges else np.array([])
merged = copy(data)
if merges.any():
merged = apply_merges(data, merges)
else:
merged = copy(data)
if isinstance(picks, bool):
picks = (
self.get_picks(names=merged.index.names)
if picks
else set(merged.index)
)
with h5py.File(self.filename, "r") as f:
if "modifiers/picks" in f and picks:
# missing_cells = [i for i in picks if tuple(i) not in
# set(merged.index)]
if picks:
return merged.loc[
set(picks).intersection(
[tuple(x) for x in merged.index]
)
]
else:
if isinstance(merged.index, pd.MultiIndex):
empty_lvls = [[] for i in merged.index.names]
......@@ -217,10 +197,8 @@ class Signal(BridgeH5):
def datasets(self):
if not hasattr(self, "_available"):
self._available = []
with h5py.File(self.filename, "r") as f:
f.visititems(self.store_signal_url)
for sig in self._available:
print(sig)
......@@ -238,10 +216,8 @@ class Signal(BridgeH5):
with h5py.File(self.filename, "r") as f:
f.visititems(self.store_signal_url)
except Exception as e:
print("Error visiting h5: {}".format(e))
return self._available
def get_merged(self, dataset):
......@@ -266,6 +242,17 @@ class Signal(BridgeH5):
def get_raw(
self, dataset: str, in_minutes: bool = True, lineage: bool = False
):
"""
Load data from a h5 file and return as a dataframe
Parameters
----------
dataset: str or list of strs
The name of the h5 file or a list of h5 file names
in_minutes: boolean
If True,
lineage: boolean
"""
try:
if isinstance(dataset, str):
with h5py.File(self.filename, "r") as f:
......@@ -274,8 +261,8 @@ class Signal(BridgeH5):
df = self.cols_in_mins(df)
elif isinstance(dataset, list):
return [self.get_raw(dset) for dset in dataset]
if lineage: # This assumes that df is sorted
if lineage:
# assumes that df is sorted
mother_label = np.zeros(len(df), dtype=int)
lineage = self.lineage()
a, b = validate_association(
......@@ -285,9 +272,7 @@ class Signal(BridgeH5):
)
mother_label[b] = lineage[a, 1]
df = add_index_levels(df, {"mother_label": mother_label})
return df
except Exception as e:
print(f"Could not fetch dataset {dataset}")
raise e
......@@ -298,7 +283,6 @@ class Signal(BridgeH5):
merges = f.get("modifiers/merges", np.array([]))
if not isinstance(merges, np.ndarray):
merges = merges[()]
return merges
def get_picks(
......@@ -313,34 +297,25 @@ class Signal(BridgeH5):
picks = set()
if path in f:
picks = set(zip(*[f[path + name] for name in names]))
return picks
def dataset_to_df(self, f: h5py.File, path: str) -> pd.DataFrame:
"""
Fetch DataFrame from results storage file.
"""
assert path in f, f"{path} not in {f}"
dset = f[path]
values, index, columns = ([], [], [])
values, index, columns = [], [], []
index_names = copy(self.index_names)
valid_names = [lbl for lbl in index_names if lbl in dset.keys()]
if valid_names:
index = pd.MultiIndex.from_arrays(
[dset[lbl] for lbl in valid_names], names=valid_names
)
columns = dset.attrs.get("columns", None) # dset.attrs["columns"]
columns = dset.attrs.get("columns", None)
if "timepoint" in dset:
columns = f[path + "/timepoint"][()]
values = f[path + "/values"][()]
return pd.DataFrame(
values,
index=index,
......@@ -351,24 +326,6 @@ class Signal(BridgeH5):
def stem(self):
return self.filename.stem
# def dataset_to_df(self, f: h5py.File, path: str):
# all_indices = self.index_names
# valid_indices = {
# k: f[path][k][()] for k in all_indices if k in f[path].keys()
# }
# new_index = pd.MultiIndex.from_arrays(
# list(valid_indices.values()), names=valid_indices.keys()
# )
# return pd.DataFrame(
# f[path + "/values"][()],
# index=new_index,
# columns=f[path + "/timepoint"][()],
# )
def store_signal_url(
self, fullname: str, node: t.Union[h5py.Dataset, h5py.Group]
):
......@@ -413,7 +370,6 @@ class Signal(BridgeH5):
flowrate_name = "pumpinit/flowrate"
pumprate_name = "pumprate"
switchtimes_name = "switchtimes"
main_pump_id = np.concatenate(
(
(np.argmax(self.meta_h5[flowrate_name]),),
......@@ -436,7 +392,6 @@ class Signal(BridgeH5):
def switch_times(self) -> t.List[int]:
switchtimes_name = "switchtimes"
switches_minutes = self.meta_h5[switchtimes_name]
return [
t_min
for t_min in switches_minutes
......
......@@ -20,13 +20,15 @@ class Chainer(Signal):
Instead of reading processes previously applied, it executes
them when called.
"""
process_types = ("multisignal", "processes", "reshapers")
common_chains = {}
#process_types = ("multisignal", "processes", "reshapers")
#common_chains = {}
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
for channel in self.candidate_channels:
# find first channel in h5 file that corresponds to a candidate_channel
# but channel is redefined. why is there a loop over candidate channels?
# what about capitals?
try:
channel = [
ch for ch in self.channels if re.match("channel", ch)
......@@ -34,8 +36,9 @@ class Chainer(Signal):
break
except:
pass
try:
# what's this?
# composite statistic comprising the quotient of two others
equivalences = {
"m5m": (
f"extraction/{channel}/max/max5px",
......@@ -43,13 +46,15 @@ class Chainer(Signal):
),
}
# function to add bgsub to urls
def replace_url(url: str, bgsub: str = ""):
# return pattern with bgsub
channel = url.split("/")[1]
if "bgsub" in bgsub:
# add bgsub to url
url = re.sub(channel, f"{channel}_bgsub", url)
return url
# add chain with and without bgsub
self.common_chains = {
alias
+ bgsub: lambda **kwargs: self.get(
......@@ -59,7 +64,6 @@ class Chainer(Signal):
for alias, (denominator, numerator) in equivalences.items()
for bgsub in ("", "_bgsub")
}
except:
pass
......@@ -72,20 +76,17 @@ class Chainer(Signal):
retain: t.Optional[float] = None,
**kwargs,
):
if dataset in self.common_chains: # Produce dataset on the fly
if dataset in self.common_chains:
# produce dataset on the fly
data = self.common_chains[dataset](**kwargs)
else:
data = self.get_raw(dataset, in_minutes=in_minutes)
if chain:
data = self.apply_chain(data, chain, **kwargs)
if retain:
data = data.loc[data.notna().sum(axis=1) > data.shape[1] * retain]
if (
stages and "stage" not in data.columns.names
): # Return stages as additional column level
if (stages and "stage" not in data.columns.names):
# return stages as additional column level
stages_index = [
x
for i, (name, span) in enumerate(self.stages_span_tp)
......@@ -95,13 +96,13 @@ class Chainer(Signal):
zip(stages_index, data.columns),
names=("stage", "time"),
)
return data
def apply_chain(
self, input_data: pd.DataFrame, chain: t.Tuple[str, ...], **kwargs
):
"""Apply a series of processes to a dataset.
"""
Apply a series of processes to a dataset.
In a similar fashion to how postprocessing works, Chainer allows the
consecutive application of processes to a dataset. Parameters can be
......
......@@ -26,8 +26,6 @@ from postprocessor.chainer import Chainer
class Grouper(ABC):
"""Base grouper class."""
files = []
def __init__(self, dir: Union[str, PosixPath]):
path = Path(dir)
self.name = path.name
......@@ -37,12 +35,11 @@ class Grouper(ABC):
self.load_chains()
def load_chains(self) -> None:
# Sets self.chainers
self.chainers = {f.name[:-3]: Chainer(f) for f in self.files}
@property
def fsignal(self) -> Chainer:
# Returns first signal
# returns first signal
return list(self.chainers.values())[0]
@property
......@@ -110,14 +107,12 @@ class Grouper(ABC):
"""
if path.startswith("/"):
path = path.strip("/")
sitems = self.filter_path(path)
if standard:
fn_pos = concat_standard
else:
fn_pos = concat_signal_ind
kwargs["mode"] = mode
kymographs = self.pool_function(
path=path,
f=fn_pos,
......@@ -125,7 +120,6 @@ class Grouper(ABC):
chainers=sitems,
**kwargs,
)
errors = [
k
for kymo, k in zip(kymographs, self.chainers.keys())
......@@ -134,20 +128,15 @@ class Grouper(ABC):
kymographs = [kymo for kymo in kymographs if kymo is not None]
if len(errors):
print("Warning: Positions contain errors {errors}")
assert len(kymographs), "All datasets contain errors"
concat = pd.concat(kymographs, axis=0)
if (
len(concat.index.names) > 4
): # Reorder levels when mother_label is present
concat = concat.reorder_levels(
("group", "position", "trap", "cell_label", "mother_label")
)
concat_sorted = concat.sort_index()
return concat_sorted
def filter_path(self, path: str) -> t.Dict[str, Chainer]:
......@@ -163,11 +152,9 @@ class Grouper(ABC):
f"Grouper:Warning: {nchains_dif} chains do not contain"
f" channel {path}"
)
assert len(
sitems
), f"No valid dataset to use. Valid datasets are {self.available}"
return sitems
@property
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment