Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • swain-lab/aliby/aliby-mirror
  • swain-lab/aliby/alibylite
2 results
Show changes
Showing
with 1107 additions and 878 deletions
......@@ -7,7 +7,7 @@ import pandas as pd
from agora.abc import ParametersABC, StepABC
from agora.io.cells import Cells
from agora.io.writer import Writer, load_attributes
from agora.io.writer import Writer, load_meta
from aliby.tile.tiler import Tiler
from extraction.core.functions.defaults import exparams_from_meta
from extraction.core.functions.distributors import reduce_z, trap_apply
......@@ -16,6 +16,7 @@ from extraction.core.functions.loaders import (
load_funs,
load_redfuns,
)
import aliby.global_parameters as global_parameters
# define types
reduction_method = t.Union[t.Callable, str, None]
......@@ -26,14 +27,14 @@ extraction_result = t.Dict[
str, t.Dict[reduction_method, t.Dict[str, t.Dict[str, pd.Series]]]
]
# Global variables used to load functions that either analyse cells or their background. These global variables both allow the functions to be stored in a dictionary for access only on demand and to be defined simply in extraction/core/functions.
CELL_FUNS, TRAPFUNS, FUNS = load_funs()
# Global variables used to load functions that either analyse cells
# or their background. These global variables both allow the functions
# to be stored in a dictionary for access only on demand and to be
# defined simply in extraction/core/functions.
CELL_FUNS, TRAP_FUNS, ALL_FUNS = load_funs()
CUSTOM_FUNS, CUSTOM_ARGS = load_custom_args()
RED_FUNS = load_redfuns()
# Assign datatype depending on the metric used
# m2type = {"mean": np.float32, "median": np.ubyte, "imBackground": np.ubyte}
class ExtractorParameters(ParametersABC):
"""Base class to define parameters for extraction."""
......@@ -74,21 +75,22 @@ class Extractor(StepABC):
"""
Apply a metric to cells identified in the tiles.
Using the cell masks, the Extractor applies a metric, such as area or median, to cells identified in the image tiles.
Using the cell masks, the Extractor applies a metric, such as
area or median, to cells identified in the image tiles.
Its methods require both tile images and masks.
Usually the metric is applied to only a tile's masked area, but some metrics depend on the whole tile.
Usually the metric is applied to only a tile's masked area, but
some metrics depend on the whole tile.
Extraction follows a three-level tree structure. Channels, such as GFP, are the root level; the reduction algorithm, such as maximum projection, is the second level; the specific metric, or operation, to apply to the masks, such as mean, is the third level.
Extraction follows a three-level tree structure. Channels, such
as GFP, are the root level; the reduction algorithm, such as
maximum projection, is the second level; the specific metric,
or operation, to apply to the masks, such as mean, is the third
or leaf level.
"""
# TODO Alan: Move this to a location with the SwainLab defaults
default_meta = {
"pixel_size": 0.236,
"z_size": 0.6,
"spacing": 0.6,
}
default_meta = global_parameters.imaging_specifications
def __init__(
self,
......@@ -107,17 +109,36 @@ class Extractor(StepABC):
store: str
Path to the h5 file containing the cell masks.
tiler: pipeline-core.core.segmentation tiler
Class that contains or fetches the images used for segmentation.
Class that contains or fetches the images used for
segmentation.
"""
self.params = parameters
if store:
self.local = store
self.load_meta()
self.meta = load_meta(self.local)
else:
# if no h5 file, use the parameters directly
self.meta = {"channel": parameters.to_dict()["tree"].keys()}
if tiler:
self.tiler = tiler
available_channels = set((*tiler.channels, "general"))
# only extract for channels available
self.params.tree = {
k: v
for k, v in self.params.tree.items()
if k in available_channels
}
self.params.sub_bg = available_channels.intersection(
self.params.sub_bg
)
# add background subtracted channels to those available
available_channels_bgsub = available_channels.union(
[c + "_bgsub" for c in self.params.sub_bg]
)
# remove any multichannel operations requiring a missing channel
for op, (input_ch, _, _) in self.params.multichannel_ops.items():
if not set(input_ch).issubset(available_channels_bgsub):
self.params.multichannel_ops.pop(op)
self.load_funs()
@classmethod
......@@ -161,13 +182,16 @@ class Extractor(StepABC):
def load_custom_funs(self):
"""
Incorporate the extra arguments of custom functions into their definitions.
Incorporate the extra arguments of custom functions into their
definitions.
Normal functions only have cell_masks and trap_image as their
arguments, and here custom functions are made the same by
setting the values of their extra arguments.
Any other parameters are taken from the experiment's metadata and automatically applied. These parameters therefore must be loaded within an Extractor instance.
Any other parameters are taken from the experiment's metadata
and automatically applied. These parameters therefore must be
loaded within an Extractor instance.
"""
# find functions specified in params.tree
funs = set(
......@@ -202,27 +226,23 @@ class Extractor(StepABC):
self._custom_funs[k] = tmp(f)
def load_funs(self):
"""Define all functions, including custum ones."""
"""Define all functions, including custom ones."""
self.load_custom_funs()
self._all_cell_funs = set(self._custom_funs.keys()).union(CELL_FUNS)
# merge the two dicts
self._all_funs = {**self._custom_funs, **FUNS}
def load_meta(self):
"""Load metadata from h5 file."""
self.meta = load_attributes(self.local)
self._all_funs = {**self._custom_funs, **ALL_FUNS}
def get_tiles(
self,
tp: int,
channels: t.Optional[t.List[t.Union[str, int]]] = None,
z: t.Optional[t.List[str]] = None,
**kwargs,
) -> t.Optional[np.ndarray]:
"""
Find tiles for a given time point, channels, and z-stacks.
Any additional keyword arguments are passed to tiler.get_tiles_timepoint
Any additional keyword arguments are passed to
tiler.get_tiles_timepoint
Parameters
----------
......@@ -243,24 +263,23 @@ class Extractor(StepABC):
# a list of the indices of the z stacks
channel_ids = None
if z is None:
# gets the tiles data via tiler
# include all Z channels
z = list(range(self.tiler.shape[-3]))
# get the image data via tiler
res = (
self.tiler.get_tiles_timepoint(
tp, channels=channel_ids, z=z, **kwargs
)
self.tiler.get_tiles_timepoint(tp, channels=channel_ids, z=z)
if channel_ids
else None
)
# data arranged as (tiles, channels, time points, X, Y, Z)
# res has dimensions (tiles, channels, 1, Z, X, Y)
return res
def extract_traps(
self,
traps: t.List[np.ndarray],
masks: t.List[np.ndarray],
metric: str,
labels: t.Dict[int, t.List[int]],
cell_property: str,
cell_labels: t.Dict[int, t.List[int]],
) -> t.Tuple[t.Union[t.Tuple[float], t.Tuple[t.Tuple[int]]]]:
"""
Apply a function to a whole position.
......@@ -271,35 +290,35 @@ class Extractor(StepABC):
t.List of images.
masks: list of arrays
t.List of masks.
metric: str
Metric to extract.
labels: dict
A dict of cell labels with trap_ids as keys and a list of cell labels as values.
pos_info: bool
Whether to add the position as an index or not.
cell_property: str
Property to extract, including imBackground.
cell_labels: dict
A dict of cell labels with trap_ids as keys and a list
of cell labels as values.
Returns
-------
res_idx: a tuple of tuples
A two-tuple comprising a tuple of results and a tuple of the tile_id and cell labels
A two-tuple comprising a tuple of results and a tuple of
the tile_id and cell labels
"""
if labels is None:
self._log("No labels given. Sorting cells using index.")
cell_fun = True if metric in self._all_cell_funs else False
if cell_labels is None:
self._log("No cell labels given. Sorting cells using index.")
cell_fun = True if cell_property in self._all_cell_funs else False
idx = []
results = []
for trap_id, (mask_set, trap, lbl_set) in enumerate(
zip(masks, traps, labels.values())
for trap_id, (mask_set, trap, local_cell_labels) in enumerate(
zip(masks, traps, cell_labels.values())
):
# ignore empty traps
if len(mask_set):
# apply metric either a cell function or otherwise
result = self._all_funs[metric](mask_set, trap)
# find property from the tile
result = self._all_funs[cell_property](mask_set, trap)
if cell_fun:
# store results for each cell separately
for lbl, val in zip(lbl_set, result):
for cell_label, val in zip(local_cell_labels, result):
results.append(val)
idx.append((trap_id, lbl))
idx.append((trap_id, cell_label))
else:
# background (trap) function
results.append(result)
......@@ -311,19 +330,19 @@ class Extractor(StepABC):
self,
traps: t.List[np.array],
masks: t.List[np.array],
metrics: t.List[str],
cell_properties: t.List[str],
**kwargs,
) -> t.Dict[str, pd.Series]:
"""
Return dict with metrics as key and metrics applied to data as values.
Return dict with metrics as key and cell_properties as values.
Data from one time point is used.
"""
d = {
metric: self.extract_traps(
traps=traps, masks=masks, metric=metric, **kwargs
cell_property: self.extract_traps(
traps=traps, masks=masks, cell_property=cell_property, **kwargs
)
for metric in metrics
for cell_property in cell_properties
}
return d
......@@ -331,11 +350,11 @@ class Extractor(StepABC):
self,
traps: np.ndarray,
masks: t.List[np.ndarray],
red_metrics: t.Dict[reduction_method, t.Collection[str]],
tree_branch: t.Dict[reduction_method, t.Collection[str]],
**kwargs,
) -> t.Dict[str, t.Dict[reduction_method, t.Dict[str, pd.Series]]]:
"""
Wrapper to apply reduction and then extraction.
Wrapper to reduce to a 2D image and then extract.
Parameters
----------
......@@ -343,8 +362,10 @@ class Extractor(StepABC):
An array of image data arranged as (tiles, X, Y, Z)
masks: list of arrays
An array of masks for each trap: one per cell at the trap
red_metrics: dict
dict for which keys are reduction functions and values are either a list or a set of strings giving the metric functions.
tree_branch: dict
An upper branch of the extraction tree: a dict for which
keys are reduction functions and values are either a list
or a set of strings giving the cell properties to be found.
For example: {'np_max': {'max5px', 'mean', 'median'}}
**kwargs: dict
All other arguments passed to Extractor.extract_funs.
......@@ -353,22 +374,27 @@ class Extractor(StepABC):
------
Dict of dataframes with the corresponding reductions and metrics nested.
"""
# create dict with keys naming the reduction in the z-direction and the reduced data as values
reduced_tiles_data = {}
# FIXME hack to pass tests
if "labels" in kwargs:
kwargs["cell_labels"] = kwargs.pop("labels")
# create dict with keys naming the reduction in the z-direction
# and the reduced data as values
reduced_tiles = {}
if traps is not None:
for red_fun in red_metrics.keys():
reduced_tiles_data[red_fun] = [
for red_fun in tree_branch.keys():
reduced_tiles[red_fun] = [
self.reduce_dims(tile_data, method=RED_FUNS[red_fun])
for tile_data in traps
]
# calculate cell and tile properties
d = {
red_fun: self.extract_funs(
metrics=metrics,
traps=reduced_tiles_data.get(red_fun, [None for _ in masks]),
cell_properties=cell_properties,
traps=reduced_tiles.get(red_fun, [None for _ in masks]),
masks=masks,
**kwargs,
)
for red_fun, metrics in red_metrics.items()
for red_fun, cell_properties in tree_branch.items()
}
return d
......@@ -392,64 +418,24 @@ class Extractor(StepABC):
reduced = reduce_z(img, method)
return reduced
def extract_tp(
self,
tp: int,
tree: t.Optional[extraction_tree] = None,
tile_size: int = 117,
masks: t.Optional[t.List[np.ndarray]] = None,
labels: t.Optional[t.List[int]] = None,
**kwargs,
) -> t.Dict[str, t.Dict[str, t.Dict[str, tuple]]]:
"""
Extract for an individual time point.
Parameters
----------
tp : int
Time point being analysed.
tree : dict
Nested dictionary indicating channels, reduction functions and
metrics to be used.
For example: {'general': {'None': ['area', 'volume', 'eccentricity']}}
tile_size : int
Size of the tile to be extracted.
masks : list of arrays
A list of masks per trap with each mask having dimensions (ncells, tile_size,
tile_size).
labels : dict
A dictionary with trap_ids as keys and cell_labels as values.
**kwargs : keyword arguments
Passed to extractor.reduce_extract.
Returns
-------
d: dict
Dictionary of the results with three levels of dictionaries.
The first level has channels as keys.
The second level has reduction metrics as keys.
The third level has cell or background metrics as keys and a two-tuple as values.
The first tuple is the result of applying the metrics to a particular cell or trap; the second tuple is either (trap_id, cell_label) for a metric applied to a cell or a trap_id for a metric applied to a trap.
An example is d["GFP"]["np_max"]["mean"][0], which gives a tuple of the calculated mean GFP fluorescence for all cells.
"""
# TODO Can we split the different extraction types into sub-methods to make this easier to read?
def make_tree_bits(self, tree):
"""Put extraction tree and information for the channels into a dict."""
if tree is None:
# use default
tree: extraction_tree = self.params.tree
# dictionary with channel: {reduction algorithm : metric}
ch_tree = {ch: v for ch, v in tree.items() if ch != "general"}
# tuple of the channels
tree_chs = (*ch_tree,)
# create a Cells object to extract information from the h5 file
cells = Cells(self.local)
# find the cell labels and store as dict with trap_ids as keys
if labels is None:
raw_labels = cells.labels_at_time(tp)
labels = {
trap_id: raw_labels.get(trap_id, [])
for trap_id in range(cells.ntraps)
}
tree_bits = {
"tree": tree,
# dictionary with channel: {reduction algorithm : metric}
"channel_tree": {
ch: v for ch, v in tree.items() if ch != "general"
},
}
# tuple of the fluorescence channels
tree_bits["tree_channels"] = (*tree_bits["channel_tree"],)
return tree_bits
def get_masks(self, tp, masks, cells):
"""Get the masks as a list with an array of masks for each trap."""
# find the cell masks for a given trap as a dict with trap_ids as keys
if masks is None:
raw_masks = cells.at_time(tp, kind="mask")
......@@ -458,16 +444,31 @@ class Extractor(StepABC):
if len(cells):
masks[trap_id] = np.stack(np.array(cells)).astype(bool)
# convert to a list of masks
# one array of size (no cells, tile_size, tile_size) per trap
masks = [np.array(v) for v in masks.values()]
# find image data at the time point
# stored as an array arranged as (traps, channels, time points, X, Y, Z)
tiles = self.get_tiles(tp, tile_shape=tile_size, channels=tree_chs)
# generate boolean masks for background as a list with one mask per trap
bgs = np.array([])
return masks
def get_cell_labels(self, tp, cell_labels, cells):
"""Get the cell labels per trap as a dict with trap_ids as keys."""
if cell_labels is None:
raw_cell_labels = cells.labels_at_time(tp)
cell_labels = {
trap_id: raw_cell_labels.get(trap_id, [])
for trap_id in range(cells.ntraps)
}
return cell_labels
def get_background_masks(self, masks, tile_size):
"""
Generate boolean background masks.
Combine masks per trap and then take the logical inverse.
"""
if self.params.sub_bg:
bgs = ~np.array(
list(
map(
# sum over masks for each cell
lambda x: np.sum(x, axis=0)
if np.any(x)
else np.zeros((tile_size, tile_size)),
......@@ -475,77 +476,177 @@ class Extractor(StepABC):
)
)
).astype(bool)
# perform extraction by applying metrics
else:
bgs = np.array([])
return bgs
def extract_one_channel(
self, tree_bits, cell_labels, tiles, masks, bgs, **kwargs
):
"""
Extract using all metrics requiring a single channel.
Apply first without and then with background subtraction.
Return the extraction results and a dict of background
corrected images.
"""
d = {}
self.img_bgsub = {}
for ch, red_metrics in tree.items():
img_bgsub = {}
for ch, tree_branch in tree_bits["tree"].items():
# NB ch != is necessary for threading
if ch != "general" and tiles is not None and len(tiles):
# image data for all traps and z sections for a particular channel
# as an array arranged as (tiles, Z, X, Y, )
img = tiles[:, tree_chs.index(ch), 0]
# image data for all traps for a particular channel and time point
# arranged as (traps, Z, X, Y)
# we use 0 here to access the single time point available
img = tiles[:, tree_bits["tree_channels"].index(ch), 0]
else:
# no reduction applied to bright-field images
img = None
# apply metrics to image data
d[ch] = self.reduce_extract(
traps=img,
masks=masks,
red_metrics=red_metrics,
labels=labels,
tree_branch=tree_branch,
cell_labels=cell_labels,
**kwargs,
)
# apply metrics to image data with the background subtracted
if bgs.any() and ch in self.params.sub_bg and img is not None:
# calculate metrics with subtracted bg
# calculate metrics with background subtracted
ch_bs = ch + "_bgsub"
# subtract median background
self.img_bgsub[ch_bs] = np.moveaxis(
np.stack(
list(
map(
lambda tile, mask: np.moveaxis(tile, 0, -1)
- bn.median(tile[:, mask], axis=1),
img,
bgs,
)
)
),
-1,
1,
) # End with tiles, z, y, x
bgsub_mapping = map(
# move Z to last column to allow subtraction
lambda img, bgs: np.moveaxis(img, 0, -1)
# median of background over all pixels for each Z section
- bn.median(img[:, bgs], axis=1),
img,
bgs,
)
# apply map and convert to array
mapping_result = np.stack(list(bgsub_mapping))
# move Z axis back to the second column
img_bgsub[ch_bs] = np.moveaxis(mapping_result, -1, 1)
# apply metrics to background-corrected data
d[ch_bs] = self.reduce_extract(
red_metrics=ch_tree[ch],
traps=self.img_bgsub[ch_bs],
tree_branch=tree_bits["channel_tree"][ch],
traps=img_bgsub[ch_bs],
masks=masks,
labels=labels,
cell_labels=cell_labels,
**kwargs,
)
# apply any metrics using multiple channels, such as pH calculations
return d, img_bgsub
def extract_multiple_channels(
self, tree_bits, cell_labels, tiles, masks, **kwargs
):
"""
Extract using all metrics requiring multiple channels.
"""
available_chs = set(self.img_bgsub.keys()).union(
tree_bits["tree_channels"]
)
d = {}
for name, (
chs,
merge_fun,
red_metrics,
reduction_fun,
op,
) in self.params.multichannel_ops.items():
if len(
set(chs).intersection(
set(self.img_bgsub.keys()).union(tree_chs)
)
) == len(chs):
common_chs = set(chs).intersection(available_chs)
# all required channels should be available
if len(common_chs) == len(chs):
channels_stack = np.stack(
[self.get_imgs(ch, tiles, tree_chs) for ch in chs], axis=-1
[
self.get_imgs(ch, tiles, tree_bits["tree_channels"])
for ch in chs
],
axis=-1,
)
merged = RED_FUNS[merge_fun](channels_stack, axis=-1)
d[name] = self.reduce_extract(
red_metrics=red_metrics,
traps=merged,
masks=masks,
labels=labels,
**kwargs,
# reduce in Z
traps = RED_FUNS[reduction_fun](channels_stack, axis=1)
# evaluate multichannel op
if name not in d:
d[name] = {}
if reduction_fun not in d[name]:
d[name][reduction_fun] = {}
d[name][reduction_fun][op] = self.extract_traps(
traps,
masks,
op,
cell_labels,
)
return d
def extract_tp(
self,
tp: int,
tree: t.Optional[extraction_tree] = None,
tile_size: int = 117,
masks: t.Optional[t.List[np.ndarray]] = None,
cell_labels: t.Optional[t.List[int]] = None,
**kwargs,
) -> t.Dict[str, t.Dict[str, t.Dict[str, tuple]]]:
"""
Extract for an individual time point.
Parameters
----------
tp : int
Time point being analysed.
tree : dict
Nested dictionary indicating channels, reduction functions
and metrics to be used.
For example: {'general': {'None': ['area', 'volume', 'eccentricity']}}
tile_size : int
Size of the tile to be extracted.
masks : list of arrays
A list of masks per trap with each mask having dimensions
(ncells, tile_size, tile_size) and with one mask per cell.
cell_labels : dict
A dictionary with trap_ids as keys and cell_labels as values.
**kwargs : keyword arguments
Passed to extractor.reduce_extract.
Returns
-------
d: dict
Dictionary of the results with three levels of dictionaries.
The first level has channels as keys.
The second level has reduction metrics as keys.
The third level has cell or background metrics as keys and a
two-tuple as values.
The first tuple is the result of applying the metrics to a
particular cell or trap; the second tuple is either
(trap_id, cell_label) for a metric applied to a cell or a
trap_id for a metric applied to a trap.
An example is d["GFP"]["np_max"]["mean"][0], which gives a tuple
of the calculated mean GFP fluorescence for all cells.
"""
# dict of information from extraction tree
tree_bits = self.make_tree_bits(tree)
# create a Cells object to extract information from the h5 file
cells = Cells(self.local)
# find the cell labels as dict with trap_ids as keys
cell_labels = self.get_cell_labels(tp, cell_labels, cells)
# get masks one per cell per trap
masks = self.get_masks(tp, masks, cells)
# find image data at the time point
# stored as an array arranged as (traps, channels, 1, Z, X, Y)
tiles = self.get_tiles(tp, channels=tree_bits["tree_channels"])
# generate boolean masks for background for each trap
bgs = self.get_background_masks(masks, tile_size)
# perform extraction
res_one, self.img_bgsub = self.extract_one_channel(
tree_bits, cell_labels, tiles, masks, bgs, **kwargs
)
res_multiple = self.extract_multiple_channels(
tree_bits, cell_labels, tiles, masks, **kwargs
)
res = {**res_one, **res_multiple}
return res
def get_imgs(self, channel: t.Optional[str], tiles, channels=None):
"""
Return image from a correct source, either raw or bgsub.
......@@ -555,14 +656,16 @@ class Extractor(StepABC):
channel: str
Name of channel to get.
tiles: ndarray
An array of the image data having dimensions of (tile_id, channel, tp, tile_size, tile_size, n_zstacks).
An array of the image data having dimensions of
(tile_id, channel, tp, tile_size, tile_size, n_zstacks).
channels: list of str (optional)
t.List of available channels.
Returns
-------
img: ndarray
An array of image data with dimensions (no tiles, X, Y, no Z channels)
An array of image data with dimensions
(no tiles, X, Y, no Z channels)
"""
if channels is None:
channels = (*self.params.tree,)
......@@ -579,7 +682,9 @@ class Extractor(StepABC):
**kwargs,
) -> dict:
"""
Wrapper to add compatibility with other steps of the pipeline.
Run extraction for one position and for the specified time points.
Save the results to a h5 file.
Parameters
----------
......@@ -597,7 +702,9 @@ class Extractor(StepABC):
Returns
-------
d: dict
A dict of the extracted data with a concatenated string of channel, reduction metric, and cell metric as keys and pd.Series of the extracted data as values.
A dict of the extracted data for one position with a concatenated
string of channel, reduction metric, and cell metric as keys and
pd.DataFrame of the extracted data for all time points as values.
"""
if tree is None:
tree = self.params.tree
......@@ -628,12 +735,12 @@ class Extractor(StepABC):
d[k].index.names = idx
# save
if save:
self.save_to_hdf(d)
self.save_to_h5(d)
return d
def save_to_hdf(self, dict_series, path=None):
def save_to_h5(self, dict_series, path=None):
"""
Save the extracted data to the h5 file.
Save the extracted data for one position to the h5 file.
Parameters
----------
......@@ -672,14 +779,17 @@ def flatten_nesteddict(
nest: dict of dicts
Contains the nested results of extraction.
to: str (optional)
Specifies the format of the output, either pd.Series (default) or a list
Specifies the format of the output, either pd.Series (default)
or a list
tp: int
Time point used to name the pd.Series
Returns
-------
d: dict
A dict with a concatenated string of channel, reduction metric, and cell metric as keys and either a pd.Series or a list of the corresponding extracted data as values.
A dict with a concatenated string of channel, reduction metric,
and cell metric as keys and either a pd.Series or a list of the
corresponding extracted data as values.
"""
d = {}
for k0, v0 in nest.items():
......@@ -689,14 +799,3 @@ def flatten_nesteddict(
pd.Series(*v2, name=tp) if to == "series" else v2
)
return d
class hollowExtractor(Extractor):
"""
Extractor that only cares about receiving images and masks.
Used for testing.
"""
def __init__(self, parameters):
self.params = parameters
......@@ -90,9 +90,10 @@ def max2p5pc(cell_mask, trap_image) -> float:
return np.mean(top_values)
def max5px(cell_mask, trap_image) -> float:
def max5px_median(cell_mask, trap_image) -> float:
"""
Find the mean of the five brightest pixels in the cell.
Find the mean of the five brightest pixels in the cell divided by the
median of all pixels.
Parameters
----------
......@@ -105,7 +106,11 @@ def max5px(cell_mask, trap_image) -> float:
top_values = bn.partition(pixels, len(pixels) - 5)[-5:]
# find mean of five brightest pixels
max5px = np.mean(top_values)
return max5px
med = np.median(pixels)
if med == 0:
return np.nan
else:
return max5px / np.median(pixels)
def std(cell_mask, trap_image):
......@@ -193,3 +198,50 @@ def min_maj_approximation(cell_mask) -> t.Tuple[int]:
# + distance from the center of cone top to edge of cone top
maj_ax = np.round(np.max(dn) + np.sum(cone_top) / 2)
return min_ax, maj_ax
def moment_of_inertia(cell_mask, trap_image):
"""
Find moment of inertia - a measure of homogeneity.
From iopscience.iop.org/article/10.1088/1742-6596/1962/1/012028
which cites ieeexplore.ieee.org/document/1057692.
"""
# set pixels not in cell to zero
trap_image[~cell_mask] = 0
x = trap_image
if np.any(x):
# x-axis : column=x-axis
columnvec = np.arange(1, x.shape[1] + 1, 1)[:, None].T
# y-axis : row=y-axis
rowvec = np.arange(1, x.shape[0] + 1, 1)[:, None]
# find raw moments
M00 = np.sum(x)
M10 = np.sum(np.multiply(x, columnvec))
M01 = np.sum(np.multiply(x, rowvec))
# find centroid
Xm = M10 / M00
Ym = M01 / M00
# find central moments
Mu00 = M00
Mu20 = np.sum(np.multiply(x, (columnvec - Xm) ** 2))
Mu02 = np.sum(np.multiply(x, (rowvec - Ym) ** 2))
# find invariants
Eta20 = Mu20 / Mu00 ** (1 + (2 + 0) / 2)
Eta02 = Mu02 / Mu00 ** (1 + (0 + 2) / 2)
# find moments of inertia
moi = Eta20 + Eta02
return moi
else:
return np.nan
def ratio(cell_mask, trap_image):
"""Find the median ratio between two fluorescence channels."""
if trap_image.ndim == 3 and trap_image.shape[-1] == 2:
fl_1 = trap_image[..., 0][cell_mask]
fl_2 = trap_image[..., 1][cell_mask]
div = np.median(fl_1 / fl_2)
else:
div = np.nan
return div
""" How to do the nuc Est Conv from MATLAB
"""
How to do the nuc Est Conv from MATLAB
Based on the code in MattSegCode/Matt Seg
GUI/@timelapseTraps/extractCellDataStacksParfor.m
Especially lines 342 to 399.
Especially lines 342 to 399.
This part only replicates the method to get the nuc_est_conv values
"""
import typing as t
......
# File with defaults for ease of use
import re
import typing as t
from pathlib import Path
import h5py
# should we move these functions here?
import aliby.global_parameters as global_parameters
from aliby.tile.tiler import find_channel_name
......@@ -22,26 +22,10 @@ def exparams_from_meta(
"tree": {"general": {"None": ["area", "volume", "eccentricity"]}},
"multichannel_ops": {},
}
candidate_channels = {
"Citrine",
"GFP",
"GFPFast",
"mCherry",
"pHluorin405",
"pHluorin488",
"Flavin",
"Cy5",
"mKO2",
}
candidate_channels = set(global_parameters.possible_imaging_channels)
default_reductions = {"max"}
default_metrics = {
"mean",
"median",
"std",
"imBackground",
"max5px",
# "nuc_est_conv",
}
default_metrics = set(global_parameters.fluorescence_functions)
# define ratiometric combinations
# key is numerator and value is denominator
# add more to support additional channel names
......@@ -59,6 +43,7 @@ def exparams_from_meta(
for ch in extant_fluorescence_ch:
base["tree"][ch] = default_reduction_metrics
base["sub_bg"] = extant_fluorescence_ch
# additional extraction defaults if the channels are available
if "ph" in extras:
# SWAINLAB specific names
......
......@@ -44,5 +44,6 @@ def reduce_z(trap_image: np.ndarray, fun: t.Callable, axis: int = 0):
elif isinstance(fun, np.ufunc):
# optimise the reduction function if possible
return fun.reduce(trap_image, axis=axis)
else: # WARNING: Very slow, only use when no alternatives exist
else:
# WARNING: Very slow, only use when no alternatives exist
return np.apply_along_axis(fun, axis, trap_image)
......@@ -11,8 +11,10 @@ from extraction.core.functions.math_utils import div0
"""
Load functions for analysing cells and their background.
Note that inspect.getmembers returns a list of function names and functions,
and inspect.getfullargspec returns a function's arguments.
Note that inspect.getmembers returns a list of function names
and functions, and inspect.getfullargspec returns a
function's arguments.
"""
......@@ -66,7 +68,7 @@ def load_cellfuns():
# create dict of the core functions from cell.py - these functions apply to a single mask
cell_funs = load_cellfuns_core()
# create a dict of functions that apply the core functions to an array of cell_masks
CELLFUNS = {}
CELL_FUNS = {}
for f_name, f in cell_funs.items():
if isfunction(f):
......@@ -79,27 +81,27 @@ def load_cellfuns():
# function that applies f to m and img, the trap_image
return lambda m, img: trap_apply(f, m, img)
CELLFUNS[f_name] = tmp(f)
return CELLFUNS
CELL_FUNS[f_name] = tmp(f)
return CELL_FUNS
def load_trapfuns():
"""Load functions that are applied to an entire tile."""
TRAPFUNS = {
TRAP_FUNS = {
f[0]: f[1]
for f in getmembers(trap)
if isfunction(f[1])
and f[1].__module__.startswith("extraction.core.functions")
}
return TRAPFUNS
return TRAP_FUNS
def load_funs():
"""Combine all automatically loaded functions."""
CELLFUNS = load_cellfuns()
TRAPFUNS = load_trapfuns()
CELL_FUNS = load_cellfuns()
TRAP_FUNS = load_trapfuns()
# return dict of cell funs, dict of trap funs, and dict of both
return CELLFUNS, TRAPFUNS, {**TRAPFUNS, **CELLFUNS}
return CELL_FUNS, TRAP_FUNS, {**TRAP_FUNS, **CELL_FUNS}
def load_redfuns(
......
......@@ -14,17 +14,17 @@ from postprocessor.core.abc import get_process
class Chainer(Signal):
"""
Extend Signal by applying post-processes and allowing composite signals that combine basic signals.
It "chains" multiple processes upon fetching a dataset to produce the desired datasets.
Instead of reading processes previously applied, it executes
Chainer "chains" multiple processes upon fetching a dataset.
Instead of reading processes previously applied, Chainer executes
them when called.
"""
_synonyms = {
"m5m": ("extraction/GFP/max/max5px", "extraction/GFP/max/median")
}
_synonyms = {}
def __init__(self, *args, **kwargs):
"""Initialise chainer."""
super().__init__(*args, **kwargs)
def replace_path(path: str, bgsub: bool = ""):
......@@ -34,7 +34,7 @@ class Chainer(Signal):
path = re.sub(channel, f"{channel}{suffix}", path)
return path
# Add chain with and without bgsub for composite statistics
# add chain with and without bgsub for composite statistics
self.common_chains = {
alias
+ bgsub: lambda **kwargs: self.get(
......
......@@ -46,18 +46,16 @@ def get_process(process, suffix="") -> PostProcessABC or ParametersABC or None:
_to_snake_case(process),
_to_pascal_case(_to_snake_case(process)),
)
found = None
for possible_location, process_syntax in product(
possible_locations, valid_syntaxes
):
location = f"{base_location}.{possible_location}.{_to_snake_case(process)}.{process_syntax}{suffix}"
# instantiate class but not a class object
found = locate(location)
if found is not None:
break
else:
raise Exception(
f"{process} not found in locations {possible_locations} at {base_location}"
)
......
"""
Functions to process, filter and merge tracks.
"""
Functions to process, filter, and merge tracks.
We call two tracks contiguous if they are adjacent in time: the
maximal time point of one is one time point less than the
minimal time point of the other.
# from collections import Counter
A right track can have multiple potential left tracks. We pick the best.
"""
import typing as t
from copy import copy
......@@ -17,8 +21,385 @@ from utils_find_1st import cmp_larger, find_1st
from postprocessor.core.processes.savgol import non_uniform_savgol
def get_merges(tracks, smooth=False, tol=0.2, window=5, degree=3) -> dict:
"""
Find all pairs of tracks that should be joined.
Each track is defined by (trap_id, cell_id).
If there are multiple choices of which, say, left tracks to join to a
right track, pick the best using the Signal values to do so.
To score two tracks, we predict the future value of a left track and
compare with the mean initial values of a right track.
Parameters
----------
tracks: pd.DataFrame
A Signal, usually area, where rows are cell tracks and columns are
time points.
smooth: boolean
If True, smooth tracks with a savgol_filter.
tol: float < 1 or int
If int, compare the absolute distance between predicted values
for the left and right end points of two contiguous tracks.
If float, compare the distance relative to the magnitude of the
end point of the left track.
window: int
Length of window used for predictions and for any savgol_filter.
degree: int
The degree of the polynomial used by the savgol_filter.
"""
# only consider time series with more than two non-NaN data points
tracks = tracks.loc[tracks.notna().sum(axis=1) > 2]
# get contiguous tracks
if smooth:
# specialise to tracks with growing cells and of long duration
clean = clean_tracks(tracks, min_duration=window + 1, min_gr=0.9)
contigs = clean.groupby(["trap"]).apply(get_contiguous_pairs)
else:
contigs = tracks.groupby(["trap"]).apply(get_contiguous_pairs)
# remove traps with no contiguous tracks
contigs = contigs.loc[contigs.apply(len) > 0]
# flatten to (trap, cell_id) pairs
flat = set([k for v in contigs.values for i in v for j in i for k in j])
# make a data frame of contiguous tracks with the tracks as arrays
if smooth:
smoothed_tracks = clean.loc[flat].apply(
lambda x: non_uniform_savgol(x.index, x.values, window, degree),
axis=1,
)
else:
smoothed_tracks = tracks.loc[flat].apply(
lambda x: np.array(x.values), axis=1
)
# get the Signal values for neighbouring end points of contiguous tracks
actual_edge_values = contigs.apply(
lambda x: get_edge_values(x, smoothed_tracks)
)
# get the predicted values
predicted_edge_values = contigs.apply(
lambda x: get_predicted_edge_values(x, smoothed_tracks, window)
)
# score predicted edge values: low values are best
prediction_scores = predicted_edge_values.apply(get_dMetric_wrap)
# find contiguous tracks to join for each trap
trap_contigs_to_join = []
for idx in contigs.index:
local_contigs = contigs.loc[idx]
# find indices of best left and right tracks to join
best_indices = find_best_from_scores_wrap(
prediction_scores.loc[idx], actual_edge_values.loc[idx], tol=tol
)
# find tracks from the indices
trap_contigs_to_join.append(
[
(contig[0][left], contig[1][right])
for best_index, contig in zip(best_indices, local_contigs)
for (left, right) in best_index
if best_index
]
)
# return only the pairs of contiguous tracks
contigs_to_join = [
contigs
for trap_tracks in trap_contigs_to_join
for contigs in trap_tracks
]
merges = np.array(contigs_to_join, dtype=int)
return merges
def clean_tracks(
tracks, min_duration: int = 15, min_gr: float = 1.0
) -> pd.DataFrame:
"""Remove small non-growing tracks and return the reduced data frame."""
ntps = tracks.apply(max_ntps, axis=1)
grs = tracks.apply(get_avg_gr, axis=1)
growing_long_tracks = tracks.loc[(ntps >= min_duration) & (grs > min_gr)]
return growing_long_tracks
def get_contiguous_pairs(tracks: pd.DataFrame) -> list:
"""
Get all pair of contiguous track ids from a tracks data frame.
For two tracks to be contiguous, they must be exactly adjacent.
Parameters
----------
tracks: pd.Dataframe
A dataframe where rows are cell tracks and columns are time
points.
"""
# TODO add support for skipping time points
# find time points bounding tracks of non-NaN values
mins, maxs = [
tracks.notna().apply(np.where, axis=1).apply(fn)
for fn in (np.min, np.max)
]
# flip so that time points become the index
mins_d = mins.groupby(mins).apply(lambda x: x.index.tolist())
maxs_d = maxs.groupby(maxs).apply(lambda x: x.index.tolist())
# reduce minimal time point to make a right track overlap with a left track
mins_d.index = mins_d.index - 1
# find common end points
common = sorted(set(mins_d.index).intersection(maxs_d.index), reverse=True)
contigs = [(maxs_d[t], mins_d[t]) for t in common]
return contigs
def get_edge_values(contigs_ids, smoothed_tracks):
"""Get Signal values for adjacent end points for each contiguous track."""
values = [
(
[get_value(smoothed_tracks.loc[pre_id], -1) for pre_id in pre_ids],
[
get_value(smoothed_tracks.loc[post_id], 0)
for post_id in post_ids
],
)
for pre_ids, post_ids in contigs_ids
]
return values
def get_predicted_edge_values(contigs_ids, smoothed_tracks, window):
"""
Find neighbouring values of two contiguous tracks.
Predict the next value for the leftmost track using window values
and find the mean of the initial window values of the rightmost
track.
"""
result = []
for pre_ids, post_ids in contigs_ids:
pre_res = []
# left contiguous tracks
for pre_id in pre_ids:
# get last window values of a track
y = get_values_i(smoothed_tracks.loc[pre_id], -window)
# predict next value
pre_res.append(
np.poly1d(np.polyfit(range(len(y)), y, 1))(len(y) + 1),
)
# right contiguous tracks
pos_res = [
# mean value of initial window values of a track
get_mean_value_i(smoothed_tracks.loc[post_id], window)
for post_id in post_ids
]
result.append([pre_res, pos_res])
return result
def get_dMetric_wrap(lst: List, **kwargs):
"""Calculate dMetric on a list."""
return [get_dMetric(*sublist, **kwargs) for sublist in lst]
def get_dMetric(pre_values: List[float], post_values: List[float]):
"""
Calculate a scoring matrix based on the difference between two Signal
values.
We generate one score per pair of contiguous tracks.
Lower scores are better.
Parameters
----------
pre_values: list of floats
Values of the Signal for left contiguous tracks.
post_values: list of floats
Values of the Signal for right contiguous tracks.
"""
if len(pre_values) > len(post_values):
dMetric = np.abs(np.subtract.outer(post_values, pre_values))
else:
dMetric = np.abs(np.subtract.outer(pre_values, post_values))
# replace NaNs with maximal values
dMetric[np.isnan(dMetric)] = 1 + np.nanmax(dMetric)
return dMetric
def find_best_from_scores_wrap(dMetric: List, edges: List, **kwargs):
"""Calculate solve_matrices on a list."""
return [
find_best_from_scores(mat, edgeset, **kwargs)
for mat, edgeset in zip(dMetric, edges)
]
def find_best_from_scores(
scores: np.ndarray, actual_edge_values: List, tol: Union[float, int] = 1
):
"""Find indices for left and right contiguous tracks with scores below a tolerance."""
ids = find_best_indices(scores)
if len(ids[0]):
pre_value, post_value = actual_edge_values
# score with relative or absolute distance
norm = (
np.array(pre_value)[ids[len(pre_value) > len(post_value)]]
if tol < 1
else 1
)
best_scores = scores[ids] / norm
ids = ids if len(pre_value) < len(post_value) else ids[::-1]
# keep only indices with best_score less than the tolerance
indices = [
idx for idx, score in zip(zip(*ids), best_scores) if score <= tol
]
return indices
else:
return []
def find_best_indices(dMetric):
"""Find indices for left and right contiguous tracks with minimal scores."""
glob_is = []
glob_js = []
if (~np.isnan(dMetric)).any():
lMetric = copy(dMetric)
sortedMetric = sorted(lMetric[~np.isnan(lMetric)])
while (~np.isnan(sortedMetric)).any():
# indices of point with the lowest score
i_s, j_s = np.where(lMetric == sortedMetric[0])
i = i_s[0]
j = j_s[0]
# store this point
glob_is.append(i)
glob_js.append(j)
# remove from lMetric
lMetric[i, :] += np.nan
lMetric[:, j] += np.nan
sortedMetric = sorted(lMetric[~np.isnan(lMetric)])
indices = (np.array(glob_is), np.array(glob_js))
return indices
def get_value(x, n):
"""Get value from an array ignoring NaN."""
val = x[~np.isnan(x)][n] if len(x[~np.isnan(x)]) else np.nan
return val
def get_mean_value_i(x, i):
"""Get track's mean Signal value from values either from or up to an index."""
if not len(x[~np.isnan(x)]):
return np.nan
else:
if i > 0:
v = x[~np.isnan(x)][:i]
else:
v = x[~np.isnan(x)][i:]
return np.nanmean(v)
def get_values_i(x, i):
"""Get track's Signal values either from or up to an index."""
if not len(x[~np.isnan(x)]):
return np.nan
else:
if i > 0:
v = x[~np.isnan(x)][:i]
else:
v = x[~np.isnan(x)][i:]
return v
def get_avg_gr(track: pd.Series) -> float:
"""Get average growth rate for a track."""
ntps = max_ntps(track)
vals = track.dropna().values
gr = (vals[-1] - vals[0]) / ntps
return gr
######################################################################
def get_joinable_original(
tracks, smooth=False, tol=0.1, window=5, degree=3
) -> dict:
"""
Get the pair of track (without repeats) that have a smaller error than the
tolerance. If there is a track that can be assigned to two or more other
ones, choose the one with lowest error.
Parameters
----------
tracks: (m x n) Signal
A Signal, usually area, dataframe where rows are cell tracks and
columns are time points.
tol: float or int
threshold of average (prediction error/std) necessary
to consider two tracks the same. If float is fraction of first track,
if int it is absolute units.
window: int
value of window used for savgol_filter
degree: int
value of polynomial degree passed to savgol_filter
"""
# only consider time series with more than two non-NaN data points
tracks = tracks.loc[tracks.notna().sum(axis=1) > 2]
# get contiguous tracks
if smooth:
# specialise to tracks with growing cells and of long duration
clean = clean_tracks(tracks, min_duration=window + 1, min_gr=0.9)
contigs = clean.groupby(["trap"]).apply(get_contiguous_pairs)
else:
contigs = tracks.groupby(["trap"]).apply(get_contiguous_pairs)
# remove traps with no contiguous tracks
contigs = contigs.loc[contigs.apply(len) > 0]
# flatten to (trap, cell_id) pairs
flat = set([k for v in contigs.values for i in v for j in i for k in j])
# make a data frame of contiguous tracks with the tracks as arrays
if smooth:
smoothed_tracks = clean.loc[flat].apply(
lambda x: non_uniform_savgol(x.index, x.values, window, degree),
axis=1,
)
else:
smoothed_tracks = tracks.loc[flat].apply(
lambda x: np.array(x.values), axis=1
)
# get the Signal values for neighbouring end points of contiguous tracks
actual_edge_values = contigs.apply(
lambda x: get_edge_values(x, smoothed_tracks)
)
# get the predicted values
predicted_edge_values = contigs.apply(
lambda x: get_predicted_edge_values(x, smoothed_tracks, window)
)
# score predicted edge values
prediction_scores = predicted_edge_values.apply(get_dMetric_wrap, tol=tol)
solutions = [
# for all sets of contigs at a trap
find_best_from_scores_wrap(cost, edge_values, tol=tol)
for (trap_id, cost), edge_values in zip(
prediction_scores.items(), actual_edge_values
)
]
closest_pairs = pd.Series(
solutions,
index=prediction_scores.index,
)
# match local with global ids
joinable_ids = [
localid_to_idx(closest_pairs.loc[i], contigs.loc[i])
for i in closest_pairs.index
]
contigs_to_join = [
contigs for trap_tracks in joinable_ids for contigs in trap_tracks
]
return contigs_to_join
def load_test_dset():
"""Load development dataset to test functions."""
"""Load test data set."""
return pd.DataFrame(
{
("a", 1, 1): [2, 5, np.nan, 6, 8] + [np.nan] * 5,
......@@ -45,49 +426,6 @@ def max_nonstop_ntps(track: pd.Series) -> int:
return max(consecutive_nonas_grouped)
def get_tracks_ntps(tracks: pd.DataFrame) -> pd.Series:
return tracks.apply(max_ntps, axis=1)
def get_avg_gr(track: pd.Series) -> int:
"""
Get average growth rate for a track.
:param tracks: Series with volume and timepoints as indices
"""
ntps = max_ntps(track)
vals = track.dropna().values
gr = (vals[-1] - vals[0]) / ntps
return gr
def get_avg_grs(tracks: pd.DataFrame) -> pd.DataFrame:
"""
Get average growth rate for a group of tracks
:param tracks: (m x n) dataframe where rows are cell tracks and
columns are timepoints
"""
return tracks.apply(get_avg_gr, axis=1)
def clean_tracks(
tracks, min_len: int = 15, min_gr: float = 1.0
) -> pd.DataFrame:
"""
Clean small non-growing tracks and return the reduced dataframe
:param tracks: (m x n) dataframe where rows are cell tracks and
columns are timepoints
:param min_len: int number of timepoints cells must have not to be removed
:param min_gr: float Minimum mean growth rate to assume an outline is growing
"""
ntps = get_tracks_ntps(tracks)
grs = get_avg_grs(tracks)
growing_long_tracks = tracks.loc[(ntps >= min_len) & (grs > min_gr)]
return growing_long_tracks
def merge_tracks(
tracks, drop=False, **kwargs
) -> t.Tuple[pd.DataFrame, t.Collection]:
......@@ -191,141 +529,6 @@ def join_track_pair(target, source):
return tgt_copy
def get_joinable(tracks, smooth=False, tol=0.1, window=5, degree=3) -> dict:
"""
Get the pair of track (without repeats) that have a smaller error than the
tolerance. If there is a track that can be assigned to two or more other
ones, choose the one with lowest error.
:param tracks: (m x n) dataframe where rows are cell tracks and
columns are timepoints
:param tol: float or int threshold of average (prediction error/std) necessary
to consider two tracks the same. If float is fraction of first track,
if int it is absolute units.
:param window: int value of window used for savgol_filter
:param degree: int value of polynomial degree passed to savgol_filter
"""
tracks = tracks.loc[tracks.notna().sum(axis=1) > 2]
# Commented because we are not smoothing in this step yet
# candict = {k:v for d in contig.values for k,v in d.items()}
# smooth all relevant tracks
if smooth: # Apply savgol filter TODO fix nans affecting edge placing
clean = clean_tracks(
tracks, min_len=window + 1, min_gr=0.9
) # get useful tracks
def savgol_on_srs(x):
return non_uniform_savgol(x.index, x.values, window, degree)
contig = clean.groupby(["trap"]).apply(get_contiguous_pairs)
contig = contig.loc[contig.apply(len) > 0]
flat = set([k for v in contig.values for i in v for j in i for k in j])
smoothed_tracks = clean.loc[flat].apply(savgol_on_srs, 1)
else:
contig = tracks.groupby(["trap"]).apply(get_contiguous_pairs)
contig = contig.loc[contig.apply(len) > 0]
flat = set([k for v in contig.values for i in v for j in i for k in j])
smoothed_tracks = tracks.loc[flat].apply(
lambda x: np.array(x.values), axis=1
)
# fetch edges from ids TODO (IF necessary, here we can compare growth rates)
def idx_to_edge(preposts):
return [
(
[get_val(smoothed_tracks.loc[pre], -1) for pre in pres],
[get_val(smoothed_tracks.loc[post], 0) for post in posts],
)
for pres, posts in preposts
]
# idx_to_means = lambda preposts: [
# (
# [get_means(smoothed_tracks.loc[pre], -window) for pre in pres],
# [get_means(smoothed_tracks.loc[post], window) for post in posts],
# )
# for pres, posts in preposts
# ]
def idx_to_pred(preposts):
result = []
for pres, posts in preposts:
pre_res = []
for pre in pres:
y = get_last_i(smoothed_tracks.loc[pre], -window)
pre_res.append(
np.poly1d(np.polyfit(range(len(y)), y, 1))(len(y) + 1),
)
pos_res = [
get_means(smoothed_tracks.loc[post], window) for post in posts
]
result.append([pre_res, pos_res])
return result
edges = contig.apply(idx_to_edge) # Raw edges
# edges_mean = contig.apply(idx_to_means) # Mean of both
pre_pred = contig.apply(idx_to_pred) # Prediction of pre and mean of post
# edges_dMetric = edges.apply(get_dMetric_wrap, tol=tol)
# edges_dMetric_mean = edges_mean.apply(get_dMetric_wrap, tol=tol)
edges_dMetric_pred = pre_pred.apply(get_dMetric_wrap, tol=tol)
# combined_dMetric = pd.Series(
# [
# [np.nanmin((a, b), axis=0) for a, b in zip(x, y)]
# for x, y in zip(edges_dMetric, edges_dMetric_mean)
# ],
# index=edges_dMetric.index,
# )
# closest_pairs = combined_dMetric.apply(get_vec_closest_pairs, tol=tol)
solutions = []
# for (i, dMetrics), edgeset in zip(combined_dMetric.items(), edges):
for (i, dMetrics), edgeset in zip(edges_dMetric_pred.items(), edges):
solutions.append(solve_matrices_wrap(dMetrics, edgeset, tol=tol))
closest_pairs = pd.Series(
solutions,
index=edges_dMetric_pred.index,
)
# match local with global ids
joinable_ids = [
localid_to_idx(closest_pairs.loc[i], contig.loc[i])
for i in closest_pairs.index
]
return [pair for pairset in joinable_ids for pair in pairset]
def get_val(x, n):
return x[~np.isnan(x)][n] if len(x[~np.isnan(x)]) else np.nan
def get_means(x, i):
if not len(x[~np.isnan(x)]):
return np.nan
if i > 0:
v = x[~np.isnan(x)][:i]
else:
v = x[~np.isnan(x)][i:]
return np.nanmean(v)
def get_last_i(x, i):
if not len(x[~np.isnan(x)]):
return np.nan
if i > 0:
v = x[~np.isnan(x)][:i]
else:
v = x[~np.isnan(x)][i:]
return v
def localid_to_idx(local_ids, contig_trap):
"""
Fetch the original ids from a nested list with joinable local_ids.
......@@ -351,60 +554,6 @@ def get_vec_closest_pairs(lst: List, **kwargs):
return [get_closest_pairs(*sublist, **kwargs) for sublist in lst]
def get_dMetric_wrap(lst: List, **kwargs):
return [get_dMetric(*sublist, **kwargs) for sublist in lst]
def solve_matrices_wrap(dMetric: List, edges: List, **kwargs):
return [
solve_matrices(mat, edgeset, **kwargs)
for mat, edgeset in zip(dMetric, edges)
]
def get_dMetric(
pre: List[float], post: List[float], tol: Union[float, int] = 1
):
"""Calculate a cost matrix
input
:param pre: list of floats with edges on left
:param post: list of floats with edges on right
:param tol: int or float if int metrics of tolerance, if float fraction
returns
:: list of indices corresponding to the best solutions for matrices
"""
if len(pre) > len(post):
dMetric = np.abs(np.subtract.outer(post, pre))
else:
dMetric = np.abs(np.subtract.outer(pre, post))
dMetric[np.isnan(dMetric)] = (
tol + 1 + np.nanmax(dMetric)
) # nans will be filtered
return dMetric
def solve_matrices(
dMetric: np.ndarray, prepost: List, tol: Union[float, int] = 1
):
"""
Solve the distance matrices obtained in get_dMetric and/or merged from
independent dMetric matrices.
"""
ids = solve_matrix(dMetric)
if not len(ids[0]):
return []
pre, post = prepost
norm = (
np.array(pre)[ids[len(pre) > len(post)]] if tol < 1 else 1
) # relative or absolute tol
result = dMetric[ids] / norm
ids = ids if len(pre) < len(post) else ids[::-1]
return [idx for idx, res in zip(zip(*ids), result) if res <= tol]
def get_closest_pairs(
pre: List[float], post: List[float], tol: Union[float, int] = 1
):
......@@ -422,41 +571,11 @@ def get_closest_pairs(
"""
dMetric = get_dMetric(pre, post, tol)
return solve_matrices(dMetric, pre, post, tol)
def solve_matrix(dMetric):
"""
Solve cost matrix focusing on getting the smallest cost at each iteration.
input
:param dMetric: np.array cost matrix
returns
tuple of np.arrays indicating picks with lowest individual value
"""
glob_is = []
glob_js = []
if (~np.isnan(dMetric)).any():
tmp = copy(dMetric)
std = sorted(tmp[~np.isnan(tmp)])
while (~np.isnan(std)).any():
v = std[0]
i_s, j_s = np.where(tmp == v)
i = i_s[0]
j = j_s[0]
tmp[i, :] += np.nan
tmp[:, j] += np.nan
glob_is.append(i)
glob_js.append(j)
std = sorted(tmp[~np.isnan(tmp)])
return (np.array(glob_is), np.array(glob_js))
return find_best_from_scores(dMetric, pre, post, tol)
def plot_joinable(tracks, joinable_pairs):
"""
Convenience plotting function for debugging and data vis
"""
"""Convenience plotting function for debugging."""
nx = 8
ny = 8
_, axes = plt.subplots(nx, ny)
......@@ -475,59 +594,3 @@ def plot_joinable(tracks, joinable_pairs):
# pass
ax.plot(post_srs.index, post_srs.values, "g")
plt.show()
def get_contiguous_pairs(tracks: pd.DataFrame) -> list:
"""
Get all pair of contiguous track ids from a tracks dataframe.
:param tracks: (m x n) dataframe where rows are cell tracks and
columns are timepoints
:param min_dgr: float minimum difference in growth rate from
the interpolation
"""
mins, maxes = [
tracks.notna().apply(np.where, axis=1).apply(fn)
for fn in (np.min, np.max)
]
mins_d = mins.groupby(mins).apply(lambda x: x.index.tolist())
mins_d.index = mins_d.index - 1 # make indices equal
# TODO add support for skipping time points
maxes_d = maxes.groupby(maxes).apply(lambda x: x.index.tolist())
common = sorted(
set(mins_d.index).intersection(maxes_d.index), reverse=True
)
return [(maxes_d[t], mins_d[t]) for t in common]
# def fit_track(track: pd.Series, obj=None):
# if obj is None:
# obj = objective
# x = track.dropna().index
# y = track.dropna().values
# popt, _ = curve_fit(obj, x, y)
# return popt
# def interpolate(track, xs) -> list:
# '''
# Interpolate next timepoint from a track
# :param track: pd.Series of volume growth over a time period
# :param t: int timepoint to interpolate
# '''
# popt = fit_track(track)
# # perr = np.sqrt(np.diag(pcov))
# return objective(np.array(xs), *popt)
# def objective(x,a,b,c,d) -> float:
# # return (a)/(1+b*np.exp(c*x))+d
# return (((x+d)*a)/((x+d)+b))+c
# def cand_pairs_to_dict(candidates):
# d={x:[] for x,_ in candidates}
# for x,y in candidates:
# d[x].append(y)
# return d
......@@ -12,19 +12,20 @@ from postprocessor.core.abc import PostProcessABC
class LineageProcessParameters(ParametersABC):
"""
Parameters
"""
"""Parameters - none are necessary."""
_defaults = {}
class LineageProcess(PostProcessABC):
"""
Lineage process that must be passed a (N,3) lineage matrix (where the columns are trap, mother, daughter respectively)
To analyse lineage data.
Currently bare bones, but extracts lineage information from a Signal or Cells object.
"""
def __init__(self, parameters: LineageProcessParameters):
"""Initialise using PostProcessABC."""
super().__init__(parameters)
@abstractmethod
......@@ -34,6 +35,7 @@ class LineageProcess(PostProcessABC):
lineage: np.ndarray,
*args,
):
"""Implement method required by PostProcessABC - undefined."""
pass
@classmethod
......@@ -45,8 +47,9 @@ class LineageProcess(PostProcessABC):
**kwargs,
):
"""
Overrides PostProcess.as_function classmethod.
Lineage functions require lineage information to be passed if run as function.
Override PostProcesABC.as_function method.
Lineage functions require lineage information to be run as functions.
"""
parameters = cls.default_parameters(**kwargs)
return cls(parameters=parameters).run(
......@@ -54,8 +57,9 @@ class LineageProcess(PostProcessABC):
)
def get_lineage_information(self, signal=None, merged=True):
"""Get lineage as an array with tile IDs, mother labels, and corresponding bud labels."""
if signal is not None and "mother_label" in signal.index.names:
# from kymograph
lineage = get_index_as_np(signal)
elif hasattr(self, "lineage"):
lineage = self.lineage
......@@ -68,5 +72,5 @@ class LineageProcess(PostProcessABC):
elif self.cells is not None:
lineage = self.cells.mothers_daughters
else:
raise Exception("No linage information found")
raise Exception("No lineage information found")
return lineage
# change "prepost" to "preprocess"; change filename to postprocessor_engine.py ??
import typing as t
from itertools import takewhile
......@@ -13,7 +14,7 @@ from agora.utils.indexing import (
_3d_index_to_2d,
_assoc_indices_to_3d,
)
from agora.utils.merge import merge_association
from agora.utils.merge import merge_lineage
from postprocessor.core.abc import get_parameters, get_process
from postprocessor.core.lineageprocess import (
LineageProcess,
......@@ -61,36 +62,17 @@ class PostProcessorParameters(ParametersABC):
kind: list of str
If "ph_batman" included, add targets for experiments using pHlourin.
"""
# each subitem specifies the function to be called and the location
# on the h5 file to be written
# each subitem specifies the function to be called
# and the h5-file location for the results
#: why does merger have a string and picker a list?
targets = {
"prepost": {
"merger": "/extraction/general/None/area",
"picker": ["/extraction/general/None/area"],
},
"processes": [
[
"buddings",
["/extraction/general/None/volume"],
],
[
"dsignal",
[
"/extraction/general/None/volume",
],
],
[
"bud_metric",
[
"/extraction/general/None/volume",
],
],
[
"dsignal",
[
"/postprocessing/bud_metric/extraction_general_None_volume",
],
],
["buddings", ["/extraction/general/None/volume"]],
["bud_metric", ["/extraction/general/None/volume"]],
],
}
param_sets = {
......@@ -100,36 +82,13 @@ class PostProcessorParameters(ParametersABC):
}
}
outpaths = {}
outpaths["aggregate"] = "/postprocessing/experiment_wide/aggregated/"
# pHlourin experiments are special
if "ph_batman" in kind:
targets["processes"]["dsignal"].append(
[
"/extraction/em_ratio/np_max/mean",
"/extraction/em_ratio/np_max/median",
"/extraction/em_ratio_bgsub/np_max/mean",
"/extraction/em_ratio_bgsub/np_max/median",
]
)
targets["processes"]["aggregate"].append(
[
[
"/extraction/em_ratio/np_max/mean",
"/extraction/em_ratio/np_max/median",
"/extraction/em_ratio_bgsub/np_max/mean",
"/extraction/em_ratio_bgsub/np_max/median",
"/extraction/gsum/np_max/median",
"/extraction/gsum/np_max/mean",
]
],
)
return cls(targets=targets, param_sets=param_sets, outpaths=outpaths)
class PostProcessor(ProcessABC):
def __init__(self, filename, parameters):
"""
Initialise PostProcessor
Initialise PostProcessor.
Parameters
----------
......@@ -139,9 +98,9 @@ class PostProcessor(ProcessABC):
An instance of PostProcessorParameters.
"""
super().__init__(parameters)
self._filename = filename
self._signal = Signal(filename)
self._writer = Writer(filename)
self.filename = filename
self.signal = Signal(filename)
self.writer = Writer(filename)
# parameters for merger and picker
dicted_params = {
i: parameters["param_sets"]["prepost"][i]
......@@ -150,7 +109,7 @@ class PostProcessor(ProcessABC):
for k in dicted_params.keys():
if not isinstance(dicted_params[k], dict):
dicted_params[k] = dicted_params[k].to_dict()
# merger and picker
# initialise merger and picker
self.merger = Merger(
MergerParameters.from_dict(dicted_params["merger"])
)
......@@ -158,12 +117,12 @@ class PostProcessor(ProcessABC):
PickerParameters.from_dict(dicted_params["picker"]),
cells=Cells.from_source(filename),
)
# processes, such as buddings
# get processes, such as buddings
self.classfun = {
process: get_process(process)
for process, _ in parameters["targets"]["processes"]
}
# parameters for the process in classfun
# get parameters for the processes in classfun
self.parameters_classfun = {
process: get_parameters(process)
for process, _ in parameters["targets"]["processes"]
......@@ -172,64 +131,51 @@ class PostProcessor(ProcessABC):
self.targets = parameters["targets"]
def run_prepost(self):
"""Using picker, get and write lineages, returning mothers and daughters."""
"""Important processes run before normal post-processing ones"""
record = self._signal.get_raw(self.targets["prepost"]["merger"])
merges = np.array(self.merger.run(record), dtype=int)
self._writer.write(
"modifiers/merges", data=[np.array(x) for x in merges]
)
"""
Run merger, get lineages, and then run picker.
Necessary before any processes can run.
"""
# run merger
record = self.signal.get_raw(self.targets["prepost"]["merger"])
merges = self.merger.run(record)
# get lineages from cells object attached to picker
lineage = _assoc_indices_to_3d(self.picker.cells.mothers_daughters)
lineage_merged = []
if merges.any(): # Update lineages after merge events
merged_indices = merge_association(lineage, merges)
# Remove repeated labels post-merging
lineage_merged = np.unique(merged_indices, axis=0)
self.lineage = _3d_index_to_2d(
lineage_merged if len(lineage_merged) else lineage
if merges.any():
# update lineages and merges after merging
new_lineage, new_merges = merge_lineage(lineage, merges)
else:
new_lineage = lineage
new_merges = merges
self.lineage = _3d_index_to_2d(new_lineage)
self.writer.write(
"modifiers/merges", data=[np.array(x) for x in new_merges]
)
self._writer.write(
"modifiers/lineage_merged", _3d_index_to_2d(lineage_merged)
self.writer.write(
"modifiers/lineage_merged", _3d_index_to_2d(new_lineage)
)
# run picker
picked_indices = self.picker.run(
self._signal[self.targets["prepost"]["picker"][0]]
self.signal[self.targets["prepost"]["picker"][0]]
)
if picked_indices.any():
self._writer.write(
self.writer.write(
"modifiers/picks",
data=pd.MultiIndex.from_arrays(
picked_indices.T,
# names=["trap", "cell_label", "mother_label"],
names=["trap", "cell_label"],
picked_indices.T, names=["trap", "cell_label"]
),
overwrite="overwrite",
)
@staticmethod
def pick_mother(a, b):
"""Update the mother id following this priorities:
The mother has a lower id
"""
x = max(a, b)
if min([a, b]):
x = [a, b][np.argmin([a, b])]
return x
def run(self):
"""
Write the results to the h5 file.
Processes include identifying buddings and finding bud metrics.
"""
# run merger, picker, and find lineages
self.run_prepost()
# run processes
# run processes: process is a str; datasets is a list of str
for process, datasets in tqdm(self.targets["processes"]):
if process in self.parameters["param_sets"].get("processes", {}):
# parameters already assigned
......@@ -237,25 +183,23 @@ class PostProcessor(ProcessABC):
self.parameters[process]
)
else:
# assign parameters
# assign default parameters
parameters = self.parameters_classfun[process].default()
# load process
# load process - instantiate an object in the class
loaded_process = self.classfun[process](parameters)
if isinstance(parameters, LineageProcessParameters):
loaded_process.lineage = self.lineage
# apply process to each dataset
for dataset in datasets:
self.run_process(dataset, process, loaded_process)
def run_process(self, dataset, process, loaded_process):
"""Run process on a single dataset and write the result."""
# define signal
"""Run process to obtain a single dataset and write the result."""
# get pre-processed data
if isinstance(dataset, list):
# multisignal process
signal = [self._signal[d] for d in dataset]
signal = [self.signal[d] for d in dataset]
elif isinstance(dataset, str):
signal = self._signal[dataset]
signal = self.signal[dataset]
else:
raise ("Incorrect dataset")
# run process on signal
......@@ -269,8 +213,9 @@ class PostProcessor(ProcessABC):
[], columns=signal.columns, index=signal.index
)
result.columns.names = ["timepoint"]
# define outpath, where result will be written
# use outpath to write result
if process in self.parameters["outpaths"]:
# outpath already defined
outpath = self.parameters["outpaths"][process]
elif isinstance(dataset, list):
# no outpath is defined
......@@ -317,4 +262,16 @@ class PostProcessor(ProcessABC):
result: t.Union[t.List, pd.DataFrame, np.ndarray],
metadata: t.Dict,
):
self._writer.write(path, result, meta=metadata, overwrite="overwrite")
self.writer.write(path, result, meta=metadata, overwrite="overwrite")
@staticmethod
def pick_mother(a, b):
"""
Update the mother id following this priorities:
The mother has a lower id
"""
x = max(a, b)
if min([a, b]):
x = [a, b][np.argmin([a, b])]
return x
import typing as t
from typing import Dict, Tuple
import numpy as np
import pandas as pd
......@@ -12,17 +11,16 @@ from postprocessor.core.lineageprocess import (
class BudMetricParameters(LineageProcessParameters):
"""
Parameters
"""
"""Give default location of lineage information."""
_defaults = {"lineage_location": "postprocessing/lineage_merged"}
class BudMetric(LineageProcess):
"""
Requires mother-bud information to create a new dataframe where the indices are mother ids and
values are the daughters' values for a given signal.
Requires mother-bud information to create a new dataframe where the
indices are mother ids and values are the daughters' values for a
given signal.
"""
def __init__(self, parameters: BudMetricParameters):
......@@ -31,94 +29,149 @@ class BudMetric(LineageProcess):
def run(
self,
signal: pd.DataFrame,
lineage: Dict[pd.Index, Tuple[pd.Index]] = None,
lineage: t.Dict[pd.Index, t.Tuple[pd.Index]] = None,
):
if lineage is None:
# define lineage
if hasattr(self, "lineage"):
lineage = self.lineage
else:
# lineage information in the Signal dataframe
assert "mother_label" in signal.index.names
lineage = signal.index.to_list()
return self.get_bud_metric(signal, mb_array_to_dict(lineage))
@staticmethod
def get_bud_metric(
signal: pd.DataFrame, md: Dict[Tuple, Tuple[Tuple]] = None
signal: pd.DataFrame, md: t.Dict[t.Tuple, t.Tuple[t.Tuple]] = None
):
"""
signal: Daughter-inclusive dataframe
md: Mother-daughters dictionary where key is mother's index and value a list of daugher indices
Get fvi (First Valid Index) for all cells
Create empty matrix
for every mother:
- Get daughters' subdataframe
- sort daughters by cell label
- get series of fvis
- concatenate the values of these ranges from the dataframe
Fill the empty matrix
Convert matrix into dataframe using mother indices
Generate a dataframe of a Signal for buds indexed by their mothers,
concatenating data from all the buds for each mother.
Parameters
---------
signal: pd.Dataframe
A dataframe that includes data for both mothers and daughters.
md: dict
A dict of lineage information with each key a mother's index,
defined as (trap, cell_label), and the corresponding values are a
list of daughter indices, also defined as (trap, cell_label).
"""
mothers_mat = np.zeros((len(md), signal.shape[1]))
cells_were_dropped = 0 # Flag determines if mothers (1), daughters (2) or both were missing (3)
md_index = signal.index
if (
"mother_label" not in md_index.names
): # Generate mother label from md dict if unavailable
d = {v: k for k, values in md.items() for v in values}
# md_index should only comprise (trap, cell_label)
if "mother_label" not in md_index.names:
# dict with daughter indices as keys and mother indices as values
bud_dict = {v: k for k, values in md.items() for v in values}
# generate mother_label in Signal using the mother's cell_label
# cells with no mothers have a mother_label of 0
signal["mother_label"] = list(
map(lambda x: d.get(x, [0])[-1], signal.index)
map(lambda x: bud_dict.get(x, [0])[-1], signal.index)
)
signal.set_index("mother_label", append=True, inplace=True)
related_items = set(
[*md.keys(), *[y for x in md.values() for y in x]]
)
md_index = md_index.intersection(related_items)
elif "mother_label" in md_index.names:
md_index = md_index.droplevel("mother_label")
# combine mothers and daughter indices
mothers_index = md.keys()
daughters_index = [y for x in md.values() for y in x]
relations = set([*mothers_index, *daughters_index])
# keep from md_index only cells that are mother or daughters
md_index = md_index.intersection(relations)
else:
raise ("Unavailable relationship information")
md_index = md_index.droplevel("mother_label")
if len(md_index) < len(signal):
print("Dropped cells before bud_metric") # TODO log
print(
f"Dropped {len(signal) - len(md_index)} cells before applying bud_metric"
) # TODO log
# restrict signal to the cells in md_index moving mother_label to do so
signal = (
signal.reset_index("mother_label")
.loc(axis=0)[md_index]
.set_index("mother_label", append=True)
)
names = list(signal.index.names)
del names[-2]
output_df = (
signal.loc[signal.index.get_level_values("mother_label") > 0]
.groupby(names)
.apply(lambda x: _combine_daughter_tracks(x))
# restrict to daughters: cells with a mother
mother_labels = signal.index.get_level_values("mother_label")
daughter_df = signal.loc[mother_labels > 0]
# join data for daughters with the same mother
output_df = daughter_df.groupby(["trap", "mother_label"]).apply(
combine_daughter_tracks
)
output_df.columns = signal.columns
output_df["padding_level"] = 0
output_df.set_index("padding_level", append=True, inplace=True)
# daughter data is indexed by mothers, which themselves have no mothers
output_df["temp_mother_label"] = 0
output_df.set_index("temp_mother_label", append=True, inplace=True)
if len(output_df):
output_df.index.names = signal.index.names
return output_df
def _combine_daughter_tracks(tracks: t.Collection[pd.Series]):
def combine_daughter_tracks(tracks: pd.DataFrame):
"""
Combine multiple time series of daughter cells into one time series.
Concatenate daughter values into one time series starting with the first
daughter and replacing later values with the values from the next daughter,
and so on.
Parameters
----------
tracks: a Signal
Data for all daughters, which are distinguished by different cell_labels,
for a particular trap and mother_label.
"""
# sort by daughter IDs
bud_df = tracks.sort_index(level="cell_label")
# remove multi-index
no_rows = len(bud_df)
bud_df.index = range(no_rows)
# find time point of first non-NaN data point of each row
init_tps = [
bud_df.iloc[irow].first_valid_index() for irow in range(no_rows)
]
# sort so that earliest daughter is first
sorted_rows = np.argsort(init_tps)
init_tps = np.sort(init_tps)
# combine data for all daughters
combined_tracks = np.nan * np.ones(tracks.columns.size)
for j, jrow in enumerate(sorted_rows):
# over-write with next earliest daughter
combined_tracks[bud_df.columns.get_loc(init_tps[j]) :] = (
bud_df.iloc[jrow].loc[init_tps[j] :].values
)
return pd.Series(combined_tracks, index=tracks.columns)
def _combine_daughter_tracks_original(tracks: pd.DataFrame):
"""
Combine multiple time series of cells into one, overwriting values
prioritising the most recent entity.
Combine multiple time series of daughter cells into one time series.
At any one time, a mother cell should have only one daughter.
Two daughters are still sometimes present at the same time point, and we
then choose the daughter that appears first.
TODO We need to fix examples with more than one daughter at a time point.
Parameters
----------
tracks: a Signal
Data for all daughters, which are distinguished by different cell_labels,
for a particular trap and mother_label.
"""
sorted_da_ids = tracks.sort_index(level="cell_label")
sorted_da_ids.index = range(len(sorted_da_ids))
tp_fvt = sorted_da_ids.apply(lambda x: x.first_valid_index(), axis=0)
tp_fvt = sorted_da_ids.columns.get_indexer(tp_fvt)
tp_fvt[tp_fvt < 0] = len(sorted_da_ids) - 1
_metric = np.choose(tp_fvt, sorted_da_ids.values)
return pd.Series(_metric, index=tracks.columns)
# sort by daughter IDs
bud_df = tracks.sort_index(level="cell_label")
# remove multi-index
bud_df.index = range(len(bud_df))
# find which row of sorted_df has the daughter for each time point
tp_fvt: pd.Series = bud_df.apply(lambda x: x.first_valid_index(), axis=0)
# combine data for all daughters
combined_tracks = np.nan * np.ones(tracks.columns.size)
for bud_row in np.unique(tp_fvt.dropna().values).astype(int):
ilocs = np.where(tp_fvt.values == bud_row)[0]
combined_tracks[ilocs] = bud_df.values[bud_row, ilocs]
# TODO delete old version
tp_fvt = bud_df.columns.get_indexer(tp_fvt)
tp_fvt[tp_fvt == -1] = len(bud_df) - 1
old = np.choose(tp_fvt, bud_df.values)
assert (
(combined_tracks == old) | (np.isnan(combined_tracks) & np.isnan(old))
).all(), "yikes"
return pd.Series(combined_tracks, index=tracks.columns)
......@@ -13,74 +13,69 @@ from postprocessor.core.lineageprocess import (
class buddingsParameters(LineageProcessParameters):
"""Parameter class to obtain budding events.
Parameters
----------
LineageProcessParameters : lineage_location
Location of lineage matrix to be used for calculations.
Examples
--------
FIXME: Add docs.
"""
"""Give the location of lineage information in the h5 file."""
_defaults = {"lineage_location": "postprocessing/lineage_merged"}
class buddings(LineageProcess):
"""
Calculate buddings in a trap assuming one mother per trap
returns a pandas series with the buddings.
Generate a dataframe of budding events.
We assume one mother per trap.
We define a budding event as the moment in which a bud was identified for
the first time, even if the bud is not considered one until later
in the experiment.
A bud may not be considered a bud until later in the experiment.
"""
def __init__(self, parameters: buddingsParameters):
"""Initialise buddings."""
super().__init__(parameters)
def run(
self, signal: pd.DataFrame, lineage: np.ndarray = None
) -> pd.DataFrame:
lineage = lineage or self.lineage
"""
Generate dataframe of budding events.
# Get time of first appearance for all cells
fvi = signal.apply(lambda x: x.first_valid_index(), axis=1)
Find daughters for those mothers in a Signal with lineage data.
Create a dataframe indicating the time each daughter first appears.
# Select mother cells in a given dataset
We use the data from Signal only to find when the daughters appear, by
their first non-NaN value.
"""
# lineage is (trap, mother, daughter)
lineage = lineage or self.lineage
# select traps and mothers in the signal that have lineage data
traps_mothers: t.Dict[tuple, list] = {
tuple(mo): [] for mo in lineage[:, :2] if tuple(mo) in signal.index
tuple(trap_mo): []
for trap_mo in lineage[:, :2]
if tuple(trap_mo) in signal.index
}
# add daughters, potentially multiple, for these traps and mothers
for trap, mother, daughter in lineage:
if (trap, mother) in traps_mothers.keys():
traps_mothers[(trap, mother)].append(daughter)
# a new dataframe with dimensions (n_mother_cells * n_tps)
mothers = signal.loc[
set(signal.index).intersection(traps_mothers.keys())
]
# Create a new dataframe with dimensions (n_mother_cells * n_timepoints)
buddings = pd.DataFrame(
np.zeros((mothers.shape[0], signal.shape[1])).astype(bool),
np.zeros(mothers.shape).astype(bool),
index=mothers.index,
columns=signal.columns,
)
buddings.columns.names = ["timepoint"]
# Fill the budding events
for mother_id, daughters in traps_mothers.items():
daughters_idx = set(
fvi.loc[
fvi.index.intersection(
list(product((mother_id[0],), daughters))
)
].values
).difference({0})
buddings.loc[
mother_id,
daughters_idx,
] = True
# get time of first non-NaN value of signal for every cell using Pandas
fvi = signal.apply(lambda x: x.first_valid_index(), axis=1)
# fill the budding events
for trap_mother_id, daughters in traps_mothers.items():
trap_daughter_ids = [
i for i in product((trap_mother_id[0],), daughters)
]
times_of_bud_appearance = fvi.loc[
fvi.index.intersection(trap_daughter_ids)
].values
# ignore zeros - buds in first image are not budding events
daughters_idx = set(times_of_bud_appearance).difference({0})
buddings.loc[trap_mother_id, daughters_idx] = True
return buddings
from agora.abc import ParametersABC
import numpy as np
from agora.abc import ParametersABC
from postprocessor.core.abc import PostProcessABC
from postprocessor.core.functions.tracks import get_joinable
from postprocessor.core.functions.tracks import get_merges
class MergerParameters(ParametersABC):
......@@ -10,11 +11,11 @@ class MergerParameters(ParametersABC):
There are five parameters expected in the dict:
smooth, boolean
smooth: boolean
Whether or not to smooth with a savgol_filter.
tol: float or int
tol: float or int
The threshold of average prediction error/std necessary to
consider two tracks the same.
consider two tracks to be the same.
If float, the threshold is the fraction of the first track;
if int, the threshold is in absolute units.
window: int
......@@ -33,19 +34,20 @@ class MergerParameters(ParametersABC):
class Merger(PostProcessABC):
"""Combine rows of tracklet that are likely to be the same."""
"""Find array of pairs of (trap, cell) indices to be merged."""
def __init__(self, parameters):
super().__init__(parameters)
def run(self, signal):
joinable = []
if signal.shape[1] > 4:
joinable = get_joinable(
merges = get_merges(
signal,
smooth=self.parameters.smooth,
tol=self.parameters.tolerance,
window=self.parameters.window,
degree=self.parameters.degree,
)
return joinable
else:
merges = np.array([])
return merges
......@@ -5,13 +5,24 @@ import pandas as pd
from agora.abc import ParametersABC
from agora.io.cells import Cells
from agora.utils.indexing import validate_association
from agora.utils.indexing import validate_lineage
from agora.utils.cast import _str_to_int
from agora.utils.kymograph import drop_mother_label
from postprocessor.core.lineageprocess import LineageProcess
class PickerParameters(ParametersABC):
"""
A dictionary specifying the sequence of picks in order.
"lineage" is further specified by "mothers", "daughters", and
"families" (mother-bud pairs).
"condition" is further specified by "present", "any_present", or
"growing" and a threshold, either a number of time points or a
fraction of the total duration of the experiment.
"""
_defaults = {
"sequence": [
["lineage", "families"],
......@@ -22,11 +33,8 @@ class PickerParameters(ParametersABC):
class Picker(LineageProcess):
"""
:cells: Cell object passed to the constructor
:condition: Tuple with condition and associated parameter(s), conditions can be
"present", "nonstoply_present" or "quantile".
Determine the thresholds or fractions of signals to use.
:lineage: str {"mothers", "daughters", "families" (mothers AND daughters), "orphans"}. Mothers/daughters picks cells with those tags, families pick the union of both and orphans the difference between the total and families.
Picker selects cells using lineage information and by
how and for how long they are retained in the data set.
"""
def __init__(
......@@ -34,6 +42,7 @@ class Picker(LineageProcess):
parameters: PickerParameters,
cells: Cells or None = None,
):
"""Initialise picker."""
super().__init__(parameters=parameters)
self.cells = cells
......@@ -43,93 +52,102 @@ class Picker(LineageProcess):
how: str,
mothers_daughters: t.Optional[np.ndarray] = None,
) -> pd.MultiIndex:
"""
Return rows of a signal using lineage information.
Rows correspond to either mothers, daughters, or mother-daughter
pairs.
"""
cells_present = drop_mother_label(signal.index)
mothers_daughters = self.get_lineage_information(signal)
valid_indices = slice(None)
if how == "mothers":
_, valid_indices = validate_association(
mothers_daughters, cells_present, match_column=0
)
elif how == "daughters":
_, valid_indices = validate_association(
mothers_daughters, cells_present, match_column=1
)
elif how == "families": # Mothers and daughters that are still present
_, valid_indices = validate_association(
mothers_daughters, cells_present
)
_, valid_indices = validate_lineage(
mothers_daughters, cells_present, how
)
return signal.index[valid_indices]
def pick_by_condition(self, signal, condition, thresh):
idx = self.switch_case(signal, condition, thresh)
return idx
def run(self, signal):
"""
Pick indices from the index of a signal's dataframe.
Typically, we first pick by lineage, then by condition.
The indices are returned as an array.
"""
self.orig_signal = signal
indices = set(signal.index)
lineage = self.get_lineage_information(signal)
if len(lineage):
self.mothers = lineage[:, :2]
self.mothers = lineage[:, [0, 1]]
self.daughters = lineage[:, [0, 2]]
for alg, *params in self.sequence:
new_indices = tuple()
if indices:
if alg == "lineage":
# pick by lineage
param1 = params[0]
new_indices = getattr(self, "pick_by_" + alg)(
new_indices = self.pick_by_lineage(
signal.loc[list(indices)], param1
)
else:
# pick by condition
param1, *param2 = params
new_indices = getattr(self, "pick_by_" + alg)(
new_indices = self.pick_by_condition(
signal.loc[list(indices)], param1, param2
)
new_indices = [tuple(x) for x in new_indices]
else:
new_indices = tuple()
# number of indices reduces for each iteration of the loop
indices = indices.intersection(new_indices)
else:
self._log(f"No lineage assignment")
self._log("No lineage assignment")
indices = np.array([])
return np.array([tuple(map(_str_to_int, x)) for x in indices])
# convert to array
indices_arr = np.array([tuple(map(_str_to_int, x)) for x in indices])
return indices_arr
def switch_case(
def pick_by_condition(
self,
signal: pd.DataFrame,
condition: str,
threshold: t.Union[float, int, list],
):
"""Pick indices from signal by any_present, present, and growing."""
if len(threshold) == 1:
threshold = [_as_int(*threshold, signal.shape[1])]
#: is this correct for "growing"?
case_mgr = {
"any_present": lambda s, thresh: any_present(s, thresh),
"present": lambda s, thresh: s.notna().sum(axis=1) > thresh,
"nonstoply_present": lambda s, thresh: s.apply(thresh, axis=1)
> thresh,
"growing": lambda s, thresh: s.diff(axis=1).sum(axis=1) > thresh,
"any_present": lambda s, threshold: any_present(s, threshold),
"present": lambda s, threshold: s.notna().sum(axis=1) > threshold,
"growing": lambda s, threshold: s.diff(axis=1).sum(axis=1)
> threshold,
}
return set(signal.index[case_mgr[condition](signal, *threshold)])
# apply condition
idx = set(signal.index[case_mgr[condition](signal, *threshold)])
new_indices = [tuple(x) for x in idx]
return new_indices
def _as_int(threshold: t.Union[float, int], ntps: int):
"""Convert a fraction of the total experiment duration into a number of time points."""
if type(threshold) is float:
threshold = ntps * threshold
return threshold
def any_present(signal, threshold):
"""
Return a mask for cells, True if there is a cell in that trap that was present for more than :threshold: timepoints.
"""
"""Find traps where at least one cell stays for more than threshold time points."""
# all_traps contains repeated traps, which have more than one cell
all_traps = [x[0] for x in signal.index]
# full_traps contains only traps that have at least one cell
full_traps = (signal.notna().sum(axis=1) > threshold).groupby("trap")
# expand full_traps to size of signal.index
# rows highlight traps in signal_index for each full trap
trap_array = np.array(
[
np.isin(all_traps, trap_id) & full
for trap_id, full in full_traps.any().items()
]
)
# convert to pd.Series
any_present = pd.Series(
np.sum(
[
np.isin([x[0] for x in signal.index], i) & v
for i, v in (signal.notna().sum(axis=1) > threshold)
.groupby("trap")
.any()
.items()
],
axis=0,
).astype(bool),
index=signal.index,
np.sum(trap_array, axis=0).astype(bool), index=signal.index
)
return any_present
......@@ -86,15 +86,15 @@ class Grouper(ABC):
**kwargs,
):
"""
Concatenate data for one signal from different h5 files, with
one h5 file per position, into a dataframe.
Concatenate data for one signal from different h5 files, one for
each position, into a dataframe.
Parameters
----------
path : str
Signal location within h5py file
Signal location within h5 file.
pool : int
Number of threads used; if 0 or None only one core is used
Number of threads used; if 0 or None only one core is used.
mode: str
standard: boolean
**kwargs : key, value pairings
......@@ -107,35 +107,38 @@ class Grouper(ABC):
if path.startswith("/"):
path = path.strip("/")
good_chains = self.filter_chains(path)
if standard:
fn_pos = concat_standard
else:
fn_pos = concat_signal_ind
kwargs["mode"] = mode
records = self.pool_function(
path=path,
f=fn_pos,
pool=pool,
chainers=good_chains,
**kwargs,
)
# check for errors
errors = [
k for kymo, k in zip(records, self.chainers.keys()) if kymo is None
]
records = [record for record in records if record is not None]
if len(errors):
print("Warning: Positions contain errors {errors}")
assert len(records), "All data sets contain errors"
# combine into one dataframe
concat = pd.concat(records, axis=0)
if len(concat.index.names) > 4:
# reorder levels in the multi-index dataframe when mother_label is present
concat = concat.reorder_levels(
("group", "position", "trap", "cell_label", "mother_label")
if good_chains:
if standard:
fn_pos = concat_standard
else:
fn_pos = concat_one_signal
kwargs["mode"] = mode
records = self.pool_function(
path=path,
f=fn_pos,
pool=pool,
chainers=good_chains,
**kwargs,
)
concat_sorted = concat.sort_index()
return concat_sorted
# check for errors
errors = [
k
for kymo, k in zip(records, self.chainers.keys())
if kymo is None
]
records = [record for record in records if record is not None]
if len(errors):
print("Warning: Positions contain errors {errors}")
assert len(records), "All data sets contain errors"
# combine into one dataframe
concat = pd.concat(records, axis=0)
if len(concat.index.names) > 4:
# reorder levels in the multi-index dataframe when mother_label is present
concat = concat.reorder_levels(
("group", "position", "trap", "cell_label", "mother_label")
)
concat_sorted = concat.sort_index()
return concat_sorted
def filter_chains(self, path: str) -> t.Dict[str, Chainer]:
"""Filter chains to those whose data is available in the h5 file."""
......@@ -150,9 +153,6 @@ class Grouper(ABC):
f"Grouper:Warning: {nchains_dif} chains do not contain"
f" channel {path}"
)
assert len(
good_chains
), f"No valid dataset to use. Valid datasets are {self.available}"
return good_chains
def pool_function(
......@@ -163,9 +163,8 @@ class Grouper(ABC):
chainers: t.Dict[str, Chainer] = None,
**kwargs,
):
"""Enable different threads for independent chains, particularly useful when aggregating multiple elements."""
if pool is None:
pass
"""Enable different threads for independent chains, particularly
useful when aggregating multiple elements."""
chainers = chainers or self.chainers
if pool:
with Pool(pool) as p:
......@@ -267,17 +266,17 @@ class Grouper(ABC):
@property
def stages_span(self):
# FAILS on my example
# TODO: fails on my example
return self.fsignal.stages_span
@property
def max_span(self):
# FAILS on my example
# TODO: fails on my example
return self.fsignal.max_span
@property
def stages(self):
# FAILS on my example
# TODO: fails on my example
return self.fsignal.stages
@property
......@@ -370,7 +369,7 @@ def concat_standard(
return combined
def concat_signal_ind(
def concat_one_signal(
path: str,
chainer: Chainer,
group: str,
......
......@@ -56,9 +56,9 @@ def test_extractor(imgs, masks, tree):
for ch_branches in extractor.params.tree.values():
print(
extractor.reduce_extract(
red_metrics=ch_branches,
traps=[traps],
masks=[masks],
[traps],
[masks],
ch_branches,
labels={0: labels},
)
)