diff --git a/src/agora/io/writer.py b/src/agora/io/writer.py index 6ce821b9aa2510f8128b8072dc4a535a52f75da2..faadeab4cf47961f4991f5be935dc80c9757418f 100644 --- a/src/agora/io/writer.py +++ b/src/agora/io/writer.py @@ -17,7 +17,7 @@ from agora.io.utils import timed def load_attributes(file: str, group="/"): - ''' + """ Load the metadata from an h5 file and convert to a dictionary, including the "parameters" field which is stored as YAML. Parameters @@ -26,7 +26,7 @@ def load_attributes(file: str, group="/"): Name of the h5 file group: str, optional The group in the h5 file from which to read the data - ''' + """ # load the metadata, stored as attributes, from the h5 file and return as a dictionary with h5py.File(file, "r") as f: meta = dict(f[group].attrs.items()) @@ -37,10 +37,9 @@ def load_attributes(file: str, group="/"): class DynamicWriter: - ''' - A parent class for all writers - ''' - # a dict giving a tuple of the maximum size, as a 2D tuple, and the type for each dataset + """Provides a parent class for all writers.""" + + # a dict giving for each dataset a tuple, comprising the dataset's maximum size, as a 2D tuple, and its type data_types = {} # the group in the h5 file to write to group = "" @@ -55,7 +54,7 @@ class DynamicWriter: def _append(self, data, key, hgroup): """ - Append data to existing dataset in the h5 file otherwise create a new one + Append data to existing dataset in the h5 file otherwise create a new one. Parameters ---------- @@ -107,7 +106,7 @@ class DynamicWriter: def _overwrite(self, data, key, hgroup): """ - Delete and then replace existing dataset in h5 file + Delete and then replace existing dataset in h5 file. Parameters ---------- @@ -135,9 +134,9 @@ class DynamicWriter: # if key not in self.datatypes: # raise KeyError(f"No defined data type for key {key}") - def write(self, data: dict, overwrite: list, meta={}): - ''' - Write data and metadata to h5 file + def write(self, data: dict, overwrite: list, meta: dict = {}): + """ + Write data and metadata to h5 file. Parameters ---------- @@ -147,7 +146,7 @@ class DynamicWriter: A list of datasets to overwrite meta: dict, optional Metadata to be written as attributes of the h5 file - ''' + """ with h5py.File(self.file, "a") as store: # open group, creating if necessary hgroup = store.require_group(self.group) @@ -178,6 +177,8 @@ class DynamicWriter: ##################### Special instances ##################### class TilerWriter(DynamicWriter): + """Write data stored in a Tiler instance to h5 files.""" + datatypes = { "trap_locations": ((None, 2), np.uint16), "drifts": ((None, 2), np.float32), @@ -186,9 +187,20 @@ class TilerWriter(DynamicWriter): } group = "trap_info" - def write(self, data, overwrite: list, tp: int, meta={}): + def write(self, data: dict, overwrite: list, tp: int, meta: dict = {}): """ - Custom function to avoid writing over any data that already exists at time point tp + Write data for time points that have none. + + Parameters + ---------- + data: dict + A dict of datasets and data + overwrite: list of str + A list of datasets to overwrite + tp: int + The time point of interest + meta: dict, optional + Metadata to be written as attributes of the h5 file """ skip = False # append to h5 file @@ -226,6 +238,12 @@ def load_complex(dataset): class BabyWriter(DynamicWriter): + """ + Write data stored in a Baby instance to h5 files. + + Assumes the edgemasks are of form ((max_ncells, max_tps, tile_size, tile_size), bool). + """ + compression = "gzip" max_ncells = 2e5 # Alan: Could just make this None max_tps = 1e3 # Could just make this None @@ -324,7 +342,19 @@ class BabyWriter(DynamicWriter): save_complex(missing, ix_dset) def write_edgemasks(self, data, keys, hgroup): - # data has shape (trap_ids, cell_labels, edgemasks) and is a structured array + """ + Write edgemasks to h5 file. + + Parameters + ---------- + data: list of arrays + Data to be written, in the form (trap_ids, cell_labels, edgemasks) + keys: list of str + Names corresponding to the elements of data. + For example: ["trap", "cell_label", "edgemasks"] + hgroup: group object + Group to write to in h5 file. + """ if not self._traps_initialised: self.__init_trap_info() key = "edgemasks" @@ -339,9 +369,24 @@ class BabyWriter(DynamicWriter): else: self.__append_edgemasks(hgroup, edgemasks, current_indices) - def write(self, data, overwrite: list, tp: int = None, meta={}): + def write(self, data: dict, overwrite: list, tp: int = None, meta: dict = {}): + """ + Write data from a Baby instance, including edgemasks. + + Parameters + ---------- + data: dict + A dict of datasets and data + overwrite: list of str + A list of datasets to overwrite + tp: int + The time point of interest + meta: dict, optional + Metadata to be written as attributes of the h5 file + """ with h5py.File(self.file, "a") as store: hgroup = store.require_group(self.group) + # write data for key, value in data.items(): if key not in self.datatypes: raise KeyError(f"BabyWriter: No defined data type for key {key}") @@ -378,7 +423,13 @@ class BabyWriter(DynamicWriter): class LinearBabyWriter(DynamicWriter): - # TODO make this YAML + """ + Write data stored in a Baby instance to h5 files. + + Assumes the edgemasks are of form ((None, tile_size, tile_size), bool). + """ + + # TODO make this YAML: Alan: why? compression = "gzip" datatypes = { "centres": ((None, 2), np.uint16), @@ -396,31 +447,51 @@ class LinearBabyWriter(DynamicWriter): } group = "cell_info" - def write(self, data, overwrite: list, tp=None, meta={}): - # Data is a dictionary, if not, make it one - # Overwrite data is a list + def write(self, data: dict, overwrite: list, tp: int = None, meta: dict = {}): + """ + Check data does not exist before writing. + Parameters + ---------- + data: dict + A dict of datasets and data + overwrite: list of str + A list of datasets to overwrite + tp: int + The time point of interest + meta: dict, optional + Metadata to be written as attributes of the h5 file + """ with h5py.File(self.file, "a") as store: hgroup = store.require_group(self.group) available_tps = hgroup.get("timepoint", None) + # write data if not available_tps or tp not in np.unique(available_tps[()]): super().write(data, overwrite) else: + # data already exists print(f"BabyWriter: Skipping tp {tp}") - + # write metadata for key, value in meta.items(): hgroup.attrs[key] = value class StateWriter(DynamicWriter): + """Write information summarising the current state of the pipeline - the 'last_state' dataset in the h5 file""" + datatypes = { + # the highest cell label assigned for each time point "max_lbl": ((None, 1), np.uint16), "tp_back": ((None, 1), np.uint16), + # trap labels "trap": ((None, 1), np.int16), + # cell labels "cell_lbls": ((None, 1), np.uint16), "prev_feats": ((None, None), np.float32), "lifetime": ((None, 2), np.uint16), + # probability of being a bud "p_was_bud": ((None, 2), np.float32), + # probability of being a mother "p_is_mother": ((None, 2), np.float32), "ba_cum": ((None, None), np.float32), } @@ -429,7 +500,7 @@ class StateWriter(DynamicWriter): @staticmethod def format_field(states: list, field: str): - # Flatten a field in the states list to save as an hdf5 dataset + """Flatten a field in the states list to save as an h5 dataset.""" fields = [pos_state[field] for pos_state in states] return fields @@ -438,7 +509,6 @@ class StateWriter(DynamicWriter): tp_back, trap, value = [ [[] for _ in states[0][val_name]] for _ in range(3) ] - lbl_tuples = [ (tp_back, trap, cell_label) for trap, state in enumerate(states) @@ -447,7 +517,6 @@ class StateWriter(DynamicWriter): ] if len(lbl_tuples): tp_back, trap, value = zip(*lbl_tuples) - return tp_back, trap, value @staticmethod @@ -463,10 +532,10 @@ class StateWriter(DynamicWriter): @staticmethod def pad_if_needed(array: np.ndarray, pad_size: int): + """Pad a 2D array with zeros.""" padded = np.zeros((pad_size, pad_size)).astype(float) length = len(array) padded[:length, :length] = array - return padded def format_states(self, states: list): @@ -475,88 +544,78 @@ class StateWriter(DynamicWriter): states, "cell_lbls" ) _, _, prev_feats = self.format_values_tpback(states, "prev_feats") - - # Heterogeneous datasets + # heterogeneous datasets formatted_state["tp_back"] = tp_back formatted_state["trap"] = trap formatted_state["cell_lbls"] = cell_label formatted_state["prev_feats"] = np.array(prev_feats) - - # One entry per cell label - tp_back independent + # one entry per cell label - tp_back independent for val_name in ("lifetime", "p_was_bud", "p_is_mother"): formatted_state[val_name] = self.format_values_traps( states, val_name ) - bacum_max = max([len(state["ba_cum"]) for state in states]) - formatted_state["ba_cum"] = np.array( [ self.pad_if_needed(state["ba_cum"], bacum_max) for state in states ] ) - return formatted_state - def write(self, data, overwrite: Iterable, tp: int = None): - # formatted_data = self.format_states(data) - # super().write(data=formatted_data, overwrite=overwrite) + def write(self, data: dict, overwrite: list, tp: int = None): + """Write.""" if len(data): last_tp = 0 if tp is None: tp = 0 - try: with h5py.File(self.file, "r") as f: gr = f.get(self.group, None) if gr: last_tp = gr.attrs.get("tp", 0) - # print(f"{ self.file } - tp: {tp}, last_tp: {last_tp}") if tp == 0 or tp > last_tp: - # print(f"Writing timepoint {tp}") + # write formatted_data = self.format_states(data) super().write(data=formatted_data, overwrite=overwrite) with h5py.File(self.file, "a") as f: - # print(f"Writing tp {tp}") f[self.group].attrs["tp"] = tp elif tp > 0 and tp <= last_tp: - print(f"BabyWriter: Skipping timepoint {tp}") + # data already present + print(f"StateWriter: Skipping timepoint {tp}") except Exception as e: raise (e) else: - print("Skipping overwriting empty state") + print("Skipping overwriting: no data") #################### Extraction version ############################### class Writer(BridgeH5): """ - Class in charge of transforming data into compatible formats + Class in charge of transforming data into compatible formats. Decoupling interface from implementation! Parameters ---------- - filename: str Name of file to write into - flag: str, default=None - Flag to pass to the default file reader. If None the file remains closed. - compression: str, default=None - Compression method passed on to h5py writing functions (only used for - dataframes and other array-like data.) + filename: str + Name of file to write into + flag: str, default=None + Flag to pass to the default file reader. If None the file remains closed. + compression: str, default="gzip" + Compression method passed on to h5py writing functions (only used for dataframes and other array-like data). """ - def __init__(self, filename, compression=None): - super().__init__(filename, flag=None) - - if compression is None: - self.compression = "gzip" + def __init__(self, filename, flag=None, compression="gzip"): + super().__init__(filename, flag=flag) + self.compression = compression def write( self, path: str, data: Iterable = None, - meta: Dict = {}, + meta: dict = {}, overwrite: str = None, ): """ @@ -565,7 +624,7 @@ class Writer(BridgeH5): path : str Path inside h5 file to write into. data : Iterable, default = None - meta : Dict, default = {} + meta : dict, default = {} """ self.id_cache = {}