Skip to content
Snippets Groups Projects
extractor.py 24.9 KiB
Newer Older
import typing as t
from time import perf_counter
from typing import Callable, Dict, List
Alán Muñoz's avatar
Alán Muñoz committed
import numpy as np
import pandas as pd
from agora.abc import ParametersABC, ProcessABC
from agora.io.cells import Cells
from agora.io.writer import Writer, load_attributes
from aliby.tile.tiler import Tiler
from extraction.core.functions.defaults import exparams_from_meta
from extraction.core.functions.distributors import reduce_z, trap_apply
Alán Muñoz's avatar
Alán Muñoz committed
from extraction.core.functions.loaders import (
    load_custom_args,
Alán Muñoz's avatar
Alán Muñoz committed
    load_mergefuns,
Alán Muñoz's avatar
Alán Muñoz committed
)
from extraction.core.functions.utils import depth

pswain's avatar
pswain committed
# Global parameters used to load functions that either analyse cells or their background. These global parameters both allow the functions to be stored in a dictionary for access only on demand and to be defined simply in extraction/core/functions.
Alán Muñoz's avatar
Alán Muñoz committed
CELL_FUNS, TRAPFUNS, FUNS = load_funs()
CUSTOM_FUNS, CUSTOM_ARGS = load_custom_args()
RED_FUNS = load_redfuns()
MERGE_FUNS = load_mergefuns()

# Assign datatype depending on the metric used
Alán Muñoz's avatar
Alán Muñoz committed
# m2type = {"mean": np.float32, "median": np.ubyte, "imBackground": np.ubyte}
Alán Muñoz's avatar
Alán Muñoz committed


class ExtractorParameters(ParametersABC):
    """
    Base class to define parameters for extraction
    """

    def __init__(
        self,
        tree: Dict[str, Dict[Callable, List[str]]] = None,
Alán Muñoz's avatar
Alán Muñoz committed
        sub_bg: set = set(),
        multichannel_ops: Dict = {},
    ):
pswain's avatar
pswain committed
        """
        Parameters
        ----------
        tree: dict
            Nested dictionary indicating channels, reduction functions and
            metrics to be used.
            str channel -> U(function,None) reduction -> str metric
            If not of depth three, tree will be filled with Nones.
        sub_bg: set
        multichannel_ops: dict
        """
Alán Muñoz's avatar
Alán Muñoz committed
        self.tree = fill_tree(tree)
        self.sub_bg = sub_bg
        self.multichannel_ops = multichannel_ops

    @staticmethod
    def guess_from_meta(store_name: str, suffix="fast"):
        """
pswain's avatar
pswain committed
        Find the microscope used from the h5 metadata
pswain's avatar
pswain committed
        Parameters
        ----------
        store_name : str or Path
            For a h5 file
        suffix : str
            Added at the end of the predicted parameter set
Alán Muñoz's avatar
Alán Muñoz committed
        """
pswain's avatar
pswain committed
        with h5py.File(store_name, "r") as f:
            microscope = f["/"].attrs.get("microscope")
Alán Muñoz's avatar
Alán Muñoz committed
        assert microscope, "No metadata found"
        return "_".join((microscope, suffix))

    @classmethod
    def default(cls):
        return cls({})

    @classmethod
    def from_meta(cls, meta):
        return cls(**exparams_from_meta(meta))


class Extractor(ProcessABC):
    """
pswain's avatar
pswain committed
    The Extractor applies a metric, such as area or median, to cells identified in the image tiles using the cell masks.

    Its methods therefore require both tile images and masks.

    Usually one metric is applied per mask, but there are tile-specific backgrounds (Alan), which apply one metric per tile.

    Extraction follows a three-level tree structure. Channels, such as GFP, are the root level; the second level is the reduction algorithm, such as maximum projection; the last level is the metric - the specific operation to apply to the cells in the image identified by the mask, such as median, which is the median value of the pixels in each cell.
pswain's avatar
pswain committed
    parameters: core.extractor Parameters
        Parameters that include with channels, reduction and
        extraction functions to use.
    store: str
        Path to hdf5 storage file. Must contain cell outlines.
    tiler: pipeline-core.core.segmentation tiler
        Class that contains or fetches the image to be used for segmentation.
pswain's avatar
pswain committed
    default_meta = {
        "pixel_size": 0.236,
        "z_size": 0.6,
        "spacing": 0.6,
    }
Alán Muñoz's avatar
Alán Muñoz committed
    def __init__(
        self,
        parameters: ExtractorParameters,
        store: str = None,
        tiler: Tiler = None,
Alán Muñoz's avatar
Alán Muñoz committed
    ):
pswain's avatar
pswain committed
        """
        Initialise Extractor.

        Parameters
        ----------
        parameters: ExtractorParameters object
        store: str
            Name of h5 file
        tiler: Tiler object
        """
Alán Muñoz's avatar
Alán Muñoz committed
        self.params = parameters
Alán Muñoz's avatar
Alán Muñoz committed
        if store:
            self.local = store
            self.load_meta()
pswain's avatar
pswain committed
        else:
            # if no h5 file, use the parameters directly
Alán Muñoz's avatar
Alán Muñoz committed
            self.meta = {"channel": parameters.to_dict()["tree"].keys()}
        if tiler:
            self.tiler = tiler
Alán Muñoz's avatar
Alán Muñoz committed
        self.load_funs()

    @classmethod
pswain's avatar
pswain committed
        cls,
        parameters: ExtractorParameters,
        store: str,
        tiler: Tiler,
pswain's avatar
pswain committed
        # initate from tiler
Alán Muñoz's avatar
Alán Muñoz committed
        return cls(parameters, store=store, tiler=tiler)

    @classmethod
pswain's avatar
pswain committed
        cls,
        parameters: ExtractorParameters,
        store: str,
        img_meta: tuple,
pswain's avatar
pswain committed
        # initiate from image
Alán Muñoz's avatar
Alán Muñoz committed
        return cls(parameters, store=store, tiler=Tiler(*img_meta))

    @property
    def channels(self):
pswain's avatar
pswain committed
        # returns a tuple of strings of the available channels
Alán Muñoz's avatar
Alán Muñoz committed
        if not hasattr(self, "_channels"):
            if type(self.params.tree) is dict:
                self._channels = tuple(self.params.tree.keys())
        return self._channels

    @property
pswain's avatar
pswain committed
    # Alan: does this work. local is not a string.
Alán Muñoz's avatar
Alán Muñoz committed
    def current_position(self):
        return self.local.split("/")[-1][:-3]

    @property
pswain's avatar
pswain committed
    def group(self):
        # returns path within h5 file
Alán Muñoz's avatar
Alán Muñoz committed
        if not hasattr(self, "_out_path"):
            self._group = "/extraction/"
        return self._group

    def load_custom_funs(self):
        """
pswain's avatar
pswain committed
        Define any custom functions to be functions of cell_masks and trap_image only.
pswain's avatar
pswain committed

pswain's avatar
pswain committed
        Any other parameters are taken from the experiment's metadata and automatically applied. These parameters therefore must be loaded within an Extractor instance.
Alán Muñoz's avatar
Alán Muñoz committed
        """
pswain's avatar
pswain committed
        # find functions specified in params.tree
Alán Muñoz's avatar
Alán Muñoz committed
        funs = set(
            [
                fun
                for ch in self.params.tree.values()
                for red in ch.values()
                for fun in red
            ]
        )
pswain's avatar
pswain committed
        # consider only those already loaded from CUSTOM_FUNS
Alán Muñoz's avatar
Alán Muñoz committed
        funs = funs.intersection(CUSTOM_FUNS.keys())
pswain's avatar
pswain committed
        # find their arguments
Alán Muñoz's avatar
Alán Muñoz committed
        ARG_VALS = {
            k: {k2: self.get_meta(k2) for k2 in v}
            for k, v in CUSTOM_ARGS.items()
pswain's avatar
pswain committed
        # define custom functions - those with extra arguments other than cell_masks and trap_image - as functions of two variables
Alán Muñoz's avatar
Alán Muñoz committed
        self._custom_funs = {}
        for k, f in CUSTOM_FUNS.items():

            def tmp(f):
pswain's avatar
pswain committed
                # pass extra arguments to custom function
                return lambda cell_masks, trap_image: trap_apply(
                    f, cell_masks, trap_image, **ARG_VALS.get(k, {})
Alán Muñoz's avatar
Alán Muñoz committed

            self._custom_funs[k] = tmp(f)

    def load_funs(self):
        self.load_custom_funs()
        self._all_cell_funs = set(self._custom_funs.keys()).union(CELL_FUNS)
pswain's avatar
pswain committed
        # merge the two dicts
Alán Muñoz's avatar
Alán Muñoz committed
        self._all_funs = {**self._custom_funs, **FUNS}

    def load_meta(self):
pswain's avatar
pswain committed
        # load metadata from h5 file whose name is given by self.local
Alán Muñoz's avatar
Alán Muñoz committed
        self.meta = load_attributes(self.local)

    def get_traps(
pswain's avatar
pswain committed
        self,
        tp: int,
        channels: list = None,
        z: list = None,
        **kwargs,
Alán Muñoz's avatar
Alán Muñoz committed
    ) -> tuple:
pswain's avatar
pswain committed
        """
        Finds traps for a given time point and given channels and z-stacks.
        Returns None if no traps are found.

        Any additional keyword arguments are passed to tiler.get_traps_timepoint

        Parameters
        ----------
        tp: int
            Time point of interest
        channels: list of strings (optional)
            Channels of interest
        z: list of integers (optional)
            Indices for the z-stacks of interest
        """
Alán Muñoz's avatar
Alán Muñoz committed
        if channels is None:
pswain's avatar
pswain committed
            # find channels from tiler
Alán Muñoz's avatar
Alán Muñoz committed
            channel_ids = list(range(len(self.tiler.channels)))
        elif len(channels):
pswain's avatar
pswain committed
            # a subset of channels was specified
Alán Muñoz's avatar
Alán Muñoz committed
            channel_ids = [self.tiler.get_channel_index(ch) for ch in channels]
        else:
pswain's avatar
pswain committed
            # oh oh
Alán Muñoz's avatar
Alán Muñoz committed
            channel_ids = None
pswain's avatar
pswain committed
        # a list of the indices of the z stacks
Alán Muñoz's avatar
Alán Muñoz committed
        if z is None:
            z = list(range(self.tiler.shape[-1]))
pswain's avatar
pswain committed
        # gets the data via tiler
Alán Muñoz's avatar
Alán Muñoz committed
        traps = (
            self.tiler.get_traps_timepoint(
                tp, channels=channel_ids, z=z, **kwargs
            )
Alán Muñoz's avatar
Alán Muñoz committed
            if channel_ids
            else None
        )
pswain's avatar
pswain committed
        # data arranged as (traps, channels, timepoints, X, Y, Z)
Alán Muñoz's avatar
Alán Muñoz committed
        return traps

    def extract_traps(
        self,
        traps: List[np.array],
        masks: List[np.array],
        metric: str,
pswain's avatar
pswain committed
        labels: Dict = None,
Alán Muñoz's avatar
Alán Muñoz committed
    ) -> dict:
        """
pswain's avatar
pswain committed
        Apply a function to a whole position.
pswain's avatar
pswain committed
        Parameters
        ----------
pswain's avatar
pswain committed
        traps: list of arrays
            List of images.
        masks: list of arrays
            List of masks.
pswain's avatar
pswain committed
        metric: str
pswain's avatar
pswain committed
            Metric to extract.
        labels: dict
            A dict of cell labels with trap_ids as keys and a list of cell labels as values.
pswain's avatar
pswain committed
        pos_info: bool
pswain's avatar
pswain committed
            Whether to add the position as an index or not.
pswain's avatar
pswain committed
        Returns
        -------
pswain's avatar
pswain committed
        res_idx: a tuple of tuples
            A two-tuple of a tuple of results and a tuple with the corresponding trap_id and cell labels
Alán Muñoz's avatar
Alán Muñoz committed
        """
        if labels is None:
pswain's avatar
pswain committed
            # Alan: it looks like this will crash if Labels is None
Alán Muñoz's avatar
Alán Muñoz committed
            raise Warning("No labels given. Sorting cells using index.")
        cell_fun = True if metric in self._all_cell_funs else False
        idx = []
        results = []
        for trap_id, (mask_set, trap, lbl_set) in enumerate(
            zip(masks, traps, labels.values())
        ):
pswain's avatar
pswain committed
            # ignore empty traps
            if len(mask_set):
                # apply metric either a cell function or otherwise
Alán Muñoz's avatar
Alán Muñoz committed
                result = self._all_funs[metric](mask_set, trap)
                if cell_fun:
pswain's avatar
pswain committed
                    # store results for each cell separately
Alán Muñoz's avatar
Alán Muñoz committed
                    for lbl, val in zip(lbl_set, result):
                        results.append(val)
                        idx.append((trap_id, lbl))
                else:
pswain's avatar
pswain committed
                    # background (trap) function
Alán Muñoz's avatar
Alán Muñoz committed
                    results.append(result)
                    idx.append(trap_id)
pswain's avatar
pswain committed
        res_idx = (tuple(results), tuple(idx))
        return res_idx
Alán Muñoz's avatar
Alán Muñoz committed

    def extract_funs(
        self,
        traps: List[np.array],
        masks: List[np.array],
        metrics: List[str],
        **kwargs,
Alán Muñoz's avatar
Alán Muñoz committed
    ) -> dict:
        """
pswain's avatar
pswain committed
        Returns dict with metrics as key and metrics applied to data as values for data from one timepoint.
Alán Muñoz's avatar
Alán Muñoz committed
        """
        d = {
            metric: self.extract_traps(
                traps=traps, masks=masks, metric=metric, **kwargs
            )
            for metric in metrics
        }
        return d

    def reduce_extract(
        self,
        traps: np.array,
        masks: list,
        red_metrics: dict,
        **kwargs,
Alán Muñoz's avatar
Alán Muñoz committed
    ) -> dict:
        """
Alán Muñoz's avatar
Alán Muñoz committed
        Wrapper to apply reduction and then extraction.

        Parameters
        ----------
pswain's avatar
pswain committed
        traps: array
            An array of image data arranged as (traps, X, Y, Z)
        masks: list of arrays
            An array of masks for each trap: one per cell at the trap
        red_metrics: dict
            dict for which keys are reduction functions and values are either a list or a set of strings giving the metric functions.
            For example: {'np_max': {'max5px', 'mean', 'median'}}
pswain's avatar
pswain committed
        **kwargs: dict
pswain's avatar
pswain committed
            All other arguments and must include masks and traps. Alan: stll true?
Alán Muñoz's avatar
Alán Muñoz committed

        Returns
        ------
        Dictionary of dataframes with the corresponding reductions and metrics nested.
Alán Muñoz's avatar
Alán Muñoz committed
        """
pswain's avatar
pswain committed
        # create dict with keys naming the reduction in the z-direction and the reduced data as values
Alán Muñoz's avatar
Alán Muñoz committed
        reduced_traps = {}
        if traps is not None:
            for red_fun in red_metrics.keys():
                reduced_traps[red_fun] = [
                    self.reduce_dims(trap, method=RED_FUNS[red_fun])
                    for trap in traps
Alán Muñoz's avatar
Alán Muñoz committed
                ]

        d = {
            red_fun: self.extract_funs(
                metrics=metrics,
                traps=reduced_traps.get(red_fun, [None for _ in masks]),
                masks=masks,
                **kwargs,
            )
            for red_fun, metrics in red_metrics.items()
        }
        return d

    def reduce_dims(self, img: np.array, method=None) -> np.array:
Alán Muñoz's avatar
Alán Muñoz committed
        """
pswain's avatar
pswain committed
        Collapse a z-stack into 2d array using method.
        If method is None, return the original data.

        Parameters
        ----------
        img: array
            An array of the image data arranged as (X, Y, Z)
        method: function
            The reduction function
Alán Muñoz's avatar
Alán Muñoz committed
        """
Alán Muñoz's avatar
Alán Muñoz committed
        if method is None:
            return img
pswain's avatar
pswain committed
        else:
            return reduce_z(img, method)
Alán Muñoz's avatar
Alán Muñoz committed

    def extract_tp(
        self,
        tp: int,
        tree: dict = None,
        tile_size: int = 117,
        masks=None,
        labels=None,
        **kwargs,
pswain's avatar
pswain committed
    ) -> t.Dict[str, t.Dict[str, t.Dict[str, tuple]]]:
        """
        Core extraction method for an individual time-point.

        Parameters
        ----------
        tp : int
            Time point being analysed.
        tree : dict
            Nested dictionary indicating channels, reduction functions and
            metrics to be used.
pswain's avatar
pswain committed
            For example: {'general': {'None': ['area', 'volume', 'eccentricity']}}
        tile_size : int
pswain's avatar
pswain committed
            Size of the tile to be extracted.
pswain's avatar
pswain committed
        masks : list of arrays
            A list of masks per trap with each mask having dimensions (ncells, tile_size,
pswain's avatar
pswain committed
        labels : dict
            A dictionary with trap_ids as keys and cell_labels as values.
pswain's avatar
pswain committed
        **kwargs : keyword arguments
            Passed to extractor.reduce_extract.
pswain's avatar
pswain committed
        d: dict
            Dictionary of the results with three levels of dictionaries.
            The first level has channels as keys.
            The second level has reduction metrics as keys.
            The third level has cell or background metrics as keys and a two-tuple as values.
            The first tuple is the result of applying the metrics to a particular cell or trap; the second tuple is either (trap_id, cell_label) for a metric applied to a cell or a trap_id for a metric applied to a trap.

            An example is d["GFP"]["np_max"]["mean"][0], which gives a tuple of the calculated mean GFP fluorescence for all cells.

Alán Muñoz's avatar
Alán Muñoz committed
        """
        if tree is None:
            # use default
Alán Muñoz's avatar
Alán Muñoz committed
            tree = self.params.tree
        # dictionary with channel: {reduction algorithm : metric}
Alán Muñoz's avatar
Alán Muñoz committed
        ch_tree = {ch: v for ch, v in tree.items() if ch != "general"}
        # tuple of the channels
Alán Muñoz's avatar
Alán Muñoz committed
        tree_chs = (*ch_tree,)
pswain's avatar
pswain committed
        # create a Cells object to extract information from the h5 file
        cells = Cells(self.local)
pswain's avatar
pswain committed
        # find the cell labels and store as dict with trap_ids as keys
        if labels is None:
            raw_labels = cells.labels_at_time(tp)
            labels = {
                trap_id: raw_labels.get(trap_id, [])
                for trap_id in range(cells.ntraps)
pswain's avatar
pswain committed
        # find the cell masks for a given trap as a dict with trap_ids as keys
        if masks is None:
            raw_masks = cells.at_time(tp, kind="mask")
            masks = {trap_id: [] for trap_id in range(cells.ntraps)}
            for trap_id, cells in raw_masks.items():
                if len(cells):
                    masks[trap_id] = np.dstack(np.array(cells)).astype(bool)
pswain's avatar
pswain committed
        # convert to a list of masks
Alán Muñoz's avatar
Alán Muñoz committed
        masks = [np.array(v) for v in masks.values()]

pswain's avatar
pswain committed
        # find image data at the time point
        # stored as an array arranged as (traps, channels, timepoints, X, Y, Z)
        # Alan: traps does not appear the best name here!
        traps = self.get_traps(tp, tile_shape=tile_size, channels=tree_chs)
pswain's avatar
pswain committed
        # generate boolean masks for background as a list with one mask per trap
Alán Muñoz's avatar
Alán Muñoz committed
        if self.params.sub_bg:
pswain's avatar
pswain committed
            bgs = [
Alán Muñoz's avatar
Alán Muñoz committed
                ~np.sum(m, axis=2).astype(bool)
                if np.any(m)
                else np.zeros((tile_size, tile_size))
                for m in masks
            ]

pswain's avatar
pswain committed
        # perform extraction by applying metrics
Alán Muñoz's avatar
Alán Muñoz committed
        d = {}
pswain's avatar
pswain committed
        self.img_bgsub = {}
Alán Muñoz's avatar
Alán Muñoz committed
        for ch, red_metrics in tree.items():
pswain's avatar
pswain committed
            # NB ch != is necessary for threading
Alán Muñoz's avatar
Alán Muñoz committed
            if ch != "general" and traps is not None and len(traps):
pswain's avatar
pswain committed
                # image data for all traps and z sections for a particular channel
                # as an array arranged as (no traps, X, Y, no Z channels)
Alán Muñoz's avatar
Alán Muñoz committed
                img = traps[:, tree_chs.index(ch), 0]
pswain's avatar
pswain committed
            else:
                img = None
pswain's avatar
pswain committed
            # apply metrics to image data
Alán Muñoz's avatar
Alán Muñoz committed
            d[ch] = self.reduce_extract(
                red_metrics=red_metrics,
                traps=img,
                masks=masks,
                labels=labels,
                **kwargs,
pswain's avatar
pswain committed
            # apply metrics to image data with the background subtracted
            if ch in self.params.sub_bg and img is not None:
                # calculate metrics with subtracted bg
Alán Muñoz's avatar
Alán Muñoz committed
                ch_bs = ch + "_bgsub"
                self.img_bgsub[ch_bs] = []
pswain's avatar
pswain committed
                for trap, bg in zip(img, bgs):
Alán Muñoz's avatar
Alán Muñoz committed
                    cells_fl = np.zeros_like(trap)
pswain's avatar
pswain committed
                    # Alan: should this not be is_not_cell?
                    is_cell = np.where(bg)
                    # skip for empty traps
                    if len(is_cell[0]):
Alán Muñoz's avatar
Alán Muñoz committed
                        cells_fl = np.median(trap[is_cell], axis=0)
pswain's avatar
pswain committed
                    # subtract median background
Alán Muñoz's avatar
Alán Muñoz committed
                    self.img_bgsub[ch_bs].append(trap - cells_fl)
pswain's avatar
pswain committed
                # apply metrics to background-corrected data
Alán Muñoz's avatar
Alán Muñoz committed
                d[ch_bs] = self.reduce_extract(
                    red_metrics=ch_tree[ch],
                    traps=self.img_bgsub[ch_bs],
                    masks=masks,
                    labels=labels,
                    **kwargs,
                )

pswain's avatar
pswain committed
        # apply any metrics that use multiple channels (eg pH calculations)
        for name, (
            chs,
            merge_fun,
            red_metrics,
        ) in self.params.multichannel_ops.items():
Alán Muñoz's avatar
Alán Muñoz committed
            if len(
                set(chs).intersection(
                    set(self.img_bgsub.keys()).union(tree_chs)
                )
Alán Muñoz's avatar
Alán Muñoz committed
            ) == len(chs):
                imgs = [self.get_imgs(ch, traps, tree_chs) for ch in chs]
                merged = MERGE_FUNS[merge_fun](*imgs)
                d[name] = self.reduce_extract(
                    red_metrics=red_metrics,
                    traps=merged,
                    masks=masks,
                    labels=labels,
                    **kwargs,
                )

        return d

    def get_imgs(self, channel, traps, channels=None):
        """
        Returns the image from a correct source, either raw or bgsub

pswain's avatar
pswain committed
        Parameters
        ----------
        channel: str
            Name of channel to get.
        traps: ndarray
            An array of the image data having dimensions of (trap_id, channel, tp, tile_size, tile_size, n_zstacks).
        channels: list of str (optional)
            List of available channels.
pswain's avatar
pswain committed
        Returns
        -------
        img: ndarray
            An array of image data with dimensions (no traps, X, Y, no Z channels)
        """
Alán Muñoz's avatar
Alán Muñoz committed
        if channels is None:
            channels = (*self.params.tree,)
        if channel in channels:
            return traps[:, channels.index(channel), 0]
        elif channel in self.img_bgsub:
            return self.img_bgsub[channel]

Alán Muñoz's avatar
Alán Muñoz committed
    def run_tp(self, tp, **kwargs):
        """
pswain's avatar
pswain committed
        Wrapper to add compatiblibility with other steps of the pipeline.
Alán Muñoz's avatar
Alán Muñoz committed
        """
        return self.run(tps=[tp], **kwargs)

pswain's avatar
pswain committed
        self,
        tree=None,
        tps: List[int] = None,
        save=True,
        **kwargs,
pswain's avatar
pswain committed
        """
        Parameters
        ----------
        tree: dict
            Nested dictionary indicating channels, reduction functions and
            metrics to be used.
            For example: {'general': {'None': ['area', 'volume', 'eccentricity']}}
        tps: list of int (optional)
            Time points to include.
        save: boolean (optional)
            If True, save results to h5 file.
        kwargs: keyword arguments (optional)
            Passed to extract_tp.
pswain's avatar
pswain committed
        Returns
        -------
        d: dict
            A dict of the extracted data with a concatenated string of channel, reduction metric, and cell metric as keys and pd.Series of the extracted data as values.
        """
Alán Muñoz's avatar
Alán Muñoz committed
        if tree is None:
            tree = self.params.tree
        if tps is None:
            tps = list(range(self.meta["time_settings/ntimepoints"][0]))
pswain's avatar
pswain committed
        # store results in dict
Alán Muñoz's avatar
Alán Muñoz committed
        d = {}
        for tp in tps:
pswain's avatar
pswain committed
            # extract for each time point and convert to dict of pd.Series
            new = flatten_nesteddict(
Alán Muñoz's avatar
Alán Muñoz committed
                self.extract_tp(tp=tp, tree=tree, **kwargs),
                to="series",
                tp=tp,
            )
pswain's avatar
pswain committed
            # concatenate with data extracted from early time points
Alán Muñoz's avatar
Alán Muñoz committed
            for k in new.keys():
pswain's avatar
pswain committed
                d[k] = pd.concat((d.get(k, None), new[k]), axis=1)
        # add indices to pd.Series containing the extracted data
Alán Muñoz's avatar
Alán Muñoz committed
        for k in d.keys():
            indices = ["experiment", "position", "trap", "cell_label"]
            idx = (
                indices[-d[k].index.nlevels :]
                if d[k].index.nlevels > 1
                else [indices[-2]]
            )
            d[k].index.names = idx
pswain's avatar
pswain committed
        # save
Alán Muñoz's avatar
Alán Muñoz committed
        if save:
pswain's avatar
pswain committed
            self.save_to_hdf(d)
        return d
pswain's avatar
pswain committed
    # Alan: isn't this identical to run?
    # def extract_pos(
    #     self, tree=None, tps: List[int] = None, save=True, **kwargs
    # ) -> dict:
pswain's avatar
pswain committed
    #     if tree is None:
    #         tree = self.params.tree
pswain's avatar
pswain committed
    #     if tps is None:
    #         tps = list(range(self.meta["time_settings/ntimepoints"]))
pswain's avatar
pswain committed
    #     d = {}
    #     for tp in tps:
    #         new = flatten_nest(
    #             self.extract_tp(tp=tp, tree=tree, **kwargs),
    #             to="series",
    #             tp=tp,
    #         )
pswain's avatar
pswain committed
    #         for k in new.keys():
    #             n = new[k]
    #             d[k] = pd.concat((d.get(k, None), n), axis=1)
pswain's avatar
pswain committed
    #     for k in d.keys():
    #         indices = ["experiment", "position", "trap", "cell_label"]
    #         idx = (
    #             indices[-d[k].index.nlevels :]
    #             if d[k].index.nlevels > 1
    #             else [indices[-2]]
    #         )
    #         d[k].index.names = idx
pswain's avatar
pswain committed
    #         toreturn = d
pswain's avatar
pswain committed
    #     if save:
    #         self.save_to_hdf(toreturn)
pswain's avatar
pswain committed
    #     return toreturn
pswain's avatar
pswain committed
    def save_to_hdf(self, dict_series, path=None):
        """
        Save the extracted data to the h5 file.
pswain's avatar
pswain committed
        Parameters
        ----------
        dict_series: dict
            A dictionary of the extracted data, created by run.
        path: Path (optional)
            To the h5 file.
        """
Alán Muñoz's avatar
Alán Muñoz committed
        if path is None:
            path = self.local
        self.writer = Writer(path)
pswain's avatar
pswain committed
        for extract_name, series in dict_series.items():
            dset_path = "/extraction/" + extract_name
            self.writer.write(dset_path, series)
Alán Muñoz's avatar
Alán Muñoz committed
        self.writer.id_cache.clear()

    def get_meta(self, flds):
pswain's avatar
pswain committed
        # Alan: unsure what this is doing. seems to break for "nuc_conv_3d"
        # make flds a list
Alán Muñoz's avatar
Alán Muñoz committed
        if not hasattr(flds, "__iter__"):
            flds = [flds]
        meta_short = {k.split("/")[-1]: v for k, v in self.meta.items()}
        return {
            f: meta_short.get(f, self.default_meta.get(f, None)) for f in flds
        }
Alán Muñoz's avatar
Alán Muñoz committed


### Helpers
pswain's avatar
pswain committed
def flatten_nesteddict(nest: dict, to="series", tp: int = None) -> dict:
Alán Muñoz's avatar
Alán Muñoz committed
    """
pswain's avatar
pswain committed
    Converts a nested extraction dict into a dict of pd.Series
pswain's avatar
pswain committed
    Parameters
    ----------
    nest: dict of dicts
        Contains the nested results of extraction.
    to: str (optional)
        Specifies the format of the output, either pd.Series (default) or a list
    tp: int
        Timepoint used to name the pd.Series

    Returns
    -------
    d: dict
        A dict with a concatenated string of channel, reduction metric, and cell metric as keys and either a pd.Series or a list of the corresponding extracted data as values.
    """
Alán Muñoz's avatar
Alán Muñoz committed
    d = {}
    for k0, v0 in nest.items():
        for k1, v1 in v0.items():
            for k2, v2 in v1.items():
                d["/".join((k0, k1, k2))] = (
                    pd.Series(*v2, name=tp) if to == "series" else v2
                )
    return d


pswain's avatar
pswain committed
# Alan: this no longer seems to be used
Alán Muñoz's avatar
Alán Muñoz committed
def fill_tree(tree):
    if tree is None:
        return None
    tree_depth = depth(tree)
    if depth(tree) < 3:
        d = {None: {None: {None: []}}}
        for _ in range(2 - tree_depth):
            d = d[None]
        d[None] = tree
        tree = d
    return tree
Alán Muñoz's avatar
Alán Muñoz committed


class hollowExtractor(Extractor):
pswain's avatar
pswain committed
    """
    Extractor that only cares about receiving images and masks.
    Used for testing.
Alán Muñoz's avatar
Alán Muñoz committed
    """

    def __init__(self, parameters):
        self.params = parameters