From 43b390905538c514f425d9b6be4758d55158c954 Mon Sep 17 00:00:00 2001 From: Swainlab <peter.swain@ed.ac.uk> Date: Mon, 14 Aug 2023 16:26:46 +0100 Subject: [PATCH] partway through add multichannel op to extractor --- src/extraction/core/extractor.py | 44 +++++++++++---------- src/extraction/core/functions/cell.py | 11 ++++++ tests/aliby/network/extraction/test_base.py | 6 +-- 3 files changed, 37 insertions(+), 24 deletions(-) diff --git a/src/extraction/core/extractor.py b/src/extraction/core/extractor.py index 24cf7171..6d025188 100644 --- a/src/extraction/core/extractor.py +++ b/src/extraction/core/extractor.py @@ -140,7 +140,7 @@ class Extractor(StepABC): [c + "_bgsub" for c in self.params.sub_bg] ) # remove any multichannel operations requiring a missing channel - for op, (input_ch, _, _) in dict(self.params.multichannel_ops): + for op, (input_ch, _, _) in self.params.multichannel_ops.items(): if not set(input_ch).issubset(available_channels_bgsub): self.params.multichannel_ops.pop(op) self.load_funs() @@ -306,8 +306,6 @@ class Extractor(StepABC): cell_labels: dict A dict of cell labels with trap_ids as keys and a list of cell labels as values. - pos_info: bool - Whether to add the position as an index or not. Returns ------- @@ -557,19 +555,18 @@ class Extractor(StepABC): """ Extract using all metrics requiring multiple channels. """ + available_chs = set(self.img_bgsub.keys()).union( + tree_bits["tree_channels"] + ) d = {} for name, ( chs, - merge_fun, - red_metrics, + reduction_fun, + op, ) in self.params.multichannel_ops.items(): - if len( - set(chs).intersection( - set(self.img_bgsub.keys()).union( - tree_bits["tree_channels"] - ) - ) - ) == len(chs): + common_chs = set(chs).intersection(available_chs) + # all required channels should be available + if len(common_chs) == len(chs): channels_stack = np.stack( [ self.get_imgs(ch, tiles, tree_bits["tree_channels"]) @@ -577,13 +574,18 @@ class Extractor(StepABC): ], axis=-1, ) - merged = RED_FUNS[merge_fun](channels_stack, axis=-1) - d[name] = self.reduce_extract( - red_metrics=red_metrics, - traps=merged, - masks=masks, - cell_labels=cell_labels, - **kwargs, + # reduce in Z + traps = RED_FUNS[reduction_fun](channels_stack, axis=1) + # evaluate multichannel op + if name not in d: + d[name] = {} + if reduction_fun not in d[name]: + d[name][reduction_fun] = {} + d[name][reduction_fun][op] = self.extract_traps( + traps, + masks, + op, + cell_labels, ) return d @@ -652,10 +654,10 @@ class Extractor(StepABC): res_one, self.img_bgsub = self.extract_one_channel( tree_bits, cell_labels, tiles, masks, bgs, **kwargs ) - res_two = self.extract_multiple_channels( + res_multiple = self.extract_multiple_channels( tree_bits, cell_labels, tiles, masks, **kwargs ) - res = {**res_one, **res_two} + res = {**res_one, **res_multiple} return res def get_imgs(self, channel: t.Optional[str], tiles, channels=None): diff --git a/src/extraction/core/functions/cell.py b/src/extraction/core/functions/cell.py index 4d97f23a..f668c567 100644 --- a/src/extraction/core/functions/cell.py +++ b/src/extraction/core/functions/cell.py @@ -229,3 +229,14 @@ def moment_of_inertia(cell_mask, trap_image): return moi else: return np.nan + + +def ratio(cell_mask, trap_image): + """Find the median ratio between two fluorescence channels.""" + if trap_image.ndim == 3 and trap_image.shape[-1] == 2: + fl_1 = trap_image[..., 0][cell_mask] + fl_2 = trap_image[..., 1][cell_mask] + div = np.median(fl_1 / fl_2) + else: + div = np.nan + return div diff --git a/tests/aliby/network/extraction/test_base.py b/tests/aliby/network/extraction/test_base.py index 14209dbd..1ba8b452 100644 --- a/tests/aliby/network/extraction/test_base.py +++ b/tests/aliby/network/extraction/test_base.py @@ -56,9 +56,9 @@ def test_extractor(imgs, masks, tree): for ch_branches in extractor.params.tree.values(): print( extractor.reduce_extract( - red_metrics=ch_branches, - traps=[traps], - masks=[masks], + [traps], + [masks], + ch_branches, labels={0: labels}, ) ) -- GitLab