From e55aed3dc4af23e4097738ae8b12fcd485260606 Mon Sep 17 00:00:00 2001 From: Peter Swain <peter.swain@ed.ac.uk> Date: Sat, 3 Sep 2022 18:16:51 +0100 Subject: [PATCH] more extractor --- extraction/core/extractor.py | 271 +++++++++++++++++++++-------------- 1 file changed, 164 insertions(+), 107 deletions(-) diff --git a/extraction/core/extractor.py b/extraction/core/extractor.py index 773da141..c2406a16 100644 --- a/extraction/core/extractor.py +++ b/extraction/core/extractor.py @@ -166,6 +166,7 @@ class Extractor(ProcessABC): return self._channels @property + # Alan: does this work. local is not a string. def current_position(self): return self.local.split("/")[-1][:-3] @@ -270,52 +271,53 @@ class Extractor(ProcessABC): traps: List[np.array], masks: List[np.array], metric: str, - labels: List[int] = None, + labels: Dict = None, ) -> dict: """ - Apply a function for a whole position. + Apply a function to a whole position. Parameters ---------- - traps: List[np.array] - List of images - masks: List[np.array] - List of masks + traps: list of arrays + List of images. + masks: list of arrays + List of masks. metric: str - Metric to extract - labels: List[int] - Cell labels to use as indices in output dataFrame + Metric to extract. + labels: dict + A dict of cell labels with trap_ids as keys and a list of cell labels as values. pos_info: bool - Whether to add the position as an index or not + Whether to add the position as an index or not. Returns ------- - d: dict - A dictionary of dataframes + res_idx: a tuple of tuples + A two-tuple of a tuple of results and a tuple with the corresponding trap_id and cell labels """ - if labels is None: + # Alan: it looks like this will crash if Labels is None raise Warning("No labels given. Sorting cells using index.") - cell_fun = True if metric in self._all_cell_funs else False - idx = [] results = [] - for trap_id, (mask_set, trap, lbl_set) in enumerate( zip(masks, traps, labels.values()) ): - if len(mask_set): # ignore empty traps + # ignore empty traps + if len(mask_set): + # apply metric either a cell function or otherwise result = self._all_funs[metric](mask_set, trap) if cell_fun: + # store results for each cell separately for lbl, val in zip(lbl_set, result): results.append(val) idx.append((trap_id, lbl)) else: + # background (trap) function results.append(result) idx.append(trap_id) - - return (tuple(results), tuple(idx)) + res_idx = (tuple(results), tuple(idx)) + return res_idx def extract_funs( self, @@ -406,7 +408,7 @@ class Extractor(ProcessABC): masks=None, labels=None, **kwargs, - ) -> t.Dict[str, t.Dict[str, pd.Series]]: + ) -> t.Dict[str, t.Dict[str, t.Dict[str, tuple]]]: """ Core extraction method for an individual time-point. @@ -420,17 +422,25 @@ class Extractor(ProcessABC): For example: {'general': {'None': ['area', 'volume', 'eccentricity']}} tile_size : int Size of the tile to be extracted. - masks : np.ndarray - A 3d boolean numpy array with dimensions (ncells, tile_size, + masks : list of arrays + A list of masks per trap with each mask having dimensions (ncells, tile_size, tile_size). - labels : t.List[t.List[int]] - List of lists of ints indicating the ids of masks. + labels : dict + A dictionary with trap_ids as keys and cell_labels as values. **kwargs : keyword arguments Passed to extractor.reduce_extract. Returns ------- - dict + d: dict + Dictionary of the results with three levels of dictionaries. + The first level has channels as keys. + The second level has reduction metrics as keys. + The third level has cell or background metrics as keys and a two-tuple as values. + The first tuple is the result of applying the metrics to a particular cell or trap; the second tuple is either (trap_id, cell_label) for a metric applied to a cell or a trap_id for a metric applied to a trap. + + An example is d["GFP"]["np_max"]["mean"][0], which gives a tuple of the calculated mean GFP fluorescence for all cells. + """ if tree is None: # use default @@ -465,26 +475,27 @@ class Extractor(ProcessABC): # Alan: traps does not appear the best name here! traps = self.get_traps(tp, tile_shape=tile_size, channels=tree_chs) - self.img_bgsub = {} + # generate boolean masks for background as a list with one mask per trap if self.params.sub_bg: - # generate boolean masks for background as a list with one mask per trap - bg = [ + bgs = [ ~np.sum(m, axis=2).astype(bool) if np.any(m) else np.zeros((tile_size, tile_size)) for m in masks ] + # perform extraction by applying metrics d = {} + self.img_bgsub = {} for ch, red_metrics in tree.items(): - # image data for all traps and z sections for a particular channel - # as an array arranged as (traps, X, Y, Z) # NB ch != is necessary for threading if ch != "general" and traps is not None and len(traps): + # image data for all traps and z sections for a particular channel + # as an array arranged as (no traps, X, Y, no Z channels) img = traps[:, tree_chs.index(ch), 0] else: img = None - + # apply metrics to image data d[ch] = self.reduce_extract( red_metrics=red_metrics, traps=img, @@ -492,23 +503,21 @@ class Extractor(ProcessABC): labels=labels, **kwargs, ) - - if ( - ch in self.params.sub_bg and img is not None - ): # Calculate metrics with subtracted bg + # apply metrics to image data with the background subtracted + if ch in self.params.sub_bg and img is not None: + # calculate metrics with subtracted bg ch_bs = ch + "_bgsub" - self.img_bgsub[ch_bs] = [] - for trap, maskset in zip(img, bg): - + for trap, bg in zip(img, bgs): cells_fl = np.zeros_like(trap) - - is_cell = np.where(maskset) - if len(is_cell[0]): # skip calculation for empty traps + # Alan: should this not be is_not_cell? + is_cell = np.where(bg) + # skip for empty traps + if len(is_cell[0]): cells_fl = np.median(trap[is_cell], axis=0) - + # subtract median background self.img_bgsub[ch_bs].append(trap - cells_fl) - + # apply metrics to background-corrected data d[ch_bs] = self.reduce_extract( red_metrics=ch_tree[ch], traps=self.img_bgsub[ch_bs], @@ -517,7 +526,7 @@ class Extractor(ProcessABC): **kwargs, ) - # Additional operations between multiple channels (e.g. pH calculations) + # apply any metrics that use multiple channels (eg pH calculations) for name, ( chs, merge_fun, @@ -544,14 +553,22 @@ class Extractor(ProcessABC): """ Returns the image from a correct source, either raw or bgsub - :channel: str name of channel to get - :img: ndarray (trap_id, channel, tp, tile_size, tile_size, n_zstacks) of standard channels - :channels: List of channels - """ + Parameters + ---------- + channel: str + Name of channel to get. + traps: ndarray + An array of the image data having dimensions of (trap_id, channel, tp, tile_size, tile_size, n_zstacks). + channels: list of str (optional) + List of available channels. + Returns + ------- + img: ndarray + An array of image data with dimensions (no traps, X, Y, no Z channels) + """ if channels is None: channels = (*self.params.tree,) - if channel in channels: return traps[:, channels.index(channel), 0] elif channel in self.img_bgsub: @@ -559,32 +576,53 @@ class Extractor(ProcessABC): def run_tp(self, tp, **kwargs): """ - Wrapper to add compatiblibility with other pipeline steps + Wrapper to add compatiblibility with other steps of the pipeline. """ return self.run(tps=[tp], **kwargs) def run( - self, tree=None, tps: List[int] = None, save=True, **kwargs + self, + tree=None, + tps: List[int] = None, + save=True, + **kwargs, ) -> dict: + """ + Parameters + ---------- + tree: dict + Nested dictionary indicating channels, reduction functions and + metrics to be used. + For example: {'general': {'None': ['area', 'volume', 'eccentricity']}} + tps: list of int (optional) + Time points to include. + save: boolean (optional) + If True, save results to h5 file. + kwargs: keyword arguments (optional) + Passed to extract_tp. + Returns + ------- + d: dict + A dict of the extracted data with a concatenated string of channel, reduction metric, and cell metric as keys and pd.Series of the extracted data as values. + """ if tree is None: tree = self.params.tree - if tps is None: tps = list(range(self.meta["time_settings/ntimepoints"][0])) - + # store results in dict d = {} for tp in tps: - new = flatten_nest( + # extract for each time point and convert to dict of pd.Series + new = flatten_nesteddict( self.extract_tp(tp=tp, tree=tree, **kwargs), to="series", tp=tp, ) - + # concatenate with data extracted from early time points for k in new.keys(): - n = new[k] - d[k] = pd.concat((d.get(k, None), n), axis=1) - + d[k] = pd.concat((d.get(k, None), new[k]), axis=1) + # add indices to pd.Series containing the extracted data for k in d.keys(): indices = ["experiment", "position", "trap", "cell_label"] idx = ( @@ -593,65 +631,73 @@ class Extractor(ProcessABC): else [indices[-2]] ) d[k].index.names = idx - - toreturn = d - + # save if save: - self.save_to_hdf(toreturn) + self.save_to_hdf(d) + return d - return toreturn + # Alan: isn't this identical to run? + # def extract_pos( + # self, tree=None, tps: List[int] = None, save=True, **kwargs + # ) -> dict: - def extract_pos( - self, tree=None, tps: List[int] = None, save=True, **kwargs - ) -> dict: + # if tree is None: + # tree = self.params.tree - if tree is None: - tree = self.params.tree + # if tps is None: + # tps = list(range(self.meta["time_settings/ntimepoints"])) - if tps is None: - tps = list(range(self.meta["time_settings/ntimepoints"])) + # d = {} + # for tp in tps: + # new = flatten_nest( + # self.extract_tp(tp=tp, tree=tree, **kwargs), + # to="series", + # tp=tp, + # ) - d = {} - for tp in tps: - new = flatten_nest( - self.extract_tp(tp=tp, tree=tree, **kwargs), - to="series", - tp=tp, - ) + # for k in new.keys(): + # n = new[k] + # d[k] = pd.concat((d.get(k, None), n), axis=1) - for k in new.keys(): - n = new[k] - d[k] = pd.concat((d.get(k, None), n), axis=1) + # for k in d.keys(): + # indices = ["experiment", "position", "trap", "cell_label"] + # idx = ( + # indices[-d[k].index.nlevels :] + # if d[k].index.nlevels > 1 + # else [indices[-2]] + # ) + # d[k].index.names = idx - for k in d.keys(): - indices = ["experiment", "position", "trap", "cell_label"] - idx = ( - indices[-d[k].index.nlevels :] - if d[k].index.nlevels > 1 - else [indices[-2]] - ) - d[k].index.names = idx + # toreturn = d - toreturn = d + # if save: + # self.save_to_hdf(toreturn) - if save: - self.save_to_hdf(toreturn) + # return toreturn - return toreturn + def save_to_hdf(self, dict_series, path=None): + """ + Save the extracted data to the h5 file. - def save_to_hdf(self, group_df, path=None): + Parameters + ---------- + dict_series: dict + A dictionary of the extracted data, created by run. + path: Path (optional) + To the h5 file. + """ if path is None: path = self.local - self.writer = Writer(path) - for path, df in group_df.items(): - dset_path = "/extraction/" + path - self.writer.write(dset_path, df) + for extract_name, series in dict_series.items(): + dset_path = "/extraction/" + extract_name + self.writer.write(dset_path, series) self.writer.id_cache.clear() def get_meta(self, flds): + # Alan: unsure what this is doing. seems to break for "nuc_conv_3d" + # make flds a list if not hasattr(flds, "__iter__"): - # make flds a list flds = [flds] meta_short = {k.split("/")[-1]: v for k, v in self.meta.items()} return { @@ -660,14 +706,24 @@ class Extractor(ProcessABC): ### Helpers -def flatten_nest(nest: dict, to="series", tp: int = None) -> dict: - """ - Convert a nested extraction dict into a dict of series - :param nest: dict contained the nested results of extraction - :param to: str = 'series' Determine output format, either list or pd.Series - :param tp: int timepoint used to name the series +def flatten_nesteddict(nest: dict, to="series", tp: int = None) -> dict: """ + Converts a nested extraction dict into a dict of pd.Series + Parameters + ---------- + nest: dict of dicts + Contains the nested results of extraction. + to: str (optional) + Specifies the format of the output, either pd.Series (default) or a list + tp: int + Timepoint used to name the pd.Series + + Returns + ------- + d: dict + A dict with a concatenated string of channel, reduction metric, and cell metric as keys and either a pd.Series or a list of the corresponding extracted data as values. + """ d = {} for k0, v0 in nest.items(): for k1, v1 in v0.items(): @@ -675,10 +731,10 @@ def flatten_nest(nest: dict, to="series", tp: int = None) -> dict: d["/".join((k0, k1, k2))] = ( pd.Series(*v2, name=tp) if to == "series" else v2 ) - return d +# Alan: this no longer seems to be used def fill_tree(tree): if tree is None: return None @@ -693,8 +749,9 @@ def fill_tree(tree): class hollowExtractor(Extractor): - """Extractor that only cares about receiving image and masks, - used for testing. + """ + Extractor that only cares about receiving images and masks. + Used for testing. """ def __init__(self, parameters): -- GitLab