Skip to content
Snippets Groups Projects
extractor.py 29.5 KiB
Newer Older
import typing as t
import bottleneck as bn
Alán Muñoz's avatar
Alán Muñoz committed
import numpy as np
import pandas as pd
import aliby.global_parameters as global_parameters
from agora.abc import ParametersABC, StepABC
from agora.io.cells import Cells
from agora.io.writer import Writer, load_meta
from aliby.tile.tiler import Tiler, find_channel_name
from extraction.core.functions.distributors import reduce_z, trap_apply
Alán Muñoz's avatar
Alán Muñoz committed
from extraction.core.functions.loaders import (
    load_custom_args,
    load_funs,
    load_redfuns,
# define types
reduction_method = t.Union[t.Callable, str, None]
extraction_tree = t.Dict[
    str, t.Dict[reduction_method, t.Dict[str, t.Collection]]
]
extraction_result = t.Dict[
    str, t.Dict[reduction_method, t.Dict[str, t.Dict[str, pd.Series]]]
]

pswain's avatar
pswain committed
# Global variables used to load functions that either analyse cells
# or their background. These global variables both allow the functions
# to be stored in a dictionary for access only on demand and to be
# defined simply in extraction/core/functions.
CELL_FUNS, TRAP_FUNS, ALL_FUNS = load_funs()
Alán Muñoz's avatar
Alán Muñoz committed
CUSTOM_FUNS, CUSTOM_ARGS = load_custom_args()
REDUCTION_FUNS = load_redfuns()


def extraction_params_from_meta(
    meta: t.Union[dict, Path, str], extras: t.Collection[str] = ["ph"]
):
    """Obtain parameters for extraction from meta data."""
    if not isinstance(meta, dict):
        # load meta data
        with h5py.File(meta, "r") as f:
            meta = dict(f["/"].attrs.items())
    base = {
        "tree": {"general": {"None": ["area", "volume", "eccentricity"]}},
        "multichannel_ops": {},
    }
    candidate_channels = set(global_parameters.possible_imaging_channels)
    default_reductions = {"max"}
    default_metrics = set(global_parameters.fluorescence_functions)
    default_reduction_metrics = {
        r: default_metrics for r in default_reductions
    }
    # default_rm["None"] = ["nuc_conv_3d"] # Uncomment this to add nuc_conv_3d (slow)
    extant_fluorescence_ch = []
    for av_channel in candidate_channels:
        # find matching channels in metadata
        found_channel = find_channel_name(meta.get("channels", []), av_channel)
        if found_channel is not None:
            extant_fluorescence_ch.append(found_channel)
    for ch in extant_fluorescence_ch:
        base["tree"][ch] = default_reduction_metrics
    base["sub_bg"] = extant_fluorescence_ch
    return base
Alán Muñoz's avatar
Alán Muñoz committed


class ExtractorParameters(ParametersABC):
    """Base class to define parameters for extraction."""
Alán Muñoz's avatar
Alán Muñoz committed

    def __init__(
        self,
        tree: extraction_tree,
Alán Muñoz's avatar
Alán Muñoz committed
        sub_bg: set = set(),
        multichannel_ops: t.Dict = {},
Alán Muñoz's avatar
Alán Muñoz committed
    ):
pswain's avatar
pswain committed
        """
pswain's avatar
pswain committed
        Parameters
        ----------
        tree: dict
            Nested dictionary indicating channels, reduction functions and
            metrics to be used.
            str channel -> U(function, None) reduction -> str metric
            If not of depth three, tree will be filled with None.
pswain's avatar
pswain committed
        sub_bg: set
        multichannel_ops: dict
        """
        self.tree = tree
Alán Muñoz's avatar
Alán Muñoz committed
        self.sub_bg = sub_bg
        self.multichannel_ops = multichannel_ops

    @classmethod
    def default(cls):
        return cls({})

    @classmethod
    def from_meta(cls, meta):
        """Instantiate from the meta data; used by Pipeline."""
        return cls(**extraction_params_from_meta(meta))
class Extractor(StepABC):
Alán Muñoz's avatar
Alán Muñoz committed
    """
    Apply a metric to cells identified in the tiles.
pswain's avatar
pswain committed

pswain's avatar
pswain committed
    Using the cell masks, the Extractor applies a metric, such as
    area or median, to cells identified in the image tiles.
pswain's avatar
pswain committed

    Its methods require both tile images and masks.
pswain's avatar
pswain committed

pswain's avatar
pswain committed
    Usually the metric is applied to only a tile's masked area, but
    some metrics depend on the whole tile.
pswain's avatar
pswain committed
    Extraction follows a three-level tree structure. Channels, such
    as GFP, are the root level; the reduction algorithm, such as
    maximum projection, is the second level; the specific metric,
    or operation, to apply to the masks, such as mean, is the third
    or leaf level.
    default_meta = global_parameters.imaging_specifications
Alán Muñoz's avatar
Alán Muñoz committed
    def __init__(
        self,
        parameters: ExtractorParameters,
        store: t.Optional[str] = None,
        tiler: t.Optional[Tiler] = None,
Alán Muñoz's avatar
Alán Muñoz committed
    ):
pswain's avatar
pswain committed
        """
        Initialise Extractor.

        Parameters
        ----------
        parameters: core.extractor Parameters
            Parameters that include the channels, reduction and
            extraction functions.
pswain's avatar
pswain committed
        store: str
            Path to the h5 file containing the cell masks.
        tiler: pipeline-core.core.segmentation tiler
pswain's avatar
pswain committed
            Class that contains or fetches the images used for
            segmentation.
pswain's avatar
pswain committed
        """
Alán Muñoz's avatar
Alán Muñoz committed
        self.params = parameters
Alán Muñoz's avatar
Alán Muñoz committed
        if store:
            self.h5path = store
            self.meta = load_meta(self.h5path)
pswain's avatar
pswain committed
        else:
            # if no h5 file, use the parameters directly
Alán Muñoz's avatar
Alán Muñoz committed
            self.meta = {"channel": parameters.to_dict()["tree"].keys()}
        if tiler:
            self.tiler = tiler
            available_channels = set((*tiler.channels, "general"))
            # only extract for channels available
            self.params.tree = {
                k: v
                for k, v in self.params.tree.items()
                if k in available_channels
            }
            self.params.sub_bg = available_channels.intersection(
                self.params.sub_bg
            )
            # add background subtracted channels to those available
            available_channels_bgsub = available_channels.union(
                [c + "_bgsub" for c in self.params.sub_bg]
            )
            # remove any multichannel operations requiring a missing channel
            for op, (input_ch, _, _) in self.params.multichannel_ops.items():
                if not set(input_ch).issubset(available_channels_bgsub):
                    self.params.multichannel_ops.pop(op)
Alán Muñoz's avatar
Alán Muñoz committed
        self.load_funs()

    @classmethod
pswain's avatar
pswain committed
        cls,
        parameters: ExtractorParameters,
        store: str,
        tiler: Tiler,
        """Initiate from a tiler instance."""
Alán Muñoz's avatar
Alán Muñoz committed
        return cls(parameters, store=store, tiler=tiler)

    @classmethod
pswain's avatar
pswain committed
        cls,
        parameters: ExtractorParameters,
        store: str,
        img_meta: tuple,
        """Initiate from images."""
Alán Muñoz's avatar
Alán Muñoz committed
        return cls(parameters, store=store, tiler=Tiler(*img_meta))

    @property
    def channels(self):
        """Get a tuple of the available channels."""
Alán Muñoz's avatar
Alán Muñoz committed
        if not hasattr(self, "_channels"):
            if type(self.params.tree) is dict:
                self._channels = tuple(self.params.tree.keys())
        return self._channels

    @property
    def current_position(self):
        """Return position being analysed."""
        return str(self.h5path).split("/")[-1][:-3]
Alán Muñoz's avatar
Alán Muñoz committed

    @property
pswain's avatar
pswain committed
    def group(self):
        """Return out path to write in the h5 file."""
Alán Muñoz's avatar
Alán Muñoz committed
        if not hasattr(self, "_out_path"):
            self._group = "/extraction/"
        return self._group

    def load_funs(self):
        """Define all functions, including custom ones."""
        self.load_custom_funs()
        self.all_cell_funs = set(self.custom_funs.keys()).union(CELL_FUNS)
        # merge the two dicts
        self.all_funs = {**self.custom_funs, **ALL_FUNS}

Alán Muñoz's avatar
Alán Muñoz committed
    def load_custom_funs(self):
        """
        Incorporate extra arguments of custom functions into their definitions.

        Normal functions only have cell_masks and trap_image as their
        arguments, and here custom functions are made the same by
        setting the values of their extra arguments.
pswain's avatar
pswain committed

pswain's avatar
pswain committed
        Any other parameters are taken from the experiment's metadata
        and automatically applied. These parameters therefore must be
        loaded within an Extractor instance.
Alán Muñoz's avatar
Alán Muñoz committed
        """
pswain's avatar
pswain committed
        # find functions specified in params.tree
Alán Muñoz's avatar
Alán Muñoz committed
        funs = set(
            [
                fun
                for channel in self.params.tree.values()
                for reduction in channel.values()
                for fun in reduction
pswain's avatar
pswain committed
        # consider only those already loaded from CUSTOM_FUNS
Alán Muñoz's avatar
Alán Muñoz committed
        funs = funs.intersection(CUSTOM_FUNS.keys())
pswain's avatar
pswain committed
        # find their arguments
        self.custom_arg_vals = {
            k: {k2: self.get_meta(k2) for k2 in v}
            for k, v in CUSTOM_ARGS.items()
        # define custom functions
Alán Muñoz's avatar
Alán Muñoz committed
        for k, f in CUSTOM_FUNS.items():

            def tmp(f):
pswain's avatar
pswain committed
                # pass extra arguments to custom function
                # return a function of cell_masks and trap_image
pswain's avatar
pswain committed
                return lambda cell_masks, trap_image: trap_apply(
                    f,
                    cell_masks,
                    trap_image,
                    **self.custom_arg_vals.get(k, {}),
            self.custom_funs[k] = tmp(f)
    def get_tiles(
pswain's avatar
pswain committed
        self,
        tp: int,
        channels: t.Optional[t.List[t.Union[str, int]]] = None,
        z: t.Optional[t.List[str]] = None,
    ) -> t.Optional[np.ndarray]:
pswain's avatar
pswain committed
        """
        Find tiles for a given time point, channels, and z-stacks.
pswain's avatar
pswain committed

pswain's avatar
pswain committed
        Any additional keyword arguments are passed to
        tiler.get_tiles_timepoint
pswain's avatar
pswain committed

        Parameters
        ----------
        tp: int
            Time point of interest.
pswain's avatar
pswain committed
        channels: list of strings (optional)
            Channels of interest.
pswain's avatar
pswain committed
        z: list of integers (optional)
            Indices for the z-stacks of interest.
pswain's avatar
pswain committed
        """
Alán Muñoz's avatar
Alán Muñoz committed
        if channels is None:
pswain's avatar
pswain committed
            # find channels from tiler
Alán Muñoz's avatar
Alán Muñoz committed
            channel_ids = list(range(len(self.tiler.channels)))
        elif len(channels):
pswain's avatar
pswain committed
            # a subset of channels was specified
Alán Muñoz's avatar
Alán Muñoz committed
            channel_ids = [self.tiler.get_channel_index(ch) for ch in channels]
        else:
            # a list of the indices of the z stacks
Alán Muñoz's avatar
Alán Muñoz committed
            channel_ids = None
        if z is None:
pswain's avatar
pswain committed
            # include all Z channels
            z = list(range(self.tiler.shape[-3]))
pswain's avatar
pswain committed
        # get the image data via tiler
pswain's avatar
pswain committed
            self.tiler.get_tiles_timepoint(tp, channels=channel_ids, z=z)
Alán Muñoz's avatar
Alán Muñoz committed
            if channel_ids
            else None
        )
        # tiles has dimensions (tiles, channels, 1, Z, X, Y)
        return tiles
    def apply_cell_function(
Alán Muñoz's avatar
Alán Muñoz committed
        self,
        traps: t.List[np.ndarray],
        masks: t.List[np.ndarray],
        cell_function: str,
pswain's avatar
pswain committed
        cell_labels: t.Dict[int, t.List[int]],
    ) -> t.Tuple[t.Union[t.Tuple[float], t.Tuple[t.Tuple[int]]]]:
Alán Muñoz's avatar
Alán Muñoz committed
        """
        Apply a cell function to all cells at all traps for one time point.
pswain's avatar
pswain committed
        Parameters
        ----------
pswain's avatar
pswain committed
        traps: list of arrays
            t.List of images.
pswain's avatar
pswain committed
        masks: list of arrays
            t.List of masks.
        cell_function: str
            Function to apply.
pswain's avatar
pswain committed
        cell_labels: dict
            A dict with trap_ids as keys and a list of cell labels as
            values.
pswain's avatar
pswain committed
        Returns
        -------
pswain's avatar
pswain committed
        res_idx: a tuple of tuples
pswain's avatar
pswain committed
            A two-tuple comprising a tuple of results and a tuple of
            the tile_id and cell labels
Alán Muñoz's avatar
Alán Muñoz committed
        """
pswain's avatar
pswain committed
        if cell_labels is None:
            self._log("No cell labels given. Sorting cells using index.")
        cell_fun = True if cell_function in self.all_cell_funs else False
Alán Muñoz's avatar
Alán Muñoz committed
        idx = []
        results = []
        for trap_id, (mask_set, trap, local_cell_labels) in enumerate(
pswain's avatar
pswain committed
            zip(masks, traps, cell_labels.values())
pswain's avatar
pswain committed
            # ignore empty traps
            if len(mask_set):
                # find property from the tile
                result = self.all_funs[cell_function](mask_set, trap)
Alán Muñoz's avatar
Alán Muñoz committed
                if cell_fun:
pswain's avatar
pswain committed
                    # store results for each cell separately
                    for cell_label, val in zip(local_cell_labels, result):
Alán Muñoz's avatar
Alán Muñoz committed
                        results.append(val)
                        idx.append((trap_id, cell_label))
Alán Muñoz's avatar
Alán Muñoz committed
                else:
pswain's avatar
pswain committed
                    # background (trap) function
Alán Muñoz's avatar
Alán Muñoz committed
                    results.append(result)
                    idx.append(trap_id)
pswain's avatar
pswain committed
        res_idx = (tuple(results), tuple(idx))
        return res_idx
        tiles: t.List[np.array],
        masks: t.List[np.array],
    ) -> t.Dict[str, pd.Series]:
Alán Muñoz's avatar
Alán Muñoz committed
        """
        Return dict with cell_funs as keys and the corresponding results as values.

        Data from one time point is used.
Alán Muñoz's avatar
Alán Muñoz committed
        """
        d = {
            cell_fun: self.apply_cell_function(
                traps=tiles, masks=masks, cell_function=cell_fun, **kwargs
            for cell_fun in cell_funs
Alán Muñoz's avatar
Alán Muñoz committed
        }
        return d

    def reduce_extract(
        masks: t.List[np.ndarray],
        reduction_cell_funs: t.Dict[reduction_method, t.Collection[str]],
    ) -> t.Dict[str, t.Dict[reduction_method, t.Dict[str, pd.Series]]]:
Alán Muñoz's avatar
Alán Muñoz committed
        """
        Reduce to a 2D image and then extract.
Alán Muñoz's avatar
Alán Muñoz committed

        Parameters
        ----------
            An array of image data arranged as (tiles, X, Y, Z)
pswain's avatar
pswain committed
        masks: list of arrays
            An array of masks for each trap: one per cell at the trap
        reduction_cell_funs: dict
            An upper branch of the extraction tree: a dict for which
            keys are reduction functions and values are either a list
            or a set of strings giving the cell functions to apply.
pswain's avatar
pswain committed
            For example: {'np_max': {'max5px', 'mean', 'median'}}
pswain's avatar
pswain committed
        **kwargs: dict
            All other arguments passed to Extractor.apply_cell_funs.
Alán Muñoz's avatar
Alán Muñoz committed

        Returns
        ------
        Dict of dataframes with the corresponding reductions and metrics nested.
Alán Muñoz's avatar
Alán Muñoz committed
        """
pswain's avatar
pswain committed
        # create dict with keys naming the reduction in the z-direction
        # and the reduced data as values
        reduced_tiles = {}
        if tiles is not None:
            for reduction in reduction_cell_funs.keys():
                reduced_tiles[reduction] = [
                    self.reduce_dims(
                        tile_data, method=REDUCTION_FUNS[reduction]
                    )
                    for tile_data in tiles
        # calculate cell and tile properties
Alán Muñoz's avatar
Alán Muñoz committed
        d = {
            reduction: self.apply_cell_funs(
                tiles=reduced_tiles.get(reduction, [None for _ in masks]),
Alán Muñoz's avatar
Alán Muñoz committed
                masks=masks,
Alán Muñoz's avatar
Alán Muñoz committed
                **kwargs,
            )
            for reduction, cell_funs in reduction_cell_funs.items()
    def reduce_dims(
        self, img: np.ndarray, method: reduction_method = None
    ) -> np.ndarray:
Alán Muñoz's avatar
Alán Muñoz committed
        """
pswain's avatar
pswain committed
        Collapse a z-stack into 2d array using method.
pswain's avatar
pswain committed
        If method is None, return the original data.

        Parameters
        ----------
        img: array
            An array of the image data arranged as (X, Y, Z).
pswain's avatar
pswain committed
        method: function
            The reduction function.
Alán Muñoz's avatar
Alán Muñoz committed
        """
        reduced = img
        if method is not None:
            reduced = reduce_z(img, method)
        return reduced
    def make_tree_dict(self, tree: extraction_tree):
        """Put extraction tree into a dict."""
Alán Muñoz's avatar
Alán Muñoz committed
        if tree is None:
            # use default
            tree = self.params.tree
        tree_dict = {
            # the whole extraction tree
pswain's avatar
pswain committed
            "tree": tree,
            # the extraction tree for fluorescence channels
            "channels_tree": {
pswain's avatar
pswain committed
                ch: v for ch, v in tree.items() if ch != "general"
            },
        }
        # tuple of the fluorescence channels
        tree_dict["channels"] = (*tree_dict["channels_tree"],)
        return tree_dict
pswain's avatar
pswain committed

    def get_masks(self, tp, masks, cells):
        """Get the masks as a list with an array of masks for each trap."""
pswain's avatar
pswain committed
        # find the cell masks for a given trap as a dict with trap_ids as keys
        if masks is None:
            raw_masks = cells.at_time(tp, kind="mask")
            masks = {trap_id: [] for trap_id in range(cells.ntraps)}
            for trap_id, cells in raw_masks.items():
                if len(cells):
                    masks[trap_id] = np.stack(np.array(cells)).astype(bool)
pswain's avatar
pswain committed
        # convert to a list of masks
pswain's avatar
pswain committed
        # one array of size (no cells, tile_size, tile_size) per trap
Alán Muñoz's avatar
Alán Muñoz committed
        masks = [np.array(v) for v in masks.values()]
pswain's avatar
pswain committed
        return masks

    def get_cell_labels(self, tp, cell_labels, cells):
        """Get the cell labels per trap as a dict with trap_ids as keys."""
        if cell_labels is None:
            raw_cell_labels = cells.labels_at_time(tp)
            cell_labels = {
                trap_id: raw_cell_labels.get(trap_id, [])
                for trap_id in range(cells.ntraps)
            }
        return cell_labels

    def get_background_masks(self, masks, tile_size):
        """
        Generate boolean background masks.

        Combine masks per trap and then take the logical inverse.
        """
Alán Muñoz's avatar
Alán Muñoz committed
        if self.params.sub_bg:
            bgs = ~np.array(
                list(
                    map(
pswain's avatar
pswain committed
                        # sum over masks for each cell
                        lambda x: np.sum(x, axis=0)
                        if np.any(x)
                        else np.zeros((tile_size, tile_size)),
                        masks,
                    )
                )
            ).astype(bool)
pswain's avatar
pswain committed
        else:
            bgs = np.array([])
        return bgs

    def extract_one_channel(
        self, tree_dict, cell_labels, img, img_bgsub, masks, **kwargs
pswain's avatar
pswain committed
    ):
        """Extract as dict all metrics requiring only a single channel."""
Alán Muñoz's avatar
Alán Muñoz committed
        d = {}
        for ch, reduction_cell_funs in tree_dict["tree"].items():
            # extract from all images including bright field
Alán Muñoz's avatar
Alán Muñoz committed
            d[ch] = self.reduce_extract(
                # use None for "general"; no fluorescence image
                tiles=img.get(ch, None),
                reduction_cell_funs=reduction_cell_funs,
pswain's avatar
pswain committed
                cell_labels=cell_labels,
            if ch != "general":
                # extract from background-corrected fluorescence images
                d[ch + "_bgsub"] = self.reduce_extract(
                    tiles=img_bgsub[ch + "_bgsub"],
Alán Muñoz's avatar
Alán Muñoz committed
                    masks=masks,
                    reduction_cell_funs=reduction_cell_funs,
pswain's avatar
pswain committed
                    cell_labels=cell_labels,
Alán Muñoz's avatar
Alán Muñoz committed
                    **kwargs,
                )
pswain's avatar
pswain committed

    def extract_multiple_channels(self, cell_labels, img, img_bgsub, masks):
        """Extract as a dict all metrics requiring multiple channels."""
        # NB multichannel functions do not use tree_dict
        available_channels = set(list(img.keys()) + list(img_bgsub.keys()))
pswain's avatar
pswain committed
        d = {}
        for multichannel_fun_name, (
            channels,
            reduction,
            multichannel_function,
        ) in self.params.multichannel_ops.items():
            common_channels = set(channels).intersection(available_channels)
            # all required channels should be available
            if len(common_channels) == len(channels):
                for images, suffix in zip([img, img_bgsub], ["", "_bgsub"]):
                    # channels
                    channels_stack = np.stack(
                        [images[ch + suffix] for ch in channels],
                        axis=-1,
                    )
                    # reduce in Z
                    tiles = REDUCTION_FUNS[reduction](channels_stack, axis=1)
                    # set up dict
                    if multichannel_fun_name not in d:
                        d[multichannel_fun_name] = {}
                    if reduction not in d[multichannel_fun_name]:
                        d[multichannel_fun_name][reduction] = {}
                    # apply multichannel function
                    d[multichannel_fun_name][reduction][
                        multichannel_function + suffix
                    ] = self.apply_cell_function(
                        tiles,
                        masks,
                        multichannel_function,
                        cell_labels,
                    )
Alán Muñoz's avatar
Alán Muñoz committed
        return d

pswain's avatar
pswain committed
    def extract_tp(
        self,
        tp: int,
        tree: t.Optional[extraction_tree] = None,
        tile_size: int = 117,
        masks: t.Optional[t.List[np.ndarray]] = None,
        cell_labels: t.Optional[t.List[int]] = None,
        **kwargs,
    ) -> t.Dict[str, t.Dict[str, t.Dict[str, tuple]]]:
        """
        Extract for an individual time point.

        Parameters
        ----------
        tp : int
            Time point being analysed.
        tree : dict
            Nested dictionary indicating channels, reduction functions
            and metrics to be used.
            For example: {'general': {'None': ['area', 'volume', 'eccentricity']}}
        tile_size : int
            Size of the tile to be extracted.
        masks : list of arrays
            A list of masks per trap with each mask having dimensions
            (ncells, tile_size, tile_size) and with one mask per cell.
        cell_labels : dict
            A dictionary with trap_ids as keys and cell_labels as values.
        **kwargs : keyword arguments
            Passed to extractor.reduce_extract.

        Returns
        -------
        d: dict
            Dictionary of the results with three levels of dictionaries.
            The first level has channels as keys.
            The second level has reduction metrics as keys.
            The third level has cell or background metrics as keys and a
            two-tuple as values.
            The first tuple is the result of applying the metrics to a
            particular cell or trap; the second tuple is either
            (trap_id, cell_label) for a metric applied to a cell or a
            trap_id for a metric applied to a trap.

            An example is d["GFP"]["np_max"]["mean"][0], which gives a tuple
            of the calculated mean GFP fluorescence for all cells.
        """
        # dict of information from extraction tree
        tree_dict = self.make_tree_dict(tree)
pswain's avatar
pswain committed
        # create a Cells object to extract information from the h5 file
        cells = Cells(self.h5path)
pswain's avatar
pswain committed
        # find the cell labels as dict with trap_ids as keys
        cell_labels = self.get_cell_labels(tp, cell_labels, cells)
        # get masks one per cell per trap
        masks = self.get_masks(tp, masks, cells)
        # find image data for all traps at the time point
pswain's avatar
pswain committed
        # stored as an array arranged as (traps, channels, 1, Z, X, Y)
        tiles = self.get_tiles(tp, channels=tree_dict["channels"])
pswain's avatar
pswain committed
        # generate boolean masks for background for each trap
        bgs = self.get_background_masks(masks, tile_size)
        # get images and background corrected images as dicts
        # with fluorescnce channels as keys
        img, img_bgsub = self.get_imgs_background_subtract(
pswain's avatar
pswain committed
        # perform extraction
        res_one = self.extract_one_channel(
            tree_dict, cell_labels, img, img_bgsub, masks, **kwargs
pswain's avatar
pswain committed
        )
        res_multiple = self.extract_multiple_channels(
            cell_labels, img, img_bgsub, masks
pswain's avatar
pswain committed
        )
        res = {**res_one, **res_multiple}
pswain's avatar
pswain committed
        return res

    def get_imgs_background_subtract(self, tree_dict, tiles, bgs):
        """
        Get two dicts of fluorescence images.

        Return images and background subtracted image for all traps
        for one time point.
        """
        img = {}
        img_bgsub = {}
        for ch, _ in tree_dict["channels_tree"].items():
            # NB ch != is necessary for threading
            if tiles is not None and len(tiles):
                # image data for all traps for a particular channel and
                # time point arranged as (traps, Z, X, Y)
                # we use 0 here to access the single time point available
                img[ch] = tiles[:, tree_dict["channels"].index(ch), 0]
                if (
                    bgs.any()
                    and ch in self.params.sub_bg
                    and img[ch] is not None
                ):
                    # subtract median background
                    bgsub_mapping = map(
                        # move Z to last column to allow subtraction
                        lambda img, bgs: np.moveaxis(img, 0, -1)
                        # median of background over all pixels for each Z section
                        - bn.median(img[:, bgs], axis=1),
                        bgs,
                    )
                    # apply map and convert to array
                    mapping_result = np.stack(list(bgsub_mapping))
                    # move Z axis back to the second column
                    img_bgsub[ch + "_bgsub"] = np.moveaxis(
                        mapping_result, -1, 1
                    )
            else:
                img[ch] = None
                img_bgsub[ch] = None
        return img, img_bgsub
    def get_imgs_old(self, channel: t.Optional[str], tiles, channels=None):
Alán Muñoz's avatar
Alán Muñoz committed
        """
        Return image from a correct source, either raw or bgsub.
pswain's avatar
pswain committed
        Parameters
        ----------
        channel: str
            Name of channel to get.
        tiles: ndarray
pswain's avatar
pswain committed
            An array of the image data having dimensions of
            (tile_id, channel, tp, tile_size, tile_size, n_zstacks).
pswain's avatar
pswain committed
        channels: list of str (optional)
            t.List of available channels.
pswain's avatar
pswain committed
        Returns
        -------
        img: ndarray
pswain's avatar
pswain committed
            An array of image data with dimensions
            (no tiles, X, Y, no Z channels)
pswain's avatar
pswain committed
        """
Alán Muñoz's avatar
Alán Muñoz committed
        if channels is None:
            channels = (*self.params.tree,)
        if channel in channels:  # TODO start here to fetch channel using regex
            return tiles[:, channels.index(channel), 0]
Alán Muñoz's avatar
Alán Muñoz committed
        elif channel in self.img_bgsub:
            return self.img_bgsub[channel]

pswain's avatar
pswain committed
        self,
        tps: t.List[int] = None,
pswain's avatar
pswain committed
        save=True,
        **kwargs,
pswain's avatar
pswain committed
        """
        Run extraction for one position and for the specified time points.

        Save the results to a h5 file.
pswain's avatar
pswain committed
        Parameters
        ----------
        tps: list of int (optional)
            Time points to include.
        tree: dict (optional)
pswain's avatar
pswain committed
            Nested dictionary indicating channels, reduction functions and
            metrics to be used.
            For example: {'general': {'None': ['area', 'volume', 'eccentricity']}}
        save: boolean (optional)
            If True, save results to h5 file.
        kwargs: keyword arguments (optional)
            Passed to extract_tp.
pswain's avatar
pswain committed
        Returns
        -------
        d: dict
pswain's avatar
pswain committed
            A dict of the extracted data for one position with a concatenated
            string of channel, reduction metric, and cell metric as keys and
            pd.DataFrame of the extracted data for all time points as values.
pswain's avatar
pswain committed
        """
Alán Muñoz's avatar
Alán Muñoz committed
        if tree is None:
            tree = self.params.tree
        if tps is None:
            tps = list(range(self.meta["time_settings/ntimepoints"][0]))
        elif isinstance(tps, int):
            tps = [tps]
pswain's avatar
pswain committed
        # store results in dict
Alán Muñoz's avatar
Alán Muñoz committed
        d = {}
        for tp in tps:
pswain's avatar
pswain committed
            # extract for each time point and convert to dict of pd.Series
            new = flatten_nesteddict(
Alán Muñoz's avatar
Alán Muñoz committed
                self.extract_tp(tp=tp, tree=tree, **kwargs),
                to="series",
                tp=tp,
            )
            # concatenate with data extracted from earlier time points
Alán Muñoz's avatar
Alán Muñoz committed
            for k in new.keys():
pswain's avatar
pswain committed
                d[k] = pd.concat((d.get(k, None), new[k]), axis=1)
        # add indices to pd.Series containing the extracted data
Alán Muñoz's avatar
Alán Muñoz committed
        for k in d.keys():
            indices = ["experiment", "position", "trap", "cell_label"]
            idx = (
                indices[-d[k].index.nlevels :]
                if d[k].index.nlevels > 1
                else [indices[-2]]
            )
            d[k].index.names = idx
pswain's avatar
pswain committed
        # save
Alán Muñoz's avatar
Alán Muñoz committed
        if save:
            self.save_to_h5(d)
pswain's avatar
pswain committed
        return d
    def save_to_h5(self, dict_series, path=None):
pswain's avatar
pswain committed
        """
        Save the extracted data for one position to the h5 file.
pswain's avatar
pswain committed
        Parameters
        ----------
        dict_series: dict
            A dictionary of the extracted data, created by run.
        path: Path (optional)
            To the h5 file.
        """
Alán Muñoz's avatar
Alán Muñoz committed
        if path is None:
Alán Muñoz's avatar
Alán Muñoz committed
        self.writer = Writer(path)
pswain's avatar
pswain committed
        for extract_name, series in dict_series.items():
            dset_path = "/extraction/" + extract_name
            self.writer.write(dset_path, series)
Alán Muñoz's avatar
Alán Muñoz committed
        self.writer.id_cache.clear()

    def get_meta(self, flds: t.Union[str, t.Collection]):
        """Obtain metadata for one or multiple fields."""
        if isinstance(flds, str):
Alán Muñoz's avatar
Alán Muñoz committed
            flds = [flds]
        meta_short = {k.split("/")[-1]: v for k, v in self.meta.items()}
        return {
            f: meta_short.get(f, self.default_meta.get(f, None)) for f in flds
        }
def flatten_nesteddict(
    nest: dict, to="series", tp: int = None
) -> t.Dict[str, pd.Series]:
Alán Muñoz's avatar
Alán Muñoz committed
    """
    Convert a nested extraction dict into a dict of pd.Series.
pswain's avatar
pswain committed
    Parameters
    ----------
    nest: dict of dicts
        Contains the nested results of extraction.
    to: str (optional)
pswain's avatar
pswain committed
        Specifies the format of the output, either pd.Series (default)
        or a list
pswain's avatar
pswain committed
    tp: int
        Time point used to name the pd.Series
pswain's avatar
pswain committed

    Returns
    -------
    d: dict
pswain's avatar
pswain committed
        A dict with a concatenated string of channel, reduction metric,
        and cell metric as keys and either a pd.Series or a list of the
        corresponding extracted data as values.
pswain's avatar
pswain committed
    """
Alán Muñoz's avatar
Alán Muñoz committed
    d = {}
    for k0, v0 in nest.items():
        for k1, v1 in v0.items():
            for k2, v2 in v1.items():
                d["/".join((k0, k1, k2))] = (
                    pd.Series(*v2, name=tp) if to == "series" else v2
                )
    return d