Skip to content
Snippets Groups Projects
Commit 87e943f5 authored by pswain's avatar pswain
Browse files

further comments for writer and bridge

parent 6feffce2
No related branches found
No related tags found
No related merge requests found
""" """
Tools to interact with hdf5 files and handle data consistently. Tools to interact with h5 files and handle data consistently.
""" """
import collections import collections
from itertools import chain, groupby, product from itertools import chain, groupby, product
...@@ -13,26 +13,28 @@ import yaml ...@@ -13,26 +13,28 @@ import yaml
class BridgeH5: class BridgeH5:
""" """
Base class to interact with h5 data stores. Base class to interact with h5 files.
It also contains functions useful to predict how long should segmentation take.
It includes functions that predict how long segmentation will take.
""" """
def __init__(self, filename, flag="r"): def __init__(self, filename, flag="r"):
"""Initialise with the name of the h5 file."""
self.filename = filename self.filename = filename
if flag is not None: if flag is not None:
self._hdf = h5py.File(filename, flag) self._hdf = h5py.File(filename, flag)
self._filecheck self._filecheck
def _filecheck(self): def _filecheck(self):
assert "cell_info" in self._hdf, "Invalid file. No 'cell_info' found." assert "cell_info" in self._hdf, "Invalid file. No 'cell_info' found."
def close(self): def close(self):
"""Close the h5 file."""
self._hdf.close() self._hdf.close()
@property @property
def meta_h5(self) -> t.Dict[str, t.Any]: def meta_h5(self) -> t.Dict[str, t.Any]:
# Return metadata as indicated in h5 file """Return metadata, defining it if necessary."""
if not hasattr(self, "_meta_h5"): if not hasattr(self, "_meta_h5"):
with h5py.File(self.filename, "r") as f: with h5py.File(self.filename, "r") as f:
self._meta_h5 = dict(f.attrs) self._meta_h5 = dict(f.attrs)
...@@ -44,24 +46,24 @@ class BridgeH5: ...@@ -44,24 +46,24 @@ class BridgeH5:
@staticmethod @staticmethod
def get_consecutives(tree, nstepsback): def get_consecutives(tree, nstepsback):
# Receives a sorted tree and returns the keys of consecutive elements """Receives a sorted tree and returns the keys of consecutive elements."""
vals = {k: np.array(list(v)) for k, v in tree.items()} # get tp level # get tp level
vals = {k: np.array(list(v)) for k, v in tree.items()}
# get indices of consecutive elements
where_consec = [ where_consec = [
{ {
k: np.where(np.subtract(v[n + 1 :], v[: -n - 1]) == n + 1)[0] k: np.where(np.subtract(v[n + 1 :], v[: -n - 1]) == n + 1)[0]
for k, v in vals.items() for k, v in vals.items()
} }
for n in range(nstepsback) for n in range(nstepsback)
] # get indices of consecutive elements ]
return where_consec return where_consec
def get_npairs(self, nstepsback=2, tree=None): def get_npairs(self, nstepsback=2, tree=None):
if tree is None: if tree is None:
tree = self.cell_tree tree = self.cell_tree
consecutive = self.get_consecutives(tree, nstepsback=nstepsback) consecutive = self.get_consecutives(tree, nstepsback=nstepsback)
flat_tree = flatten(tree) flat_tree = flatten(tree)
n_predictions = 0 n_predictions = 0
for i, d in enumerate(consecutive, 1): for i, d in enumerate(consecutive, 1):
flat = list(chain(*[product([k], list(v)) for k, v in d.items()])) flat = list(chain(*[product([k], list(v)) for k, v in d.items()]))
...@@ -70,55 +72,49 @@ class BridgeH5: ...@@ -70,55 +72,49 @@ class BridgeH5:
n_predictions += len(flat_tree.get(p[0], [])) * len( n_predictions += len(flat_tree.get(p[0], [])) * len(
flat_tree.get(p[1], []) flat_tree.get(p[1], [])
) )
return n_predictions return n_predictions
def get_npairs_over_time(self, nstepsback=2): def get_npairs_over_time(self, nstepsback=2):
tree = self.cell_tree tree = self.cell_tree
npairs = [] npairs = []
for t in self._hdf["cell_info"]["processed_timepoints"][()]: for tp in self._hdf["cell_info"]["processed_timepoints"][()]:
tmp_tree = { tmp_tree = {
k: {k2: v2 for k2, v2 in v.items() if k2 <= t} k: {k2: v2 for k2, v2 in v.items() if k2 <= tp}
for k, v in tree.items() for k, v in tree.items()
} }
npairs.append(self.get_npairs(tree=tmp_tree)) npairs.append(self.get_npairs(tree=tmp_tree))
return np.diff(npairs) return np.diff(npairs)
def get_info_tree( def get_info_tree(
self, fields: Union[tuple, list] = ("trap", "timepoint", "cell_label") self, fields: Union[tuple, list] = ("trap", "timepoint", "cell_label")
): ):
""" """
Returns traps, time points and labels for this position in form of a tree Return traps, time points and labels for this position in the form of a tree in the hierarchy determined by the argument fields.
in the hierarchy determined by the argument fields. Note that it is
compressed to non-empty elements and timepoints. Note that it is compressed to non-empty elements and timepoints.
Default hierarchy is: Default hierarchy is:
- trap - trap
- time point - time point
- cell label - cell label
This function currently produces trees of depth 3, but it can easily be This function currently produces trees of depth 3, but it can easily be extended for deeper trees if needed (e.g. considering groups, chambers and/or positions).
extended for deeper trees if needed (e.g. considering groups,
chambers and/or positions).
Parameters Parameters
---------- ----------
fields: Fields to fetch from 'cell_info' inside the hdf5 storage fields: list of strs
Fields to fetch from 'cell_info' inside the h5 file.
Returns Returns
---------- ----------
Nested dictionary where keys (or branches) are the upper levels Nested dictionary where keys (or branches) are the upper levels and the leaves are the last element of :fields:.
and the leaves are the last element of :fields:.
""" """
zipped_info = (*zip(*[self._hdf["cell_info"][f][()] for f in fields]),) zipped_info = (*zip(*[self._hdf["cell_info"][f][()] for f in fields]),)
return recursive_groupsort(zipped_info) return recursive_groupsort(zipped_info)
def groupsort(iterable: Union[tuple, list]): def groupsort(iterable: Union[tuple, list]):
# Sorts iterable and returns a dictionary where the values are grouped by the first element. """Sorts iterable and returns a dictionary where the values are grouped by the first element."""
iterable = sorted(iterable, key=lambda x: x[0]) iterable = sorted(iterable, key=lambda x: x[0])
grouped = { grouped = {
k: [x[1:] for x in v] for k, v in groupby(iterable, lambda x: x[0]) k: [x[1:] for x in v] for k, v in groupby(iterable, lambda x: x[0])
...@@ -127,17 +123,18 @@ def groupsort(iterable: Union[tuple, list]): ...@@ -127,17 +123,18 @@ def groupsort(iterable: Union[tuple, list]):
def recursive_groupsort(iterable): def recursive_groupsort(iterable):
# Recursive extension of groupsort """Recursive extension of groupsort."""
if len(iterable[0]) > 1: if len(iterable[0]) > 1:
return { return {
k: recursive_groupsort(v) for k, v in groupsort(iterable).items() k: recursive_groupsort(v) for k, v in groupsort(iterable).items()
} }
else: # Only two elements in list else:
# only two elements in list
return [x[0] for x in iterable] return [x[0] for x in iterable]
def flatten(d, parent_key="", sep="_"): def flatten(d, parent_key="", sep="_"):
"""Flatten nested dict. Adapted from https://stackoverflow.com/a/6027615""" """Flatten nested dict. Adapted from https://stackoverflow.com/a/6027615."""
items = [] items = []
for k, v in d.items(): for k, v in d.items():
new_key = parent_key + (k,) if parent_key else (k,) new_key = parent_key + (k,) if parent_key else (k,)
...@@ -149,18 +146,19 @@ def flatten(d, parent_key="", sep="_"): ...@@ -149,18 +146,19 @@ def flatten(d, parent_key="", sep="_"):
def attrs_from_h5(fpath: str): def attrs_from_h5(fpath: str):
"""Return attributes as dict from h5 file""" """Return attributes as dict from an h5 file."""
with h5py.File(fpath, "r") as f: with h5py.File(fpath, "r") as f:
return dict(f.attrs) return dict(f.attrs)
def parameters_from_h5(fpath: str): def parameters_from_h5(fpath: str):
"""Return parameters from an h5 file."""
attrs = attrs_from_h5(fpath) attrs = attrs_from_h5(fpath)
return yaml.safe_load(attrs["parameters"]) return yaml.safe_load(attrs["parameters"])
def image_creds_from_h5(fpath: str): def image_creds_from_h5(fpath: str):
"""Return image id and server credentials from h5""" """Return image id and server credentials from an h5."""
attrs = attrs_from_h5(fpath) attrs = attrs_from_h5(fpath)
return ( return (
attrs["image_id"], attrs["image_id"],
......
...@@ -620,22 +620,23 @@ class StateWriter(DynamicWriter): ...@@ -620,22 +620,23 @@ class StateWriter(DynamicWriter):
#################### Extraction version ############################### #################### Extraction version ###############################
class Writer(BridgeH5): class Writer(BridgeH5):
""" """Class to transform data into compatible structures."""
Class in charge of transforming data into compatible structures.
Decoupling interface from implementation! # Alan: when is this used?
Parameters
----------
filename: str
Name of file to write into
flag: str, default=None
Flag to pass to the default file reader. If None the file remains closed.
compression: str, default="gzip"
Compression method passed on to h5py writing functions (only used for dataframes and other array-like data).
"""
def __init__(self, filename, flag=None, compression="gzip"): def __init__(self, filename, flag=None, compression="gzip"):
"""
Initialise write.
Parameters
----------
filename: str
Name of file to write into
flag: str, default=None
Flag to pass to the default file reader. If None the file remains closed.
compression: str, default="gzip"
Compression method passed on to h5py writing functions (only used for dataframes and other array-like data).
"""
super().__init__(filename, flag=flag) super().__init__(filename, flag=flag)
self.compression = compression self.compression = compression
...@@ -647,16 +648,19 @@ class Writer(BridgeH5): ...@@ -647,16 +648,19 @@ class Writer(BridgeH5):
overwrite: str = None, overwrite: str = None,
): ):
""" """
Write data and metadata to a particular path in the h5 file.
Parameters Parameters
---------- ----------
path : str path : str
Path inside h5 file to write into. Path inside h5 file into which to write.
data : Iterable, default = None data : Iterable, optional
meta : dict, default = {} meta : dict, optional
overwrite: str, optional
""" """
self.id_cache = {} self.id_cache = {}
with h5py.File(self.filename, "a") as f: with h5py.File(self.filename, "a") as f:
# Alan, haven't we already opened the h5 file through BridgeH5's init?
if overwrite == "overwrite": # TODO refactor overwriting if overwrite == "overwrite": # TODO refactor overwriting
if path in f: if path in f:
del f[path] del f[path]
...@@ -668,49 +672,57 @@ class Writer(BridgeH5): ...@@ -668,49 +672,57 @@ class Writer(BridgeH5):
# elif overwrite == "skip": # elif overwrite == "skip":
# if path in f: # if path in f:
# logging.debug("Skipping dataset {}".format(path)) # logging.debug("Skipping dataset {}".format(path))
logging.debug( logging.debug(
"{} {} to {} and {} metadata fields".format( "{} {} to {} and {} metadata fields".format(
overwrite, type(data), path, len(meta) overwrite, type(data), path, len(meta)
) )
) )
# write data
if data is not None: if data is not None:
self.write_dset(f, path, data) self.write_dset(f, path, data)
# write metadata
if meta: if meta:
for attr, metadata in meta.items(): for attr, metadata in meta.items():
self.write_meta(f, path, attr, data=metadata) self.write_meta(f, path, attr, data=metadata)
def write_dset(self, f: h5py.File, path: str, data: Iterable): def write_dset(self, f: h5py.File, path: str, data: Iterable):
"""Write data in different ways depending on its type to an open h5 file."""
# data is a datafram
if isinstance(data, pd.DataFrame): if isinstance(data, pd.DataFrame):
self.write_pd(f, path, data, compression=self.compression) self.write_pd(f, path, data, compression=self.compression)
# data is a multi-index dataframe
elif isinstance(data, pd.MultiIndex): elif isinstance(data, pd.MultiIndex):
# Alan: should we still not compress here?
self.write_index(f, path, data) # , compression=self.compression) self.write_index(f, path, data) # , compression=self.compression)
# data is a dictionary of dataframes
elif isinstance(data, Dict) and np.all( elif isinstance(data, Dict) and np.all(
[isinstance(x, pd.DataFrame) for x in data.values] [isinstance(x, pd.DataFrame) for x in data.values]
): ):
for k, df in data.items(): for k, df in data.items():
self.write_dset(f, path + f"/{k}", df) self.write_dset(f, path + f"/{k}", df)
# data is an iterable
elif isinstance(data, Iterable): elif isinstance(data, Iterable):
self.write_arraylike(f, path, data) self.write_arraylike(f, path, data)
# data is a float or integer
else: else:
self.write_atomic(data, f, path) self.write_atomic(data, f, path)
def write_meta(self, f: h5py.File, path: str, attr: str, data: Iterable): def write_meta(self, f: h5py.File, path: str, attr: str, data: Iterable):
"""Write metadata to an open h5 file."""
obj = f.require_group(path) obj = f.require_group(path)
obj.attrs[attr] = data obj.attrs[attr] = data
@staticmethod @staticmethod
def write_arraylike(f: h5py.File, path: str, data: Iterable, **kwargs): def write_arraylike(f: h5py.File, path: str, data: Iterable, **kwargs):
"""Write an iterable."""
if path in f: if path in f:
del f[path] del f[path]
narray = np.array(data) narray = np.array(data)
chunks = None
if narray.any(): if narray.any():
chunks = (1, *narray.shape[1:]) chunks = (1, *narray.shape[1:])
else:
chunks = None
# create dset
dset = f.create_dataset( dset = f.create_dataset(
path, path,
shape=narray.shape, shape=narray.shape,
...@@ -718,10 +730,12 @@ class Writer(BridgeH5): ...@@ -718,10 +730,12 @@ class Writer(BridgeH5):
dtype="int", dtype="int",
compression=kwargs.get("compression", None), compression=kwargs.get("compression", None),
) )
# add data to dset
dset[()] = narray dset[()] = narray
@staticmethod @staticmethod
def write_index(f, path, pd_index, **kwargs): def write_index(f, path, pd_index, **kwargs):
"""Write a multi-index dataframe."""
f.require_group(path) # TODO check if we can remove this f.require_group(path) # TODO check if we can remove this
for i, name in enumerate(pd_index.names): for i, name in enumerate(pd_index.names):
ids = pd_index.get_level_values(i) ids = pd_index.get_level_values(i)
...@@ -736,12 +750,14 @@ class Writer(BridgeH5): ...@@ -736,12 +750,14 @@ class Writer(BridgeH5):
indices[()] = ids indices[()] = ids
def write_pd(self, f, path, df, **kwargs): def write_pd(self, f, path, df, **kwargs):
"""Write a dataframe."""
values_path = ( values_path = (
path + "values" if path.endswith("/") else path + "/values" path + "values" if path.endswith("/") else path + "/values"
) )
if path not in f: if path not in f:
max_ncells = 2e5
# create dataset and write data
max_ncells = 2e5
max_tps = 1e3 max_tps = 1e3
f.create_dataset( f.create_dataset(
name=values_path, name=values_path,
...@@ -754,7 +770,7 @@ class Writer(BridgeH5): ...@@ -754,7 +770,7 @@ class Writer(BridgeH5):
) )
dset = f[values_path] dset = f[values_path]
dset[()] = df.values dset[()] = df.values
# create dateset and write indices
for name in df.index.names: for name in df.index.names:
indices_path = "/".join((path, name)) indices_path = "/".join((path, name))
f.create_dataset( f.create_dataset(
...@@ -766,7 +782,7 @@ class Writer(BridgeH5): ...@@ -766,7 +782,7 @@ class Writer(BridgeH5):
) )
dset = f[indices_path] dset = f[indices_path]
dset[()] = df.index.get_level_values(level=name).tolist() dset[()] = df.index.get_level_values(level=name).tolist()
# create dataset and write columns
if ( if (
df.columns.dtype == np.int df.columns.dtype == np.int
or df.columns.dtype == np.dtype("uint") or df.columns.dtype == np.dtype("uint")
...@@ -784,9 +800,11 @@ class Writer(BridgeH5): ...@@ -784,9 +800,11 @@ class Writer(BridgeH5):
else: else:
f[path].attrs["columns"] = df.columns.tolist() f[path].attrs["columns"] = df.columns.tolist()
else: else:
# path exists
dset = f[values_path] dset = f[values_path]
# Filter out repeated timepoints # filter out repeated timepoints
new_tps = set(df.columns) new_tps = set(df.columns)
if path + "/timepoint" in f: if path + "/timepoint" in f:
new_tps = new_tps.difference(f[path + "/timepoint"][()]) new_tps = new_tps.difference(f[path + "/timepoint"][()])
...@@ -795,16 +813,18 @@ class Writer(BridgeH5): ...@@ -795,16 +813,18 @@ class Writer(BridgeH5):
if ( if (
not hasattr(self, "id_cache") not hasattr(self, "id_cache")
or df.index.nlevels not in self.id_cache or df.index.nlevels not in self.id_cache
): # Use cache dict to store previously-obtained indices ):
# use cache dict to store previously obtained indices
self.id_cache[df.index.nlevels] = {} self.id_cache[df.index.nlevels] = {}
existing_ids = self.get_existing_ids( existing_ids = self.get_existing_ids(
f, [path + "/" + x for x in df.index.names] f, [path + "/" + x for x in df.index.names]
) )
# Split indices in existing and additional # split indices in existing and additional
new = df.index.tolist() new = df.index.tolist()
if ( if (
df.index.nlevels == 1 df.index.nlevels == 1
): # Cover for cases with a single index ):
# cover cases with a single index
new = [(x,) for x in df.index.tolist()] new = [(x,) for x in df.index.tolist()]
( (
found_multis, found_multis,
...@@ -817,7 +837,7 @@ class Writer(BridgeH5): ...@@ -817,7 +837,7 @@ class Writer(BridgeH5):
locate_indices(existing_ids, found_multis) locate_indices(existing_ids, found_multis)
) )
# We must sort our indices for h5py indexing # sort indices for h5 indexing
incremental_existing = np.argsort(found_indices) incremental_existing = np.argsort(found_indices)
self.id_cache[df.index.nlevels][ self.id_cache[df.index.nlevels][
"found_indices" "found_indices"
...@@ -842,7 +862,7 @@ class Writer(BridgeH5): ...@@ -842,7 +862,7 @@ class Writer(BridgeH5):
].values ].values
ncells, ntps = f[values_path].shape ncells, ntps = f[values_path].shape
# Add found cells # add found cells
dset.resize(dset.shape[1] + df.shape[1], axis=1) dset.resize(dset.shape[1] + df.shape[1], axis=1)
dset[:, ntps:] = np.nan dset[:, ntps:] = np.nan
...@@ -851,13 +871,13 @@ class Writer(BridgeH5): ...@@ -851,13 +871,13 @@ class Writer(BridgeH5):
"found_indices" "found_indices"
] ]
# Cover for case when all labels are new # case when all labels are new
if found_indices_sorted.any(): if found_indices_sorted.any():
# h5py does not allow bidimensional indexing, # h5py does not allow bidimensional indexing,
# so we have to iterate over the columns # so we have to iterate over the columns
for i, tp in enumerate(df.columns): for i, tp in enumerate(df.columns):
dset[found_indices_sorted, tp] = existing_values[:, i] dset[found_indices_sorted, tp] = existing_values[:, i]
# Add new cells # add new cells
n_newcells = len( n_newcells = len(
self.id_cache[df.index.nlevels]["additional_multis"] self.id_cache[df.index.nlevels]["additional_multis"]
) )
...@@ -892,24 +912,20 @@ class Writer(BridgeH5): ...@@ -892,24 +912,20 @@ class Writer(BridgeH5):
@staticmethod @staticmethod
def get_existing_ids(f, paths): def get_existing_ids(f, paths):
# Fetch indices and convert them to a (nentries, nlevels) ndarray """Fetch indices and convert them to a (nentries, nlevels) ndarray."""
return np.array([f[path][()] for path in paths]).T return np.array([f[path][()] for path in paths]).T
@staticmethod @staticmethod
def find_ids(existing, new): def find_ids(existing, new):
# Compare two tuple sets and return the intersection and difference """Compare two tuple sets and return the intersection and difference (elements in the 'new' set not in 'existing')."""
# (elements in the 'new' set not in 'existing')
set_existing = set([tuple(*x) for x in zip(existing.tolist())]) set_existing = set([tuple(*x) for x in zip(existing.tolist())])
existing_cells = np.array(list(set_existing.intersection(new))) existing_cells = np.array(list(set_existing.intersection(new)))
new_cells = np.array(list(set(new).difference(set_existing))) new_cells = np.array(list(set(new).difference(set_existing)))
return existing_cells, new_cells
return (
existing_cells,
new_cells,
)
# @staticmethod
def locate_indices(existing, new): def locate_indices(existing, new):
if new.any(): if new.any():
if new.shape[1] > 1: if new.shape[1] > 1:
...@@ -930,7 +946,7 @@ def locate_indices(existing, new): ...@@ -930,7 +946,7 @@ def locate_indices(existing, new):
def _tuple_or_int(x): def _tuple_or_int(x):
# Convert tuple to int if it only contains one value """Convert tuple to int if it only contains one value."""
if len(x) == 1: if len(x) == 1:
return x[0] return x[0]
else: else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment