Skip to content
Snippets Groups Projects
Commit 4fc31051 authored by pswain's avatar pswain
Browse files

extended bgsub to multichannel functions

parent e5d83cf7
No related branches found
No related tags found
No related merge requests found
......@@ -306,15 +306,15 @@ class Extractor(StepABC):
# tiles has dimensions (tiles, channels, 1, Z, X, Y)
return tiles
def extract_traps(
def apply_cell_function(
self,
traps: t.List[np.ndarray],
masks: t.List[np.ndarray],
cell_property: str,
cell_function: str,
cell_labels: t.Dict[int, t.List[int]],
) -> t.Tuple[t.Union[t.Tuple[float], t.Tuple[t.Tuple[int]]]]:
"""
Apply a function to a whole position.
Apply a cell function to all cells at all traps for one time point.
Parameters
----------
......@@ -322,11 +322,11 @@ class Extractor(StepABC):
t.List of images.
masks: list of arrays
t.List of masks.
cell_property: str
Property to extract, including imBackground.
cell_function: str
Function to apply.
cell_labels: dict
A dict of cell labels with trap_ids as keys and a list
of cell labels as values.
A dict with trap_ids as keys and a list of cell labels as
values.
Returns
-------
......@@ -336,7 +336,7 @@ class Extractor(StepABC):
"""
if cell_labels is None:
self._log("No cell labels given. Sorting cells using index.")
cell_fun = True if cell_property in self.all_cell_funs else False
cell_fun = True if cell_function in self.all_cell_funs else False
idx = []
results = []
for trap_id, (mask_set, trap, local_cell_labels) in enumerate(
......@@ -345,7 +345,7 @@ class Extractor(StepABC):
# ignore empty traps
if len(mask_set):
# find property from the tile
result = self.all_funs[cell_property](mask_set, trap)
result = self.all_funs[cell_function](mask_set, trap)
if cell_fun:
# store results for each cell separately
for cell_label, val in zip(local_cell_labels, result):
......@@ -371,8 +371,8 @@ class Extractor(StepABC):
Data from one time point is used.
"""
d = {
cell_fun: self.extract_traps(
traps=tiles, masks=masks, cell_property=cell_fun, **kwargs
cell_fun: self.apply_cell_function(
traps=tiles, masks=masks, cell_function=cell_fun, **kwargs
)
for cell_fun in cell_funs
}
......@@ -513,88 +513,66 @@ class Extractor(StepABC):
return bgs
def extract_one_channel(
self, tree_dict, cell_labels, tiles, masks, bgs, **kwargs
self, tree_dict, cell_labels, img, img_bgsub, masks, **kwargs
):
"""
Extract all metrics requiring only a single channel.
Apply first without and then with background subtraction.
Return the extraction results and a dict of background
corrected images.
"""
"""Extract as dict all metrics requiring only a single channel."""
d = {}
for ch, reduction_cell_funs in tree_dict["tree"].items():
# NB ch != is necessary for threading
if ch != "general" and tiles is not None and len(tiles):
# image data for all traps for a particular channel and
# time point arranged as (traps, Z, X, Y)
# we use 0 here to access the single time point available
channel_tile = tiles[:, tree_dict["channels"].index(ch), 0]
else:
# no reduction applied to "general" - bright-field images
channel_tile = None
# apply metrics to image data
# extract from all images including bright field
d[ch] = self.reduce_extract(
tiles=channel_tile,
# use None for "general"; no fluorescence image
tiles=img.get(ch, None),
masks=masks,
reduction_cell_funs=reduction_cell_funs,
cell_labels=cell_labels,
**kwargs,
)
# apply metrics to image data with the background subtracted
if (
bgs.any()
and ch in self.params.sub_bg
and channel_tile is not None
):
# apply metrics to background-corrected data
if ch != "general":
# extract from background-corrected fluorescence images
d[ch + "_bgsub"] = self.reduce_extract(
tiles=self.img_bgsub[ch + "_bgsub"],
tiles=img_bgsub[ch + "_bgsub"],
masks=masks,
reduction_cell_funs=tree_dict["channels_tree"][ch],
reduction_cell_funs=reduction_cell_funs,
cell_labels=cell_labels,
**kwargs,
)
return d
def extract_multiple_channels(
self, tree_dict, cell_labels, tiles, masks, **kwargs
):
"""
Extract all metrics requiring multiple channels.
"""
# channels and background corrected channels
available_chs = set(self.img_bgsub.keys()).union(tree_dict["channels"])
def extract_multiple_channels(self, cell_labels, img, img_bgsub, masks):
"""Extract as a dict all metrics requiring multiple channels."""
# NB multichannel functions do not use tree_dict
available_channels = set(list(img.keys()) + list(img_bgsub.keys()))
d = {}
for name, (
chs,
reduction_fun,
op,
for multichannel_fun_name, (
channels,
reduction,
multichannel_function,
) in self.params.multichannel_ops.items():
common_chs = set(chs).intersection(available_chs)
common_channels = set(channels).intersection(available_channels)
# all required channels should be available
if len(common_chs) == len(chs):
channels_stack = np.stack(
[
self.get_imgs(ch, tiles, tree_dict["channels"])
for ch in chs
],
axis=-1,
)
# reduce in Z
traps = REDUCTION_FUNS[reduction_fun](channels_stack, axis=1)
# evaluate multichannel op
if name not in d:
d[name] = {}
if reduction_fun not in d[name]:
d[name][reduction_fun] = {}
d[name][reduction_fun][op] = self.extract_traps(
traps,
masks,
op,
cell_labels,
)
if len(common_channels) == len(channels):
for images, suffix in zip([img, img_bgsub], ["", "_bgsub"]):
# channels
channels_stack = np.stack(
[images[ch + suffix] for ch in channels],
axis=-1,
)
# reduce in Z
tiles = REDUCTION_FUNS[reduction](channels_stack, axis=1)
# set up dict
if multichannel_fun_name not in d:
d[multichannel_fun_name] = {}
if reduction not in d[multichannel_fun_name]:
d[multichannel_fun_name][reduction] = {}
# apply multichannel function
d[multichannel_fun_name][reduction][
multichannel_function + suffix
] = self.apply_cell_function(
tiles,
masks,
multichannel_function,
cell_labels,
)
return d
def extract_tp(
......@@ -656,33 +634,41 @@ class Extractor(StepABC):
tiles = self.get_tiles(tp, channels=tree_dict["channels"])
# generate boolean masks for background for each trap
bgs = self.get_background_masks(masks, tile_size)
# perform background subtraction for all traps at this time point
self.img_bgsub = self.perform_background_subtraction(
# get images and background corrected images as dicts
# with fluorescnce channels as keys
img, img_bgsub = self.get_imgs_background_subtract(
tree_dict, tiles, bgs
)
# perform extraction
res_one = self.extract_one_channel(
tree_dict, cell_labels, tiles, masks, bgs, **kwargs
tree_dict, cell_labels, img, img_bgsub, masks, **kwargs
)
res_multiple = self.extract_multiple_channels(
tree_dict, cell_labels, tiles, masks, **kwargs
cell_labels, img, img_bgsub, masks
)
res = {**res_one, **res_multiple}
return res
def perform_background_subtraction(self, tree_dict, tiles, bgs):
"""Subtract background for fluorescence channels."""
def get_imgs_background_subtract(self, tree_dict, tiles, bgs):
"""
Get two dicts of fluorescence images.
Return images and background subtracted image for all traps
for one time point.
"""
img = {}
img_bgsub = {}
for ch, _ in tree_dict["channels_tree"].items():
# NB ch != is necessary for threading
if tiles is not None and len(tiles):
# image data for all traps for a particular channel and
# time point arranged as (traps, Z, X, Y)
# we use 0 here to access the single time point available
channel_tile = tiles[:, tree_dict["channels"].index(ch), 0]
img[ch] = tiles[:, tree_dict["channels"].index(ch), 0]
if (
bgs.any()
and ch in self.params.sub_bg
and channel_tile is not None
and img[ch] is not None
):
# subtract median background
bgsub_mapping = map(
......@@ -690,7 +676,7 @@ class Extractor(StepABC):
lambda img, bgs: np.moveaxis(img, 0, -1)
# median of background over all pixels for each Z section
- bn.median(img[:, bgs], axis=1),
channel_tile,
img[ch],
bgs,
)
# apply map and convert to array
......@@ -699,9 +685,12 @@ class Extractor(StepABC):
img_bgsub[ch + "_bgsub"] = np.moveaxis(
mapping_result, -1, 1
)
return img_bgsub
else:
img[ch] = None
img_bgsub[ch] = None
return img, img_bgsub
def get_imgs(self, channel: t.Optional[str], tiles, channels=None):
def get_imgs_old(self, channel: t.Optional[str], tiles, channels=None):
"""
Return image from a correct source, either raw or bgsub.
......
......@@ -239,9 +239,12 @@ def moment_of_inertia(cell_mask, trap_image):
def ratio(cell_mask, trap_image):
"""Find the median ratio between two fluorescence channels."""
if trap_image.ndim == 3 and trap_image.shape[-1] == 2:
fl_1 = trap_image[..., 0][cell_mask]
fl_2 = trap_image[..., 1][cell_mask]
div = np.median(fl_1 / fl_2)
fl_0 = trap_image[..., 0][cell_mask]
fl_1 = trap_image[..., 1][cell_mask]
if np.any(fl_1 == 0):
div = np.nan
else:
div = np.median(fl_0 / fl_1)
else:
div = np.nan
return div
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment