diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e5a0d9b4834ec8f46d6e0d1256c6dcaad2e460fe --- /dev/null +++ b/__init__.py @@ -0,0 +1 @@ +#!/usr/bin/env python3 diff --git a/abc.py b/abc.py new file mode 100644 index 0000000000000000000000000000000000000000..2c32876e1af2da3c36207b7fc8177b3d62b7abd7 --- /dev/null +++ b/abc.py @@ -0,0 +1,218 @@ +import typing as t +from abc import ABC, abstractmethod +from collections.abc import Iterable +from copy import copy +from pathlib import Path, PosixPath +from typing import Union + +from yaml import dump, safe_load +from flatten_dict import flatten + +atomic = t.Union[int, float, str, bool] + + +class ParametersABC(ABC): + """ + Defines parameters as attributes and allows parameters to + be converted to either a dictionary or to yaml. + + No attribute should be called "parameters"! + """ + + def __init__(self, **kwargs): + """ + Defines parameters as attributes + """ + assert ( + "parameters" not in kwargs + ), "No attribute should be named parameters" + for k, v in kwargs.items(): + setattr(self, k, v) + + def to_dict(self, iterable="null") -> t.Dict: + """ + Recursive function to return a nested dictionary of the + attributes of the class instance. + """ + if isinstance(iterable, dict): + if any( + [ + True + for x in iterable.values() + if isinstance(x, Iterable) or hasattr(x, "to_dict") + ] + ): + return { + k: v.to_dict() + if hasattr(v, "to_dict") + else self.to_dict(v) + for k, v in iterable.items() + } + else: + return iterable + elif iterable == "null": + # use instance's built-in __dict__ dictionary of attributes + return self.to_dict(self.__dict__) + else: + return iterable + + def to_yaml(self, path: Union[PosixPath, str] = None): + """ + Returns a yaml stream of the attributes of the class instance. + If path is provided, the yaml stream is saved there. + + Parameters + ---------- + path : Union[PosixPath, str] + Output path. + """ + if path: + with open(Path(path), "w") as f: + dump(self.to_dict(), f) + return dump(self.to_dict()) + + @classmethod + def from_dict(cls, d: dict): + return cls(**d) + + @classmethod + def from_yaml(cls, source: Union[PosixPath, str]): + """ + Returns instance from a yaml filename or stdin + """ + is_buffer = True + try: + if Path(source).exists(): + is_buffer = False + except Exception: + pass + if is_buffer: + params = safe_load(source) + else: + with open(source) as f: + params = safe_load(f) + return cls(**params) + + @classmethod + def default(cls, **kwargs): + overriden_defaults = copy(cls._defaults) + for k, v in kwargs.items(): + overriden_defaults[k] = v + return cls.from_dict(overriden_defaults) + + def update(self, name: str, new_value): + """ + Update values recursively + if name is a dictionary, replace data where existing found or add if not. + It warns against type changes. + + If the existing structure under name is a dictionary, + it looks for the first occurrence and modifies it accordingly. + + If a leaf node that is to be changed is a collection, it adds the new elements. + """ + + assert name not in ( + "parameters", + "params", + ), "Attribute can't be named params or parameters" + + if name in self.__dict__: + if check_type_recursive(getattr(self, name), new_value): + print("Warnings:Type changes are risky") + + if isinstance(getattr(self, name), dict): + flattened = flatten(self.to_dict()) + names_found = [k for k in flattened.keys() if name in k] + found_idx = [keys.index(name) for keys in names_found] + + assert len(names_found), f"{name} not found as key." + + keys = None + if len(names_found) > 1: + for level in zip(found_idx, names_found): + if level == min(found_idx): + keys = level + print( + f"Warning: {name} was found in multiple keys. Selected {keys}" + ) + break + + else: + keys = names_found.pop() + + if keys: + current_val = flattened.get(keys, None) + # if isinstance(current_val, t.Collection): + + elif isinstance(getattr(self, name), t.Collection): + add_to_collection(getattr(self, name), new_value) + + elif isinstance(getattr(self, name), set): + pass # TODO implement + + new_d = getattr(self, name) + new_d.update(new_value) + setattr(self, name, new_d) + + else: + setattr(self, name, new_value) + + +def add_to_collection( + collection: t.Collection, value: t.Union[atomic, t.Collection] +): + # Adds element(s) in place. + if not isinstance(value, t.Collection): + value = [value] + if isinstance(collection, list): + collection += value + elif isinstance(collection, set): + collection.update(value) + + +class ProcessABC(ABC): + """ + Base class for processes. + Defines parameters as attributes and requires run method to be defined. + """ + + def __init__(self, parameters): + """ + Arguments + --------- + parameters: instance of ParametersABC + """ + self._parameters = parameters + # convert parameters to dictionary + # and then define each parameter as an attribute + for k, v in parameters.to_dict().items(): + setattr(self, k, v) + + @property + def parameters(self): + return self._parameters + + @abstractmethod + def run(self): + pass + + +def check_type_recursive(val1, val2): + same_types = True + if not isinstance(val1, type(val2)) and not all( + type(x) in (PosixPath, str) for x in (val1, val2) # Ignore str->path + ): + return False + if not isinstance(val1, t.Iterable) and not isinstance(val2, t.Iterable): + return isinstance(val1, type(val2)) + elif isinstance(val1, (tuple, list)) and isinstance(val2, (tuple, list)): + return bool( + sum([check_type_recursive(v1, v2) for v1, v2 in zip(val1, val2)]) + ) + elif isinstance(val1, dict) and isinstance(val2, dict): + if not len(val1) or not len(val2): + return False + for k in val2.keys(): + same_types = same_types and check_type_recursive(val1[k], val2[k]) + return same_types diff --git a/io/__init__.py b/io/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e5a0d9b4834ec8f46d6e0d1256c6dcaad2e460fe --- /dev/null +++ b/io/__init__.py @@ -0,0 +1 @@ +#!/usr/bin/env python3 diff --git a/io/bridge.py b/io/bridge.py new file mode 100644 index 0000000000000000000000000000000000000000..f797f07bf45b72c8a16040474c6a6445510c9ec5 --- /dev/null +++ b/io/bridge.py @@ -0,0 +1,166 @@ +""" +Tools to interact with hdf5 files and handle data consistently. +""" +import collections +from itertools import chain, groupby, product +from typing import Union +import typing as t + +import h5py +import numpy as np +import yaml + + +class BridgeH5: + """ + Base class to interact with h5 data stores. + It also contains functions useful to predict how long should segmentation take. + """ + + def __init__(self, filename, flag="r"): + self.filename = filename + if flag is not None: + self._hdf = h5py.File(filename, flag) + + self._filecheck + + def _filecheck(self): + assert "cell_info" in self._hdf, "Invalid file. No 'cell_info' found." + + def close(self): + self._hdf.close() + + @property + def meta_h5(self) -> t.Dict[str, t.Any]: + # Return metadata as indicated in h5 file + if not hasattr(self, "_meta_h5"): + with h5py.File(self.filename, "r") as f: + self._meta_h5 = dict(f.attrs) + return self._meta_h5 + + @property + def cell_tree(self): + return self.get_info_tree() + + @staticmethod + def get_consecutives(tree, nstepsback): + # Receives a sorted tree and returns the keys of consecutive elements + vals = {k: np.array(list(v)) for k, v in tree.items()} # get tp level + where_consec = [ + { + k: np.where(np.subtract(v[n + 1 :], v[: -n - 1]) == n + 1)[0] + for k, v in vals.items() + } + for n in range(nstepsback) + ] # get indices of consecutive elements + return where_consec + + def get_npairs(self, nstepsback=2, tree=None): + if tree is None: + tree = self.cell_tree + + consecutive = self.get_consecutives(tree, nstepsback=nstepsback) + flat_tree = flatten(tree) + + n_predictions = 0 + for i, d in enumerate(consecutive, 1): + flat = list(chain(*[product([k], list(v)) for k, v in d.items()])) + pairs = [(f, (f[0], f[1] + i)) for f in flat] + for p in pairs: + n_predictions += len(flat_tree.get(p[0], [])) * len( + flat_tree.get(p[1], []) + ) + + return n_predictions + + def get_npairs_over_time(self, nstepsback=2): + tree = self.cell_tree + npairs = [] + for t in self._hdf["cell_info"]["processed_timepoints"][()]: + tmp_tree = { + k: {k2: v2 for k2, v2 in v.items() if k2 <= t} + for k, v in tree.items() + } + npairs.append(self.get_npairs(tree=tmp_tree)) + + return np.diff(npairs) + + def get_info_tree( + self, fields: Union[tuple, list] = ("trap", "timepoint", "cell_label") + ): + """ + Returns traps, time points and labels for this position in form of a tree + in the hierarchy determined by the argument fields. Note that it is + compressed to non-empty elements and timepoints. + + Default hierarchy is: + - trap + - time point + - cell label + + This function currently produces trees of depth 3, but it can easily be + extended for deeper trees if needed (e.g. considering groups, + chambers and/or positions). + + input + :fields: Fields to fetch from 'cell_info' inside the hdf5 storage + + returns + :tree: Nested dictionary where keys (or branches) are the upper levels + and the leaves are the last element of :fields:. + """ + zipped_info = (*zip(*[self._hdf["cell_info"][f][()] for f in fields]),) + + return recursive_groupsort(zipped_info) + + +def groupsort(iterable: Union[tuple, list]): + # Sorts iterable and returns a dictionary where the values are grouped by the first element. + + iterable = sorted(iterable, key=lambda x: x[0]) + grouped = { + k: [x[1:] for x in v] for k, v in groupby(iterable, lambda x: x[0]) + } + return grouped + + +def recursive_groupsort(iterable): + # Recursive extension of groupsort + if len(iterable[0]) > 1: + return { + k: recursive_groupsort(v) for k, v in groupsort(iterable).items() + } + else: # Only two elements in list + return [x[0] for x in iterable] + + +def flatten(d, parent_key="", sep="_"): + """Flatten nested dict. Adapted from https://stackoverflow.com/a/6027615""" + items = [] + for k, v in d.items(): + new_key = parent_key + (k,) if parent_key else (k,) + if isinstance(v, collections.MutableMapping): + items.extend(flatten(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + +def attrs_from_h5(fpath: str): + """Return attributes as dict from h5 file""" + with h5py.File(fpath, "r") as f: + return dict(f.attrs) + + +def parameters_from_h5(fpath: str): + attrs = attrs_from_h5(fpath) + return yaml.safe_load(attrs["parameters"]) + + +def image_creds_from_h5(fpath: str): + """Return image id and server credentials from h5""" + attrs = attrs_from_h5(fpath) + return ( + attrs["image_id"], + yaml.safe_load(attrs["parameters"])["general"]["server_info"], + ) diff --git a/io/cells.py b/io/cells.py new file mode 100644 index 0000000000000000000000000000000000000000..a8ca61b0e13c28b266d5bb2b9bd8fd02384106e3 --- /dev/null +++ b/io/cells.py @@ -0,0 +1,403 @@ +import logging +import typing as t +from collections.abc import Iterable +from itertools import groupby +from pathlib import Path, PosixPath +from functools import lru_cache + +import h5py +import numpy as np +from numpy.lib.stride_tricks import sliding_window_view +from scipy import ndimage +from scipy.sparse.base import isdense +from utils_find_1st import cmp_equal, find_1st + + +class Cells: + """ + Extracts information from an h5 file. This class accesses: + + 'cell_info', which contains 'angles', 'cell_label', 'centres', + 'edgemasks', 'ellipse_dims', 'mother_assign', 'mother_assign_dynamic', + 'radii', 'timepoint', 'trap'. + All of these except for 'edgemasks' are a 1D ndarray. + + 'trap_info', which contains 'drifts', 'trap_locations' + + """ + + def __init__(self, filename, path="cell_info"): + self.filename: t.Optional[t.Union[str, PosixPath]] = filename + self.cinfo_path: t.Optional[str] = path + self._edgemasks: t.Optional[str] = None + self._tile_size: t.Optional[int] = None + + @classmethod + def from_source(cls, source: t.Union[PosixPath, str]): + return cls(Path(source)) + + @staticmethod + def _asdense(array: np.ndarray): + if not isdense(array): + array = array.todense() + return array + + @staticmethod + def _astype(array: np.ndarray, kind: str): + # Convert sparse arrays if needed and if kind is 'mask' it fills the outline + array = Cells._asdense(array) + if kind == "mask": + array = ndimage.binary_fill_holes(array).astype(bool) + return array + + def _get_idx(self, cell_id: int, trap_id: int): + # returns boolean array of time points where both the cell with cell_id and the trap with trap_id exist + return (self["cell_label"] == cell_id) & (self["trap"] == trap_id) + + @property + def max_labels(self) -> t.List[int]: + return [max((0, *self.labels_in_trap(i))) for i in range(self.ntraps)] + + @property + def max_label(self) -> int: + return sum(self.max_labels) + + @property + def ntraps(self) -> int: + # find the number of traps from the h5 file + with h5py.File(self.filename, mode="r") as f: + return len(f["trap_info/trap_locations"][()]) + + @property + def tinterval(self): + with h5py.File(self.filename, mode="r") as f: + return f.attrs["time_settings/timeinterval"] + + @property + def traps(self) -> t.List[int]: + # returns a list of traps + return list(set(self["trap"])) + + @property + def tile_size(self) -> t.Union[int, t.Tuple[int], None]: + if self._tile_size is None: + with h5py.File(self.filename, mode="r") as f: + # self._tile_size = f["trap_info/tile_size"][0] + self._tile_size = f["cell_info/edgemasks"].shape[1:] + return self._tile_size + + def nonempty_tp_in_trap(self, trap_id: int) -> set: + # given a trap_id returns time points in which cells are available + return set(self["timepoint"][self["trap"] == trap_id]) + + @property + def edgemasks(self) -> t.List[np.ndarray]: + # returns the masks per tile + if self._edgemasks is None: + edgem_path: str = "edgemasks" + self._edgemasks = self._fetch(edgem_path) + return self._edgemasks + + @property + def labels(self) -> t.List[t.List[int]]: + """ + Return all cell labels in object + We use mother_assign to list traps because it is the only property that appears even + when no cells are found + """ + return [self.labels_in_trap(trap) for trap in range(self.ntraps)] + + def max_labels_in_frame(self, frame: int) -> t.List[int]: + # Return the maximum label for each trap in the given frame + max_labels = [ + self["cell_label"][ + (self["timepoint"] <= frame) & (self["trap"] == trap_id) + ] + for trap_id in range(self.ntraps) + ] + return [max([0, *labels]) for labels in max_labels] + + def where(self, cell_id: int, trap_id: int): + """ + Parameters + ---------- + cell_id: int + Cell index + trap_id: int + Trap index + + Returns + ---------- + indices int array + boolean mask array + edge_ix int array + """ + indices = self._get_idx(cell_id, trap_id) + edgem_ix = self._edgem_where(cell_id, trap_id) + return ( + self["timepoint"][indices], + indices, + edgem_ix, + ) + + def mask(self, cell_id, trap_id): + times, outlines = self.outline(cell_id, trap_id) + return times, np.array( + [ndimage.morphology.binary_fill_holes(o) for o in outlines] + ) + + def at_time(self, timepoint, kind="mask"): + ix = self["timepoint"] == timepoint + traps = self["trap"][ix] + edgemasks = self._edgem_from_masking(ix) + masks = [ + self._astype(edgemask, kind) + for edgemask in edgemasks + if edgemask.any() + ] + return self.group_by_traps(traps, masks) + + def group_by_traps( + self, traps: t.Collection, cell_labels: t.Collection + ) -> t.Dict[int, t.List[int]]: + # returns a dict with traps as keys and list of labels as value + # Data is a + iterator = groupby(zip(traps, cell_labels), lambda x: x[0]) + d = {key: [x[1] for x in group] for key, group in iterator} + d = {i: d.get(i, []) for i in self.traps} + return d + + def labels_in_trap(self, trap_id: int) -> t.Set[int]: + # return set of cell ids for a given trap + return set((self["cell_label"][self["trap"] == trap_id])) + + def labels_at_time(self, timepoint: int) -> t.Dict[int, t.List[int]]: + labels = self["cell_label"][self["timepoint"] == timepoint] + traps = self["trap"][self["timepoint"] == timepoint] + return self.group_by_traps(traps, labels) + + def __getitem__(self, item): + assert item != "edgemasks", "Edgemasks must not be loaded as a whole" + + _item = "_" + item + if not hasattr(self, _item): + setattr(self, _item, self._fetch(item)) + return getattr(self, _item) + + def _fetch(self, path): + with h5py.File(self.filename, mode="r") as f: + return f[self.cinfo_path][path][()] + + def _edgem_from_masking(self, mask): + with h5py.File(self.filename, mode="r") as f: + edgem = f[self.cinfo_path + "/edgemasks"][mask, ...] + return edgem + + def _edgem_where(self, cell_id, trap_id): + id_mask = self._get_idx(cell_id, trap_id) + edgem = self._edgem_from_masking(id_mask) + + return edgem + + def outline(self, cell_id: int, trap_id: int): + id_mask = self._get_idx(cell_id, trap_id) + times = self["timepoint"][id_mask] + + return times, self._edgem_from_masking(id_mask) + + @property + def ntimepoints(self) -> int: + return self["timepoint"].max() + 1 + + @property + def ncells_matrix(self): + ncells_mat = np.zeros( + (self.ntraps, self["cell_label"].max(), self.ntimepoints), + dtype=bool, + ) + ncells_mat[ + self["trap"], self["cell_label"] - 1, self["timepoint"] + ] = True + return ncells_mat + + def matrix_trap_tp_where( + self, min_ncells: int = None, min_consecutive_tps: int = None + ): + """ + Return a matrix of shape (ntraps x ntps - min_consecutive_tps to + indicate traps and time-points where min_ncells are available for at least min_consecutive_tps + + Parameters + --------- + min_ncells: int Minimum number of cells + min_consecutive_tps: int + Minimum number of time-points a + + Returns + --------- + (ntraps x ( ntps-min_consecutive_tps )) 2D boolean numpy array where rows are trap ids and columns are timepoint windows. + If the value in a cell is true its corresponding trap and timepoint contains more than min_ncells for at least min_consecutive time-points. + """ + if min_ncells is None: + min_ncells = 2 + if min_consecutive_tps is None: + min_consecutive_tps = 5 + + window = sliding_window_view( + self.ncells_matrix, min_consecutive_tps, axis=2 + ) + tp_min = window.sum(axis=-1) == min_consecutive_tps + ncells_tp_min = tp_min.sum(axis=1) >= min_ncells + return ncells_tp_min + + def random_valid_trap_tp( + self, min_ncells: int = None, min_consecutive_tps: int = None + ): + # Return a randomly-selected pair of trap_id and timepoints + mat = self.matrix_trap_tp_where( + min_ncells=min_ncells, min_consecutive_tps=min_consecutive_tps + ) + traps, tps = np.where(mat) + rand = np.random.randint(mat.sum()) + return (traps[rand], tps[rand]) + + def mothers_in_trap(self, trap_id: int): + return self.mothers[trap_id] + + @property + def mothers(self): + """ + Return nested list with final prediction of mother id for each cell + """ + return self.mother_assign_from_dynamic( + self["mother_assign_dynamic"], + self["cell_label"], + self["trap"], + self.ntraps, + ) + + @property + def mothers_daughters(self): + nested_massign = self.mothers + + if sum([x for y in nested_massign for x in y]): + mothers, daughters = zip( + *[ + ((tid, m), (tid, d)) + for tid, trapcells in enumerate(nested_massign) + for d, m in enumerate(trapcells, 1) + if m + ] + ) + else: + mothers, daughters = ([], []) + # print("Warning:Cells: No mother-daughters assigned") + + return mothers, daughters + + @staticmethod + def mother_assign_to_mb_matrix(ma: t.List[np.array]): + # Convert from list of lists to mother_bud sparse matrix + ncells = sum([len(t) for t in ma]) + mb_matrix = np.zeros((ncells, ncells), dtype=bool) + c = 0 + for cells in ma: + for d, m in enumerate(cells): + if m: + mb_matrix[c + d, c + m - 1] = True + + c += len(cells) + + return mb_matrix + + @staticmethod + def mother_assign_from_dynamic( + ma, cell_label: t.List[int], trap: t.List[int], ntraps: int + ): + """ + Interpolate the list of lists containing the associated mothers from the mother_assign_dynamic feature + """ + idlist = list(zip(trap, cell_label)) + cell_gid = np.unique(idlist, axis=0) + + last_lin_preds = [ + find_1st( + ((cell_label[::-1] == lbl) & (trap[::-1] == tr)), + True, + cmp_equal, + ) + for tr, lbl in cell_gid + ] + mother_assign_sorted = ma[::-1][last_lin_preds] + + traps = cell_gid[:, 0] + iterator = groupby(zip(traps, mother_assign_sorted), lambda x: x[0]) + d = {key: [x[1] for x in group] for key, group in iterator} + nested_massign = [d.get(i, []) for i in range(ntraps)] + + return nested_massign + + @lru_cache(maxsize=200) + def labelled_in_frame(self, frame: int, global_id=False) -> np.ndarray: + """ + Return labels in a ndarray with the global ids + with shape (ntraps, max_nlabels, ysize, xsize) + at a given frame. + + max_nlabels is specific for this frame, not + the entire experiment. + """ + labels_in_frame = self.labels_at_time(frame) + n_labels = [ + len(labels_in_frame.get(trap_id, [])) + for trap_id in range(self.ntraps) + ] + # maxes = self.max_labels_in_frame(frame) + stacks_in_frame = self.get_stacks_in_frame(frame, self.tile_size) + first_id = np.cumsum([0, *n_labels]) + labels_mat = np.zeros( + ( + self.ntraps, + max(n_labels), + *self.tile_size, + ), + dtype=int, + ) + for trap_id, masks in enumerate(stacks_in_frame): # new_axis = np.pad( + if trap_id in labels_in_frame: + new_axis = np.array(labels_in_frame[trap_id], dtype=int)[ + :, np.newaxis, np.newaxis + ] + global_id_masks = new_axis * masks + if global_id: + global_id_masks += first_id[trap_id] * masks + global_id_masks = np.pad( + global_id_masks, + pad_width=( + (0, labels_mat.shape[1] - global_id_masks.shape[0]), + (0, 0), + (0, 0), + ), + ) + labels_mat[trap_id] += global_id_masks + return labels_mat + + def get_stacks_in_frame(self, frame: int, tile_shape: t.Tuple[int]): + # Stack all cells in a trap-wise manner + masks = self.at_time(frame) + return [ + stack_masks_in_trap( + masks.get(trap_id, np.array([], dtype=bool)), tile_shape + ) + for trap_id in range(self.ntraps) + ] + + +def stack_masks_in_trap( + masks: t.List[np.ndarray], tile_shape: t.Tuple[int] +) -> np.ndarray: + # Stack all masks in a trap padding accordingly if no outlines found + result = np.zeros((0, *tile_shape), dtype=bool) + if len(masks): + result = np.array(masks) + return result diff --git a/io/metadata.py b/io/metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..d5a5c04e7ab173c47a9c45aa173964e0bc4b8373 --- /dev/null +++ b/io/metadata.py @@ -0,0 +1,116 @@ +import glob +import os +from datetime import datetime + +import pandas as pd +from pytz import timezone + +from agora.io.writer import Writer +from logfile_parser import Parser + + +class MetaData: + """Small metadata Process that loads log.""" + + def __init__(self, log_dir, store): + self.log_dir = log_dir + self.store = store + self.metadata_writer = Writer(self.store) + + def __getitem__(self, item): + return self.load_logs()[item] + + def load_logs(self): + parsed_flattened = parse_logfiles(self.log_dir) + return parsed_flattened + + def run(self, overwrite=False): + metadata_dict = self.load_logs() + self.metadata_writer.write( + path="/", meta=metadata_dict, overwrite=overwrite + ) + + def add_field(self, field_name, field_value, **kwargs): + self.metadata_writer.write( + path="/", + meta={field_name: field_value}, + **kwargs, + ) + + def add_fields(self, fields_values: dict, **kwargs): + for field, value in fields_values.items(): + self.add_field(field, value) + + +# Paradigm: able to do something with all datatypes present in log files, +# then pare down on what specific information is really useful later. + +# Needed because HDF5 attributes do not support dictionaries +def flatten_dict(nested_dict, separator="/"): + """ + Flattens nested dictionary + """ + df = pd.json_normalize(nested_dict, sep=separator) + return df.to_dict(orient="records")[0] + + +# Needed because HDF5 attributes do not support datetime objects +# Takes care of time zones & daylight saving +def datetime_to_timestamp(time, locale="Europe/London"): + """ + Convert datetime object to UNIX timestamp + """ + return timezone(locale).localize(time).timestamp() + + +def find_file(root_dir, regex): + file = glob.glob(os.path.join(str(root_dir), regex)) + if len(file) != 1: + return None + else: + return file[0] + + +# TODO: re-write this as a class if appropriate +# WARNING: grammars depend on the directory structure of a locally installed +# logfile_parser repo +def parse_logfiles( + root_dir, + acq_grammar="multiDGUI_acq_format.json", + log_grammar="multiDGUI_log_format.json", +): + """ + Parse acq and log files depending on the grammar specified, then merge into + single dict. + """ + # Both acq and log files contain useful information. + # ACQ_FILE = 'flavin_htb2_glucose_long_ramp_DelftAcq.txt' + # LOG_FILE = 'flavin_htb2_glucose_long_ramp_Delftlog.txt' + log_parser = Parser(log_grammar) + try: + log_file = find_file(root_dir, "*log.txt") + except FileNotFoundError: + raise ValueError("Experiment log file not found.") + with open(log_file, "r") as f: + log_parsed = log_parser.parse(f) + + acq_parser = Parser(acq_grammar) + try: + acq_file = find_file(root_dir, "*[Aa]cq.txt") + except FileNotFoundError: + raise ValueError("Experiment acq file not found.") + with open(acq_file, "r") as f: + acq_parsed = acq_parser.parse(f) + + parsed = {**acq_parsed, **log_parsed} + + for key, value in parsed.items(): + if isinstance(value, datetime): + parsed[key] = datetime_to_timestamp(value) + + parsed_flattened = flatten_dict(parsed) + for k, v in parsed_flattened.items(): + if isinstance(v, list): + parsed_flattened[k] = [0 if el is None else el for el in v] + + return parsed_flattened diff --git a/io/reader.py b/io/reader.py new file mode 100644 index 0000000000000000000000000000000000000000..b67934834a12564240f143ad12357a106fe449ad --- /dev/null +++ b/io/reader.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python3 +from pathlib import Path + +import h5py +import numpy as np + +from agora.io.bridge import groupsort +from agora.io.writer import load_attributes + + +class DynamicReader: + group = "" + + def __init__(self, file: str): + self.file = file + self.metadata = load_attributes(file) + + +class StateReader(DynamicReader): + """ + Analogous to StateWriter: + + + Possible cases (and data shapes): + - max_lbl (ntraps, 1) -> One int per trap. + - tp_back, trap, cell_label -> One int per cell_label-timepoint + - prev_feats -> A fixed number of floats per cell_label-timepoint (default is 9) + - lifetime, p_was_bud, p_is_mother -> (nTotalCells, 2) A (Ncells, 2) matrix where the first column is the trap, + and its index for such trap (+1) is its cell label. + - ba_cum ->. (2^n, 2^n, None) 3d array where the lineage score is contained for all traps - traps in the 3rd dimension 3d array where the lineage score is contained for all traps - traps in the 3rd dimension. + 2^n >= ncells, it is kept in powers of two for efficiency. + + """ + + data_types = {} + datatypes = { + "max_lbl": ((None, 1), np.uint16), + "tp_back": ((None, 1), np.uint16), + "trap": ((None, 1), np.int16), + "cell_lbls": ((None, 1), np.uint16), + "prev_feats": ((None, None), np.float64), + "lifetime": ((None, 2), np.uint16), + "p_was_bud": ((None, 2), np.float64), + "p_is_mother": ((None, 2), np.float64), + "ba_cum": ((None, None), np.float64), + } + group = "last_state" + + def __init__(self, file: str): + super().__init__(file) + + def format_tps(self): + pass + + def format_traps(self): + pass + + def format_bacum(self): + pass + + def read_raw(self, key, dtype): + with h5py.File(self.file, "r") as f: + raw = f[self.group + "/" + key][()].astype(dtype) + + return raw + + def read_all(self): + + self.raw_data = { + key: self.read_raw(key, dtype) + for key, (_, dtype) in self.datatypes.items() + } + + return self.raw_data + + def reconstruct_states(self, data: dict): + ntps_back = max(data["tp_back"]) + 1 + + from copy import copy + + tpback_as_idx = copy(data["tp_back"]) + trap_as_idx = copy(data["trap"]) + + states = {k: {"max_lbl": v} for k, v in enumerate(data["max_lbl"])} + for val_name in ("cell_lbls", "prev_feats"): + for k in states.keys(): + if val_name == "cell_lbls": + states[k][val_name] = [[] for _ in range(ntps_back)] + else: + states[k][val_name] = [ + np.zeros( + (0, data[val_name].shape[1]), dtype=np.float64 + ) + for _ in range(ntps_back) + ] + + data[val_name] = list( + zip(trap_as_idx, tpback_as_idx, data[val_name]) + ) + for k, v in groupsort(data[val_name]).items(): + states[k][val_name] = [ + np.array([w[0] for w in val]) + for val in groupsort(v).values() + ] + + for val_name in ("lifetime", "p_was_bud", "p_is_mother"): + for k in states.keys(): + states[k][val_name] = np.array([]) + # This contains no time points back + for k, v in groupsort(data[val_name]).items(): + states[k][val_name] = np.array([val[0] for val in v]) + + for trap_id, ba_matrix in enumerate(data["ba_cum"]): + states[trap_id]["ba_cum"] = np.array(ba_matrix, dtype=np.float64) + + return [val for val in states.values()] + + def get_formatted_states(self): + return self.reconstruct_states(self.read_all()) diff --git a/io/signal.py b/io/signal.py new file mode 100644 index 0000000000000000000000000000000000000000..673b13047b9d864294ae36fd9eebc96b365d2004 --- /dev/null +++ b/io/signal.py @@ -0,0 +1,357 @@ +import typing as t +from copy import copy +from pathlib import PosixPath + +import h5py +import numpy as np +import pandas as pd +from utils_find_1st import cmp_larger, find_1st + +from agora.io.bridge import BridgeH5 + + +class Signal(BridgeH5): + """ + Class that fetches data from the hdf5 storage for post-processing + + Signal is works under the assumption that metadata and data are + accessible, to perform time-adjustments and apply previously-recorded + postprocesses. + """ + + def __init__(self, file: t.Union[str, PosixPath]): + super().__init__(file, flag=None) + + self.names = ["experiment", "position", "trap"] + + def __getitem__(self, dsets: t.Union[str, t.Collection]): + + assert isinstance( + dsets, (str, t.Collection) + ), "Incorrect type for dset" + + if isinstance(dsets, str) and dsets.endswith("imBackground"): + df = self.get_raw(dsets) + + elif isinstance(dsets, str): + df = self.apply_prepost(dsets) + + elif isinstance(dsets, list): + is_bgd = [dset.endswith("imBackground") for dset in dsets] + assert sum(is_bgd) == 0 or sum(is_bgd) == len( + dsets + ), "Trap data and cell data can't be mixed" + return [ + self.add_name(self.apply_prepost(dset), dset) for dset in dsets + ] + + # return self.cols_in_mins(self.add_name(df, dsets)) + return self.add_name(df, dsets) + + @staticmethod + def add_name(df, name): + df.name = name + return df + + def cols_in_mins(self, df: pd.DataFrame): + # Convert numerical columns in a dataframe to minutes + try: + df.columns = (df.columns * self.tinterval // 60).astype(int) + except Exception as e: + print( + """ + Warning:Can't convert columns to minutes. Signal {}.{}""".format( + df.name, e + ) + ) + return df + + @property + def ntimepoints(self): + with h5py.File(self.filename, "r") as f: + return f["extraction/general/None/area/timepoint"][-1] + 1 + + @property + def tinterval(self) -> int: + tinterval_location = "time_settings/timeinterval" + with h5py.File(self.filename, "r") as f: + return f.attrs[tinterval_location][0] + + @staticmethod + def get_retained(df, cutoff): + return df.loc[df.notna().sum(axis=1) > df.shape[1] * cutoff] + + def retained(self, signal, cutoff=0.8): + + df = self[signal] + if isinstance(df, pd.DataFrame): + return self.get_retained(df, cutoff) + + elif isinstance(df, list): + return [self.get_retained(d, cutoff=cutoff) for d in df] + + def apply_prepost(self, dataset: str, skip_pick: t.Optional[bool] = None): + """ + Apply modifier operations (picker, merger) to a given dataframe. + """ + merges = self.get_merges() + df = self.get_raw(dataset) + merged = copy(df) + if merges.any(): + # Split in two dfs, one with rows relevant for merging and one + # without them + valid_merges = merges[ + ( + merges[:, :, :, None] + == np.array(list(df.index)).T[:, None, :] + ) + .all(axis=(1, 2)) + .any(axis=1) + ] # Casting allows fast multiindexing + + merged = self.apply_merge( + df.loc[map(tuple, valid_merges.reshape(-1, 2))], + valid_merges, + ) + + nonmergeable_ids = df.index.difference(valid_merges.reshape(-1, 2)) + + merged = pd.concat( + (merged, df.loc[nonmergeable_ids]), names=df.index.names + ) + + with h5py.File(self.filename, "r") as f: + if "modifiers/picks" in f and not skip_pick: + picks = self.get_picks(names=merged.index.names) + # missing_cells = [i for i in picks if tuple(i) not in + # set(merged.index)] + + if picks: + return merged.loc[ + set(picks).intersection( + [tuple(x) for x in merged.index] + ) + ] + return merged.loc[picks] + else: + if isinstance(merged.index, pd.MultiIndex): + empty_lvls = [[] for i in merged.index.names] + index = pd.MultiIndex( + levels=empty_lvls, + codes=empty_lvls, + names=merged.index.names, + ) + else: + index = pd.Index([], name=merged.index.name) + merged = pd.DataFrame([], index=index) + return merged + + @property + def datasets(self): + if not hasattr(self, "_siglist"): + self._siglist = [] + + with h5py.File(self.filename, "r") as f: + f.visititems(self.get_siglist) + + for sig in self.siglist: + print(sig) + + @property + def p_siglist(self): + """Print signal list""" + self.datasets + + @property + def siglist(self): + """Return list of signals""" + try: + if not hasattr(self, "_siglist"): + self._siglist = [] + with h5py.File(self.filename, "r") as f: + f.visititems(self.get_siglist) + except Exception as e: + print("Error visiting h5: {}".format(e)) + self._siglist = [] + + return self._siglist + + def get_merged(self, dataset): + return self.apply_prepost(dataset, skip_pick=True) + + @property + def merges(self): + with h5py.File(self.filename, "r") as f: + dsets = f.visititems(self._if_merges) + return dsets + + @property + def n_merges(self): + print("{} merge events".format(len(self.merges))) + + @property + def picks(self): + with h5py.File(self.filename, "r") as f: + dsets = f.visititems(self._if_picks) + return dsets + + def apply_merge(self, df, changes): + if len(changes): + + for target, source in changes: + df.loc[tuple(target)] = self.join_tracks_pair( + df.loc[tuple(target)], df.loc[tuple(source)] + ) + df.drop(tuple(source), inplace=True) + + return df + + def get_raw(self, dataset, in_minutes=True): + try: + if isinstance(dataset, str): + with h5py.File(self.filename, "r") as f: + df = self.dset_to_df(f, dataset) + if in_minutes: + df = self.cols_in_mins(df) + return df + elif isinstance(dataset, list): + return [self.get_raw(dset) for dset in dataset] + except Exception as e: + print(f"Could not fetch dataset {dataset}") + print(e) + + def get_merges(self): + # fetch merge events going up to the first level + with h5py.File(self.filename, "r") as f: + merges = f.get("modifiers/merges", np.array([])) + if not isinstance(merges, np.ndarray): + merges = merges[()] + + return merges + + # def get_picks(self, levels): + def get_picks(self, names, path="modifiers/picks/"): + with h5py.File(self.filename, "r") as f: + if path in f: + return list(zip(*[f[path + name] for name in names])) + # return f["modifiers/picks"] + else: + return None + + def dset_to_df(self, f, dataset): + dset = f[dataset] + names = copy(self.names) + if not dataset.endswith("imBackground"): + names.append("cell_label") + lbls = {lbl: dset[lbl][()] for lbl in names if lbl in dset.keys()} + index = pd.MultiIndex.from_arrays( + list(lbls.values()), names=names[-len(lbls) :] + ) + + columns = ( + dset["timepoint"][()] + if "timepoint" in dset + else dset.attrs["columns"] + ) + + df = pd.DataFrame(dset[("values")][()], index=index, columns=columns) + + return df + + @property + def stem(self): + return self.filename.stem + + @staticmethod + def dataset_to_df(f: h5py.File, path: str): + + all_indices = ["experiment", "position", "trap", "cell_label"] + indices = { + k: f[path][k][()] for k in all_indices if k in f[path].keys() + } + return pd.DataFrame( + f[path + "/values"][()], + index=pd.MultiIndex.from_arrays( + list(indices.values()), names=indices.keys() + ), + columns=f[path + "/timepoint"][()], + ) + + def get_siglist(self, name: str, node): + fullname = node.name + if isinstance(node, h5py.Group) and np.all( + [isinstance(x, h5py.Dataset) for x in node.values()] + ): + self._if_ext_or_post(fullname, self._siglist) + + @staticmethod + def _if_ext_or_post(name: str, siglist: list): + if name.startswith("/extraction") or name.startswith( + "/postprocessing" + ): + siglist.append(name) + + @staticmethod + def _if_merges(name: str, obj): + if isinstance(obj, h5py.Dataset) and name.startswith( + "modifiers/merges" + ): + return obj[()] + + @staticmethod + def _if_picks(name: str, obj): + if isinstance(obj, h5py.Group) and name.endswith("picks"): + return obj[()] + + @staticmethod + def join_tracks_pair(target: pd.Series, source: pd.Series): + """ + Join two tracks + """ + tgt_copy = copy(target) + end = find_1st(target.values[::-1], 0, cmp_larger) + tgt_copy.iloc[-end:] = source.iloc[-end:].values + return tgt_copy + + # TODO FUTURE add stages support to fluigent system + @property + def ntps(self) -> int: + # Return number of time-points according to the metadata + return self.meta_h5["time_settings/ntimepoints"][0] + + @property + def stages(self) -> t.List[str]: + """ + Return the contents of the pump with highest flow rate + at each stage. + """ + flowrate_name = "pumpinit/flowrate" + pumprate_name = "pumprate" + main_pump_id = np.concatenate( + ( + (np.argmax(self.meta_h5[flowrate_name]),), + np.argmax(self.meta_h5[pumprate_name], axis=0), + ) + ) + return [self.meta_h5["pumpinit/contents"][i] for i in main_pump_id] + + @property + def nstages(self) -> int: + switchtimes_name = "switchtimes" + return self.meta_h5[switchtimes_name] + 1 + + @property + def max_span(self) -> int: + return int(self.tinterval * self.ntps / 60) + + @property + def stages_span(self) -> t.Tuple[t.Tuple[str, int], ...]: + # Return consecutive stages and their corresponding number of time-points + switchtimes_name = "switchtimes" + transition_tps = (0, *self.meta_h5[switchtimes_name]) + spans = [ + end - start + for start, end in zip(transition_tps[:-1], transition_tps[1:]) + if end <= self.max_span + ] + return tuple((stage, ntps) for stage, ntps in zip(self.stages, spans)) diff --git a/io/utils.py b/io/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0acca82cb57fbab596d990a9e9d554f0c8b80344 --- /dev/null +++ b/io/utils.py @@ -0,0 +1,142 @@ +""" +Utility functions and classes +""" +import itertools +import logging +import operator +import typing as t +from functools import partial, wraps +from pathlib import Path +from time import perf_counter +from typing import Callable + +import cv2 +import h5py +import numpy as np + + +def repr_obj(obj, indent=0): + """ + Helper function to display info about OMERO objects. + Not all objects will have a "name" or owner field. + """ + string = """%s%s:%s Name:"%s" (owner=%s)""" % ( + " " * indent, + obj.OMERO_CLASS, + obj.getId(), + obj.getName(), + obj.getAnnotation(), + ) + + return string + + +def imread(path): + return cv2.imread(str(path), -1) + + +class ImageCache: + """HDF5-based image cache for faster loading of the images once they've + been read. + """ + + def __init__(self, file, name, shape, remote_fn): + self.store = h5py.File(file, "a") + # Create a dataset + self.dataset = self.store.create_dataset( + name, shape, dtype=np.float, fill_value=np.nan + ) + self.remote_fn = remote_fn + + def __getitem__(self, item): + cached = self.dataset[item] + if np.any(np.isnan(cached)): + full = self.remote_fn(item) + self.dataset[item] = full + return full + else: + return cached + + +class Cache: + """ + Fixed-length mapping to use as a cache. + Deletes items in FIFO manner when maximum allowed length is reached. + """ + + def __init__(self, max_len=5000, load_fn: Callable = imread): + """ + :param max_len: Maximum number of items in the cache. + :param load_fn: The function used to load new items if they are not + available in the Cache + """ + self._dict = dict() + self._queue = [] + self.load_fn = load_fn + self.max_len = max_len + + def __getitem__(self, item): + if item not in self._dict: + self.load_item(item) + return self._dict[item] + + def load_item(self, item): + self._dict[item] = self.load_fn(item) + # Clean up the queue + self._queue.append(item) + if len(self._queue) > self.max_len: + del self._dict[self._queue.pop(0)] + + def clear(self): + self._dict.clear() + self._queue.clear() + + +def accumulate(list_: list) -> t.Generator: + """Accumulate list based on the first value""" + list_ = sorted(list_) + it = itertools.groupby(list_, operator.itemgetter(0)) + for key, sub_iter in it: + yield key, [x[1] for x in sub_iter] + + +def get_store_path(save_dir, store, name): + """Create a path to a position-specific store. + + This combines the name and the store's base name into a file path within save_dir. + For example. + >>> get_store_path('data', 'baby_seg.h5', 'pos001') + Path(data/pos001baby_seg.h5') + + :param save_dir: The root directory in which to save the file, absolute + path. + :param store: The base name of the store + :param name: The name of the position + :return: Path(save_dir) / name+store + """ + store = Path(save_dir) / store + store = store.with_name(name + store.name) + return store + + +def parametrized(dec): + def layer(*args, **kwargs): + def repl(f): + return dec(f, *args, **kwargs) + + return repl + + return layer + + +@parametrized +def timed(f, name=None): + @wraps(f) + def decorated(*args, **kwargs): + t = perf_counter() + res = f(*args, **kwargs) + to_print = name or f.__name__ + logging.debug(f"Timing:{to_print}:{perf_counter() - t}s") + return res + + return decorated diff --git a/io/writer.py b/io/writer.py new file mode 100644 index 0000000000000000000000000000000000000000..d3030bfd572e255080c8b97e9d7a983f968baada --- /dev/null +++ b/io/writer.py @@ -0,0 +1,787 @@ +import itertools +import logging +from collections.abc import Iterable +from time import perf_counter +from typing import Dict + +import h5py +import numpy as np +import pandas as pd +import yaml +from utils_find_1st import cmp_equal, find_1st + +from agora.io.bridge import BridgeH5 +from agora.io.utils import timed + +#################### Dynamic version ################################## + + +def load_attributes(file: str, group="/"): + with h5py.File(file, "r") as f: + meta = dict(f[group].attrs.items()) + if "parameters" in meta: + meta["parameters"] = yaml.safe_load(meta["parameters"]) + return meta + + +class DynamicWriter: + data_types = {} + group = "" + compression = "gzip" + compression_opts = 9 + + def __init__(self, file: str): + self.file = file + self.metadata = load_attributes(file) + + def _append(self, data, key, hgroup): + """Append data to existing dataset.""" + try: + n = len(data) + except Exception as e: + logging.debug( + "DynamicWriter:Attributes have no length: {}".format(e) + ) + n = 1 + if key not in hgroup: + # TODO Include sparsity check + max_shape, dtype = self.datatypes[key] + shape = (n,) + max_shape[1:] + hgroup.create_dataset( + key, + shape=shape, + maxshape=max_shape, + dtype=dtype, + compression=self.compression, + compression_opts=self.compression_opts + if self.compression is not None + else None, + ) + hgroup[key][()] = data + else: + # The dataset already exists, expand it + + try: # FIXME This is broken by bugged mother-bud assignment + dset = hgroup[key] + dset.resize(dset.shape[0] + n, axis=0) + dset[-n:] = data + except Exception as e: + logging.debug( + "DynamicWriter:Inconsistency between dataset shape and new empty data: {}".format( + e + ) + ) + return + + def _overwrite(self, data, key, hgroup): + """Overwrite existing dataset with new data""" + # We do not append to mother_assign; raise error if already saved + data_shape = np.shape(data) + max_shape, dtype = self.datatypes[key] + if key in hgroup: + del hgroup[key] + hgroup.require_dataset( + key, shape=data_shape, dtype=dtype, compression=self.compression + ) + hgroup[key][()] = data + + def _check_key(self, key): + if key not in self.datatypes: + raise KeyError(f"No defined data type for key {key}") + + def write(self, data, overwrite: list, meta={}): + # Data is a dictionary, if not, make it one + # Overwrite data is a list + with h5py.File(self.file, "a") as store: + hgroup = store.require_group(self.group) + + for key, value in data.items(): + # We're only saving data that has a pre-defined data-type + self._check_key(key) + try: + if key.startswith("attrs/"): # metadata + key = key.split("/")[1] # First thing after attrs + hgroup.attrs[key] = value + elif key in overwrite: + self._overwrite(value, key, hgroup) + else: + self._append(value, key, hgroup) + except Exception as e: + print(key, value) + raise (e) + for key, value in meta.items(): + hgroup.attrs[key] = value + + return + + +##################### Special instances ##################### +class TilerWriter(DynamicWriter): + datatypes = { + "trap_locations": ((None, 2), np.uint16), + "drifts": ((None, 2), np.float32), + "attrs/tile_size": ((1,), np.uint16), + "attrs/max_size": ((1,), np.uint16), + } + group = "trap_info" + + def write(self, data, overwrite: list, tp: int, meta={}): + """ + Skips writing data if it were to overwrite it,using drift as a marker + """ + + skip = False + with h5py.File(self.file, "a") as store: + hgroup = store.require_group(self.group) + + nprev = hgroup.get("drifts", None) + if nprev and tp < nprev.shape[0]: + print(f"Tiler: Skipping timepoint {tp}") + skip = True + + if not skip: + super().write(data=data, overwrite=overwrite, meta=meta) + + +tile_size = 117 + + +@timed() +def save_complex(array, dataset): + # Dataset needs to be 2D + n = len(array) + if n > 0: + dataset.resize(dataset.shape[0] + n, axis=0) + dataset[-n:, 0] = array.real + dataset[-n:, 1] = array.imag + + +@timed() +def load_complex(dataset): + array = dataset[:, 0] + 1j * dataset[:, 1] + return array + + +class BabyWriter(DynamicWriter): + compression = "gzip" + max_ncells = 2e5 # Could just make this None + max_tps = 1e3 # Could just make this None + chunk_cells = 25 # The number of cells in a chunk for edge masks + default_tile_size = 117 + datatypes = { + "centres": ((None, 2), np.uint16), + "position": ((None,), np.uint16), + "angles": ((None,), h5py.vlen_dtype(np.float32)), + "radii": ((None,), h5py.vlen_dtype(np.float32)), + "edgemasks": ((max_ncells, max_tps, tile_size, tile_size), bool), + "ellipse_dims": ((None, 2), np.float32), + "cell_label": ((None,), np.uint16), + "trap": ((None,), np.uint16), + "timepoint": ((None,), np.uint16), + # "mother_assign": ((None,), h5py.vlen_dtype(np.uint16)), + "mother_assign_dynamic": ((None,), np.uint16), + "volumes": ((None,), np.float32), + } + group = "cell_info" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Get max_tps and trap info + self._traps_initialised = False + + def __init_trap_info(self): + # Should only be run after the traps have been initialised + trap_metadata = load_attributes(self.file, "trap_info") + tile_size = trap_metadata.get("tile_size", self.default_tile_size) + max_tps = self.metadata["time_settings/ntimepoints"][0] + self.datatypes["edgemasks"] = ( + (self.max_ncells, max_tps, tile_size, tile_size), + bool, + ) + self._traps_initialised = True + + def __init_edgemasks(self, hgroup, edgemasks, current_indices, n_cells): + # Create values dataset + # This holds the edge masks directly and + # Is of shape (n_tps, n_cells, tile_size, tile_size) + key = "edgemasks" + max_shape, dtype = self.datatypes[key] + shape = (n_cells, 1) + max_shape[2:] + chunks = (self.chunk_cells, 1) + max_shape[2:] + val_dset = hgroup.create_dataset( + "values", + shape=shape, + maxshape=max_shape, + dtype=dtype, + chunks=chunks, + compression=self.compression, + ) + val_dset[:, 0] = edgemasks + # Create index dataset + # Holds the (trap, cell_id) description used to index into the + # values and is of shape (n_cells, 2) + ix_max_shape = (max_shape[0], 2) + ix_shape = (0, 2) + ix_dtype = np.uint16 + ix_dset = hgroup.create_dataset( + "indices", + shape=ix_shape, + maxshape=ix_max_shape, + dtype=ix_dtype, + compression=self.compression, + ) + save_complex(current_indices, ix_dset) + + def __append_edgemasks(self, hgroup, edgemasks, current_indices): + # key = "edgemasks" + val_dset = hgroup["values"] + ix_dset = hgroup["indices"] + existing_indices = load_complex(ix_dset) + # Check if there are any new labels + available = np.in1d(current_indices, existing_indices) + missing = current_indices[~available] + all_indices = np.concatenate([existing_indices, missing]) + # Resizing + t = perf_counter() + n_tps = val_dset.shape[1] + 1 + n_add_cells = len(missing) + # RESIZE DATASET FOR TIME and Cells + new_shape = (val_dset.shape[0] + n_add_cells, n_tps) + val_dset.shape[ + 2: + ] + val_dset.resize(new_shape) + logging.debug(f"Timing:resizing:{perf_counter() - t}") + # Writing data + cell_indices = np.where(np.in1d(all_indices, current_indices))[0] + for ix, mask in zip(cell_indices, edgemasks): + try: + val_dset[ix, n_tps - 1] = mask + except Exception as e: + logging.debug( + "Exception: {}:{}, {}, {}".format( + e, ix, n_tps, val_dset.shape + ) + ) + # Save the index values + save_complex(missing, ix_dset) + + def write_edgemasks(self, data, keys, hgroup): + if not self._traps_initialised: + self.__init_trap_info() + # DATA is TRAP_IDS, CELL_LABELS, EDGEMASKS in a structured array + key = "edgemasks" + val_key = "values" + # idx_key = "indices" + # Length of edgemasks + traps, cell_labels, edgemasks = data + n_cells = len(cell_labels) + hgroup = hgroup.require_group(key) + current_indices = np.array(traps) + 1j * np.array(cell_labels) + if val_key not in hgroup: + self.__init_edgemasks(hgroup, edgemasks, current_indices, n_cells) + else: + self.__append_edgemasks(hgroup, edgemasks, current_indices) + + def write(self, data, overwrite: list, tp: int = None, meta={}): + with h5py.File(self.file, "a") as store: + hgroup = store.require_group(self.group) + + for key, value in data.items(): + # We're only saving data that has a pre-defined data-type + self._check_key(key) + try: + if key.startswith("attrs/"): # metadata + key = key.split("/")[1] # First thing after attrs + hgroup.attrs[key] = value + elif key in overwrite: + self._overwrite(value, key, hgroup) + elif key == "edgemasks": + keys = ["trap", "cell_label", "edgemasks"] + value = [data[x] for x in keys] + + edgemask_dset = hgroup.get(key + "/values", None) + if ( + # tp > 0 + edgemask_dset + and tp < edgemask_dset[()].shape[1] + ): + print(f"BabyWriter: Skipping edgemasks in tp {tp}") + else: + # print(f"BabyWriter: Writing edgemasks in tp {tp}") + self.write_edgemasks(value, keys, hgroup) + else: + self._append(value, key, hgroup) + except Exception as e: + print(key, value) + raise (e) + + # Meta + for key, value in meta.items(): + hgroup.attrs[key] = value + + return + + +class LinearBabyWriter(DynamicWriter): + # TODO make this YAML + compression = "gzip" + datatypes = { + "centres": ((None, 2), np.uint16), + "position": ((None,), np.uint16), + "angles": ((None,), h5py.vlen_dtype(np.float32)), + "radii": ((None,), h5py.vlen_dtype(np.float32)), + "edgemasks": ((None, tile_size, tile_size), bool), + "ellipse_dims": ((None, 2), np.float32), + "cell_label": ((None,), np.uint16), + "trap": ((None,), np.uint16), + "timepoint": ((None,), np.uint16), + # "mother_assign": ((None,), h5py.vlen_dtype(np.uint16)), + "mother_assign_dynamic": ((None,), np.uint16), + "volumes": ((None,), np.float32), + } + group = "cell_info" + + def write(self, data, overwrite: list, tp=None, meta={}): + # Data is a dictionary, if not, make it one + # Overwrite data is a list + + with h5py.File(self.file, "a") as store: + hgroup = store.require_group(self.group) + available_tps = hgroup.get("timepoint", None) + if not available_tps or tp not in np.unique(available_tps[()]): + super().write(data, overwrite) + else: + print(f"BabyWriter: Skipping tp {tp}") + + for key, value in meta.items(): + hgroup.attrs[key] = value + + +class StateWriter(DynamicWriter): + datatypes = { + "max_lbl": ((None, 1), np.uint16), + "tp_back": ((None, 1), np.uint16), + "trap": ((None, 1), np.int16), + "cell_lbls": ((None, 1), np.uint16), + "prev_feats": ((None, None), np.float32), + "lifetime": ((None, 2), np.uint16), + "p_was_bud": ((None, 2), np.float32), + "p_is_mother": ((None, 2), np.float32), + "ba_cum": ((None, None), np.float32), + } + group = "last_state" + compression = "gzip" + + @staticmethod + def format_field(states: list, field: str): + # Flatten a field in the states list to save as an hdf5 dataset + fields = [pos_state[field] for pos_state in states] + return fields + + @staticmethod + def format_values_tpback(states: list, val_name: str): + tp_back, trap, value = [ + [[] for _ in states[0][val_name]] for _ in range(3) + ] + + lbl_tuples = [ + (tp_back, trap, cell_label) + for trap, state in enumerate(states) + for tp_back, value in enumerate(state[val_name]) + for cell_label in value + ] + if len(lbl_tuples): + tp_back, trap, value = zip(*lbl_tuples) + + return tp_back, trap, value + + @staticmethod + def format_values_traps(states: list, val_name: str): + formatted = np.array( + [ + (trap, clabel_val) + for trap, state in enumerate(states) + for clabel_val in state[val_name] + ] + ) + return formatted + + @staticmethod + def pad_if_needed(array: np.ndarray, pad_size: int): + padded = np.zeros((pad_size, pad_size)).astype(float) + length = len(array) + padded[:length, :length] = array + + return padded + + def format_states(self, states: list): + formatted_state = {"max_lbl": [state["max_lbl"] for state in states]} + tp_back, trap, cell_label = self.format_values_tpback( + states, "cell_lbls" + ) + _, _, prev_feats = self.format_values_tpback(states, "prev_feats") + + # Heterogeneous datasets + formatted_state["tp_back"] = tp_back + formatted_state["trap"] = trap + formatted_state["cell_lbls"] = cell_label + formatted_state["prev_feats"] = np.array(prev_feats) + + # One entry per cell label - tp_back independent + for val_name in ("lifetime", "p_was_bud", "p_is_mother"): + formatted_state[val_name] = self.format_values_traps( + states, val_name + ) + + bacum_max = max([len(state["ba_cum"]) for state in states]) + + formatted_state["ba_cum"] = np.array( + [ + self.pad_if_needed(state["ba_cum"], bacum_max) + for state in states + ] + ) + + return formatted_state + + def write(self, data, overwrite: Iterable, tp: int = None): + # formatted_data = self.format_states(data) + # super().write(data=formatted_data, overwrite=overwrite) + if len(data): + last_tp = 0 + if tp is None: + tp = 0 + + try: + with h5py.File(self.file, "r") as f: + gr = f.get(self.group, None) + if gr: + last_tp = gr.attrs.get("tp", 0) + + # print(f"{ self.file } - tp: {tp}, last_tp: {last_tp}") + if tp == 0 or tp > last_tp: + # print(f"Writing timepoint {tp}") + formatted_data = self.format_states(data) + super().write(data=formatted_data, overwrite=overwrite) + with h5py.File(self.file, "a") as f: + # print(f"Writing tp {tp}") + f[self.group].attrs["tp"] = tp + elif tp > 0 and tp <= last_tp: + print(f"BabyWriter: Skipping timepoint {tp}") + except Exception as e: + raise (e) + else: + print("Skipping overwriting empty state") + + +#################### Extraction version ############################### +class Writer(BridgeH5): + """ + Class in charge of transforming data into compatible formats + + Decoupling interface from implementation! + + Parameters + ---------- + filename: str Name of file to write into + flag: str, default=None + Flag to pass to the default file reader. If None the file remains closed. + compression: str, default=None + Compression method passed on to h5py writing functions (only used for + dataframes and other array-like data.) + """ + + def __init__(self, filename, compression=None): + super().__init__(filename, flag=None) + + if compression is None: + self.compression = "gzip" + + def write( + self, + path: str, + data: Iterable = None, + meta: Dict = {}, + overwrite: str = None, + ): + """ + Parameters + ---------- + path : str + Path inside h5 file to write into. + data : Iterable, default = None + meta : Dict, default = {} + + """ + self.id_cache = {} + with h5py.File(self.filename, "a") as f: + if overwrite == "overwrite": # TODO refactor overwriting + if path in f: + del f[path] + # elif overwrite == "accumulate": # Add a number if needed + # if path in f: + # parent, name = path.rsplit("/", maxsplit=1) + # n = sum([x.startswith(name) for x in f[path]]) + # path = path + str(n).zfill(3) + # elif overwrite == "skip": + # if path in f: + # logging.debug("Skipping dataset {}".format(path)) + + logging.debug( + "{} {} to {} and {} metadata fields".format( + overwrite, type(data), path, len(meta) + ) + ) + if data is not None: + self.write_dset(f, path, data) + if meta: + for attr, metadata in meta.items(): + self.write_meta(f, path, attr, data=metadata) + + def write_dset(self, f: h5py.File, path: str, data: Iterable): + if isinstance(data, pd.DataFrame): + self.write_pd(f, path, data, compression=self.compression) + elif isinstance(data, pd.MultiIndex): + self.write_index(f, path, data) # , compression=self.compression) + elif isinstance(data, Dict) and np.all( + [isinstance(x, pd.DataFrame) for x in data.values] + ): + for k, df in data.items(): + self.write_dset(f, path + f"/{k}", df) + elif isinstance(data, Iterable): + self.write_arraylike(f, path, data) + else: + self.write_atomic(data, f, path) + + def write_meta(self, f: h5py.File, path: str, attr: str, data: Iterable): + obj = f.require_group(path) + + obj.attrs[attr] = data + + @staticmethod + def write_arraylike(f: h5py.File, path: str, data: Iterable, **kwargs): + if path in f: + del f[path] + + narray = np.array(data) + + chunks = None + if narray.any(): + chunks = (1, *narray.shape[1:]) + + dset = f.create_dataset( + path, + shape=narray.shape, + chunks=chunks, + dtype="int", + compression=kwargs.get("compression", None), + ) + dset[()] = narray + + @staticmethod + def write_index(f, path, pd_index, **kwargs): + f.require_group(path) # TODO check if we can remove this + for i, name in enumerate(pd_index.names): + ids = pd_index.get_level_values(i) + id_path = path + "/" + name + f.create_dataset( + name=id_path, + shape=(len(ids),), + dtype="uint16", + compression=kwargs.get("compression", None), + ) + indices = f[id_path] + indices[()] = ids + + def write_pd(self, f, path, df, **kwargs): + values_path = ( + path + "values" if path.endswith("/") else path + "/values" + ) + if path not in f: + max_ncells = 2e5 + + max_tps = 1e3 + f.create_dataset( + name=values_path, + shape=df.shape, + # chunks=(min(df.shape[0], 1), df.shape[1]), + # dtype=df.dtypes.iloc[0], This is making NaN in ints into negative vals + dtype="float", + maxshape=(max_ncells, max_tps), + compression=kwargs.get("compression", None), + ) + dset = f[values_path] + dset[()] = df.values + + for name in df.index.names: + indices_path = "/".join((path, name)) + f.create_dataset( + name=indices_path, + shape=(len(df),), + dtype="uint16", # Assuming we'll always use int indices + chunks=True, + maxshape=(max_ncells,), + ) + dset = f[indices_path] + dset[()] = df.index.get_level_values(level=name).tolist() + + if ( + df.columns.dtype == np.int + or df.columns.dtype == np.dtype("uint") + or df.columns.name == "timepoint" + ): + tp_path = path + "/timepoint" + f.create_dataset( + name=tp_path, + shape=(df.shape[1],), + maxshape=(max_tps,), + dtype="uint16", + ) + tps = list(range(df.shape[1])) + f[tp_path][tps] = tps + else: + f[path].attrs["columns"] = df.columns.tolist() + else: + dset = f[values_path] + + # Filter out repeated timepoints + new_tps = set(df.columns) + if path + "/timepoint" in f: + new_tps = new_tps.difference(f[path + "/timepoint"][()]) + df = df[new_tps] + + if ( + not hasattr(self, "id_cache") + or df.index.nlevels not in self.id_cache + ): # Use cache dict to store previously-obtained indices + self.id_cache[df.index.nlevels] = {} + existing_ids = self.get_existing_ids( + f, [path + "/" + x for x in df.index.names] + ) + # Split indices in existing and additional + new = df.index.tolist() + if ( + df.index.nlevels == 1 + ): # Cover for cases with a single index + new = [(x,) for x in df.index.tolist()] + ( + found_multis, + self.id_cache[df.index.nlevels]["additional_multis"], + ) = self.find_ids( + existing=existing_ids, + new=new, + ) + found_indices = np.array( + locate_indices(existing_ids, found_multis) + ) + + # We must sort our indices for h5py indexing + incremental_existing = np.argsort(found_indices) + self.id_cache[df.index.nlevels][ + "found_indices" + ] = found_indices[incremental_existing] + self.id_cache[df.index.nlevels]["found_multi"] = found_multis[ + incremental_existing + ] + + existing_values = df.loc[ + [ + _tuple_or_int(x) + for x in self.id_cache[df.index.nlevels]["found_multi"] + ] + ].values + new_values = df.loc[ + [ + _tuple_or_int(x) + for x in self.id_cache[df.index.nlevels][ + "additional_multis" + ] + ] + ].values + ncells, ntps = f[values_path].shape + + # Add found cells + dset.resize(dset.shape[1] + df.shape[1], axis=1) + dset[:, ntps:] = np.nan + for i, tp in enumerate(df.columns): + dset[ + self.id_cache[df.index.nlevels]["found_indices"], tp + ] = existing_values[:, i] + # Add new cells + n_newcells = len( + self.id_cache[df.index.nlevels]["additional_multis"] + ) + dset.resize(dset.shape[0] + n_newcells, axis=0) + dset[ncells:, :] = np.nan + + for i, tp in enumerate(df.columns): + dset[ncells:, tp] = new_values[:, i] + + # save indices + for i, name in enumerate(df.index.names): + tmp = path + "/" + name + dset = f[tmp] + n = dset.shape[0] + dset.resize(n + n_newcells, axis=0) + dset[n:] = ( + self.id_cache[df.index.nlevels]["additional_multis"][:, i] + if len( + self.id_cache[df.index.nlevels][ + "additional_multis" + ].shape + ) + > 1 + else self.id_cache[df.index.nlevels]["additional_multis"] + ) + + tmp = path + "/timepoint" + dset = f[tmp] + n = dset.shape[0] + dset.resize(n + df.shape[1], axis=0) + dset[n:] = df.columns.tolist() + + @staticmethod + def get_existing_ids(f, paths): + # Fetch indices and convert them to a (nentries, nlevels) ndarray + return np.array([f[path][()] for path in paths]).T + + @staticmethod + def find_ids(existing, new): + # Compare two tuple sets and return the intersection and difference + # (elements in the 'new' set not in 'existing') + set_existing = set([tuple(*x) for x in zip(existing.tolist())]) + existing_cells = np.array(list(set_existing.intersection(new))) + new_cells = np.array(list(set(new).difference(set_existing))) + + return ( + existing_cells, + new_cells, + ) + + +# @staticmethod +def locate_indices(existing, new): + if new.any(): + if new.shape[1] > 1: + return [ + find_1st( + (existing[:, 0] == n[0]) & (existing[:, 1] == n[1]), + True, + cmp_equal, + ) + for n in new + ] + else: + return [ + find_1st(existing[:, 0] == n, True, cmp_equal) for n in new + ] + else: + return [] + + +def _tuple_or_int(x): + # Convert tuple to int if it only contains one value + if len(x) == 1: + return x[0] + else: + return x diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/utils/example.py b/utils/example.py new file mode 100644 index 0000000000000000000000000000000000000000..e3ff571acc0cb780e2feb92007111dbe041a91fb --- /dev/null +++ b/utils/example.py @@ -0,0 +1,53 @@ +"""This is an example module to show the structure.""" +from typing import Union + + +class ExampleClass: + """This is an example class to show the structure.""" + + def __init__(self, parameter: int): + """This class takes one parameter and is used to add one to that + parameter. + + :param parameter: The parameter for this class + """ + self.parameter = parameter + + def add_one(self): + """Takes the parameter and adds one. + + >>> x = ExampleClass(1) + >>> x.add_one() + 2 + + :return: the parameter + 1 + """ + return self.parameter + 1 + + def add_n(self, n: int): + """Adds n to the class instance's parameter. + + For instance + >>> x = ExampleClass(1) + >>> x.add_n(10) + 11 + + :param n: The number to add + :return: the parameter + n + """ + return self.parameter + n + + +def example_function(parameter: Union[int, str]): + """This is a factory function for an ExampleClass. + + :param parameter: the parameter to give to the example class + :return: An example class + """ + try: + return ExampleClass(int(parameter)) + except ValueError as e: + raise ValueError( + f"The parameter {parameter} could not be turned " + f"into an integer." + ) from e diff --git a/utils/lineage.py b/utils/lineage.py new file mode 100644 index 0000000000000000000000000000000000000000..b72c69cda3243893f8ae7dd31521c280e342b456 --- /dev/null +++ b/utils/lineage.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +import numpy as np +import pandas as pd + +from agora.io.bridge import groupsort + + +def mb_array_to_dict(mb_array: np.ndarray): + """ + Convert a lineage ndarray (trap, mother_id, daughter_id) + into a dictionary of lists ( mother_id ->[daughters_ids] ) + """ + return { + (trap, mo): [(trap, d[0]) for d in daughters] + for trap, mo_da in groupsort(mb_array).items() + for mo, daughters in groupsort(mo_da).items() + } + + +def mb_array_to_indices(mb_array: np.ndarray): + """ + Convert a lineage ndarray (trap, mother_id, daughter_id) + into a dictionary of lists ( mother_id ->[daughters_ids] ) + """ + return pd.MultiIndex.from_arrays(mb_array[:, :2].T).union( + pd.MultiIndex.from_arrays(mb_array[:, [0, 2]].T) + )