Skip to content
Snippets Groups Projects
extractor.py 19.7 KiB
Newer Older
import typing as t
from time import perf_counter
from typing import Callable, Dict, List
Alán Muñoz's avatar
Alán Muñoz committed
import numpy as np
import pandas as pd
from agora.abc import ParametersABC, ProcessABC
from agora.io.cells import Cells
from agora.io.writer import Writer, load_attributes
from aliby.tile.tiler import Tiler
from extraction.core.functions.defaults import exparams_from_meta
from extraction.core.functions.distributors import reduce_z, trap_apply
Alán Muñoz's avatar
Alán Muñoz committed
from extraction.core.functions.loaders import (
    load_custom_args,
Alán Muñoz's avatar
Alán Muñoz committed
    load_mergefuns,
Alán Muñoz's avatar
Alán Muñoz committed
)
from extraction.core.functions.utils import depth

CELL_FUNS, TRAPFUNS, FUNS = load_funs()
CUSTOM_FUNS, CUSTOM_ARGS = load_custom_args()
RED_FUNS = load_redfuns()
MERGE_FUNS = load_mergefuns()

# Assign datatype depending on the metric used
Alán Muñoz's avatar
Alán Muñoz committed
# m2type = {"mean": np.float32, "median": np.ubyte, "imBackground": np.ubyte}
Alán Muñoz's avatar
Alán Muñoz committed


class ExtractorParameters(ParametersABC):
    """
    Base class to define parameters for extraction
    """

    def __init__(
        self,
        tree: Dict[str, Dict[Callable, List[str]]] = None,
Alán Muñoz's avatar
Alán Muñoz committed
        sub_bg: set = set(),
        multichannel_ops: Dict = {},
    ):
pswain's avatar
pswain committed
        """
        Parameters
        ----------
        tree: dict
            Nested dictionary indicating channels, reduction functions and
            metrics to be used.
            str channel -> U(function,None) reduction -> str metric
            If not of depth three, tree will be filled with Nones.
        sub_bg: set
        multichannel_ops: dict
        """
Alán Muñoz's avatar
Alán Muñoz committed
        self.tree = fill_tree(tree)
        self.sub_bg = sub_bg
        self.multichannel_ops = multichannel_ops

    @staticmethod
    def guess_from_meta(store_name: str, suffix="fast"):
        """
pswain's avatar
pswain committed
        Find the microscope used from the h5 metadata
pswain's avatar
pswain committed
        Parameters
        ----------
        store_name : str or Path
            For a h5 file
        suffix : str
            Added at the end of the predicted parameter set
Alán Muñoz's avatar
Alán Muñoz committed
        """
pswain's avatar
pswain committed
        with h5py.File(store_name, "r") as f:
            microscope = f["/"].attrs.get("microscope")
Alán Muñoz's avatar
Alán Muñoz committed
        assert microscope, "No metadata found"
        return "_".join((microscope, suffix))

    @classmethod
    def default(cls):
        return cls({})

    @classmethod
    def from_meta(cls, meta):
        return cls(**exparams_from_meta(meta))


class Extractor(ProcessABC):
    """
pswain's avatar
pswain committed
    The Extractor applies a metric, such as area or median, to cells identified in the image tiles using the cell masks.

    Its methods therefore require both tile images and masks.

    Usually one metric is applied per mask, but there are tile-specific backgrounds (Alan), which apply one metric per tile.

    Extraction follows a three-level tree structure. Channels, such as GFP, are the root level; the second level is the reduction algorithm, such as maximum projection; the last level is the metric - the specific operation to apply to the cells in the image identified by the mask, such as median, which is the median value of the pixels in each cell.
pswain's avatar
pswain committed
    parameters: core.extractor Parameters
        Parameters that include with channels, reduction and
        extraction functions to use.
    store: str
        Path to hdf5 storage file. Must contain cell outlines.
    tiler: pipeline-core.core.segmentation tiler
        Class that contains or fetches the image to be used for segmentation.
pswain's avatar
pswain committed
    default_meta = {
        "pixel_size": 0.236,
        "z_size": 0.6,
        "spacing": 0.6,
    }
Alán Muñoz's avatar
Alán Muñoz committed
    def __init__(
        self,
        parameters: ExtractorParameters,
        store: str = None,
        tiler: Tiler = None,
Alán Muñoz's avatar
Alán Muñoz committed
    ):
pswain's avatar
pswain committed
        """
        Initialise Extractor.

        Parameters
        ----------
        parameters: ExtractorParameters object
        store: str
            Name of h5 file
        tiler: Tiler object
        """
Alán Muñoz's avatar
Alán Muñoz committed
        self.params = parameters
Alán Muñoz's avatar
Alán Muñoz committed
        if store:
            self.local = store
            self.load_meta()
pswain's avatar
pswain committed
        else:
            # if no h5 file, use the parameters directly
Alán Muñoz's avatar
Alán Muñoz committed
            self.meta = {"channel": parameters.to_dict()["tree"].keys()}
        if tiler:
            self.tiler = tiler
Alán Muñoz's avatar
Alán Muñoz committed
        self.load_funs()

    @classmethod
pswain's avatar
pswain committed
        cls,
        parameters: ExtractorParameters,
        store: str,
        tiler: Tiler,
pswain's avatar
pswain committed
        # initate from tiler
Alán Muñoz's avatar
Alán Muñoz committed
        return cls(parameters, store=store, tiler=tiler)

    @classmethod
pswain's avatar
pswain committed
        cls,
        parameters: ExtractorParameters,
        store: str,
        img_meta: tuple,
pswain's avatar
pswain committed
        # initiate from image
Alán Muñoz's avatar
Alán Muñoz committed
        return cls(parameters, store=store, tiler=Tiler(*img_meta))

    @property
    def channels(self):
pswain's avatar
pswain committed
        # returns a tuple of strings of the available channels
Alán Muñoz's avatar
Alán Muñoz committed
        if not hasattr(self, "_channels"):
            if type(self.params.tree) is dict:
                self._channels = tuple(self.params.tree.keys())
        return self._channels

    @property
    def current_position(self):
        return self.local.split("/")[-1][:-3]

    @property
pswain's avatar
pswain committed
    def group(self):
        # returns path within h5 file
Alán Muñoz's avatar
Alán Muñoz committed
        if not hasattr(self, "_out_path"):
            self._group = "/extraction/"
        return self._group

    def load_custom_funs(self):
        """
pswain's avatar
pswain committed
        Load any necessary parameters for functions.
        These parameters must be loaded within an Extractor instance because they depend on the experiment's metadata.

        Alan: this is challenging to follow!
Alán Muñoz's avatar
Alán Muñoz committed
        """
        funs = set(
            [
                fun
                for ch in self.params.tree.values()
                for red in ch.values()
                for fun in red
            ]
        )
        funs = funs.intersection(CUSTOM_FUNS.keys())
        ARG_VALS = {
            k: {k2: self.get_meta(k2) for k2 in v}
            for k, v in CUSTOM_ARGS.items()
Alán Muñoz's avatar
Alán Muñoz committed
        }
        # self._custom_funs = {trap_apply(CUSTOM_FUNS[fun],])
        self._custom_funs = {}
        for k, f in CUSTOM_FUNS.items():

            def tmp(f):
                return lambda m, img: trap_apply(
                    f, m, img, **ARG_VALS.get(k, {})
                )
Alán Muñoz's avatar
Alán Muñoz committed

            self._custom_funs[k] = tmp(f)

    def load_funs(self):
        self.load_custom_funs()
        self._all_cell_funs = set(self._custom_funs.keys()).union(CELL_FUNS)
        self._all_funs = {**self._custom_funs, **FUNS}

    def load_meta(self):
pswain's avatar
pswain committed
        # load metadata from h5 file whose name is given by self.local
Alán Muñoz's avatar
Alán Muñoz committed
        self.meta = load_attributes(self.local)

    def get_traps(
        self, tp: int, channels: list = None, z: list = None, **kwargs
    ) -> tuple:
pswain's avatar
pswain committed
        """
        Finds traps for a given time point and given channels and z-stacks.
        Returns None if no traps are found.

        Any additional keyword arguments are passed to tiler.get_traps_timepoint

        Parameters
        ----------
        tp: int
            Time point of interest
        channels: list of strings (optional)
            Channels of interest
        z: list of integers (optional)
            Indices for the z-stacks of interest
        """
Alán Muñoz's avatar
Alán Muñoz committed
        if channels is None:
pswain's avatar
pswain committed
            # find channels from tiler
Alán Muñoz's avatar
Alán Muñoz committed
            channel_ids = list(range(len(self.tiler.channels)))
        elif len(channels):
pswain's avatar
pswain committed
            # a subset of channels was specified
Alán Muñoz's avatar
Alán Muñoz committed
            channel_ids = [self.tiler.get_channel_index(ch) for ch in channels]
        else:
pswain's avatar
pswain committed
            # oh oh
Alán Muñoz's avatar
Alán Muñoz committed
            channel_ids = None
pswain's avatar
pswain committed
        # a list of the indices of the z stacks
Alán Muñoz's avatar
Alán Muñoz committed
        if z is None:
            z = list(range(self.tiler.shape[-1]))
pswain's avatar
pswain committed
        # find the appropiate traps from tiler
Alán Muñoz's avatar
Alán Muñoz committed
        traps = (
            self.tiler.get_traps_timepoint(
                tp, channels=channel_ids, z=z, **kwargs
            )
Alán Muñoz's avatar
Alán Muñoz committed
            if channel_ids
            else None
        )
        return traps

    def extract_traps(
        self,
        traps: List[np.array],
        masks: List[np.array],
        metric: str,
        labels: List[int] = None,
    ) -> dict:
        """
        Apply a function for a whole position.

        :traps: List[np.array] list of images
        :masks: List[np.array] list of masks
        :metric:str metric to extract
        :labels: List[int] cell Labels to use as indices for output DataFrame
        :pos_info: bool Whether to add the position as index or not

        returns
        :d: Dictionary of dataframe
        """

        if labels is None:
            raise Warning("No labels given. Sorting cells using index.")

        cell_fun = True if metric in self._all_cell_funs else False

        idx = []
        results = []

        for trap_id, (mask_set, trap, lbl_set) in enumerate(
            zip(masks, traps, labels.values())
        ):
            if len(mask_set):  # ignore empty traps
                result = self._all_funs[metric](mask_set, trap)
                if cell_fun:
                    for lbl, val in zip(lbl_set, result):
                        results.append(val)
                        idx.append((trap_id, lbl))
                else:
                    results.append(result)
                    idx.append(trap_id)

        return (tuple(results), tuple(idx))

    def extract_funs(
        self,
        traps: List[np.array],
        masks: List[np.array],
        metrics: List[str],
        **kwargs,
Alán Muñoz's avatar
Alán Muñoz committed
    ) -> dict:
        """
        Extract multiple metrics from a timepoint
        """
        d = {
            metric: self.extract_traps(
                traps=traps, masks=masks, metric=metric, **kwargs
            )
            for metric in metrics
        }
        return d

    def reduce_extract(
        self,
        traps: np.array,
        masks: list,
        red_metrics: dict,
        **kwargs,
Alán Muñoz's avatar
Alán Muñoz committed
    ) -> dict:
        """
Alán Muñoz's avatar
Alán Muñoz committed
        Wrapper to apply reduction and then extraction.

        Parameters
        ----------
Alán Muñoz's avatar
Alán Muñoz committed
        :param red_metrics: dict in which keys are reduction funcions and
        values are strings indicating the metric function
        :**kwargs: All other arguments, must include masks and traps.
Alán Muñoz's avatar
Alán Muñoz committed

        Returns
        ------
        Dictionary of dataframes with the corresponding reductions and metrics nested.

Alán Muñoz's avatar
Alán Muñoz committed
        """

        reduced_traps = {}
        if traps is not None:
            for red_fun in red_metrics.keys():
                reduced_traps[red_fun] = [
                    self.reduce_dims(trap, method=RED_FUNS[red_fun])
                    for trap in traps
Alán Muñoz's avatar
Alán Muñoz committed
                ]

        d = {
            red_fun: self.extract_funs(
                metrics=metrics,
                traps=reduced_traps.get(red_fun, [None for _ in masks]),
                masks=masks,
                **kwargs,
            )
            for red_fun, metrics in red_metrics.items()
        }
        return d

    def reduce_dims(self, img: np.array, method=None) -> np.array:
Alán Muñoz's avatar
Alán Muñoz committed
        """
        Collapse a z-stack into a single file. It may perform a null operation.
        """
Alán Muñoz's avatar
Alán Muñoz committed
        if method is None:
            return img
pswain's avatar
pswain committed
        else:
            return reduce_z(img, method)
Alán Muñoz's avatar
Alán Muñoz committed

    def extract_tp(
        self,
        tp: int,
        tree: dict = None,
        tile_size: int = 117,
        masks=None,
        labels=None,
        **kwargs,
    ) -> t.Dict[str, t.Dict[str, pd.Series]]:
        """
        Core extraction method for an individual time-point.

        Parameters
        ----------
        tp : int
            Time point being analysed.
        tree : dict
            Nested dictionary indicating channels, reduction functions and
            metrics to be used.
        tile_size : int
            size of the tile to be extracted.
        masks : np.ndarray
            A 3-D boolean numpy array with dimensions (ncells, tile_size,
            tile_size).
        labels : t.List[t.List[int]]
            List of lists of ints indicating the ids of masks.
        **kwargs : Additional keyword arguments to be passed to extractor.reduce_extract.

        Returns
        -------
        dict

        Examples
        --------
        FIXME: Add docs.

Alán Muñoz's avatar
Alán Muñoz committed

Alán Muñoz's avatar
Alán Muñoz committed
        """
        if tree is None:
            # use default
Alán Muñoz's avatar
Alán Muñoz committed
            tree = self.params.tree
        # dictionary with channel: {reduction algorithm : metric}
Alán Muñoz's avatar
Alán Muñoz committed
        ch_tree = {ch: v for ch, v in tree.items() if ch != "general"}
        # tuple of the channels
Alán Muñoz's avatar
Alán Muñoz committed
        tree_chs = (*ch_tree,)
pswain's avatar
pswain committed
        # create a Cells object to extract information from the h5 file
        cells = Cells(self.local)
pswain's avatar
pswain committed
        # find the cell labels and store as dict with trap_ids as keys
        if labels is None:
            raw_labels = cells.labels_at_time(tp)
            labels = {
                trap_id: raw_labels.get(trap_id, [])
                for trap_id in range(cells.ntraps)
pswain's avatar
pswain committed
        # find the cell masks as a dict with trap_ids as keys
        if masks is None:
            raw_masks = cells.at_time(tp, kind="mask")
            masks = {trap_id: [] for trap_id in range(cells.ntraps)}
            for trap_id, cells in raw_masks.items():
                if len(cells):
                    masks[trap_id] = np.dstack(np.array(cells)).astype(bool)
pswain's avatar
pswain committed
        # convert to a list of masks
Alán Muñoz's avatar
Alán Muñoz committed
        masks = [np.array(v) for v in masks.values()]

pswain's avatar
pswain committed
        # find traps at the time point
        traps = self.get_traps(tp, tile_shape=tile_size, channels=tree_chs)
Alán Muñoz's avatar
Alán Muñoz committed

        self.img_bgsub = {}
        if self.params.sub_bg:
pswain's avatar
pswain committed
            # generate boolean masks for background
Alán Muñoz's avatar
Alán Muñoz committed
            bg = [
                ~np.sum(m, axis=2).astype(bool)
                if np.any(m)
                else np.zeros((tile_size, tile_size))
                for m in masks
            ]

        d = {}
        for ch, red_metrics in tree.items():
            img = None
            # ch != is necessary for threading
            if ch != "general" and traps is not None and len(traps):
                img = traps[:, tree_chs.index(ch), 0]

            d[ch] = self.reduce_extract(
                red_metrics=red_metrics,
                traps=img,
                masks=masks,
                labels=labels,
                **kwargs,
Alán Muñoz's avatar
Alán Muñoz committed
            )

            if (
                ch in self.params.sub_bg and img is not None
            ):  # Calculate metrics with subtracted bg
                ch_bs = ch + "_bgsub"

                self.img_bgsub[ch_bs] = []
                for trap, maskset in zip(img, bg):

                    cells_fl = np.zeros_like(trap)

                    is_cell = np.where(maskset)
                    if len(is_cell[0]):  # skip calculation for empty traps
                        cells_fl = np.median(trap[is_cell], axis=0)

                    self.img_bgsub[ch_bs].append(trap - cells_fl)

                d[ch_bs] = self.reduce_extract(
                    red_metrics=ch_tree[ch],
                    traps=self.img_bgsub[ch_bs],
                    masks=masks,
                    labels=labels,
                    **kwargs,
                )

        # Additional operations between multiple channels (e.g. pH calculations)
        for name, (
            chs,
            merge_fun,
            red_metrics,
        ) in self.params.multichannel_ops.items():
Alán Muñoz's avatar
Alán Muñoz committed
            if len(
                set(chs).intersection(
                    set(self.img_bgsub.keys()).union(tree_chs)
                )
Alán Muñoz's avatar
Alán Muñoz committed
            ) == len(chs):
                imgs = [self.get_imgs(ch, traps, tree_chs) for ch in chs]
                merged = MERGE_FUNS[merge_fun](*imgs)
                d[name] = self.reduce_extract(
                    red_metrics=red_metrics,
                    traps=merged,
                    masks=masks,
                    labels=labels,
                    **kwargs,
                )

        return d

    def get_imgs(self, channel, traps, channels=None):
        """
        Returns the image from a correct source, either raw or bgsub

        :channel: str name of channel to get
        :img: ndarray (trap_id, channel, tp, tile_size, tile_size, n_zstacks) of standard channels
        :channels: List of channels
        """

        if channels is None:
            channels = (*self.params.tree,)

        if channel in channels:
            return traps[:, channels.index(channel), 0]
        elif channel in self.img_bgsub:
            return self.img_bgsub[channel]

Alán Muñoz's avatar
Alán Muñoz committed
    def run_tp(self, tp, **kwargs):
        """
        Wrapper to add compatiblibility with other pipeline steps
        """
        return self.run(tps=[tp], **kwargs)

    def run(
        self, tree=None, tps: List[int] = None, save=True, **kwargs
    ) -> dict:
Alán Muñoz's avatar
Alán Muñoz committed

        if tree is None:
            tree = self.params.tree

        if tps is None:
            tps = list(range(self.meta["time_settings/ntimepoints"][0]))
Alán Muñoz's avatar
Alán Muñoz committed

        d = {}
        for tp in tps:
            new = flatten_nest(
                self.extract_tp(tp=tp, tree=tree, **kwargs),
                to="series",
                tp=tp,
            )

            for k in new.keys():
                n = new[k]
                d[k] = pd.concat((d.get(k, None), n), axis=1)

        for k in d.keys():
            indices = ["experiment", "position", "trap", "cell_label"]
            idx = (
                indices[-d[k].index.nlevels :]
                if d[k].index.nlevels > 1
                else [indices[-2]]
            )
            d[k].index.names = idx

            toreturn = d

        if save:
            self.save_to_hdf(toreturn)

        return toreturn

    def extract_pos(
        self, tree=None, tps: List[int] = None, save=True, **kwargs
    ) -> dict:

        if tree is None:
            tree = self.params.tree

        if tps is None:
            tps = list(range(self.meta["time_settings/ntimepoints"]))

        d = {}
        for tp in tps:
            new = flatten_nest(
                self.extract_tp(tp=tp, tree=tree, **kwargs),
                to="series",
                tp=tp,
            )

            for k in new.keys():
                n = new[k]
                d[k] = pd.concat((d.get(k, None), n), axis=1)

        for k in d.keys():
            indices = ["experiment", "position", "trap", "cell_label"]
            idx = (
                indices[-d[k].index.nlevels :]
                if d[k].index.nlevels > 1
                else [indices[-2]]
            )
            d[k].index.names = idx

            toreturn = d

        if save:
            self.save_to_hdf(toreturn)

        return toreturn

    def save_to_hdf(self, group_df, path=None):
        if path is None:
            path = self.local

        self.writer = Writer(path)
        for path, df in group_df.items():
            dset_path = "/extraction/" + path
            self.writer.write(dset_path, df)
        self.writer.id_cache.clear()

    def get_meta(self, flds):
        if not hasattr(flds, "__iter__"):
            flds = [flds]
        meta_short = {k.split("/")[-1]: v for k, v in self.meta.items()}
        return {
            f: meta_short.get(f, self.default_meta.get(f, None)) for f in flds
        }
Alán Muñoz's avatar
Alán Muñoz committed


### Helpers
def flatten_nest(nest: dict, to="series", tp: int = None) -> dict:
    """
    Convert a nested extraction dict into a dict of series
    :param nest: dict contained the nested results of extraction
    :param to: str = 'series' Determine output format, either list or  pd.Series
    :param tp: int timepoint used to name the series
    """

    d = {}
    for k0, v0 in nest.items():
        for k1, v1 in v0.items():
            for k2, v2 in v1.items():
                d["/".join((k0, k1, k2))] = (
                    pd.Series(*v2, name=tp) if to == "series" else v2
                )

    return d


def fill_tree(tree):
    if tree is None:
        return None
    tree_depth = depth(tree)
    if depth(tree) < 3:
        d = {None: {None: {None: []}}}
        for _ in range(2 - tree_depth):
            d = d[None]
        d[None] = tree
        tree = d
    return tree
Alán Muñoz's avatar
Alán Muñoz committed


class hollowExtractor(Extractor):
    """Extractor that only cares about receiving image and masks,
    used for testing.
    """

    def __init__(self, parameters):
        self.params = parameters