Skip to content
Snippets Groups Projects
Commit 8bc14018 authored by Alán Muñoz's avatar Alán Muñoz
Browse files

add time-awareness to StateWriter

parent f6968e17
No related branches found
No related tags found
No related merge requests found
......@@ -106,9 +106,7 @@ class StateReader(DynamicReader):
states[k][val_name] = np.array([val[0] for val in v])
for trap_id, ba_matrix in enumerate(data["ba_cum"]):
states[trap_id]["ba_cum"] = np.zeros((0, 0), dtype=np.float64)
if ba_matrix.any():
states[trap_id]["ba_cum"] = np.array(ba_matrix, dtype=np.float64)
states[trap_id]["ba_cum"] = np.array(ba_matrix, dtype=np.float64)
return [val for val in states.values()]
......
......@@ -282,7 +282,7 @@ class BabyWriter(DynamicWriter):
edgemask_dset = hgroup.get(key + "/values", None)
if edgemask_dset and tp <= edgemask_dset[()].shape[1]:
print("BabyWriter: Skipping tp {tp}")
print(f"BabyWriter: Skipping tp {tp}")
else:
self.write_edgemasks(value, keys, hgroup)
else:
......@@ -367,9 +367,29 @@ class StateWriter(DynamicWriter):
return formatted_state
def write(self, data, overwrite: Iterable):
formatted_data = self.format_states(data)
super().write(data=formatted_data, overwrite=overwrite)
def write(self, data, overwrite: Iterable, tp: int = None):
# formatted_data = self.format_states(data)
# super().write(data=formatted_data, overwrite=overwrite)
last_tp = 0
if tp is None:
tp = 0
try:
with h5py.File(self.file, "r") as f:
gr = f.get(self.group, None)
if gr:
last_tp = gr.attrs.get("tp", None)
if not tp or tp > last_tp:
formatted_data = self.format_states(data)
super().write(data=formatted_data, overwrite=overwrite)
with h5py.File(self.file, "a") as f:
print(f"Writing tp {tp}")
f[self.group].attrs["tp"] = tp
elif tp and tp <= last_tp:
print(f"Skipping timepoint {tp}")
except Exception as e:
raise (e)
#################### Extraction version ###############################
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment