Skip to content
Snippets Groups Projects
Commit 43b39090 authored by pswain's avatar pswain
Browse files

partway through add multichannel op to extractor

parent 109b10f2
No related branches found
No related tags found
No related merge requests found
...@@ -140,7 +140,7 @@ class Extractor(StepABC): ...@@ -140,7 +140,7 @@ class Extractor(StepABC):
[c + "_bgsub" for c in self.params.sub_bg] [c + "_bgsub" for c in self.params.sub_bg]
) )
# remove any multichannel operations requiring a missing channel # remove any multichannel operations requiring a missing channel
for op, (input_ch, _, _) in dict(self.params.multichannel_ops): for op, (input_ch, _, _) in self.params.multichannel_ops.items():
if not set(input_ch).issubset(available_channels_bgsub): if not set(input_ch).issubset(available_channels_bgsub):
self.params.multichannel_ops.pop(op) self.params.multichannel_ops.pop(op)
self.load_funs() self.load_funs()
...@@ -306,8 +306,6 @@ class Extractor(StepABC): ...@@ -306,8 +306,6 @@ class Extractor(StepABC):
cell_labels: dict cell_labels: dict
A dict of cell labels with trap_ids as keys and a list A dict of cell labels with trap_ids as keys and a list
of cell labels as values. of cell labels as values.
pos_info: bool
Whether to add the position as an index or not.
Returns Returns
------- -------
...@@ -557,19 +555,18 @@ class Extractor(StepABC): ...@@ -557,19 +555,18 @@ class Extractor(StepABC):
""" """
Extract using all metrics requiring multiple channels. Extract using all metrics requiring multiple channels.
""" """
available_chs = set(self.img_bgsub.keys()).union(
tree_bits["tree_channels"]
)
d = {} d = {}
for name, ( for name, (
chs, chs,
merge_fun, reduction_fun,
red_metrics, op,
) in self.params.multichannel_ops.items(): ) in self.params.multichannel_ops.items():
if len( common_chs = set(chs).intersection(available_chs)
set(chs).intersection( # all required channels should be available
set(self.img_bgsub.keys()).union( if len(common_chs) == len(chs):
tree_bits["tree_channels"]
)
)
) == len(chs):
channels_stack = np.stack( channels_stack = np.stack(
[ [
self.get_imgs(ch, tiles, tree_bits["tree_channels"]) self.get_imgs(ch, tiles, tree_bits["tree_channels"])
...@@ -577,13 +574,18 @@ class Extractor(StepABC): ...@@ -577,13 +574,18 @@ class Extractor(StepABC):
], ],
axis=-1, axis=-1,
) )
merged = RED_FUNS[merge_fun](channels_stack, axis=-1) # reduce in Z
d[name] = self.reduce_extract( traps = RED_FUNS[reduction_fun](channels_stack, axis=1)
red_metrics=red_metrics, # evaluate multichannel op
traps=merged, if name not in d:
masks=masks, d[name] = {}
cell_labels=cell_labels, if reduction_fun not in d[name]:
**kwargs, d[name][reduction_fun] = {}
d[name][reduction_fun][op] = self.extract_traps(
traps,
masks,
op,
cell_labels,
) )
return d return d
...@@ -652,10 +654,10 @@ class Extractor(StepABC): ...@@ -652,10 +654,10 @@ class Extractor(StepABC):
res_one, self.img_bgsub = self.extract_one_channel( res_one, self.img_bgsub = self.extract_one_channel(
tree_bits, cell_labels, tiles, masks, bgs, **kwargs tree_bits, cell_labels, tiles, masks, bgs, **kwargs
) )
res_two = self.extract_multiple_channels( res_multiple = self.extract_multiple_channels(
tree_bits, cell_labels, tiles, masks, **kwargs tree_bits, cell_labels, tiles, masks, **kwargs
) )
res = {**res_one, **res_two} res = {**res_one, **res_multiple}
return res return res
def get_imgs(self, channel: t.Optional[str], tiles, channels=None): def get_imgs(self, channel: t.Optional[str], tiles, channels=None):
......
...@@ -229,3 +229,14 @@ def moment_of_inertia(cell_mask, trap_image): ...@@ -229,3 +229,14 @@ def moment_of_inertia(cell_mask, trap_image):
return moi return moi
else: else:
return np.nan return np.nan
def ratio(cell_mask, trap_image):
"""Find the median ratio between two fluorescence channels."""
if trap_image.ndim == 3 and trap_image.shape[-1] == 2:
fl_1 = trap_image[..., 0][cell_mask]
fl_2 = trap_image[..., 1][cell_mask]
div = np.median(fl_1 / fl_2)
else:
div = np.nan
return div
...@@ -56,9 +56,9 @@ def test_extractor(imgs, masks, tree): ...@@ -56,9 +56,9 @@ def test_extractor(imgs, masks, tree):
for ch_branches in extractor.params.tree.values(): for ch_branches in extractor.params.tree.values():
print( print(
extractor.reduce_extract( extractor.reduce_extract(
red_metrics=ch_branches, [traps],
traps=[traps], [masks],
masks=[masks], ch_branches,
labels={0: labels}, labels={0: labels},
) )
) )
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment