Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • swain-lab/aliby/aliby-mirror
  • swain-lab/aliby/alibylite
2 results
Show changes
Commits on Source (31)
Showing
with 850 additions and 586 deletions
......@@ -17,16 +17,14 @@ atomic = t.Union[int, float, str, bool]
class ParametersABC(ABC):
"""
Defines parameters as attributes and allows parameters to
Define parameters as attributes and allow parameters to
be converted to either a dictionary or to yaml.
No attribute should be called "parameters"!
"""
def __init__(self, **kwargs):
"""
Defines parameters as attributes
"""
"""Define parameters as attributes."""
assert (
"parameters" not in kwargs
), "No attribute should be named parameters"
......@@ -35,8 +33,9 @@ class ParametersABC(ABC):
def to_dict(self, iterable="null") -> t.Dict:
"""
Recursive function to return a nested dictionary of the
attributes of the class instance.
Return a nested dictionary of the attributes of the class instance.
Uses recursion.
"""
if isinstance(iterable, dict):
if any(
......@@ -62,7 +61,8 @@ class ParametersABC(ABC):
def to_yaml(self, path: Union[Path, str] = None):
"""
Returns a yaml stream of the attributes of the class instance.
Return a yaml stream of the attributes of the class instance.
If path is provided, the yaml stream is saved there.
Parameters
......@@ -81,9 +81,7 @@ class ParametersABC(ABC):
@classmethod
def from_yaml(cls, source: Union[Path, str]):
"""
Returns instance from a yaml filename or stdin
"""
"""Return instance from a yaml filename or stdin."""
is_buffer = True
try:
if Path(source).exists():
......@@ -107,7 +105,8 @@ class ParametersABC(ABC):
def update(self, name: str, new_value):
"""
Update values recursively
Update values recursively.
if name is a dictionary, replace data where existing found or add if not.
It warns against type changes.
......@@ -179,7 +178,8 @@ def add_to_collection(
class ProcessABC(ABC):
"""
Base class for processes.
Defines parameters as attributes and requires run method to be defined.
Define parameters as attributes and requires a run method.
"""
def __init__(self, parameters):
......@@ -243,11 +243,9 @@ class StepABC(ProcessABC):
@timer
def run_tp(self, tp: int, **kwargs):
"""
Time and log the timing of a step.
"""
"""Time and log the timing of a step."""
return self._run_tp(tp, **kwargs)
def run(self):
# Replace run with run_tp
raise Warning("Steps use run_tp instead of run")
raise Warning("Steps use run_tp instead of run.")
This diff is collapsed.
......@@ -6,17 +6,19 @@ import typing as t
from functools import wraps
def _first_arg_str_to_df(
def _first_arg_str_to_raw_df(
fn: t.Callable,
):
"""Enable Signal-like classes to convert strings to data sets."""
@wraps(fn)
def format_input(*args, **kwargs):
cls = args[0]
data = args[1]
if isinstance(data, str):
# get data from h5 file
# get data from h5 file using Signal's get_raw
data = cls.get_raw(data)
# replace path in the undecorated function with data
return fn(cls, data, *args[2:], **kwargs)
return format_input
......@@ -66,7 +66,7 @@ class MetaData:
# Needed because HDF5 attributes do not support dictionaries
def flatten_dict(nested_dict, separator="/"):
"""
Flattens nested dictionary. If empty return as-is.
Flatten nested dictionary. If empty return as-is.
"""
flattened = {}
if nested_dict:
......@@ -79,9 +79,7 @@ def flatten_dict(nested_dict, separator="/"):
# Needed because HDF5 attributes do not support datetime objects
# Takes care of time zones & daylight saving
def datetime_to_timestamp(time, locale="Europe/London"):
"""
Convert datetime object to UNIX timestamp
"""
"""Convert datetime object to UNIX timestamp."""
return timezone(locale).localize(time).timestamp()
......@@ -189,36 +187,37 @@ def parse_swainlab_metadata(filedir: t.Union[str, Path]):
Dictionary with minimal metadata
"""
filedir = Path(filedir)
filepath = find_file(filedir, "*.log")
if filepath:
# new log files
raw_parse = parse_from_swainlab_grammar(filepath)
minimal_meta = get_meta_swainlab(raw_parse)
else:
# old log files
if filedir.is_file() or str(filedir).endswith(".zarr"):
# log file is in parent directory
filedir = filedir.parent
legacy_parse = parse_logfiles(filedir)
minimal_meta = (
get_meta_from_legacy(legacy_parse) if legacy_parse else {}
)
return minimal_meta
def dispatch_metadata_parser(filepath: t.Union[str, Path]):
"""
Function to dispatch different metadata parsers that convert logfiles into a
basic metadata dictionary. Currently only contains the swainlab log parsers.
Dispatch different metadata parsers that convert logfiles into a dictionary.
Currently only contains the swainlab log parsers.
Input:
--------
filepath: str existing file containing metadata, or folder containing naming conventions
filepath: str existing file containing metadata, or folder containing naming
conventions
"""
parsed_meta = parse_swainlab_metadata(filepath)
if parsed_meta is None:
parsed_meta = dir_to_meta
return parsed_meta
......
......@@ -10,8 +10,8 @@ import numpy as np
import pandas as pd
from agora.io.bridge import BridgeH5
from agora.io.decorators import _first_arg_str_to_df
from agora.utils.indexing import validate_association
from agora.io.decorators import _first_arg_str_to_raw_df
from agora.utils.indexing import validate_lineage
from agora.utils.kymograph import add_index_levels
from agora.utils.merge import apply_merges
......@@ -20,11 +20,14 @@ class Signal(BridgeH5):
"""
Fetch data from h5 files for post-processing.
Signal assumes that the metadata and data are accessible to perform time-adjustments and apply previously recorded post-processes.
Signal assumes that the metadata and data are accessible to
perform time-adjustments and apply previously recorded
post-processes.
"""
def __init__(self, file: t.Union[str, Path]):
"""Define index_names for dataframes, candidate fluorescence channels, and composite statistics."""
"""Define index_names for dataframes, candidate fluorescence channels,
and composite statistics."""
super().__init__(file, flag=None)
self.index_names = (
"experiment",
......@@ -46,9 +49,9 @@ class Signal(BridgeH5):
def __getitem__(self, dsets: t.Union[str, t.Collection]):
"""Get and potentially pre-process data from h5 file and return as a dataframe."""
if isinstance(dsets, str): # no pre-processing
if isinstance(dsets, str):
return self.get(dsets)
elif isinstance(dsets, list): # pre-processing
elif isinstance(dsets, list):
is_bgd = [dset.endswith("imBackground") for dset in dsets]
# Check we are not comparing tile-indexed and cell-indexed data
assert sum(is_bgd) == 0 or sum(is_bgd) == len(
......@@ -58,22 +61,23 @@ class Signal(BridgeH5):
else:
raise Exception(f"Invalid type {type(dsets)} to get datasets")
def get(self, dsets: t.Union[str, t.Collection], **kwargs):
"""Get and potentially pre-process data from h5 file and return as a dataframe."""
if isinstance(dsets, str): # no pre-processing
df = self.get_raw(dsets, **kwargs)
def get(self, dset_name: t.Union[str, t.Collection], **kwargs):
"""Return pre-processed data as a dataframe."""
if isinstance(dset_name, str):
dsets = self.get_raw(dset_name, **kwargs)
prepost_applied = self.apply_prepost(dsets, **kwargs)
return self.add_name(prepost_applied, dsets)
return self.add_name(prepost_applied, dset_name)
else:
raise Exception("Error in Signal.get")
@staticmethod
def add_name(df, name):
"""Add column of identical strings to a dataframe."""
"""Add name of the Signal as an attribute to its corresponding dataframe."""
df.name = name
return df
def cols_in_mins(self, df: pd.DataFrame):
# Convert numerical columns in a dataframe to minutes
"""Convert numerical columns in a dataframe to minutes."""
try:
df.columns = (df.columns * self.tinterval // 60).astype(int)
except Exception as e:
......@@ -94,14 +98,15 @@ class Signal(BridgeH5):
if tinterval_location in f.attrs:
return f.attrs[tinterval_location][0]
else:
logging.getlogger("aliby").warn(
logging.getLogger("aliby").warn(
f"{str(self.filename).split('/')[-1]}: using default time interval of 5 minutes"
)
return 5
@staticmethod
def get_retained(df, cutoff):
"""Return a fraction of the df, one without later time points."""
"""Return rows of df with at least cutoff fraction of the total number
of time points."""
return df.loc[bn.nansum(df.notna(), axis=1) > df.shape[1] * cutoff]
@property
......@@ -110,15 +115,15 @@ class Signal(BridgeH5):
with h5py.File(self.filename, "r") as f:
return list(f.attrs["channels"])
@_first_arg_str_to_df
def retained(self, signal, cutoff=0.8):
"""
Load data (via decorator) and reduce the resulting dataframe.
Load data for a signal or a list of signals and reduce the resulting
dataframes to a fraction of their original size, losing late time
points.
dataframes to rows with sufficient numbers of time points.
"""
if isinstance(signal, str):
signal = self.get_raw(signal)
if isinstance(signal, pd.DataFrame):
return self.get_retained(signal, cutoff)
elif isinstance(signal, list):
......@@ -131,17 +136,15 @@ class Signal(BridgeH5):
"""
Get lineage data from a given location in the h5 file.
Returns an array with three columns: the tile id, the mother label, and the daughter label.
Returns an array with three columns: the tile id, the mother label,
and the daughter label.
"""
if lineage_location is None:
lineage_location = "modifiers/lineage_merged"
with h5py.File(self.filename, "r") as f:
# if lineage_location not in f:
# lineage_location = lineage_location.split("_")[0]
if lineage_location not in f:
lineage_location = "postprocessing/lineage"
tile_mo_da = f[lineage_location]
if isinstance(tile_mo_da, h5py.Dataset):
lineage = tile_mo_da[()]
else:
......@@ -154,7 +157,7 @@ class Signal(BridgeH5):
).T
return lineage
@_first_arg_str_to_df
@_first_arg_str_to_raw_df
def apply_prepost(
self,
data: t.Union[str, pd.DataFrame],
......@@ -272,7 +275,7 @@ class Signal(BridgeH5):
Parameters
----------
dataset: str or list of strs
The name of the h5 file or a list of h5 file names
The name of the h5 file or a list of h5 file names.
in_minutes: boolean
If True,
lineage: boolean
......@@ -288,15 +291,17 @@ class Signal(BridgeH5):
self.get_raw(dset, in_minutes=in_minutes, lineage=lineage)
for dset in dataset
]
if lineage: # assume that df is sorted
if lineage:
# assume that df is sorted
mother_label = np.zeros(len(df), dtype=int)
lineage = self.lineage()
a, b = validate_association(
# information on buds
valid_lineage, valid_indices = validate_lineage(
lineage,
np.array(df.index.to_list()),
match_column=1,
"daughters",
)
mother_label[b] = lineage[a, 1]
mother_label[valid_indices] = lineage[valid_lineage, 1]
df = add_index_levels(df, {"mother_label": mother_label})
return df
except Exception as e:
......@@ -353,10 +358,7 @@ class Signal(BridgeH5):
fullname: str,
node: t.Union[h5py.Dataset, h5py.Group],
):
"""
Store the name of a signal if it is a leaf node
(a group with no more groups inside) and if it starts with extraction.
"""
"""Store the name of a signal if it is a leaf node and if it starts with extraction."""
if isinstance(node, h5py.Group) and np.all(
[isinstance(x, h5py.Dataset) for x in node.values()]
):
......
......@@ -10,6 +10,102 @@ import numpy as np
import typing as t
def validate_lineage(
lineage: np.ndarray, indices: np.ndarray, how: str = "families"
):
"""
Identify mother-bud pairs that exist both in lineage and a Signal's
indices.
We expect the lineage information to be unique: a bud should not have
two mothers.
Parameters
----------
lineage : np.ndarray
2D array of lineage associations where columns are
(trap, mother, daughter)
or
a 3D array, which is an array of 2 X 2 arrays comprising
[[trap_id, mother_label], [trap_id, daughter_label]].
indices : np.ndarray
A 2D array of cell indices from a Signal, (trap_id, cell_label).
This array should not include mother_label.
how: str
If "mothers", matches indicate mothers from mother-bud pairs;
If "daughters", matches indicate daughters from mother-bud pairs;
If "families", matches indicate mothers and daughters in mother-bud pairs.
Returns
-------
valid_lineage: boolean np.ndarray
1D array indicating matched elements in lineage.
valid_indices: boolean np.ndarray
1D array indicating matched elements in indices.
Examples
--------
>>> import numpy as np
>>> from agora.utils.indexing import validate_lineage
>>> lineage = np.array([ [[0, 1], [0, 3]], [[0, 1], [0, 4]], [[0, 1], [0, 6]], [[0, 4], [0, 7]] ])
>>> indices = np.array([ [0, 1], [0, 2], [0, 3]])
>>> valid_lineage, valid_indices = validate_lineage(lineage, indices)
>>> print(valid_lineage)
array([ True, False, False, False])
>>> print(valid_indices)
array([ True, False, True])
and
>>> lineage = np.array([[[0,3], [0,1]], [[0,2], [0,4]]])
>>> indices = np.array([[0,1], [0,2], [0,3]])
>>> valid_lineage, valid_indices = validate_lineage(lineage, indices)
>>> print(valid_lineage)
array([ True, False])
>>> print(valid_indices)
array([ True, False, True])
"""
if lineage.ndim == 2:
# [trap, mother, daughter] becomes [[trap, mother], [trap, daughter]]
lineage = _assoc_indices_to_3d(lineage)
if how == "mothers":
c_index = 0
elif how == "daughters":
c_index = 1
# data type to link together trap and cell ids
dtype = {"names": ["trap_id", "cell_id"], "formats": [np.int64, np.int64]}
lineage = np.ascontiguousarray(lineage, dtype=np.int64)
# find (trap, cell_ids) in intersection
inboth = np.intersect1d(lineage.view(dtype), indices.view(dtype))
# find valid lineage
valid_lineages = np.isin(lineage.view(dtype), inboth)
if how == "families":
# both mother and bud must be in indices
valid_lineage = valid_lineages.all(axis=1)
else:
valid_lineage = valid_lineages[:, c_index, :]
# find valid indices
selected_lineages = lineage[valid_lineage.flatten(), ...]
if how == "families":
# select only pairs of mother and bud indices
valid_indices = np.isin(
indices.view(dtype), selected_lineages.view(dtype)
)
else:
valid_indices = np.isin(
indices.view(dtype), selected_lineages.view(dtype)[:, c_index, :]
)
if valid_indices[valid_indices].size != valid_lineage[valid_lineage].size:
raise Exception(
"Error in validate_lineage: "
"lineage information is likely not unique."
)
return valid_lineage.flatten(), valid_indices.flatten()
def validate_association(
association: np.ndarray,
indices: np.ndarray,
......@@ -133,9 +229,7 @@ def _assoc_indices_to_3d(ndarray: np.ndarray):
def _3d_index_to_2d(array: np.ndarray):
"""
Opposite to _assoc_indices_to_3d.
"""
"""Revert _assoc_indices_to_3d."""
result = array
if len(array):
result = np.concatenate(
......
#!/usr/bin/env jupyter
"""
Utilities based on association are used to efficiently acquire indices of
tracklets with some kind of relationship.
This can be:
- Cells that are to be merged.
- Cells that have a lineage relationship.
"""
import numpy as np
import typing as t
def validate_association(
association: np.ndarray,
indices: np.ndarray,
match_column: t.Optional[int] = None,
) -> t.Tuple[np.ndarray, np.ndarray]:
"""
Identify mother-bud pairs that exist both in lineage and a Signal's indices.
"""
if association.ndim == 2:
# reshape into 3D array for broadcasting
# for each trap, [trap, mother, daughter] becomes
# [[trap, mother], [trap, daughter]]
association = _assoc_indices_to_3d(association)
valid_association, valid_indices = validate_lineage(association, indices)
# Alan's working code
# Compare existing association with available indices
# Swap trap and label axes for the association array to correctly cast
valid_ndassociation_a = association[..., None] == indices.T[None, ...]
# Broadcasting is confusing (but efficient):
# First we check the dimension across trap and cell id, to ensure both match
valid_cell_ids_a = valid_ndassociation_a.all(axis=2)
if match_column is None:
# Then we check the merge tuples to check which cases have both target and source
valid_association_a = valid_cell_ids_a.any(axis=2).all(axis=1)
# Finally we check the dimension that crosses all indices, to ensure the pair
# is present in a valid merge event.
valid_indices_a = (
valid_ndassociation_a[valid_association_a]
.all(axis=2)
.any(axis=(0, 1))
)
else: # We fetch specific indices if we aim for the ones with one present
valid_indices_a = valid_cell_ids_a[:, match_column].any(axis=0)
# Valid association then becomes a boolean array, true means that there is a
# match (match_column) between that cell and the index
valid_association_a = (
valid_cell_ids_a[:, match_column] & valid_indices
).any(axis=1)
assert np.array_equal(
valid_association, valid_association_a
), "valid_association error"
assert np.array_equal(
valid_indices, valid_indices_a
), "valid_indices error"
return valid_association, valid_indices
def validate_association_old(
association: np.ndarray,
indices: np.ndarray,
match_column: t.Optional[int] = None,
) -> t.Tuple[np.ndarray, np.ndarray]:
"""
Identify mother-bud pairs that exist both in lineage and a Signal's indices.
Parameters
----------
association : np.ndarray
2D array of lineage associations where columns are (trap, mother, daughter)
or
a 3D array, which is an array of 2 X 2 arrays comprising [[trap_id, mother_label], [trap_id, daughter_label]].
indices : np.ndarray
A 2D array where each column is a different level, such as (trap_id, cell_label), which typically is an index of a Signal
dataframe. This array should not include mother_label.
match_column: int
If 0, matches indicate mothers from mother-bud pairs;
If 1, matches indicate daughters from mother-bud pairs;
If None, matches indicate either mothers or daughters in mother-bud pairs.
Returns
-------
valid_association: boolean np.ndarray
1D array indicating elements in association with matches.
valid_indices: boolean np.ndarray
1D array indicating elements in indices with matches.
Examples
--------
>>> import numpy as np
>>> from agora.utils.indexing import validate_association
>>> association = np.array([ [[0, 1], [0, 3]], [[0, 1], [0, 4]], [[0, 1], [0, 6]], [[0, 4], [0, 7]] ])
>>> indices = np.array([ [0, 1], [0, 2], [0, 3]])
>>> print(indices.T)
>>> valid_association, valid_indices = validate_association(association, indices)
>>> print(valid_association)
array([ True, False, False, False])
>>> print(valid_indices)
array([ True, False, True])
and
>>> association = np.array([[[0,3], [0,1]], [[0,2], [0,4]]])
>>> indices = np.array([[0,1], [0,2], [0,3]])
>>> valid_association, valid_indices = validate_association(association, indices)
>>> print(valid_association)
array([ True, False])
>>> print(valid_indices)
array([ True, False, True])
"""
if association.ndim == 2:
# reshape into 3D array for broadcasting
# for each trap, [trap, mother, daughter] becomes
# [[trap, mother], [trap, daughter]]
association = _assoc_indices_to_3d(association)
# use broadcasting to compare association with indices
# swap trap and cell_label axes for correct broadcasting
indicesT = indices.T
# compare each of [[trap, mother], [trap, daughter]] for all traps
# in association with [trap, cell_label] for all traps in indices
# association is no_traps x 2 x 2; indices is no_traps X 2
# valid_ndassociation is no_traps_association x 2 x 2 x no_traps_indices
valid_ndassociation = (
association[..., np.newaxis] == indicesT[np.newaxis, ...]
)
# find matches in association
###
# make True comparisons with both trap_ids and cell labels matching
# compare trap_ids and cell_ids for each pair of traps
valid_cell_ids = valid_ndassociation.all(axis=2)
if match_column is None:
# make True comparisons match at least one row in indices
# at least one cell_id matches
va_intermediate = valid_cell_ids.any(axis=2)
# make True comparisons have both mother and bud matching rows in indices
valid_association = va_intermediate.all(axis=1)
else:
# match_column selects mothers if 0 and daughters if 1
# make True match at least one row in indices
valid_association = valid_cell_ids[:, match_column].any(axis=1)
# find matches in indices
###
# make True comparisons have a validated association for both the mother and bud
# make True comparisons have both trap_ids and cell labels matching
valid_cell_ids_va = valid_ndassociation[valid_association].all(axis=2)
if match_column is None:
# make True comparisons match either a mother or a bud in association
valid_indices = valid_cell_ids_va.any(axis=(0, 1))
else:
valid_indices = valid_cell_ids_va[:, match_column][0]
# Alan's working code
# Compare existing association with available indices
# Swap trap and label axes for the association array to correctly cast
valid_ndassociation_a = association[..., None] == indices.T[None, ...]
# Broadcasting is confusing (but efficient):
# First we check the dimension across trap and cell id, to ensure both match
valid_cell_ids_a = valid_ndassociation_a.all(axis=2)
if match_column is None:
# Then we check the merge tuples to check which cases have both target and source
valid_association_a = valid_cell_ids_a.any(axis=2).all(axis=1)
# Finally we check the dimension that crosses all indices, to ensure the pair
# is present in a valid merge event.
valid_indices_a = (
valid_ndassociation_a[valid_association_a]
.all(axis=2)
.any(axis=(0, 1))
)
else: # We fetch specific indices if we aim for the ones with one present
valid_indices_a = valid_cell_ids_a[:, match_column].any(axis=0)
# Valid association then becomes a boolean array, true means that there is a
# match (match_column) between that cell and the index
valid_association_a = (
valid_cell_ids_a[:, match_column] & valid_indices
).any(axis=1)
assert np.array_equal(
valid_association, valid_association_a
), "valid_association error"
assert np.array_equal(
valid_indices, valid_indices_a
), "valid_indices error"
return valid_association, valid_indices
def _assoc_indices_to_3d(ndarray: np.ndarray):
"""
Reorganise an array of shape (N, 3) into one of shape (N, 2, 2).
Reorganise an array so that the last entry of each row is removed
and generates a new row. This new row retains all other entries of
the original row.
Example:
[ [0, 1, 3], [0, 1, 4] ]
becomes
[ [[0, 1], [0, 3]], [[0, 1], [0, 4]] ]
"""
result = ndarray
if len(ndarray) and ndarray.ndim > 1:
if ndarray.shape[1] == 3:
# faster indexing for single positions
result = np.transpose(
np.hstack((ndarray[:, [0]], ndarray)).reshape(-1, 2, 2),
axes=[0, 2, 1],
)
else:
# 20% slower, but more general indexing
columns = np.arange(ndarray.shape[1])
result = np.stack(
(
ndarray[:, np.delete(columns, -1)],
ndarray[:, np.delete(columns, -2)],
),
axis=1,
)
return result
def _3d_index_to_2d(array: np.ndarray):
"""Revert switch from _assoc_indices_to_3d."""
result = array
if len(array):
result = np.concatenate(
(array[:, 0, :], array[:, 1, 1, np.newaxis]), axis=1
)
return result
def compare_indices(x: np.ndarray, y: np.ndarray) -> np.ndarray:
"""
Compare two 2D arrays using broadcasting.
Return a binary array where a True value links two cells where
all cells are the same.
"""
return (x[..., np.newaxis] == y.T[np.newaxis, ...]).all(axis=1)
......@@ -86,16 +86,19 @@ def bidirectional_retainment_filter(
daughters_thresh: int = 7,
) -> pd.DataFrame:
"""
Retrieve families where mothers are present for more than a fraction of the experiment, and daughters for longer than some number of time-points.
Retrieve families where mothers are present for more than a fraction
of the experiment and daughters for longer than some number of
time-points.
Parameters
----------
df: pd.DataFrame
Data
mothers_thresh: float
Minimum fraction of experiment's total duration for which mothers must be present.
Minimum fraction of experiment's total duration for which mothers
must be present.
daughters_thresh: int
Minimum number of time points for which daughters must be observed
Minimum number of time points for which daughters must be observed.
"""
# daughters
all_daughters = df.loc[df.index.get_level_values("mother_label") > 0]
......@@ -170,6 +173,7 @@ def slices_from_spans(spans: t.Tuple[int], df: pd.DataFrame) -> t.List[slice]:
def drop_mother_label(index: pd.MultiIndex) -> np.ndarray:
"""Remove mother_label level from a MultiIndex."""
no_mother_label = index
if "mother_label" in index.names:
no_mother_label = index.droplevel("mother_label")
......
#!/usr/bin/env python3
import re
import typing as t
import numpy as np
import pandas as pd
from agora.io.bridge import groupsort
from itertools import groupby
def mb_array_to_dict(mb_array: np.ndarray):
......@@ -19,4 +15,3 @@ def mb_array_to_dict(mb_array: np.ndarray):
for trap, mo_da in groupsort(mb_array).items()
for mo, daughters in groupsort(mo_da).items()
}
......@@ -13,8 +13,11 @@ from agora.utils.indexing import compare_indices, validate_association
def apply_merges(data: pd.DataFrame, merges: np.ndarray):
"""Split data in two, one subset for rows relevant for merging and one
without them. It uses an array of source tracklets and target tracklets
"""
Split data in two, one subset for rows relevant for merging and one
without them.
Use an array of source tracklets and target tracklets
to efficiently merge them.
Parameters
......@@ -43,9 +46,9 @@ def apply_merges(data: pd.DataFrame, merges: np.ndarray):
# Implement the merges and drop source rows.
# TODO Use matrices to perform merges in batch
# for ecficiency
# for efficiency
if valid_merges.any():
to_merge = data.loc[indices]
to_merge = data.loc[indices].copy()
targets, sources = zip(*merges[valid_merges])
for source, target in zip(sources, targets):
target = tuple(target)
......
......@@ -54,7 +54,7 @@ class DatasetLocalABC(ABC):
Abstract Base class to find local files, either OME-XML or raw images.
"""
_valid_suffixes = ("tiff", "png", "zarr")
_valid_suffixes = ("tiff", "png", "zarr", "tif")
_valid_meta_suffixes = ("txt", "log")
def __init__(self, dpath: t.Union[str, Path], *args, **kwargs):
......
......@@ -30,14 +30,14 @@ from agora.io.metadata import dir_to_meta, dispatch_metadata_parser
def get_examples_dir():
"""Get examples directory which stores dummy image for tiler"""
"""Get examples directory that stores dummy image for tiler."""
return files("aliby").parent.parent / "examples" / "tiler"
def instantiate_image(
source: t.Union[str, int, t.Dict[str, str], Path], **kwargs
):
"""Wrapper to instatiate the appropiate image
"""Wrapper to instantiate the appropriate image
Parameters
----------
......@@ -55,26 +55,26 @@ def instantiate_image(
def dispatch_image(source: t.Union[str, int, t.Dict[str, str], Path]):
"""
Wrapper to pick the appropiate Image class depending on the source of data.
"""
"""Pick the appropriate Image class depending on the source of data."""
if isinstance(source, (int, np.int64)):
from aliby.io.omero import Image
instatiator = Image
instantiator = Image
elif isinstance(source, dict) or (
isinstance(source, (str, Path)) and Path(source).is_dir()
):
if Path(source).suffix == ".zarr":
instatiator = ImageZarr
instantiator = ImageZarr
else:
instatiator = ImageDir
instantiator = ImageDir
elif isinstance(source, Path) and source.is_file():
# my addition
instantiator = ImageLocalOME
elif isinstance(source, str) and Path(source).is_file():
instatiator = ImageLocalOME
instantiator = ImageLocalOME
else:
raise Exception(f"Invalid data source at {source}")
return instatiator
return instantiator
class BaseLocalImage(ABC):
......@@ -82,6 +82,7 @@ class BaseLocalImage(ABC):
Base Image class to set path and provide context management method.
"""
# default image order
_default_dimorder = "tczyx"
def __init__(self, path: t.Union[str, Path]):
......@@ -98,8 +99,7 @@ class BaseLocalImage(ABC):
return False
def rechunk_data(self, img):
# Format image using x and y size from metadata.
"""Format image using x and y size from metadata."""
self._rechunked_img = da.rechunk(
img,
chunks=(
......@@ -145,16 +145,16 @@ class ImageLocalOME(BaseLocalImage):
in which a multidimensional tiff image contains the metadata.
"""
def __init__(self, path: str, dimorder=None):
def __init__(self, path: str, dimorder=None, **kwargs):
super().__init__(path)
self._id = str(path)
self.set_meta(str(path))
def set_meta(self):
def set_meta(self, path):
meta = dict()
try:
with TiffFile(path) as f:
self._meta = xmltodict.parse(f.ome_metadata)["OME"]
for dim in self.dimorder:
meta["size_" + dim.lower()] = int(
self._meta["Image"]["Pixels"]["@Size" + dim]
......@@ -165,21 +165,19 @@ class ImageLocalOME(BaseLocalImage):
]
meta["name"] = self._meta["Image"]["@Name"]
meta["type"] = self._meta["Image"]["Pixels"]["@Type"]
except Exception as e: # Images not in OMEXML
except Exception as e:
# images not in OMEXML
print("Warning:Metadata not found: {}".format(e))
print(
f"Warning: No dimensional info provided. Assuming {self._default_dimorder}"
"Warning: No dimensional info provided. "
f"Assuming {self._default_dimorder}"
)
# Mark non-existent dimensions for padding
# mark non-existent dimensions for padding
self.base = self._default_dimorder
# self.ids = [self.index(i) for i in dimorder]
self._dimorder = base
self._dimorder = self.base
self._meta = meta
# self._meta["name"] = Path(path).name.split(".")[0]
@property
def name(self):
......@@ -246,7 +244,7 @@ class ImageDir(BaseLocalImage):
It inherits from BaseLocalImage so we only override methods that are critical.
Assumptions:
- One folders per position.
- One folder per position.
- Images are flat.
- Channel, Time, z-stack and the others are determined by filenames.
- Provides Dimorder as it is set in the filenames, or expects order during instatiation
......@@ -318,7 +316,7 @@ class ImageZarr(BaseLocalImage):
print(f"Could not add size info to metadata: {e}")
def get_data_lazy(self) -> da.Array:
"""Return 5D dask array. For lazy-loading local multidimensional zarr files"""
"""Return 5D dask array for lazy-loading local multidimensional zarr files."""
return self._img
def add_size_to_meta(self):
......
......@@ -154,6 +154,7 @@ class PipelineParameters(ParametersABC):
defaults["tiler"]["backup_ref_channel"] = backup_ref_channel
defaults["baby"] = BabyParameters.default(**baby).to_dict()
# why are BabyParameters here as an alternative?
defaults["extraction"] = (
exparams_from_meta(meta_d)
or BabyParameters.default(**extraction).to_dict()
......@@ -320,7 +321,7 @@ class Pipeline(ProcessABC):
)
# get log files, either locally or via OMERO
with dispatcher as conn:
image_ids = conn.get_images()
position_ids = conn.get_images()
directory = self.store or root_dir / conn.unique_name
if not directory.exists():
directory.mkdir(parents=True)
......@@ -330,29 +331,29 @@ class Pipeline(ProcessABC):
self.parameters.general["directory"] = str(directory)
config["general"]["directory"] = directory
self.setLogger(directory)
# pick particular images if desired
# pick particular positions if desired
if pos_filter is not None:
if isinstance(pos_filter, list):
image_ids = {
position_ids = {
k: v
for filt in pos_filter
for k, v in self.apply_filter(image_ids, filt).items()
for k, v in self.apply_filter(position_ids, filt).items()
}
else:
image_ids = self.apply_filter(image_ids, pos_filter)
assert len(image_ids), "No images to segment"
position_ids = self.apply_filter(position_ids, pos_filter)
assert len(position_ids), "No images to segment"
# create pipelines
if distributed != 0:
# multiple cores
with Pool(distributed) as p:
results = p.map(
lambda x: self.run_one_position(*x),
[(k, i) for i, k in enumerate(image_ids.items())],
[(k, i) for i, k in enumerate(position_ids.items())],
)
else:
# single core
results = []
for k, v in tqdm(image_ids.items()):
for k, v in tqdm(position_ids.items()):
r = self.run_one_position((k, v), 1)
results.append(r)
return results
......@@ -432,6 +433,7 @@ class Pipeline(ProcessABC):
if process_from["extraction"] < tps:
# TODO Move this parameter validation into Extractor
av_channels = set((*steps["tiler"].channels, "general"))
# overwrite extraction specified by PipelineParameters !!
config["extraction"]["tree"] = {
k: v
for k, v in config["extraction"]["tree"].items()
......@@ -453,13 +455,14 @@ class Pipeline(ProcessABC):
steps["extraction"] = Extractor.from_tiler(
exparams, store=filename, tiler=steps["tiler"]
)
# set up progress meter
# set up progress bar
pbar = tqdm(
range(min_process_from, tps),
desc=image.name,
initial=min_process_from,
total=tps,
)
# run through time points
for i in pbar:
if (
frac_clogged_traps
......@@ -469,9 +472,12 @@ class Pipeline(ProcessABC):
# run through steps
for step in self.pipeline_steps:
if i >= process_from[step]:
# perform step
result = steps[step].run_tp(
i, **run_kwargs.get(step, {})
)
# write to h5 file using writers
# extractor writes to h5 itself
if step in loaded_writers:
loaded_writers[step].write(
data=result,
......@@ -481,7 +487,7 @@ class Pipeline(ProcessABC):
tp=i,
meta={"last_processed": i},
)
# perform step
# clean up
if (
step == "tiler"
and i == min_process_from
......@@ -501,7 +507,7 @@ class Pipeline(ProcessABC):
tp=i,
)
elif step == "extraction":
# remove mask/label after extraction
# remove masks and labels after extraction
for k in ["masks", "labels"]:
run_kwargs[step][k] = None
# check and report clogging
......@@ -586,41 +592,6 @@ class Pipeline(ProcessABC):
)
return (traps_above_nthresh & traps_above_athresh).mean()
# FIXME: Remove this functionality. It used to be for
# older hdf5 file formats.
def _load_config_from_file(
self,
filename: Path,
process_from: t.Dict[str, int],
trackers_state: t.List,
overwrite: t.Dict[str, bool],
):
with h5py.File(filename, "r") as f:
for k in process_from.keys():
if not overwrite[k]:
process_from[k] = self.legacy_get_last_tp[k](f)
process_from[k] += 1
return process_from, trackers_state, overwrite
# FIXME: Remove this functionality. It used to be for
# older hdf5 file formats.
@staticmethod
def legacy_get_last_tp(step: str) -> t.Callable:
"""Get last time-point in different ways depending
on which step we are using
To support segmentation in aliby < v0.24
TODO Deprecate and replace with State method
"""
switch_case = {
"tiler": lambda f: f["trap_info/drifts"].shape[0] - 1,
"baby": lambda f: f["cell_info/timepoint"][-1],
"extraction": lambda f: f[
"extraction/general/None/area/timepoint"
][-1],
}
return switch_case[step]
def _setup_pipeline(
self, image_id: int
) -> t.Tuple[
......@@ -676,13 +647,12 @@ class Pipeline(ProcessABC):
step: self.step_sequence.index(ow_id) < i
for i, step in enumerate(self.step_sequence, 1)
}
# Set up
# set up
directory = config["general"]["directory"]
trackers_state: t.List[np.ndarray] = []
with dispatch_image(image_id)(image_id, **self.server_info) as image:
filename = Path(f"{directory}/{image.name}.h5")
# load metadata
meta = MetaData(directory, filename)
from_start = True if np.any(ow.values()) else False
# remove existing file if overwriting
......@@ -716,7 +686,7 @@ class Pipeline(ProcessABC):
)
config["tiler"] = steps["tiler"].parameters.to_dict()
except Exception:
self._log(f"Overwriting tiling data")
self._log("Overwriting tiling data")
if config["general"]["use_explog"]:
meta.run()
......
"""
Tiler: Divides images into smaller tiles.
The tasks of the Tiler are selecting regions of interest, or tiles, of images - with one trap per tile, correcting for the drift of the microscope stage over time, and handling errors and bridging between the image data and Aliby’s image-processing steps.
The tasks of the Tiler are selecting regions of interest, or tiles, of
images - with one trap per tile, correcting for the drift of the microscope
stage over time, and handling errors and bridging between the image data
and Aliby’s image-processing steps.
Tiler subclasses deal with either network connections or local files.
To find tiles, we use a two-step process: we analyse the bright-field image to produce the template of a trap, and we fit this template to the image to find the tiles' centres.
To find tiles, we use a two-step process: we analyse the bright-field image
to produce the template of a trap, and we fit this template to the image to
find the tiles' centres.
We use texture-based segmentation (entropy) to split the image into foreground -- cells and traps -- and background, which we then identify with an Otsu filter. Two methods are used to produce a template trap from these regions: pick the trap with the smallest minor axis length and average over all validated traps.
We use texture-based segmentation (entropy) to split the image into
foreground -- cells and traps -- and background, which we then identify with
an Otsu filter. Two methods are used to produce a template trap from these
regions: pick the trap with the smallest minor axis length and average over
all validated traps.
A peak-identifying algorithm recovers the x and y-axis location of traps in the original image, and we choose the approach to template that identifies the most tiles.
A peak-identifying algorithm recovers the x and y-axis location of traps in
the original image, and we choose the approach to template that identifies
the most tiles.
The experiment is stored as an array with a standard indexing order of (Time, Channels, Z-stack, X, Y).
The experiment is stored as an array with a standard indexing order of
(Time, Channels, Z-stack, X, Y).
"""
import logging
import re
......@@ -355,9 +367,9 @@ class Tiler(StepABC):
full: an array of images
"""
full = self.image[t, c]
if hasattr(full, "compute"): # If using dask fetch images here
if hasattr(full, "compute"):
# if using dask fetch images
full = full.compute(scheduler="synchronous")
return full
@property
......@@ -593,7 +605,10 @@ class Tiler(StepABC):
def get_channel_index(self, channel: str or int) -> int or None:
"""
Find index for channel using regex. Returns the first matched string.
Find index for channel using regex.
Return the first matched string.
If self.channels is integers (no image metadata) it returns None.
If channel is integer
......@@ -602,10 +617,8 @@ class Tiler(StepABC):
channel: string or int
The channel or index to be used.
"""
if all(map(lambda x: isinstance(x, int), self.channels)):
channel = channel if isinstance(channel, int) else None
if isinstance(channel, str):
channel = find_channel_index(self.channels, channel)
return channel
......
......@@ -80,7 +80,7 @@ class Extractor(StepABC):
Usually the metric is applied to only a tile's masked area, but some metrics depend on the whole tile.
Extraction follows a three-level tree structure. Channels, such as GFP, are the root level; the reduction algorithm, such as maximum projection, is the second level; the specific metric, or operation, to apply to the masks, such as mean, is the third level.
Extraction follows a three-level tree structure. Channels, such as GFP, are the root level; the reduction algorithm, such as maximum projection, is the second level; the specific metric, or operation, to apply to the masks, such as mean, is the third or leaf level.
"""
# TODO Alan: Move this to a location with the SwainLab defaults
......@@ -202,7 +202,7 @@ class Extractor(StepABC):
self._custom_funs[k] = tmp(f)
def load_funs(self):
"""Define all functions, including custum ones."""
"""Define all functions, including custom ones."""
self.load_custom_funs()
self._all_cell_funs = set(self._custom_funs.keys()).union(CELL_FUNS)
# merge the two dicts
......@@ -335,7 +335,7 @@ class Extractor(StepABC):
**kwargs,
) -> t.Dict[str, t.Dict[reduction_method, t.Dict[str, pd.Series]]]:
"""
Wrapper to apply reduction and then extraction.
Wrapper to reduce to a 2D image and then extract.
Parameters
----------
......@@ -499,7 +499,6 @@ class Extractor(StepABC):
# calculate metrics with subtracted bg
ch_bs = ch + "_bgsub"
# subtract median background
self.img_bgsub[ch_bs] = np.moveaxis(
np.stack(
list(
......@@ -579,7 +578,9 @@ class Extractor(StepABC):
**kwargs,
) -> dict:
"""
Wrapper to add compatibility with other steps of the pipeline.
Run extraction for one position and for the specified time points.
Save the results to a h5 file.
Parameters
----------
......@@ -597,7 +598,7 @@ class Extractor(StepABC):
Returns
-------
d: dict
A dict of the extracted data with a concatenated string of channel, reduction metric, and cell metric as keys and pd.Series of the extracted data as values.
A dict of the extracted data for one position with a concatenated string of channel, reduction metric, and cell metric as keys and pd.DataFrame of the extracted data for all time points as values.
"""
if tree is None:
tree = self.params.tree
......@@ -633,7 +634,7 @@ class Extractor(StepABC):
def save_to_hdf(self, dict_series, path=None):
"""
Save the extracted data to the h5 file.
Save the extracted data for one position to the h5 file.
Parameters
----------
......
# File with defaults for ease of use
import re
import typing as t
from pathlib import Path
import h5py
# should we move these functions here?
from aliby.tile.tiler import find_channel_name
......@@ -59,6 +58,7 @@ def exparams_from_meta(
for ch in extant_fluorescence_ch:
base["tree"][ch] = default_reduction_metrics
base["sub_bg"] = extant_fluorescence_ch
# additional extraction defaults if the channels are available
if "ph" in extras:
# SWAINLAB specific names
......
......@@ -14,9 +14,10 @@ from postprocessor.core.abc import get_process
class Chainer(Signal):
"""
Extend Signal by applying post-processes and allowing composite signals that combine basic signals.
It "chains" multiple processes upon fetching a dataset to produce the desired datasets.
Instead of reading processes previously applied, it executes
Chainer "chains" multiple processes upon fetching a dataset.
Instead of reading processes previously applied, Chainer executes
them when called.
"""
......@@ -25,6 +26,7 @@ class Chainer(Signal):
}
def __init__(self, *args, **kwargs):
"""Initialise chainer."""
super().__init__(*args, **kwargs)
def replace_path(path: str, bgsub: bool = ""):
......@@ -34,7 +36,7 @@ class Chainer(Signal):
path = re.sub(channel, f"{channel}{suffix}", path)
return path
# Add chain with and without bgsub for composite statistics
# add chain with and without bgsub for composite statistics
self.common_chains = {
alias
+ bgsub: lambda **kwargs: self.get(
......
"""
Functions to process, filter and merge tracks.
"""
Functions to process, filter, and merge tracks.
We call two tracks contiguous if they are adjacent in time: the
maximal time point of one is one time point less than the
minimal time point of the other.
# from collections import Counter
A right track can have multiple potential left tracks. We must
pick the best.
"""
import typing as t
from copy import copy
......@@ -17,6 +22,76 @@ from utils_find_1st import cmp_larger, find_1st
from postprocessor.core.processes.savgol import non_uniform_savgol
def get_joinable(tracks, smooth=False, tol=0.1, window=5, degree=3) -> dict:
"""
Get the pair of track (without repeats) that have a smaller error than the
tolerance. If there is a track that can be assigned to two or more other
ones, choose the one with lowest error.
Parameters
----------
tracks: (m x n) Signal
A Signal, usually area, dataframe where rows are cell tracks and
columns are time points.
tol: float or int
threshold of average (prediction error/std) necessary
to consider two tracks the same. If float is fraction of first track,
if int it is absolute units.
window: int
value of window used for savgol_filter
degree: int
value of polynomial degree passed to savgol_filter
"""
# only consider time series with more than two non-NaN data points
tracks = tracks.loc[tracks.notna().sum(axis=1) > 2]
# get contiguous tracks
if smooth:
# specialise to tracks with growing cells and of long duration
clean = clean_tracks(tracks, min_duration=window + 1, min_gr=0.9)
contigs = clean.groupby(["trap"]).apply(get_contiguous_pairs)
else:
contigs = tracks.groupby(["trap"]).apply(get_contiguous_pairs)
# remove traps with no contiguous tracks
contigs = contigs.loc[contigs.apply(len) > 0]
# flatten to (trap, cell_id) pairs
flat = set([k for v in contigs.values for i in v for j in i for k in j])
# make a data frame of contiguous tracks with the tracks as arrays
if smooth:
smoothed_tracks = clean.loc[flat].apply(
lambda x: non_uniform_savgol(x.index, x.values, window, degree),
axis=1,
)
else:
smoothed_tracks = tracks.loc[flat].apply(
lambda x: np.array(x.values), axis=1
)
# get the Signal values for neighbouring end points of contiguous tracks
actual_edges = contigs.apply(lambda x: get_edge_values(x, smoothed_tracks))
# get the predicted values
predicted_edges = contigs.apply(
lambda x: get_predicted_edge_values(x, smoothed_tracks, window)
)
# Prediction of pre and mean of post
prediction_costs = predicted_edges.apply(get_dMetric_wrap, tol=tol)
solutions = [
solve_matrices_wrap(cost, edges, tol=tol)
for (trap_id, cost), edges in zip(
prediction_costs.items(), actual_edges
)
]
breakpoint()
closest_pairs = pd.Series(
solutions,
index=edges_dMetric_pred.index,
)
# match local with global ids
joinable_ids = [
localid_to_idx(closest_pairs.loc[i], contigs.loc[i])
for i in closest_pairs.index
]
return [pair for pairset in joinable_ids for pair in pairset]
def load_test_dset():
"""Load development dataset to test functions."""
return pd.DataFrame(
......@@ -45,46 +120,21 @@ def max_nonstop_ntps(track: pd.Series) -> int:
return max(consecutive_nonas_grouped)
def get_tracks_ntps(tracks: pd.DataFrame) -> pd.Series:
return tracks.apply(max_ntps, axis=1)
def get_avg_gr(track: pd.Series) -> int:
"""
Get average growth rate for a track.
:param tracks: Series with volume and timepoints as indices
"""
def get_avg_gr(track: pd.Series) -> float:
"""Get average growth rate for a track."""
ntps = max_ntps(track)
vals = track.dropna().values
gr = (vals[-1] - vals[0]) / ntps
return gr
def get_avg_grs(tracks: pd.DataFrame) -> pd.DataFrame:
"""
Get average growth rate for a group of tracks
:param tracks: (m x n) dataframe where rows are cell tracks and
columns are timepoints
"""
return tracks.apply(get_avg_gr, axis=1)
def clean_tracks(
tracks, min_len: int = 15, min_gr: float = 1.0
tracks, min_duration: int = 15, min_gr: float = 1.0
) -> pd.DataFrame:
"""
Clean small non-growing tracks and return the reduced dataframe
:param tracks: (m x n) dataframe where rows are cell tracks and
columns are timepoints
:param min_len: int number of timepoints cells must have not to be removed
:param min_gr: float Minimum mean growth rate to assume an outline is growing
"""
ntps = get_tracks_ntps(tracks)
grs = get_avg_grs(tracks)
growing_long_tracks = tracks.loc[(ntps >= min_len) & (grs > min_gr)]
"""Remove small non-growing tracks and return the reduced data frame."""
ntps = tracks.apply(max_ntps, axis=1)
grs = tracks.apply(get_avg_gr, axis=1)
growing_long_tracks = tracks.loc[(ntps >= min_duration) & (grs > min_gr)]
return growing_long_tracks
......@@ -191,139 +241,78 @@ def join_track_pair(target, source):
return tgt_copy
def get_joinable(tracks, smooth=False, tol=0.1, window=5, degree=3) -> dict:
"""
Get the pair of track (without repeats) that have a smaller error than the
tolerance. If there is a track that can be assigned to two or more other
ones, choose the one with lowest error.
def get_edge_values(contigs_ids, smoothed_tracks):
"""Get Signal values for adjacent end points for each contiguous track."""
values = [
(
[get_value(smoothed_tracks.loc[pre_id], -1) for pre_id in pre_ids],
[
get_value(smoothed_tracks.loc[post_id], 0)
for post_id in post_ids
],
)
for pre_ids, post_ids in contigs_ids
]
return values
:param tracks: (m x n) dataframe where rows are cell tracks and
columns are timepoints
:param tol: float or int threshold of average (prediction error/std) necessary
to consider two tracks the same. If float is fraction of first track,
if int it is absolute units.
:param window: int value of window used for savgol_filter
:param degree: int value of polynomial degree passed to savgol_filter
def get_predicted_edge_values(contigs_ids, smoothed_tracks, window):
"""
tracks = tracks.loc[tracks.notna().sum(axis=1) > 2]
Find neighbouring values of two contiguous tracks.
# Commented because we are not smoothing in this step yet
# candict = {k:v for d in contig.values for k,v in d.items()}
# smooth all relevant tracks
if smooth: # Apply savgol filter TODO fix nans affecting edge placing
clean = clean_tracks(
tracks, min_len=window + 1, min_gr=0.9
) # get useful tracks
def savgol_on_srs(x):
return non_uniform_savgol(x.index, x.values, window, degree)
contig = clean.groupby(["trap"]).apply(get_contiguous_pairs)
contig = contig.loc[contig.apply(len) > 0]
flat = set([k for v in contig.values for i in v for j in i for k in j])
smoothed_tracks = clean.loc[flat].apply(savgol_on_srs, 1)
else:
contig = tracks.groupby(["trap"]).apply(get_contiguous_pairs)
contig = contig.loc[contig.apply(len) > 0]
flat = set([k for v in contig.values for i in v for j in i for k in j])
smoothed_tracks = tracks.loc[flat].apply(
lambda x: np.array(x.values), axis=1
)
# fetch edges from ids TODO (IF necessary, here we can compare growth rates)
def idx_to_edge(preposts):
return [
(
[get_val(smoothed_tracks.loc[pre], -1) for pre in pres],
[get_val(smoothed_tracks.loc[post], 0) for post in posts],
Predict the next value for the leftmost track using window values
and find the mean of the initial window values of the rightmost
track.
"""
result = []
for pre_ids, post_ids in contigs_ids:
pre_res = []
# left contiguous tracks
for pre_id in pre_ids:
# get last window values of a track
y = get_values_i(smoothed_tracks.loc[pre_id], -window)
# predict next value
pre_res.append(
np.poly1d(np.polyfit(range(len(y)), y, 1))(len(y) + 1),
)
for pres, posts in preposts
# right contiguous tracks
pos_res = [
# mean value of initial window values of a track
get_mean_value_i(smoothed_tracks.loc[post_id], window)
for post_id in post_ids
]
result.append([pre_res, pos_res])
return result
# idx_to_means = lambda preposts: [
# (
# [get_means(smoothed_tracks.loc[pre], -window) for pre in pres],
# [get_means(smoothed_tracks.loc[post], window) for post in posts],
# )
# for pres, posts in preposts
# ]
def idx_to_pred(preposts):
result = []
for pres, posts in preposts:
pre_res = []
for pre in pres:
y = get_last_i(smoothed_tracks.loc[pre], -window)
pre_res.append(
np.poly1d(np.polyfit(range(len(y)), y, 1))(len(y) + 1),
)
pos_res = [
get_means(smoothed_tracks.loc[post], window) for post in posts
]
result.append([pre_res, pos_res])
return result
edges = contig.apply(idx_to_edge) # Raw edges
# edges_mean = contig.apply(idx_to_means) # Mean of both
pre_pred = contig.apply(idx_to_pred) # Prediction of pre and mean of post
# edges_dMetric = edges.apply(get_dMetric_wrap, tol=tol)
# edges_dMetric_mean = edges_mean.apply(get_dMetric_wrap, tol=tol)
edges_dMetric_pred = pre_pred.apply(get_dMetric_wrap, tol=tol)
# combined_dMetric = pd.Series(
# [
# [np.nanmin((a, b), axis=0) for a, b in zip(x, y)]
# for x, y in zip(edges_dMetric, edges_dMetric_mean)
# ],
# index=edges_dMetric.index,
# )
# closest_pairs = combined_dMetric.apply(get_vec_closest_pairs, tol=tol)
solutions = []
# for (i, dMetrics), edgeset in zip(combined_dMetric.items(), edges):
for (i, dMetrics), edgeset in zip(edges_dMetric_pred.items(), edges):
solutions.append(solve_matrices_wrap(dMetrics, edgeset, tol=tol))
closest_pairs = pd.Series(
solutions,
index=edges_dMetric_pred.index,
)
# match local with global ids
joinable_ids = [
localid_to_idx(closest_pairs.loc[i], contig.loc[i])
for i in closest_pairs.index
]
return [pair for pairset in joinable_ids for pair in pairset]
def get_val(x, n):
return x[~np.isnan(x)][n] if len(x[~np.isnan(x)]) else np.nan
def get_value(x, n):
"""Get value from an array ignoring NaN."""
val = x[~np.isnan(x)][n] if len(x[~np.isnan(x)]) else np.nan
return val
def get_means(x, i):
def get_mean_value_i(x, i):
"""Get track's mean Signal value from values either from or up to an index."""
if not len(x[~np.isnan(x)]):
return np.nan
if i > 0:
v = x[~np.isnan(x)][:i]
else:
v = x[~np.isnan(x)][i:]
return np.nanmean(v)
if i > 0:
v = x[~np.isnan(x)][:i]
else:
v = x[~np.isnan(x)][i:]
return np.nanmean(v)
def get_last_i(x, i):
def get_values_i(x, i):
"""Get track's Signal values either from or up to an index."""
if not len(x[~np.isnan(x)]):
return np.nan
if i > 0:
v = x[~np.isnan(x)][:i]
else:
v = x[~np.isnan(x)][i:]
return v
if i > 0:
v = x[~np.isnan(x)][:i]
else:
v = x[~np.isnan(x)][i:]
return v
def localid_to_idx(local_ids, contig_trap):
......@@ -352,57 +341,55 @@ def get_vec_closest_pairs(lst: List, **kwargs):
def get_dMetric_wrap(lst: List, **kwargs):
"""Calculate dMetric on a list."""
return [get_dMetric(*sublist, **kwargs) for sublist in lst]
def solve_matrices_wrap(dMetric: List, edges: List, **kwargs):
"""Calculate solve_matrices on a list."""
return [
solve_matrices(mat, edgeset, **kwargs)
for mat, edgeset in zip(dMetric, edges)
]
def get_dMetric(
pre: List[float], post: List[float], tol: Union[float, int] = 1
):
"""Calculate a cost matrix
input
:param pre: list of floats with edges on left
:param post: list of floats with edges on right
:param tol: int or float if int metrics of tolerance, if float fraction
returns
:: list of indices corresponding to the best solutions for matrices
def get_dMetric(pre: List[float], post: List[float], tol):
"""
Calculate a cost matrix based on the difference between two Signal
values.
Parameters
----------
pre: list of floats
Values of the Signal for left contiguous tracks.
post: list of floats
Values of the Signal for right contiguous tracks.
"""
if len(pre) > len(post):
dMetric = np.abs(np.subtract.outer(post, pre))
else:
dMetric = np.abs(np.subtract.outer(pre, post))
dMetric[np.isnan(dMetric)] = (
tol + 1 + np.nanmax(dMetric)
) # nans will be filtered
# replace NaNs with maximal cost values
dMetric[np.isnan(dMetric)] = tol + 1 + np.nanmax(dMetric)
return dMetric
def solve_matrices(
dMetric: np.ndarray, prepost: List, tol: Union[float, int] = 1
):
def solve_matrices(cost: np.ndarray, edges: List, tol: Union[float, int] = 1):
"""
Solve the distance matrices obtained in get_dMetric and/or merged from
independent dMetric matrices.
"""
ids = solve_matrix(dMetric)
if not len(ids[0]):
ids = solve_matrix(cost)
if len(ids[0]):
pre, post = edges
norm = (
np.array(pre)[ids[len(pre) > len(post)]] if tol < 1 else 1
) # relative or absolute tol
result = dMetric[ids] / norm
ids = ids if len(pre) < len(post) else ids[::-1]
return [idx for idx, res in zip(zip(*ids), result) if res <= tol]
else:
return []
pre, post = prepost
norm = (
np.array(pre)[ids[len(pre) > len(post)]] if tol < 1 else 1
) # relative or absolute tol
result = dMetric[ids] / norm
ids = ids if len(pre) < len(post) else ids[::-1]
return [idx for idx, res in zip(zip(*ids), result) if res <= tol]
def get_closest_pairs(
......@@ -426,37 +413,31 @@ def get_closest_pairs(
def solve_matrix(dMetric):
"""
Solve cost matrix focusing on getting the smallest cost at each iteration.
input
:param dMetric: np.array cost matrix
returns
tuple of np.arrays indicating picks with lowest individual value
"""
"""Arrange indices to the cost matrix in order of increasing cost."""
glob_is = []
glob_js = []
if (~np.isnan(dMetric)).any():
tmp = copy(dMetric)
std = sorted(tmp[~np.isnan(tmp)])
while (~np.isnan(std)).any():
v = std[0]
i_s, j_s = np.where(tmp == v)
lMetric = copy(dMetric)
sortedMetric = sorted(lMetric[~np.isnan(lMetric)])
while (~np.isnan(sortedMetric)).any():
# indices of point with minimal cost
i_s, j_s = np.where(lMetric == sortedMetric[0])
i = i_s[0]
j = j_s[0]
tmp[i, :] += np.nan
tmp[:, j] += np.nan
# store this point
glob_is.append(i)
glob_js.append(j)
std = sorted(tmp[~np.isnan(tmp)])
return (np.array(glob_is), np.array(glob_js))
# remove from lMetric
lMetric[i, :] += np.nan
lMetric[:, j] += np.nan
sortedMetric = sorted(lMetric[~np.isnan(lMetric)])
indices = (np.array(glob_is), np.array(glob_js))
breakpoint()
return indices
def plot_joinable(tracks, joinable_pairs):
"""
Convenience plotting function for debugging and data vis
"""
"""Convenience plotting function for debugging."""
nx = 8
ny = 8
_, axes = plt.subplots(nx, ny)
......@@ -479,55 +460,33 @@ def plot_joinable(tracks, joinable_pairs):
def get_contiguous_pairs(tracks: pd.DataFrame) -> list:
"""
Get all pair of contiguous track ids from a tracks dataframe.
Get all pair of contiguous track ids from a tracks data frame.
:param tracks: (m x n) dataframe where rows are cell tracks and
columns are timepoints
:param min_dgr: float minimum difference in growth rate from
the interpolation
For two tracks to be contiguous, they must be exactly adjacent.
Parameters
----------
tracks: pd.Dataframe
A dataframe where rows are cell tracks and columns are time
points.
"""
mins, maxes = [
# TODO add support for skipping time points
# find time points bounding tracks of non-NaN values
mins, maxs = [
tracks.notna().apply(np.where, axis=1).apply(fn)
for fn in (np.min, np.max)
]
# mins.name = "min_tpt"
# maxs.name = "max_tpt"
# df = pd.merge(mins, maxs, right_index=True, left_index=True)
# df["duration"] = df.max_tpt - df.min_tpt
#
# flip so that time points become the index
mins_d = mins.groupby(mins).apply(lambda x: x.index.tolist())
mins_d.index = mins_d.index - 1 # make indices equal
# TODO add support for skipping time points
maxes_d = maxes.groupby(maxes).apply(lambda x: x.index.tolist())
common = sorted(
set(mins_d.index).intersection(maxes_d.index), reverse=True
)
return [(maxes_d[t], mins_d[t]) for t in common]
# def fit_track(track: pd.Series, obj=None):
# if obj is None:
# obj = objective
# x = track.dropna().index
# y = track.dropna().values
# popt, _ = curve_fit(obj, x, y)
# return popt
# def interpolate(track, xs) -> list:
# '''
# Interpolate next timepoint from a track
# :param track: pd.Series of volume growth over a time period
# :param t: int timepoint to interpolate
# '''
# popt = fit_track(track)
# # perr = np.sqrt(np.diag(pcov))
# return objective(np.array(xs), *popt)
# def objective(x,a,b,c,d) -> float:
# # return (a)/(1+b*np.exp(c*x))+d
# return (((x+d)*a)/((x+d)+b))+c
# def cand_pairs_to_dict(candidates):
# d={x:[] for x,_ in candidates}
# for x,y in candidates:
# d[x].append(y)
# return d
maxs_d = maxs.groupby(maxs).apply(lambda x: x.index.tolist())
# reduce minimal time point to make a right track overlap with a left track
mins_d.index = mins_d.index - 1
# find common end points
common = sorted(set(mins_d.index).intersection(maxs_d.index), reverse=True)
contigs = [(maxs_d[t], mins_d[t]) for t in common]
return contigs
......@@ -12,19 +12,20 @@ from postprocessor.core.abc import PostProcessABC
class LineageProcessParameters(ParametersABC):
"""
Parameters
"""
"""Parameters - none are necessary."""
_defaults = {}
class LineageProcess(PostProcessABC):
"""
Lineage process that must be passed a (N,3) lineage matrix (where the columns are trap, mother, daughter respectively)
To analyse lineage data.
Currently bare bones, but extracts lineage information from a Signal or Cells object.
"""
def __init__(self, parameters: LineageProcessParameters):
"""Initialise using PostProcessABC."""
super().__init__(parameters)
@abstractmethod
......@@ -34,6 +35,7 @@ class LineageProcess(PostProcessABC):
lineage: np.ndarray,
*args,
):
"""Implement method required by PostProcessABC - undefined."""
pass
@classmethod
......@@ -45,8 +47,9 @@ class LineageProcess(PostProcessABC):
**kwargs,
):
"""
Overrides PostProcess.as_function classmethod.
Lineage functions require lineage information to be passed if run as function.
Override PostProcesABC.as_function method.
Lineage functions require lineage information to be run as functions.
"""
parameters = cls.default_parameters(**kwargs)
return cls(parameters=parameters).run(
......@@ -54,8 +57,9 @@ class LineageProcess(PostProcessABC):
)
def get_lineage_information(self, signal=None, merged=True):
"""Get lineage as an array with tile IDs, mother labels, and corresponding bud labels."""
if signal is not None and "mother_label" in signal.index.names:
# from kymograph
lineage = get_index_as_np(signal)
elif hasattr(self, "lineage"):
lineage = self.lineage
......@@ -68,5 +72,5 @@ class LineageProcess(PostProcessABC):
elif self.cells is not None:
lineage = self.cells.mothers_daughters
else:
raise Exception("No linage information found")
raise Exception("No lineage information found")
return lineage
# change "prepost" to "preprocess"; change filename to postprocessor_engine.py ??
import typing as t
from itertools import takewhile
......@@ -61,36 +62,24 @@ class PostProcessorParameters(ParametersABC):
kind: list of str
If "ph_batman" included, add targets for experiments using pHlourin.
"""
# each subitem specifies the function to be called and the location
# on the h5 file to be written
# each subitem specifies the function to be called
# and the h5-file location for the results
#: why does merger have a string and picker a list?
targets = {
"prepost": {
"merger": "/extraction/general/None/area",
"picker": ["/extraction/general/None/area"],
},
"processes": [
[
"buddings",
["/extraction/general/None/volume"],
],
[
"dsignal",
[
"/extraction/general/None/volume",
],
],
[
"bud_metric",
[
"/extraction/general/None/volume",
],
],
[
"dsignal",
[
"/postprocessing/bud_metric/extraction_general_None_volume",
],
],
["buddings", ["/extraction/general/None/volume"]],
# ["dsignal", ["/extraction/general/None/volume"]],
["bud_metric", ["/extraction/general/None/volume"]],
# [
# "dsignal",
# [
# "/postprocessing/bud_metric/extraction_general_None_volume"
# ],
# ],
],
}
param_sets = {
......@@ -129,7 +118,7 @@ class PostProcessorParameters(ParametersABC):
class PostProcessor(ProcessABC):
def __init__(self, filename, parameters):
"""
Initialise PostProcessor
Initialise PostProcessor.
Parameters
----------
......@@ -150,7 +139,7 @@ class PostProcessor(ProcessABC):
for k in dicted_params.keys():
if not isinstance(dicted_params[k], dict):
dicted_params[k] = dicted_params[k].to_dict()
# merger and picker
# initialise merger and picker
self.merger = Merger(
MergerParameters.from_dict(dicted_params["merger"])
)
......@@ -158,12 +147,12 @@ class PostProcessor(ProcessABC):
PickerParameters.from_dict(dicted_params["picker"]),
cells=Cells.from_source(filename),
)
# processes, such as buddings
# get processes, such as buddings
self.classfun = {
process: get_process(process)
for process, _ in parameters["targets"]["processes"]
}
# parameters for the process in classfun
# get parameters for the processes in classfun
self.parameters_classfun = {
process: get_parameters(process)
for process, _ in parameters["targets"]["processes"]
......@@ -172,31 +161,32 @@ class PostProcessor(ProcessABC):
self.targets = parameters["targets"]
def run_prepost(self):
"""Using picker, get and write lineages, returning mothers and daughters."""
"""Important processes run before normal post-processing ones"""
"""
Run picker and merger and get lineages.
Necessary before any processes can run.
"""
# run merger
record = self._signal.get_raw(self.targets["prepost"]["merger"])
merges = np.array(self.merger.run(record), dtype=int)
self._writer.write(
"modifiers/merges", data=[np.array(x) for x in merges]
)
# get lineages from picker
lineage = _assoc_indices_to_3d(self.picker.cells.mothers_daughters)
lineage_merged = []
if merges.any(): # Update lineages after merge events
if merges.any():
# update lineages after merge events
merged_indices = merge_association(lineage, merges)
# Remove repeated labels post-merging
# remove repeated labels post-merging
lineage_merged = np.unique(merged_indices, axis=0)
self.lineage = _3d_index_to_2d(
lineage_merged if len(lineage_merged) else lineage
)
self._writer.write(
"modifiers/lineage_merged", _3d_index_to_2d(lineage_merged)
)
# run picker
picked_indices = self.picker.run(
self._signal[self.targets["prepost"]["picker"][0]]
)
......@@ -211,25 +201,15 @@ class PostProcessor(ProcessABC):
overwrite="overwrite",
)
@staticmethod
def pick_mother(a, b):
"""Update the mother id following this priorities:
The mother has a lower id
"""
x = max(a, b)
if min([a, b]):
x = [a, b][np.argmin([a, b])]
return x
def run(self):
"""
Write the results to the h5 file.
Processes include identifying buddings and finding bud metrics.
"""
# run merger, picker, and find lineages
self.run_prepost()
# run processes
# run processes: process is a str; datasets is a list of str
for process, datasets in tqdm(self.targets["processes"]):
if process in self.parameters["param_sets"].get("processes", {}):
# parameters already assigned
......@@ -243,16 +223,14 @@ class PostProcessor(ProcessABC):
loaded_process = self.classfun[process](parameters)
if isinstance(parameters, LineageProcessParameters):
loaded_process.lineage = self.lineage
# apply process to each dataset
for dataset in datasets:
self.run_process(dataset, process, loaded_process)
def run_process(self, dataset, process, loaded_process):
"""Run process on a single dataset and write the result."""
# define signal
"""Run process to obtain a single dataset and write the result."""
# get pre-processed data
if isinstance(dataset, list):
# multisignal process
signal = [self._signal[d] for d in dataset]
elif isinstance(dataset, str):
signal = self._signal[dataset]
......@@ -269,8 +247,9 @@ class PostProcessor(ProcessABC):
[], columns=signal.columns, index=signal.index
)
result.columns.names = ["timepoint"]
# define outpath, where result will be written
# use outpath to write result
if process in self.parameters["outpaths"]:
# outpath already defined
outpath = self.parameters["outpaths"][process]
elif isinstance(dataset, list):
# no outpath is defined
......@@ -318,3 +297,15 @@ class PostProcessor(ProcessABC):
metadata: t.Dict,
):
self._writer.write(path, result, meta=metadata, overwrite="overwrite")
@staticmethod
def pick_mother(a, b):
"""
Update the mother id following this priorities:
The mother has a lower id
"""
x = max(a, b)
if min([a, b]):
x = [a, b][np.argmin([a, b])]
return x