diff --git a/src/agora/io/writer.py b/src/agora/io/writer.py index faadeab4cf47961f4991f5be935dc80c9757418f..66fec1f40d69c08412e246be4e21cd20d6fdb410 100644 --- a/src/agora/io/writer.py +++ b/src/agora/io/writer.py @@ -103,7 +103,6 @@ class DynamicWriter: # write all data, signified by the empty tuple hgroup[key][()] = data - def _overwrite(self, data, key, hgroup): """ Delete and then replace existing dataset in h5 file. @@ -125,7 +124,10 @@ class DynamicWriter: del hgroup[key] # write new data hgroup.require_dataset( - key, shape=data_shape, dtype=dtype, compression=self.compression, + key, + shape=data_shape, + dtype=dtype, + compression=self.compression, ) # write all data, signified by the empty tuple hgroup[key][()] = data @@ -153,9 +155,9 @@ class DynamicWriter: # write data for key, value in data.items(): # only save data with a pre-defined data-type - if key not in self.datatypes: - raise KeyError(f"No defined data type for key {key}") - else: + if key not in self.datatypes: + raise KeyError(f"No defined data type for key {key}") + else: try: if key.startswith("attrs/"): # metadata @@ -216,8 +218,6 @@ class TilerWriter(DynamicWriter): if not skip: super().write(data=data, overwrite=overwrite, meta=meta) -# Alan: why's this here? -tile_size = 117 # Alan: we use complex numbers because... @timed() @@ -255,7 +255,10 @@ class BabyWriter(DynamicWriter): "position": ((None,), np.uint16), "angles": ((None,), h5py.vlen_dtype(np.float32)), "radii": ((None,), h5py.vlen_dtype(np.float32)), - "edgemasks": ((max_ncells, max_tps, tile_size, tile_size), bool), + "edgemasks": ( + (max_ncells, max_tps, default_tile_size, default_tile_size), + bool, + ), "ellipse_dims": ((None, 2), np.float32), "cell_label": ((None,), np.uint16), "trap": ((None,), np.uint16), @@ -324,7 +327,9 @@ class BabyWriter(DynamicWriter): n_tps = val_dset.shape[1] + 1 n_add_cells = len(missing) # resize dataset for Time and Cells - new_shape = (val_dset.shape[0] + n_add_cells, n_tps) + val_dset.shape[2:] + new_shape = (val_dset.shape[0] + n_add_cells, n_tps) + val_dset.shape[ + 2: + ] val_dset.resize(new_shape) logging.debug(f"Timing:resizing:{perf_counter() - debug_t}") # write data @@ -369,7 +374,9 @@ class BabyWriter(DynamicWriter): else: self.__append_edgemasks(hgroup, edgemasks, current_indices) - def write(self, data: dict, overwrite: list, tp: int = None, meta: dict = {}): + def write( + self, data: dict, overwrite: list, tp: int = None, meta: dict = {} + ): """ Write data from a Baby instance, including edgemasks. @@ -389,7 +396,9 @@ class BabyWriter(DynamicWriter): # write data for key, value in data.items(): if key not in self.datatypes: - raise KeyError(f"BabyWriter: No defined data type for key {key}") + raise KeyError( + f"BabyWriter: No defined data type for key {key}" + ) else: try: if key.startswith("attrs/"): @@ -408,7 +417,9 @@ class BabyWriter(DynamicWriter): and tp < edgemask_dset[()].shape[1] ): # data already exists - print(f"BabyWriter: Skipping edgemasks in tp {tp}") + print( + f"BabyWriter: Skipping edgemasks in tp {tp}" + ) else: self.write_edgemasks(value, keys, hgroup) else: @@ -431,12 +442,13 @@ class LinearBabyWriter(DynamicWriter): # TODO make this YAML: Alan: why? compression = "gzip" + _default_tile_size = 117 datatypes = { "centres": ((None, 2), np.uint16), "position": ((None,), np.uint16), "angles": ((None,), h5py.vlen_dtype(np.float32)), "radii": ((None,), h5py.vlen_dtype(np.float32)), - "edgemasks": ((None, tile_size, tile_size), bool), + "edgemasks": ((None, _default_tile_size, _default_tile_size), bool), "ellipse_dims": ((None, 2), np.float32), "cell_label": ((None,), np.uint16), "trap": ((None,), np.uint16), @@ -447,7 +459,9 @@ class LinearBabyWriter(DynamicWriter): } group = "cell_info" - def write(self, data: dict, overwrite: list, tp: int = None, meta: dict = {}): + def write( + self, data: dict, overwrite: list, tp: int = None, meta: dict = {} + ): """ Check data does not exist before writing. @@ -477,22 +491,31 @@ class LinearBabyWriter(DynamicWriter): class StateWriter(DynamicWriter): - """Write information summarising the current state of the pipeline - the 'last_state' dataset in the h5 file""" + """ + Write information summarising the current state of the pipeline - the 'last_state' dataset in the h5 file. + + MOVEDatatypes are specified with the first variable specifying the number of traps and the other specifying the shape of the object. + + """ datatypes = { # the highest cell label assigned for each time point "max_lbl": ((None, 1), np.uint16), + # how far back we go for tracking "tp_back": ((None, 1), np.uint16), # trap labels "trap": ((None, 1), np.int16), # cell labels "cell_lbls": ((None, 1), np.uint16), + # previous cell features for tracking "prev_feats": ((None, None), np.float32), + # number of images for which a cell has been present "lifetime": ((None, 2), np.uint16), - # probability of being a bud + # probability of a mother-bud relationship given a bud "p_was_bud": ((None, 2), np.float32), - # probability of being a mother + # probability of a mother-bud relationship given a mother "p_is_mother": ((None, 2), np.float32), + # cumulative matrix, over time, of bud assignments "ba_cum": ((None, None), np.float32), } group = "last_state" @@ -506,21 +529,27 @@ class StateWriter(DynamicWriter): @staticmethod def format_values_tpback(states: list, val_name: str): + """Unpacks a dict of state data into tp_back, trap, value.""" + # initialise as empty lists + # Alan: is this initialisation necessary? tp_back, trap, value = [ [[] for _ in states[0][val_name]] for _ in range(3) ] + # store results as a list of tuples lbl_tuples = [ (tp_back, trap, cell_label) for trap, state in enumerate(states) for tp_back, value in enumerate(state[val_name]) for cell_label in value ] + # unpack list of tuples to define variables if len(lbl_tuples): tp_back, trap, value = zip(*lbl_tuples) return tp_back, trap, value @staticmethod def format_values_traps(states: list, val_name: str): + """Format either lifetime, p_was_bud, or p_is_mother variables as a list.""" formatted = np.array( [ (trap, clabel_val) @@ -532,19 +561,20 @@ class StateWriter(DynamicWriter): @staticmethod def pad_if_needed(array: np.ndarray, pad_size: int): - """Pad a 2D array with zeros.""" + """Pad a 2D array with zeros for large indices.""" padded = np.zeros((pad_size, pad_size)).astype(float) length = len(array) padded[:length, :length] = array return padded def format_states(self, states: list): + """Re-format state data into a dict of lists, with one element per per list per state.""" formatted_state = {"max_lbl": [state["max_lbl"] for state in states]} tp_back, trap, cell_label = self.format_values_tpback( states, "cell_lbls" ) _, _, prev_feats = self.format_values_tpback(states, "prev_feats") - # heterogeneous datasets + # store lists in a dict formatted_state["tp_back"] = tp_back formatted_state["trap"] = trap formatted_state["cell_lbls"] = cell_label @@ -563,23 +593,21 @@ class StateWriter(DynamicWriter): ) return formatted_state - def write(self, data: dict, overwrite: list, tp: int = None): - """Write.""" + def write(self, data: dict, overwrite: list, tp: int = 0): + """Write the current state of the pipeline.""" if len(data): last_tp = 0 - if tp is None: - tp = 0 try: with h5py.File(self.file, "r") as f: gr = f.get(self.group, None) if gr: last_tp = gr.attrs.get("tp", 0) - # print(f"{ self.file } - tp: {tp}, last_tp: {last_tp}") if tp == 0 or tp > last_tp: # write formatted_data = self.format_states(data) super().write(data=formatted_data, overwrite=overwrite) with h5py.File(self.file, "a") as f: + # record that data for the timepoint has been written f[self.group].attrs["tp"] = tp elif tp > 0 and tp <= last_tp: # data already present @@ -593,7 +621,7 @@ class StateWriter(DynamicWriter): #################### Extraction version ############################### class Writer(BridgeH5): """ - Class in charge of transforming data into compatible formats. + Class in charge of transforming data into compatible structures. Decoupling interface from implementation!