diff --git a/extraction/__init__.py b/extraction/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/extraction/core/__init__.py b/extraction/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cf529d79a32f43a95d91ad5704aa71a9c3237974 --- /dev/null +++ b/extraction/core/__init__.py @@ -0,0 +1,2 @@ +#!/usr/bin/env python + diff --git a/extraction/core/extractor.py b/extraction/core/extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..b3f7abdcf862db6cc0fc1bb573668f26353b0e84 --- /dev/null +++ b/extraction/core/extractor.py @@ -0,0 +1,534 @@ +import os +from pathlib import Path, PosixPath +import pkg_resources +from collections.abc import Iterable + +# from copy import copy +from typing import Union, List, Dict, Callable +from datetime import datetime + +import numpy as np +import pandas as pd +from scipy.sparse import dok_matrix, vstack, issparse +from tqdm import tqdm + + +from extraction.core.functions.loaders import ( + load_funs, + load_custom_args, + load_redfuns, + load_mergefuns, +) +from extraction.core.functions.defaults import exparams_from_meta +from extraction.core.functions.distributors import trap_apply, reduce_z +from extraction.core.functions.utils import depth + +from agora.abc import ProcessABC, ParametersABC +from agora.io.writer import Writer, load_attributes +from agora.io.cells import Cells +from agora.tile.tiler import Tiler + +CELL_FUNS, TRAPFUNS, FUNS = load_funs() +CUSTOM_FUNS, CUSTOM_ARGS = load_custom_args() +RED_FUNS = load_redfuns() +MERGE_FUNS = load_mergefuns() + +# Assign datatype depending on the metric used +m2type = {"mean": np.float32, "median": np.ubyte, "imBackground": np.ubyte} + + +class ExtractorParameters(ParametersABC): + """ + Base class to define parameters for extraction + :tree: dict of depth n. If not of depth three tree will be filled with Nones + str channel -> U(function,None) reduction -> str metric + + """ + + def __init__( + self, + tree: Dict[Union[str, None], Dict[Union[Callable, None], List[str]]] = None, + sub_bg: set = set(), + multichannel_ops: Dict = {}, + ): + + self.tree = fill_tree(tree) + + self.sub_bg = sub_bg + self.multichannel_ops = multichannel_ops + + @staticmethod + def guess_from_meta(store_name: str, suffix="fast"): + """ + Make a guess on the parameters using the hdf5 metadata + + Add anything as a suffix, default "fast" + + + Parameters: + store_name : str or Path indicating the results' storage. + suffix : str to add at the end of the predicted parameter set + """ + + with h5py.open(store_name, "r") as f: + microscope = f["/"].attrs.get("microscope") # TODO Check this with Arin + assert microscope, "No metadata found" + + return "_".join((microscope, suffix)) + + @classmethod + def default(cls): + return cls({}) + + @classmethod + def from_meta(cls, meta): + return cls(**exparams_from_meta(meta)) + + +class Extractor(ProcessABC): + """ + Base class to perform feature extraction. + + Parameters + ---------- + parameters: core.extractor Parameters + Parameters that include with channels, reduction and + b extraction functions to use. + store: str + Path to hdf5 storage file. Must contain cell outlines. + tiler: pipeline-core.core.segmentation tiler + Class that contains or fetches the image to be used for segmentation. + """ + + default_meta = {"pixel_size": 0.236, "z_size": 0.6, "spacing": 0.6} + + def __init__(self, parameters: ExtractorParameters, store: str, tiler: Tiler): + self.params = parameters + self.local = store + self.load_meta() + self.tiler = tiler + self.load_funs() + + @classmethod + def from_tiler(cls, parameters: ExtractorParameters, store: str, tiler: Tiler): + return cls(parameters, store=store, tiler=tiler) + + @classmethod + def from_img(cls, parameters: ExtractorParameters, store: str, img_meta: tuple): + return cls(parameters, store=store, tiler=Tiler(*img_meta)) + + # @classmethod + # def from_store(cls, parameters: ExtractorParameters, store: str, img_meta: tuple): + # return cls(parameters, store=store, tiler=Tiler(*img_meta)) + + @property + def channels(self): + if not hasattr(self, "_channels"): + if type(self.params.tree) is dict: + self._channels = tuple(self.params.tree.keys()) + + return self._channels + + @property + def current_position(self): + return self.local.split("/")[-1][:-3] + + @property + def group(self): # Path within hdf5 + if not hasattr(self, "_out_path"): + self._group = "/extraction/" + return self._group + + @property + def pos_file(self, store_name="store.h5"): + if not hasattr(self, "_pos_file"): + return self.local + + def load_custom_funs(self): + """ + Load parameters of functions that require them from expt. + These must be loaded within the Extractor instance because their parameters + depend on their experiment's metadata. + """ + funs = set( + [ + fun + for ch in self.params.tree.values() + for red in ch.values() + for fun in red + ] + ) + funs = funs.intersection(CUSTOM_FUNS.keys()) + ARG_VALS = { + k: {k2: self.get_meta(k2) for k2 in v} for k, v in CUSTOM_ARGS.items() + } + # self._custom_funs = {trap_apply(CUSTOM_FUNS[fun],]) + self._custom_funs = {} + for k, f in CUSTOM_FUNS.items(): + + def tmp(f): + return lambda m, img: trap_apply(f, m, img, **ARG_VALS.get(k, {})) + + self._custom_funs[k] = tmp(f) + + def load_funs(self): + self.load_custom_funs() + self._all_cell_funs = set(self._custom_funs.keys()).union(CELL_FUNS) + self._all_funs = {**self._custom_funs, **FUNS} + + def load_meta(self): + self.meta = load_attributes(self.local) + + def get_traps( + self, tp: int, channels: list = None, z: list = None, **kwargs + ) -> tuple: + if channels is None: + channel_ids = list(range(len(self.tiler.channels))) + elif len(channels): + channel_ids = [self.tiler.get_channel_index(ch) for ch in channels] + else: + channel_ids = None + + if z is None: + z = list(range(self.tiler.shape[-1])) + + traps = ( + self.tiler.get_traps_timepoint(tp, channels=channel_ids, z=z, **kwargs) + if channel_ids + else None + ) + + return traps + + def extract_traps( + self, + traps: List[np.array], + masks: List[np.array], + metric: str, + labels: List[int] = None, + ) -> dict: + """ + Apply a function for a whole position. + + :traps: List[np.array] list of images + :masks: List[np.array] list of masks + :metric:str metric to extract + :labels: List[int] cell Labels to use as indices for output DataFrame + :pos_info: bool Whether to add the position as index or not + + returns + :d: Dictionary of dataframe + """ + + if labels is None: + raise Warning("No labels given. Sorting cells using index.") + + cell_fun = True if metric in self._all_cell_funs else False + + idx = [] + results = [] + + for trap_id, (mask_set, trap, lbl_set) in enumerate( + zip(masks, traps, labels.values()) + ): + if len(mask_set): # ignore empty traps + result = self._all_funs[metric](mask_set, trap) + if cell_fun: + for lbl, val in zip(lbl_set, result): + results.append(val) + idx.append((trap_id, lbl)) + else: + results.append(result) + idx.append(trap_id) + + return (tuple(results), tuple(idx)) + + def extract_funs( + self, traps: List[np.array], masks: List[np.array], metrics: List[str], **kwargs + ) -> dict: + """ + Extract multiple metrics from a timepoint + """ + d = { + metric: self.extract_traps( + traps=traps, masks=masks, metric=metric, **kwargs + ) + for metric in metrics + } + + return d + + def reduce_extract( + self, traps: Union[np.array, None], masks: list, red_metrics: dict, **kwargs + ) -> dict: + """ + :param red_metrics: dict in which keys are reduction funcions and + values are strings indicating the metric function + :**kwargs: All other arguments, must include masks and traps. + """ + + reduced_traps = {} + if traps is not None: + for red_fun in red_metrics.keys(): + reduced_traps[red_fun] = [ + self.reduce_dims(trap, method=RED_FUNS[red_fun]) for trap in traps + ] + + d = { + red_fun: self.extract_funs( + metrics=metrics, + traps=reduced_traps.get(red_fun, [None for _ in masks]), + masks=masks, + **kwargs, + ) + for red_fun, metrics in red_metrics.items() + } + return d + + def reduce_dims(self, img: np.array, method=None) -> np.array: + # assert len(img.shape) == 3, "Incorrect number of dimensions" + + if method is None: + return img + + return reduce_z(img, method) + + def extract_tp( + self, tp: int, tree: dict = None, tile_size: int = 117, **kwargs + ) -> dict: + """ + :param tp: int timepoint from which to extract results + :param tree: dict of dict {channel : {reduction_function : metrics}} + :**kwargs: Must include masks and preferably labels. + """ + + if tree is None: + tree = self.params.tree + + ch_tree = {ch: v for ch, v in tree.items() if ch != "general"} + tree_chs = (*ch_tree,) + + cells = Cells.hdf(self.local) + + # labels + raw_labels = cells.labels_at_time(tp) + labels = { + trap_id: raw_labels.get(trap_id, []) for trap_id in range(cells.ntraps) + } + + # masks + raw_masks = cells.at_time(tp, kind="mask") + + masks = {trap_id: [] for trap_id in range(cells.ntraps)} + for trap_id, cells in raw_masks.items(): + if len(cells): + masks[trap_id] = np.dstack(np.array(cells)).astype(bool) + + masks = [np.array(v) for v in masks.values()] + + # traps + traps = self.get_traps(tp, tile_size=tile_size, channels=tree_chs) + + self.img_bgsub = {} + if self.params.sub_bg: + bg = [ + ~np.sum(m, axis=2).astype(bool) + if np.any(m) + else np.zeros((tile_size, tile_size)) + for m in masks + ] + + d = {} + + for ch, red_metrics in tree.items(): + img = None + # ch != is necessary for threading + if ch != "general" and traps is not None and len(traps): + img = traps[:, tree_chs.index(ch), 0] + + d[ch] = self.reduce_extract( + red_metrics=red_metrics, traps=img, masks=masks, labels=labels, **kwargs + ) + + if ( + ch in self.params.sub_bg and img is not None + ): # Calculate metrics with subtracted bg + ch_bs = ch + "_bgsub" + + self.img_bgsub[ch_bs] = [] + for trap, maskset in zip(img, bg): + + cells_fl = np.zeros_like(trap) + + is_cell = np.where(maskset) + if len(is_cell[0]): # skip calculation for empty traps + cells_fl = np.median(trap[is_cell], axis=0) + + self.img_bgsub[ch_bs].append(trap - cells_fl) + + d[ch_bs] = self.reduce_extract( + red_metrics=ch_tree[ch], + traps=self.img_bgsub[ch_bs], + masks=masks, + labels=labels, + **kwargs, + ) + + # Additional operations between multiple channels (e.g. pH calculations) + for name, (chs, merge_fun, red_metrics) in self.params.multichannel_ops.items(): + if len( + set(chs).intersection(set(self.img_bgsub.keys()).union(tree_chs)) + ) == len(chs): + imgs = [self.get_imgs(ch, traps, tree_chs) for ch in chs] + merged = MERGE_FUNS[merge_fun](*imgs) + d[name] = self.reduce_extract( + red_metrics=red_metrics, + traps=merged, + masks=masks, + labels=labels, + **kwargs, + ) + + del traps, masks + return d + + def get_imgs(self, channel, traps, channels=None): + """ + Returns the image from a correct source, either raw or bgsub + + + :channel: str name of channel to get + :img: ndarray (trap_id, channel, tp, tile_size, tile_size, n_zstacks) of standard channels + :channels: List of channels + """ + + if channels is None: + channels = (*self.params.tree,) + + if channel in channels: + return traps[:, channels.index(channel), 0] + elif channel in self.img_bgsub: + return self.img_bgsub[channel] + + def run(self, tree=None, tps: List[int] = None, save=True, **kwargs) -> dict: + + if tree is None: + tree = self.params.tree + + if tps is None: + tps = list(range(self.meta["time_settings/ntimepoints"])) + + d = {} + for tp in tps: + new = flatten_nest( + self.extract_tp(tp=tp, tree=tree, **kwargs), + to="series", + tp=tp, + ) + + for k in new.keys(): + n = new[k] + d[k] = pd.concat((d.get(k, None), n), axis=1) + + for k in d.keys(): + indices = ["experiment", "position", "trap", "cell_label"] + idx = ( + indices[-d[k].index.nlevels :] + if d[k].index.nlevels > 1 + else [indices[-2]] + ) + d[k].index.names = idx + + toreturn = d + + if save: + self.save_to_hdf(toreturn) + + return toreturn + + def extract_pos( + self, tree=None, tps: List[int] = None, save=True, **kwargs + ) -> dict: + + if tree is None: + tree = self.params.tree + + if tps is None: + tps = list(range(self.meta["time_settings/ntimepoints"])) + + d = {} + for tp in tps: + new = flatten_nest( + self.extract_tp(tp=tp, tree=tree, **kwargs), + to="series", + tp=tp, + ) + + for k in new.keys(): + n = new[k] + d[k] = pd.concat((d.get(k, None), n), axis=1) + + for k in d.keys(): + indices = ["experiment", "position", "trap", "cell_label"] + idx = ( + indices[-d[k].index.nlevels :] + if d[k].index.nlevels > 1 + else [indices[-2]] + ) + d[k].index.names = idx + + toreturn = d + + if save: + self.save_to_hdf(toreturn) + + return toreturn + + def save_to_hdf(self, group_df, path=None): + if path is None: + path = self.local + + self.writer = Writer(path) + for path, df in group_df.items(): + dset_path = "/extraction/" + path + self.writer.write(dset_path, df) + self.writer.id_cache.clear() + + def get_meta(self, flds): + if not hasattr(flds, "__iter__"): + flds = [flds] + meta_short = {k.split("/")[-1]: v for k, v in self.meta.items()} + return {f: meta_short.get(f, self.default_meta.get(f, None)) for f in flds} + + +### Helpers +def flatten_nest(nest: dict, to="series", tp: int = None) -> dict: + """ + Convert a nested extraction dict into a dict of series + :param nest: dict contained the nested results of extraction + :param to: str = 'series' Determine output format, either list or pd.Series + :param tp: int timepoint used to name the series + """ + + d = {} + for k0, v0 in nest.items(): + for k1, v1 in v0.items(): + for k2, v2 in v1.items(): + d["/".join((k0, k1, k2))] = ( + pd.Series(*v2, name=tp) if to == "series" else v2 + ) + + return d + + +def fill_tree(tree): + if tree is None: + return None + tree_depth = depth(tree) + if depth(tree) < 3: + d = {None: {None: {None: []}}} + for _ in range(2 - tree_depth): + d = d[None] + d[None] = tree + tree = d + return tree diff --git a/extraction/core/functions/__init__.py b/extraction/core/functions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e5a0d9b4834ec8f46d6e0d1256c6dcaad2e460fe --- /dev/null +++ b/extraction/core/functions/__init__.py @@ -0,0 +1 @@ +#!/usr/bin/env python3 diff --git a/extraction/core/functions/cell.py b/extraction/core/functions/cell.py new file mode 100644 index 0000000000000000000000000000000000000000..cfbe5626b8c1f2a036110757d2242fe253d33939 --- /dev/null +++ b/extraction/core/functions/cell.py @@ -0,0 +1,124 @@ +""" +Base functions to extract information from a single cell + +These functions are automatically read, so only add new functions with +the same arguments as the existing ones. + +Input: +:cell_mask: (x,y) 2-D cell mask +:trap_image: (x,y) 2-D or (x,y,z) 3-D cell mask + + +np.where is used to cover for cases where z>1 +""" +import math + +import numpy as np +from scipy import ndimage +from sklearn.cluster import KMeans + + +def area(cell_mask, trap_image=None): + return np.sum(cell_mask, dtype=int) + + +def mean(cell_mask, trap_image): + return np.mean(trap_image[np.where(cell_mask)], dtype=float) + + +def median(cell_mask, trap_image): + return np.median(trap_image[np.where(cell_mask)]) + + +def max2p5pc(cell_mask, trap_image): + npixels = cell_mask.sum() + top_pixels = int(np.ceil(npixels * 0.025)) + + sorted_vals = np.sort(trap_image[np.where(cell_mask)], axis=None) + top_vals = sorted_vals[-top_pixels:] + max2p5pc = np.mean(top_vals, dtype=float) + + return max2p5pc + + +def max5px(cell_mask, trap_image): + sorted_vals = np.sort(trap_image[np.where(cell_mask)], axis=None) + top_vals = sorted_vals[-5:] + max5px = np.mean(top_vals, dtype=float) + + return max5px + + +def std(cell_mask, trap_image): + return np.std(trap_image[np.where(cell_mask)], dtype=float) + + +def k2_top_median(cell_mask, trap_image): + # Use kmeans to cluster the contents of a cell in two, return the high median + # Useful when a big non-tagged organelle (e.g. vacuole) occupies a big fraction + # of the cell + if not np.any(cell_mask): + return np.nan + + X = trap_image[np.where(cell_mask)].reshape(-1, 1) + kmeans = KMeans(n_clusters=2, random_state=0).fit(X) + high_clust_id = kmeans.cluster_centers_.argmax() + major_cluster = X[kmeans.predict(X) == high_clust_id] + + k2_top_median = np.median(major_cluster, axis=None) + return k2_top_median + + +def membraneMax5(cell_mask, trap_image): + pass + + +def membraneMedian(cell_mask, trap_image): + pass + + +def volume(cell_mask, trap_image=None): + """Volume from a cell mask, assuming an ellipse. + + Assumes the mask is the median plane of the ellipsoid. + Assumes rotational symmetry around the major axis. + """ + min_ax, maj_ax = min_maj_approximation(cell_mask, trap_image) + return (4 * math.pi * min_ax**2 * maj_ax) / 3 + + +def conical_volume(cell_mask, trap_image=None): + padded = np.pad(cell_mask, 1, mode='constant', constant_values=0) + nearest_neighbor = ndimage.morphology.distance_transform_edt( + padded == 1) * padded + return 4 * (nearest_neighbor.sum()) + + +def spherical_volume(cell_mask, trap_image=None): + area = cell_mask.sum() + r = np.sqrt(area / np.pi) + return (4 * np.pi * r**3) / 3 + + +def min_maj_approximation(cell_mask, trap_image=None): + """Length approximation of minor and major axes of an ellipse from mask. + + + :param cell_mask: + :param trap_image: + :return: + """ + padded = np.pad(cell_mask,1, mode='constant', constant_values=0) + nn = ndimage.morphology.distance_transform_edt(padded == 1) * padded + dn = ndimage.morphology.distance_transform_edt(nn - nn.max()) * padded + cone_top = ndimage.morphology.distance_transform_edt(dn == 0) * padded + min_ax = np.round(nn.max()) + maj_ax = np.round(dn.max() + cone_top.sum()/2) + return min_ax, maj_ax + + +def eccentricity(cell_mask, trap_image=None): + min_ax, maj_ax = min_maj_approximation(cell_mask) + return np.sqrt(maj_ax**2 - min_ax**2)/maj_ax + + diff --git a/extraction/core/functions/custom/__init__.py b/extraction/core/functions/custom/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e5a0d9b4834ec8f46d6e0d1256c6dcaad2e460fe --- /dev/null +++ b/extraction/core/functions/custom/__init__.py @@ -0,0 +1 @@ +#!/usr/bin/env python3 diff --git a/extraction/core/functions/custom/localisation.py b/extraction/core/functions/custom/localisation.py new file mode 100644 index 0000000000000000000000000000000000000000..09f9094a597316cb6558076efe3c3530af2382fc --- /dev/null +++ b/extraction/core/functions/custom/localisation.py @@ -0,0 +1,117 @@ +""" How to do the nuc Est Conv from MATLAB +Based on the code in MattSegCode/Matt Seg +GUI/@timelapseTraps/extractCellDataStacksParfor.m + +Especially lines 342 to 399. +This part only replicates the method to get the nuc_est_conv values +""" +import numpy as np +import scipy +import skimage + + +def matlab_style_gauss2D(shape=(3, 3), sigma=0.5): + """ + 2D gaussian mask - should give the same result as MATLAB's + fspecial('gaussian',[shape],[sigma]) + """ + m, n = [(ss - 1.0) / 2.0 for ss in shape] + y, x = np.ogrid[-m : m + 1, -n : n + 1] + h = np.exp(-(x * x + y * y) / (2.0 * sigma * sigma)) + h[h < np.finfo(h.dtype).eps * h.max()] = 0 + sumh = h.sum() + if sumh != 0: + h /= sumh + return h + + +def gauss3D(shape=(3, 3, 3), sigma=(0.5, 0.5, 0.5)): + """3D gaussian mask - based on MATLAB's fspecial but made 3D.""" + m, n, p = [(ss - 1.0) / 2.0 for ss in shape] + z, y, x = np.ogrid[-p : p + 1, -m : m + 1, -n : n + 1] + sigmax, sigmay, sigmaz = sigma + xx = (x ** 2) / (2 * sigmax) + yy = (y ** 2) / (2 * sigmay) + zz = (z ** 2) / (2 * sigmaz) + h = np.exp(-(xx + yy + zz)) + h[h < np.finfo(h.dtype).eps * h.max()] = 0 # Truncate + sumh = h.sum() + if sumh != 0: + h /= sumh + return h + + +def small_peaks_conv(cell_mask, trap_image): + cell_fluo = trap_image[cell_mask] + # Get the number of pixels in the cell + num_cell_fluo = len(np.nonzero(cell_fluo)[0]) + # Sort cell pixels in descending fluorescence order + ratio_overlap = num_cell_fluo * 0.025 # TODO what is this? + + # Small Peak Conv + # Convolution parameters + conv_matrix = np.zeros((3, 3)) + # This makes a matrix with zeros in the corners and ones every where else + # Basically the minimal disk. + conv_matrix[1, :] = 1 + conv_matrix[:, 1] = 1 + # Reshape so that it is the size of a fifth of the cell, which is what we + # expect the size of the nucleus to be. + # TODO directly get a disk of that size? + # new_shape = tuple(x * ratio_overlap / 5 for x in conv_matrix.shape) + # conv_matrix = scipy.misc.imresize(conv_matrix, new_shape) + conv_matrix = skimage.morphology.disk(3 * ratio_overlap / 5) + # Apply convolution to the image + # TODO maybe rename 'conv_matrix' to 'kernel' + fluo_peaks = scipy.signal.convolve(trap_image, conv_matrix, "same") + fluo_peaks = fluo_peaks[cell_mask] + small_peak_conv = np.max(fluo_peaks) + return small_peak_conv + + +def nuc_est_conv(cell_mask, trap_image): + """ + :param cell_mask: the segmentation mask of the cell (filled) + :param trap_image: the image for the trap in which the cell is (all + channels) + """ + cell_loc = cell_mask # np.where(cell_mask)[0] + cell_fluo = trap_image[cell_mask] + num_cell_fluo = len(np.nonzero(cell_fluo)[0]) + + # Nuc Est Conv + alpha = 0.95 + approx_nuc_radius = np.sqrt(0.085 * num_cell_fluo / np.pi) + chi2inv = scipy.stats.distributions.chi2.ppf(alpha, df=2) + sd_est = approx_nuc_radius / np.sqrt(chi2inv) + + nuc_filt_hw = np.ceil(2 * approx_nuc_radius) + nuc_filter = matlab_style_gauss2D((2 * nuc_filt_hw + 1,) * 2, sd_est) + + cell_image = trap_image - np.median(cell_fluo) + cell_image[~cell_loc] = 0 + + nuc_conv = scipy.signal.convolve(cell_image, nuc_filter, "same") + nuc_est_conv = np.max(nuc_conv) + nuc_est_conv /= np.sum(nuc_filter ** 2) * alpha * np.pi * chi2inv * sd_est ** 2 + return nuc_est_conv + + +def nuc_conv_3d(cell_mask, trap_image, pixel_size=0.23, spacing=0.6): + cell_mask = np.dstack([cell_mask] * trap_image.shape[-1]) + ratio = spacing / pixel_size + cell_fluo = trap_image[cell_mask] + num_cell_fluo = len(np.nonzero(cell_fluo)[0]) + # Nuc Est Conv + alpha = 0.95 + approx_nuc_radius = np.sqrt(0.085 * num_cell_fluo / np.pi) + chi2inv = scipy.stats.distributions.chi2.ppf(alpha, df=2) + sd_est = approx_nuc_radius / np.sqrt(chi2inv) + nuc_filt_hw = np.ceil(2 * approx_nuc_radius) + nuc_filter = gauss3D((2 * nuc_filt_hw + 1,) * 3, (sd_est, sd_est, sd_est * ratio)) + cell_image = trap_image - np.median(cell_fluo) + cell_image[~cell_mask] = 0 + nuc_conv = scipy.signal.convolve(cell_image, nuc_filter, "same") + nuc_est_conv = np.max(nuc_conv) + nuc_est_conv /= np.sum(nuc_filter ** 2) * alpha * np.pi * chi2inv * sd_est ** 2 + return nuc_est_conv diff --git a/extraction/core/functions/defaults.py b/extraction/core/functions/defaults.py new file mode 100644 index 0000000000000000000000000000000000000000..7bd673a0737f0dc98a98565e5514615df1bb0c6b --- /dev/null +++ b/extraction/core/functions/defaults.py @@ -0,0 +1,67 @@ +# File with defaults for ease of use +from typing import Union +from pathlib import PosixPath, Path +import json + + +def exparams_from_meta(meta: Union[dict, PosixPath, str], extras=["ph"]): + """ + Obtain parameters from metadata of hdf5 file + """ + meta = meta if isinstance(meta, dict) else load_attributes(meta) + base = { + "tree": {"general": {"None": ["area", "volume"]}}, + "multichannel_ops": {}, + } + + av_channels = { + "Citrine", + "GFP", + "GFPFast", + "mCherry", + "pHluorin405", + "Flavin", + "Cy5", + "mKO2", + } + + default_reductions = {"np_max"} + default_metrics = {"mean", "median", "imBackground", "max2p5pc"} + default_rm = {r: default_metrics for r in default_reductions} + + av_flch = av_channels.intersection(meta["channels/channel"]).difference( + {"Brightfield, DIC"} + ) + + for ch in av_flch: + base["tree"][ch] = default_rm + + base["sub_bg"] = av_flch + + # Additional extraction + if "ph" in extras and {"pHluorin405", "GFPFast"}.issubset(av_flch): + + sets = { + b + a: (x, y) + for a, x in zip( + ["", "_bgsub"], + ( + ["GFPFast", "pHluorin405"], + ["GFPFast_bgsub", "pHluorin405_bgsub"], + ), + ) + for b, y in zip(["em_ratio", "gsum"], ["div0", "np_add"]) + } + for i, v in sets.items(): + base["multichannel_ops"][i] = [ + *v, + default_rm, + ] + + return base + + +def load_attributes(file: str, group="/"): + with h5py.File(file, "r") as f: + meta = dict(f[group].attrs.items()) + return meta diff --git a/extraction/core/functions/distributors.py b/extraction/core/functions/distributors.py new file mode 100644 index 0000000000000000000000000000000000000000..c0b8253c63d571bc73318d6d459d760ee95386e3 --- /dev/null +++ b/extraction/core/functions/distributors.py @@ -0,0 +1,24 @@ +import numpy as np + + +def trap_apply(cell_fun, cell_masks, trap_image, **kwargs): + """ + Apply a cell_function to a mask, trap_image pair + + :param cell_fun: function to apply to a cell (from extraction/cell.py) + :param cell_masks: (numpy 3d array) cells' segmentation mask + :param trap_image: the image for the trap in which the cell is (all + channels) + :**kwargs: parameters to pass if needed for custom functions + """ + + cells_iter = (*range(cell_masks.shape[2]),) + return [cell_fun(cell_masks[..., i], trap_image, **kwargs) for i in cells_iter] + + +def reduce_z(trap_image, fun): + # Optimise the reduction function if possible + if isinstance(fun, np.ufunc): + return fun.reduce(trap_image, axis=2) + else: + return np.apply_along_axis(fun, 2, trap_image) diff --git a/extraction/core/functions/io.py b/extraction/core/functions/io.py new file mode 100644 index 0000000000000000000000000000000000000000..3f146d1cdcb0cb836051e0b61c83993933f0dc4a --- /dev/null +++ b/extraction/core/functions/io.py @@ -0,0 +1,12 @@ +from yaml import load, dump + + +def dict_to_yaml(d, f): + with open(f, "w") as f: + dump(d, f) + + +def add_attrs(hdfile, path, files): + group = hdfile.create_group(path) + for k, v in files: + group.attrs[k] = v diff --git a/extraction/core/functions/loaders.py b/extraction/core/functions/loaders.py new file mode 100644 index 0000000000000000000000000000000000000000..7196574f67f155f67b7878ecfdde0f09eb1e71c8 --- /dev/null +++ b/extraction/core/functions/loaders.py @@ -0,0 +1,66 @@ +import numpy as np +from inspect import getmembers, isfunction, getargspec +from extraction.core.functions import cell, trap +from extraction.core.functions.custom import localisation +from extraction.core.functions.math import div0 +from extraction.core.functions.distributors import trap_apply + + +def load_cellfuns_core(): + # Generate str -> trap_function dict from functions in core.cell + return {f[0]: f[1] for f in getmembers(cell) if isfunction(f[1])} + + +def load_custom_args(): + """ + Load custom functions. If they have extra arguments also load these + """ + funs = {f[0]: f[1] for f in getmembers(localisation) if isfunction(f[1])} + args = { + k: getargspec(v).args[2:] + for k, v in funs.items() + if set(["cell_mask", "trap_image"]).intersection(getargspec(v).args) + } + + return ({k: funs[k] for k in args.keys()}, {k: v for k, v in args.items() if v}) + + +def load_cellfuns(): + # Generate str -> trap_function dict from core.cell and core.trap functions + cell_funs = load_cellfuns_core() + CELLFUNS = {} + for k, f in cell_funs.items(): + if isfunction(f): + + def tmp(f): + return lambda m, img: trap_apply(f, m, img) + + CELLFUNS[k] = tmp(f) + return CELLFUNS + + +def load_trapfuns(): + TRAPFUNS = {f[0]: f[1] for f in getmembers(trap) if isfunction(f[1])} + return TRAPFUNS + + +def load_funs(): + CELLFUNS = load_cellfuns() + TRAPFUNS = load_trapfuns() + + return CELLFUNS, TRAPFUNS, {**TRAPFUNS, **CELLFUNS} + + +def load_redfuns(): # TODO make defining reduction functions more flexible + RED_FUNS = { + "np_max": np.maximum, + "np_mean": np.mean, + "np_median": np.median, + "None": None, + } + return RED_FUNS + + +def load_mergefuns(): + MERGE_FUNS = {"div0": div0, "np_add": np.add} + return MERGE_FUNS diff --git a/extraction/core/functions/math.py b/extraction/core/functions/math.py new file mode 100644 index 0000000000000000000000000000000000000000..95c1add81067d7cf0c7004b4a1788408b6d1fac3 --- /dev/null +++ b/extraction/core/functions/math.py @@ -0,0 +1,15 @@ +import numpy as np + + +def div0(a, b, fill=0): + """a / b, divide by 0 -> `fill` + div0( [-1, 0, 1], 0, fill=np.nan) -> [nan nan nan] + div0( 1, 0, fill=np.inf ) -> inf + """ + with np.errstate(divide="ignore", invalid="ignore"): + c = np.true_divide(a, b) + if np.isscalar(c): + return c if np.isfinite(c) else fill + else: + c[~np.isfinite(c)] = fill + return c diff --git a/extraction/core/functions/trap.py b/extraction/core/functions/trap.py new file mode 100644 index 0000000000000000000000000000000000000000..483f97d8bffca1ee4c264f703ada67f0a8e4933b --- /dev/null +++ b/extraction/core/functions/trap.py @@ -0,0 +1,16 @@ +## Trap-wise calculations + +import numpy as np + + +def imBackground(cell_masks, trap_image): + ''' + :param cell_masks: (numpy 3d array) cells' segmentation mask + :param trap_image: the image for the trap in which the cell is (all + channels) + ''' + if not len(cell_masks): + cell_masks = np.zeros_like(trap_image) + + background = ~cell_masks.sum(axis=2).astype(bool) + return (np.median(trap_image[np.where(background)])) diff --git a/extraction/core/functions/utils.py b/extraction/core/functions/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..dcae68c9bde33a86108cf5a6c5bda701c80420e5 --- /dev/null +++ b/extraction/core/functions/utils.py @@ -0,0 +1,19 @@ +from collections import deque + + +def depth(d): + """ + Copied from https://stackoverflow.com/a/23499088 + + Used to determine the depth of our config trees and fill them + """ + queue = deque([(id(d), d, 1)]) + memo = set() + while queue: + id_, o, level = queue.popleft() + if id_ in memo: + continue + memo.add(id_) + if isinstance(o, dict): + queue += ((id(v), v, level + 1) for v in o.values()) + return level diff --git a/extraction/core/functions/versioning.py b/extraction/core/functions/versioning.py new file mode 100644 index 0000000000000000000000000000000000000000..222126ef0d75264d9614bf595c775ac1fc26f1f0 --- /dev/null +++ b/extraction/core/functions/versioning.py @@ -0,0 +1,12 @@ +import git +import pkg_resources + + +def get_sha(): + repo = git.Repo(search_parent_directories=True) + sha = repo.head.object.hexsha + return sha + + +def get_version(pkg="extraction"): + return pkg_resources.require(pkg)[0].version diff --git a/extraction/core/lineage.py b/extraction/core/lineage.py new file mode 100644 index 0000000000000000000000000000000000000000..d8876a882e6c86c8080a507636d11c37039dc4fa --- /dev/null +++ b/extraction/core/lineage.py @@ -0,0 +1,31 @@ +from copy import copy + +def reassign_mo_bud(mo_bud, trans): + """ + Update mother_bud dictionary using another dict with tracks joined + + input + :param mo_bud: dict with mother's ids as keys and daughters' as values + :param trans: dict of joint tracks where moved track -> static track + + output + mo_bud with updated cell ids + """ + + val2lst = lambda x: [j for i in x.values() for j in i] + + bud_inter=set(val2lst(mo_bud)).intersection(trans.keys()) + + # translate daughters + mo_bud = copy(mo_bud) + for k,das in mo_bud.items(): + for da in bud_inter.intersection(das): + mo_bud[k][mo_bud[k].index(da)] = trans[da] + + # translate mothers + mo_inter = set(mo_bud.keys()).intersection(trans.keys()) + for k in mo_inter: + mo_bud[trans[k]] = mo_bud.get(trans[k], []) + mo_bud[k] + del mo_bud[k] + + return mo_bud diff --git a/extraction/core/omero.py b/extraction/core/omero.py new file mode 100644 index 0000000000000000000000000000000000000000..cd59ce75441fd43a84a732d2f71400880e4e5525 --- /dev/null +++ b/extraction/core/omero.py @@ -0,0 +1,27 @@ +from tqdm import tqdm + +from omero.gateway import BlitzGateway + +# Helper funs +def connect_omero(): + conn = BlitzGateway(*get_creds(), host='islay.bio.ed.ac.uk', port=4064) + conn.connect() + return conn + +def get_creds(): + return('upload', + '***REMOVED***', #OMERO Password + ) + +def download_file(f): + """ + Download file in chunks using FileWrapper object + """ + desc = 'Downloading ' + f.getFileName() + \ + ' (' + str(round(f.getFileSize()/1000**2, 2)) + 'Mb)' + + down_file = bytearray() + for c in tqdm(f.getFileInChunks(), desc=desc): + down_file += c + + return down_file diff --git a/extraction/core/tracks.py b/extraction/core/tracks.py new file mode 100644 index 0000000000000000000000000000000000000000..71002b72bf096332b29507e40dd39f97b951149e --- /dev/null +++ b/extraction/core/tracks.py @@ -0,0 +1,498 @@ +''' +Functions to process, filter and merge tracks. +''' + +# from collections import Counter + +from copy import copy +from typing import Union, List + +import numpy as np +import pandas as pd + +from scipy.signal import savgol_filter +# from scipy.optimize import linear_sum_assignment +# from scipy.optimize import curve_fit + +from matplotlib import pyplot as plt + +def load_test_dset(): + # Load development dataset to test functions + return pd.DataFrame({('a',1,1):[2, 5, np.nan, 6,8] + [np.nan] * 5, + ('a',1,2):list(range(2,12)), + ('a',1,3):[np.nan] * 8 + [6,7], + ('a',1,4):[np.nan] * 5 + [9,12,10,14,18]}, + index=range(1,11)).T + +def get_ntps(track:pd.Series) -> int: + # Get number of timepoints + indices = np.where(track.notna()) + return np.max(indices) - np.min(indices) + + +def get_tracks_ntps(tracks:pd.DataFrame) -> pd.Series: + return tracks.apply(get_ntps, axis=1) + +def get_avg_gr(track:pd.Series) -> int: + ''' + Get average growth rate for a track. + + :param tracks: Series with volume and timepoints as indices + ''' + ntps = get_ntps(track) + vals = track.dropna().values + gr = (vals[-1] - vals[0] )/ ntps + return gr + + +def get_avg_grs(tracks:pd.DataFrame) -> pd.DataFrame: + ''' + Get average growth rate for a group of tracks + + :param tracks: (m x n) dataframe where rows are cell tracks and + columns are timepoints + ''' + return tracks.apply(get_avg_gr, axis=1) + + +def clean_tracks(tracks, min_len:int=6, min_gr:float=0.5) -> pd.DataFrame: + ''' + Clean small non-growing tracks and return the reduced dataframe + + :param tracks: (m x n) dataframe where rows are cell tracks and + columns are timepoints + :param min_len: int number of timepoints cells must have not to be removed + :param min_gr: float Minimum mean growth rate to assume an outline is growing + ''' + ntps = get_tracks_ntps(tracks) + grs = get_avg_grs(tracks) + + growing_long_tracks = tracks.loc[(ntps >= min_len) & (grs > min_gr)] + + return (growing_long_tracks) + +def merge_tracks(tracks, drop=False, **kwargs) -> pd.DataFrame: + ''' + Join tracks that are contiguous and within a volume threshold of each other + + :param tracks: (m x n) dataframe where rows are cell tracks and + columns are timepoints + :param kwargs: args passed to get_joinable + + returns + + :joint_tracks: (m x n) Dataframe where rows are cell tracks and + columns are timepoints. Merged tracks are still present but filled + with np.nans. + ''' + + # calculate tracks that can be merged until no more traps can be merged + joinable_pairs = get_joinable(tracks, **kwargs) + if joinable_pairs: + tracks = join_tracks(tracks, joinable_pairs, drop=drop) + joint_ids = get_joint_ids(joinable_pairs) + + return (tracks, joint_ids) + + +def get_joint_ids(merging_seqs) -> dict: + ''' + Convert a series of merges into a dictionary where + the key is the cell_id of destination and the value a list + of the other track ids that were merged into the key + + :param merging_seqs: list of tuples of indices indicating the + sequence of merging events. It is important for this to be in sequential order + + How it works: + + The order of merging matters for naming, always the leftmost track will keep the id + + For example, having tracks (a, b, c, d) and the iterations of merge events: + + 0 a b c d + 1 a b cd + 2 ab cd + 3 abcd + + We shold get: + + output {a:a, b:a, c:a, d:a} + + ''' + targets, origins = list(zip(*merging_seqs)) + static_tracks = set(targets).difference(origins) + + joint = {track_id: track_id for track_id in static_tracks} + for target, origin in merging_seqs: + joint[origin] = target + + moved_target = [k for k,v in joint.items() \ + if joint[v]!=v and v in joint.values()] + + for orig in moved_target: + joint[orig] = rec_bottom(joint, orig) + + return {k:v for k,v in joint.items() if k!=v} # remove ids that point to themselves + +def rec_bottom(d, k): + if d[k] == k: + return k + else: + return rec_bottom(d, d[k]) + +def join_tracks(tracks, joinable_pairs, drop=False) -> pd.DataFrame: + ''' + Join pairs of tracks from later tps towards the start. + + :param tracks: (m x n) dataframe where rows are cell tracks and + columns are timepoints + + returns (copy) + + :param joint_tracks: (m x n) Dataframe where rows are cell tracks and + columns are timepoints. Merged tracks are still present but filled + with np.nans. + :param drop: bool indicating whether or not to drop moved rows + + ''' + + + tmp = copy(tracks) + for target, source in joinable_pairs: + tmp.loc[target] = join_track_pairs(tmp.loc[target], tmp.loc[source]) + + if drop: + tmp = tmp.drop(source) + + return (tmp) + +def join_track_pairs(track1, track2): + tmp = copy(track1) + tmp.loc[track2.dropna().index] = track2.dropna().values + + return tmp + +def get_joinable(tracks, smooth=False, tol=0.1, window=5, degree=3) -> dict: + ''' + Get the pair of track (without repeats) that have a smaller error than the + tolerance. If there is a track that can be assigned to two or more other + ones, it chooses the one with a lowest error. + + :param tracks: (m x n) dataframe where rows are cell tracks and + columns are timepoints + :param tol: float or int threshold of average (prediction error/std) necessary + to consider two tracks the same. If float is fraction of first track, + if int it is absolute units. + :param window: int value of window used for savgol_filter + :param degree: int value of polynomial degree passed to savgol_filter + + ''' + + tracks.index.names = ['pos', 'trap', 'cell'] #TODO remove this once it is integrated in the tracker + # contig=tracks.groupby(['pos','trap']).apply(tracks2contig) + clean = clean_tracks(tracks, min_len=window+1, min_gr = 0.9) # get useful tracks + contig=clean.groupby(['pos','trap']).apply(get_contiguous_pairs) + contig = contig.loc[contig.apply(len) > 0] + # candict = {k:v for d in contig.values for k,v in d.items()} + + # smooth all relevant tracks + + linear=set([k for v in contig.values for i in v for j in i for k in j]) + if smooth: # Apply savgol filter TODO fix nans affecting edge placing + savgol_on_srs = lambda x: non_uniform_savgol(x.index, x.values, + window, degree) + smoothed_tracks = clean.loc[linear].apply(savgol_on_srs,1) + else: + smoothed_tracks = clean.loc[linear].apply(lambda x: np.array(x.values), axis=1) + + # fetch edges from ids TODO (IF necessary, here we can compare growth rates) + idx_to_edge = lambda preposts: [([get_val(smoothed_tracks.loc[pre],-1) for pre in pres], + [get_val(smoothed_tracks.loc[post],0) for post in posts]) + for pres, posts in preposts] + edges = contig.apply(idx_to_edge) + + closest_pairs = edges.apply(get_vec_closest_pairs, tol=tol) + + #match local with global ids + joinable_ids = [localid_to_idx(closest_pairs.loc[i], contig.loc[i])\ + for i in closest_pairs.index] + + return [pair for pairset in joinable_ids for pair in pairset] + +get_val = lambda x, n: x[~np.isnan(x)][n] if len(x[~np.isnan(x)]) else np.nan + + +def localid_to_idx(local_ids, contig_trap): + """Fetch then original ids from a nested list with joinable local_ids + + input + :param local_ids: list of list of pairs with cell ids to be joint + :param local_ids: list of list of pairs with corresponding cell ids + + return + list of pairs with (experiment-level) ids to be joint + """ + lin_pairs = [] + for i,pairs in enumerate(local_ids): + if len(pairs): + for left,right in pairs: + lin_pairs.append((contig_trap[i][0][left], + contig_trap[i][1][right])) + return lin_pairs + +def get_vec_closest_pairs(lst:List, **kwags): + return [get_closest_pairs(*l, **kwags) for l in lst] + +def get_closest_pairs(pre:List[float], post:List[float], tol:Union[float,int]=1): + """Calculate a cost matrix the Hungarian algorithm to pick the best set of + options + + input + :param pre: list of floats with edges on left + :param post: list of floats with edges on right + :param tol: int or float if int metrics of tolerance, if float fraction + + returns + :: list of indices corresponding to the best solutions for matrices + + """ + if len(pre) > len(post): + dMetric = np.abs(np.subtract.outer(post,pre)) + else: + dMetric = np.abs(np.subtract.outer(pre,post)) + # dMetric[np.isnan(dMetric)] = tol + 1 + np.nanmax(dMetric) # nans will be filtered + # ids = linear_sum_assignment(dMetric) + dMetric[np.isnan(dMetric)] = tol + 1 + np.nanmax(dMetric) # nans will be filtered + + ids = solve_matrix(dMetric) + if not len(ids[0]): + return [] + + norm = np.array(pre)[ids[len(pre)>len(post)]] if tol<1 else 1 # relative or absolute tol + result = dMetric[ids]/norm + ids = ids if len(pre)<len(post) else ids[::-1] + + return [idx for idx,res in zip(zip(*ids), result) if res < tol] + +def solve_matrix(dMetric): + """ + Solve cost matrix focusing on getting the smallest cost at each iteration. + + input + :param dMetric: np.array cost matrix + + returns + tuple of np.arrays indicating picks with lowest individual value + """ + glob_is = [] + glob_js = [] + if (~np.isnan(dMetric)).any(): + tmp = copy(dMetric ) + std = sorted(tmp[~np.isnan(tmp)]) + while (~np.isnan(std)).any(): + v = std[0] + i_s,j_s = np.where(tmp==v) + i = i_s[0] + j = j_s[0] + tmp[i,:]+= np.nan + tmp[:,j]+= np.nan + glob_is.append(i) + glob_js.append(j) + + std = sorted(tmp[~np.isnan(tmp)]) + + + return (np.array( glob_is ), np.array( glob_js )) + +def plot_joinable(tracks, joinable_pairs, max=64): + """ + Convenience plotting function for debugging and data vis + """ + + nx=8 + ny=8 + _, axes = plt.subplots(nx,ny) + for i in range(nx): + for j in range(ny): + if i*ny+j < len(joinable_pairs): + ax = axes[i, j] + pre, post = joinable_pairs[i * ny + j] + pre_srs = tracks.loc[pre].dropna() + post_srs = tracks.loc[post].dropna() + ax.plot(pre_srs.index, pre_srs.values , 'b') + # try: + # totrange = np.arange(pre_srs.index[0],post_srs.index[-1]) + # ax.plot(totrange, interpolate(pre_srs, totrange), 'r-') + # except: + # pass + ax.plot(post_srs.index, post_srs.values, 'g') + + plt.show() + +def get_contiguous_pairs(tracks: pd.DataFrame) -> list: + ''' + Get all pair of contiguous track ids from a tracks dataframe. + + :param tracks: (m x n) dataframe where rows are cell tracks and + columns are timepoints + :param min_dgr: float minimum difference in growth rate from the interpolation + ''' + # indices = np.where(tracks.notna()) + + + mins, maxes = [tracks.notna().apply(np.where, axis=1).apply(fn) + for fn in (np.min, np.max)] + + mins_d = mins.groupby(mins).apply(lambda x: x.index.tolist()) + mins_d.index = mins_d.index - 1 # make indices equal + maxes_d = maxes.groupby(maxes).apply(lambda x: x.index.tolist()) + + common = sorted(set(mins_d.index).intersection(maxes_d.index), reverse=True) + + return [(maxes_d[t], mins_d[t]) for t in common] + +# def fit_track(track: pd.Series, obj=None): +# if obj is None: +# obj = objective + +# x = track.dropna().index +# y = track.dropna().values +# popt, _ = curve_fit(obj, x, y) + +# return popt + +# def interpolate(track, xs) -> list: +# ''' +# Interpolate next timepoint from a track + +# :param track: pd.Series of volume growth over a time period +# :param t: int timepoint to interpolate +# ''' +# popt = fit_track(track) +# # perr = np.sqrt(np.diag(pcov)) +# return objective(np.array(xs), *popt) + + +# def objective(x,a,b,c,d) -> float: +# # return (a)/(1+b*np.exp(c*x))+d +# return (((x+d)*a)/((x+d)+b))+c + +# def cand_pairs_to_dict(candidates): +# d={x:[] for x,_ in candidates} +# for x,y in candidates: +# d[x].append(y) +# return d + + +def non_uniform_savgol(x, y, window, polynom): + """ + Applies a Savitzky-Golay filter to y with non-uniform spacing + as defined in x + + This is based on https://dsp.stackexchange.com/questions/1676/savitzky-golay-smoothing-filter-for-not-equally-spaced-data + The borders are interpolated like scipy.signal.savgol_filter would do + + source: https://dsp.stackexchange.com/a/64313 + + Parameters + ---------- + x : array_like + List of floats representing the x values of the data + y : array_like + List of floats representing the y values. Must have same length + as x + window : int (odd) + Window length of datapoints. Must be odd and smaller than x + polynom : int + The order of polynom used. Must be smaller than the window size + + Returns + ------- + np.array of float + The smoothed y values + """ + if len(x) != len(y): + raise ValueError('"x" and "y" must be of the same size') + + if len(x) < window: + raise ValueError('The data size must be larger than the window size') + + if type(window) is not int: + raise TypeError('"window" must be an integer') + + if window % 2 == 0: + raise ValueError('The "window" must be an odd integer') + + if type(polynom) is not int: + raise TypeError('"polynom" must be an integer') + + if polynom >= window: + raise ValueError('"polynom" must be less than "window"') + + half_window = window // 2 + polynom += 1 + + # Initialize variables + A = np.empty((window, polynom)) # Matrix + tA = np.empty((polynom, window)) # Transposed matrix + t = np.empty(window) # Local x variables + y_smoothed = np.full(len(y), np.nan) + + # Start smoothing + for i in range(half_window, len(x) - half_window, 1): + # Center a window of x values on x[i] + for j in range(0, window, 1): + t[j] = x[i + j - half_window] - x[i] + + # Create the initial matrix A and its transposed form tA + for j in range(0, window, 1): + r = 1.0 + for k in range(0, polynom, 1): + A[j, k] = r + tA[k, j] = r + r *= t[j] + + # Multiply the two matrices + tAA = np.matmul(tA, A) + + # Invert the product of the matrices + tAA = np.linalg.inv(tAA) + + # Calculate the pseudoinverse of the design matrix + coeffs = np.matmul(tAA, tA) + + # Calculate c0 which is also the y value for y[i] + y_smoothed[i] = 0 + for j in range(0, window, 1): + y_smoothed[i] += coeffs[0, j] * y[i + j - half_window] + + # If at the end or beginning, store all coefficients for the polynom + if i == half_window: + first_coeffs = np.zeros(polynom) + for j in range(0, window, 1): + for k in range(polynom): + first_coeffs[k] += coeffs[k, j] * y[j] + elif i == len(x) - half_window - 1: + last_coeffs = np.zeros(polynom) + for j in range(0, window, 1): + for k in range(polynom): + last_coeffs[k] += coeffs[k, j] * y[len(y) - window + j] + + # Interpolate the result at the left border + for i in range(0, half_window, 1): + y_smoothed[i] = 0 + x_i = 1 + for j in range(0, polynom, 1): + y_smoothed[i] += first_coeffs[j] * x_i + x_i *= x[i] - x[half_window] + + # Interpolate the result at the right border + for i in range(len(x) - half_window, len(x), 1): + y_smoothed[i] = 0 + x_i = 1 + for j in range(0, polynom, 1): + y_smoothed[i] += last_coeffs[j] * x_i + x_i *= x[i] - x[-half_window - 1] + + return y_smoothed diff --git a/extraction/examples/argo.py b/extraction/examples/argo.py new file mode 100644 index 0000000000000000000000000000000000000000..d361c1ac411abdb9b81b32433fd13531c9a01071 --- /dev/null +++ b/extraction/examples/argo.py @@ -0,0 +1,39 @@ +# Example of argo experiment explorer +from agora.argo import Argo +from extraction.core.extractor import Extractor +from extraction.core.parameters import Parameters +from extraction.core.functions.defaults import get_params + +argo = Argo() +argo.load() +# argo.channels("GFP") +argo.tags(["Alan"]) +argo.complete() +# argo.cExperiment() +# argo.tiler_cells() + +# params = Parameters(**get_params("batman_ph_dual_fast")) + + +# def try_extract(d): +# try: +# params = Parameters(**get_params("batman_ph_dual_fast")) +# ext = Extractor(params, source=d.getId()) +# ext.load_tiler_cells() +# ext.process_experiment() +# print(d.getId(), d.getName(), "Experiment processed") +# return True +# except: +# print(d.getId(), d.getName(), "Experiment not processed") + +# return False + + +# from multiprocessing.dummy import Pool as ThreadPool + +# pool = ThreadPool(4) +# results = pool.map(try_extract, argo.dsets) +# import pickle + +# with open("results.pkl", "wb") as f: +# pickle.dump(results, f) diff --git a/extraction/examples/data.py b/extraction/examples/data.py new file mode 100644 index 0000000000000000000000000000000000000000..949d3182ce542d82bf2d72575ddcf5d42000f421 --- /dev/null +++ b/extraction/examples/data.py @@ -0,0 +1,108 @@ +""" +Functions to load and reshape examples for extraction development + +The basic format for data will be pair data/masks pairs. Data will be +assumed to be a single slide, given that this reduction is expected +to happen beforehand. + +The most basic functions were copied from Swain Lab's baby module, +specifically baby/io.py +""" + +import os +import json +import re + +from pathlib import Path +from itertools import groupby +from typing import Callable + +import numpy as np +import random +from imageio import imread + +from extraction.core.functions.distributors import reduce_z + + +def load_tiled_image(filename): + tImg = imread(filename) + info = json.loads(tImg.meta.get("Description", "{}")) + tw, th = info.get("tilesize", tImg.shape[0:2]) + nt = info.get("ntiles", 1) + nr, nc = info.get("layout", (1, 1)) + nc_final_row = np.mod(nt, nc) + img = np.zeros((tw, th, nt), dtype=tImg.dtype) + for i in range(nr): + i_nc = nc_final_row if i + 1 == nr and nc_final_row > 0 else nc + for j in range(i_nc): + ind = i * nc + j + img[:, :, ind] = tImg[i * tw : (i + 1) * tw, j * th : (j + 1) * th] + return img, info + + +def load_paired_images(filenames, typeA="Brightfield", typeB="segoutlines"): + re_imlbl = re.compile(r"^(.*)_(" + typeA + r"|" + typeB + r")\.png$") + # For groupby to work, the list needs to be sorted; also has the side + # effect of ensuring filenames is no longer a generator + filenames = sorted(filenames) + matches = [re_imlbl.match(f.name) for f in filenames] + valid = filter(lambda m: m[0], zip(matches, filenames)) + grouped = { + k: {m.group(2): f for m, f in v} + for k, v in groupby(valid, key=lambda m: m[0].group(1)) + } + valid = [set(v.keys()).issuperset({typeA, typeB}) for v in grouped.values()] + if not all(valid): + raise Exception + return { + l: {t: load_tiled_image(f) for t, f in g.items()} for l, g in grouped.items() + } + + +def load(path=None): + """ + Loads annotated pngs into memory. Only designed for GFP and brightfield. + + input + :path: Folder used to look for images + + returns + list of dictionaries containing GFP, Brightfield and segoutlines channel + """ + if path is None: + path = Path(os.path.dirname(os.path.realpath(__file__))) / Path("pairs_data") + + image_dir = Path(path) + channels = ["Brightfield", "GFP"] + imset = {"segoutlines": {}} + for ch in channels: + imset[ch] = {} + pos = load_paired_images(image_dir.glob("*.png"), typeA=ch) + for key, img in pos.items(): + imset[ch][key] = img[ch][0] + imset["segoutlines"][key] = img["segoutlines"][0].astype(bool) + + return [{ch: imset[ch][pos] for ch in imset.keys()} for pos in pos.keys()] + + +def load_1z(fun: Callable = np.maximum, path: str = None): + """ + --- + fun: Function used to reduce the multiple stacks + path: Path to pass to load function + + + """ + dsets = load(path=path) + dsets_1z = [] + for dset in dsets: + tmp = {} + for ch, img in dset.items(): + if ch == "segoutlines": + tmp[ch] = img + else: + tmp[ch] = reduce_z(img, fun) + + dsets_1z.append(tmp) + + return dsets_1z diff --git a/extraction/examples/pairs_data/pos010_trap001_tp0001_Brightfield.png b/extraction/examples/pairs_data/pos010_trap001_tp0001_Brightfield.png new file mode 100644 index 0000000000000000000000000000000000000000..d50e92fbf3a744f01fd44993720c2b2d38b3fb0b Binary files /dev/null and b/extraction/examples/pairs_data/pos010_trap001_tp0001_Brightfield.png differ diff --git a/extraction/examples/pairs_data/pos010_trap001_tp0001_GFP.png b/extraction/examples/pairs_data/pos010_trap001_tp0001_GFP.png new file mode 100644 index 0000000000000000000000000000000000000000..2b50663f26ed9c7f0305c35339d23fbbc1b529d8 Binary files /dev/null and b/extraction/examples/pairs_data/pos010_trap001_tp0001_GFP.png differ diff --git a/extraction/examples/pairs_data/pos010_trap001_tp0001_segoutlines.png b/extraction/examples/pairs_data/pos010_trap001_tp0001_segoutlines.png new file mode 100644 index 0000000000000000000000000000000000000000..eb29a60c11d99411cad09711b046d1cb32318e4d Binary files /dev/null and b/extraction/examples/pairs_data/pos010_trap001_tp0001_segoutlines.png differ diff --git a/extraction/examples/pairs_data/pos010_trap050_tp0001_Brightfield.png b/extraction/examples/pairs_data/pos010_trap050_tp0001_Brightfield.png new file mode 100644 index 0000000000000000000000000000000000000000..9f26817ea8dffffd33f360dd39031ee59830d673 Binary files /dev/null and b/extraction/examples/pairs_data/pos010_trap050_tp0001_Brightfield.png differ diff --git a/extraction/examples/pairs_data/pos010_trap050_tp0001_GFP.png b/extraction/examples/pairs_data/pos010_trap050_tp0001_GFP.png new file mode 100644 index 0000000000000000000000000000000000000000..a600fefdf7dff067a3274f808d359491babca203 Binary files /dev/null and b/extraction/examples/pairs_data/pos010_trap050_tp0001_GFP.png differ diff --git a/extraction/examples/pairs_data/pos010_trap050_tp0001_segoutlines.png b/extraction/examples/pairs_data/pos010_trap050_tp0001_segoutlines.png new file mode 100644 index 0000000000000000000000000000000000000000..8dc7a456ee8df34e09e6bee62008ecb5653ff24f Binary files /dev/null and b/extraction/examples/pairs_data/pos010_trap050_tp0001_segoutlines.png differ diff --git a/extraction/examples/pairs_data/pos010_trap081_tp0001_Brightfield.png b/extraction/examples/pairs_data/pos010_trap081_tp0001_Brightfield.png new file mode 100644 index 0000000000000000000000000000000000000000..5670a02a388cd94fdbcf9974295b766caa09632c Binary files /dev/null and b/extraction/examples/pairs_data/pos010_trap081_tp0001_Brightfield.png differ diff --git a/extraction/examples/pairs_data/pos010_trap081_tp0001_GFP.png b/extraction/examples/pairs_data/pos010_trap081_tp0001_GFP.png new file mode 100644 index 0000000000000000000000000000000000000000..7b5702fc7a99653930bd22aa3ef520ac2c2749f3 Binary files /dev/null and b/extraction/examples/pairs_data/pos010_trap081_tp0001_GFP.png differ diff --git a/extraction/examples/pairs_data/pos010_trap081_tp0001_segoutlines.png b/extraction/examples/pairs_data/pos010_trap081_tp0001_segoutlines.png new file mode 100644 index 0000000000000000000000000000000000000000..909cfe00ec9dedd03e2f8bd565eaa2c18f6a5e02 Binary files /dev/null and b/extraction/examples/pairs_data/pos010_trap081_tp0001_segoutlines.png differ diff --git a/extraction/examples/pos_example.py b/extraction/examples/pos_example.py new file mode 100644 index 0000000000000000000000000000000000000000..f535c9cdd3e000a2280322dcb9c6ffd5e49aab1e --- /dev/null +++ b/extraction/examples/pos_example.py @@ -0,0 +1,20 @@ +from extraction.core.parameters import Parameters +from extraction.core.extractor import Extractor +import numpy as np + +params = Parameters( + tree={ + "general": {"None": ["area"]}, + "GFPFast": {"np_max": ["mean", "median", "imBackground"]}, + "pHluorin405": {"np_max": ["mean", "median", "imBackground"]}, + "mCherry": { + "np_max": ["mean", "median", "imBackground", "max5px", "max2p5pc"] + }, + } +) + + +ext = Extractor(params, omero_id=19310) +# ext.extract_exp(tile_size=117) +d=ext.extract_tp(tp=1,tile_size=117) + diff --git a/extraction/examples/test_pipeline.py b/extraction/examples/test_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..26c0cea4dce38d8b8a36cf8d342147daee161da2 --- /dev/null +++ b/extraction/examples/test_pipeline.py @@ -0,0 +1,26 @@ +import numpy as np +from pathlib import Path +from extraction.core.extractor import Extractor +from extraction.core.parameters import Parameters +from extraction.core.functions.defaults import get_params + +params = Parameters(**get_params("batman_ph_dual_fast")) +# ext = Extractor(params, source=19918) # 19831 +ext = Extractor(params, source=19831) +ext.load_tiler() +self = ext +# s=self.extract_exp(tree={'general':{None:['area']}, 'GFPFast':{np.maximum:['median','mean']}},poses=self.expt.positions[:2], tps=[0,1], stg='df') +s = self.extract_exp() +# # import cProfile +# # profile = cProfile.Profile() +# # profile.enable() +# # ext.change_position(ext.expt.positions[1]) +# # tracks = self.extract_pos( +# # tree={('general'):{None: # Other metrics can be used +# # [tidy_metric]}})#['general',None,'area'] + +# # profile.disable() +# # import pstats +# # ps = pstats.Stats(profile) +# # ps.sort_stats('cumulative') +# # ps.print_stats() diff --git a/extraction/examples/tiler_error.py b/extraction/examples/tiler_error.py new file mode 100644 index 0000000000000000000000000000000000000000..9915edf2afc012d81a769a83d9d1b2ad95f91103 --- /dev/null +++ b/extraction/examples/tiler_error.py @@ -0,0 +1,35 @@ +from core.experiment import Experiment +from core.segment import Tiler +expt = Experiment.from_source(19310, #Experiment ID on OMERO + 'upload', #OMERO Username + '***REMOVED***', #OMERO Password + 'islay.bio.ed.ac.uk', #OMERO host + port=4064 #This is default + ) + + +# Load whole position +img=expt[0,0,:,:,2] +plt.imshow(img[0,0,...,0]); plt.show() + +# Manually get template +tilesize=117 +x0=827 +y0=632 +trap_template = img[0,0,x0:x0+tilesize,y0:y0+tilesize,0] +plt.imshow(trap_template); plt.show() + +tiler = Tiler(expt, template = trap_template) + +# Load images (takes about 5 mins) +trap_tps = tiler.get_traps_timepoint(0, tile_size=117, z=[2]) + +#Plot found traps +nrows, ncols = (5,5) +fig, axes = plt.subplots(nrows,ncols) +for i in range(nrows): + for j in range(ncols): + if i*nrows+j < trap_tps.shape[0]: + axes[i,j].imshow(trap_tps[i*nrows+j,0,0,...,0]) +plt.show() + diff --git a/tests/argo/test_argo.py b/tests/argo/test_argo.py new file mode 100644 index 0000000000000000000000000000000000000000..c07e7d18c7523e12f2299f14098232144102864b --- /dev/null +++ b/tests/argo/test_argo.py @@ -0,0 +1,32 @@ +# Example of argo experiment explorer +import pytest +from agora.argo import Argo + + +@pytest.mark.skip(reason="no way of testing this without sensitive info") +def test_load(): + argo = Argo() + argo.load() + return 1 + + +@pytest.mark.skip(reason="no way of testing this without sensitive info") +def test_channel_filter(): + argo = Argo() + argo.load() + argo.channels("GFP") + return 1 + + +@pytest.mark.skip(reason="no way of testing this without sensitive info") +def test_tags(): + argo = Argo() + argo.load() + argo.channels("GFP") + argo.tags(["Alan", "batgirl"]) + return 1 + + +@pytest.mark.skip(reason="no way of testing this without sensitive info") +def test_timepoint(): + pass diff --git a/tests/extraction/__pycache__/log_test.cpython-37-pytest-6.2.5.pyc b/tests/extraction/__pycache__/log_test.cpython-37-pytest-6.2.5.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d598779a3e3d03ea6335aa26a3c1b73d713e53e Binary files /dev/null and b/tests/extraction/__pycache__/log_test.cpython-37-pytest-6.2.5.pyc differ diff --git a/tests/extraction/data/mo_bud.pkl b/tests/extraction/data/mo_bud.pkl new file mode 100644 index 0000000000000000000000000000000000000000..e78d2e2c6be8ea3036d94c96a0c2bf56bb03e9c5 Binary files /dev/null and b/tests/extraction/data/mo_bud.pkl differ diff --git a/tests/extraction/data/tracks.pkl b/tests/extraction/data/tracks.pkl new file mode 100644 index 0000000000000000000000000000000000000000..6ac81124b81e01c101c23a8b68937831f9fe6c42 Binary files /dev/null and b/tests/extraction/data/tracks.pkl differ diff --git a/tests/extraction/log_test.py b/tests/extraction/log_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a09b5d01805ddf5daeb3e9225e90c919723122a8 --- /dev/null +++ b/tests/extraction/log_test.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 +import pytest + + +def test_dummy(): + print("passed") diff --git a/tests/extraction/test_base.py b/tests/extraction/test_base.py new file mode 100644 index 0000000000000000000000000000000000000000..68c13066ada6d827b4b632773b62199abd58c084 --- /dev/null +++ b/tests/extraction/test_base.py @@ -0,0 +1,65 @@ +from itertools import product +import pytest + +from extraction.core.extractor import hollowExtractor, Parameters +from extraction.core.functions import cell +from extraction.core.functions.trap import imBackground +from extraction.core.functions.loaders import ( + load_funs, + load_cellfuns, + load_trapfuns, + load_redfuns, +) +from extraction.examples import data +from extraction.core.functions.defaults import get_params + +dsets1z = data.load_1z() +dsets = data.load() +masks = [d["segoutlines"] for d in dsets1z] +functions = load_funs()[2].values() +tree = { + c: {r: list(load_funs()[2].keys()) for r in load_redfuns()} + for c in dsets[0] + if c != "segoutlines" +} + + +@pytest.mark.parametrize( + ["imgs", "masks", "f"], list(product(dsets1z, masks, functions)) +) +def test_metrics_run(imgs, masks, f): + """ + Test all core cell functions using pre-flattened images + """ + + for ch, img in imgs.items(): + if ch is not "segoutlines": + assert tuple(masks.shape[:2]) == tuple(imgs[ch].shape) + f(masks, img) + + +@pytest.mark.parametrize(["imgs", "masks", "tree"], product(dsets, masks, tree)) +def test_extractor(imgs, masks, tree): + """ + Test a tiler-less extractor using an instance built using default parameters. + + + Tests reduce-extract + """ + extractor = hollowExtractor(Parameters(**get_params("batgirl_fast"))) + # Load all available functions + extractor._all_funs = load_funs()[2] + extractor._all_cell_funs = load_cellfuns() + extractor.tree = tree + traps = imgs["GFP"] + # Generate mock labels + labels = list(range(masks.shape[2])) + for ch_branches in extractor.params.tree.values(): + print( + extractor.reduce_extract( + red_metrics=ch_branches, + traps=[traps], + masks=[masks], + labels={0: labels}, + ) + ) diff --git a/tests/extraction/test_functions.py b/tests/extraction/test_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..95fd3f0f5d05e3d6854a3ad66f4d6a5b5c771302 --- /dev/null +++ b/tests/extraction/test_functions.py @@ -0,0 +1,24 @@ +import numpy as np +from pathlib import Path +from extraction.core.extractor import Extractor +from extraction.core.parameters import Parameters +from extraction.core.functions.defaults import get_params + +params = Parameters(**get_params("batman_dual")) +ext = Extractor(params, source=19310) +ext.load_funs() + + +def test_custom_output(): + self = ext + mask = np.zeros((6, 6, 2), dtype=bool) + mask[2:4, 2:4, 0] = True + mask[3:5, 3:5, 1] = True + img = np.random.randint(1, 11, size=6 ** 2 * 5).reshape(6, 6, 5) + + for i, f in self._custom_funs.items(): + if "3d" in i: + res = f(mask, img) + else: + res = f(mask, np.maximum.reduce(img, axis=2)) + assert len(res) == mask.shape[2], "Output doesn't match input" diff --git a/tests/extraction/test_mo_bud.py b/tests/extraction/test_mo_bud.py new file mode 100644 index 0000000000000000000000000000000000000000..620cbb31d1e08211a75627cedd9f869b776ae809 --- /dev/null +++ b/tests/extraction/test_mo_bud.py @@ -0,0 +1,43 @@ +import pytest + +import os +from pathlib import Path + +# from copy import copy + +import pickle + +from extraction.core.tracks import get_joinable, get_joint_ids +from extraction.core.extractor import Extractor, ExtractorParameters +from extraction.core.lineage import reassign_mo_bud + + +DATA_DIR = Path(os.path.dirname(os.path.realpath(__file__))) / Path("data") + + +def test_mobud_translation(tracks_pkl=None, mo_bud_pkl=None): + + if tracks_pkl is None: + tracks_pkl = "tracks.pkl" + + if mo_bud_pkl is None: + mo_bud_pkl = "mo_bud.pkl" + + mo_bud_pkl = Path(mo_bud_pkl) + tracks_pkl = Path(tracks_pkl) + + with open(DATA_DIR / tracks_pkl, "rb") as f: + tracks = pickle.load(f) + with open(DATA_DIR / mo_bud_pkl, "rb") as f: + mo_bud = pickle.load(f) + + params = Parameters(**get_params("batman_dual")) + ext = Extractor(params) + + joinable = get_joinable(tracks, **ext.params.merge_tracks) + trans = get_joint_ids(joinable) + + # Check that we have reassigned cell labels + mo_bud2 = reassign_mo_bud(mo_bud, trans) + + assert mo_bud != mo_bud2 diff --git a/tests/extraction/test_tracks.py b/tests/extraction/test_tracks.py new file mode 100644 index 0000000000000000000000000000000000000000..40976fcca875972cf01d5018b0ce64205ce13c46 --- /dev/null +++ b/tests/extraction/test_tracks.py @@ -0,0 +1,31 @@ + +from extraction.core.tracks import load_test_dset, clean_tracks, merge_tracks + +def test_clean_tracks(): + tracks = load_test_dset() + clean = clean_tracks(tracks, min_len=3) + + assert len(clean) < len(tracks) + pass + +def test_merge_tracks_drop(): + tracks = load_test_dset() + + joint_tracks,joint_ids = merge_tracks(tracks,window=3, degree=2, drop=True) + + assert len(joint_tracks)<len(tracks), 'Error when merging' + + assert len(joint_ids), 'No joint ids found' + + pass + +def test_merge_tracks_nodrop(): + tracks = load_test_dset() + + joint_tracks,joint_ids = merge_tracks(tracks,window=3, degree=2, drop=False) + + assert len(joint_tracks)==len(tracks), 'Error when merging' + + assert len(joint_ids), 'No joint ids found' + + pass diff --git a/tests/extraction/test_volume.py b/tests/extraction/test_volume.py new file mode 100644 index 0000000000000000000000000000000000000000..019ad57ec242b73240febb5c15f32db8e26cfb35 --- /dev/null +++ b/tests/extraction/test_volume.py @@ -0,0 +1,74 @@ +import pytest +from skimage.morphology import disk, erosion +from skimage import draw +import numpy as np + +from extraction.core.functions.cell import volume +from extraction.core.functions.cell import min_maj_approximation +from extraction.core.functions.cell import eccentricity + +threshold = 0.01 +radii = list(range(10, 100, 10)) +circularities = np.arange(0.4, 1., 0.1) +eccentricities = np.arange(0, 0.9, 0.1) +rotations = [10, 20, 30, 40, 50, 60, 70, 80, 90] + + +def ellipse(x, y, rotate=0): + shape = (4 * x, 4 * y) + img = np.zeros(shape, dtype=np.uint8) + rr, cc = draw.ellipse(2 * x, 2 * y, x, y, rotation=np.deg2rad(rotate)) + img[rr, cc] = 1 + return img + + +def maj_from_min(min_ax, ecc): + y = np.sqrt(min_ax ** 2 / (1 - ecc ** 2)) + return np.round(y).astype(int) + + +@pytest.mark.parametrize('r', radii) +def test_volume_circular(r): + im = disk(r) + v = volume(im) + real_v = (4 * np.pi * r ** 3) / 3 + err = np.abs(v - real_v) / real_v + assert err < threshold + assert np.isclose(v, real_v, rtol=threshold * real_v) + + +@pytest.mark.parametrize('x', radii) +@pytest.mark.parametrize('ecc', eccentricities) +@pytest.mark.parametrize('rotation', rotations) +def test_volume_ellipsoid(x, ecc, rotation): + y = maj_from_min(x, ecc) + im = ellipse(x, y, rotation) + v = volume(im) + real_v = (4 * np.pi * x * y * x) / 3 + err = np.abs(v - real_v) / real_v + assert err < threshold + assert np.isclose(v, real_v, rtol=threshold * real_v) + return v, real_v + + +@pytest.mark.parametrize('x', radii) +@pytest.mark.parametrize('ecc', eccentricities) +@pytest.mark.parametrize('rotation', rotations) +def test_approximation(x, ecc, rotation): + y = maj_from_min(x, ecc) + im = ellipse(x, y, rotation) + min_ax, maj_ax = min_maj_approximation(im) + assert np.allclose([min_ax, maj_ax], [x, y], + rtol=threshold * min(np.array([x, y]))) + + +@pytest.mark.parametrize('x', radii) +@pytest.mark.parametrize('ecc', eccentricities) +@pytest.mark.parametrize('rotation', rotations) +def test_roundness(x, ecc, rotation): + y = maj_from_min(x, ecc) + real_ecc = np.sqrt(y ** 2 - x ** 2) / y + im = ellipse(x, y, rotation) + e = eccentricity(im) + assert np.isclose(real_ecc, e, rtol=threshold * real_ecc) + diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/utils/argo.py b/utils/argo.py new file mode 100644 index 0000000000000000000000000000000000000000..fcfa08fe5bac58d332b94b8ae1f22bf9f3cd48e8 --- /dev/null +++ b/utils/argo.py @@ -0,0 +1,527 @@ +import io +import operator +from pathlib import Path, PosixPath +from collections import Counter +from datetime import datetime +import re +import csv + +import numpy as np + +from tqdm import tqdm + +from logfile_parser import Parser +from omero.gateway import BlitzGateway, TagAnnotationWrapper + + +class OmeroExplorer: + def __init__(self, host, user, password, min_date=(2020, 6, 1)): + self.conn = BlitzGateway(user, password, host=host) + self.conn.connect() + + self.min_date = min_date + self.backups = {} + self.removed = [] + + @property + def cache(self): + if not hasattr(self, "_cache"): + self._cache = {v.id: get_annotsets(v) for v in self.dsets} + return self._cache + + @property + def raw_log(self): + return {k: v["log"] for k, v in self.cache.items()} + + @property + def raw_log_end(self): + if not hasattr(self, "_raw_log_end"): + self._raw_log_end = {d.id: get_logfile(d) for d in self.dsets} + return self._raw_log_end + + @property + def log(self): + return {k: parse_annot(v, "log") for k, v in self.raw_log.items()} + + @property + def raw_acq(self): + return {k: v["acq"] for k, v in self.cache.items()} + + @property + def acq(self): + return {k: parse_annot(v, "acq") for k, v in self.raw_acq.items()} + + def load(self, min_id=18000, min_date=None): + """ + :min_id: int + :min_date: tuple + """ + if min_date is None: + min_date = self.min_date + self._dsets_bak = [ + d for d in self.conn.getObjects("Dataset") if d.getId() > min_id + ] + + if min_date: + if len(min_date) < 3: + min_date = min_date + tuple([1 for i in range(3 - len(min_date))]) + min_date = datetime(*min_date) + + # sort by dates + dates = [d.getDate() for d in self._dsets_bak] + self._dsets_bak[:] = [a for b, a in sorted(zip(dates, self._dsets_bak))] + + self._dsets_bak = [d for d in self._dsets_bak if d.getDate() >= min_date] + + self.dsets = self._dsets_bak + self.n_dsets + + def dset(self, n): + try: + return [x for x in self.dsets if x.id == n][0] + except: + return + + def channels(self, setkey, present=True): + """ + :setkey: str indicating a set of channels + :present: bool indicating whether the search should or shold not be in the dset + """ + self.dsets = [ + v for v in self.acqs.values() if present == has_channels(v, setkey) + ] + self.n_dsets + + def update_cache(self): + if not hasattr(self, "acq") or not hasattr(self, "log"): + for attr in ["acq", "log"]: + print("Updating raw ", attr) + setattr( + self, + "raw_" + attr, + {v.id: get_annotsets(v)[attr] for v in self.dsets}, + ) + setattr( + self, + attr, + { + v.id: parse_annot(getattr(self, "raw_" + attr)[v.id], attr) + for v in self.dsets + }, + ) + else: + + for attr in ["acq", "log", "raw_acq", "raw_log"]: + setattr( + self, attr, {i.id: getattr(self, attr)[i.id] for i in self.dsets} + ) + + @property + def dsets(self): + if not hasattr(self, "_dsets"): + self._dsets = self.load() + + return self._dsets + + @dsets.setter + def dsets(self, dsets): + if hasattr(self, "_dsets"): + if self._dsets is None: + self._dsets = [] + self.removed += [ + x for x in self._dsets if x.id not in [y.id for y in dsets] + ] + + self._dsets = dsets + + def tags(self, tags, present=True): + """ + :setkey str tags to filter + """ + if type(tags) is not list: + tags = [str(tags)] + + self.dsets = [v for v in self.dsets if present == self.has_tags(v, tags)] + self.n_dsets + + @property + def all_tags(self): + if not hasattr(self, "_tags"): + self._tags = { + d.id: [ + x.getValue() + for x in d.listAnnotations() + if isinstance(x, TagAnnotationWrapper) + ] + for d in self.dsets + } + return self._tags + + def get_timepoints(self): + self.image_wrappers = {d.id: list(d.listChildren())[0] for d in self.dsets} + + return {k: i.getSizeT() for k, i in self.image_wrappers.items()} + + def timepoints(self, n, op="greater"): + "Filter experiments using the number of timepoints" + op = operator.gt if op == "greater" else operator.le + self._timepoints = self.get_timepoints() + + self.dsets = [v for v in tqdm(self.dsets) if op(self._timepoints[v.id], n)] + + def microscope(self, microscope): + self.microscopes = { + dset.id: self.get_microscope(self.log[dset.id]) for dset in self.dsets + } + + self.n_dsets + + def get_microscope(self, parsed_log): + return parsed_log["microscope"] + + def reset(self, backup_id=None): + self.dsets = self.backups.get(backup_id, self._dsets_bak) + self.n_dsets + + def backup(self, name): + self.backups[name] = self.dsets + + def reset_backup(self, name): + self.dsets = self.backups[name] + + def cExperiment(self, present=True): + self.dsets = [ + v + for v in self.dsets + if present + * sum( + [ + "cExperiment" in x.getFileName() + for x in v.listAnnotations() + if hasattr(x, "getFileName") + ] + ) + ] + self.n_dsets + + @staticmethod + def is_complete(logfile): + return logfile.endswith("Experiment completed\r\r\n") + + @staticmethod + def contains_regex(logfile): + pass + # return re. + + def tiler_cells(self, present=True): + self.__dsets = [v for v in self.dsets if present == tiler_cells_load(v)] + + @property + def n_dsets(self): + print("{} datasets.".format(len(self.dsets))) + + @property + def desc(self): + for d in self.dsets: + print( + "{}\t{}\t{}\t{}".format( + d.getDate().strftime("%x"), + d.getId(), + d.getName(), + d.getDescription(), + ) + ) + + @property + def ids(self): + return [d.getId() for d in self.dsets] + + # @property + # def acqs(self): + # if not hasattr(self, "_acqs") or len(self.__dict__.get("_acqs", [])) != len( + # self.dsets + # ): + # self._acqs = [get_annot(get_annotsets(d), "acq") for d in self.dsets] + # return self._acqs + + def get_ph_params(self): + t = [ + { + ch: [exp, v] + for ch, exp, v in zip(j["channel"], j["exposure"], j["voltage"]) + if ch in {"GFPFast", "pHluorin405"} + } + for j in [i["channels"] for i in self.acqs] + ] + + ph_param_pairs = [(tuple(x.values())) for x in t if np.all(list(x.values()))] + + return Counter([str(x) for x in ph_param_pairs]) + + def find_duplicate_candidates(self, days_tol=2): + # Find experiments with the same name or Aim and from similar upload dates + # and group them for cleaning + pass + + def group_by_date(tol=1): + dates = [x.getDate() for x in self.dsets] + distances = np.array( + [[abs(convert_to_hours(a - b)) for a in dates] for b in dates] + ) + return explore_booldiag(distances > tol, 0, []) + + @property + def complete(self): + self.completed = {k: self.is_complete(v) for k, v in self.raw_log_end.items()} + self.dsets = [dset for dset in self.dsets if self.completed[dset.id]] + return self.n_dsets + + def save(self, fname): + with open(fname + ".tsv", "w") as f: + writer = csv.writer(f, delimiter="\t") + for d in self.dsets: + writer.writerow( + [ + d.getDate().strftime("%x"), + d.getId(), + d.getName(), + d.getDescription(), + ] + ) + + @property + def positions(self): + return {x.id: len(list(x.listChildren())) for x in self.dsets} + + def has_tags(self, d, tags): + if set(tags).intersection(self.all_tags[d.id]): + return True + + +def explore_booldiag(bool_field, current_position, cluster_start_end): + # Recursively find the square clusters over the diagonal. Allows for duplicates + # returns a list of tuples with the start, end of clusters + if current_position < len(bool_field) - 1: + elements = np.where(bool_field[current_position]) + if len(elements[0]) > 1: + start = elements[0][0] + end = elements[0][-1] + else: + start = elements[0][0] + end = elements[0][0] + + cluster_start_end.append((start, end)) + return explore_square(bool_field, end + 1, cluster_start_end) + else: + return cluster_start_end + _ + + +def convert_to_hours(delta): + total_seconds = delta.total_seconds() + hours = int(total_seconds // 3600) + return hours + + +class Argo(OmeroExplorer): + def __init__(self,*args, **kwargs): + super().__init__(*args,**kwargs) + + +def get_creds(): + return ( + "upload", + "***REMOVED***", # OMERO Password + ) + + +def list_files(dset): + return {x for x in dset.listAnnotations() if hasattr(x, "getFileName")} + + +def annot_from_dset(dset, kind): + v = [x for x in dset.listAnnotations() if hasattr(x, "getFileName")] + infname = kind if kind is "log" else kind.title() + try: + acqfile = [x for x in v if x.getFileName().endswith(infname + ".txt")][0] + decoded = list(acqfile.getFileInChunks())[0].decode("utf-8") + acq = parse_annot(decoded, kind) + except: + return {} + + return acq + + +def check_channels(acq, channels, _all=True): + I = set(acq["channels"]["channel"]).intersection(channels) + + condition = False + if _all: + if len(I) == len(channels): + condition = True + else: + if len(I): + condition = True + + return condition + + +def get_chs(exptype): + exptypes = { + "dual_ph": ("GFP", "pHluorin405", "mCherry"), + "ph": ("GFP", "pHluorin405"), + "dual": ("GFP", "mCherry"), + "mCherry": ("mCherry"), + } + return exptypes[exptype] + + +def load_annot_from_cache(exp_id, cache_dir="cache/"): + if type(cache_dir) is not PosixPath: + cache_dir = Path(cache_dir) + + annot_sets = {} + for fname in cache_dir.joinpath(exp_id).rglob("*.txt"): + fmt = fname.name[:3] + with open(fname, "r") as f: + annot_sets[fmt] = f.read() + + return annot_sets + + +def get_annot(annot_sets, fmt): + """ + Get parsed annotation file + """ + str_io = annot_sets.get(fmt, None) + return parse_annot(str_io, fmt) + + +def parse_annot(str_io, fmt): + parser = Parser("multiDGUI_" + fmt + "_format") + return parser.parse(io.StringIO(str_io)) + + +def get_log_date(annot_sets): + log = get_annot(annot_sets, "log") + return log.get("date", None) + + +def get_log_microscope(annot_sets): + log = get_annot(annot_sets, "log") + return log.get("microscope", None) + + +def get_annotsets(dset): + annot_files = [ + annot.getFile() for annot in dset.listAnnotations() if hasattr(annot, "getFile") + ] + annot_sets = { + ftype[:-4].lower(): annot + for ftype in ("log.txt", "Acq.txt", "Pos.txt") + for annot in annot_files + if ftype in annot.getName() + } + annot_sets = { + key: list(a.getFileInChunks())[0].decode("utf-8") + for key, a in annot_sets.items() + } + return annot_sets + + +# def has_tags(d, tags): +# if set(tags).intersection(annot_from_dset(d, "log").get("omero_tags", [])): +# return True + + +def load_acq(dset): + try: + acq = annot_from_dset(dset, kind="acq") + return acq + except: + print("dset", dset.getId(), " failed acq loading") + return False + + +def has_channels(dset, exptype): + acq = load_acq(dset) + if acq: + return check_channels(acq, get_chs(exptype)) + else: + return + + +def get_id_from_name(exp_name, conn=None): + if conn is None: + conn = BlitzGateway(*get_creds(), host="islay.bio.ed.ac.uk", port=4064) + + if not conn.isConnected(): + conn.connect() + + cand_dsets = [ + d + for d in conn.getObjects("Dataset") # , opts={'offset': 10600, + # 'limit':500}) + if exp_name in d.name + ] # increase the offset for better speed + + # return cand_dsets + if len(cand_dsets) > 1: + # Get date and try to find it using date and microscope name and date + + # found = [] + # for cand in cand_dsets: + # annot_sets = get_annotsets(cand) + # date = get_log_date(annot_sets) + # microscope = get_log_microscope(annot_sets) + # if date==date_name and microscope == microscope_name: + # found.append(cand) + + # if True:#len(found)==1: + # return best_cand.id#best_cand = found[0] + if True: + + print("Multiple options found. Selecting the one with most children") + + max_dset = np.argmax( + [ + len(list(conn.getObject("Dataset", c.id).listChildren())) + for c in cand_dsets + ] + ) + + best_cand = cand_dsets[max_dset] + + return best_cand.id + elif len(cand_dsets) == 1: + return cand_dsets[0].id + + +# Custom functions +def compare_dsets_voltages_exp(dsets): + a = {} + for d in dsets: + try: + acq = annot_from_dset(d, kind="acq")["channels"] + a[d.getId()] = { + k: (v, e) + for k, v, e in zip(acq["channel"], acq["voltage"], acq["exposure"]) + } + + except: + print(d, "didnt work") + + return a + + +def get_logfile(dset): + annot_file = [ + annot.getFile() + for annot in dset.listAnnotations() + if hasattr(annot, "getFile") and annot.getFileName().endswith("log.txt") + ][0] + return list(annot_file.getFileInChunks())[-1].decode("utf-8") + + +# 19920 -> 19300/19310 +#