diff --git a/src/agora/io/bridge.py b/src/agora/io/bridge.py index 4a2f6093e7d8aaf4ed8e0e0f8abfcfaffb796449..478408f37d1f0724a232876f0f5ce0da5b02b585 100644 --- a/src/agora/io/bridge.py +++ b/src/agora/io/bridge.py @@ -1,5 +1,5 @@ """ -Tools to interact with hdf5 files and handle data consistently. +Tools to interact with h5 files and handle data consistently. """ import collections from itertools import chain, groupby, product @@ -13,26 +13,28 @@ import yaml class BridgeH5: """ - Base class to interact with h5 data stores. - It also contains functions useful to predict how long should segmentation take. + Base class to interact with h5 files. + + It includes functions that predict how long segmentation will take. """ def __init__(self, filename, flag="r"): + """Initialise with the name of the h5 file.""" self.filename = filename if flag is not None: self._hdf = h5py.File(filename, flag) - self._filecheck def _filecheck(self): assert "cell_info" in self._hdf, "Invalid file. No 'cell_info' found." def close(self): + """Close the h5 file.""" self._hdf.close() @property def meta_h5(self) -> t.Dict[str, t.Any]: - # Return metadata as indicated in h5 file + """Return metadata, defining it if necessary.""" if not hasattr(self, "_meta_h5"): with h5py.File(self.filename, "r") as f: self._meta_h5 = dict(f.attrs) @@ -44,24 +46,24 @@ class BridgeH5: @staticmethod def get_consecutives(tree, nstepsback): - # Receives a sorted tree and returns the keys of consecutive elements - vals = {k: np.array(list(v)) for k, v in tree.items()} # get tp level + """Receives a sorted tree and returns the keys of consecutive elements.""" + # get tp level + vals = {k: np.array(list(v)) for k, v in tree.items()} + # get indices of consecutive elements where_consec = [ { k: np.where(np.subtract(v[n + 1 :], v[: -n - 1]) == n + 1)[0] for k, v in vals.items() } for n in range(nstepsback) - ] # get indices of consecutive elements + ] return where_consec def get_npairs(self, nstepsback=2, tree=None): if tree is None: tree = self.cell_tree - consecutive = self.get_consecutives(tree, nstepsback=nstepsback) flat_tree = flatten(tree) - n_predictions = 0 for i, d in enumerate(consecutive, 1): flat = list(chain(*[product([k], list(v)) for k, v in d.items()])) @@ -70,55 +72,49 @@ class BridgeH5: n_predictions += len(flat_tree.get(p[0], [])) * len( flat_tree.get(p[1], []) ) - return n_predictions def get_npairs_over_time(self, nstepsback=2): tree = self.cell_tree npairs = [] - for t in self._hdf["cell_info"]["processed_timepoints"][()]: + for tp in self._hdf["cell_info"]["processed_timepoints"][()]: tmp_tree = { - k: {k2: v2 for k2, v2 in v.items() if k2 <= t} + k: {k2: v2 for k2, v2 in v.items() if k2 <= tp} for k, v in tree.items() } npairs.append(self.get_npairs(tree=tmp_tree)) - return np.diff(npairs) def get_info_tree( self, fields: Union[tuple, list] = ("trap", "timepoint", "cell_label") ): """ - Returns traps, time points and labels for this position in form of a tree - in the hierarchy determined by the argument fields. Note that it is - compressed to non-empty elements and timepoints. + Return traps, time points and labels for this position in the form of a tree in the hierarchy determined by the argument fields. + + Note that it is compressed to non-empty elements and timepoints. Default hierarchy is: - trap - time point - cell label - This function currently produces trees of depth 3, but it can easily be - extended for deeper trees if needed (e.g. considering groups, - chambers and/or positions). + This function currently produces trees of depth 3, but it can easily be extended for deeper trees if needed (e.g. considering groups, chambers and/or positions). Parameters ---------- - fields: Fields to fetch from 'cell_info' inside the hdf5 storage + fields: list of strs + Fields to fetch from 'cell_info' inside the h5 file. Returns ---------- - Nested dictionary where keys (or branches) are the upper levels - and the leaves are the last element of :fields:. + Nested dictionary where keys (or branches) are the upper levels and the leaves are the last element of :fields:. """ zipped_info = (*zip(*[self._hdf["cell_info"][f][()] for f in fields]),) - return recursive_groupsort(zipped_info) def groupsort(iterable: Union[tuple, list]): - # Sorts iterable and returns a dictionary where the values are grouped by the first element. - + """Sorts iterable and returns a dictionary where the values are grouped by the first element.""" iterable = sorted(iterable, key=lambda x: x[0]) grouped = { k: [x[1:] for x in v] for k, v in groupby(iterable, lambda x: x[0]) @@ -127,17 +123,18 @@ def groupsort(iterable: Union[tuple, list]): def recursive_groupsort(iterable): - # Recursive extension of groupsort + """Recursive extension of groupsort.""" if len(iterable[0]) > 1: return { k: recursive_groupsort(v) for k, v in groupsort(iterable).items() } - else: # Only two elements in list + else: + # only two elements in list return [x[0] for x in iterable] def flatten(d, parent_key="", sep="_"): - """Flatten nested dict. Adapted from https://stackoverflow.com/a/6027615""" + """Flatten nested dict. Adapted from https://stackoverflow.com/a/6027615.""" items = [] for k, v in d.items(): new_key = parent_key + (k,) if parent_key else (k,) @@ -149,18 +146,19 @@ def flatten(d, parent_key="", sep="_"): def attrs_from_h5(fpath: str): - """Return attributes as dict from h5 file""" + """Return attributes as dict from an h5 file.""" with h5py.File(fpath, "r") as f: return dict(f.attrs) def parameters_from_h5(fpath: str): + """Return parameters from an h5 file.""" attrs = attrs_from_h5(fpath) return yaml.safe_load(attrs["parameters"]) def image_creds_from_h5(fpath: str): - """Return image id and server credentials from h5""" + """Return image id and server credentials from an h5.""" attrs = attrs_from_h5(fpath) return ( attrs["image_id"], diff --git a/src/agora/io/writer.py b/src/agora/io/writer.py index 66fec1f40d69c08412e246be4e21cd20d6fdb410..6c07e40609f188006fe703debb605dd14a2426a3 100644 --- a/src/agora/io/writer.py +++ b/src/agora/io/writer.py @@ -620,22 +620,23 @@ class StateWriter(DynamicWriter): #################### Extraction version ############################### class Writer(BridgeH5): - """ - Class in charge of transforming data into compatible structures. + """Class to transform data into compatible structures.""" - Decoupling interface from implementation! - - Parameters - ---------- - filename: str - Name of file to write into - flag: str, default=None - Flag to pass to the default file reader. If None the file remains closed. - compression: str, default="gzip" - Compression method passed on to h5py writing functions (only used for dataframes and other array-like data). - """ + # Alan: when is this used? def __init__(self, filename, flag=None, compression="gzip"): + """ + Initialise write. + + Parameters + ---------- + filename: str + Name of file to write into + flag: str, default=None + Flag to pass to the default file reader. If None the file remains closed. + compression: str, default="gzip" + Compression method passed on to h5py writing functions (only used for dataframes and other array-like data). + """ super().__init__(filename, flag=flag) self.compression = compression @@ -647,16 +648,19 @@ class Writer(BridgeH5): overwrite: str = None, ): """ + Write data and metadata to a particular path in the h5 file. + Parameters ---------- path : str - Path inside h5 file to write into. - data : Iterable, default = None - meta : dict, default = {} - + Path inside h5 file into which to write. + data : Iterable, optional + meta : dict, optional + overwrite: str, optional """ self.id_cache = {} with h5py.File(self.filename, "a") as f: + # Alan, haven't we already opened the h5 file through BridgeH5's init? if overwrite == "overwrite": # TODO refactor overwriting if path in f: del f[path] @@ -668,49 +672,57 @@ class Writer(BridgeH5): # elif overwrite == "skip": # if path in f: # logging.debug("Skipping dataset {}".format(path)) - logging.debug( "{} {} to {} and {} metadata fields".format( overwrite, type(data), path, len(meta) ) ) + # write data if data is not None: self.write_dset(f, path, data) + # write metadata if meta: for attr, metadata in meta.items(): self.write_meta(f, path, attr, data=metadata) def write_dset(self, f: h5py.File, path: str, data: Iterable): + """Write data in different ways depending on its type to an open h5 file.""" + # data is a datafram if isinstance(data, pd.DataFrame): self.write_pd(f, path, data, compression=self.compression) + # data is a multi-index dataframe elif isinstance(data, pd.MultiIndex): + # Alan: should we still not compress here? self.write_index(f, path, data) # , compression=self.compression) + # data is a dictionary of dataframes elif isinstance(data, Dict) and np.all( [isinstance(x, pd.DataFrame) for x in data.values] ): for k, df in data.items(): self.write_dset(f, path + f"/{k}", df) + # data is an iterable elif isinstance(data, Iterable): self.write_arraylike(f, path, data) + # data is a float or integer else: self.write_atomic(data, f, path) def write_meta(self, f: h5py.File, path: str, attr: str, data: Iterable): + """Write metadata to an open h5 file.""" obj = f.require_group(path) - obj.attrs[attr] = data @staticmethod def write_arraylike(f: h5py.File, path: str, data: Iterable, **kwargs): + """Write an iterable.""" if path in f: del f[path] - narray = np.array(data) - - chunks = None if narray.any(): chunks = (1, *narray.shape[1:]) - + else: + chunks = None + # create dset dset = f.create_dataset( path, shape=narray.shape, @@ -718,10 +730,12 @@ class Writer(BridgeH5): dtype="int", compression=kwargs.get("compression", None), ) + # add data to dset dset[()] = narray @staticmethod def write_index(f, path, pd_index, **kwargs): + """Write a multi-index dataframe.""" f.require_group(path) # TODO check if we can remove this for i, name in enumerate(pd_index.names): ids = pd_index.get_level_values(i) @@ -736,12 +750,14 @@ class Writer(BridgeH5): indices[()] = ids def write_pd(self, f, path, df, **kwargs): + """Write a dataframe.""" values_path = ( path + "values" if path.endswith("/") else path + "/values" ) if path not in f: - max_ncells = 2e5 + # create dataset and write data + max_ncells = 2e5 max_tps = 1e3 f.create_dataset( name=values_path, @@ -754,7 +770,7 @@ class Writer(BridgeH5): ) dset = f[values_path] dset[()] = df.values - + # create dateset and write indices for name in df.index.names: indices_path = "/".join((path, name)) f.create_dataset( @@ -766,7 +782,7 @@ class Writer(BridgeH5): ) dset = f[indices_path] dset[()] = df.index.get_level_values(level=name).tolist() - + # create dataset and write columns if ( df.columns.dtype == np.int or df.columns.dtype == np.dtype("uint") @@ -784,9 +800,11 @@ class Writer(BridgeH5): else: f[path].attrs["columns"] = df.columns.tolist() else: + + # path exists dset = f[values_path] - # Filter out repeated timepoints + # filter out repeated timepoints new_tps = set(df.columns) if path + "/timepoint" in f: new_tps = new_tps.difference(f[path + "/timepoint"][()]) @@ -795,16 +813,18 @@ class Writer(BridgeH5): if ( not hasattr(self, "id_cache") or df.index.nlevels not in self.id_cache - ): # Use cache dict to store previously-obtained indices + ): + # use cache dict to store previously obtained indices self.id_cache[df.index.nlevels] = {} existing_ids = self.get_existing_ids( f, [path + "/" + x for x in df.index.names] ) - # Split indices in existing and additional + # split indices in existing and additional new = df.index.tolist() if ( df.index.nlevels == 1 - ): # Cover for cases with a single index + ): + # cover cases with a single index new = [(x,) for x in df.index.tolist()] ( found_multis, @@ -817,7 +837,7 @@ class Writer(BridgeH5): locate_indices(existing_ids, found_multis) ) - # We must sort our indices for h5py indexing + # sort indices for h5 indexing incremental_existing = np.argsort(found_indices) self.id_cache[df.index.nlevels][ "found_indices" @@ -842,7 +862,7 @@ class Writer(BridgeH5): ].values ncells, ntps = f[values_path].shape - # Add found cells + # add found cells dset.resize(dset.shape[1] + df.shape[1], axis=1) dset[:, ntps:] = np.nan @@ -851,13 +871,13 @@ class Writer(BridgeH5): "found_indices" ] - # Cover for case when all labels are new + # case when all labels are new if found_indices_sorted.any(): # h5py does not allow bidimensional indexing, # so we have to iterate over the columns for i, tp in enumerate(df.columns): dset[found_indices_sorted, tp] = existing_values[:, i] - # Add new cells + # add new cells n_newcells = len( self.id_cache[df.index.nlevels]["additional_multis"] ) @@ -892,24 +912,20 @@ class Writer(BridgeH5): @staticmethod def get_existing_ids(f, paths): - # Fetch indices and convert them to a (nentries, nlevels) ndarray + """Fetch indices and convert them to a (nentries, nlevels) ndarray.""" return np.array([f[path][()] for path in paths]).T @staticmethod def find_ids(existing, new): - # Compare two tuple sets and return the intersection and difference - # (elements in the 'new' set not in 'existing') + """Compare two tuple sets and return the intersection and difference (elements in the 'new' set not in 'existing').""" set_existing = set([tuple(*x) for x in zip(existing.tolist())]) existing_cells = np.array(list(set_existing.intersection(new))) new_cells = np.array(list(set(new).difference(set_existing))) + return existing_cells, new_cells + - return ( - existing_cells, - new_cells, - ) -# @staticmethod def locate_indices(existing, new): if new.any(): if new.shape[1] > 1: @@ -930,7 +946,7 @@ def locate_indices(existing, new): def _tuple_or_int(x): - # Convert tuple to int if it only contains one value + """Convert tuple to int if it only contains one value.""" if len(x) == 1: return x[0] else: