Skip to content
Snippets Groups Projects
Commit f6968e17 authored by Alán Muñoz's avatar Alán Muñoz
Browse files

make reader and writer more compatible

parent a5074c99
No related branches found
No related tags found
No related merge requests found
...@@ -37,12 +37,12 @@ class StateReader(DynamicReader): ...@@ -37,12 +37,12 @@ class StateReader(DynamicReader):
"max_lbl": ((None, 1), np.uint16), "max_lbl": ((None, 1), np.uint16),
"tp_back": ((None, 1), np.uint16), "tp_back": ((None, 1), np.uint16),
"trap": ((None, 1), np.int16), "trap": ((None, 1), np.int16),
"cell_label": ((None, 1), np.uint16), "cell_lbls": ((None, 1), np.uint16),
"prev_feats": ((None, None), np.float32), "prev_feats": ((None, None), np.float64),
"lifetime": ((None, 2), np.uint16), "lifetime": ((None, 2), np.uint16),
"p_was_bud": ((None, 2), np.float32), "p_was_bud": ((None, 2), np.float64),
"p_is_mother": ((None, 2), np.float32), "p_is_mother": ((None, 2), np.float64),
"ba_cum": ((None, None), np.float32), "ba_cum": ((None, None), np.float64),
} }
group = "last_state" group = "last_state"
...@@ -82,24 +82,35 @@ class StateReader(DynamicReader): ...@@ -82,24 +82,35 @@ class StateReader(DynamicReader):
trap_as_idx = copy(data["trap"]) trap_as_idx = copy(data["trap"])
states = {k: {"max_lbl": v} for k, v in enumerate(data["max_lbl"])} states = {k: {"max_lbl": v} for k, v in enumerate(data["max_lbl"])}
for val_name in ("cell_label", "prev_feats"): for val_name in ("cell_lbls", "prev_feats"):
for k in states.keys():
if val_name == "cell_lbls":
states[k][val_name] = [[] for _ in range(ntps_back)]
else:
states[k][val_name] = [
np.zeros((0, data[val_name].shape[1]), dtype=np.float64)
for _ in range(ntps_back)
]
data[val_name] = list(zip(trap_as_idx, tpback_as_idx, data[val_name])) data[val_name] = list(zip(trap_as_idx, tpback_as_idx, data[val_name]))
for k, v in groupsort(data[val_name]).items(): for k, v in groupsort(data[val_name]).items():
states[k][val_name] = [ states[k][val_name] = [
[w[0] for w in val] for val in groupsort(v).values() np.array([w[0] for w in val]) for val in groupsort(v).values()
] ]
for val_name in ("lifetime", "p_was_bud", "p_is_mother"): for val_name in ("lifetime", "p_was_bud", "p_is_mother"):
for k in states.keys():
states[k][val_name] = np.array([])
# This contains no time points back # This contains no time points back
for k, v in groupsort(data[val_name]).items(): for k, v in groupsort(data[val_name]).items():
states[k][val_name] = [val[0] for val in v] states[k][val_name] = np.array([val[0] for val in v])
for trap_id, ba_matrix in enumerate(data["ba_cum"]): for trap_id, ba_matrix in enumerate(data["ba_cum"]):
states[trap_id]["ba_cum"] = np.array([]) states[trap_id]["ba_cum"] = np.zeros((0, 0), dtype=np.float64)
if ba_matrix.sum(): if ba_matrix.any():
states[trap_id]["ba_cum"] = ba_matrix states[trap_id]["ba_cum"] = np.array(ba_matrix, dtype=np.float64)
return states return [val for val in states.values()]
def get_formatted_states(self): def get_formatted_states(self):
return self.reconstruct_states(self.read_all()) return self.reconstruct_states(self.read_all())
...@@ -263,7 +263,7 @@ class BabyWriter(DynamicWriter): ...@@ -263,7 +263,7 @@ class BabyWriter(DynamicWriter):
else: else:
self.__append_edgemasks(hgroup, edgemasks, current_indices) self.__append_edgemasks(hgroup, edgemasks, current_indices)
def write(self, data, overwrite: list): def write(self, data, overwrite: list, tp: int = None):
with h5py.File(self.file, "a") as store: with h5py.File(self.file, "a") as store:
hgroup = store.require_group(self.group) hgroup = store.require_group(self.group)
...@@ -279,7 +279,12 @@ class BabyWriter(DynamicWriter): ...@@ -279,7 +279,12 @@ class BabyWriter(DynamicWriter):
elif key == "edgemasks": elif key == "edgemasks":
keys = ["trap", "cell_label", "edgemasks"] keys = ["trap", "cell_label", "edgemasks"]
value = [data[x] for x in keys] value = [data[x] for x in keys]
self.write_edgemasks(value, keys, hgroup)
edgemask_dset = hgroup.get(key + "/values", None)
if edgemask_dset and tp <= edgemask_dset[()].shape[1]:
print("BabyWriter: Skipping tp {tp}")
else:
self.write_edgemasks(value, keys, hgroup)
else: else:
self._append(value, key, hgroup) self._append(value, key, hgroup)
except Exception as e: except Exception as e:
...@@ -293,7 +298,7 @@ class StateWriter(DynamicWriter): ...@@ -293,7 +298,7 @@ class StateWriter(DynamicWriter):
"max_lbl": ((None, 1), np.uint16), "max_lbl": ((None, 1), np.uint16),
"tp_back": ((None, 1), np.uint16), "tp_back": ((None, 1), np.uint16),
"trap": ((None, 1), np.int16), "trap": ((None, 1), np.int16),
"cell_label": ((None, 1), np.uint16), "cell_lbls": ((None, 1), np.uint16),
"prev_feats": ((None, None), np.float32), "prev_feats": ((None, None), np.float32),
"lifetime": ((None, 2), np.uint16), "lifetime": ((None, 2), np.uint16),
"p_was_bud": ((None, 2), np.float32), "p_was_bud": ((None, 2), np.float32),
...@@ -347,7 +352,7 @@ class StateWriter(DynamicWriter): ...@@ -347,7 +352,7 @@ class StateWriter(DynamicWriter):
# Heterogeneous datasets # Heterogeneous datasets
formatted_state["tp_back"] = tp_back formatted_state["tp_back"] = tp_back
formatted_state["trap"] = trap formatted_state["trap"] = trap
formatted_state["cell_label"] = cell_label formatted_state["cell_lbls"] = cell_label
formatted_state["prev_feats"] = np.array(prev_feats) formatted_state["prev_feats"] = np.array(prev_feats)
# One entry per cell label - tp_back independent # One entry per cell label - tp_back independent
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment