Newer
Older
from typing import Callable, Dict, List
from agora.abc import ParametersABC, ProcessABC
from agora.io.cells import Cells
from agora.io.writer import Writer, load_attributes
from aliby.tile.tiler import Tiler
from extraction.core.functions.defaults import exparams_from_meta
from extraction.core.functions.distributors import reduce_z, trap_apply
from extraction.core.functions.loaders import (
load_custom_args,
)
from extraction.core.functions.utils import depth
# Global parameters used to load functions that either analyse cells or their background. These global parameters both allow the functions to be stored in a dictionary for access only on demand and to be defined simply in extraction/core/functions.
CELL_FUNS, TRAPFUNS, FUNS = load_funs()
CUSTOM_FUNS, CUSTOM_ARGS = load_custom_args()
RED_FUNS = load_redfuns()
MERGE_FUNS = load_mergefuns()
# Assign datatype depending on the metric used
# m2type = {"mean": np.float32, "median": np.ubyte, "imBackground": np.ubyte}
class ExtractorParameters(ParametersABC):
"""
Base class to define parameters for extraction
"""
def __init__(
self,
tree: Dict[str, Dict[Callable, List[str]]] = None,
sub_bg: set = set(),
multichannel_ops: Dict = {},
):
"""
Parameters
----------
tree: dict
Nested dictionary indicating channels, reduction functions and
metrics to be used.
str channel -> U(function,None) reduction -> str metric
If not of depth three, tree will be filled with Nones.
sub_bg: set
multichannel_ops: dict
"""
self.tree = fill_tree(tree)
self.sub_bg = sub_bg
self.multichannel_ops = multichannel_ops
@staticmethod
def guess_from_meta(store_name: str, suffix="fast"):
"""
Parameters
----------
store_name : str or Path
For a h5 file
suffix : str
Added at the end of the predicted parameter set
with h5py.File(store_name, "r") as f:
microscope = f["/"].attrs.get("microscope")
assert microscope, "No metadata found"
return "_".join((microscope, suffix))
@classmethod
def default(cls):
return cls({})
@classmethod
def from_meta(cls, meta):
return cls(**exparams_from_meta(meta))
class Extractor(ProcessABC):
"""
The Extractor applies a metric, such as area or median, to cells identified in the image tiles using the cell masks.
Its methods therefore require both tile images and masks.
Usually one metric is applied per mask, but there are tile-specific backgrounds (Alan), which apply one metric per tile.
Extraction follows a three-level tree structure. Channels, such as GFP, are the root level; the second level is the reduction algorithm, such as maximum projection; the last level is the metric - the specific operation to apply to the cells in the image identified by the mask, such as median, which is the median value of the pixels in each cell.
parameters: core.extractor Parameters
Parameters that include with channels, reduction and
extraction functions to use.
store: str
Path to hdf5 storage file. Must contain cell outlines.
tiler: pipeline-core.core.segmentation tiler
Class that contains or fetches the image to be used for segmentation.
default_meta = {
"pixel_size": 0.236,
"z_size": 0.6,
"spacing": 0.6,
}
self,
parameters: ExtractorParameters,
store: str = None,
tiler: Tiler = None,
"""
Initialise Extractor.
Parameters
----------
parameters: ExtractorParameters object
store: str
Name of h5 file
tiler: Tiler object
"""
if store:
self.local = store
self.load_meta()
self.meta = {"channel": parameters.to_dict()["tree"].keys()}
if tiler:
self.tiler = tiler
cls,
parameters: ExtractorParameters,
store: str,
tiler: Tiler,
return cls(parameters, store=store, tiler=tiler)
@classmethod
cls,
parameters: ExtractorParameters,
store: str,
img_meta: tuple,
return cls(parameters, store=store, tiler=Tiler(*img_meta))
@property
def channels(self):
if not hasattr(self, "_channels"):
if type(self.params.tree) is dict:
self._channels = tuple(self.params.tree.keys())
return self._channels
@property
def current_position(self):
return self.local.split("/")[-1][:-3]
@property
if not hasattr(self, "_out_path"):
self._group = "/extraction/"
return self._group
def load_custom_funs(self):
"""
Define any custom functions to be functions of cell_masks and trap_image only.
Any other parameters are taken from the experiment's metadata and automatically applied. These parameters therefore must be loaded within an Extractor instance.
funs = set(
[
fun
for ch in self.params.tree.values()
for red in ch.values()
for fun in red
]
)
k: {k2: self.get_meta(k2) for k2 in v}
for k, v in CUSTOM_ARGS.items()
# define custom functions - those with extra arguments other than cell_masks and trap_image - as functions of two variables
self._custom_funs = {}
for k, f in CUSTOM_FUNS.items():
def tmp(f):
# pass extra arguments to custom function
return lambda cell_masks, trap_image: trap_apply(
f, cell_masks, trap_image, **ARG_VALS.get(k, {})
self._custom_funs[k] = tmp(f)
def load_funs(self):
self.load_custom_funs()
self._all_cell_funs = set(self._custom_funs.keys()).union(CELL_FUNS)
self._all_funs = {**self._custom_funs, **FUNS}
def load_meta(self):
self.meta = load_attributes(self.local)
def get_traps(
self,
tp: int,
channels: list = None,
z: list = None,
**kwargs,
"""
Finds traps for a given time point and given channels and z-stacks.
Returns None if no traps are found.
Any additional keyword arguments are passed to tiler.get_traps_timepoint
Parameters
----------
tp: int
Time point of interest
channels: list of strings (optional)
Channels of interest
z: list of integers (optional)
Indices for the z-stacks of interest
"""
channel_ids = list(range(len(self.tiler.channels)))
elif len(channels):
channel_ids = [self.tiler.get_channel_index(ch) for ch in channels]
else:
if z is None:
z = list(range(self.tiler.shape[-1]))
self.tiler.get_traps_timepoint(
tp, channels=channel_ids, z=z, **kwargs
)
return traps
def extract_traps(
self,
traps: List[np.array],
masks: List[np.array],
metric: str,
traps: list of arrays
List of images.
masks: list of arrays
List of masks.
Metric to extract.
labels: dict
A dict of cell labels with trap_ids as keys and a list of cell labels as values.
res_idx: a tuple of tuples
A two-tuple of a tuple of results and a tuple with the corresponding trap_id and cell labels
raise Warning("No labels given. Sorting cells using index.")
cell_fun = True if metric in self._all_cell_funs else False
idx = []
results = []
for trap_id, (mask_set, trap, lbl_set) in enumerate(
zip(masks, traps, labels.values())
):
# ignore empty traps
if len(mask_set):
# apply metric either a cell function or otherwise
result = self._all_funs[metric](mask_set, trap)
if cell_fun:
for lbl, val in zip(lbl_set, result):
results.append(val)
idx.append((trap_id, lbl))
else:
self,
traps: List[np.array],
masks: List[np.array],
metrics: List[str],
**kwargs,
Returns dict with metrics as key and metrics applied to data as values for data from one timepoint.
"""
d = {
metric: self.extract_traps(
traps=traps, masks=masks, metric=metric, **kwargs
)
for metric in metrics
}
return d
def reduce_extract(
self,
traps: np.array,
masks: list,
red_metrics: dict,
**kwargs,
Wrapper to apply reduction and then extraction.
Parameters
----------
traps: array
An array of image data arranged as (traps, X, Y, Z)
masks: list of arrays
An array of masks for each trap: one per cell at the trap
red_metrics: dict
dict for which keys are reduction functions and values are either a list or a set of strings giving the metric functions.
For example: {'np_max': {'max5px', 'mean', 'median'}}
All other arguments and must include masks and traps. Alan: stll true?
Returns
------
Dictionary of dataframes with the corresponding reductions and metrics nested.
# create dict with keys naming the reduction in the z-direction and the reduced data as values
reduced_traps = {}
if traps is not None:
for red_fun in red_metrics.keys():
reduced_traps[red_fun] = [
self.reduce_dims(trap, method=RED_FUNS[red_fun])
for trap in traps
]
d = {
red_fun: self.extract_funs(
metrics=metrics,
traps=reduced_traps.get(red_fun, [None for _ in masks]),
masks=masks,
**kwargs,
)
for red_fun, metrics in red_metrics.items()
}
return d
def reduce_dims(self, img: np.array, method=None) -> np.array:
Collapse a z-stack into 2d array using method.
If method is None, return the original data.
Parameters
----------
img: array
An array of the image data arranged as (X, Y, Z)
method: function
The reduction function
self,
tp: int,
tree: dict = None,
tile_size: int = 117,
masks=None,
labels=None,
**kwargs,
"""
Core extraction method for an individual time-point.
Parameters
----------
tp : int
tree : dict
Nested dictionary indicating channels, reduction functions and
For example: {'general': {'None': ['area', 'volume', 'eccentricity']}}
masks : list of arrays
A list of masks per trap with each mask having dimensions (ncells, tile_size,
labels : dict
A dictionary with trap_ids as keys and cell_labels as values.
**kwargs : keyword arguments
Passed to extractor.reduce_extract.
d: dict
Dictionary of the results with three levels of dictionaries.
The first level has channels as keys.
The second level has reduction metrics as keys.
The third level has cell or background metrics as keys and a two-tuple as values.
The first tuple is the result of applying the metrics to a particular cell or trap; the second tuple is either (trap_id, cell_label) for a metric applied to a cell or a trap_id for a metric applied to a trap.
An example is d["GFP"]["np_max"]["mean"][0], which gives a tuple of the calculated mean GFP fluorescence for all cells.
# dictionary with channel: {reduction algorithm : metric}
ch_tree = {ch: v for ch, v in tree.items() if ch != "general"}
if labels is None:
raw_labels = cells.labels_at_time(tp)
labels = {
trap_id: raw_labels.get(trap_id, [])
for trap_id in range(cells.ntraps)
# find the cell masks for a given trap as a dict with trap_ids as keys
if masks is None:
raw_masks = cells.at_time(tp, kind="mask")
masks = {trap_id: [] for trap_id in range(cells.ntraps)}
for trap_id, cells in raw_masks.items():
if len(cells):
masks[trap_id] = np.dstack(np.array(cells)).astype(bool)
# find image data at the time point
# stored as an array arranged as (traps, channels, timepoints, X, Y, Z)
# Alan: traps does not appear the best name here!
traps = self.get_traps(tp, tile_shape=tile_size, channels=tree_chs)
# generate boolean masks for background as a list with one mask per trap
~np.sum(m, axis=2).astype(bool)
if np.any(m)
else np.zeros((tile_size, tile_size))
for m in masks
]
if ch != "general" and traps is not None and len(traps):
# image data for all traps and z sections for a particular channel
# as an array arranged as (no traps, X, Y, no Z channels)
red_metrics=red_metrics,
traps=img,
masks=masks,
labels=labels,
**kwargs,
# apply metrics to image data with the background subtracted
if ch in self.params.sub_bg and img is not None:
# calculate metrics with subtracted bg
ch_bs = ch + "_bgsub"
self.img_bgsub[ch_bs] = []
# Alan: should this not be is_not_cell?
is_cell = np.where(bg)
# skip for empty traps
if len(is_cell[0]):
d[ch_bs] = self.reduce_extract(
red_metrics=ch_tree[ch],
traps=self.img_bgsub[ch_bs],
masks=masks,
labels=labels,
**kwargs,
)
# apply any metrics that use multiple channels (eg pH calculations)
for name, (
chs,
merge_fun,
red_metrics,
) in self.params.multichannel_ops.items():
set(chs).intersection(
set(self.img_bgsub.keys()).union(tree_chs)
)
) == len(chs):
imgs = [self.get_imgs(ch, traps, tree_chs) for ch in chs]
merged = MERGE_FUNS[merge_fun](*imgs)
d[name] = self.reduce_extract(
red_metrics=red_metrics,
traps=merged,
masks=masks,
labels=labels,
**kwargs,
)
return d
def get_imgs(self, channel, traps, channels=None):
"""
Returns the image from a correct source, either raw or bgsub
Parameters
----------
channel: str
Name of channel to get.
traps: ndarray
An array of the image data having dimensions of (trap_id, channel, tp, tile_size, tile_size, n_zstacks).
channels: list of str (optional)
List of available channels.
Returns
-------
img: ndarray
An array of image data with dimensions (no traps, X, Y, no Z channels)
"""
if channels is None:
channels = (*self.params.tree,)
if channel in channels:
return traps[:, channels.index(channel), 0]
elif channel in self.img_bgsub:
return self.img_bgsub[channel]
Wrapper to add compatiblibility with other steps of the pipeline.
self,
tree=None,
tps: List[int] = None,
save=True,
**kwargs,
"""
Parameters
----------
tree: dict
Nested dictionary indicating channels, reduction functions and
metrics to be used.
For example: {'general': {'None': ['area', 'volume', 'eccentricity']}}
tps: list of int (optional)
Time points to include.
save: boolean (optional)
If True, save results to h5 file.
kwargs: keyword arguments (optional)
Passed to extract_tp.
Returns
-------
d: dict
A dict of the extracted data with a concatenated string of channel, reduction metric, and cell metric as keys and pd.Series of the extracted data as values.
"""
if tree is None:
tree = self.params.tree
if tps is None:
tps = list(range(self.meta["time_settings/ntimepoints"][0]))
# extract for each time point and convert to dict of pd.Series
new = flatten_nesteddict(
self.extract_tp(tp=tp, tree=tree, **kwargs),
to="series",
tp=tp,
)
d[k] = pd.concat((d.get(k, None), new[k]), axis=1)
# add indices to pd.Series containing the extracted data
for k in d.keys():
indices = ["experiment", "position", "trap", "cell_label"]
idx = (
indices[-d[k].index.nlevels :]
if d[k].index.nlevels > 1
else [indices[-2]]
)
d[k].index.names = idx
# Alan: isn't this identical to run?
# def extract_pos(
# self, tree=None, tps: List[int] = None, save=True, **kwargs
# ) -> dict:
# if tps is None:
# tps = list(range(self.meta["time_settings/ntimepoints"]))
# d = {}
# for tp in tps:
# new = flatten_nest(
# self.extract_tp(tp=tp, tree=tree, **kwargs),
# to="series",
# tp=tp,
# )
# for k in new.keys():
# n = new[k]
# d[k] = pd.concat((d.get(k, None), n), axis=1)
# for k in d.keys():
# indices = ["experiment", "position", "trap", "cell_label"]
# idx = (
# indices[-d[k].index.nlevels :]
# if d[k].index.nlevels > 1
# else [indices[-2]]
# )
# d[k].index.names = idx
def save_to_hdf(self, dict_series, path=None):
"""
Save the extracted data to the h5 file.
Parameters
----------
dict_series: dict
A dictionary of the extracted data, created by run.
path: Path (optional)
To the h5 file.
"""
if path is None:
path = self.local
self.writer = Writer(path)
for extract_name, series in dict_series.items():
dset_path = "/extraction/" + extract_name
self.writer.write(dset_path, series)
self.writer.id_cache.clear()
def get_meta(self, flds):
# Alan: unsure what this is doing. seems to break for "nuc_conv_3d"
# make flds a list
if not hasattr(flds, "__iter__"):
flds = [flds]
meta_short = {k.split("/")[-1]: v for k, v in self.meta.items()}
return {
f: meta_short.get(f, self.default_meta.get(f, None)) for f in flds
}
def flatten_nesteddict(nest: dict, to="series", tp: int = None) -> dict:
Parameters
----------
nest: dict of dicts
Contains the nested results of extraction.
to: str (optional)
Specifies the format of the output, either pd.Series (default) or a list
tp: int
Timepoint used to name the pd.Series
Returns
-------
d: dict
A dict with a concatenated string of channel, reduction metric, and cell metric as keys and either a pd.Series or a list of the corresponding extracted data as values.
"""
d = {}
for k0, v0 in nest.items():
for k1, v1 in v0.items():
for k2, v2 in v1.items():
d["/".join((k0, k1, k2))] = (
pd.Series(*v2, name=tp) if to == "series" else v2
)
return d
def fill_tree(tree):
if tree is None:
return None
tree_depth = depth(tree)
if depth(tree) < 3:
d = {None: {None: {None: []}}}
for _ in range(2 - tree_depth):
d = d[None]
d[None] = tree
tree = d
return tree
"""
Extractor that only cares about receiving images and masks.
Used for testing.
"""
def __init__(self, parameters):
self.params = parameters