diff --git a/src/agora/io/bridge.py b/src/agora/io/bridge.py index 4a2f6093e7d8aaf4ed8e0e0f8abfcfaffb796449..478408f37d1f0724a232876f0f5ce0da5b02b585 100644 --- a/src/agora/io/bridge.py +++ b/src/agora/io/bridge.py @@ -1,5 +1,5 @@ """ -Tools to interact with hdf5 files and handle data consistently. +Tools to interact with h5 files and handle data consistently. """ import collections from itertools import chain, groupby, product @@ -13,26 +13,28 @@ import yaml class BridgeH5: """ - Base class to interact with h5 data stores. - It also contains functions useful to predict how long should segmentation take. + Base class to interact with h5 files. + + It includes functions that predict how long segmentation will take. """ def __init__(self, filename, flag="r"): + """Initialise with the name of the h5 file.""" self.filename = filename if flag is not None: self._hdf = h5py.File(filename, flag) - self._filecheck def _filecheck(self): assert "cell_info" in self._hdf, "Invalid file. No 'cell_info' found." def close(self): + """Close the h5 file.""" self._hdf.close() @property def meta_h5(self) -> t.Dict[str, t.Any]: - # Return metadata as indicated in h5 file + """Return metadata, defining it if necessary.""" if not hasattr(self, "_meta_h5"): with h5py.File(self.filename, "r") as f: self._meta_h5 = dict(f.attrs) @@ -44,24 +46,24 @@ class BridgeH5: @staticmethod def get_consecutives(tree, nstepsback): - # Receives a sorted tree and returns the keys of consecutive elements - vals = {k: np.array(list(v)) for k, v in tree.items()} # get tp level + """Receives a sorted tree and returns the keys of consecutive elements.""" + # get tp level + vals = {k: np.array(list(v)) for k, v in tree.items()} + # get indices of consecutive elements where_consec = [ { k: np.where(np.subtract(v[n + 1 :], v[: -n - 1]) == n + 1)[0] for k, v in vals.items() } for n in range(nstepsback) - ] # get indices of consecutive elements + ] return where_consec def get_npairs(self, nstepsback=2, tree=None): if tree is None: tree = self.cell_tree - consecutive = self.get_consecutives(tree, nstepsback=nstepsback) flat_tree = flatten(tree) - n_predictions = 0 for i, d in enumerate(consecutive, 1): flat = list(chain(*[product([k], list(v)) for k, v in d.items()])) @@ -70,55 +72,49 @@ class BridgeH5: n_predictions += len(flat_tree.get(p[0], [])) * len( flat_tree.get(p[1], []) ) - return n_predictions def get_npairs_over_time(self, nstepsback=2): tree = self.cell_tree npairs = [] - for t in self._hdf["cell_info"]["processed_timepoints"][()]: + for tp in self._hdf["cell_info"]["processed_timepoints"][()]: tmp_tree = { - k: {k2: v2 for k2, v2 in v.items() if k2 <= t} + k: {k2: v2 for k2, v2 in v.items() if k2 <= tp} for k, v in tree.items() } npairs.append(self.get_npairs(tree=tmp_tree)) - return np.diff(npairs) def get_info_tree( self, fields: Union[tuple, list] = ("trap", "timepoint", "cell_label") ): """ - Returns traps, time points and labels for this position in form of a tree - in the hierarchy determined by the argument fields. Note that it is - compressed to non-empty elements and timepoints. + Return traps, time points and labels for this position in the form of a tree in the hierarchy determined by the argument fields. + + Note that it is compressed to non-empty elements and timepoints. Default hierarchy is: - trap - time point - cell label - This function currently produces trees of depth 3, but it can easily be - extended for deeper trees if needed (e.g. considering groups, - chambers and/or positions). + This function currently produces trees of depth 3, but it can easily be extended for deeper trees if needed (e.g. considering groups, chambers and/or positions). Parameters ---------- - fields: Fields to fetch from 'cell_info' inside the hdf5 storage + fields: list of strs + Fields to fetch from 'cell_info' inside the h5 file. Returns ---------- - Nested dictionary where keys (or branches) are the upper levels - and the leaves are the last element of :fields:. + Nested dictionary where keys (or branches) are the upper levels and the leaves are the last element of :fields:. """ zipped_info = (*zip(*[self._hdf["cell_info"][f][()] for f in fields]),) - return recursive_groupsort(zipped_info) def groupsort(iterable: Union[tuple, list]): - # Sorts iterable and returns a dictionary where the values are grouped by the first element. - + """Sorts iterable and returns a dictionary where the values are grouped by the first element.""" iterable = sorted(iterable, key=lambda x: x[0]) grouped = { k: [x[1:] for x in v] for k, v in groupby(iterable, lambda x: x[0]) @@ -127,17 +123,18 @@ def groupsort(iterable: Union[tuple, list]): def recursive_groupsort(iterable): - # Recursive extension of groupsort + """Recursive extension of groupsort.""" if len(iterable[0]) > 1: return { k: recursive_groupsort(v) for k, v in groupsort(iterable).items() } - else: # Only two elements in list + else: + # only two elements in list return [x[0] for x in iterable] def flatten(d, parent_key="", sep="_"): - """Flatten nested dict. Adapted from https://stackoverflow.com/a/6027615""" + """Flatten nested dict. Adapted from https://stackoverflow.com/a/6027615.""" items = [] for k, v in d.items(): new_key = parent_key + (k,) if parent_key else (k,) @@ -149,18 +146,19 @@ def flatten(d, parent_key="", sep="_"): def attrs_from_h5(fpath: str): - """Return attributes as dict from h5 file""" + """Return attributes as dict from an h5 file.""" with h5py.File(fpath, "r") as f: return dict(f.attrs) def parameters_from_h5(fpath: str): + """Return parameters from an h5 file.""" attrs = attrs_from_h5(fpath) return yaml.safe_load(attrs["parameters"]) def image_creds_from_h5(fpath: str): - """Return image id and server credentials from h5""" + """Return image id and server credentials from an h5.""" attrs = attrs_from_h5(fpath) return ( attrs["image_id"], diff --git a/src/agora/io/writer.py b/src/agora/io/writer.py index c448f59f64f59c9bc77ace34fc931e014eecd04a..f11827fa9adb6818e354c0574f91a419770f8f9a 100644 --- a/src/agora/io/writer.py +++ b/src/agora/io/writer.py @@ -18,35 +18,78 @@ from agora.io.utils import timed def load_attributes(file: str, group="/"): + """ + Load the metadata from an h5 file and convert to a dictionary, including the "parameters" field which is stored as YAML. + + Parameters + ---------- + file: str + Name of the h5 file + group: str, optional + The group in the h5 file from which to read the data + """ + # load the metadata, stored as attributes, from the h5 file and return as a dictionary with h5py.File(file, "r") as f: meta = dict(f[group].attrs.items()) if "parameters" in meta: + # convert from yaml format into dict meta["parameters"] = yaml.safe_load(meta["parameters"]) return meta class DynamicWriter: + """Provides a parent class for all writers.""" + + # a dict giving for each dataset a tuple, comprising the dataset's maximum size, as a 2D tuple, and its type data_types = {} + # the group in the h5 file to write to group = "" + # compression info compression = "gzip" compression_opts = 9 metadata = None def __init__(self, file: str): self.file = file + # the metadata is stored as attributes in the h5 file if Path(file).exists(): self.metadata = load_attributes(file) def _append(self, data, key, hgroup): - """Append data to existing dataset.""" + """ + Append data to existing dataset in the h5 file otherwise create a new one. + + Parameters + ---------- + data + Data to be written, typically a numpy array + key: str + Name of dataset + hgroup: str + Destination group in the h5 file + """ try: n = len(data) except Exception as e: logging.debug( - "DynamicWriter:Attributes have no length: {}".format(e) + "DynamicWriter: Attributes have no length: {}".format(e) ) n = 1 - if key not in hgroup: + if key in hgroup: + # append to existing dataset + try: + # FIXME This is broken by bugged mother-bud assignment + dset = hgroup[key] + dset.resize(dset.shape[0] + n, axis=0) + dset[-n:] = data + except Exception as e: + logging.debug( + "DynamicWriter: Inconsistency between dataset shape and new empty data: {}".format( + e + ) + ) + else: + # create new dataset # TODO Include sparsity check max_shape, dtype = self.datatypes[key] shape = (n,) + max_shape[1:] @@ -60,66 +103,87 @@ class DynamicWriter: if self.compression is not None else None, ) + # write all data, signified by the empty tuple hgroup[key][()] = data - else: - # The dataset already exists, expand it - - try: # FIXME This is broken by bugged mother-bud assignment - dset = hgroup[key] - dset.resize(dset.shape[0] + n, axis=0) - dset[-n:] = data - except Exception as e: - logging.debug( - "DynamicWriter:Inconsistency between dataset shape and new empty data: {}".format( - e - ) - ) - return def _overwrite(self, data, key, hgroup): - """Overwrite existing dataset with new data""" + """ + Delete and then replace existing dataset in h5 file. + + Parameters + ---------- + data + Data to be written, typically a numpy array + key: str + Name of dataset + hgroup: str + Destination group in the h5 file + """ # We do not append to mother_assign; raise error if already saved data_shape = np.shape(data) max_shape, dtype = self.datatypes[key] + # delete existing data if key in hgroup: del hgroup[key] + # write new data hgroup.require_dataset( - key, shape=data_shape, dtype=dtype, compression=self.compression + key, + shape=data_shape, + dtype=dtype, + compression=self.compression, ) + # write all data, signified by the empty tuple hgroup[key][()] = data - def _check_key(self, key): - if key not in self.datatypes: - raise KeyError(f"No defined data type for key {key}") + # def _check_key(self, key): + # if key not in self.datatypes: + # raise KeyError(f"No defined data type for key {key}") + + def write(self, data: dict, overwrite: list, meta: dict = {}): + """ + Write data and metadata to h5 file. - def write(self, data, overwrite: list, meta={}): - # Data is a dictionary, if not, make it one - # Overwrite data is a list + Parameters + ---------- + data: dict + A dict of datasets and data + overwrite: list of str + A list of datasets to overwrite + meta: dict, optional + Metadata to be written as attributes of the h5 file + """ with h5py.File(self.file, "a") as store: + # open group, creating if necessary hgroup = store.require_group(self.group) - + # write data for key, value in data.items(): - # We're only saving data that has a pre-defined data-type - self._check_key(key) - try: - if key.startswith("attrs/"): # metadata - key = key.split("/")[1] # First thing after attrs - hgroup.attrs[key] = value - elif key in overwrite: - self._overwrite(value, key, hgroup) - else: - self._append(value, key, hgroup) - except Exception as e: - print(key, value) - raise (e) + # only save data with a pre-defined data-type + if key not in self.datatypes: + raise KeyError(f"No defined data type for key {key}") + else: + try: + if key.startswith("attrs/"): + # metadata + key = key.split("/")[1] + hgroup.attrs[key] = value + elif key in overwrite: + # delete and replace existing dataset + self._overwrite(value, key, hgroup) + else: + # append or create new dataset + self._append(value, key, hgroup) + except Exception as e: + print(key, value) + raise (e) + # write metadata for key, value in meta.items(): hgroup.attrs[key] = value - return - ##################### Special instances ##################### class TilerWriter(DynamicWriter): + """Write data stored in a Tiler instance to h5 files.""" + datatypes = { "trap_locations": ((None, 2), np.uint16), "drifts": ((None, 2), np.float32), @@ -128,30 +192,40 @@ class TilerWriter(DynamicWriter): } group = "trap_info" - def write(self, data, overwrite: list, tp: int, meta={}): - """ - Skips writing data if it were to overwrite it,using drift as a marker + def write(self, data: dict, overwrite: list, tp: int, meta: dict = {}): """ + Write data for time points that have none. + Parameters + ---------- + data: dict + A dict of datasets and data + overwrite: list of str + A list of datasets to overwrite + tp: int + The time point of interest + meta: dict, optional + Metadata to be written as attributes of the h5 file + """ skip = False + # append to h5 file with h5py.File(self.file, "a") as store: + # open group, creating if necessary hgroup = store.require_group(self.group) - + # find xy drift for each time point as proof that it has already been processed nprev = hgroup.get("drifts", None) if nprev and tp < nprev.shape[0]: + # data already exists print(f"Tiler: Skipping timepoint {tp}") skip = True - if not skip: super().write(data=data, overwrite=overwrite, meta=meta) -tile_size = 117 - - +# Alan: we use complex numbers because... @timed() def save_complex(array, dataset): - # Dataset needs to be 2D + # append array, an 1D array of complex numbers, onto dataset, a 2D array of real numbers n = len(array) if n > 0: dataset.resize(dataset.shape[0] + n, axis=0) @@ -161,22 +235,33 @@ def save_complex(array, dataset): @timed() def load_complex(dataset): + # convert 2D dataset into a 1D array of complex numbers array = dataset[:, 0] + 1j * dataset[:, 1] return array class BabyWriter(DynamicWriter): + """ + Write data stored in a Baby instance to h5 files. + + Assumes the edgemasks are of form ((max_ncells, max_tps, tile_size, tile_size), bool). + """ + compression = "gzip" - max_ncells = 2e5 # Could just make this None + max_ncells = 2e5 # Alan: Could just make this None max_tps = 1e3 # Could just make this None - chunk_cells = 25 # The number of cells in a chunk for edge masks + # the number of cells in a chunk for edge masks + chunk_cells = 25 default_tile_size = 117 datatypes = { "centres": ((None, 2), np.uint16), "position": ((None,), np.uint16), "angles": ((None,), h5py.vlen_dtype(np.float32)), "radii": ((None,), h5py.vlen_dtype(np.float32)), - "edgemasks": ((max_ncells, max_tps, tile_size, tile_size), bool), + "edgemasks": ( + (max_ncells, max_tps, default_tile_size, default_tile_size), + bool, + ), "ellipse_dims": ((None, 2), np.float32), "cell_label": ((None,), np.uint16), "trap": ((None,), np.uint16), @@ -189,11 +274,10 @@ class BabyWriter(DynamicWriter): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # Get max_tps and trap info self._traps_initialised = False def __init_trap_info(self): - # Should only be run after the traps have been initialised + # requires traps to have been initialised trap_metadata = load_attributes(self.file, "trap_info") tile_size = trap_metadata.get("tile_size", self.default_tile_size) max_tps = self.metadata["time_settings/ntimepoints"][0] @@ -204,9 +288,8 @@ class BabyWriter(DynamicWriter): self._traps_initialised = True def __init_edgemasks(self, hgroup, edgemasks, current_indices, n_cells): - # Create values dataset - # This holds the edge masks directly and - # Is of shape (n_tps, n_cells, tile_size, tile_size) + # create the values dataset in the h5 file + # holds the edge masks and has shape (n_tps, n_cells, tile_size, tile_size) key = "edgemasks" max_shape, dtype = self.datatypes[key] shape = (n_cells, 1) + max_shape[2:] @@ -220,9 +303,8 @@ class BabyWriter(DynamicWriter): compression=self.compression, ) val_dset[:, 0] = edgemasks - # Create index dataset - # Holds the (trap, cell_id) description used to index into the - # values and is of shape (n_cells, 2) + # create index dataset in the h5 file: + # the (trap, cell_id) description used to index into the values and has shape (n_cells, 2) ix_max_shape = (max_shape[0], 2) ix_shape = (0, 2) ix_dtype = np.uint16 @@ -236,25 +318,24 @@ class BabyWriter(DynamicWriter): save_complex(current_indices, ix_dset) def __append_edgemasks(self, hgroup, edgemasks, current_indices): - # key = "edgemasks" val_dset = hgroup["values"] ix_dset = hgroup["indices"] existing_indices = load_complex(ix_dset) - # Check if there are any new labels + # check if there are any new labels available = np.in1d(current_indices, existing_indices) missing = current_indices[~available] all_indices = np.concatenate([existing_indices, missing]) - # Resizing - t = perf_counter() + # resizing + debug_t = perf_counter() # for timing code for debugging n_tps = val_dset.shape[1] + 1 n_add_cells = len(missing) - # RESIZE DATASET FOR TIME and Cells + # resize dataset for Time and Cells new_shape = (val_dset.shape[0] + n_add_cells, n_tps) + val_dset.shape[ 2: ] val_dset.resize(new_shape) - logging.debug(f"Timing:resizing:{perf_counter() - t}") - # Writing data + logging.debug(f"Timing:resizing:{perf_counter() - debug_t}") + # write data cell_indices = np.where(np.in1d(all_indices, current_indices))[0] for ix, mask in zip(cell_indices, edgemasks): try: @@ -265,75 +346,112 @@ class BabyWriter(DynamicWriter): e, ix, n_tps, val_dset.shape ) ) - # Save the index values + # save the index values save_complex(missing, ix_dset) def write_edgemasks(self, data, keys, hgroup): + """ + Write edgemasks to h5 file. + + Parameters + ---------- + data: list of arrays + Data to be written, in the form (trap_ids, cell_labels, edgemasks) + keys: list of str + Names corresponding to the elements of data. + For example: ["trap", "cell_label", "edgemasks"] + hgroup: group object + Group to write to in h5 file. + """ if not self._traps_initialised: self.__init_trap_info() - # DATA is TRAP_IDS, CELL_LABELS, EDGEMASKS in a structured array key = "edgemasks" val_key = "values" - # idx_key = "indices" - # Length of edgemasks traps, cell_labels, edgemasks = data n_cells = len(cell_labels) hgroup = hgroup.require_group(key) + # create complex indices with traps as real part and cell_labels as imaginary part current_indices = np.array(traps) + 1j * np.array(cell_labels) if val_key not in hgroup: self.__init_edgemasks(hgroup, edgemasks, current_indices, n_cells) else: self.__append_edgemasks(hgroup, edgemasks, current_indices) - def write(self, data, overwrite: list, tp: int = None, meta={}): + def write( + self, data: dict, overwrite: list, tp: int = None, meta: dict = {} + ): + """ + Write data from a Baby instance, including edgemasks. + + Parameters + ---------- + data: dict + A dict of datasets and data + overwrite: list of str + A list of datasets to overwrite + tp: int + The time point of interest + meta: dict, optional + Metadata to be written as attributes of the h5 file + """ with h5py.File(self.file, "a") as store: hgroup = store.require_group(self.group) - + # write data for key, value in data.items(): - # We're only saving data that has a pre-defined data-type - self._check_key(key) - try: - if key.startswith("attrs/"): # metadata - key = key.split("/")[1] # First thing after attrs - hgroup.attrs[key] = value - elif key in overwrite: - self._overwrite(value, key, hgroup) - elif key == "edgemasks": - keys = ["trap", "cell_label", "edgemasks"] - value = [data[x] for x in keys] - - edgemask_dset = hgroup.get(key + "/values", None) - if ( - # tp > 0 - edgemask_dset - and tp < edgemask_dset[()].shape[1] - ): - print(f"BabyWriter: Skipping edgemasks in tp {tp}") + if key not in self.datatypes: + raise KeyError( + f"BabyWriter: No defined data type for key {key}" + ) + else: + try: + if key.startswith("attrs/"): + # metadata + key = key.split("/")[1] + hgroup.attrs[key] = value + elif key in overwrite: + # delete and replace existing dataset + self._overwrite(value, key, hgroup) + elif key == "edgemasks": + keys = ["trap", "cell_label", "edgemasks"] + value = [data[x] for x in keys] + edgemask_dset = hgroup.get(key + "/values", None) + if ( + edgemask_dset + and tp < edgemask_dset[()].shape[1] + ): + # data already exists + print( + f"BabyWriter: Skipping edgemasks in tp {tp}" + ) + else: + self.write_edgemasks(value, keys, hgroup) else: - # print(f"BabyWriter: Writing edgemasks in tp {tp}") - self.write_edgemasks(value, keys, hgroup) - else: - self._append(value, key, hgroup) - except Exception as e: - print(key, value) - raise (e) - - # Meta + # append or create new dataset + self._append(value, key, hgroup) + except Exception as e: + print(key, value) + raise (e) + # write metadata for key, value in meta.items(): hgroup.attrs[key] = value - return - class LinearBabyWriter(DynamicWriter): - # TODO make this YAML + """ + Write data stored in a Baby instance to h5 files. + + Assumes the edgemasks are of form ((None, tile_size, tile_size), bool). + """ + + # TODO make this YAML: Alan: why? compression = "gzip" + _default_tile_size = 117 datatypes = { "centres": ((None, 2), np.uint16), "position": ((None,), np.uint16), "angles": ((None,), h5py.vlen_dtype(np.float32)), "radii": ((None,), h5py.vlen_dtype(np.float32)), - "edgemasks": ((None, tile_size, tile_size), bool), + "edgemasks": ((None, _default_tile_size, _default_tile_size), bool), "ellipse_dims": ((None, 2), np.float32), "cell_label": ((None,), np.uint16), "trap": ((None,), np.uint16), @@ -344,32 +462,63 @@ class LinearBabyWriter(DynamicWriter): } group = "cell_info" - def write(self, data, overwrite: list, tp=None, meta={}): - # Data is a dictionary, if not, make it one - # Overwrite data is a list + def write( + self, data: dict, overwrite: list, tp: int = None, meta: dict = {} + ): + """ + Check data does not exist before writing. + Parameters + ---------- + data: dict + A dict of datasets and data + overwrite: list of str + A list of datasets to overwrite + tp: int + The time point of interest + meta: dict, optional + Metadata to be written as attributes of the h5 file + """ with h5py.File(self.file, "a") as store: hgroup = store.require_group(self.group) available_tps = hgroup.get("timepoint", None) + # write data if not available_tps or tp not in np.unique(available_tps[()]): super().write(data, overwrite) else: + # data already exists print(f"BabyWriter: Skipping tp {tp}") - + # write metadata for key, value in meta.items(): hgroup.attrs[key] = value class StateWriter(DynamicWriter): + """ + Write information summarising the current state of the pipeline - the 'last_state' dataset in the h5 file. + + MOVEDatatypes are specified with the first variable specifying the number of traps and the other specifying the shape of the object. + + """ + datatypes = { + # the highest cell label assigned for each time point "max_lbl": ((None, 1), np.uint16), + # how far back we go for tracking "tp_back": ((None, 1), np.uint16), + # trap labels "trap": ((None, 1), np.int16), + # cell labels "cell_lbls": ((None, 1), np.uint16), + # previous cell features for tracking "prev_feats": ((None, None), np.float32), + # number of images for which a cell has been present "lifetime": ((None, 2), np.uint16), + # probability of a mother-bud relationship given a bud "p_was_bud": ((None, 2), np.float32), + # probability of a mother-bud relationship given a mother "p_is_mother": ((None, 2), np.float32), + # cumulative matrix, over time, of bud assignments "ba_cum": ((None, None), np.float32), } group = "last_state" @@ -377,29 +526,33 @@ class StateWriter(DynamicWriter): @staticmethod def format_field(states: list, field: str): - # Flatten a field in the states list to save as an hdf5 dataset + """Flatten a field in the states list to save as an h5 dataset.""" fields = [pos_state[field] for pos_state in states] return fields @staticmethod def format_values_tpback(states: list, val_name: str): + """Unpacks a dict of state data into tp_back, trap, value.""" + # initialise as empty lists + # Alan: is this initialisation necessary? tp_back, trap, value = [ [[] for _ in states[0][val_name]] for _ in range(3) ] - + # store results as a list of tuples lbl_tuples = [ (tp_back, trap, cell_label) for trap, state in enumerate(states) for tp_back, value in enumerate(state[val_name]) for cell_label in value ] + # unpack list of tuples to define variables if len(lbl_tuples): tp_back, trap, value = zip(*lbl_tuples) - return tp_back, trap, value @staticmethod def format_values_traps(states: list, val_name: str): + """Format either lifetime, p_was_bud, or p_is_mother variables as a list.""" formatted = np.array( [ (trap, clabel_val) @@ -411,113 +564,106 @@ class StateWriter(DynamicWriter): @staticmethod def pad_if_needed(array: np.ndarray, pad_size: int): + """Pad a 2D array with zeros for large indices.""" padded = np.zeros((pad_size, pad_size)).astype(float) length = len(array) padded[:length, :length] = array - return padded def format_states(self, states: list): + """Re-format state data into a dict of lists, with one element per per list per state.""" formatted_state = {"max_lbl": [state["max_lbl"] for state in states]} tp_back, trap, cell_label = self.format_values_tpback( states, "cell_lbls" ) _, _, prev_feats = self.format_values_tpback(states, "prev_feats") - - # Heterogeneous datasets + # store lists in a dict formatted_state["tp_back"] = tp_back formatted_state["trap"] = trap formatted_state["cell_lbls"] = cell_label formatted_state["prev_feats"] = np.array(prev_feats) - - # One entry per cell label - tp_back independent + # one entry per cell label - tp_back independent for val_name in ("lifetime", "p_was_bud", "p_is_mother"): formatted_state[val_name] = self.format_values_traps( states, val_name ) - bacum_max = max([len(state["ba_cum"]) for state in states]) - formatted_state["ba_cum"] = np.array( [ self.pad_if_needed(state["ba_cum"], bacum_max) for state in states ] ) - return formatted_state - def write(self, data, overwrite: Iterable, tp: int = None): - # formatted_data = self.format_states(data) - # super().write(data=formatted_data, overwrite=overwrite) + def write(self, data: dict, overwrite: list, tp: int = 0): + """Write the current state of the pipeline.""" if len(data): last_tp = 0 - if tp is None: - tp = 0 - try: with h5py.File(self.file, "r") as f: gr = f.get(self.group, None) if gr: last_tp = gr.attrs.get("tp", 0) - - # print(f"{ self.file } - tp: {tp}, last_tp: {last_tp}") if tp == 0 or tp > last_tp: - # print(f"Writing timepoint {tp}") + # write formatted_data = self.format_states(data) super().write(data=formatted_data, overwrite=overwrite) with h5py.File(self.file, "a") as f: - # print(f"Writing tp {tp}") + # record that data for the timepoint has been written f[self.group].attrs["tp"] = tp elif tp > 0 and tp <= last_tp: - print(f"BabyWriter: Skipping timepoint {tp}") + # data already present + print(f"StateWriter: Skipping timepoint {tp}") except Exception as e: raise (e) else: - print("Skipping overwriting empty state") + print("Skipping overwriting: no data") #################### Extraction version ############################### class Writer(BridgeH5): - """ - Class in charge of transforming data into compatible formats + """Class to transform data into compatible structures.""" - Decoupling interface from implementation! + # Alan: when is this used? - Parameters - ---------- - filename: str Name of file to write into + def __init__(self, filename, flag=None, compression="gzip"): + """ + Initialise write. + + Parameters + ---------- + filename: str + Name of file to write into flag: str, default=None Flag to pass to the default file reader. If None the file remains closed. - compression: str, default=None - Compression method passed on to h5py writing functions (only used for - dataframes and other array-like data.) - """ - - def __init__(self, filename, compression=None): - super().__init__(filename, flag=None) - - if compression is None: - self.compression = "gzip" + compression: str, default="gzip" + Compression method passed on to h5py writing functions (only used for dataframes and other array-like data). + """ + super().__init__(filename, flag=flag) + self.compression = compression def write( self, path: str, data: Iterable = None, - meta: Dict = {}, + meta: dict = {}, overwrite: str = None, ): """ + Write data and metadata to a particular path in the h5 file. + Parameters ---------- path : str - Path inside h5 file to write into. - data : Iterable, default = None - meta : Dict, default = {} - + Path inside h5 file into which to write. + data : Iterable, optional + meta : dict, optional + overwrite: str, optional """ self.id_cache = {} with h5py.File(self.filename, "a") as f: + # Alan, haven't we already opened the h5 file through BridgeH5's init? if overwrite == "overwrite": # TODO refactor overwriting if path in f: del f[path] @@ -529,49 +675,57 @@ class Writer(BridgeH5): # elif overwrite == "skip": # if path in f: # logging.debug("Skipping dataset {}".format(path)) - logging.debug( "{} {} to {} and {} metadata fields".format( overwrite, type(data), path, len(meta) ) ) + # write data if data is not None: self.write_dset(f, path, data) + # write metadata if meta: for attr, metadata in meta.items(): self.write_meta(f, path, attr, data=metadata) def write_dset(self, f: h5py.File, path: str, data: Iterable): + """Write data in different ways depending on its type to an open h5 file.""" + # data is a datafram if isinstance(data, pd.DataFrame): self.write_pd(f, path, data, compression=self.compression) + # data is a multi-index dataframe elif isinstance(data, pd.MultiIndex): + # Alan: should we still not compress here? self.write_index(f, path, data) # , compression=self.compression) + # data is a dictionary of dataframes elif isinstance(data, Dict) and np.all( [isinstance(x, pd.DataFrame) for x in data.values] ): for k, df in data.items(): self.write_dset(f, path + f"/{k}", df) + # data is an iterable elif isinstance(data, Iterable): self.write_arraylike(f, path, data) + # data is a float or integer else: self.write_atomic(data, f, path) def write_meta(self, f: h5py.File, path: str, attr: str, data: Iterable): + """Write metadata to an open h5 file.""" obj = f.require_group(path) - obj.attrs[attr] = data @staticmethod def write_arraylike(f: h5py.File, path: str, data: Iterable, **kwargs): + """Write an iterable.""" if path in f: del f[path] - narray = np.array(data) - - chunks = None if narray.any(): chunks = (1, *narray.shape[1:]) - + else: + chunks = None + # create dset dset = f.create_dataset( path, shape=narray.shape, @@ -579,10 +733,12 @@ class Writer(BridgeH5): dtype="int", compression=kwargs.get("compression", None), ) + # add data to dset dset[()] = narray @staticmethod def write_index(f, path, pd_index, **kwargs): + """Write a multi-index dataframe.""" f.require_group(path) # TODO check if we can remove this for i, name in enumerate(pd_index.names): ids = pd_index.get_level_values(i) @@ -597,12 +753,14 @@ class Writer(BridgeH5): indices[()] = ids def write_pd(self, f, path, df, **kwargs): + """Write a dataframe.""" values_path = ( path + "values" if path.endswith("/") else path + "/values" ) if path not in f: - max_ncells = 2e5 + # create dataset and write data + max_ncells = 2e5 max_tps = 1e3 f.create_dataset( name=values_path, @@ -616,6 +774,7 @@ class Writer(BridgeH5): dset = f[values_path] dset[()] = df.values + # create dateset and write indices if not len(df): # Only write more if not empty return None @@ -630,7 +789,7 @@ class Writer(BridgeH5): ) dset = f[indices_path] dset[()] = df.index.get_level_values(level=name).tolist() - + # create dataset and write columns if ( df.columns.dtype == np.int or df.columns.dtype == np.dtype("uint") @@ -648,9 +807,11 @@ class Writer(BridgeH5): else: f[path].attrs["columns"] = df.columns.tolist() else: + + # path exists dset = f[values_path] - # Filter out repeated timepoints + # filter out repeated timepoints new_tps = set(df.columns) if path + "/timepoint" in f: new_tps = new_tps.difference(f[path + "/timepoint"][()]) @@ -659,16 +820,18 @@ class Writer(BridgeH5): if ( not hasattr(self, "id_cache") or df.index.nlevels not in self.id_cache - ): # Use cache dict to store previously-obtained indices + ): + # use cache dict to store previously obtained indices self.id_cache[df.index.nlevels] = {} existing_ids = self.get_existing_ids( f, [path + "/" + x for x in df.index.names] ) - # Split indices in existing and additional + # split indices in existing and additional new = df.index.tolist() if ( df.index.nlevels == 1 - ): # Cover for cases with a single index + ): + # cover cases with a single index new = [(x,) for x in df.index.tolist()] ( found_multis, @@ -681,7 +844,7 @@ class Writer(BridgeH5): locate_indices(existing_ids, found_multis) ) - # We must sort our indices for h5py indexing + # sort indices for h5 indexing incremental_existing = np.argsort(found_indices) self.id_cache[df.index.nlevels][ "found_indices" @@ -706,7 +869,7 @@ class Writer(BridgeH5): ].values ncells, ntps = f[values_path].shape - # Add found cells + # add found cells dset.resize(dset.shape[1] + df.shape[1], axis=1) dset[:, ntps:] = np.nan @@ -715,13 +878,13 @@ class Writer(BridgeH5): "found_indices" ] - # Cover for case when all labels are new + # case when all labels are new if found_indices_sorted.any(): # h5py does not allow bidimensional indexing, # so we have to iterate over the columns for i, tp in enumerate(df.columns): dset[found_indices_sorted, tp] = existing_values[:, i] - # Add new cells + # add new cells n_newcells = len( self.id_cache[df.index.nlevels]["additional_multis"] ) @@ -756,24 +919,20 @@ class Writer(BridgeH5): @staticmethod def get_existing_ids(f, paths): - # Fetch indices and convert them to a (nentries, nlevels) ndarray + """Fetch indices and convert them to a (nentries, nlevels) ndarray.""" return np.array([f[path][()] for path in paths]).T @staticmethod def find_ids(existing, new): - # Compare two tuple sets and return the intersection and difference - # (elements in the 'new' set not in 'existing') + """Compare two tuple sets and return the intersection and difference (elements in the 'new' set not in 'existing').""" set_existing = set([tuple(*x) for x in zip(existing.tolist())]) existing_cells = np.array(list(set_existing.intersection(new))) new_cells = np.array(list(set(new).difference(set_existing))) + return existing_cells, new_cells + - return ( - existing_cells, - new_cells, - ) -# @staticmethod def locate_indices(existing, new): if new.any(): if new.shape[1] > 1: @@ -794,7 +953,7 @@ def locate_indices(existing, new): def _tuple_or_int(x): - # Convert tuple to int if it only contains one value + """Convert tuple to int if it only contains one value.""" if len(x) == 1: return x[0] else: