diff --git a/aliby/__init__.py b/aliby/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/aliby/baby_client.py b/aliby/baby_client.py new file mode 100644 index 0000000000000000000000000000000000000000..4dd610796e0fe73ea52d7bf96e82ca94336cc2d1 --- /dev/null +++ b/aliby/baby_client.py @@ -0,0 +1,268 @@ +import collections +import itertools +import json +import time +from pathlib import Path +from typing import Iterable + +import h5py +import numpy as np +import pandas as pd +import re +import requests +import tensorflow as tf +from tqdm import tqdm + +from agora.base import ParametersABC, ProcessABC +import baby.errors +from baby import modelsets +from baby.brain import BabyBrain +from baby.crawler import BabyCrawler +from requests.exceptions import Timeout, HTTPError +from requests_toolbelt.multipart.encoder import MultipartEncoder + +from aliby.utils import Cache, accumulate, get_store_path + + +################### Dask Methods ################################ +def format_segmentation(segmentation, tp): + """Format a single timepoint into a dictionary. + + Parameters + ------------ + segmentation: list + A list of results, each result is the output of the crawler, which is JSON-encoded + tp: int + the time point considered + + Returns + -------- + A dictionary containing the formatted results of BABY + """ + # Segmentation is a list of dictionaries, ordered by trap + # Add trap information + # mother_assign = None + for i, x in enumerate(segmentation): + x["trap"] = [i] * len(x["cell_label"]) + x["mother_assign_dynamic"] = np.array(x["mother_assign"])[ + np.array(x["cell_label"], dtype=int) - 1 + ] + # Merge into a dictionary of lists, by column + merged = { + k: list(itertools.chain.from_iterable(res[k] for res in segmentation)) + for k in segmentation[0].keys() + } + # Special case for mother_assign + # merged["mother_assign_dynamic"] = [merged["mother_assign"]] + if "mother_assign" in merged: + del merged["mother_assign"] + mother_assign = [x["mother_assign"] for x in segmentation] + # Check that the lists are all of the same length (in case of errors in + # BABY) + n_cells = min([len(v) for v in merged.values()]) + merged = {k: v[:n_cells] for k, v in merged.items()} + merged["timepoint"] = [tp] * n_cells + merged["mother_assign"] = mother_assign + return merged + + +class BabyParameters(ParametersABC): + def __init__( + self, + model_config, + tracker_params, + clogging_thresh, + min_bud_tps, + isbud_thresh, + session, + graph, + print_info, + suppress_errors, + error_dump_dir, + tf_version, + ): + self.model_config = model_config + self.tracker_params = tracker_params + self.clogging_thresh = clogging_thresh + self.min_bud_tps = min_bud_tps + self.isbud_thresh = isbud_thresh + self.session = session + self.graph = graph + self.print_info = print_info + self.suppress_errors = suppress_errors + self.error_dump_dir = error_dump_dir + self.tf_version = tf_version + + @classmethod + def default(cls, **kwargs): + """kwargs passes values to the model chooser""" + return cls( + model_config=choose_model_from_params(**kwargs), + tracker_params=dict(ctrack_params=dict(), budtrack_params=dict()), + clogging_thresh=1, + min_bud_tps=3, + isbud_thresh=0.5, + session=None, + graph=None, + print_info=False, + suppress_errors=False, + error_dump_dir=None, + tf_version=2, + ) + + +class BabyRunner: + """A BabyRunner object for cell segmentation. + + Does segmentation one time point at a time.""" + + def __init__(self, tiler, parameters=None, *args, **kwargs): + self.tiler = tiler + # self.model_config = modelsets()[choose_model_from_params(**kwargs)] + self.model_config = modelsets()[ + ( + parameters.model_config + if parameters is not None + else choose_model_from_params(**kwargs) + ) + ] + self.brain = BabyBrain(**self.model_config) + self.crawler = BabyCrawler(self.brain) + self.bf_channel = self.tiler.get_channel_index("Brightfield") + + @classmethod + def from_tiler(cls, parameters: BabyParameters, tiler): + return cls(tiler, parameters) + + def get_data(self, tp): + # Swap axes x and z, probably shouldn't swap, just move z + return self.tiler.get_tp_data(tp, self.bf_channel).swapaxes(1, 3).swapaxes(1, 2) + + def run_tp(self, tp, with_edgemasks=True, assign_mothers=True, **kwargs): + """Simulating processing time with sleep""" + # Access the image + img = self.get_data(tp) + segmentation = self.crawler.step( + img, with_edgemasks=with_edgemasks, assign_mothers=assign_mothers, **kwargs + ) + return format_segmentation(segmentation, tp) + + +class BabyClient: + """A dummy BabyClient object for Dask Demo. + + + Does segmentation one time point at a time. + Should work better with the parallelisation. + """ + + bf_channel = 0 + model_name = "prime95b_brightfield_60x_5z" + url = "http://localhost:5101" + max_tries = 50 + sleep_time = 0.1 + + def __init__(self, tiler, *args, **kwargs): + self.tiler = tiler + self._session = None + + @property + def session(self): + if self._session is None: + r_session = requests.get(self.url + f"/session/{self.model_name}") + r_session.raise_for_status() + self._session = r_session.json()["sessionid"] + return self._session + + def get_data(self, tp): + return self.tiler.get_tp_data(tp, self.bf_channel).swapaxes(1, 3) + + def queue_image(self, img, **kwargs): + bit_depth = img.dtype.itemsize * 8 # bit depth = byte_size * 8 + data = create_request(img.shape, bit_depth, img, **kwargs) + status = requests.post( + self.url + f"/segment?sessionid={self.session}", + data=data, + headers={"Content-Type": data.content_type}, + ) + status.raise_for_status() + return status + + def get_segmentation(self): + try: + seg_response = requests.get( + self.url + f"/segment?sessionid={self.session}", timeout=120 + ) + seg_response.raise_for_status() + result = seg_response.json() + except Timeout as e: + raise e + except HTTPError as e: + raise e + return result + + def run_tp(self, tp, **kwargs): + # Get data + img = self.get_data(tp) + # Queue image + status = self.queue_image(img, **kwargs) + # Get segmentation + for _ in range(self.max_tries): + try: + seg = self.get_segmentation() + break + except (Timeout, HTTPError): + time.sleep(self.sleep_time) + continue + return format_segmentation(seg, tp) + + +def choose_model_from_params( + modelset_filter=None, + camera="prime95b", + channel="brightfield", + zoom="60x", + n_stacks="5z", + **kwargs, +): + """ + Define which model to query from the server based on a set of parameters. + + Parameters + ---------- + valid_models: List[str] + The names of the models that are available. + modelset_filter: str + A regex filter to apply on the models to start. + camera: str + The camera used in the experiment (case insensitive). + channel:str + The channel used for segmentation (case insensitive). + zoom: str + The zoom on the channel. + n_stacks: str + The number of z_stacks to use in segmentation + + Returns + ------- + model_name : str + """ + valid_models = list(modelsets().keys()) + + # Apply modelset filter if specified + if modelset_filter is not None: + msf_regex = re.compile(modelset_filter) + valid_models = filter(msf_regex.search, valid_models) + + # Apply parameter filters if specified + params = [ + str(x) if x is not None else ".+" + for x in [camera.lower(), channel.lower(), zoom, n_stacks] + ] + params_re = re.compile("^" + "_".join(params) + "$") + valid_models = list(filter(params_re.search, valid_models)) + # Check that there are valid models + if len(valid_models) == 0: + raise KeyError("No model sets found matching {}".format(", ".join(params))) + # Pick the first model + return valid_models[0] diff --git a/aliby/cells.py b/aliby/cells.py new file mode 100644 index 0000000000000000000000000000000000000000..463e232c5b70808b13571039d269f23879cef62c --- /dev/null +++ b/aliby/cells.py @@ -0,0 +1,325 @@ +import logging +from pathlib import Path, PosixPath +from time import perf_counter +from typing import Union +from itertools import groupby +from collections.abc import Iterable + +from utils_find_1st import find_1st, cmp_equal +import h5py +import numpy as np +from scipy import ndimage +from scipy.sparse.base import isdense + +from aliby.io.matlab import matObject +from aliby.utils import timed +from aliby.io.writer import load_complex + + +def cell_factory(store, type="matlab"): + if isinstance(store, matObject): + return CellsMat(store) + if type == "matlab": + mat_object = matObject(store) + return CellsMat(mat_object) + elif type == "hdf5": + return CellsHDF(store) + else: + raise TypeError( + "Could not get cells for type {}:" "valid types are matlab and hdf5" + ) + + +class Cells: + """An object that gathers information about all the cells in a given + trap. + This is the abstract object, used for type testing + """ + + def __init__(self): + pass + + @staticmethod + def from_source(source: Union[PosixPath, str], kind: str = None): + if isinstance(source, str): + source = Path(source) + if kind is None: # Infer kind from filename + kind = "matlab" if source.suffix == ".mat" else "hdf5" + return cell_factory(source, kind) + + @staticmethod + def _asdense(array): + if not isdense(array): + array = array.todense() + return array + + @staticmethod + def _astype(array, kind): + # Convert sparse arrays if needed and if kind is 'mask' it fills the outline + array = Cells._asdense(array) + if kind == "mask": + array = ndimage.binary_fill_holes(array).astype(int) + return array + + @classmethod + def hdf(cls, fpath): + return CellsHDF(fpath) + + @classmethod + def mat(cls, path): + return CellsMat(matObject(store)) + + +class CellsHDF(Cells): + def __init__(self, filename, path="cell_info"): + self.filename = filename + self.cinfo_path = path + self._edgem_indices = None + self._edgemasks = None + self._tile_size = None + + def __getitem__(self, item): + if item == "edgemasks": + return self.edgemasks + _item = "_" + item + if not hasattr(self, _item): + setattr(self, _item, self._fetch(item)) + return getattr(self, _item) + + def _get_idx(self, cell_id, trap_id): + return (self["cell_label"] == cell_id) & (self["trap"] == trap_id) + + def _fetch(self, path): + with h5py.File(self.filename, mode="r") as f: + return f[self.cinfo_path][path][()] + + @property + def ntraps(self): + with h5py.File(self.filename, mode="r") as f: + return len(f["/trap_info/trap_locations"][()]) + + @property + def traps(self): + return list(set(self["trap"])) + + @property + def tile_size(self): # TODO read from metadata + if self._tile_size is None: + with h5py.File(self.filename, mode="r") as f: + self._tile_size == f["trap_info/tile_size"][0] + return self._tile_size + + @property + def edgem_indices(self): + if self._edgem_indices is None: + edgem_path = "edgemasks/indices" + self._edgem_indices = load_complex(self._fetch(edgem_path)) + return self._edgem_indices + + @property + def edgemasks(self): + if self._edgemasks is None: + edgem_path = "edgemasks/values" + self._edgemasks = self._fetch(edgem_path) + + return self._edgemasks + + def _edgem_where(self, cell_id, trap_id): + ix = trap_id + 1j * cell_id + return find_1st(self.edgem_indices == ix, True, cmp_equal) + + @property + def labels(self): + """ + Return all cell labels in object + We use mother_assign to list traps because it is the only propriety that appears even + when no cells are found""" + return [self.labels_in_trap(trap) for trap in self.traps] + + def where(self, cell_id, trap_id): + """ + Returns + Parameters + ---------- + cell_id: int + Cell index + trap_id: int + Trap index + + Returns + ---------- + indices int array + boolean mask array + edge_ix int array + """ + indices = self._get_idx(cell_id, trap_id) + edgem_ix = self._edgem_where(cell_id, trap_id) + return ( + self["timepoint"][indices], + indices, + edgem_ix, + ) # FIXME edgem_ix makes output different to matlab's Cell + + def outline(self, cell_id, trap_id): + times, indices, cell_ix = self.where(cell_id, trap_id) + return times, self["edgemasks"][cell_ix, times] + + def mask(self, cell_id, trap_id): + times, outlines = self.outline(cell_id, trap_id) + return times, np.array( + [ndimage.morphology.binary_fill_holes(o) for o in outlines] + ) + + def at_time(self, timepoint, kind="mask"): + ix = self["timepoint"] == timepoint + cell_ix = self["cell_label"][ix] + traps = self["trap"][ix] + indices = traps + 1j * cell_ix + choose = np.in1d(self.edgem_indices, indices) + edgemasks = self["edgemasks"][choose, timepoint] + masks = [ + self._astype(edgemask, kind) for edgemask in edgemasks if edgemask.any() + ] + return self.group_by_traps(traps, masks) + + def group_by_traps(self, traps, data): + # returns a dict with traps as keys and labels as value + iterator = groupby(zip(traps, data), lambda x: x[0]) + d = {key: [x[1] for x in group] for key, group in iterator} + d = {i: d.get(i, []) for i in self.traps} + return d + + def labels_in_trap(self, trap_id): + # Return set of cell ids in a trap. + return set((self["cell_label"][self["trap"] == trap_id])) + + def labels_at_time(self, timepoint): + labels = self["cell_label"][self["timepoint"] == timepoint] + traps = self["trap"][self["timepoint"] == timepoint] + return self.group_by_traps(traps, labels) + + +class CellsMat(Cells): + def __init__(self, mat_object): + super(CellsMat, self).__init__() + # TODO add __contains__ to the matObject + timelapse_traps = mat_object.get( + "timelapseTrapsOmero", mat_object.get("timelapseTraps", None) + ) + if timelapse_traps is None: + raise NotImplementedError( + "Could not find a timelapseTraps or " + "timelapseTrapsOmero object. Cells " + "from cellResults not implemented" + ) + else: + self.trap_info = timelapse_traps["cTimepoint"]["trapInfo"] + + if isinstance(self.trap_info, list): + self.trap_info = { + k: list([res.get(k, []) for res in self.trap_info]) + for k in self.trap_info[0].keys() + } + + def where(self, cell_id, trap_id): + times, indices = zip( + *[ + (tp, np.where(cell_id == x)[0][0]) + for tp, x in enumerate(self.trap_info["cellLabel"][:, trap_id].tolist()) + if np.any(cell_id == x) + ] + ) + return times, indices + + def outline(self, cell_id, trap_id): + times, indices = self.where(cell_id, trap_id) + info = self.trap_info["cell"][times, trap_id] + + def get_segmented(cell, index): + if cell["segmented"].ndim == 0: + return cell["segmented"][()].todense() + else: + return cell["segmented"][index].todense() + + segmentation_outline = [ + get_segmented(cell, idx) for idx, cell in zip(indices, info) + ] + return times, np.array(segmentation_outline) + + def mask(self, cell_id, trap_id): + times, outlines = self.outline(cell_id, trap_id) + return times, np.array( + [ndimage.morphology.binary_fill_holes(o) for o in outlines] + ) + + def at_time(self, timepoint, kind="outline"): + + """Returns the segmentations for all the cells at a given timepoint. + + FIXME: this is extremely hacky and accounts for differently saved + results in the matlab object. Deprecate ASAP. + """ + # Case 1: only one cell per trap: trap_info['cell'][timepoint] is a + # structured array + if isinstance(self.trap_info["cell"][timepoint], dict): + segmentations = [ + self._astype(x, "outline") + for x in self.trap_info["cell"][timepoint]["segmented"] + ] + # Case 2: Multiple cells per trap: it becomes a list of arrays or + # dictionaries, one for each trap + # Case 2.1 : it's a dictionary + elif isinstance(self.trap_info["cell"][timepoint][0], dict): + segmentations = [] + for x in self.trap_info["cell"][timepoint]: + seg = x["segmented"] + if not isinstance(seg, np.ndarray): + seg = [seg] + segmentations.append([self._astype(y, "outline") for y in seg]) + # Case 2.2 : it's an array + else: + segmentations = [ + [self._astype(y, type) for y in x["segmented"]] if x.ndim != 0 else [] + for x in self.trap_info["cell"][timepoint] + ] + # Return dict for compatibility with hdf5 output + return {i: v for i, v in enumerate(segmentations)} + + def labels_at_time(self, tp): + labels = self.trap_info["cellLabel"] + labels = [_aslist(x) for x in labels[tp]] + labels = {i: [lbl for lbl in lblset] for i, lblset in enumerate(labels)} + return labels + + @property + def ntraps(self): + return len(self.trap_info["cellLabel"][0]) + + @property + def tile_size(self): + pass + + +class ExtractionRunner: + """An object to run extraction of fluorescence, and general data out of + segmented data. + + Configure with what extraction we want to run. + Cell selection criteria. + Filtering criteria. + """ + + def __init__(self, tiler, cells): + pass + + def run(self, keys, store, **kwargs): + pass + + +def _aslist(x): + if isinstance(x, Iterable): + if hasattr(x, "tolist"): + x = x.tolist() + else: + x = [x] + return x diff --git a/aliby/core.py b/aliby/core.py new file mode 100644 index 0000000000000000000000000000000000000000..56422434475f9bd42e2d891f3a32ba3295d14f75 --- /dev/null +++ b/aliby/core.py @@ -0,0 +1,34 @@ +"""Barebones implementation of the structure/organisation of experiments.""" + + +class Experiment: + def __init__(self): + self.strains = dict() + self._metadata = None + + def add_strains(self, name, strain): + self.strains[name] = strain + + +class Strain: + def __init__(self): + self.positions = dict() + + def add_position(self, name, position): + self.positions[name] = position + + +class Position: + def __init__(self): + self.traps = [] + + def add_trap(self, trap): + self.traps.append(trap) + + +class Trap: # TODO Name this Tile? + def __init__(self): + self.cells = [] + + def add_cell(self, cell): + self.cells.append(cell) diff --git a/aliby/experiment.py b/aliby/experiment.py new file mode 100644 index 0000000000000000000000000000000000000000..9919f40feffbe22a1598be6eb563514ad151cfe3 --- /dev/null +++ b/aliby/experiment.py @@ -0,0 +1,499 @@ +"""Core classes for the pipeline""" +import atexit +import itertools +import os +import abc +import glob +import json +import warnings +from getpass import getpass +from pathlib import Path +import re +import logging +from typing import Union + +import h5py +from tqdm import tqdm +import pandas as pd + +import omero +from omero.gateway import BlitzGateway +from logfile_parser import Parser + +from aliby.timelapse import TimelapseOMERO, TimelapseLocal +from aliby.utils import accumulate + +from aliby.io.writer import Writer + +logger = logging.getLogger(__name__) + +########################### Dask objects ################################### +##################### ENVIRONMENT INITIALISATION ################ +import omero +from omero.gateway import BlitzGateway, PixelsWrapper +from omero.model import enums as omero_enums +import numpy as np + +# Set up the pixels so that we can reuse them across sessions (?) +PIXEL_TYPES = { + omero_enums.PixelsTypeint8: np.int8, + omero_enums.PixelsTypeuint8: np.uint8, + omero_enums.PixelsTypeint16: np.int16, + omero_enums.PixelsTypeuint16: np.uint16, + omero_enums.PixelsTypeint32: np.int32, + omero_enums.PixelsTypeuint32: np.uint32, + omero_enums.PixelsTypefloat: np.float32, + omero_enums.PixelsTypedouble: np.float64, +} + + +class NonCachedPixelsWrapper(PixelsWrapper): + """Extend gateway.PixelWrapper to override _prepareRawPixelsStore.""" + + def _prepareRawPixelsStore(self): + """ + Creates RawPixelsStore and sets the id etc + This overrides the superclass behaviour to make sure that + we don't re-use RawPixelStore in multiple processes since + the Store may be closed in 1 process while still needed elsewhere. + This is needed when napari requests may planes simultaneously, + e.g. when switching to 3D view. + """ + ps = self._conn.c.sf.createRawPixelsStore() + ps.setPixelsId(self._obj.id.val, True, self._conn.SERVICE_OPTS) + return ps + + +omero.gateway.PixelsWrapper = NonCachedPixelsWrapper +# Update the BlitzGateway to use our NonCachedPixelsWrapper +omero.gateway.refreshWrappers() + + +###################### DATA ACCESS ################### +import dask.array as da +from dask import delayed + + +def get_data_lazy(image) -> da.Array: + """Get 5D dask array, with delayed reading from OMERO image.""" + nt, nc, nz, ny, nx = [getattr(image, f"getSize{x}")() for x in "TCZYX"] + pixels = image.getPrimaryPixels() + dtype = PIXEL_TYPES.get(pixels.getPixelsType().value, None) + get_plane = delayed(lambda idx: pixels.getPlane(*idx)) + + def get_lazy_plane(zct): + return da.from_delayed(get_plane(zct), shape=(ny, nx), dtype=dtype) + + # 5D stack: TCZXY + t_stacks = [] + for t in range(nt): + c_stacks = [] + for c in range(nc): + z_stack = [] + for z in range(nz): + z_stack.append(get_lazy_plane((z, c, t))) + c_stacks.append(da.stack(z_stack)) + t_stacks.append(da.stack(c_stacks)) + return da.stack(t_stacks) + + +# Metadata writer +from aliby.io.metadata_parser import parse_logfiles + + +class MetaData: + """Small metadata Process that loads log.""" + + def __init__(self, log_dir, store): + self.log_dir = log_dir + self.store = store + + def load_logs(self): + parsed_flattened = parse_logfiles(self.log_dir) + return parsed_flattened + + def run(self): + metadata_writer = Writer(self.store) + metadata_dict = self.load_logs() + metadata_writer.write(path="/", meta=metadata_dict, overwrite=False) + + +########################### Old Objects #################################### + + +class Experiment(abc.ABC): + """ + Abstract base class for experiments. + Gives all the functions that need to be implemented in both the local + version and the Omero version of the Experiment class. + + As this is an abstract class, experiments can not be directly instantiated + through the usual `__init__` function, but must be instantiated from a + source. + >>> expt = Experiment.from_source(root_directory) + Data from the current timelapse can be obtained from the experiment using + colon and comma separated slicing. + The order of data is C, T, X, Y, Z + C, T and Z can have any slice + X and Y will only consider the beginning and end as we want the images + to be continuous + >>> bf_1 = expt[0, 0, :, :, :] # First channel, first timepoint, all x,y,z + """ + + __metaclass__ = abc.ABCMeta + + # metadata_parser = AcqMetadataParser() + + def __init__(self): + self.exptID = "" + self._current_position = None + self.position_to_process = 0 + + def __getitem__(self, item): + return self.current_position[item] + + @property + def shape(self): + return self.current_position.shape + + @staticmethod + def from_source(*args, **kwargs): + """ + Factory method to construct an instance of an Experiment subclass ( + either ExperimentOMERO or ExperimentLocal). + + :param source: Where the data is stored (OMERO server or directory + name) + :param kwargs: If OMERO server, `user` and `password` keyword + arguments are required. If the data is stored locally keyword + arguments are ignored. + """ + if len(args) > 1: + logger.debug("ExperimentOMERO: {}".format(args, kwargs)) + return ExperimentOMERO(*args, **kwargs) + else: + logger.debug("ExperimentLocal: {}".format(args, kwargs)) + return ExperimentLocal(*args, **kwargs) + + @property + @abc.abstractmethod + def positions(self): + """Returns list of available position names""" + return + + @abc.abstractmethod + def get_position(self, position): + return + + @property + def current_position(self): + return self._current_position + + @property + def channels(self): + return self._current_position.channels + + @current_position.setter + def current_position(self, position): + self._current_position = self.get_position(position) + + def get_hypercube(self, x, y, z_positions, channels, timepoints): + return self.current_position.get_hypercube( + x, y, z_positions, channels, timepoints + ) + + +# Todo: cache images like in ExperimentLocal +class ExperimentOMERO(Experiment): + """ + Experiment class to organise different timelapses. + Connected to a Dataset object which handles database I/O. + """ + + def __init__(self, omero_id, host, port=4064, **kwargs): + super(ExperimentOMERO, self).__init__() + self.exptID = omero_id + # Get annotations + self.use_annotations = kwargs.get("use_annotations", True) + self._files = None + self._tags = None + + # Create a connection + self.connection = BlitzGateway( + kwargs.get("username") or input("Username: "), + kwargs.get("password") or getpass("Password: "), + host=host, + port=port, + ) + connected = self.connection.connect() + try: + assert connected is True, "Could not connect to server." + except AssertionError as e: + self.connection.close() + raise (e) + try: # Run everything that could cause the initialisation to fail + self.dataset = self.connection.getObject("Dataset", self.exptID) + self.name = self.dataset.getName() + # Create positions objects + self._positions = { + img.getName(): img.getId() + for img in sorted( + self.dataset.listChildren(), key=lambda x: x.getName() + ) + } + # Set up local cache + self.root_dir = Path(kwargs.get("save_dir", "./")) / self.name + if not self.root_dir.exists(): + self.root_dir.mkdir(parents=True) + self.compression = kwargs.get("compression", None) + self.image_cache = h5py.File(self.root_dir / "images.h5", "a") + + # Set up the current position as the first in the list + self._current_position = self.get_position(self.positions[0]) + self.running_tp = 0 + except Exception as e: + # Close the connection! + print("Error in initialisation, closing connection.") + self.connection.close() + print(self.connection.isConnected()) + raise e + atexit.register(self.close) # Close everything if program ends + + def close(self): + print("Clean-up on exit.") + self.image_cache.close() + self.connection.close() + + @property + def files(self): + if self._files is None: + self._files = { + x.getFileName(): x + for x in self.dataset.listAnnotations() + if isinstance(x, omero.gateway.FileAnnotationWrapper) + } + return self._files + + @property + def tags(self): + if self._tags is None: + self._tags = { + x.getName(): x + for x in self.dataset.listAnnotations() + if isinstance(x, omero.gateway.TagAnnotationWrapper) + } + return self._tags + + @property + def positions(self): + return list(self._positions.keys()) + + def _get_position_annotation(self, position): + # Get file annotations filtered by position name and ordered by + # creation date + r = re.compile(position) + wrappers = sorted( + [self.files[key] for key in filter(r.match, self.files)], + key=lambda x: x.creationEventDate(), + reverse=True, + ) + # Choose newest file + if len(wrappers) < 1: + return None + else: + # Choose the newest annotation and cache it + annotation = wrappers[0] + filepath = self.root_dir / annotation.getFileName().replace("/", "_") + if not filepath.exists(): + with open(str(filepath), "wb") as fd: + for chunk in annotation.getFileInChunks(): + fd.write(chunk) + return filepath + + def get_position(self, position): + """Get a Timelapse object for a given position by name""" + # assert position in self.positions, "Position not available." + img = self.connection.getObject("Image", self._positions[position]) + if self.use_annotations: + annotation = self._get_position_annotation(position) + else: + annotation = None + return TimelapseOMERO(img, annotation, self.image_cache) + + def cache_locally( + self, + root_dir="./", + positions=None, + channels=None, + timepoints=None, + z_positions=None, + ): + """ + Save the experiment locally. + + :param root_dir: The directory in which the experiment will be + saved. The experiment will be a subdirectory of "root_directory" + and will be named by its id. + """ + logger.warning("Saving experiment {}; may take some time.".format(self.name)) + + if positions is None: + positions = self.positions + if channels is None: + channels = self.current_position.channels + if timepoints is None: + timepoints = range(self.current_position.size_t) + if z_positions is None: + z_positions = range(self.current_position.size_z) + + save_dir = Path(root_dir) / self.name + if not save_dir.exists(): + save_dir.mkdir() + # Save the images + for pos_name in tqdm(positions): + pos = self.get_position(pos_name) + pos_dir = save_dir / pos_name + if not pos_dir.exists(): + pos_dir.mkdir() + self.cache_set(pos, range(pos.size_t)) + + self.cache_logs(save_dir) + # Save the file annotations + cache_config = dict( + positions=positions, + channels=channels, + timepoints=timepoints, + z_positions=z_positions, + ) + with open(str(save_dir / "cache.config"), "w") as fd: + json.dump(cache_config, fd) + logger.info("Downloaded experiment {}".format(self.exptID)) + + def cache_logs(self, **kwargs): + # Save the file annotations + tags = dict() # and the tag annotations + for annotation in self.dataset.listAnnotations(): + if isinstance(annotation, omero.gateway.FileAnnotationWrapper): + filepath = self.root_dir / annotation.getFileName().replace("/", "_") + if str(filepath).endswith("txt") and not filepath.exists(): + # Save only the text files + with open(str(filepath), "wb") as fd: + for chunk in annotation.getFileInChunks(): + fd.write(chunk) + if isinstance(annotation, omero.gateway.TagAnnotationWrapper): + key = annotation.getDescription() + if key == "": + key = "misc. tags" + if key in tags: + if not isinstance(tags[key], list): + tags[key] = [tags[key]] + tags[key].append(annotation.getValue()) + else: + tags[key] = annotation.getValue() + with open(str(self.root_dir / "omero_tags.json"), "w") as fd: + json.dump(tags, fd) + return + + def run(self, keys: Union[list, int], store, **kwargs): + if self.running_tp == 0: + self.cache_logs(**kwargs) + self.running_tp = 1 # Todo rename based on annotations + run_tps = dict() + for pos, tps in accumulate(keys): + position = self.get_position(pos) + run_tps[pos] = position.run(tps, store, save_dir=self.root_dir) + # Update the keys to match what was actually run + keys = [(pos, tp) for pos in run_tps for tp in run_tps[pos]] + return keys + + +class ExperimentLocal(Experiment): + def __init__(self, root_dir, finished=True): + super(ExperimentLocal, self).__init__() + self.root_dir = Path(root_dir) + self.exptID = self.root_dir.name + self._pos_mapper = dict() + # Fixme: Made the assumption that the Acq file gets saved before the + # experiment is run and that the information in that file is + # trustworthy. + acq_file = self._find_acq_file() + acq_parser = Parser("multiDGUI_acq_format") + with open(acq_file, "r") as fd: + metadata = acq_parser.parse(fd) + self.metadata = metadata + self.metadata["finished"] = finished + self.files = [f for f in self.root_dir.iterdir() if f.is_file()] + self.image_cache = h5py.File(self.root_dir / "images.h5", "a") + if self.finished: + cache = self._find_cache() + # log = self._find_log() # Todo: add log metadata + if cache is not None: + with open(cache, "r") as fd: + cache_config = json.load(fd) + self.metadata.update(**cache_config) + self._current_position = self.get_position(self.positions[0]) + + def _find_file(self, regex): + file = glob.glob(os.path.join(str(self.root_dir), regex)) + if len(file) != 1: + return None + else: + return file[0] + + def _find_acq_file(self): + file = self._find_file("*[Aa]cq.txt") + if file is None: + raise ValueError( + "Cannot load this experiment. There are either " + "too many or too few acq files." + ) + return file + + def _find_cache(self): + return self._find_file("cache.config") + + @property + def finished(self): + return self.metadata["finished"] + + @property + def running(self): + return not self.metadata["finished"] + + @property + def positions(self): + return self.metadata["positions"]["posname"] + + def _get_position_annotation(self, position): + r = re.compile(position) + files = list(filter(lambda x: r.match(x.stem), self.files)) + if len(files) == 0: + return None + files = sorted(files, key=lambda x: x.lstat().st_ctime, reverse=True) + # Get the newest and return as string + return files[0] + + def get_position(self, position): + if position not in self._pos_mapper: + annotation = self._get_position_annotation(position) + self._pos_mapper[position] = TimelapseLocal( + position, + self.root_dir, + finished=self.finished, + annotation=annotation, + cache=self.image_cache, + ) + return self._pos_mapper[position] + + def run(self, keys, store, **kwargs): + """ + + :param keys: List of (position, time point) tuples to process. + :return: + """ + run_tps = dict() + for pos, tps in accumulate(keys): + run_tps[pos] = self.get_position(pos).run(tps, store) + # Update the keys to match what was actually run + keys = [(pos, tp) for pos in run_tps for tp in run_tps[pos]] + return keys diff --git a/aliby/extract.py b/aliby/extract.py new file mode 100644 index 0000000000000000000000000000000000000000..edd7c063cd9866c6f994a0d487eb912fe366e54a --- /dev/null +++ b/aliby/extract.py @@ -0,0 +1,279 @@ +""" +A module to extract data from a processed experiment. +""" +import h5py +import numpy as np +from tqdm import tqdm + +from core.io.matlab import matObject +from growth_rate.estimate_gr import estimate_gr + + +class Extracted: + # TODO write the filtering functions. + def __init__(self): + self.volume = None + self._keep = None + + def filter(self, filename=None, **kwargs): + """ + 1. Filter out small non-growing tracks. This means: + a. the cell size never reaches beyond a certain size-threshold + volume_thresh or + b. the cell's volume doesn't increase by at least a minimum + amount over its lifetime + 2. Join daughter tracks that are contiguous and within a volume + threshold of each other + 3. Discard tracks that are shorter than a threshold number of + timepoints + + This function is used to fix tracking/bud errors in post-processing. + The parameters define the thresholds used to determine which cells are + discarded. + FIXME Ideally we get to a point where this is no longer needed. + :return: + """ + #self.join_tracks() + filter_out = self.filter_size(**kwargs) + filter_out += self.filter_lifespan(**kwargs) + # TODO save data or just filtering parameters? + #self.to_hdf(filename) + self.keep = ~filter_out + + def filter_size(self, volume_thresh=7, growth_thresh=10, **kwargs): + """Filter out small and non-growing cells. + :param volume_thresh: Size threshold for small cells + :param growth_thresh: Size difference threshold for non-growing cells + """ + filter_out = np.where(np.max(self.volume, axis=1) < volume_thresh, + True, False) + growth = [v[v > 0] for v in self.volume] + growth = np.array([v[-1] - v[0] if len(v) > 0 else 0 for v in growth]) + filter_out += np.where(growth < growth_thresh, True, False) + return filter_out + + def filter_lifespan(self, min_time=5, **kwargs): + """Remove daughter cells that have a small life span. + + :param min_time: The minimum life span, under which cells are removed. + """ + # TODO What if there are nan values? + filter_out = np.where(np.count_nonzero(self.volume, axis=1) < + min_time, True, False) + return filter_out + + def join_tracks(self, threshold=7): + """ Join contiguous tracks that are within a certain volume + threshold of each other. + + :param threshold: Maximum volume difference to join contiguous tracks. + :return: + """ + # For all pairs of cells + # + pass + + +class ExtractedHDF(Extracted): + # TODO pull all the data out of the HFile and filter! + def __init__(self, file): + # We consider the data to be read-only + self.hfile = h5py.File(file, 'r') + + +class ExtractedMat(Extracted): + """ Pulls the extracted data out of the MATLAB cTimelapse file. + + This is mostly a convenience function in order to run the + gaussian-processes growth-rate estimation + """ + def __init__(self, file, debug=False): + ct = matObject(file) + self.debug = debug + # Pre-computed data + # TODO what if there is no timelapseTrapsOmero? + self.metadata = ct['timelapseTrapsOmero']['metadata'] + self.extracted_data = ct['timelapseTrapsOmero']['extractedData'] + self.channels = ct['timelapseTrapsOmero']['extractionParameters'][ + 'functionParameters']['channels'].tolist() + self.time_settings = ct['timelapseTrapsOmero']['metadata']['acq'][ + 'times'] + # Get filtering information + n_cells = self.extracted_data['cellNum'][0].shape + self.keep = np.full(n_cells, True) + # Not yet computed data + self._growth_rate = None + self._daughter_index = None + + + def get_channel_index(self, channel): + """Get index of channel based on name. This only considers + fluorescence channels.""" + return self.channels.index(channel) + + @property + def trap_num(self): + return self.extracted_data['trapNum'][0][self.keep] + + @property + def cell_num(self): + return self.extracted_data['cellNum'][0][self.keep] + + def identity(self, cell_idx): + """Get the (position), trap, and cell label given a cell's global + index.""" + # Todo include position when using full strain + trap = self.trap_num[cell_idx] + cell = self.cell_num[cell_idx] + return trap, cell + + def global_index(self, trap_id, cell_label): + """Get the global index of a cell given it's trap/cellNum + combination.""" + candidates = np.where(np.logical_and( + (self.trap_num == trap_id), # +1? + (self.cell_num == cell_label) + ))[0] + # TODO raise error if number of candidates != 1 + if len(candidates) == 1: + return candidates[0] + elif len(candidates) == 0: + return -1 + else: + raise(IndexError("No such cell/trap combination")) + + @property + def daughter_label(self): + """Returns the cell label of the daughters of each cell over the + timelapse. + + 0 corresponds to no daughter. This *not* the index of the daughter + cell within the data. To get this, use daughter_index. + """ + return self.extracted_data['daughterLabel'][0][self.keep] + + def _single_daughter_idx(self, mother_idx, daughter_labels): + trap_id, _ = self.identity(mother_idx) + daughter_index = [self.global_index(trap_id, cell_label) for + cell_label + in daughter_labels] + return daughter_index + + @property + def daughter_index(self): + """Returns the global index of the daughters of each cell. + + This is different from the daughter label because it corresponds to + the index of the daughter when counting all of the cells. This can + be used to index within the data arrays. + """ + if self._daughter_index is None: + daughter_index = [self._single_daughter_idx(i, daughter_labels) + for i, daughter_labels in enumerate( + self.daughter_label)] + self._daughter_index = np.array(daughter_index) + return self._daughter_index + + @property + def births(self): + return np.array(self.extracted_data['births'][0].todense())[self.keep] + + @property + def volume(self): + """Get the volume of all of the cells""" + return np.array(self.extracted_data['volume'][0].todense())[self.keep] + + def _gr_estimation(self): + dt = self.time_settings['interval'] / 360 # s to h conversion + results = [] + for v in tqdm(self.volume): + results.append(estimate_gr(v, dt)) + merged = {k: np.stack([x[k] for x in results]) for k in results[0]} + self._gr_results = merged + return + + @property + def growth_rate(self): + """Get the growth rate for all cells. + + Note that this uses the gaussian processes method of estimating + growth rate by default. If there is no growth rate in the given file + (usually the case for MATLAB), it needs to run estimation first. + This can take a while. + """ + # TODO cache the results of growth rate estimation. + if self._gr_results is None: + dt = self.time_settings['interval'] / 360 # s to h conversion + self._growth_rate = [estimate_gr(v, dt) for v in self.volume] + return self._gr_results['growth_rate'] + + def _fluo_attribute(self, channel, attribute): + channel_id = self.get_channel_index(channel) + res = np.array(self.extracted_data[attribute][channel_id].todense()) + return res[self.keep] + + def protein_localisation(self, channel, method='nucEstConv'): + """Returns protein localisation data for a given channel. + + Uses the 'nucEstConv' by default. Alternatives are 'smallPeakConv', + 'max5px', 'max2p5pc' + """ + return self._fluo_attribute(channel, method) + + def background_fluo(self, channel): + return self._fluo_attribute(channel, 'imBackground') + + def mean(self, channel): + return self._fluo_attribute(channel, 'mean') + + def median(self, channel): + return self._fluo_attribute(channel, 'median') + + def filter(self, filename=None): + """Filters and saves results to and HDF5 file. + + This is necessary because we cannot write to the MATLAB file, + so the results of the filter cannot be saved in the object. + """ + super().filter(filename=filename) + self._growth_rate = None # reset growth rate so it is recomputed + + def to_hdf(self, filename): + """Store the current results, including any filtering done, to a file. + + TODO Should we save filtered results or just re-do? + :param filename: + :return: + """ + store = h5py.File(filename, 'w') + try: + # Store (some of the) metadata + for meta in ['experiment', 'username', 'microscope', + 'comments', 'project', 'date', 'posname', + 'exptid']: + store.attrs[meta] = self.metadata[meta] + # TODO store timing information? + store.attrs['time_interval'] = self.time_settings['interval'] + store.attrs['timepoints'] = self.time_settings['ntimepoints'] + store.attrs['total_duration'] = self.time_settings['totalduration'] + # Store volume, births, daughterLabel, trapNum, cellNum + for key in ['volume', 'births', 'daughter_label', 'trap_num', + 'cell_num']: + store[key] = getattr(self, key) + # Store growth rate results + if self._gr_results: + grp = store.create_group('gaussian_process') + for key, val in self._gr_results.items(): + grp[key] = val + for channel in self.channels: + # Create a group for each channel + grp = store.create_group(channel) + # Store protein_localisation, background fluorescence, mean, median + # for each channel + grp['protein_localisation'] = self.protein_localisation(channel) + grp['background_fluo'] = self.background_fluo(channel) + grp['mean'] = self.mean(channel) + grp['median'] = self.median(channel) + finally: + store.close() + diff --git a/aliby/grouper.py b/aliby/grouper.py new file mode 100644 index 0000000000000000000000000000000000000000..c3d6f7aea4d832a5d0f763d2bd16c62c15adea74 --- /dev/null +++ b/aliby/grouper.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python3 + +from abc import ABC, abstractmethod, abstractproperty +from pathlib import Path +from pathos.multiprocessing import Pool + +import h5py +import numpy as np +import pandas as pd + +from aliby.io.signal import Signal + + +class Grouper(ABC): + """ + Base grouper class + """ + + files = [] + + def __init__(self, dir): + self.files = list(Path(dir).glob("*.h5")) + self.load_signals() + + def load_signals(self): + self.signals = {f.name[:-3]: Signal(f) for f in self.files} + + @property + def fsignal(self): + return list(self.signals.values())[0] + + @property + def siglist(self): + return self.fsignal.datasets + + @abstractproperty + def group_names(): + pass + + def concat_signal(self, path, reduce_cols=None, axis=0, pool=8): + group_names = self.group_names + sitems = self.signals.items() + if pool: + with Pool(pool) as p: + signals = p.map( + lambda x: concat_signal_ind(path, group_names, x[0], x[1]), + sitems, + ) + else: + signals = [ + concat_signal_ind(path, group_names, name, signal) + for name, signal in sitems + ] + + signals = [s for s in signals if s is not None] + sorted = pd.concat(signals, axis=axis).sort_index() + if reduce_cols: + sorted = sorted.apply(np.nanmean, axis=1) + spath = path.split("/") + sorted.name = "_".join([spath[1], spath[-1]]) + + return sorted + + @property + def ntraps(self): + for pos, s in self.signals.items(): + with h5py.File(s.filename, "r") as f: + print(pos, f["/trap_info/trap_locations"].shape[0]) + + def traplocs(self): + d = {} + for pos, s in self.signals.items(): + with h5py.File(s.filename, "r") as f: + d[pos] = f["/trap_info/trap_locations"][()] + return d + + +class MetaGrouper(Grouper): + """Group positions using metadata's 'group' number""" + + pass + + +class NameGrouper(Grouper): + """ + Group a set of positions using a subsection of the name + """ + + def __init__(self, dir, by=None): + super().__init__(dir=dir) + + if by is None: + by = (0, -4) + self.by = by + + @property + def group_names(self): + if not hasattr(self, "_group_names"): + self._group_names = {} + for name in self.signals.keys(): + self._group_names[name] = name[self.by[0] : self.by[1]] + + return self._group_names + + def aggregate_multisignals(self, paths=None, **kwargs): + + aggregated = pd.concat( + [ + self.concat_signal(path, reduce_cols=np.nanmean, **kwargs) + for path in paths + ], + axis=1, + ) + # ph = pd.Series( + # [ + # self.ph_from_group(x[list(aggregated.index.names).index("group")]) + # for x in aggregated.index + # ], + # index=aggregated.index, + # name="media_pH", + # ) + # self.aggregated = pd.concat((aggregated, ph), axis=1) + + return aggregated + + +class phGrouper(NameGrouper): + """ + Grouper for pH calibration experiments where all surveyed media pH values + are within a single experiment. + """ + + def __init__(self, dir, by=(3, 7)): + super().__init__(dir=dir, by=by) + + def get_ph(self): + self.ph = {gn: self.ph_from_group(gn) for gn in self.group_names} + + @staticmethod + def ph_from_group(group_name): + if group_name.startswith("ph_"): + group_name = group_name[3:] + + return float(group_name.replace("_", ".")) + + def aggregate_multisignals(self, paths): + + aggregated = pd.concat( + [self.concat_signal(path, reduce_cols=np.nanmean) for path in paths], axis=1 + ) + ph = pd.Series( + [ + self.ph_from_group(x[list(aggregated.index.names).index("group")]) + for x in aggregated.index + ], + index=aggregated.index, + name="media_pH", + ) + aggregated = pd.concat((aggregated, ph), axis=1) + + return aggregated + + +def concat_signal_ind(path, group_names, group, signal): + print("Looking at ", group) + # try: + combined = signal[path] + combined["position"] = group + combined["group"] = group_names[group] + combined.set_index(["group", "position"], inplace=True, append=True) + combined.index = combined.index.swaplevel(-2, 0).swaplevel(-1, 1) + + return combined + # except: + # return None diff --git a/aliby/haystack.py b/aliby/haystack.py new file mode 100644 index 0000000000000000000000000000000000000000..c30cf4aad367bb811cae00f9e081d2f4676aa240 --- /dev/null +++ b/aliby/haystack.py @@ -0,0 +1,97 @@ +import numpy as np +from time import perf_counter +from pathlib import Path + +import tensorflow as tf + +from aliby.io.writer import DynamicWriter + + +def initialise_tf(version): + # Initialise tensorflow + if version == 1: + core_config = tf.ConfigProto() + core_config.gpu_options.allow_growth = True + session = tf.Session(config=core_config) + return session + # TODO this only works for TF2 + if version == 2: + gpus = tf.config.experimental.list_physical_devices("GPU") + if gpus: + for gpu in gpus: + tf.config.experimental.set_memory_growth(gpu, True) + logical_gpus = tf.config.experimental.list_logical_devices("GPU") + print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs") + return None + + +def timer(func, *args, **kwargs): + start = perf_counter() + result = func(*args, **kwargs) + print(f"Function {func.__name__}: {perf_counter() - start}s") + return result + + +################## CUSTOM OBJECTS ################################## + + +class ModelPredictor: + """Generic object that takes a NN and returns the prediction. + + Use for predicting fluorescence/other from bright field. + This does not do instance segmentations of anything. + """ + + def __init__(self, tiler, model, name): + self.tiler = tiler + self.model = model + self.name = name + + def get_data(self, tp): + # Change axes to X,Y,Z rather than Z,Y,X + return self.tiler.get_tp_data(tp, self.bf_channel).swapaxes(1, 3).swapaxes(1, 2) + + def format_result(self, result, tp): + return {self.name: result, "timepoints": [tp] * len(result)} + + def run_tp(self, tp, **kwargs): + """Simulating processing time with sleep""" + # Access the image + segmentation = self.model.predict(self.get_data(tp)) + return self._format_result(segmentation, tp) + + +class ModelPredictorWriter(DynamicWriter): + def __init__(self, file, name, shape, dtype): + super.__init__(file) + self.datatypes = {name: (shape, dtype), "timepoint": ((None,), np.uint16)} + self.group = f"{self.name}_info" + + +class Saver: + channel_names = {0: "BrightField", 1: "GFP"} + + def __init__(self, tiler, save_directory, pos_name): + """This class straight up saves the trap data for use with neural networks in the future.""" + self.tiler = tiler + self.name = pos_name + self.save_dir = Path(save_directory) + + def channel_dir(self, index): + ch_dir = self.save_dir / self.channel_names[index] + if not ch_dir.exists(): + ch_dir.mkdir() + return ch_dir + + def get_data(self, tp, ch): + return self.tiler.get_tp_data(tp, ch).swapaxes(1, 3).swapaxes(1, 2) + + def cache(self, tp): + # Get a given time point + # split into channels + for ch in self.channel_names: + ch_dir = self.channel_dir(ch) + data = self.get_data(tp, ch) + for tid, trap in enumerate(data): + np.save(ch_dir / f"{self.name}_{tid}_{tp}.npy", trap) + return diff --git a/aliby/io/__init__.py b/aliby/io/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/aliby/io/base.py b/aliby/io/base.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b911736943ecbcc9abe79300e52ceff5fe34e1 --- /dev/null +++ b/aliby/io/base.py @@ -0,0 +1,142 @@ +from typing import Union +import collections +from itertools import groupby, chain, product + +import numpy as np +import h5py + + +class BridgeH5: + """ + Base class to interact with h5 data stores. + It also contains functions useful to predict how long should segmentation take. + """ + + def __init__(self, filename, flag="r"): + self.filename = filename + if flag is not None: + self._hdf = h5py.File(filename, flag) + + self._filecheck + + def _filecheck(self): + assert "cell_info" in self._hdf, "Invalid file. No 'cell_info' found." + + def close(self): + self._hdf.close() + + def max_ncellpairs(self, nstepsback): + """ + Get maximum number of cell pairs to be calculated + """ + + dset = self._hdf["cell_info"][()] + # attrs = self._hdf[dataset].attrs + pass + + @property + def cell_tree(self): + return self.get_info_tree() + + def get_n_cellpairs(self, nstepsback=2): + cell_tree = self.cell_tree + # get pair of consecutive trap-time points + pass + + @staticmethod + def get_consecutives(tree, nstepsback): + # Receives a sorted tree and returns the keys of consecutive elements + vals = {k: np.array(list(v)) for k, v in tree.items()} # get tp level + where_consec = [ + { + k: np.where(np.subtract(v[n + 1 :], v[: -n - 1]) == n + 1)[0] + for k, v in vals.items() + } + for n in range(nstepsback) + ] # get indices of consecutive elements + return where_consec + + def get_npairs(self, nstepsback=2, tree=None): + if tree is None: + tree = self.cell_tree + + consecutive = self.get_consecutives(tree, nstepsback=nstepsback) + flat_tree = flatten(tree) + + n_predictions = 0 + for i, d in enumerate(consecutive, 1): + flat = list(chain(*[product([k], list(v)) for k, v in d.items()])) + pairs = [(f, (f[0], f[1] + i)) for f in flat] + for p in pairs: + n_predictions += len(flat_tree.get(p[0], [])) * len( + flat_tree.get(p[1], []) + ) + + return n_predictions + + def get_npairs_over_time(self, nstepsback=2): + tree = self.cell_tree + npairs = [] + for t in self._hdf["cell_info"]["processed_timepoints"][()]: + tmp_tree = { + k: {k2: v2 for k2, v2 in v.items() if k2 <= t} for k, v in tree.items() + } + npairs.append(self.get_npairs(tree=tmp_tree)) + + return np.diff(npairs) + + def get_info_tree( + self, fields: Union[tuple, list] = ("trap", "timepoint", "cell_label") + ): + """ + Returns traps, time points and labels for this position in form of a tree + in the hierarchy determined by the argument fields. Note that it is + compressed to non-empty elements and timepoints. + + Default hierarchy is: + - trap + - time point + - cell label + + This function currently produces trees of depth 3, but it can easily be + extended for deeper trees if needed (e.g. considering groups, + chambers and/or positions). + + input + :fields: Fields to fetch from 'cell_info' inside the hdf5 storage + + returns + :tree: Nested dictionary where keys (or branches) are the upper levels + and the leaves are the last element of :fields:. + """ + zipped_info = (*zip(*[self._hdf["cell_info"][f][()] for f in fields]),) + + return recursive_groupsort(zipped_info) + + +def groupsort(iterable: Union[tuple, list]): + # Sorts iterable and returns a dictionary where the values are grouped by the first element. + + iterable = sorted(iterable, key=lambda x: x[0]) + grouped = {k: [x[1:] for x in v] for k, v in groupby(iterable, lambda x: x[0])} + return grouped + + +def recursive_groupsort(iterable): + # Recursive extension of groupsort + if len(iterable[0]) > 1: + return {k: recursive_groupsort(v) for k, v in groupsort(iterable).items()} + else: # Only two elements in list + return [x[0] for x in iterable] + + +def flatten(d, parent_key="", sep="_"): + """Flatten nested dict. Adapted from https://stackoverflow.com/a/6027615""" + items = [] + for k, v in d.items(): + new_key = parent_key + (k,) if parent_key else (k,) + if isinstance(v, collections.MutableMapping): + items.extend(flatten(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) diff --git a/aliby/io/matlab.py b/aliby/io/matlab.py new file mode 100644 index 0000000000000000000000000000000000000000..18362ed377b86e577bc6dc703a0f5c8bb67d5c58 --- /dev/null +++ b/aliby/io/matlab.py @@ -0,0 +1,569 @@ +"""Read and convert MATLAB files from Swain Lab platform. + +TODO: Information that I need from lab members esp J and A + * Lots of examples to try + * Any ideas on what these Map objects are? + +TODO: Update Swain Lab wiki + +All credit to Matt Bauman for +the reverse engineering at https://nbviewer.jupyter.org/gist/mbauman/9121961 +""" + +import re +import struct +import sys +from collections import Iterable +from io import BytesIO + +import h5py +import numpy as np +import pandas as pd +import scipy +from numpy.compat import asstr + +# TODO only use this if scipy>=1.6 or so +from scipy.io import matlab +from scipy.io.matlab.mio5 import MatFile5Reader +from scipy.io.matlab.mio5_params import mat_struct + +from aliby.io.utils import read_int, read_string, read_delim + + +def read_minimat_vars(rdr): + rdr.initialize_read() + mdict = {"__globals__": []} + i = 0 + while not rdr.end_of_stream(): + hdr, next_position = rdr.read_var_header() + name = asstr(hdr.name) + if name == "": + name = "var_%d" % i + i += 1 + res = rdr.read_var_array(hdr, process=False) + rdr.mat_stream.seek(next_position) + mdict[name] = res + if hdr.is_global: + mdict["__globals__"].append(name) + return mdict + + +def read_workspace_vars(fname): + fp = open(fname, "rb") + rdr = MatFile5Reader(fp, struct_as_record=True, squeeze_me=True) + vars = rdr.get_variables() + fws = vars["__function_workspace__"] + ws_bs = BytesIO(fws.tostring()) + ws_bs.seek(2) + rdr.mat_stream = ws_bs + # Guess byte order. + mi = rdr.mat_stream.read(2) + rdr.byte_order = mi == b"IM" and "<" or ">" + rdr.mat_stream.read(4) # presumably byte padding + mdict = read_minimat_vars(rdr) + fp.close() + return mdict + + +class matObject: + """A python read-out of MATLAB objects + The objects pulled out of the + """ + + def __init__(self, filepath): + self.filepath = filepath # For record + self.classname = None + self.object_name = None + self.buffer = None + self.version = None + self.names = None + self.segments = None + self.heap = None + self.attrs = dict() + self._init_buffer() + self._init_heap() + self._read_header() + self.parse_file() + + def __getitem__(self, item): + return self.attrs[item] + + def keys(self): + """Returns the names of the available properties""" + return self.attrs.keys() + + def get(self, item, default=None): + return self.attrs.get(item, default) + + def _init_buffer(self): + fp = open(self.filepath, "rb") + rdr = MatFile5Reader(fp, struct_as_record=True, squeeze_me=True) + vars = rdr.get_variables() + self.classname = vars["None"]["s2"][0].decode("utf-8") + self.object_name = vars["None"]["s0"][0].decode("utf-8") + fws = vars["__function_workspace__"] + self.buffer = BytesIO(fws.tostring()) + fp.close() + + def _init_heap(self): + super_data = read_workspace_vars(self.filepath) + elem = super_data["var_0"][0, 0] + if isinstance(elem, mat_struct): + self.heap = elem.MCOS[0]["arr"] + else: + self.heap = elem["MCOS"][0]["arr"] + + def _read_header(self): + self.buffer.seek(248) # the start of the header + version = read_int(self.buffer) + n_str = read_int(self.buffer) + + offsets = read_int(self.buffer, n=6) + + # check that the next two are zeros + reserved = read_int(self.buffer, n=2) + assert all( + [x == 0 for x in reserved] + ), "Non-zero reserved header fields: {}".format(reserved) + # check that we are at the right place + assert self.buffer.tell() == 288, "String elemnts begin at 288" + hdrs = [] + for i in range(n_str): + hdrs.append(read_string(self.buffer)) + self.names = hdrs + self.version = version + # The offsets are actually STARTING FROM 248 as well + self.segments = [x + 248 for x in offsets] # list(offsets) + return + + def parse_file(self): + # Get class attributes from segment 1 + self.buffer.seek(self.segments[0]) + classes = self._parse_class_attributes(self.segments[1]) + # Get first set of properties from segment 2 + self.buffer.seek(self.segments[1]) + props1 = self._parse_properties(self.segments[2]) + # Get the property description from segment 3 + self.buffer.seek(self.segments[2]) + object_info = self._parse_prop_description(classes, self.segments[3]) + # Get more properties from segment 4 + self.buffer.seek(self.segments[3]) + props2 = self._parse_properties(self.segments[4]) + # Check that the last segment is empty + self.buffer.seek(self.segments[4]) + seg5_length = (self.segments[5] - self.segments[4]) // 8 + read_delim(self.buffer, seg5_length) + props = (props1, props2) + self._to_attrs(object_info, props) + + def _to_attrs(self, object_info, props): + """Re-organise the various classes and subclasses into a nested + dictionary. + :return: + """ + for pkg_clss, indices, idx in object_info: + pkg, clss = pkg_clss + idx = max(indices) + which = indices.index(idx) + obj = flatten_obj(props[which][idx]) + subdict = self.attrs + if pkg != "": + subdict = self.attrs.setdefault(pkg, {}) + if clss in subdict: + if isinstance(subdict[clss], list): + subdict[clss].append(obj) + else: + subdict[clss] = [subdict[clss]] + subdict[clss].append(obj) + else: + subdict[clss] = obj + + def describe(self): + describe(self.attrs) + + def _parse_class_attributes(self, section_end): + """Read the Class attributes = the first segment""" + read_delim(self.buffer, 4) + classes = [] + while self.buffer.tell() < section_end: + package_index = read_int(self.buffer) - 1 + package = self.names[package_index] if package_index > 0 else "" + name_idx = read_int(self.buffer) - 1 + name = self.names[name_idx] if name_idx > 0 else "" + classes.append((package, name)) + read_delim(self.buffer, 2) + return classes + + def _parse_prop_description(self, classes, section_end): + """Parse the description of each property = the third segment""" + read_delim(self.buffer, 6) + object_info = [] + while self.buffer.tell() < section_end: + class_idx = read_int(self.buffer) - 1 + class_type = classes[class_idx] + read_delim(self.buffer, 2) + indices = [x - 1 for x in read_int(self.buffer, 2)] + obj_id = read_int(self.buffer) + object_info.append((class_type, indices, obj_id)) + return object_info + + def _parse_properties(self, section_end): + """ + Parse the actual values of the attributes == segments 2 and 4 + """ + read_delim(self.buffer, 2) + props = [] + while self.buffer.tell() < section_end: + n_props = read_int(self.buffer) + d = parse_prop(n_props, self.buffer, self.names, self.heap) + if not d: # Empty dictionary + break + props.append(d) + # Move to next 8-byte aligned offset + self.buffer.seek(self.buffer.tell() + self.buffer.tell() % 8) + return props + + def to_hdf(self, filename): + f = h5py.File(filename, mode="w") + save_to_hdf(f, "/", self.attrs) + + +def describe(d, indent=0, width=4, out=None): + for key, value in d.items(): + print(f'{"": <{width * indent}}' + str(key), file=out) + if isinstance(value, dict): + describe(value, indent + 1, out=out) + elif isinstance(value, np.ndarray): + print( + f'{"": <{width * (indent + 1)}} {value.shape} array ' + f"of type {value.dtype}", + file=out, + ) + elif isinstance(value, scipy.sparse.csc.csc_matrix): + print( + f'{"": <{width * (indent + 1)}} {value.shape} ' + f"sparse matrix of type {value.dtype}", + file=out, + ) + elif isinstance(value, Iterable) and not isinstance(value, str): + print( + f'{"": <{width * (indent + 1)}} {type(value)} of len ' f"{len(value)}", + file=out, + ) + else: + print(f'{"": <{width * (indent + 1)}} {value}', file=out) + + +def parse_prop(n_props, buff, names, heap): + d = dict() + for i in range(n_props): + name_idx, flag, heap_idx = read_int(buff, 3) + if flag not in [0, 1, 2] and name_idx == 0: + n_props = flag + buff.seek(buff.tell() - 1) # go back on one byte + d = parse_prop(n_props, buff, names, heap) + else: + item_name = names[name_idx - 1] + if flag == 0: + d[item_name] = names[heap_idx] + elif flag == 1: + d[item_name] = heap[heap_idx + 2] # Todo: what is the heap? + elif flag == 2: + assert 0 <= heap_idx <= 1, ( + "Boolean flag has a value other " "than 0 or 1 " + ) + d[item_name] = bool(heap_idx) + else: + raise ValueError( + "unknown flag {} for property {} with heap " + "index {}".format(flag, item_name, heap_idx) + ) + return d + + +def is_object(x): + """Checking object dtype for structured numpy arrays""" + if x.dtype.names is not None and len(x.dtype.names) > 1: # Complex obj + return all(x.dtype[ix] == np.object for ix in range(len(x.dtype))) + else: # simple object + return x.dtype == np.object + + +def flatten_obj(arr): + # TODO turn structured arrays into nested dicts of lists rather that + # lists of dicts + if isinstance(arr, np.ndarray): + if arr.dtype.names: + arrdict = dict() + for fieldname in arr.dtype.names: + arrdict[fieldname] = flatten_obj(arr[fieldname]) + arr = arrdict + elif arr.dtype == np.object and arr.ndim == 0: + arr = flatten_obj(arr[()]) + elif arr.dtype == np.object and arr.ndim > 0: + try: + arr = np.stack(arr) + if arr.dtype.names: + d = {k: flatten_obj(arr[k]) for k in arr.dtype.names} + arr = d + except: + arr = [flatten_obj(x) for x in arr.tolist()] + elif isinstance(arr, dict): + arr = {k: flatten_obj(v) for k, v in arr.items()} + elif isinstance(arr, list): + try: + arr = flatten_obj(np.stack(arr)) + except: + arr = [flatten_obj(x) for x in arr] + return arr + + +def save_to_hdf(h5file, path, dic): + """ + Saving a MATLAB object to HDF5 + """ + if isinstance(dic, list): + dic = {str(i): v for i, v in enumerate(dic)} + for key, item in dic.items(): + if isinstance(item, (int, float, str)): + h5file[path].attrs.create(key, item) + elif isinstance(item, list): + if len(item) == 0 and path + key not in h5file: # empty list empty group + h5file.create_group(path + key) + if all(isinstance(x, (int, float, str)) for x in item): + if path not in h5file: + h5file.create_group(path) + h5file[path].attrs.create(key, item) + else: + if path + key not in h5file: + h5file.create_group(path + key) + save_to_hdf( + h5file, path + key + "/", {str(i): x for i, x in enumerate(item)} + ) + elif isinstance(item, scipy.sparse.csc.csc_matrix): + try: + h5file.create_dataset( + path + key, data=item.todense(), compression="gzip" + ) + except Exception as e: + print(path + key) + raise e + elif isinstance(item, (np.ndarray, np.int64, np.float64)): + if item.dtype == np.dtype("<U1"): # Strings to 'S' type for HDF5 + item = item.astype("S") + try: + h5file.create_dataset(path + key, data=item, compression="gzip") + except Exception as e: + print(path + key) + raise e + elif isinstance(item, dict): + if path + key not in h5file: + h5file.create_group(path + key) + save_to_hdf(h5file, path + key + "/", item) + elif item is None: + continue + else: + raise ValueError(f"Cannot save {type(item)} type at key {path + key}") + + +## NOT YET FULLY IMPLEMENTED! + + +class _Info: + def __init__(self, info): + self.info = info + self._identity = None + + def __getitem__(self, item): + val = self.info[item] + if val.shape[0] == 1: + val = val[0] + if 0 in val[1].shape: + val = val[0] + if isinstance(val, scipy.sparse.csc.csc_matrix): + return np.asarray(val.todense()) + if val.dtype == np.dtype("O"): + # 3d "sparse matrix" + if all(isinstance(x, scipy.sparse.csc.csc_matrix) for x in val): + val = np.array([x.todense() for x in val]) + # TODO: The actual object data + equality = val[0] == val[1] + if isinstance(equality, scipy.sparse.csc.csc_matrix): + equality = equality.todense() + if equality.all(): + val = val[0] + return np.squeeze(val) + + @property + def categories(self): + return self.info.dtype.names + + +class TrapInfo(_Info): + def __init__(self, info): + """ + The information on all of the traps in a given position. + + :param info: The TrapInfo structure, can be found in the heap of + the CTimelapse at index 7 + """ + super().__init__(info) + + +class CellInfo(_Info): + def __init__(self, info): + """ + The extracted information of all cells in a given position. + :param info: The CellInfo structure, can be found in the heap + of the CTimelapse at index 15. + """ + super().__init__(info) + + @property + def identity(self): + if self._identity is None: + self._identity = pd.DataFrame( + zip(self["trapNum"], self["cellNum"]), columns=["trapNum", "cellNum"] + ) + return self._identity + + def index(self, trapNum, cellNum): + query = "trapNum=={} and cellNum=={}".format(trapNum, cellNum) + try: + result = self.identity.query(query).index[0] + except Exception as e: + print(query) + raise e + return result + + @property + def nucEstConv1(self): + return np.asarray(self.info["nuc_est_conv"][0][0].todense()) + + @property + def nucEstConv2(self): + return np.asarray(self.info["nuc_est_conv"][0][1].todense()) + + @property + def mothers(self): + return np.where((self["births"] != 0).any(axis=1))[0] + + def daughters(self, mother_index): + """ + Get daughters of cell with index `mother_index`. + + :param mother_index: the index of the mother within the data. This is + different from the mother's cell/trap identity. + """ + daughter_ids = np.unique(self["daughterLabel"][mother_index]).tolist() + daughter_ids.remove(0) + mother_trap = self.identity["trapNum"].loc[mother_index] + daughters = [self.index(mother_trap, cellNum) for cellNum in daughter_ids] + return daughters + + +def _todict(matobj): + """ + A recursive function which constructs from matobjects nested dictionaries + """ + if not hasattr(matobj, "_fieldnames"): + return matobj + d = {} + for strg in matobj._fieldnames: + elem = matobj.__dict__[strg] + if isinstance(elem, matlab.mio5_params.mat_struct): + d[strg] = _todict(elem) + elif isinstance(elem, np.ndarray): + d[strg] = _toarray(elem) + else: + d[strg] = elem + return d + + +def _toarray(ndarray): + """ + A recursive function which constructs ndarray from cellarrays + (which are loaded as numpy ndarrays), recursing into the elements + if they contain matobjects. + """ + if ndarray.dtype != "float64": + elem_list = [] + for sub_elem in ndarray: + if isinstance(sub_elem, matlab.mio5_params.mat_struct): + elem_list.append(_todict(sub_elem)) + elif isinstance(sub_elem, np.ndarray): + elem_list.append(_toarray(sub_elem)) + else: + elem_list.append(sub_elem) + return np.array(elem_list) + else: + return ndarray + + +from pathlib import Path + + +class Strain: + """The cell info for all the positions of a strain.""" + + def __init__(self, origin, strain): + self.origin = Path(origin) + self.files = [x for x in origin.iterdir() if strain in str(x)] + self.cts = [matObject(x) for x in self.files] + self.cinfos = [CellInfo(x.heap[15]) for x in self.cts] + self._identity = None + + def __getitem__(self, item): + try: + return np.concatenate([c[item] for c in self.cinfos]) + except ValueError: # If first axis is the channel + return np.concatenate([c[item] for c in self.cinfos], axis=1) + + @property + def categories(self): + return set.union(*[set(c.categories) for c in self.cinfos]) + + @property + def identity(self): + if self._identity is None: + identities = [] + for pos_id, cinfo in enumerate(self.cinfos): + identity = cinfo.identity + identity["position"] = pos_id + identities.append(identity) + self._identity = pd.concat(identities, ignore_index=True) + return self._identity + + def index(self, posNum, trapNum, cellNum): + query = "position=={} and trapNum=={} and cellNum=={}".format( + posNum, trapNum, cellNum + ) + try: + result = self.identity.query(query).index[0] + except Exception as e: + raise e + return result + + @property + def mothers(self): + # At least two births are needed to be considered a mother cell + return np.where(np.count_nonzero(self["births"], axis=1) > 3)[0] + + def daughters(self, mother_index): + """ + Get daughters of cell with index `mother_index`. + + :param mother_index: the index of the mother within the data. This is + different from the mother's pos/trap/cell identity. + """ + daughter_ids = np.unique(self["daughterLabel"][mother_index]).tolist() + if 0 in daughter_ids: + daughter_ids.remove(0) + mother_pos_trap = self.identity[["position", "trapNum"]].loc[mother_index] + daughters = [] + for cellNum in daughter_ids: + try: + daughters.append(self.index(*mother_pos_trap, cellNum)) + except IndexError: + continue + return daughters diff --git a/aliby/io/metadata_parser.py b/aliby/io/metadata_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..81938152ecdf00bf3d892d4231ca44e5b47d6635 --- /dev/null +++ b/aliby/io/metadata_parser.py @@ -0,0 +1,77 @@ +""" +Parse microscopy log files according to specified JSON grammars. +Produces dictionary to include in HDF5 +""" +import glob +import os +import numpy as np +import pandas as pd +from datetime import datetime, timezone +from pytz import timezone + +from logfile_parser import Parser + +# Paradigm: able to do something with all datatypes present in log files, +# then pare down on what specific information is really useful later. + +# Needed because HDF5 attributes do not support dictionaries +def flatten_dict(nested_dict, separator='/'): + ''' + Flattens nested dictionary + ''' + df = pd.json_normalize(nested_dict, sep=separator) + return df.to_dict(orient='records')[0] + +# Needed because HDF5 attributes do not support datetime objects +# Takes care of time zones & daylight saving +def datetime_to_timestamp(time, locale = 'Europe/London'): + ''' + Convert datetime object to UNIX timestamp + ''' + return timezone(locale).localize(time).timestamp() + +def find_file(root_dir, regex): + file = glob.glob(os.path.join(str(root_dir), regex)) + if len(file) != 1: + return None + else: + return file[0] + +# TODO: re-write this as a class if appropriate +# WARNING: grammars depend on the directory structure of a locally installed +# logfile_parser repo +def parse_logfiles(root_dir, + acq_grammar = 'multiDGUI_acq_format.json', + log_grammar = 'multiDGUI_log_format.json'): + ''' + Parse acq and log files depending on the grammar specified, then merge into + single dict. + ''' + # Both acq and log files contain useful information. + #ACQ_FILE = 'flavin_htb2_glucose_long_ramp_DelftAcq.txt' + #LOG_FILE = 'flavin_htb2_glucose_long_ramp_Delftlog.txt' + log_parser = Parser(log_grammar) + try: + log_file = find_file(root_dir, '*log.txt') + except: + raise ValueError('Experiment log file not found.') + with open(log_file, 'r') as f: + log_parsed = log_parser.parse(f) + + acq_parser = Parser(acq_grammar) + try: + acq_file = find_file(root_dir, '*[Aa]cq.txt') + except: + raise ValueError('Experiment acq file not found.') + with open(acq_file, 'r') as f: + acq_parsed = acq_parser.parse(f) + + parsed = {**acq_parsed, **log_parsed} + + for key, value in parsed.items(): + if isinstance(value, datetime): + parsed[key] = datetime_to_timestamp(value) + + parsed_flattened = flatten_dict(parsed) + + return parsed_flattened diff --git a/aliby/io/omero.py b/aliby/io/omero.py new file mode 100644 index 0000000000000000000000000000000000000000..266d7d163db6137b584b357c9361a726101bed5c --- /dev/null +++ b/aliby/io/omero.py @@ -0,0 +1,133 @@ +import h5py +import omero +from omero.gateway import BlitzGateway +from aliby.experiment import get_data_lazy +from aliby.cells import CellsHDF + + +class Argo: + # TODO use the one in extraction? + def __init__( + self, host="islay.bio.ed.ac.uk", username="upload", password="***REMOVED***" + ): + self.conn = None + self.host = host + self.username = username + self.password = password + + def get_meta(self): + pass + + def __enter__(self): + self.conn = BlitzGateway( + host=self.host, username=self.username, passwd=self.password + ) + self.conn.connect() + return self + + def __exit__(self, *exc): + self.conn.close() + return False + + +class Dataset(Argo): + def __init__(self, expt_id): + super().__init__() + self.expt_id = expt_id + self._files = None + + @property + def dataset(self): + return self.conn.getObject("Dataset", self.expt_id) + + @property + def name(self): + return self.dataset.getName() + + @property + def date(self): + return self.dataset.getDate() + + @property + def unique_name(self): + return "_".join((self.date.strftime("%Y_%m_%d").replace("/", "_"), self.name)) + + def get_images(self): + return {im.getName(): im.getId() for im in self.dataset.listChildren()} + + @property + def files(self): + if self._files is None: + self._files = { + x.getFileName(): x + for x in self.dataset.listAnnotations() + if isinstance(x, omero.gateway.FileAnnotationWrapper) + } + return self._files + + @property + def tags(self): + if self._tags is None: + self._tags = { + x.getName(): x + for x in self.dataset.listAnnotations() + if isinstance(x, omero.gateway.TagAnnotationWrapper) + } + return self._tags + + def cache_logs(self, root_dir): + for name, annotation in self.files.items(): + filepath = root_dir / annotation.getFileName().replace("/", "_") + if str(filepath).endswith("txt") and not filepath.exists(): + # Save only the text files + with open(str(filepath), "wb") as fd: + for chunk in annotation.getFileInChunks(): + fd.write(chunk) + return True + + +class Image(Argo): + def __init__(self, image_id): + super().__init__() + self.image_id = image_id + self._image_wrap = None + + @property + def image_wrap(self): + # TODO check that it is alive/ connected + if self._image_wrap is None: + self._image_wrap = self.conn.getObject("Image", self.image_id) + return self._image_wrap + + @property + def name(self): + return self.image_wrap.getName() + + @property + def data(self): + return get_data_lazy(self.image_wrap) + + @property + def metadata(self): + meta = dict() + meta["size_x"] = self.image_wrap.getSizeX() + meta["size_y"] = self.image_wrap.getSizeY() + meta["size_z"] = self.image_wrap.getSizeZ() + meta["size_c"] = self.image_wrap.getSizeC() + meta["size_t"] = self.image_wrap.getSizeT() + meta["channels"] = self.image_wrap.getChannelLabels() + meta["name"] = self.image_wrap.getName() + return meta + + +class Cells(CellsHDF): + def __init__(self, filename): + file = h5py.File(filename, "r") + super().__init__(file) + + def __enter__(self): + return self + + def __exit__(self, *exc): + self.close + return False diff --git a/aliby/io/signal.py b/aliby/io/signal.py new file mode 100644 index 0000000000000000000000000000000000000000..63b930cb05660e4f050809c7781d8b787a4cc7c4 --- /dev/null +++ b/aliby/io/signal.py @@ -0,0 +1,234 @@ +import numpy as np +from copy import copy +from itertools import accumulate + +from numpy import ndarray + +# from more_itertools import first_true + +import h5py +import pandas as pd +from utils_find_1st import find_1st, cmp_larger + +from aliby.io.base import BridgeH5 + + +class Signal(BridgeH5): + """ + Class that fetches data from the hdf5 storage for post-processing + """ + + def __init__(self, file): + super().__init__(file, flag=None) + + self.names = ["experiment", "position", "trap"] + + @staticmethod + def add_name(df, name): + df.name = name + return df + + def mothers(self, signal, cutoff=0.8): + df = self[signal] + get_mothers = lambda df: df.loc[df.notna().sum(axis=1) > df.shape[1] * cutoff] + if isinstance(df, pd.DataFrame): + return get_mothers(df) + elif isinstance(df, list): + return [get_mothers(d) for d in df] + + def __getitem__(self, dsets): + + if isinstance(dsets, str) and ( + dsets.startswith("postprocessing") + or dsets.startswith("/postprocessing") + or dsets.endswith("imBackground") + ): + df = self.get_raw(dsets) + + elif isinstance(dsets, str): + df = self.apply_prepost(dsets) + + elif isinstance(dsets, list): + is_bgd = [dset.endswith("imBackground") for dset in dsets] + assert sum(is_bgd) == 0 or sum(is_bgd) == len( + dsets + ), "Trap data and cell data can't be mixed" + with h5py.File(self.filename, "r") as f: + return [self.add_name(self.apply_prepost(dset), dset) for dset in dsets] + + return self.add_name(df, dsets) + + def apply_prepost(self, dataset: str): + merges = self.get_merges() + with h5py.File(self.filename, "r") as f: + df = self.dset_to_df(f, dataset) + + merged = df + if merges.any(): + # Split in two dfs, one with rows relevant for merging and one without them + mergable_ids = pd.MultiIndex.from_arrays( + np.unique(merges.reshape(-1, 2), axis=0).T, + names=df.index.names, + ) + merged = self.apply_merge(df.loc[mergable_ids], merges) + + nonmergable_ids = df.index.difference(mergable_ids) + + merged = pd.concat( + (merged, df.loc[nonmergable_ids]), names=df.index.names + ) + + search = lambda a, b: np.where( + np.in1d( + np.ravel_multi_index(a.T, a.max(0) + 1), + np.ravel_multi_index(b.T, a.max(0) + 1), + ) + ) + if "modifiers/picks" in f: + picks = self.get_picks(names=merged.index.names) + missing_cells = [i for i in picks if tuple(i) not in set(merged.index)] + + if picks: + # return merged.loc[ + # set(picks).intersection([tuple(x) for x in merged.index]) + # ] + return merged.loc[picks] + else: + if isinstance(merged.index, pd.MultiIndex): + empty_lvls = [[] for i in merged.index.names] + index = pd.MultiIndex( + levels=empty_lvls, + codes=empty_lvls, + names=merged.index.names, + ) + else: + index = pd.Index([], name=merged.index.name) + merged = pd.DataFrame([], index=index) + return merged + + @property + def datasets(self): + with h5py.File(self.filename, "r") as f: + dsets = f.visititems(self._if_ext_or_post) + return dsets + + def get_merged(self, dataset): + return self.apply_prepost(dataset, skip_pick=True) + + @property + def merges(self): + with h5py.File(self.filename, "r") as f: + dsets = f.visititems(self._if_merges) + return dsets + + @property + def n_merges(self): + print("{} merge events".format(len(self.merges))) + + @property + def merges(self): + with h5py.File(self.filename, "r") as f: + dsets = f.visititems(self._if_merges) + return dsets + + @property + def picks(self): + with h5py.File(self.filename, "r") as f: + dsets = f.visititems(self._if_picks) + return dsets + + def apply_merge(self, df, changes): + if len(changes): + + for target, source in changes: + df.loc[tuple(target)] = self.join_tracks_pair( + df.loc[tuple(target)], df.loc[tuple(source)] + ) + df.drop(tuple(source), inplace=True) + + return df + + def get_raw(self, dataset): + if isinstance(dataset, str): + with h5py.File(self.filename, "r") as f: + return self.dset_to_df(f, dataset) + elif isinstance(dataset, list): + return [self.get_raw(dset) for dset in dataset] + + def get_merges(self): + # fetch merge events going up to the first level + with h5py.File(self.filename, "r") as f: + merges = f.get("modifiers/merges", np.array([])) + if not isinstance(merges, np.ndarray): + merges = merges[()] + + return merges + + # def get_picks(self, levels): + def get_picks(self, names, path="modifiers/picks/"): + with h5py.File(self.filename, "r") as f: + if path in f: + return list(zip(*[f[path + name] for name in names])) + # return f["modifiers/picks"] + else: + return None + + def dset_to_df(self, f, dataset): + dset = f[dataset] + names = copy(self.names) + if not dataset.endswith("imBackground"): + names.append("cell_label") + lbls = {lbl: dset[lbl][()] for lbl in names if lbl in dset.keys()} + index = pd.MultiIndex.from_arrays( + list(lbls.values()), names=names[-len(lbls) :] + ) + + columns = ( + dset["timepoint"][()] if "timepoint" in dset else dset.attrs["columns"] + ) + + df = pd.DataFrame(dset[("values")][()], index=index, columns=columns) + + return df + + @staticmethod + def dataset_to_df(f: h5py.File, path: str, mode: str = "h5py"): + + if mode is "h5py": + all_indices = ["experiment", "position", "trap", "cell_label"] + indices = {k: f[path][k][()] for k in all_indices if k in f[path].keys()} + return pd.DataFrame( + f[path + "/values"][()], + index=pd.MultiIndex.from_arrays( + list(indices.values()), names=indices.keys() + ), + columns=f[path + "/timepoint"][()], + ) + + @staticmethod + def _if_ext_or_post(name, *args): + flag = False + if name.startswith("extraction") and len(name.split("/")) == 4: + flag = True + elif name.startswith("postprocessing") and len(name.split("/")) == 3: + flag = True + + if flag: + print(name) + + @staticmethod + def _if_merges(name: str, obj): + if isinstance(obj, h5py.Dataset) and name.startswith("modifiers/merges"): + return obj[()] + + @staticmethod + def _if_picks(name: str, obj): + if isinstance(obj, h5py.Group) and name.endswith("picks"): + return obj[()] + + @staticmethod + def join_tracks_pair(target, source): + tgt_copy = copy(target) + end = find_1st(target.values[::-1], 0, cmp_larger) + tgt_copy.iloc[-end:] = source.iloc[-end:].values + return tgt_copy diff --git a/aliby/io/utils.py b/aliby/io/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a1029e822d7b1fa91daf28f63886d9fa44ce4bc2 --- /dev/null +++ b/aliby/io/utils.py @@ -0,0 +1,44 @@ +import re +import struct + + +def clean_ascii(text): + return re.sub(r'[^\x20-\x7F]', '.', text) + + +def xxd(x, start=0, stop=None): + if stop is None: + stop = len(x) + for i in range(start, stop, 8): + # Row number + print("%04d" % i, end=" ") + # Hexadecimal bytes + for r in range(i, i + 8): + print("%02x" % x[r], end="") + if (r + 1) % 4 == 0: + print(" ", end="") + # ASCII + print(" ", clean_ascii(x[i:i + 8].decode('utf-8', errors='ignore')), + " ", end="") + # Int32 + print('{:>10} {:>10}'.format(*struct.unpack('II', x[i: i + 8])), + end=" ") + print("") # Newline + return + + +# Buffer reading functions +def read_int(buffer, n=1): + res = struct.unpack('I' * n, buffer.read(4 * n)) + if n == 1: + res = res[0] + return res + + +def read_string(buffer): + return ''.join([x.decode() for x in iter(lambda: buffer.read(1), b'\x00')]) + + +def read_delim(buffer, n): + delim = read_int(buffer, n) + assert all([x == 0 for x in delim]), "Unknown nonzero value in delimiter" diff --git a/aliby/io/writer.py b/aliby/io/writer.py new file mode 100644 index 0000000000000000000000000000000000000000..d94e624c3d8e282a18d4e90092438343190b95c9 --- /dev/null +++ b/aliby/io/writer.py @@ -0,0 +1,567 @@ +import itertools +import logging +from time import perf_counter + +import h5py +import numpy as np +import pandas as pd +from collections.abc import Iterable +from typing import Dict + +from utils_find_1st import find_1st, cmp_equal + +from aliby.io.base import BridgeH5 +from aliby.utils import timed + + +#################### Dynamic version ################################## + + +def load_attributes(file: str, group="/"): + with h5py.File(file, "r") as f: + meta = dict(f[group].attrs.items()) + return meta + + +class DynamicWriter: + data_types = {} + group = "" + compression = None + + def __init__(self, file: str): + self.file = file + self.metadata = load_attributes(file) + + def _append(self, data, key, hgroup): + """Append data to existing dataset.""" + try: + n = len(data) + except: + # Attributes have no length + n = 1 + if key not in hgroup: + # TODO Include sparsity check + max_shape, dtype = self.datatypes[key] + shape = (n,) + max_shape[1:] + hgroup.create_dataset( + key, + shape=shape, + maxshape=max_shape, + dtype=dtype, + compression=self.compression, + ) + hgroup[key][()] = data + else: + # The dataset already exists, expand it + + try: # FIXME This is broken by bugged mother-bud assignment + dset = hgroup[key] + dset.resize(dset.shape[0] + n, axis=0) + dset[-n:] = data + except: + logging.debug( + "DynamicWriter:Inconsistency between dataset shape and new empty data" + ) + return + + def _overwrite(self, data, key, hgroup): + """Overwrite existing dataset with new data""" + # We do not append to mother_assign; raise error if already saved + n = len(data) + max_shape, dtype = self.datatypes[key] + if key in hgroup: + del hgroup[key] + hgroup.require_dataset( + key, shape=(n,), dtype=dtype, compression=self.compression + ) + hgroup[key][()] = data + + def _check_key(self, key): + if key not in self.datatypes: + raise KeyError(f"No defined data type for key {key}") + + def write(self, data, overwrite: list): + # Data is a dictionary, if not, make it one + # Overwrite data is a dictionary + with h5py.File(self.file, "a") as store: + hgroup = store.require_group(self.group) + + for key, value in data.items(): + # We're only saving data that has a pre-defined data-type + self._check_key(key) + try: + if key.startswith("attrs/"): # metadata + key = key.split("/")[1] # First thing after attrs + hgroup.attrs[key] = value + elif key in overwrite: + self._overwrite(value, key, hgroup) + else: + self._append(value, key, hgroup) + except Exception as e: + print(key, value) + raise (e) + return + + +##################### Special instances ##################### +class TilerWriter(DynamicWriter): + datatypes = { + "trap_locations": ((None, 2), np.uint16), + "drifts": ((None, 2), np.float32), + "attrs/tile_size": ((1,), np.uint16), + "attrs/max_size": ((1,), np.uint16), + } + group = "trap_info" + + +tile_size = 117 + + +@timed() +def save_complex(array, dataset): + # Dataset needs to be 2D + n = len(array) + if n > 0: + dataset.resize(dataset.shape[0] + n, axis=0) + dataset[-n:, 0] = array.real + dataset[-n:, 1] = array.imag + + +@timed() +def load_complex(dataset): + array = dataset[:, 0] + 1j * dataset[:, 1] + return array + + +class BabyWriter(DynamicWriter): + compression = "gzip" + max_ncells = 2e5 # Could just make this None + max_tps = 1e3 # Could just make this None + chunk_cells = 25 # The number of cells in a chunk for edge masks + default_tile_size = 117 + datatypes = { + "centres": ((None, 2), np.uint16), + "position": ((None,), np.uint16), + "angles": ((None,), h5py.vlen_dtype(np.float32)), + "radii": ((None,), h5py.vlen_dtype(np.float32)), + "edgemasks": ((max_ncells, max_tps, tile_size, tile_size), np.bool), + "ellipse_dims": ((None, 2), np.float32), + "cell_label": ((None,), np.uint16), + "trap": ((None,), np.uint16), + "timepoint": ((None,), np.uint16), + "mother_assign": ((None,), h5py.vlen_dtype(np.uint16)), + "mother_assign_dynamic": ((None,), np.uint16), + "volumes": ((None,), np.float32), + } + group = "cell_info" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Get max_tps and trap info + self._traps_initialised = False + + def __init_trap_info(self): + # Should only be run after the traps have been initialised + trap_metadata = load_attributes(self.file, "trap_info") + tile_size = trap_metadata.get("tile_size", self.default_tile_size) + max_tps = self.metadata["time_settings/ntimepoints"][0] + self.datatypes["edgemasks"] = ( + (self.max_ncells, max_tps, tile_size, tile_size), + np.bool, + ) + self._traps_initialised = True + + def __init_edgemasks(self, hgroup, edgemasks, current_indices, n_cells): + # Create values dataset + # This holds the edge masks directly and + # Is of shape (n_tps, n_cells, tile_size, tile_size) + key = "edgemasks" + max_shape, dtype = self.datatypes[key] + shape = (n_cells, 1) + max_shape[2:] + chunks = (self.chunk_cells, 1) + max_shape[2:] + val_dset = hgroup.create_dataset( + "values", + shape=shape, + maxshape=max_shape, + dtype=dtype, + chunks=chunks, + compression=self.compression, + ) + val_dset[:, 0] = edgemasks + # Create index dataset + # Holds the (trap, cell_id) description used to index into the + # values and is of shape (n_cells, 2) + ix_max_shape = (max_shape[0], 2) + ix_shape = (0, 2) + ix_dtype = np.uint16 + ix_dset = hgroup.create_dataset( + "indices", + shape=ix_shape, + maxshape=ix_max_shape, + dtype=ix_dtype, + compression=self.compression, + ) + save_complex(current_indices, ix_dset) + + def __append_edgemasks(self, hgroup, edgemasks, current_indices): + key = "edgemasks" + val_dset = hgroup["values"] + ix_dset = hgroup["indices"] + existing_indices = load_complex(ix_dset) + # Check if there are any new labels + available = np.in1d(current_indices, existing_indices) + missing = current_indices[~available] + all_indices = np.concatenate([existing_indices, missing]) + # Resizing + t = perf_counter() + n_tps = val_dset.shape[1] + 1 + n_add_cells = len(missing) + # RESIZE DATASET FOR TIME and Cells + new_shape = (val_dset.shape[0] + n_add_cells, n_tps) + val_dset.shape[2:] + val_dset.resize(new_shape) + logging.debug(f"Timing:resizing:{perf_counter() - t}") + # Writing data + cell_indices = np.where(np.in1d(all_indices, current_indices))[0] + for ix, mask in zip(cell_indices, edgemasks): + try: + val_dset[ix, n_tps - 1] = mask + except Exception as e: + logging.debug(f"{ix}, {n_tps}, {val_dset.shape}") + # Save the index values + save_complex(missing, ix_dset) + + def write_edgemasks(self, data, keys, hgroup): + if not self._traps_initialised: + self.__init_trap_info() + # DATA is TRAP_IDS, CELL_LABELS, EDGEMASKS in a structured array + key = "edgemasks" + val_key = "values" + idx_key = "indices" + # Length of edgemasks + traps, cell_labels, edgemasks = data + n_cells = len(cell_labels) + hgroup = hgroup.require_group(key) + current_indices = np.array(traps) + 1j * np.array(cell_labels) + if val_key not in hgroup: + self.__init_edgemasks(hgroup, edgemasks, current_indices, n_cells) + else: + self.__append_edgemasks(hgroup, edgemasks, current_indices) + + def write(self, data, overwrite: list): + with h5py.File(self.file, "a") as store: + hgroup = store.require_group(self.group) + + for key, value in data.items(): + # We're only saving data that has a pre-defined data-type + self._check_key(key) + try: + if key.startswith("attrs/"): # metadata + key = key.split("/")[1] # First thing after attrs + hgroup.attrs[key] = value + elif key in overwrite: + self._overwrite(value, key, hgroup) + elif key == "edgemasks": + keys = ["trap", "cell_label", "edgemasks"] + value = [data[x] for x in keys] + self.write_edgemasks(value, keys, hgroup) + else: + self._append(value, key, hgroup) + except Exception as e: + print(key, value) + raise (e) + return + + +#################### Extraction version ############################### +class Writer(BridgeH5): + """ + Class in charge of transforming data into compatible formats + + Decoupling interface from implementation! + + Parameters + ---------- + filename: str Name of file to write into + flag: str, default=None + Flag to pass to the default file reader. If None the file remains closed. + compression: str, default=None + Compression method passed on to h5py writing functions (only used for + dataframes and other array-like data.) + """ + + def __init__(self, filename, compression=None): + super().__init__(filename, flag=None) + + if compression is None: + self.compression = "gzip" + + def write( + self, + path: str, + data: Iterable = None, + meta: Dict = {}, + overwrite: str = None, + ): + """ + Parameters + ---------- + path : str + Path inside h5 file to write into. + data : Iterable, default = None + meta : Dict, default = {} + + + """ + self.id_cache = {} + with h5py.File(self.filename, "a") as f: + if overwrite == "overwrite": # TODO refactor overwriting + if path in f: + del f[path] + elif overwrite == "accumulate": # Add a number if needed + if path in f: + parent, name = path.rsplit("/", maxsplit=1) + n = sum([x.startswith(name) for x in f[path]]) + path = path + str(n).zfill(3) + elif overwrite == "skip": + if path in f: + logging.debug("Skipping dataset {}".format(path)) + + logging.debug( + "{} {} to {} and {} metadata fields".format( + overwrite, type(data), path, len(meta) + ) + ) + if data is not None: + self.write_dset(f, path, data) + if meta: + for attr, metadata in meta.items(): + self.write_meta(f, path, attr, data=metadata) + + def write_dset(self, f: h5py.File, path: str, data: Iterable): + if isinstance(data, pd.DataFrame): + self.write_pd(f, path, data, compression=self.compression) + elif isinstance(data, pd.MultiIndex): + self.write_index(f, path, data) # , compression=self.compression) + elif isinstance(data, Iterable): + self.write_arraylike(f, path, data) + else: + self.write_atomic(data, f, path) + + def write_meta(self, f: h5py.File, path: str, attr: str, data: Iterable): + obj = f.require_group(path) + + obj.attrs[attr] = data + + @staticmethod + def write_arraylike(f: h5py.File, path: str, data: Iterable, **kwargs): + if path in f: + del f[path] + + narray = np.array(data) + + chunks = None + if narray.any(): + chunks = (1, *narray.shape[1:]) + + dset = f.create_dataset( + path, + shape=narray.shape, + chunks=chunks, + dtype="int", + compression=kwargs.get("compression", None), + ) + dset[()] = narray + + @staticmethod # TODO Use this function to implement Diane's dynamic writer + def write_dynamic(f: h5py.File, path: str, data: Iterable): + pass + + @staticmethod + def write_index(f, path, pd_index, **kwargs): + f.require_group(path) # TODO check if we can remove this + for i, name in enumerate(pd_index.names): + ids = pd_index.get_level_values(i) + id_path = path + "/" + name + f.create_dataset( + name=id_path, + shape=(len(ids),), + dtype="uint16", + compression=kwargs.get("compression", None), + ) + indices = f[id_path] + indices[()] = ids + + def write_pd(self, f, path, df, **kwargs): + values_path = path + "values" if path.endswith("/") else path + "/values" + if path not in f: + max_ncells = 2e5 + + max_tps = 1e3 + f.create_dataset( + name=values_path, + shape=df.shape, + # chunks=(min(df.shape[0], 1), df.shape[1]), + # dtype=df.dtypes.iloc[0], This is making NaN in ints into negative vals + dtype="float", + maxshape=(max_ncells, max_tps), + compression=kwargs.get("compression", None), + ) + dset = f[values_path] + dset[()] = df.values + + for name in df.index.names: + indices_path = "/".join((path, name)) + f.create_dataset( + name=indices_path, + shape=(len(df),), + dtype="uint16", # Assuming we'll always use int indices + chunks=True, + maxshape=(max_ncells,), + ) + dset = f[indices_path] + dset[()] = df.index.get_level_values(level=name).tolist() + + if df.columns.dtype == np.int or df.columns.dtype == np.dtype("uint"): + tp_path = path + "/timepoint" + f.create_dataset( + name=tp_path, + shape=(df.shape[1],), + maxshape=(max_tps,), + dtype="uint16", + ) + tps = df.columns.tolist() + f[tp_path][tps] = tps + else: + f[path].attrs["columns"] = df.columns.tolist() + else: + dset = f[values_path] + + # Filter out repeated timepoints + new_tps = set(df.columns) + if path + "/timepoint" in f: + new_tps = new_tps.difference(f[path + "/timepoint"][()]) + df = df[new_tps] + + if ( + not hasattr(self, "id_cache") or not df.index.nlevels in self.id_cache + ): # Use cache dict to store previously-obtained indices + self.id_cache[df.index.nlevels] = {} + existing_ids = self.get_existing_ids( + f, [path + "/" + x for x in df.index.names] + ) + # Split indices in existing and additional + new = df.index.tolist() + if df.index.nlevels == 1: # Cover for cases with a single index + new = [(x,) for x in df.index.tolist()] + ( + found_multis, + self.id_cache[df.index.nlevels]["additional_multis"], + ) = self.find_ids( + existing=existing_ids, + new=new, + ) + found_indices = np.array(locate_indices(existing_ids, found_multis)) + + # We must sort our indices for h5py indexing + incremental_existing = np.argsort(found_indices) + self.id_cache[df.index.nlevels]["found_indices"] = found_indices[ + incremental_existing + ] + self.id_cache[df.index.nlevels]["found_multi"] = found_multis[ + incremental_existing + ] + + existing_values = df.loc[ + [ + _tuple_or_int(x) + for x in self.id_cache[df.index.nlevels]["found_multi"] + ] + ].values + new_values = df.loc[ + [ + _tuple_or_int(x) + for x in self.id_cache[df.index.nlevels]["additional_multis"] + ] + ].values + ncells, ntps = f[values_path].shape + + # Add found cells + dset.resize(dset.shape[1] + df.shape[1], axis=1) + dset[:, ntps:] = np.nan + for i, tp in enumerate(df.columns): + dset[ + self.id_cache[df.index.nlevels]["found_indices"], tp + ] = existing_values[:, i] + # Add new cells + n_newcells = len(self.id_cache[df.index.nlevels]["additional_multis"]) + dset.resize(dset.shape[0] + n_newcells, axis=0) + dset[ncells:, :] = np.nan + + for i, tp in enumerate(df.columns): + dset[ncells:, tp] = new_values[:, i] + + # save indices + for i, name in enumerate(df.index.names): + tmp = path + "/" + name + dset = f[tmp] + n = dset.shape[0] + dset.resize(n + n_newcells, axis=0) + dset[n:] = ( + self.id_cache[df.index.nlevels]["additional_multis"][:, i] + if len(self.id_cache[df.index.nlevels]["additional_multis"].shape) + > 1 + else self.id_cache[df.index.nlevels]["additional_multis"] + ) + + tmp = path + "/timepoint" + dset = f[tmp] + n = dset.shape[0] + dset.resize(n + df.shape[1], axis=0) + dset[n:] = df.columns.tolist() + + @staticmethod + def get_existing_ids(f, paths): + # Fetch indices and convert them to a (nentries, nlevels) ndarray + return np.array([f[path][()] for path in paths]).T + + @staticmethod + def find_ids(existing, new): + # Compare two tuple sets and return the intersection and difference + # (elements in the 'new' set not in 'existing') + set_existing = set([tuple(*x) for x in zip(existing.tolist())]) + existing_cells = np.array(list(set_existing.intersection(new))) + new_cells = np.array(list(set(new).difference(set_existing))) + + return ( + existing_cells, + new_cells, + ) + + +# @staticmethod +def locate_indices(existing, new): + if new.any(): + if new.shape[1] > 1: + return [ + find_1st( + (existing[:, 0] == n[0]) & (existing[:, 1] == n[1]), True, cmp_equal + ) + for n in new + ] + else: + return [find_1st(existing[:, 0] == n, True, cmp_equal) for n in new] + else: + return [] + + +# def tuple_or_int(x): +# if isinstance(x, Iterable): +# return tuple(x) +# else: +# return x +def _tuple_or_int(x): + # Convert tuple to int if it only contains one value + if len(x) == 1: + return x[0] + else: + return x diff --git a/aliby/multiexperiment.py b/aliby/multiexperiment.py new file mode 100644 index 0000000000000000000000000000000000000000..1076ec65572d8fec2e30210519cba1b18ea4ba06 --- /dev/null +++ b/aliby/multiexperiment.py @@ -0,0 +1,25 @@ +from pathos.multiprocessing import Pool + +from aliby.pipeline import PipelineParameters, Pipeline + + +class MultiExp: + """ + Manages cases when you need to segment several different experiments with a single + position (e.g. pH calibration). + """ + + def __init__(self, expt_ids, npools=8, *args, **kwargs): + + self.expt_ids = expt_ids + + def run(self): + run_expt = lambda expt: Pipeline( + PipelineParameters.default(general={"expt_id": expt, "distributed": 0}) + ).run() + with Pool(npools) as p: + results = p.map(lambda x: self.create_pipeline(x), self.exp_ids) + + @classmethod + def default(self): + return cls(expt_ids=list(range(20448, 20467 + 1))) diff --git a/aliby/pipeline.py b/aliby/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..7d0c31588f92d9ff14e0ff7bbfbe150d4599bffd --- /dev/null +++ b/aliby/pipeline.py @@ -0,0 +1,274 @@ +""" +Pipeline and chaining elements. +""" +import logging +import os +from abc import ABC, abstractmethod +from typing import List +from pathlib import Path +import traceback + +import itertools +import yaml +from tqdm import tqdm +from time import perf_counter + +import numpy as np +import pandas as pd +from pathos.multiprocessing import Pool + +from agora.base import ParametersABC, ProcessABC +from aliby.experiment import MetaData +from aliby.io.omero import Dataset, Image +from aliby.haystack import initialise_tf +from aliby.baby_client import BabyRunner, BabyParameters +from aliby.segment import Tiler, TilerParameters +from aliby.io.writer import TilerWriter, BabyWriter +from aliby.io.signal import Signal +from extraction.core.extractor import Extractor, ExtractorParameters +from extraction.core.functions.defaults import exparams_from_meta +from postprocessor.core.processor import PostProcessor, PostProcessorParameters + + +class PipelineParameters(ParametersABC): + def __init__(self, general, tiler, baby, extraction, postprocessing): + self.general = general + self.tiler = tiler + self.baby = baby + self.extraction = extraction + self.postprocessing = postprocessing + + @classmethod + def default( + cls, + general={}, + tiler={}, + baby={}, + extraction={}, + postprocessing={}, + ): + """ + Load unit test experiment + :expt_id: Experiment id + :directory: Output directory + + Provides default parameters for the entire pipeline. This downloads the logfiles and sets the default + timepoints and extraction parameters from there. + """ + expt_id = general.get("expt_id", 19993) + directory = Path(general.get("directory", "../data")) + with Dataset(int(expt_id)) as conn: + directory = directory / conn.unique_name + if not directory.exists(): + directory.mkdir(parents=True) + # Download logs to use for metadata + conn.cache_logs(directory) + meta = MetaData(directory, None).load_logs() + tps = meta["time_settings/ntimepoints"][0] + defaults = { + "general": dict( + id=expt_id, + distributed=0, + tps=tps, + directory=directory, + strain="", + earlystop=dict( + min_tp=180, + thresh_pos_clogged=0.3, + thresh_trap_clogged=7, + ntps_to_eval=5, + ), + ) + } + defaults["tiler"] = TilerParameters.default().to_dict() + defaults["baby"] = BabyParameters.default().to_dict() + defaults["extraction"] = exparams_from_meta(meta) + defaults["postprocessing"] = PostProcessorParameters.default().to_dict() + for k in defaults.keys(): + exec("defaults[k].update(" + k + ")") + return cls(**{k: v for k, v in defaults.items()}) + + def load_logs(self): + parsed_flattened = parse_logfiles(self.log_dir) + return parsed_flattened + + +class Pipeline(ProcessABC): + """ + A chained set of Pipeline elements connected through pipes. + """ + + # Tiling, Segmentation,Extraction and Postprocessing should use their own default parameters + + # Early stop for clogging + earlystop = { + "min_tp": 180, + "thresh_pos_clogged": 0.3, + "thresh_trap_clogged": 7, + "ntps_to_eval": 5, + } + + def __init__(self, parameters: PipelineParameters): + super().__init__(parameters) + self.store = self.parameters.general["directory"] + + @classmethod + def from_yaml(cls, fpath): + # This is just a convenience function, think before implementing + # for other processes + return cls(parameters=PipelineParameters.from_yaml(fpath)) + + def run(self): + # Config holds the general information, use in main + # Steps holds the description of tasks with their parameters + # Steps: all holds general tasks + # steps: strain_name holds task for a given strain + config = self.parameters.to_dict() + expt_id = config["general"]["id"] + distributed = config["general"]["distributed"] + strain_filter = config["general"]["strain"] + root_dir = config["general"]["directory"] + root_dir = Path(root_dir) + + print("Searching OMERO") + # Do all initialis + with Dataset(int(expt_id)) as conn: + image_ids = conn.get_images() + directory = root_dir / conn.unique_name + if not directory.exists(): + directory.mkdir(parents=True) + # Download logs to use for metadata + conn.cache_logs(directory) + + # Modify to the configuration + self.parameters.general["directory"] = directory + config["general"]["directory"] = directory + + # Filter TODO integrate filter onto class and add regex + image_ids = {k: v for k, v in image_ids.items() if k.startswith(strain_filter)} + + if distributed != 0: # Gives the number of simultaneous processes + with Pool(distributed) as p: + results = p.map(lambda x: self.create_pipeline(x), image_ids.items()) + return results + else: # Sequential + results = [] + for k, v in image_ids.items(): + r = self.create_pipeline((k, v)) + results.append(r) + + def create_pipeline(self, image_id): + config = self.parameters.to_dict() + name, image_id = image_id + general_config = config["general"] + session = None + earlystop = general_config["earlystop"] + try: + directory = general_config["directory"] + with Image(image_id) as image: + filename = f"{directory}/{image.name}.h5" + try: + os.remove(filename) + except: + pass + + # Run metadata first + process_from = 0 + # if True: # not Path(filename).exists(): + meta = MetaData(directory, filename) + meta.run() + tiler = Tiler.from_image( + image, TilerParameters.from_dict(config["tiler"]) + ) + # else: TODO add support to continue local experiments? + # tiler = Tiler.from_hdf5(image.data, filename) + # s = Signal(filename) + # process_from = s["/general/None/extraction/volume"].columns[-1] + # if process_from > 2: + # process_from = process_from - 3 + # tiler.n_processed = process_from + + writer = TilerWriter(filename) + session = initialise_tf(2) + runner = BabyRunner.from_tiler( + BabyParameters.from_dict(config["baby"]), tiler + ) + bwriter = BabyWriter(filename) + exparams = ExtractorParameters.from_dict(config["extraction"]) + ext = Extractor.from_tiler(exparams, store=filename, tiler=tiler) + + # RUN + tps = general_config["tps"] + frac_clogged_traps = 0 + for i in tqdm( + range(process_from, tps), desc=image.name, initial=process_from + ): + if ( + frac_clogged_traps < earlystop["thresh_pos_clogged"] + or i < earlystop["min_tp"] + ): + t = perf_counter() + trap_info = tiler.run_tp(i) + logging.debug(f"Timing:Trap:{perf_counter() - t}s") + t = perf_counter() + writer.write(trap_info, overwrite=[]) + logging.debug(f"Timing:Writing-trap:{perf_counter() - t}s") + t = perf_counter() + seg = runner.run_tp(i) + logging.debug(f"Timing:Segmentation:{perf_counter() - t}s") + # logging.debug( + # f"Segmentation failed:Segmentation:{perf_counter() - t}s" + # ) + t = perf_counter() + bwriter.write(seg, overwrite=["mother_assign"]) + logging.debug(f"Timing:Writing-baby:{perf_counter() - t}s") + t = perf_counter() + + tmp = ext.run(tps=[i]) + logging.debug(f"Timing:Extraction:{perf_counter() - t}s") + else: # Stop if more than X% traps are clogged + logging.debug( + f"EarlyStop:{earlystop['thresh_pos_clogged']*100}% traps clogged at time point {i}" + ) + print( + f"Stopping analysis at time {i} with {frac_clogged_traps} clogged traps" + ) + break + + if ( + i > earlystop["min_tp"] + ): # Calculate the fraction of clogged traps + frac_clogged_traps = self.check_earlystop(filename, earlystop) + logging.debug(f"Quality:Clogged_traps:{frac_clogged_traps}") + print("Frac clogged traps: ", frac_clogged_traps) + + # Run post processing + post_proc_params = PostProcessorParameters.from_dict( + self.parameters.postprocessing + ).to_dict() + PostProcessor(filename, post_proc_params).run() + return True + except Exception as e: # bug in the trap getting + print(f"Caught exception in worker thread (x = {name}):") + # This prints the type, value, and stack trace of the + # current exception being handled. + traceback.print_exc() + print() + raise e + finally: + if session: + session.close() + + def check_earlystop(self, filename, es_parameters): + s = Signal(filename) + df = s["/extraction/general/None/area"] + frac_clogged_traps = ( + df[df.columns[-1 - es_parameters["ntps_to_eval"] : -1]] + .dropna(how="all") + .notna() + .groupby("trap") + .apply(sum) + .apply(np.mean, axis=1) + > es_parameters["thresh_trap_clogged"] + ).mean() + return frac_clogged_traps diff --git a/aliby/post_processing.py b/aliby/post_processing.py new file mode 100644 index 0000000000000000000000000000000000000000..58481dd84c44911d6f90f2fb3947eaeb199f9ab9 --- /dev/null +++ b/aliby/post_processing.py @@ -0,0 +1,189 @@ +""" +Post-processing utilities + +Notes: I don't have statistics on ranges of radii for each of the knots in +the radial spline representation, but we regularly extract the average of +these radii for each cell. So, depending on camera/lens, we get: + * 60x evolve: mean radii of 2-14 pixels (and measured areas of 30-750 + pixels^2) + * 60x prime95b: mean radii of 3-24 pixels (and measured areas of 60-2000 + pixels^2) + +And I presume that for a 100x lens we would get an ~5/3 increase over those +values. + +In terms of the current volume estimation method, it's currently only +implemented in the AnalysisToolbox repository, but it's super simple: + +mVol = 4/3*pi*sqrt(mArea/pi).^3 + +where mArea is simply the sum of pixels for that cell. +""" +import matplotlib.pyplot as plt +import numpy as np +from mpl_toolkits.mplot3d.art3d import Poly3DCollection +from scipy import ndimage +from skimage.morphology import erosion, ball +from skimage import measure, draw + + +def my_ball(radius): + """Generates a ball-shaped structuring element. + + This is the 3D equivalent of a disk. + A pixel is within the neighborhood if the Euclidean distance between + it and the origin is no greater than radius. + + Parameters + ---------- + radius : int + The radius of the ball-shaped structuring element. + + Other Parameters + ---------------- + dtype : data-type + The data type of the structuring element. + + Returns + ------- + selem : ndarray + The structuring element where elements of the neighborhood + are 1 and 0 otherwise. + """ + n = 2 * radius + 1 + Z, Y, X = np.mgrid[-radius:radius:n * 1j, + -radius:radius:n * 1j, + -radius:radius:n * 1j] + X **= 2 + Y **= 2 + Z **= 2 + X += Y + X += Z + # s = X ** 2 + Y ** 2 + Z ** 2 + return X <= radius * radius + +def circle_outline(r): + return ellipse_perimeter(r, r) + +def ellipse_perimeter(x, y): + im_shape = int(2*max(x, y) + 1) + img = np.zeros((im_shape, im_shape), dtype=np.uint8) + rr, cc = draw.ellipse_perimeter(int(im_shape//2), int(im_shape//2), + int(x), int(y)) + img[rr, cc] = 1 + return np.pad(img, 1) + +def capped_cylinder(x, y): + max_size = (y + 2*x + 2) + pixels = np.zeros((max_size, max_size)) + + rect_start = ((max_size-x)//2, x + 1) + rr, cc = draw.rectangle_perimeter(rect_start, extent=(x, y), + shape=(max_size, max_size)) + pixels[rr, cc] = 1 + circle_centres = [(max_size//2 - 1, x), + (max_size//2 - 1, max_size - x - 1 )] + for r, c in circle_centres: + rr, cc = draw.circle_perimeter(r, c, (x + 1)//2, + shape=(max_size, max_size)) + pixels[rr, cc] = 1 + pixels = ndimage.morphology.binary_fill_holes(pixels) + pixels ^= erosion(pixels) + return pixels + +def volume_of_sphere(radius): + return 4 / 3 * np.pi * radius**3 + +def plot_voxels(voxels): + verts, faces, normals, values = measure.marching_cubes_lewiner( + voxels, 0) + fig = plt.figure(figsize=(10, 10)) + ax = fig.add_subplot(111, projection='3d') + mesh = Poly3DCollection(verts[faces]) + mesh.set_edgecolor('k') + ax.add_collection3d(mesh) + ax.set_xlim(0, voxels.shape[0]) + ax.set_ylim(0, voxels.shape[1]) + ax.set_zlim(0, voxels.shape[2]) + plt.tight_layout() + plt.show() + +# Volume estimation +def union_of_spheres(outline, shape='my_ball', debug=False): + filled = ndimage.binary_fill_holes(outline) + nearest_neighbor = ndimage.morphology.distance_transform_edt( + outline == 0) * filled + voxels = np.zeros((filled.shape[0], filled.shape[1], max(filled.shape))) + c_z = voxels.shape[2] // 2 + for x,y in zip(*np.where(filled)): + radius = nearest_neighbor[(x,y)] + if radius > 0: + if shape == 'ball': + b = ball(radius) + elif shape == 'my_ball': + b = my_ball(radius) + else: + raise ValueError(f"{shape} is not an accepted value for " + f"shape.") + centre_b = ndimage.measurements.center_of_mass(b) + + I,J,K = np.ogrid[:b.shape[0], :b.shape[1], :b.shape[2]] + voxels[I + int(x - centre_b[0]), J + int(y - centre_b[1]), + K + int(c_z - centre_b[2])] += b + if debug: + plot_voxels(voxels) + return voxels.astype(bool).sum() + +def improved_uos(outline, shape='my_ball', debug=False): + filled = ndimage.binary_fill_holes(outline) + nearest_neighbor = ndimage.morphology.distance_transform_edt( + outline == 0) * filled + voxels = np.zeros((filled.shape[0], filled.shape[1], max(filled.shape))) + c_z = voxels.shape[2] // 2 + + while np.any(nearest_neighbor != 0): + radius = np.max(nearest_neighbor) + x, y = np.argwhere(nearest_neighbor == radius)[0] + if shape == 'ball': + b = ball(np.ceil(radius)) + elif shape == 'my_ball': + b = my_ball(np.ceil(radius)) + else: + raise ValueError(f"{shape} is not an accepted value for shape") + centre_b = ndimage.measurements.center_of_mass(b) + + I, J, K = np.ogrid[:b.shape[0], :b.shape[1], :b.shape[2]] + voxels[I + int(x - centre_b[0]), J + int(y - centre_b[1]), + K + int(c_z - centre_b[2])] += b + + # Use the central disk of the ball from voxels to get the circle + # = 0 if nn[x,y] < r else nn[x,y] + rr, cc = draw.circle(x, y, np.ceil(radius), nearest_neighbor.shape) + nearest_neighbor[rr, cc] = 0 + if debug: + plot_voxels(voxels) + return voxels.astype(bool).sum() + +def conical(outline, debug=False): + nearest_neighbor = ndimage.morphology.distance_transform_edt( + outline == 0) * ndimage.binary_fill_holes(outline) + if debug: + hf = plt.figure() + ha = hf.add_subplot(111, projection='3d') + + X, Y = np.meshgrid(np.arange(nearest_neighbor.shape[0]), + np.arange(nearest_neighbor.shape[1])) + ha.plot_surface(X, Y, nearest_neighbor) + plt.show() + return 4 * nearest_neighbor.sum() + +def volume(outline, method='spheres'): + if method=='conical': + return conical(outline) + elif method=='spheres': + return union_of_spheres(outline) + else: + raise ValueError(f"Method {method} not implemented.") + +def circularity(outline): + pass \ No newline at end of file diff --git a/aliby/results.py b/aliby/results.py new file mode 100644 index 0000000000000000000000000000000000000000..fd12c2831dc9da6e84017eb5e80b3672492013d2 --- /dev/null +++ b/aliby/results.py @@ -0,0 +1,35 @@ +"""Pipeline results classes and utilities""" + + +class SegmentationResults: + """ + Object storing the data from the Segmentation pipeline. + Everything is stored as an `AttributeDict`, which is a `defaultdict` where + you can get elements as attributes. + + In addition, it implements: + - IO functionality (read from file, write to file) + """ + def __init__(self, raw_expt): + pass + + + + +class CellResults: + """ + Results on a set of cells TODO: what set of cells, how? + + Contains: + * cellInf describing which cells are taken into account + * annotations on the cell + * segmentation maps of the cell TODO: how to define and save this? + * trapLocations TODO: why is this not part of cellInf? + """ + + def __init__(self, cellInf=None, annotations=None, segmentation=None, + trapLocations=None): + self._cellInf = cellInf + self._annotations = annotations + self._segmentation = segmentation + self._trapLocations = trapLocations diff --git a/aliby/segment.py b/aliby/segment.py new file mode 100644 index 0000000000000000000000000000000000000000..6550bd404f1d042b99f24105677900e6e29c60e2 --- /dev/null +++ b/aliby/segment.py @@ -0,0 +1,344 @@ +"""Segment/segmented pipelines. +Includes splitting the image into traps/parts, +cell segmentation, nucleus segmentation.""" +import warnings +from functools import lru_cache + +import h5py +import numpy as np + +from pathlib import Path, PosixPath + +from skimage.registration import phase_cross_correlation + +from agora.base import ParametersABC, ProcessABC +from aliby.traps import segment_traps +from aliby.timelapse import TimelapseOMERO +from aliby.io.matlab import matObject +from aliby.traps import ( + identify_trap_locations, + get_trap_timelapse, + get_traps_timepoint, + centre, + get_trap_timelapse_omero, +) +from aliby.utils import accumulate, get_store_path + +from aliby.io.writer import Writer, load_attributes +from aliby.io.metadata_parser import parse_logfiles + +trap_template_directory = Path(__file__).parent / "trap_templates" +# TODO do we need multiple templates, one for each setup? +trap_template = np.array([]) # np.load(trap_template_directory / "trap_prime.npy") + + +def get_tile_shapes(x, tile_size, max_shape): + half_size = tile_size // 2 + xmin = int(x[0] - half_size) + ymin = max(0, int(x[1] - half_size)) + if xmin + tile_size > max_shape[0]: + xmin = max_shape[0] - tile_size + if ymin + tile_size > max_shape[1]: + ymin = max_shape[1] - tile_size + return xmin, xmin + tile_size, ymin, ymin + tile_size + + +###################### Dask versions ######################## +class Trap: + def __init__(self, centre, parent, size, max_size): + self.centre = centre + self.parent = parent # Used to access drifts + self.size = size + self.half_size = size // 2 + self.max_size = max_size + + def padding_required(self, tp): + """Check if we need to pad the trap image for this time point.""" + try: + assert all(self.at_time(tp) - self.half_size >= 0) + assert all(self.at_time(tp) + self.half_size <= self.max_size) + except AssertionError: + return True + return False + + def at_time(self, tp): + """Return trap centre at time tp""" + drifts = self.parent.drifts + return self.centre - np.sum(drifts[:tp], axis=0) + + def as_tile(self, tp): + """Return trap in the OMERO tile format of x, y, w, h + + Also returns the padding necessary for this tile. + """ + x, y = self.at_time(tp) + # tile bottom corner + x = int(x - self.half_size) + y = int(y - self.half_size) + return x, y, self.size, self.size + + def as_range(self, tp): + """Return trap in a range format, two slice objects that can be used in Arrays""" + x, y, w, h = self.as_tile(tp) + return slice(x, x + w), slice(y, y + h) + + +class TrapLocations: + def __init__(self, initial_location, tile_size, max_size=1200, drifts=[]): + self.tile_size = tile_size + self.max_size = max_size + self.initial_location = initial_location + self.traps = [ + Trap(centre, self, tile_size, max_size) for centre in initial_location + ] + self.drifts = drifts + + @classmethod + def from_source(cls, fpath: str): + with h5py.File(fpath, "r") as f: + # TODO read tile size from file metadata + drifts = f["trap_info/drifts"][()] + tlocs = cls(f["trap_info/trap_locations"][()], tile_size=96, drifts=drifts) + + return tlocs + + @property + def shape(self): + return len(self.traps), len(self.drifts) + + def __len__(self): + return len(self.traps) + + def __iter__(self): + yield from self.traps + + def padding_required(self, tp): + return any([trap.padding_required(tp) for trap in self.traps]) + + def to_dict(self, tp): + res = dict() + if tp == 0: + res["trap_locations"] = self.initial_location + res["attrs/tile_size"] = self.tile_size + res["attrs/max_size"] = self.max_size + res["drifts"] = np.expand_dims(self.drifts[tp], axis=0) + # res['processed_timepoints'] = tp + return res + + @classmethod + def read_hdf5(cls, file): + with h5py.File(file, "r") as hfile: + trap_info = hfile["trap_info"] + initial_locations = trap_info["trap_locations"][()] + drifts = trap_info["drifts"][()] + max_size = trap_info.attrs["max_size"] + tile_size = trap_info.attrs["tile_size"] + trap_locs = cls(initial_locations, tile_size, max_size=max_size) + trap_locs.drifts = drifts + return trap_locs + + +class TilerParameters(ParametersABC): + def __init__( + self, tile_size: int, ref_channel: str, ref_z: int, template_name: str = None + ): + self.tile_size = tile_size + self.ref_channel = ref_channel + self.ref_z = ref_z + self.template_name = template_name + + @classmethod + def from_template(cls, template_name: str, ref_channel: str, ref_z: int): + return cls(template.shape[0], ref_channel, ref_z, template_path=template_name) + + @classmethod + def default(cls): + return cls(96, "Brightfield", 0) + + +class Tiler(ProcessABC): + """A dummy TimelapseTiler object fora Dask Demo. + + Does trap finding and image registration.""" + + def __init__( + self, + image, + metadata, + parameters: TilerParameters, + ): + super().__init__(parameters) + self.image = image + self.channels = metadata["channels"] + self.ref_channel = self.get_channel_index(parameters.ref_channel) + + @classmethod + def from_image(cls, image, parameters: TilerParameters): + return cls(image.data, image.metadata, parameters) + + @classmethod + def from_hdf5(cls, image, filepath, tile_size=None): + trap_locs = TrapLocations.read_hdf5(filepath) + metadata = load_attributes(filepath) + metadata["channels"] = metadata["channels/channel"].tolist() + if tile_size is None: + tile_size = trap_locs.tile_size + return Tiler( + image=image, + metadata=metadata, + template=None, + tile_size=tile_size, + trap_locs=trap_locs, + ) + + @lru_cache(maxsize=2) + def get_tc(self, t, c): + # Get image + full = self.image[t, c].compute() # FORCE THE CACHE + return full + + @property + def shape(self): + c, t, z, y, x = self.image.shape + return (c, t, x, y, z) + + @property + def n_processed(self): + if not hasattr(self, "_n_processed"): + self._n_processed = 0 + return self._n_processed + + @n_processed.setter + def n_processed(self, value): + self._n_processed = value + + @property + def n_traps(self): + return len(self.trap_locs) + + @property + def finished(self): + return self.n_processed == self.image.shape[0] + + def _initialise_traps(self, tile_size): + """Find initial trap positions. + + Removes all those that are too close to the edge so no padding is necessary. + """ + half_tile = tile_size // 2 + max_size = min(self.image.shape[-2:]) + initial_image = self.image[ + 0, self.ref_channel, self.ref_z + ] # First time point, first channel, first z-position + trap_locs = segment_traps(initial_image, tile_size) + trap_locs = [ + [x, y] + for x, y in trap_locs + if half_tile < x < max_size - half_tile + and half_tile < y < max_size - half_tile + ] + self.trap_locs = TrapLocations(trap_locs, tile_size) + + def find_drift(self, tp): + # TODO check that the drift doesn't move any tiles out of the image, remove them from list if so + prev_tp = max(0, tp - 1) + drift, error, _ = phase_cross_correlation( + self.image[prev_tp, self.ref_channel, self.ref_z], + self.image[tp, self.ref_channel, self.ref_z], + ) + self.trap_locs.drifts.append(drift) + + def get_tp_data(self, tp, c): + traps = [] + full = self.get_tc(tp, c) + # if self.trap_locs.padding_required(tp): + for trap in self.trap_locs: + ndtrap = self.ifoob_pad(full, trap.as_range(tp)) + + traps.append(ndtrap) + return np.stack(traps) + + def get_trap_data(self, trap_id, tp, c): + full = self.get_tc(tp, c) + trap = self.trap_locs.traps[trap_id] + ndtrap = self.ifoob_pad(full, trap.as_range(tp)) + + return ndtrap + + @staticmethod + def ifoob_pad(full, slices): + """ + Returns the slices padded if it is out of bounds + + Parameters: + ---------- + full: (zstacks, max_size, max_size) ndarray + Entire position with zstacks as first axis + slices: tuple of two slices + Each slice indicates an axis to index + + + Returns + Trap for given slices, padded with median if needed, or np.nan if the padding is too much + """ + max_size = full.shape[-1] + + y, x = [slice(max(0, s.start), min(max_size, s.stop)) for s in slices] + trap = full[:, y, x] + + padding = np.array( + [(-min(0, s.start), -min(0, max_size - s.stop)) for s in slices] + ) + if padding.any(): + tile_size = slices[0].stop - slices[0].start + if (padding > tile_size / 4).any(): + trap = np.full((full.shape[0], tile_size, tile_size), np.nan) + else: + + trap = np.pad(trap, [[0, 0]] + padding.tolist(), "median") + + return trap + + def run_tp(self, tp): + assert tp >= self.n_processed, "Time point already processed" + # TODO check contiguity? + if self.n_processed == 0: + self._initialise_traps(self.tile_size) + self.find_drift(tp) # Get drift + # update n_processed + self.n_processed += 1 + # Return result for writer + return self.trap_locs.to_dict(tp) + + def run(self, tp): + if self.n_processed == 0: + self._initialise_traps(self.tile_size) + self.find_drift(tp) # Get drift + # update n_processed + self.n_processed += 1 + # Return result for writer + return self.trap_locs.to_dict(tp) + + # The next set of functions are necessary for the extraction object + def get_traps_timepoint(self, tp, tile_size=None, channels=None, z=None): + # FIXME we currently ignore the tile size + # FIXME can we ignore z(always give) + res = [] + for c in channels: + val = self.get_tp_data(tp, c)[:, z] # Only return requested z + # positions + # Starts at traps, z, y, x + # Turn to Trap, C, T, X, Y, Z order + val = val.swapaxes(1, 3).swapaxes(1, 2) + val = np.expand_dims(val, axis=1) + res.append(val) + return np.stack(res, axis=1) + + def get_channel_index(self, item): + for i, ch in enumerate(self.channels): + if item in ch: + return i + + def get_position_annotation(self): + # TODO required for matlab support + return None diff --git a/aliby/tests/__init__.py b/aliby/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/aliby/tests/test_integration.py b/aliby/tests/test_integration.py new file mode 100644 index 0000000000000000000000000000000000000000..10e45f8ae16b8319cf18266bf2e3a331e502969b --- /dev/null +++ b/aliby/tests/test_integration.py @@ -0,0 +1,28 @@ +""" +Testing the "run" functions in the pipeline elements. +""" +import pytest +pytest.mark.skip(reason='All tests still WIP') + +# Todo: data needed: an experiment object +# Todo: data needed: an sqlite database +# Todo: data needed: a Shelf storage +class TestPipeline: + def test_experiment(self): + pass + + def test_omero_experiment(self): + pass + + def test_tiler(self): + pass + + def test_baby_client(self): + pass + + def test_baby_runner(self): + pass + + def test_pipeline(self): + pass + diff --git a/aliby/tests/test_units.py b/aliby/tests/test_units.py new file mode 100644 index 0000000000000000000000000000000000000000..447ab6018d2e249612094202a490c0beb9ae4324 --- /dev/null +++ b/aliby/tests/test_units.py @@ -0,0 +1,99 @@ +import pytest +pytest.mark.skip("all tests still WIP") + + +from core.core import PersistentDict + +# Todo: temporary file needed +class TestPersistentDict: + @pytest.fixture(autouse=True, scope='class') + def _get_json_file(self, tmp_path): + self._filename = tmp_path / 'persistent_dict.json' + + def test_persistent_dict(self): + p = PersistentDict(self._filename) + p['hello/from/the/other/side'] = "adele" + p['hello/how/you/doing'] = 'lionel' + # Todo: run checks + + +# Todo: data needed - small experiment +class TestExperiment: + def test_shape(self): + pass + def test_positions(self): + pass + def test_channels(self): + pass + def test_hypercube(self): + pass + +# Todo: data needed - a dummy OMERO server +class TestConnection: + def test_dataset(self): + pass + def test_image(self): + pass + +# Todo data needed - a position +class TestTimelapse: + def test_id(self): + pass + def test_name(self): + pass + def test_size_z(self): + pass + def test_size_c(self): + pass + def test_size_t(self): + pass + def test_size_x(self): + pass + def test_size_y(self): + pass + def test_channels(self): + pass + def test_channel_index(self): + pass + +# Todo: data needed image and template +class TestTrapUtils: + def test_trap_locations(self): + pass + def test_tile_shape(self): + pass + def test_get_tile(self): + pass + def test_centre(self): + pass + +# Todo: data needed - a functional experiment object +class TestTiler: + def test_n_timepoints(self): + pass + def test_n_traps(self): + pass + def test_get_trap_timelapse(self): + pass + def test_get_trap_timepoints(self): + pass + +# Todo: data needed - a functional tiler object +# Todo: running server needed +class TestBabyClient: + def test_get_new_session(self): + pass + def test_queue_image(self): + pass + def test_get_segmentation(self): + pass + +# Todo: data needed - a functional tiler object +class TestBabyRunner: + def test_model_choice(self): + pass + def test_properties(self): + pass + def test_segment(self): + pass + diff --git a/aliby/timelapse.py b/aliby/timelapse.py new file mode 100644 index 0000000000000000000000000000000000000000..b2d2233ffdb1e17988ae2512db201e00db8c2fda --- /dev/null +++ b/aliby/timelapse.py @@ -0,0 +1,427 @@ +import itertools +import logging + +import h5py +import numpy as np +from pathlib import Path + +from tqdm import tqdm +import cv2 + +from aliby.io.matlab import matObject +from aliby.utils import Cache, imread, get_store_path + +logger = logging.getLogger(__name__) + + +def parse_local_fs(pos_dir, tp=None): + """ + Local file structure: + - pos_dir + -- exptID_{timepointID}_{ChannelID}_{z_position_id}.png + + :param pos_dirs: + :return: Image_mapper + """ + pos_dir = Path(pos_dir) + + img_mapper = dict() + + def channel_idx(img_name): + return img_name.stem.split("_")[-2] + + def tp_idx(img_name): + return int(img_name.stem.split("_")[-3]) - 1 + + def z_idx(img_name): + return img_name.stem.split("_")[-1] + + if tp is not None: + img_list = [img for img in pos_dir.iterdir() if tp_idx(img) in tp] + else: + img_list = [img for img in pos_dir.iterdir()] + + for tp, group in itertools.groupby(sorted(img_list, key=tp_idx), key=tp_idx): + img_mapper[int(tp)] = { + channel: {i: item for i, item in enumerate(sorted(grp, key=z_idx))} + for channel, grp in itertools.groupby( + sorted(group, key=channel_idx), key=channel_idx + ) + } + return img_mapper + + +class Timelapse: + """ + Timelapse class contains the specifics of one position. + """ + + def __init__(self): + self._id = None + self._name = None + self._channels = [] + self._size_c = 0 + self._size_t = 0 + self._size_x = 0 + self._size_y = 0 + self._size_z = 0 + self.image_cache = None + self.annotation = None + + def __repr__(self): + return self.name + + def full_mask(self): + return np.full(self.shape, False) + + def __getitem__(self, item): + cached = self.image_cache[item] + # Check if there are missing values, if so reload + # TODO only reload missing + mask = np.isnan(cached) + if np.any(mask): + full = self.load_fn(item) + shape = self.image_cache[ + item + ].shape # TODO speed this up by recognising the shape from the item + self.image_cache[item] = np.reshape(full, shape) + return full + return cached + + def get_hypercube(self): + pass + + def load_fn(self, item): + """ + The hypercube is ordered as: C, T, X, Y, Z + :param item: + :return: + """ + + def parse_slice(s): + step = s.step if s.step is not None else 1 + if s.start is None and s.stop is None: + return None + elif s.start is None and s.stop is not None: + return range(0, s.stop, step) + elif s.start is not None and s.stop is None: + return [s.start] + else: # both s.start and s.stop are not None + return range(s.start, s.stop, step) + + def parse_subitem(subitem, kw): + if isinstance(subitem, (int, float)): + res = [int(subitem)] + elif isinstance(subitem, list) or isinstance(subitem, tuple): + res = list(subitem) + elif isinstance(subitem, slice): + res = parse_slice(subitem) + else: + res = subitem + # raise ValueError(f"Cannot parse slice {kw}: {subitem}") + + if kw in ["x", "y"]: + # Need exactly two values + if res is not None: + if len(res) < 2: + # An int was passed, assume it was + res = [res[0], self.size_x] + elif len(res) > 2: + res = [res[0], res[-1] + 1] + return res + + if isinstance(item, int): + return self.get_hypercube( + x=None, y=None, z_positions=None, channels=[item], timepoints=None + ) + elif isinstance(item, slice): + return self.get_hypercube(channels=parse_slice(item)) + keywords = ["channels", "timepoints", "x", "y", "z_positions"] + kwargs = dict() + for kw, subitem in zip(keywords, item): + kwargs[kw] = parse_subitem(subitem, kw) + return self.get_hypercube(**kwargs) + + @property + def shape(self): + return (self.size_c, self.size_t, self.size_x, self.size_y, self.size_z) + + @property + def id(self): + return self._id + + @property + def name(self): + return self._name + + @property + def size_z(self): + return self._size_z + + @property + def size_c(self): + return self._size_c + + @property + def size_t(self): + return self._size_t + + @property + def size_x(self): + return self._size_x + + @property + def size_y(self): + return self._size_y + + @property + def channels(self): + return self._channels + + def get_channel_index(self, channel): + return self.channels.index(channel) + + +def load_annotation(filepath: Path): + try: + return matObject(filepath) + except Exception as e: + raise ( + "Could not load annotation file. \n" + "Non MATLAB files currently unsupported" + ) from e + + +class TimelapseOMERO(Timelapse): + """ + Connected to an Image object which handles database I/O. + """ + + def __init__(self, image, annotation, cache, **kwargs): + super(TimelapseOMERO, self).__init__() + self.image = image + # Pre-load pixels + self.pixels = self.image.getPrimaryPixels() + self._id = self.image.getId() + self._name = self.image.getName() + self._size_x = self.image.getSizeX() + self._size_y = self.image.getSizeY() + self._size_z = self.image.getSizeZ() + self._size_c = self.image.getSizeC() + self._size_t = self.image.getSizeT() + self._channels = self.image.getChannelLabels() + # Check whether there are file annotations for this position + if annotation is not None: + self.annotation = load_annotation(annotation) + # Get an HDF5 dataset to use as a cache. + compression = kwargs.get("compression", None) + self.image_cache = cache.require_dataset( + self.name, + self.shape, + dtype=np.float16, + fillvalue=np.nan, + compression=compression, + ) + + def get_hypercube( + self, x=None, y=None, z_positions=None, channels=None, timepoints=None + ): + if x is None and y is None: + tile = None # Get full plane + elif x is None: + ymin, ymax = y + tile = (None, ymin, None, ymax - ymin) + elif y is None: + xmin, xmax = x + tile = (xmin, None, xmax - xmin, None) + else: + xmin, xmax = x + ymin, ymax = y + tile = (xmin, ymin, xmax - xmin, ymax - ymin) + + if z_positions is None: + z_positions = range(self.size_z) + if channels is None: + channels = range(self.size_c) + if timepoints is None: + timepoints = range(self.size_t) + + z_positions = z_positions or [0] + channels = channels or [0] + timepoints = timepoints or [0] + + zcttile_list = [ + (z, c, t, tile) + for z, c, t in itertools.product(z_positions, channels, timepoints) + ] + planes = list(self.pixels.getTiles(zcttile_list)) + order = ( + len(z_positions), + len(channels), + len(timepoints), + planes[0].shape[-2], + planes[0].shape[-1], + ) + result = np.stack([x for x in planes]).reshape(order) + # Set to C, T, X, Y, Z order + result = np.moveaxis(result, -1, -2) + return np.moveaxis(result, 0, -1) + + def cache_set(self, save_dir, timepoints, expt_name, quiet=True): + # TODO deprecate when this is default + pos_dir = save_dir / self.name + if not pos_dir.exists(): + pos_dir.mkdir() + for tp in tqdm(timepoints, desc=self.name): + for channel in tqdm(self.channels, disable=quiet): + for z_pos in tqdm(range(self.size_z), disable=quiet): + ch_id = self.get_channel_index(channel) + image = self.get_hypercube( + x=None, + y=None, + channels=[ch_id], + z_positions=[z_pos], + timepoints=[tp], + ) + im_name = "{}_{:06d}_{}_{:03d}.png".format( + expt_name, tp + 1, channel, z_pos + 1 + ) + cv2.imwrite(str(pos_dir / im_name), np.squeeze(image)) + # TODO update positions table to get the number of timepoints? + return list(itertools.product([self.name], timepoints)) + + def run(self, keys, store, save_dir="./", **kwargs): + """ + Parse file structure and get images for the timepoints in keys. + """ + save_dir = Path(save_dir) + if keys is None: + # TODO save final metadata + return None + store = save_dir / store + # A position specific store + store = store.with_name(self.name + store.name) + # Create store if it does not exist + if not store.exists(): + # The first run, add metadata to the store + with h5py.File(store, "w") as pos_store: + # TODO Add metadata to the store. + pass + # TODO check how sensible the keys are with what is available + # if some of the keys don't make sense, log a warning and remove + # them so that the next steps of the pipeline make sense + return keys + + def clear_cache(self): + self.image_cache.clear() + + +class TimelapseLocal(Timelapse): + def __init__( + self, position, root_dir, finished=True, annotation=None, cache=None, **kwargs + ): + """ + Linked to a local directory containing the images for one position + in an experiment. + Can be a still running experiment or a finished one. + + :param position: Name of the position + :param root_dir: Root directory + :param finished: Whether the experiment has finished running or the + class will be used as part of a pipeline, mostly with calls to `run` + """ + super(TimelapseLocal, self).__init__() + self.pos_dir = Path(root_dir) / position + assert self.pos_dir.exists() + self._id = position + self._name = position + if finished: + self.image_mapper = parse_local_fs(self.pos_dir) + self._update_metadata() + else: + self.image_mapper = dict() + self.annotation = None + # Check whether there are file annotations for this position + if annotation is not None: + self.annotation = load_annotation(annotation) + compression = kwargs.get("compression", None) + self.image_cache = cache.require_dataset( + self.name, + self.shape, + dtype=np.float16, + fillvalue=np.nan, + compression=compression, + ) + + def _update_metadata(self): + self._size_t = len(self.image_mapper) + # Todo: if cy5 is the first one it causes issues with getting x, y + # hence the sorted but it's not very robust + self._channels = sorted( + list(set.union(*[set(tp.keys()) for tp in self.image_mapper.values()])) + ) + self._size_c = len(self._channels) + # Todo: refactor so we don't rely on there being any images at all + self._size_z = max([len(self.image_mapper[0][ch]) for ch in self._channels]) + single_img = self.get_hypercube( + x=None, y=None, z_positions=None, channels=[0], timepoints=[0] + ) + self._size_x = single_img.shape[2] + self._size_y = single_img.shape[3] + + def get_hypercube( + self, x=None, y=None, z_positions=None, channels=None, timepoints=None + ): + xmin, xmax = x if x is not None else (None, None) + ymin, ymax = y if y is not None else (None, None) + + if z_positions is None: + z_positions = range(self.size_z) + if channels is None: + channels = range(self.size_c) + if timepoints is None: + timepoints = range(self.size_t) + + def z_pos_getter(z_positions, ch_id, t): + default = np.zeros((self.size_x, self.size_y)) + names = [ + self.image_mapper[t][self.channels[ch_id]].get(i, None) + for i in z_positions + ] + res = [imread(name) if name is not None else default for name in names] + return res + + # nested list of images in C, T, X, Y, Z order + ctxyz = [] + for ch_id in channels: + txyz = [] + for t in timepoints: + xyz = z_pos_getter(z_positions, ch_id, t) + txyz.append(np.dstack(list(xyz))[xmin:xmax, ymin:ymax]) + ctxyz.append(np.stack(txyz)) + return np.stack(ctxyz) + + def clear_cache(self): + self.image_cache.clear() + + def run(self, keys, store, save_dir="./", **kwargs): + """ + Parse file structure and get images for the time points in keys. + """ + if keys is None: + return None + elif isinstance(keys, int): + keys = [keys] + self.image_mapper.update(parse_local_fs(self.pos_dir, tp=keys)) + self._update_metadata() + # Create store if it does not exist + store = get_store_path(save_dir, store, self.name) + if not store.exists(): + # The first run, add metadata to the store + with h5py.File(store, "w") as pos_store: + # TODO Add metadata to the store. + pass + # TODO check how sensible the keys are with what is available + # if some of the keys don't make sense, log a warning and remove + # them so that the next steps of the pipeline make sense + return keys diff --git a/aliby/traps.py b/aliby/traps.py new file mode 100644 index 0000000000000000000000000000000000000000..e37eb925eb5763c43efbabed961fb76385aa5e4c --- /dev/null +++ b/aliby/traps.py @@ -0,0 +1,480 @@ +""" +A set of utilities for dealing with ALCATRAS traps +""" + +import numpy as np +from tqdm import tqdm + +from skimage import transform, feature +from skimage.filters.rank import entropy +from skimage.filters import threshold_otsu +from skimage.segmentation import clear_border +from skimage.measure import label, regionprops +from skimage.morphology import disk, closing, square + + +def stretch_image(image): + image = ((image - image.min()) / (image.max() - image.min())) * 255 + minval = np.percentile(image, 2) + maxval = np.percentile(image, 98) + image = np.clip(image, minval, maxval) + image = (image - minval) / (maxval - minval) + return image + + +def segment_traps(image, tile_size, downscale=0.4): + # Make image go between 0 and 255 + img = image # Keep a memory of image in case need to re-run + # stretched = stretch_image(image) + # img = stretch_image(image) + # TODO Optimise the hyperparameters + disk_radius = int(min([0.01 * x for x in img.shape])) + min_area = 0.2 * (tile_size ** 2) + if downscale != 1: + img = transform.rescale(image, downscale) + entropy_image = entropy(img, disk(disk_radius)) + if downscale != 1: + entropy_image = transform.rescale(entropy_image, 1 / downscale) + + # apply threshold + thresh = threshold_otsu(entropy_image) + bw = closing(entropy_image > thresh, square(3)) + + # remove artifacts connected to image border + cleared = clear_border(bw) + + # label image regions + label_image = label(cleared) + areas = [ + region.area + for region in regionprops(label_image) + if region.area > min_area and region.area < tile_size ** 2 * 0.8 + ] + traps = ( + np.array( + [ + region.centroid + for region in regionprops(label_image) + if region.area > min_area and region.area < tile_size ** 2 * 0.8 + ] + ) + .round() + .astype(int) + ) + ma = ( + np.array( + [ + region.minor_axis_length + for region in regionprops(label_image) + if region.area > min_area and region.area < tile_size ** 2 * 0.8 + ] + ) + .round() + .astype(int) + ) + maskx = (tile_size // 2 < traps[:, 0]) & ( + traps[:, 0] < image.shape[0] - tile_size // 2 + ) + masky = (tile_size // 2 < traps[:, 1]) & ( + traps[:, 1] < image.shape[1] - tile_size // 2 + ) + + traps = traps[maskx & masky, :] + ma = ma[maskx & masky] + + chosen_trap_coords = np.round(traps[ma.argmin()]).astype(int) + x, y = chosen_trap_coords + template = image[ + x - tile_size // 2 : x + tile_size // 2, y - tile_size // 2 : y + tile_size // 2 + ] + + traps = identify_trap_locations(image, template) + + if len(traps) < 10 and downscale != 1: + print("Trying again.") + return segment_traps(image, tile_size, downscale=1) + + return traps + + +# def segment_traps(image, tile_size, downscale=0.4): +# # Make image go between 0 and 255 +# img = image # Keep a memory of image in case need to re-run +# image = stretch_image(image) +# # TODO Optimise the hyperparameters +# disk_radius = int(min([0.01 * x for x in img.shape])) +# min_area = 0.1 * (tile_size ** 2) +# if downscale != 1: +# img = transform.rescale(image, downscale) +# entropy_image = entropy(img, disk(disk_radius)) +# if downscale != 1: +# entropy_image = transform.rescale(entropy_image, 1 / downscale) + +# # apply threshold +# thresh = threshold_otsu(entropy_image) +# bw = closing(entropy_image > thresh, square(3)) + +# # remove artifacts connected to image border +# cleared = clear_border(bw) + +# # label image regions +# label_image = label(cleared) +# traps = [ +# region.centroid for region in regionprops(label_image) if region.area > min_area +# ] +# if len(traps) < 10 and downscale != 1: +# print("Trying again.") +# return segment_traps(image, tile_size, downscale=1) +# return traps + + +def identify_trap_locations( + image, trap_template, optimize_scale=True, downscale=0.35, trap_size=None +): + """ + Identify the traps in a single image based on a trap template. + This assumes a trap template that is similar to the image in question + (same camera, same magification; ideally same experiment). + + This method speeds up the search by downscaling both the image and + the trap template before running the template match. + It also optimizes the scale and the rotation of the trap template. + + :param image: + :param trap_template: + :param optimize_scale: + :param downscale: + :param trap_rotation: + :return: + """ + trap_size = trap_size if trap_size is not None else trap_template.shape[0] + # Careful, the image is float16! + img = transform.rescale(image.astype(float), downscale) + temp = transform.rescale(trap_template, downscale) + + # TODO random search hyperparameter optimization + # optimize rotation + matches = { + rotation: feature.match_template( + img, + transform.rotate(temp, rotation, cval=np.median(img)), + pad_input=True, + mode="median", + ) + ** 2 + for rotation in [0, 90, 180, 270] + } + best_rotation = max(matches, key=lambda x: np.percentile(matches[x], 99.9)) + temp = transform.rotate(temp, best_rotation, cval=np.median(img)) + + if optimize_scale: + scales = np.linspace(0.5, 2, 10) + matches = { + scale: feature.match_template( + img, transform.rescale(temp, scale), mode="median", pad_input=True + ) + ** 2 + for scale in scales + } + best_scale = max(matches, key=lambda x: np.percentile(matches[x], 99.9)) + matched = matches[best_scale] + else: + matched = feature.match_template(img, temp, pad_input=True, mode="median") + + coordinates = feature.peak_local_max( + transform.rescale(matched, 1 / downscale), + min_distance=int(trap_template.shape[0] * 0.70), + exclude_border=(trap_size // 3), + ) + return coordinates + + +def get_tile_shapes(x, tile_size, max_shape): + half_size = tile_size // 2 + xmin = int(x[0] - half_size) + ymin = max(0, int(x[1] - half_size)) + # if xmin + tile_size > max_shape[0]: + # xmin = max_shape[0] - tile_size + # if ymin + tile_size > max_shape[1]: + # # ymin = max_shape[1] - tile_size + # return max(xmin, 0), xmin + tile_size, max(ymin, 0), ymin + tile_size + return xmin, xmin + tile_size, ymin, ymin + tile_size + + +def in_image(img, xmin, xmax, ymin, ymax, xidx=2, yidx=3): + if xmin >= 0 and ymin >= 0: + if xmax < img.shape[xidx] and ymax < img.shape[yidx]: + return True + else: + return False + + +def get_xy_tile(img, xmin, xmax, ymin, ymax, xidx=2, yidx=3, pad_val=None): + if pad_val is None: + pad_val = np.median(img) + # Get the tile from the image + idx = [slice(None)] * len(img.shape) + idx[xidx] = slice(max(0, xmin), min(xmax, img.shape[xidx])) + idx[yidx] = slice(max(0, ymin), min(ymax, img.shape[yidx])) + tile = img[tuple(idx)] + # Check if the tile is in the image + if in_image(img, xmin, xmax, ymin, ymax, xidx, yidx): + return tile + else: + # Add padding + pad_shape = [(0, 0)] * len(img.shape) + pad_shape[xidx] = (max(-xmin, 0), max(xmax - img.shape[xidx], 0)) + pad_shape[yidx] = (max(-ymin, 0), max(ymax - img.shape[yidx], 0)) + tile = np.pad(tile, pad_shape, constant_values=pad_val) + return tile + + +def get_trap_timelapse( + raw_expt, trap_locations, trap_id, tile_size=117, channels=None, z=None +): + """ + Get a timelapse for a given trap by specifying the trap_id + :param trap_id: An integer defining which trap to choose. Counted + between 0 and Tiler.n_traps - 1 + :param tile_size: The size of the trap tile (centered around the + trap as much as possible, edge cases exist) + :param channels: Which channels to fetch, indexed from 0. + If None, defaults to [0] + :param z: Which z_stacks to fetch, indexed from 0. + If None, defaults to [0]. + :return: A numpy array with the timelapse in (C,T,X,Y,Z) order + """ + # Set the defaults (list is mutable) + channels = channels if channels is not None else [0] + z = z if z is not None else [0] + # Get trap location for that id: + trap_centers = [trap_locations[i][trap_id] for i in range(len(trap_locations))] + + max_shape = (raw_expt.shape[2], raw_expt.shape[3]) + tiles_shapes = [ + get_tile_shapes((x[0], x[1]), tile_size, max_shape) for x in trap_centers + ] + + timelapse = [ + get_xy_tile( + raw_expt[channels, i, :, :, z], xmin, xmax, ymin, ymax, pad_val=None + ) + for i, (xmin, xmax, ymin, ymax) in enumerate(tiles_shapes) + ] + return np.hstack(timelapse) + + +def get_trap_timelapse_omero( + raw_expt, trap_locations, trap_id, tile_size=117, channels=None, z=None, t=None +): + """ + Get a timelapse for a given trap by specifying the trap_id + :param raw_expt: A Timelapse object from which data is obtained + :param trap_id: An integer defining which trap to choose. Counted + between 0 and Tiler.n_traps - 1 + :param tile_size: The size of the trap tile (centered around the + trap as much as possible, edge cases exist) + :param channels: Which channels to fetch, indexed from 0. + If None, defaults to [0] + :param z: Which z_stacks to fetch, indexed from 0. + If None, defaults to [0]. + :return: A numpy array with the timelapse in (C,T,X,Y,Z) order + """ + # Set the defaults (list is mutable) + channels = channels if channels is not None else [0] + z_positions = z if z is not None else [0] + times = ( + t if t is not None else np.arange(raw_expt.shape[1]) + ) # TODO choose sub-set of time points + shape = (len(channels), len(times), tile_size, tile_size, len(z_positions)) + # Get trap location for that id: + zct_tiles, slices, trap_ids = all_tiles( + trap_locations, shape, raw_expt, z_positions, channels, times, [trap_id] + ) + + # TODO Make this an explicit function in TimelapseOMERO + images = raw_expt.pixels.getTiles(zct_tiles) + timelapse = np.full(shape, np.nan) + total = len(zct_tiles) + for (z, c, t, _), (y, x), image in tqdm( + zip(zct_tiles, slices, images), total=total + ): + ch = channels.index(c) + tp = times.tolist().index(t) + z_pos = z_positions.index(z) + timelapse[ch, tp, x[0] : x[1], y[0] : y[1], z_pos] = image + + # for x in timelapse: # By channel + # np.nan_to_num(x, nan=np.nanmedian(x), copy=False) + return timelapse + + +def all_tiles(trap_locations, shape, raw_expt, z_positions, channels, times, traps): + _, _, x, y, _ = shape + _, _, MAX_X, MAX_Y, _ = raw_expt.shape + + trap_ids = [] + zct_tiles = [] + slices = [] + for z in z_positions: + for ch in channels: + for t in times: + for trap_id in traps: + centre = trap_locations[t][trap_id] + xmin, ymin, xmax, ymax, r_xmin, r_ymin, r_xmax, r_ymax = tile_where( + centre, x, y, MAX_X, MAX_Y + ) + slices.append( + ((r_ymin - ymin, r_ymax - ymin), (r_xmin - xmin, r_xmax - xmin)) + ) + tile = (r_ymin, r_xmin, r_ymax - r_ymin, r_xmax - r_xmin) + zct_tiles.append((z, ch, t, tile)) + trap_ids.append(trap_id) # So we remember the order! + return zct_tiles, slices, trap_ids + + +def tile_where(centre, x, y, MAX_X, MAX_Y): + # Find the position of the tile + xmin = int(centre[1] - x // 2) + ymin = int(centre[0] - y // 2) + xmax = xmin + x + ymax = ymin + y + # What do we actually have available? + r_xmin = max(0, xmin) + r_xmax = min(MAX_X, xmax) + r_ymin = max(0, ymin) + r_ymax = min(MAX_Y, ymax) + return xmin, ymin, xmax, ymax, r_xmin, r_ymin, r_xmax, r_ymax + + +def get_tile(shape, center, raw_expt, ch, t, z): + """Returns a tile from the raw experiment with a given shape. + + :param shape: The shape of the tile in (C, T, Z, Y, X) order. + :param center: The x,y position of the centre of the tile + :param + """ + _, _, x, y, _ = shape + _, _, MAX_X, MAX_Y, _ = raw_expt.shape + tile = np.full(shape, np.nan) + + # Find the position of the tile + xmin = int(center[1] - x // 2) + ymin = int(center[0] - y // 2) + xmax = xmin + x + ymax = ymin + y + # What do we actually have available? + r_xmin = max(0, xmin) + r_xmax = min(MAX_X, xmax) + r_ymin = max(0, ymin) + r_ymax = min(MAX_Y, ymax) + + # Fill values + tile[ + :, :, (r_xmin - xmin) : (r_xmax - xmin), (r_ymin - ymin) : (r_ymax - ymin), : + ] = raw_expt[ch, t, r_xmin:r_xmax, r_ymin:r_ymax, z] + # fill_val = np.nanmedian(tile) + # np.nan_to_num(tile, nan=fill_val, copy=False) + return tile + + +def get_traps_timepoint( + raw_expt, trap_locations, tp, tile_size=96, channels=None, z=None +): + """ + Get all the traps from a given time point + :param raw_expt: + :param trap_locations: + :param tp: + :param tile_size: + :param channels: + :param z: + :return: A numpy array with the traps in the (trap, C, T, X, Y, + Z) order + """ + + # Set the defaults (list is mutable) + channels = channels if channels is not None else [0] + z_positions = z if z is not None else [0] + if isinstance(z_positions, slice): + n_z = z_positions.stop + z_positions = list(range(n_z)) # slice is not iterable error + elif isinstance(z_positions, list): + n_z = len(z_positions) + else: + n_z = 1 + + n_traps = len(trap_locations[tp]) + trap_ids = list(range(n_traps)) + shape = (len(channels), 1, tile_size, tile_size, n_z) + # all tiles + zct_tiles, slices, trap_ids = all_tiles( + trap_locations, shape, raw_expt, z_positions, channels, [tp], trap_ids + ) + # TODO Make this an explicit function in TimelapseOMERO + images = raw_expt.pixels.getTiles(zct_tiles) + # Initialise empty traps + traps = np.full((n_traps,) + shape, np.nan) + for trap_id, (z, c, _, _), (y, x), image in zip( + trap_ids, zct_tiles, slices, images + ): + ch = channels.index(c) + z_pos = z_positions.index(z) + traps[trap_id, ch, 0, x[0] : x[1], y[0] : y[1], z_pos] = image + for x in traps: # By channel + np.nan_to_num(x, nan=np.nanmedian(x), copy=False) + return traps + + +def centre(img, percentage=0.3): + y, x = img.shape + cropx = int(np.ceil(x * percentage)) + cropy = int(np.ceil(y * percentage)) + startx = int(x // 2 - (cropx // 2)) + starty = int(y // 2 - (cropy // 2)) + return img[starty : starty + cropy, startx : startx + cropx] + + +def align_timelapse_images( + raw_data, channel=0, reference_reset_time=80, reference_reset_drift=25 +): + """ + Uses image registration to align images in the timelapse. + Uses the channel with id `channel` to perform the registration. + + Starts with the first timepoint as a reference and changes the + reference to the current timepoint if either the images have moved + by half of a trap width or `reference_reset_time` has been reached. + + Sets `self.drift`, a 3D numpy array with shape (t, drift_x, drift_y). + We assume no drift occurs in the z-direction. + + :param reference_reset_drift: Upper bound on the allowed drift before + resetting the reference image. + :param reference_reset_time: Upper bound on number of time points to + register before resetting the reference image. + :param channel: index of the channel to use for image registration. + """ + ref = centre(np.squeeze(raw_data[channel, 0, :, :, 0])) + size_t = raw_data.shape[1] + + drift = [np.array([0, 0])] + for i in range(1, size_t): + img = centre(np.squeeze(raw_data[channel, i, :, :, 0])) + + shifts, _, _ = feature.register_translation(ref, img) + # If a huge move is detected at a single time point it is taken + # to be inaccurate and the correction from the previous time point + # is used. + # This might be common if there is a focus loss for example. + if any([abs(x - y) > reference_reset_drift for x, y in zip(shifts, drift[-1])]): + shifts = drift[-1] + + drift.append(shifts) + ref = img + + # TODO test necessity for references, description below + # If the images have drifted too far from the reference or too + # much time has passed we change the reference and keep track of + # which images are kept as references + return np.stack(drift) diff --git a/aliby/utils.py b/aliby/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..613bdb720e2e1338261a5965f1c329f7250189e7 --- /dev/null +++ b/aliby/utils.py @@ -0,0 +1,135 @@ +""" +Utility functions and classes +""" +import itertools +import logging +import operator +from pathlib import Path +from typing import Callable + +import h5py +import imageio +import cv2 +import numpy as np + +def repr_obj(obj, indent=0): + """ + Helper function to display info about OMERO objects. + Not all objects will have a "name" or owner field. + """ + string = """%s%s:%s Name:"%s" (owner=%s)""" % ( + " " * indent, + obj.OMERO_CLASS, + obj.getId(), + obj.getName(), + obj.getAnnotation()) + + return string + +def imread(path): + return cv2.imread(str(path), -1) + + +class ImageCache: + """HDF5-based image cache for faster loading of the images once they've + been read. + """ + def __init__(self, file, name, shape, remote_fn): + self.store = h5py.File(file, 'a') + # Create a dataset + self.dataset = self.store.create_dataset(name, shape, + dtype=np.float, + fill_value=np.nan) + self.remote_fn = remote_fn + + def __getitem__(self, item): + cached = self.dataset[item] + if np.any(np.isnan(cached)): + full = self.remote_fn(item) + self.dataset[item] = full + return full + else: + return cached + + +class Cache: + """ + Fixed-length mapping to use as a cache. + Deletes items in FIFO manner when maximum allowed length is reached. + """ + def __init__(self, max_len=5000, load_fn: Callable = imread): + """ + :param max_len: Maximum number of items in the cache. + :param load_fn: The function used to load new items if they are not + available in the Cache + """ + self._dict = dict() + self._queue = [] + self.load_fn = load_fn + self.max_len=max_len + + def __getitem__(self, item): + if item not in self._dict: + self.load_item(item) + return self._dict[item] + + def load_item(self, item): + self._dict[item] = self.load_fn(item) + # Clean up the queue + self._queue.append(item) + if len(self._queue) > self.max_len: + del self._dict[self._queue.pop(0)] + + def clear(self): + self._dict.clear() + self._queue.clear() + + +def accumulate(l: list): + l = sorted(l) + it = itertools.groupby(l, operator.itemgetter(0)) + for key, sub_iter in it: + yield key, [x[1] for x in sub_iter] + + +def get_store_path(save_dir, store, name): + """Create a path to a position-specific store. + + This combines the name and the store's base name into a file path within save_dir. + For example. + >>> get_store_path('data', 'baby_seg.h5', 'pos001') + Path(data/pos001baby_seg.h5') + + :param save_dir: The root directory in which to save the file, absolute + path. + :param store: The base name of the store + :param name: The name of the position + :return: Path(save_dir) / name+store + """ + store = Path(save_dir) / store + store = store.with_name(name + store.name) + return store + +def parametrized(dec): + def layer(*args, **kwargs): + def repl(f): + return dec(f, *args, **kwargs) + return repl + return layer + +from functools import wraps, partial +from time import perf_counter +import logging +@parametrized +def timed(f, name=None): + @wraps(f) + def decorated(*args, **kwargs): + t = perf_counter() + res = f(*args, **kwargs) + to_print = name or f.__name__ + logging.debug(f'Timing:{to_print}:{perf_counter() - t}s') + return res + return decorated + + + diff --git a/scripts/dev_lineage.py b/scripts/dev_lineage.py index b2d443fbb587e37f6e552326240014f928290e8a..e817c4b5dfad31e4e0520766aa01aeb34e06c7bb 100644 --- a/scripts/dev_lineage.py +++ b/scripts/dev_lineage.py @@ -13,14 +13,15 @@ from utils_find_1st import find_1st, cmp_equal # ) dpath = Path( - "/home/alan/Documents/sync_docs/data/data/2021_03_20_sugarShift_pal_glu_pal_Myo1Whi5_Sfp1Nhp6a__00/2021_03_20_sugarShift_pal_glu_pal_Myo1Whi5_Sfp1Nhp6a__00" + # "/home/alan/Documents/dev/stoa_libs/pipeline-core/data/2021_11_04_doseResponse_raf_1_15_2_glu_01_2_dual_phluorin_whi5_constantMedia_00/" + "/home/alan/Documents/dev/stoa_libs/pipeline-core/data/2020_10_22_2tozero_Hxts_02/2020_10_22_2tozero_Hxts_02" ) def compare_ma_methods(fpath): with h5py.File(fpath, "r") as f: - ma = f["cell_info/mother_assign"][()] + maf = f["cell_info/mother_assign"][()] mad = f["cell_info/mother_assign_dynamic"][()] trap = f["cell_info/trap"][()] timepoint = f["cell_info/timepoint"][()] @@ -31,12 +32,19 @@ def compare_ma_methods(fpath): def mother_assign_from_dynamic(ma, cell_label, trap, ntraps: int): """ Interpolate the list of lists containing the associated mothers from the mother_assign_dynamic feature + + Parameters: + + ma: + cell_label: + trap: + ntraps: """ - idlist = list(zip(trap, label)) + idlist = list(zip(trap, cell_label)) cell_gid = np.unique(idlist, axis=0) last_lin_preds = [ - find_1st(((label[::-1] == lbl) & (trap[::-1] == tr)), True, cmp_equal) + find_1st(((cell_label[::-1] == lbl) & (trap[::-1] == tr)), True, cmp_equal) for tr, lbl in cell_gid ] mother_assign_sorted = ma[::-1][last_lin_preds] @@ -51,9 +59,18 @@ def compare_ma_methods(fpath): mad_fixed = mother_assign_from_dynamic(mad, cell_label, trap, len(np.unique(trap))) dyn = sum([np.array(i, dtype=bool).sum() for i in mad_fixed]) - nondyn = sum([x.astype(bool).sum() for x in ma]) - return dyn, nondyn + dyn_len = sum([len(i) for i in mad_fixed]) + nondyn = sum([x.astype(bool).sum() for x in maf]) + nondyn_len = sum([len(x) for x in maf]) + return (dyn, dyn_len), (nondyn, nondyn_len) + # return mad_fixed, maf + + +nids = [print(compare_ma_methods(fpath)) for fpath in dpath.glob("*.h5")] +tmp = nids[0] + +fpath = list(dpath.glob("*.h5"))[0] +from aliby.io.signal import Signal -for fpath in dpath.glob("*.h5"): - print(compare_ma_methods(fpath)) +s = Signal(fpath) diff --git a/scripts/distributed_alan.py b/scripts/distributed_alan.py index 28bb68bd30c50a03d251ed3e18f5ac8ea198b255..760f3eff11c8dc5a867c1a0bfca4f3f471e93125 100644 --- a/scripts/distributed_alan.py +++ b/scripts/distributed_alan.py @@ -16,15 +16,15 @@ import seaborn as sns import operator -from pcore.experiment import MetaData -from pcore.io.omero import Dataset, Image -from pcore.haystack import initialise_tf -from pcore.baby_client import BabyRunner -from pcore.segment import Tiler -from pcore.io.writer import TilerWriter, BabyWriter -from pcore.utils import timed - -from pcore.io.signal import Signal +from aliby.experiment import MetaData +from aliby.io.omero import Dataset, Image +from aliby.haystack import initialise_tf +from aliby.baby_client import BabyRunner +from aliby.segment import Tiler +from aliby.io.writer import TilerWriter, BabyWriter +from aliby.utils import timed + +from aliby.io.signal import Signal from extraction.core.functions.defaults import exparams_from_meta from extraction.core.extractor import Extractor from extraction.core.parameters import Parameters diff --git a/scripts/run_default.py b/scripts/run_default.py index acc31ef5a132d06df6aed1485b9a8f676f7d1125..a9d6bd41688ee46df55ca363d5c6ca260a1a9e20 100644 --- a/scripts/run_default.py +++ b/scripts/run_default.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -from pcore.pipeline import PipelineParameters, Pipeline +from aliby.pipeline import PipelineParameters, Pipeline p = Pipeline(PipelineParameters.default()) p.run() diff --git a/setup.py b/setup.py index e6e978a7687368346021360b9b580a198caf8a35..64bb81ee39c098ff885650b45d816e86bb4e5ab6 100644 --- a/setup.py +++ b/setup.py @@ -1,12 +1,12 @@ from setuptools import setup, find_packages -print("find_packages outputs ", find_packages("pcore")) +print("find_packages outputs ", find_packages("aliby")) setup( name="pipeline-core", version="0.1.1-dev", packages=find_packages(), - # package_dir={"": "pcore"}, - # packages=['pcore', 'pcore.io'], + # package_dir={"": "aliby"}, + # packages=['aliby', 'aliby.io'], # include_package_data=True, url="", license="",