Newer
Older
from typing import Callable, Dict, List
from agora.abc import ParametersABC, ProcessABC
from agora.io.cells import Cells
from agora.io.writer import Writer, load_attributes
from aliby.tile.tiler import Tiler
from extraction.core.functions.defaults import exparams_from_meta
from extraction.core.functions.distributors import reduce_z, trap_apply
from extraction.core.functions.loaders import (
load_custom_args,
)
from extraction.core.functions.utils import depth
# Global parameters used to load functions that either analyse cells or their background. These global parameters both allow the functions to be stored in a dictionary for access only on demand and to be defined simply in extraction/core/functions.
CELL_FUNS, TRAPFUNS, FUNS = load_funs()
CUSTOM_FUNS, CUSTOM_ARGS = load_custom_args()
RED_FUNS = load_redfuns()
MERGE_FUNS = load_mergefuns()
# Assign datatype depending on the metric used
# m2type = {"mean": np.float32, "median": np.ubyte, "imBackground": np.ubyte}
class ExtractorParameters(ParametersABC):
"""
Base class to define parameters for extraction
"""
def __init__(
self,
tree: Dict[str, Dict[Callable, List[str]]] = None,
sub_bg: set = set(),
multichannel_ops: Dict = {},
):
"""
Parameters
----------
tree: dict
Nested dictionary indicating channels, reduction functions and
metrics to be used.
str channel -> U(function,None) reduction -> str metric
If not of depth three, tree will be filled with Nones.
sub_bg: set
multichannel_ops: dict
"""
self.tree = fill_tree(tree)
self.sub_bg = sub_bg
self.multichannel_ops = multichannel_ops
@staticmethod
def guess_from_meta(store_name: str, suffix="fast"):
"""
Parameters
----------
store_name : str or Path
For a h5 file
suffix : str
Added at the end of the predicted parameter set
with h5py.File(store_name, "r") as f:
microscope = f["/"].attrs.get("microscope")
assert microscope, "No metadata found"
return "_".join((microscope, suffix))
@classmethod
def default(cls):
return cls({})
@classmethod
def from_meta(cls, meta):
return cls(**exparams_from_meta(meta))
class Extractor(ProcessABC):
"""
The Extractor applies a metric, such as area or median, to cells identified in the image tiles using the cell masks.
Its methods therefore require both tile images and masks.
Usually one metric is applied per mask, but there are tile-specific backgrounds (Alan), which apply one metric per tile.
Extraction follows a three-level tree structure. Channels, such as GFP, are the root level; the second level is the reduction algorithm, such as maximum projection; the last level is the metric - the specific operation to apply to the cells in the image identified by the mask, such as median, which is the median value of the pixels in each cell.
parameters: core.extractor Parameters
Parameters that include with channels, reduction and
extraction functions to use.
store: str
Path to hdf5 storage file. Must contain cell outlines.
tiler: pipeline-core.core.segmentation tiler
Class that contains or fetches the image to be used for segmentation.
default_meta = {
"pixel_size": 0.236,
"z_size": 0.6,
"spacing": 0.6,
}
self,
parameters: ExtractorParameters,
store: str = None,
tiler: Tiler = None,
"""
Initialise Extractor.
Parameters
----------
parameters: ExtractorParameters object
store: str
Name of h5 file
tiler: Tiler object
"""
if store:
self.local = store
self.load_meta()
self.meta = {"channel": parameters.to_dict()["tree"].keys()}
if tiler:
self.tiler = tiler
cls,
parameters: ExtractorParameters,
store: str,
tiler: Tiler,
return cls(parameters, store=store, tiler=tiler)
@classmethod
cls,
parameters: ExtractorParameters,
store: str,
img_meta: tuple,
return cls(parameters, store=store, tiler=Tiler(*img_meta))
@property
def channels(self):
if not hasattr(self, "_channels"):
if type(self.params.tree) is dict:
self._channels = tuple(self.params.tree.keys())
return self._channels
@property
def current_position(self):
return self.local.split("/")[-1][:-3]
@property
if not hasattr(self, "_out_path"):
self._group = "/extraction/"
return self._group
def load_custom_funs(self):
"""
Define any custom functions to be functions of cell_masks and trap_image only.
Any other parameters are taken from the experiment's metadata and automatically applied. These parameters therefore must be loaded within an Extractor instance.
funs = set(
[
fun
for ch in self.params.tree.values()
for red in ch.values()
for fun in red
]
)
k: {k2: self.get_meta(k2) for k2 in v}
for k, v in CUSTOM_ARGS.items()
# define custom functions - those with extra arguments other than cell_masks and trap_image - as functions of two variables
self._custom_funs = {}
for k, f in CUSTOM_FUNS.items():
def tmp(f):
# pass extra arguments to custom function
return lambda cell_masks, trap_image: trap_apply(
f, cell_masks, trap_image, **ARG_VALS.get(k, {})
self._custom_funs[k] = tmp(f)
def load_funs(self):
self.load_custom_funs()
self._all_cell_funs = set(self._custom_funs.keys()).union(CELL_FUNS)
self._all_funs = {**self._custom_funs, **FUNS}
def load_meta(self):
self.meta = load_attributes(self.local)
def get_traps(
self, tp: int, channels: list = None, z: list = None, **kwargs
) -> tuple:
"""
Finds traps for a given time point and given channels and z-stacks.
Returns None if no traps are found.
Any additional keyword arguments are passed to tiler.get_traps_timepoint
Parameters
----------
tp: int
Time point of interest
channels: list of strings (optional)
Channels of interest
z: list of integers (optional)
Indices for the z-stacks of interest
"""
channel_ids = list(range(len(self.tiler.channels)))
elif len(channels):
channel_ids = [self.tiler.get_channel_index(ch) for ch in channels]
else:
if z is None:
z = list(range(self.tiler.shape[-1]))
self.tiler.get_traps_timepoint(
tp, channels=channel_ids, z=z, **kwargs
)
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
if channel_ids
else None
)
return traps
def extract_traps(
self,
traps: List[np.array],
masks: List[np.array],
metric: str,
labels: List[int] = None,
) -> dict:
"""
Apply a function for a whole position.
:traps: List[np.array] list of images
:masks: List[np.array] list of masks
:metric:str metric to extract
:labels: List[int] cell Labels to use as indices for output DataFrame
:pos_info: bool Whether to add the position as index or not
returns
:d: Dictionary of dataframe
"""
if labels is None:
raise Warning("No labels given. Sorting cells using index.")
cell_fun = True if metric in self._all_cell_funs else False
idx = []
results = []
for trap_id, (mask_set, trap, lbl_set) in enumerate(
zip(masks, traps, labels.values())
):
if len(mask_set): # ignore empty traps
result = self._all_funs[metric](mask_set, trap)
if cell_fun:
for lbl, val in zip(lbl_set, result):
results.append(val)
idx.append((trap_id, lbl))
else:
results.append(result)
idx.append(trap_id)
return (tuple(results), tuple(idx))
def extract_funs(
self,
traps: List[np.array],
masks: List[np.array],
metrics: List[str],
**kwargs,
) -> dict:
"""
Extract multiple metrics from a timepoint
"""
d = {
metric: self.extract_traps(
traps=traps, masks=masks, metric=metric, **kwargs
)
for metric in metrics
}
return d
def reduce_extract(
self,
traps: np.array,
masks: list,
red_metrics: dict,
**kwargs,
Wrapper to apply reduction and then extraction.
Parameters
----------
param red_metrics: dict
dict for which keys are reduction funcions and values are strings indicating the metric function
**kwargs: dict
All other arguments and must include masks and traps.
Returns
------
Dictionary of dataframes with the corresponding reductions and metrics nested.
reduced_traps = {}
if traps is not None:
for red_fun in red_metrics.keys():
reduced_traps[red_fun] = [
self.reduce_dims(trap, method=RED_FUNS[red_fun])
for trap in traps
]
d = {
red_fun: self.extract_funs(
metrics=metrics,
traps=reduced_traps.get(red_fun, [None for _ in masks]),
masks=masks,
**kwargs,
)
for red_fun, metrics in red_metrics.items()
}
return d
def reduce_dims(self, img: np.array, method=None) -> np.array:
self,
tp: int,
tree: dict = None,
tile_size: int = 117,
masks=None,
labels=None,
**kwargs,
) -> t.Dict[str, t.Dict[str, pd.Series]]:
"""
Core extraction method for an individual time-point.
Parameters
----------
tp : int
tree : dict
Nested dictionary indicating channels, reduction functions and
tile_size : int
size of the tile to be extracted.
masks : np.ndarray
A 3-D boolean numpy array with dimensions (ncells, tile_size,
tile_size).
List of lists of ints indicating the ids of masks.
**kwargs : Additional keyword arguments to be passed to extractor.reduce_extract.
Returns
-------
dict
# dictionary with channel: {reduction algorithm : metric}
ch_tree = {ch: v for ch, v in tree.items() if ch != "general"}
if labels is None:
raw_labels = cells.labels_at_time(tp)
labels = {
trap_id: raw_labels.get(trap_id, [])
for trap_id in range(cells.ntraps)
if masks is None:
raw_masks = cells.at_time(tp, kind="mask")
masks = {trap_id: [] for trap_id in range(cells.ntraps)}
for trap_id, cells in raw_masks.items():
if len(cells):
masks[trap_id] = np.dstack(np.array(cells)).astype(bool)
traps = self.get_traps(tp, tile_shape=tile_size, channels=tree_chs)
self.img_bgsub = {}
if self.params.sub_bg:
bg = [
~np.sum(m, axis=2).astype(bool)
if np.any(m)
else np.zeros((tile_size, tile_size))
for m in masks
]
d = {}
for ch, red_metrics in tree.items():
img = None
# ch != is necessary for threading
if ch != "general" and traps is not None and len(traps):
img = traps[:, tree_chs.index(ch), 0]
d[ch] = self.reduce_extract(
red_metrics=red_metrics,
traps=img,
masks=masks,
labels=labels,
**kwargs,
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
)
if (
ch in self.params.sub_bg and img is not None
): # Calculate metrics with subtracted bg
ch_bs = ch + "_bgsub"
self.img_bgsub[ch_bs] = []
for trap, maskset in zip(img, bg):
cells_fl = np.zeros_like(trap)
is_cell = np.where(maskset)
if len(is_cell[0]): # skip calculation for empty traps
cells_fl = np.median(trap[is_cell], axis=0)
self.img_bgsub[ch_bs].append(trap - cells_fl)
d[ch_bs] = self.reduce_extract(
red_metrics=ch_tree[ch],
traps=self.img_bgsub[ch_bs],
masks=masks,
labels=labels,
**kwargs,
)
# Additional operations between multiple channels (e.g. pH calculations)
for name, (
chs,
merge_fun,
red_metrics,
) in self.params.multichannel_ops.items():
set(chs).intersection(
set(self.img_bgsub.keys()).union(tree_chs)
)
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
) == len(chs):
imgs = [self.get_imgs(ch, traps, tree_chs) for ch in chs]
merged = MERGE_FUNS[merge_fun](*imgs)
d[name] = self.reduce_extract(
red_metrics=red_metrics,
traps=merged,
masks=masks,
labels=labels,
**kwargs,
)
return d
def get_imgs(self, channel, traps, channels=None):
"""
Returns the image from a correct source, either raw or bgsub
:channel: str name of channel to get
:img: ndarray (trap_id, channel, tp, tile_size, tile_size, n_zstacks) of standard channels
:channels: List of channels
"""
if channels is None:
channels = (*self.params.tree,)
if channel in channels:
return traps[:, channels.index(channel), 0]
elif channel in self.img_bgsub:
return self.img_bgsub[channel]
def run_tp(self, tp, **kwargs):
"""
Wrapper to add compatiblibility with other pipeline steps
"""
return self.run(tps=[tp], **kwargs)
def run(
self, tree=None, tps: List[int] = None, save=True, **kwargs
) -> dict:
if tree is None:
tree = self.params.tree
if tps is None:
tps = list(range(self.meta["time_settings/ntimepoints"][0]))
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
d = {}
for tp in tps:
new = flatten_nest(
self.extract_tp(tp=tp, tree=tree, **kwargs),
to="series",
tp=tp,
)
for k in new.keys():
n = new[k]
d[k] = pd.concat((d.get(k, None), n), axis=1)
for k in d.keys():
indices = ["experiment", "position", "trap", "cell_label"]
idx = (
indices[-d[k].index.nlevels :]
if d[k].index.nlevels > 1
else [indices[-2]]
)
d[k].index.names = idx
toreturn = d
if save:
self.save_to_hdf(toreturn)
return toreturn
def extract_pos(
self, tree=None, tps: List[int] = None, save=True, **kwargs
) -> dict:
if tree is None:
tree = self.params.tree
if tps is None:
tps = list(range(self.meta["time_settings/ntimepoints"]))
d = {}
for tp in tps:
new = flatten_nest(
self.extract_tp(tp=tp, tree=tree, **kwargs),
to="series",
tp=tp,
)
for k in new.keys():
n = new[k]
d[k] = pd.concat((d.get(k, None), n), axis=1)
for k in d.keys():
indices = ["experiment", "position", "trap", "cell_label"]
idx = (
indices[-d[k].index.nlevels :]
if d[k].index.nlevels > 1
else [indices[-2]]
)
d[k].index.names = idx
toreturn = d
if save:
self.save_to_hdf(toreturn)
return toreturn
def save_to_hdf(self, group_df, path=None):
if path is None:
path = self.local
self.writer = Writer(path)
for path, df in group_df.items():
dset_path = "/extraction/" + path
self.writer.write(dset_path, df)
self.writer.id_cache.clear()
def get_meta(self, flds):
if not hasattr(flds, "__iter__"):
flds = [flds]
meta_short = {k.split("/")[-1]: v for k, v in self.meta.items()}
return {
f: meta_short.get(f, self.default_meta.get(f, None)) for f in flds
}
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
### Helpers
def flatten_nest(nest: dict, to="series", tp: int = None) -> dict:
"""
Convert a nested extraction dict into a dict of series
:param nest: dict contained the nested results of extraction
:param to: str = 'series' Determine output format, either list or pd.Series
:param tp: int timepoint used to name the series
"""
d = {}
for k0, v0 in nest.items():
for k1, v1 in v0.items():
for k2, v2 in v1.items():
d["/".join((k0, k1, k2))] = (
pd.Series(*v2, name=tp) if to == "series" else v2
)
return d
def fill_tree(tree):
if tree is None:
return None
tree_depth = depth(tree)
if depth(tree) < 3:
d = {None: {None: {None: []}}}
for _ in range(2 - tree_depth):
d = d[None]
d[None] = tree
tree = d
return tree