Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • swain-lab/aliby/aliby-mirror
  • swain-lab/aliby/alibylite
2 results
Show changes
Commits on Source (45)
Showing
with 1333 additions and 827 deletions
......@@ -17,16 +17,14 @@ atomic = t.Union[int, float, str, bool]
class ParametersABC(ABC):
"""
Defines parameters as attributes and allows parameters to
Define parameters as attributes and allow parameters to
be converted to either a dictionary or to yaml.
No attribute should be called "parameters"!
"""
def __init__(self, **kwargs):
"""
Defines parameters as attributes
"""
"""Define parameters as attributes."""
assert (
"parameters" not in kwargs
), "No attribute should be named parameters"
......@@ -35,8 +33,9 @@ class ParametersABC(ABC):
def to_dict(self, iterable="null") -> t.Dict:
"""
Recursive function to return a nested dictionary of the
attributes of the class instance.
Return a nested dictionary of the attributes of the class instance.
Uses recursion.
"""
if isinstance(iterable, dict):
if any(
......@@ -62,7 +61,8 @@ class ParametersABC(ABC):
def to_yaml(self, path: Union[Path, str] = None):
"""
Returns a yaml stream of the attributes of the class instance.
Return a yaml stream of the attributes of the class instance.
If path is provided, the yaml stream is saved there.
Parameters
......@@ -81,9 +81,7 @@ class ParametersABC(ABC):
@classmethod
def from_yaml(cls, source: Union[Path, str]):
"""
Returns instance from a yaml filename or stdin
"""
"""Return instance from a yaml filename or stdin."""
is_buffer = True
try:
if Path(source).exists():
......@@ -107,7 +105,8 @@ class ParametersABC(ABC):
def update(self, name: str, new_value):
"""
Update values recursively
Update values recursively.
if name is a dictionary, replace data where existing found or add if not.
It warns against type changes.
......@@ -179,7 +178,8 @@ def add_to_collection(
class ProcessABC(ABC):
"""
Base class for processes.
Defines parameters as attributes and requires run method to be defined.
Define parameters as attributes and requires a run method.
"""
def __init__(self, parameters):
......@@ -243,11 +243,9 @@ class StepABC(ProcessABC):
@timer
def run_tp(self, tp: int, **kwargs):
"""
Time and log the timing of a step.
"""
"""Time and log the timing of a step."""
return self._run_tp(tp, **kwargs)
def run(self):
# Replace run with run_tp
raise Warning("Steps use run_tp instead of run")
raise Warning("Steps use run_tp instead of run.")
This diff is collapsed.
......@@ -6,17 +6,19 @@ import typing as t
from functools import wraps
def _first_arg_str_to_df(
def _first_arg_str_to_raw_df(
fn: t.Callable,
):
"""Enable Signal-like classes to convert strings to data sets."""
@wraps(fn)
def format_input(*args, **kwargs):
cls = args[0]
data = args[1]
if isinstance(data, str):
# get data from h5 file
# get data from h5 file using Signal's get_raw
data = cls.get_raw(data)
# replace path in the undecorated function with data
return fn(cls, data, *args[2:], **kwargs)
return format_input
......@@ -66,7 +66,7 @@ class MetaData:
# Needed because HDF5 attributes do not support dictionaries
def flatten_dict(nested_dict, separator="/"):
"""
Flattens nested dictionary. If empty return as-is.
Flatten nested dictionary. If empty return as-is.
"""
flattened = {}
if nested_dict:
......@@ -79,9 +79,7 @@ def flatten_dict(nested_dict, separator="/"):
# Needed because HDF5 attributes do not support datetime objects
# Takes care of time zones & daylight saving
def datetime_to_timestamp(time, locale="Europe/London"):
"""
Convert datetime object to UNIX timestamp
"""
"""Convert datetime object to UNIX timestamp."""
return timezone(locale).localize(time).timestamp()
......@@ -189,36 +187,37 @@ def parse_swainlab_metadata(filedir: t.Union[str, Path]):
Dictionary with minimal metadata
"""
filedir = Path(filedir)
filepath = find_file(filedir, "*.log")
if filepath:
# new log files
raw_parse = parse_from_swainlab_grammar(filepath)
minimal_meta = get_meta_swainlab(raw_parse)
else:
# old log files
if filedir.is_file() or str(filedir).endswith(".zarr"):
# log file is in parent directory
filedir = filedir.parent
legacy_parse = parse_logfiles(filedir)
minimal_meta = (
get_meta_from_legacy(legacy_parse) if legacy_parse else {}
)
return minimal_meta
def dispatch_metadata_parser(filepath: t.Union[str, Path]):
"""
Function to dispatch different metadata parsers that convert logfiles into a
basic metadata dictionary. Currently only contains the swainlab log parsers.
Dispatch different metadata parsers that convert logfiles into a dictionary.
Currently only contains the swainlab log parsers.
Input:
--------
filepath: str existing file containing metadata, or folder containing naming conventions
filepath: str existing file containing metadata, or folder containing naming
conventions
"""
parsed_meta = parse_swainlab_metadata(filepath)
if parsed_meta is None:
parsed_meta = dir_to_meta
return parsed_meta
......
......@@ -10,8 +10,8 @@ import numpy as np
import pandas as pd
from agora.io.bridge import BridgeH5
from agora.io.decorators import _first_arg_str_to_df
from agora.utils.indexing import validate_association
from agora.io.decorators import _first_arg_str_to_raw_df
from agora.utils.indexing import validate_lineage
from agora.utils.kymograph import add_index_levels
from agora.utils.merge import apply_merges
......@@ -20,11 +20,14 @@ class Signal(BridgeH5):
"""
Fetch data from h5 files for post-processing.
Signal assumes that the metadata and data are accessible to perform time-adjustments and apply previously recorded post-processes.
Signal assumes that the metadata and data are accessible to
perform time-adjustments and apply previously recorded
post-processes.
"""
def __init__(self, file: t.Union[str, Path]):
"""Define index_names for dataframes, candidate fluorescence channels, and composite statistics."""
"""Define index_names for dataframes, candidate fluorescence channels,
and composite statistics."""
super().__init__(file, flag=None)
self.index_names = (
"experiment",
......@@ -46,9 +49,9 @@ class Signal(BridgeH5):
def __getitem__(self, dsets: t.Union[str, t.Collection]):
"""Get and potentially pre-process data from h5 file and return as a dataframe."""
if isinstance(dsets, str): # no pre-processing
if isinstance(dsets, str):
return self.get(dsets)
elif isinstance(dsets, list): # pre-processing
elif isinstance(dsets, list):
is_bgd = [dset.endswith("imBackground") for dset in dsets]
# Check we are not comparing tile-indexed and cell-indexed data
assert sum(is_bgd) == 0 or sum(is_bgd) == len(
......@@ -58,22 +61,23 @@ class Signal(BridgeH5):
else:
raise Exception(f"Invalid type {type(dsets)} to get datasets")
def get(self, dsets: t.Union[str, t.Collection], **kwargs):
"""Get and potentially pre-process data from h5 file and return as a dataframe."""
if isinstance(dsets, str): # no pre-processing
df = self.get_raw(dsets, **kwargs)
def get(self, dset_name: t.Union[str, t.Collection], **kwargs):
"""Return pre-processed data as a dataframe."""
if isinstance(dset_name, str):
dsets = self.get_raw(dset_name, **kwargs)
prepost_applied = self.apply_prepost(dsets, **kwargs)
return self.add_name(prepost_applied, dsets)
return self.add_name(prepost_applied, dset_name)
else:
raise Exception("Error in Signal.get")
@staticmethod
def add_name(df, name):
"""Add column of identical strings to a dataframe."""
"""Add name of the Signal as an attribute to its corresponding dataframe."""
df.name = name
return df
def cols_in_mins(self, df: pd.DataFrame):
# Convert numerical columns in a dataframe to minutes
"""Convert numerical columns in a dataframe to minutes."""
try:
df.columns = (df.columns * self.tinterval // 60).astype(int)
except Exception as e:
......@@ -94,14 +98,15 @@ class Signal(BridgeH5):
if tinterval_location in f.attrs:
return f.attrs[tinterval_location][0]
else:
logging.getlogger("aliby").warn(
logging.getLogger("aliby").warn(
f"{str(self.filename).split('/')[-1]}: using default time interval of 5 minutes"
)
return 5
@staticmethod
def get_retained(df, cutoff):
"""Return a fraction of the df, one without later time points."""
"""Return rows of df with at least cutoff fraction of the total number
of time points."""
return df.loc[bn.nansum(df.notna(), axis=1) > df.shape[1] * cutoff]
@property
......@@ -110,15 +115,15 @@ class Signal(BridgeH5):
with h5py.File(self.filename, "r") as f:
return list(f.attrs["channels"])
@_first_arg_str_to_df
def retained(self, signal, cutoff=0.8):
"""
Load data (via decorator) and reduce the resulting dataframe.
Load data for a signal or a list of signals and reduce the resulting
dataframes to a fraction of their original size, losing late time
points.
dataframes to rows with sufficient numbers of time points.
"""
if isinstance(signal, str):
signal = self.get_raw(signal)
if isinstance(signal, pd.DataFrame):
return self.get_retained(signal, cutoff)
elif isinstance(signal, list):
......@@ -131,17 +136,15 @@ class Signal(BridgeH5):
"""
Get lineage data from a given location in the h5 file.
Returns an array with three columns: the tile id, the mother label, and the daughter label.
Returns an array with three columns: the tile id, the mother label,
and the daughter label.
"""
if lineage_location is None:
lineage_location = "modifiers/lineage_merged"
with h5py.File(self.filename, "r") as f:
# if lineage_location not in f:
# lineage_location = lineage_location.split("_")[0]
if lineage_location not in f:
lineage_location = "postprocessing/lineage"
tile_mo_da = f[lineage_location]
if isinstance(tile_mo_da, h5py.Dataset):
lineage = tile_mo_da[()]
else:
......@@ -154,7 +157,7 @@ class Signal(BridgeH5):
).T
return lineage
@_first_arg_str_to_df
@_first_arg_str_to_raw_df
def apply_prepost(
self,
data: t.Union[str, pd.DataFrame],
......@@ -162,57 +165,40 @@ class Signal(BridgeH5):
picks: t.Union[t.Collection, bool] = True,
):
"""
Apply modifier operations (picker or merger) to a dataframe.
Apply picking and merging to a Signal data frame.
Parameters
----------
data : t.Union[str, pd.DataFrame]
DataFrame or path to one.
A data frame or a path to one.
merges : t.Union[np.ndarray, bool]
(optional) 2-D array with three columns: the tile id, the mother label, and the daughter id.
(optional) An array of pairs of (trap, cell) indices to merge.
If True, fetch merges from file.
picks : t.Union[np.ndarray, bool]
(optional) 2-D array with two columns: the tiles and
the cell labels.
(optional) An array of (trap, cell) indices.
If True, fetch picks from file.
Examples
--------
FIXME: Add docs.
"""
if isinstance(merges, bool):
merges: np.ndarray = self.load_merges() if merges else np.array([])
merges = self.load_merges() if merges else np.array([])
if merges.any():
merged = apply_merges(data, merges)
else:
merged = copy(data)
if isinstance(picks, bool):
picks = (
self.get_picks(names=merged.index.names)
self.get_picks(
names=merged.index.names, path="modifiers/picks/"
)
if picks
else set(merged.index)
else merged.index
)
with h5py.File(self.filename, "r") as f:
if "modifiers/picks" in f and picks:
if picks:
return merged.loc[
set(picks).intersection(
[tuple(x) for x in merged.index]
)
]
else:
if isinstance(merged.index, pd.MultiIndex):
empty_lvls = [[] for i in merged.index.names]
index = pd.MultiIndex(
levels=empty_lvls,
codes=empty_lvls,
names=merged.index.names,
)
else:
index = pd.Index([], name=merged.index.name)
merged = pd.DataFrame([], index=index)
return merged
if picks:
picked_indices = set(picks).intersection(
[tuple(x) for x in merged.index]
)
return merged.loc[picked_indices]
else:
return merged
@cached_property
def p_available(self):
......@@ -272,10 +258,11 @@ class Signal(BridgeH5):
Parameters
----------
dataset: str or list of strs
The name of the h5 file or a list of h5 file names
The name of the h5 file or a list of h5 file names.
in_minutes: boolean
If True,
If True, convert column headings to times in minutes.
lineage: boolean
If True, add mother_label to index.
"""
try:
if isinstance(dataset, str):
......@@ -288,15 +275,17 @@ class Signal(BridgeH5):
self.get_raw(dset, in_minutes=in_minutes, lineage=lineage)
for dset in dataset
]
if lineage: # assume that df is sorted
if lineage:
# assume that df is sorted
mother_label = np.zeros(len(df), dtype=int)
lineage = self.lineage()
a, b = validate_association(
# information on buds
valid_lineage, valid_indices = validate_lineage(
lineage,
np.array(df.index.to_list()),
match_column=1,
"daughters",
)
mother_label[b] = lineage[a, 1]
mother_label[valid_indices] = lineage[valid_lineage, 1]
df = add_index_levels(df, {"mother_label": mother_label})
return df
except Exception as e:
......@@ -316,13 +305,14 @@ class Signal(BridgeH5):
names: t.Tuple[str, ...] = ("trap", "cell_label"),
path: str = "modifiers/picks/",
) -> t.Set[t.Tuple[int, str]]:
"""Get the relevant picks based on names."""
"""Get picks from the h5 file."""
with h5py.File(self.filename, "r") as f:
picks = set()
if path in f:
picks = set(
zip(*[f[path + name] for name in names if name in f[path]])
)
else:
picks = set()
return picks
def dataset_to_df(self, f: h5py.File, path: str) -> pd.DataFrame:
......@@ -353,10 +343,7 @@ class Signal(BridgeH5):
fullname: str,
node: t.Union[h5py.Dataset, h5py.Group],
):
"""
Store the name of a signal if it is a leaf node
(a group with no more groups inside) and if it starts with extraction.
"""
"""Store the name of a signal if it is a leaf node and if it starts with extraction."""
if isinstance(node, h5py.Group) and np.all(
[isinstance(x, h5py.Dataset) for x in node.values()]
):
......
......@@ -9,6 +9,152 @@ This can be:
import numpy as np
import typing as t
# data type to link together trap and cell ids
i_dtype = {"names": ["trap_id", "cell_id"], "formats": [np.int64, np.int64]}
def validate_lineage(
lineage: np.ndarray, indices: np.ndarray, how: str = "families"
):
"""
Identify mother-bud pairs that exist both in lineage and a Signal's
indices.
We expect the lineage information to be unique: a bud should not have
two mothers.
Parameters
----------
lineage : np.ndarray
2D array of lineage associations where columns are
(trap, mother, daughter)
or
a 3D array, which is an array of 2 X 2 arrays comprising
[[trap_id, mother_label], [trap_id, daughter_label]].
indices : np.ndarray
A 2D array of cell indices from a Signal, (trap_id, cell_label).
This array should not include mother_label.
how: str
If "mothers", matches indicate mothers from mother-bud pairs;
If "daughters", matches indicate daughters from mother-bud pairs;
If "families", matches indicate mothers and daughters in mother-bud pairs.
Returns
-------
valid_lineage: boolean np.ndarray
1D array indicating matched elements in lineage.
valid_indices: boolean np.ndarray
1D array indicating matched elements in indices.
Examples
--------
>>> import numpy as np
>>> from agora.utils.indexing import validate_lineage
>>> lineage = np.array([ [[0, 1], [0, 3]], [[0, 1], [0, 4]], [[0, 1], [0, 6]], [[0, 4], [0, 7]] ])
>>> indices = np.array([ [0, 1], [0, 2], [0, 3]])
>>> valid_lineage, valid_indices = validate_lineage(lineage, indices)
>>> print(valid_lineage)
array([ True, False, False, False])
>>> print(valid_indices)
array([ True, False, True])
and
>>> lineage = np.array([[[0,3], [0,1]], [[0,2], [0,4]]])
>>> indices = np.array([[0,1], [0,2], [0,3]])
>>> valid_lineage, valid_indices = validate_lineage(lineage, indices)
>>> print(valid_lineage)
array([ True, False])
>>> print(valid_indices)
array([ True, False, True])
"""
if lineage.ndim == 2:
# [trap, mother, daughter] becomes [[trap, mother], [trap, daughter]]
lineage = _assoc_indices_to_3d(lineage)
if how == "mothers":
c_index = 0
elif how == "daughters":
c_index = 1
# find valid lineage
valid_lineages = index_isin(lineage, indices)
if how == "families":
# both mother and bud must be in indices
valid_lineage = valid_lineages.all(axis=1)
else:
valid_lineage = valid_lineages[:, c_index, :]
flat_valid_lineage = valid_lineage.flatten()
# find valid indices
selected_lineages = lineage[flat_valid_lineage, ...]
if how == "families":
# select only pairs of mother and bud indices
valid_indices = index_isin(indices, selected_lineages)
else:
valid_indices = index_isin(indices, selected_lineages[:, c_index, :])
flat_valid_indices = valid_indices.flatten()
if (
indices[flat_valid_indices, :].size
!= np.unique(
lineage[flat_valid_lineage, :].reshape(-1, 2), axis=0
).size
):
# all unique indices in valid_lineages should be in valid_indices
raise Exception(
"Error in validate_lineage: "
"lineage information is likely not unique."
)
return flat_valid_lineage, flat_valid_indices
def index_isin(x: np.ndarray, y: np.ndarray) -> np.ndarray:
"""
Find those elements of x that are in y.
Both arrays must be arrays of integer indices,
such as (trap_id, cell_id).
"""
x = np.ascontiguousarray(x, dtype=np.int64)
y = np.ascontiguousarray(y, dtype=np.int64)
xv = x.view(i_dtype)
inboth = np.intersect1d(xv, y.view(i_dtype))
x_bool = np.isin(xv, inboth)
return x_bool
def _assoc_indices_to_3d(ndarray: np.ndarray):
"""
Convert the last column to a new row and repeat first column's values.
For example: [trap, mother, daughter] becomes
[[trap, mother], [trap, daughter]].
Assumes the input array has shape (N,3).
"""
result = ndarray
if len(ndarray) and ndarray.ndim > 1:
# faster indexing for single positions
if ndarray.shape[1] == 3:
result = np.transpose(
np.hstack((ndarray[:, [0]], ndarray)).reshape(-1, 2, 2),
axes=[0, 2, 1],
)
else:
# 20% slower but more general indexing
columns = np.arange(ndarray.shape[1])
result = np.stack(
(
ndarray[:, np.delete(columns, -1)],
ndarray[:, np.delete(columns, -2)],
),
axis=1,
)
return result
###################################################################
def validate_association(
association: np.ndarray,
......@@ -104,38 +250,8 @@ def validate_association(
return valid_association, valid_indices
def _assoc_indices_to_3d(ndarray: np.ndarray):
"""
Convert the last column to a new row while repeating all previous indices.
This is useful when converting a signal multiindex before comparing association.
Assumes the input array has shape (N,3)
"""
result = ndarray
if len(ndarray) and ndarray.ndim > 1:
if ndarray.shape[1] == 3: # Faster indexing for single positions
result = np.transpose(
np.hstack((ndarray[:, [0]], ndarray)).reshape(-1, 2, 2),
axes=[0, 2, 1],
)
else: # 20% slower but more general indexing
columns = np.arange(ndarray.shape[1])
result = np.stack(
(
ndarray[:, np.delete(columns, -1)],
ndarray[:, np.delete(columns, -2)],
),
axis=1,
)
return result
def _3d_index_to_2d(array: np.ndarray):
"""
Opposite to _assoc_indices_to_3d.
"""
"""Revert _assoc_indices_to_3d."""
result = array
if len(array):
result = np.concatenate(
......
#!/usr/bin/env jupyter
"""
Utilities based on association are used to efficiently acquire indices of
tracklets with some kind of relationship.
This can be:
- Cells that are to be merged.
- Cells that have a lineage relationship.
"""
import numpy as np
import typing as t
def validate_association(
association: np.ndarray,
indices: np.ndarray,
match_column: t.Optional[int] = None,
) -> t.Tuple[np.ndarray, np.ndarray]:
"""
Identify mother-bud pairs that exist both in lineage and a Signal's indices.
"""
if association.ndim == 2:
# reshape into 3D array for broadcasting
# for each trap, [trap, mother, daughter] becomes
# [[trap, mother], [trap, daughter]]
association = _assoc_indices_to_3d(association)
valid_association, valid_indices = validate_lineage(association, indices)
# Alan's working code
# Compare existing association with available indices
# Swap trap and label axes for the association array to correctly cast
valid_ndassociation_a = association[..., None] == indices.T[None, ...]
# Broadcasting is confusing (but efficient):
# First we check the dimension across trap and cell id, to ensure both match
valid_cell_ids_a = valid_ndassociation_a.all(axis=2)
if match_column is None:
# Then we check the merge tuples to check which cases have both target and source
valid_association_a = valid_cell_ids_a.any(axis=2).all(axis=1)
# Finally we check the dimension that crosses all indices, to ensure the pair
# is present in a valid merge event.
valid_indices_a = (
valid_ndassociation_a[valid_association_a]
.all(axis=2)
.any(axis=(0, 1))
)
else: # We fetch specific indices if we aim for the ones with one present
valid_indices_a = valid_cell_ids_a[:, match_column].any(axis=0)
# Valid association then becomes a boolean array, true means that there is a
# match (match_column) between that cell and the index
valid_association_a = (
valid_cell_ids_a[:, match_column] & valid_indices
).any(axis=1)
assert np.array_equal(
valid_association, valid_association_a
), "valid_association error"
assert np.array_equal(
valid_indices, valid_indices_a
), "valid_indices error"
return valid_association, valid_indices
def validate_association_old(
association: np.ndarray,
indices: np.ndarray,
match_column: t.Optional[int] = None,
) -> t.Tuple[np.ndarray, np.ndarray]:
"""
Identify mother-bud pairs that exist both in lineage and a Signal's indices.
Parameters
----------
association : np.ndarray
2D array of lineage associations where columns are (trap, mother, daughter)
or
a 3D array, which is an array of 2 X 2 arrays comprising [[trap_id, mother_label], [trap_id, daughter_label]].
indices : np.ndarray
A 2D array where each column is a different level, such as (trap_id, cell_label), which typically is an index of a Signal
dataframe. This array should not include mother_label.
match_column: int
If 0, matches indicate mothers from mother-bud pairs;
If 1, matches indicate daughters from mother-bud pairs;
If None, matches indicate either mothers or daughters in mother-bud pairs.
Returns
-------
valid_association: boolean np.ndarray
1D array indicating elements in association with matches.
valid_indices: boolean np.ndarray
1D array indicating elements in indices with matches.
Examples
--------
>>> import numpy as np
>>> from agora.utils.indexing import validate_association
>>> association = np.array([ [[0, 1], [0, 3]], [[0, 1], [0, 4]], [[0, 1], [0, 6]], [[0, 4], [0, 7]] ])
>>> indices = np.array([ [0, 1], [0, 2], [0, 3]])
>>> print(indices.T)
>>> valid_association, valid_indices = validate_association(association, indices)
>>> print(valid_association)
array([ True, False, False, False])
>>> print(valid_indices)
array([ True, False, True])
and
>>> association = np.array([[[0,3], [0,1]], [[0,2], [0,4]]])
>>> indices = np.array([[0,1], [0,2], [0,3]])
>>> valid_association, valid_indices = validate_association(association, indices)
>>> print(valid_association)
array([ True, False])
>>> print(valid_indices)
array([ True, False, True])
"""
if association.ndim == 2:
# reshape into 3D array for broadcasting
# for each trap, [trap, mother, daughter] becomes
# [[trap, mother], [trap, daughter]]
association = _assoc_indices_to_3d(association)
# use broadcasting to compare association with indices
# swap trap and cell_label axes for correct broadcasting
indicesT = indices.T
# compare each of [[trap, mother], [trap, daughter]] for all traps
# in association with [trap, cell_label] for all traps in indices
# association is no_traps x 2 x 2; indices is no_traps X 2
# valid_ndassociation is no_traps_association x 2 x 2 x no_traps_indices
valid_ndassociation = (
association[..., np.newaxis] == indicesT[np.newaxis, ...]
)
# find matches in association
###
# make True comparisons with both trap_ids and cell labels matching
# compare trap_ids and cell_ids for each pair of traps
valid_cell_ids = valid_ndassociation.all(axis=2)
if match_column is None:
# make True comparisons match at least one row in indices
# at least one cell_id matches
va_intermediate = valid_cell_ids.any(axis=2)
# make True comparisons have both mother and bud matching rows in indices
valid_association = va_intermediate.all(axis=1)
else:
# match_column selects mothers if 0 and daughters if 1
# make True match at least one row in indices
valid_association = valid_cell_ids[:, match_column].any(axis=1)
# find matches in indices
###
# make True comparisons have a validated association for both the mother and bud
# make True comparisons have both trap_ids and cell labels matching
valid_cell_ids_va = valid_ndassociation[valid_association].all(axis=2)
if match_column is None:
# make True comparisons match either a mother or a bud in association
valid_indices = valid_cell_ids_va.any(axis=(0, 1))
else:
valid_indices = valid_cell_ids_va[:, match_column][0]
# Alan's working code
# Compare existing association with available indices
# Swap trap and label axes for the association array to correctly cast
valid_ndassociation_a = association[..., None] == indices.T[None, ...]
# Broadcasting is confusing (but efficient):
# First we check the dimension across trap and cell id, to ensure both match
valid_cell_ids_a = valid_ndassociation_a.all(axis=2)
if match_column is None:
# Then we check the merge tuples to check which cases have both target and source
valid_association_a = valid_cell_ids_a.any(axis=2).all(axis=1)
# Finally we check the dimension that crosses all indices, to ensure the pair
# is present in a valid merge event.
valid_indices_a = (
valid_ndassociation_a[valid_association_a]
.all(axis=2)
.any(axis=(0, 1))
)
else: # We fetch specific indices if we aim for the ones with one present
valid_indices_a = valid_cell_ids_a[:, match_column].any(axis=0)
# Valid association then becomes a boolean array, true means that there is a
# match (match_column) between that cell and the index
valid_association_a = (
valid_cell_ids_a[:, match_column] & valid_indices
).any(axis=1)
assert np.array_equal(
valid_association, valid_association_a
), "valid_association error"
assert np.array_equal(
valid_indices, valid_indices_a
), "valid_indices error"
return valid_association, valid_indices
def _assoc_indices_to_3d(ndarray: np.ndarray):
"""
Reorganise an array of shape (N, 3) into one of shape (N, 2, 2).
Reorganise an array so that the last entry of each row is removed
and generates a new row. This new row retains all other entries of
the original row.
Example:
[ [0, 1, 3], [0, 1, 4] ]
becomes
[ [[0, 1], [0, 3]], [[0, 1], [0, 4]] ]
"""
result = ndarray
if len(ndarray) and ndarray.ndim > 1:
if ndarray.shape[1] == 3:
# faster indexing for single positions
result = np.transpose(
np.hstack((ndarray[:, [0]], ndarray)).reshape(-1, 2, 2),
axes=[0, 2, 1],
)
else:
# 20% slower, but more general indexing
columns = np.arange(ndarray.shape[1])
result = np.stack(
(
ndarray[:, np.delete(columns, -1)],
ndarray[:, np.delete(columns, -2)],
),
axis=1,
)
return result
def _3d_index_to_2d(array: np.ndarray):
"""Revert switch from _assoc_indices_to_3d."""
result = array
if len(array):
result = np.concatenate(
(array[:, 0, :], array[:, 1, 1, np.newaxis]), axis=1
)
return result
def compare_indices(x: np.ndarray, y: np.ndarray) -> np.ndarray:
"""
Compare two 2D arrays using broadcasting.
Return a binary array where a True value links two cells where
all cells are the same.
"""
return (x[..., np.newaxis] == y.T[np.newaxis, ...]).all(axis=1)
......@@ -86,16 +86,19 @@ def bidirectional_retainment_filter(
daughters_thresh: int = 7,
) -> pd.DataFrame:
"""
Retrieve families where mothers are present for more than a fraction of the experiment, and daughters for longer than some number of time-points.
Retrieve families where mothers are present for more than a fraction
of the experiment and daughters for longer than some number of
time-points.
Parameters
----------
df: pd.DataFrame
Data
mothers_thresh: float
Minimum fraction of experiment's total duration for which mothers must be present.
Minimum fraction of experiment's total duration for which mothers
must be present.
daughters_thresh: int
Minimum number of time points for which daughters must be observed
Minimum number of time points for which daughters must be observed.
"""
# daughters
all_daughters = df.loc[df.index.get_level_values("mother_label") > 0]
......@@ -170,6 +173,7 @@ def slices_from_spans(spans: t.Tuple[int], df: pd.DataFrame) -> t.List[slice]:
def drop_mother_label(index: pd.MultiIndex) -> np.ndarray:
"""Remove mother_label level from a MultiIndex."""
no_mother_label = index
if "mother_label" in index.names:
no_mother_label = index.droplevel("mother_label")
......
#!/usr/bin/env python3
import re
import typing as t
import numpy as np
import pandas as pd
from agora.io.bridge import groupsort
from itertools import groupby
def mb_array_to_dict(mb_array: np.ndarray):
......@@ -19,4 +15,3 @@ def mb_array_to_dict(mb_array: np.ndarray):
for trap, mo_da in groupsort(mb_array).items()
for mo, daughters in groupsort(mo_da).items()
}
......@@ -3,90 +3,159 @@
Functions to efficiently merge rows in DataFrames.
"""
import typing as t
from copy import copy
import numpy as np
import pandas as pd
from utils_find_1st import cmp_larger, find_1st
from agora.utils.indexing import compare_indices, validate_association
from agora.utils.indexing import index_isin
def group_merges(merges: np.ndarray) -> t.List[t.Tuple]:
"""
Convert merges into a list of merges for traps requiring multiple
merges and then for traps requiring single merges.
"""
left_tracks = merges[:, 0]
right_tracks = merges[:, 1]
# find traps requiring multiple merges
linr = merges[index_isin(left_tracks, right_tracks).flatten(), :]
rinl = merges[index_isin(right_tracks, left_tracks).flatten(), :]
# make unique and order merges for each trap
multi_merge = np.unique(np.concatenate((linr, rinl)), axis=0)
# find traps requiring a singe merge
single_merge = merges[
~index_isin(merges, multi_merge).all(axis=1).flatten(), :
]
# convert to lists of arrays
single_merge_list = [[sm] for sm in single_merge]
multi_merge_list = [
multi_merge[multi_merge[:, 0, 0] == trap_id, ...]
for trap_id in np.unique(multi_merge[:, 0, 0])
]
res = [*multi_merge_list, *single_merge_list]
return res
def merge_lineage(
lineage: np.ndarray, merges: np.ndarray
) -> (np.ndarray, np.ndarray):
"""
Use merges to update lineage information.
Check if merging causes any buds to have multiple mothers and discard
those incorrect merges.
Return updated lineage and merge arrays.
"""
flat_lineage = lineage.reshape(-1, 2)
bud_mother_dict = {
tuple(bud): mother for bud, mother in zip(lineage[:, 1], lineage[:, 0])
}
left_tracks = merges[:, 0]
# find left tracks that are in lineages
valid_lineages = index_isin(flat_lineage, left_tracks).flatten()
# group into multi- and then single merges
grouped_merges = group_merges(merges)
# perform merges
if valid_lineages.any():
# indices of each left track -> indices of rightmost right track
replacement_dict = {
tuple(contig_pair[0]): merge[-1][1]
for merge in grouped_merges
for contig_pair in merge
}
# if both key and value are buds, they must have the same mother
buds = lineage[:, 1]
incorrect_merges = [
key
for key in replacement_dict
if np.any(index_isin(buds, replacement_dict[key]).flatten())
and np.any(index_isin(buds, key).flatten())
and not np.array_equal(
bud_mother_dict[key],
bud_mother_dict[tuple(replacement_dict[key])],
)
]
if incorrect_merges:
# reassign incorrect merges so that they have no affect
for key in incorrect_merges:
replacement_dict[key] = key
# find only correct merges
new_merges = merges[
~index_isin(
merges[:, 0], np.array(incorrect_merges)
).flatten(),
...,
]
else:
new_merges = merges
# correct lineage information
# replace mother or bud index with index of rightmost track
flat_lineage[valid_lineages] = [
replacement_dict[tuple(index)]
for index in flat_lineage[valid_lineages]
]
# reverse flattening
new_lineage = flat_lineage.reshape(-1, 2, 2)
# remove any duplicates
new_lineage = np.unique(new_lineage, axis=0)
return new_lineage, new_merges
def apply_merges(data: pd.DataFrame, merges: np.ndarray):
"""Split data in two, one subset for rows relevant for merging and one
without them. It uses an array of source tracklets and target tracklets
to efficiently merge them.
"""
Generate a new data frame containing merged tracks.
Parameters
----------
data : pd.DataFrame
Input DataFrame.
A Signal data frame.
merges : np.ndarray
3-D ndarray where dimensions are (X,2,2): nmerges, source-target
pair and single-cell identifiers, respectively.
Examples
--------
FIXME: Add docs.
An array of pairs of (trap, cell) indices to merge.
"""
indices = data.index
if "mother_label" in indices.names:
indices = indices.droplevel("mother_label")
valid_merges, indices = validate_association(
merges, np.array(list(indices))
)
# Assign non-merged
merged = data.loc[~indices]
# Implement the merges and drop source rows.
# TODO Use matrices to perform merges in batch
# for ecficiency
indices = np.array(list(indices))
# merges in the data frame's indices
valid_merges = index_isin(merges, indices).all(axis=1).flatten()
# corresponding indices for the data frame in merges
selected_merges = merges[valid_merges, ...]
valid_indices = index_isin(indices, selected_merges).flatten()
# data not requiring merging
merged = data.loc[~valid_indices]
# merge tracks
if valid_merges.any():
to_merge = data.loc[indices]
targets, sources = zip(*merges[valid_merges])
for source, target in zip(sources, targets):
target = tuple(target)
to_merge.loc[target] = join_tracks_pair(
to_merge.loc[target].values,
to_merge.loc[tuple(source)].values,
to_merge = data.loc[valid_indices].copy()
left_indices = merges[:, 0]
right_indices = merges[:, 1]
# join left track with right track
for left_index, right_index in zip(left_indices, right_indices):
to_merge.loc[tuple(left_index)] = join_two_tracks(
to_merge.loc[tuple(left_index)].values,
to_merge.loc[tuple(right_index)].values,
)
to_merge.drop(map(tuple, sources), inplace=True)
# drop indices for right tracks
to_merge.drop(map(tuple, right_indices), inplace=True)
# add to data not requiring merges
merged = pd.concat((merged, to_merge), names=data.index.names)
return merged
def join_tracks_pair(target: np.ndarray, source: np.ndarray) -> np.ndarray:
"""
Join two tracks and return the new value of the target.
"""
target_copy = target
end = find_1st(target_copy[::-1], 0, cmp_larger)
target_copy[-end:] = source[-end:]
return target_copy
def group_merges(merges: np.ndarray) -> t.List[t.Tuple]:
# Return a list where the cell is present as source and target
# (multimerges)
sources_targets = compare_indices(merges[:, 0, :], merges[:, 1, :])
is_multimerge = sources_targets.any(axis=0) | sources_targets.any(axis=1)
is_monomerge = ~is_multimerge
multimerge_subsets = union_find(zip(*np.where(sources_targets)))
merge_groups = [merges[np.array(tuple(x))] for x in multimerge_subsets]
def join_two_tracks(
left_track: np.ndarray, right_track: np.ndarray
) -> np.ndarray:
"""Join two tracks and return the new one."""
new_track = left_track.copy()
# find last positive element by inverting track
end = find_1st(left_track[::-1], 0, cmp_larger)
# merge tracks into one
new_track[-end:] = right_track[-end:]
return new_track
sorted_merges = list(map(sort_association, merge_groups))
# Ensure that source and target are at the edges
return [
*sorted_merges,
*[[event] for event in merges[is_monomerge]],
]
##################################################################
def union_find(lsts):
......@@ -120,27 +189,3 @@ def sort_association(array: np.ndarray):
[res.append(x) for x in np.flip(order).flatten() if x not in res]
sorted_array = array[np.array(res)]
return sorted_array
def merge_association(
association: np.ndarray, merges: np.ndarray
) -> np.ndarray:
grouped_merges = group_merges(merges)
flat_indices = association.reshape(-1, 2)
comparison_mat = compare_indices(merges[:, 0], flat_indices)
valid_indices = comparison_mat.any(axis=0)
if valid_indices.any(): # Where valid, perform transformation
replacement_d = {}
for dataset in grouped_merges:
for k in dataset:
replacement_d[tuple(k[0])] = dataset[-1][1]
flat_indices[valid_indices] = [
replacement_d[tuple(i)] for i in flat_indices[valid_indices]
]
merged_indices = flat_indices.reshape(-1, 2, 2)
return merged_indices
......@@ -54,7 +54,7 @@ class DatasetLocalABC(ABC):
Abstract Base class to find local files, either OME-XML or raw images.
"""
_valid_suffixes = ("tiff", "png", "zarr")
_valid_suffixes = ("tiff", "png", "zarr", "tif")
_valid_meta_suffixes = ("txt", "log")
def __init__(self, dpath: t.Union[str, Path], *args, **kwargs):
......
......@@ -30,14 +30,14 @@ from agora.io.metadata import dir_to_meta, dispatch_metadata_parser
def get_examples_dir():
"""Get examples directory which stores dummy image for tiler"""
"""Get examples directory that stores dummy image for tiler."""
return files("aliby").parent.parent / "examples" / "tiler"
def instantiate_image(
source: t.Union[str, int, t.Dict[str, str], Path], **kwargs
):
"""Wrapper to instatiate the appropiate image
"""Wrapper to instantiate the appropriate image
Parameters
----------
......@@ -55,26 +55,26 @@ def instantiate_image(
def dispatch_image(source: t.Union[str, int, t.Dict[str, str], Path]):
"""
Wrapper to pick the appropiate Image class depending on the source of data.
"""
"""Pick the appropriate Image class depending on the source of data."""
if isinstance(source, (int, np.int64)):
from aliby.io.omero import Image
instatiator = Image
instantiator = Image
elif isinstance(source, dict) or (
isinstance(source, (str, Path)) and Path(source).is_dir()
):
if Path(source).suffix == ".zarr":
instatiator = ImageZarr
instantiator = ImageZarr
else:
instatiator = ImageDir
instantiator = ImageDir
elif isinstance(source, Path) and source.is_file():
# my addition
instantiator = ImageLocalOME
elif isinstance(source, str) and Path(source).is_file():
instatiator = ImageLocalOME
instantiator = ImageLocalOME
else:
raise Exception(f"Invalid data source at {source}")
return instatiator
return instantiator
class BaseLocalImage(ABC):
......@@ -82,6 +82,7 @@ class BaseLocalImage(ABC):
Base Image class to set path and provide context management method.
"""
# default image order
_default_dimorder = "tczyx"
def __init__(self, path: t.Union[str, Path]):
......@@ -98,8 +99,7 @@ class BaseLocalImage(ABC):
return False
def rechunk_data(self, img):
# Format image using x and y size from metadata.
"""Format image using x and y size from metadata."""
self._rechunked_img = da.rechunk(
img,
chunks=(
......@@ -145,16 +145,16 @@ class ImageLocalOME(BaseLocalImage):
in which a multidimensional tiff image contains the metadata.
"""
def __init__(self, path: str, dimorder=None):
def __init__(self, path: str, dimorder=None, **kwargs):
super().__init__(path)
self._id = str(path)
self.set_meta(str(path))
def set_meta(self):
def set_meta(self, path):
meta = dict()
try:
with TiffFile(path) as f:
self._meta = xmltodict.parse(f.ome_metadata)["OME"]
for dim in self.dimorder:
meta["size_" + dim.lower()] = int(
self._meta["Image"]["Pixels"]["@Size" + dim]
......@@ -165,21 +165,19 @@ class ImageLocalOME(BaseLocalImage):
]
meta["name"] = self._meta["Image"]["@Name"]
meta["type"] = self._meta["Image"]["Pixels"]["@Type"]
except Exception as e: # Images not in OMEXML
except Exception as e:
# images not in OMEXML
print("Warning:Metadata not found: {}".format(e))
print(
f"Warning: No dimensional info provided. Assuming {self._default_dimorder}"
"Warning: No dimensional info provided. "
f"Assuming {self._default_dimorder}"
)
# Mark non-existent dimensions for padding
# mark non-existent dimensions for padding
self.base = self._default_dimorder
# self.ids = [self.index(i) for i in dimorder]
self._dimorder = base
self._dimorder = self.base
self._meta = meta
# self._meta["name"] = Path(path).name.split(".")[0]
@property
def name(self):
......@@ -246,7 +244,7 @@ class ImageDir(BaseLocalImage):
It inherits from BaseLocalImage so we only override methods that are critical.
Assumptions:
- One folders per position.
- One folder per position.
- Images are flat.
- Channel, Time, z-stack and the others are determined by filenames.
- Provides Dimorder as it is set in the filenames, or expects order during instatiation
......@@ -318,7 +316,7 @@ class ImageZarr(BaseLocalImage):
print(f"Could not add size info to metadata: {e}")
def get_data_lazy(self) -> da.Array:
"""Return 5D dask array. For lazy-loading local multidimensional zarr files"""
"""Return 5D dask array for lazy-loading local multidimensional zarr files."""
return self._img
def add_size_to_meta(self):
......
This diff is collapsed.
"""
Tiler: Divides images into smaller tiles.
The tasks of the Tiler are selecting regions of interest, or tiles, of images - with one trap per tile, correcting for the drift of the microscope stage over time, and handling errors and bridging between the image data and Aliby’s image-processing steps.
The tasks of the Tiler are selecting regions of interest, or tiles, of
images - with one trap per tile, correcting for the drift of the microscope
stage over time, and handling errors and bridging between the image data
and Aliby’s image-processing steps.
Tiler subclasses deal with either network connections or local files.
To find tiles, we use a two-step process: we analyse the bright-field image to produce the template of a trap, and we fit this template to the image to find the tiles' centres.
To find tiles, we use a two-step process: we analyse the bright-field image
to produce the template of a trap, and we fit this template to the image to
find the tiles' centres.
We use texture-based segmentation (entropy) to split the image into foreground -- cells and traps -- and background, which we then identify with an Otsu filter. Two methods are used to produce a template trap from these regions: pick the trap with the smallest minor axis length and average over all validated traps.
We use texture-based segmentation (entropy) to split the image into
foreground -- cells and traps -- and background, which we then identify with
an Otsu filter. Two methods are used to produce a template trap from these
regions: pick the trap with the smallest minor axis length and average over
all validated traps.
A peak-identifying algorithm recovers the x and y-axis location of traps in the original image, and we choose the approach to template that identifies the most tiles.
A peak-identifying algorithm recovers the x and y-axis location of traps in
the original image, and we choose the approach to template that identifies
the most tiles.
The experiment is stored as an array with a standard indexing order of (Time, Channels, Z-stack, X, Y).
The experiment is stored as an array with a standard indexing order of
(Time, Channels, Z-stack, X, Y).
"""
import logging
import re
......@@ -169,7 +181,7 @@ class TileLocations:
return cls(initial_location, tile_size, max_size, drifts=[])
@classmethod
def read_hdf5(cls, file):
def read_h5(cls, file):
"""Instantiate from a h5 file."""
with h5py.File(file, "r") as hfile:
tile_info = hfile["trap_info"]
......@@ -316,7 +328,7 @@ class Tiler(StepABC):
Path to a directory of h5 files
parameters: an instance of TileParameters (optional)
"""
tile_locs = TileLocations.read_hdf5(filepath)
tile_locs = TileLocations.read_h5(filepath)
metadata = BridgeH5(filepath).meta_h5
metadata["channels"] = image.metadata["channels"]
if parameters is None:
......@@ -332,7 +344,7 @@ class Tiler(StepABC):
return tiler
@lru_cache(maxsize=2)
def get_tc(self, t: int, c: int) -> np.ndarray:
def get_tc(self, tp: int, c: int) -> np.ndarray:
"""
Load image using dask.
......@@ -345,7 +357,7 @@ class Tiler(StepABC):
Parameters
----------
t: integer
tp: integer
An index for a time point
c: integer
An index for a channel
......@@ -354,10 +366,10 @@ class Tiler(StepABC):
-------
full: an array of images
"""
full = self.image[t, c]
if hasattr(full, "compute"): # If using dask fetch images here
full = self.image[tp, c]
if hasattr(full, "compute"):
# if using dask fetch images
full = full.compute(scheduler="synchronous")
return full
@property
......@@ -558,9 +570,8 @@ class Tiler(StepABC):
Returns
-------
res: array
Data arranged as (tiles, channels, time points, X, Y, Z)
Data arranged as (tiles, channels, Z, X, Y)
"""
# FIXME add support for sub-tiling a tile
# FIXME can we ignore z
if channels is None:
channels = [0]
......@@ -571,8 +582,7 @@ class Tiler(StepABC):
for c in channels:
# only return requested z
val = self.get_tp_data(tp, c)[:, z]
# starts with the order: tiles, z, y, x
# returns the order: tiles, C, T, Z, X, Y
# starts with the order: tiles, Z, Y, X
val = np.expand_dims(val, axis=1)
res.append(val)
if tile_shape is not None:
......@@ -584,7 +594,10 @@ class Tiler(StepABC):
for tile_size, ax in zip(tile_shape, res[0].shape[-3:-2])
]
)
return np.stack(res, axis=1)
# convert to array with channels as first column
# final has dimensions (tiles, channels, 1, Z, X, Y)
final = np.stack(res, axis=1)
return final
@property
def ref_channel_index(self):
......@@ -593,7 +606,10 @@ class Tiler(StepABC):
def get_channel_index(self, channel: str or int) -> int or None:
"""
Find index for channel using regex. Returns the first matched string.
Find index for channel using regex.
Return the first matched string.
If self.channels is integers (no image metadata) it returns None.
If channel is integer
......@@ -602,10 +618,8 @@ class Tiler(StepABC):
channel: string or int
The channel or index to be used.
"""
if all(map(lambda x: isinstance(x, int), self.channels)):
channel = channel if isinstance(channel, int) else None
if isinstance(channel, str):
channel = find_channel_index(self.channels, channel)
return channel
......
This diff is collapsed.
......@@ -193,3 +193,50 @@ def min_maj_approximation(cell_mask) -> t.Tuple[int]:
# + distance from the center of cone top to edge of cone top
maj_ax = np.round(np.max(dn) + np.sum(cone_top) / 2)
return min_ax, maj_ax
def moment_of_inertia(cell_mask, trap_image):
"""
Find moment of inertia - a measure of homogeneity.
From iopscience.iop.org/article/10.1088/1742-6596/1962/1/012028
which cites ieeexplore.ieee.org/document/1057692.
"""
# set pixels not in cell to zero
trap_image[~cell_mask] = 0
x = trap_image
if np.any(x):
# x-axis : column=x-axis
columnvec = np.arange(1, x.shape[1] + 1, 1)[:, None].T
# y-axis : row=y-axis
rowvec = np.arange(1, x.shape[0] + 1, 1)[:, None]
# find raw moments
M00 = np.sum(x)
M10 = np.sum(np.multiply(x, columnvec))
M01 = np.sum(np.multiply(x, rowvec))
# find centroid
Xm = M10 / M00
Ym = M01 / M00
# find central moments
Mu00 = M00
Mu20 = np.sum(np.multiply(x, (columnvec - Xm) ** 2))
Mu02 = np.sum(np.multiply(x, (rowvec - Ym) ** 2))
# find invariants
Eta20 = Mu20 / Mu00 ** (1 + (2 + 0) / 2)
Eta02 = Mu02 / Mu00 ** (1 + (0 + 2) / 2)
# find moments of inertia
moi = Eta20 + Eta02
return moi
else:
return np.nan
def ratio(cell_mask, trap_image):
"""Find the median ratio between two fluorescence channels."""
if trap_image.ndim == 3 and trap_image.shape[-1] == 2:
fl_1 = trap_image[..., 0][cell_mask]
fl_2 = trap_image[..., 1][cell_mask]
div = np.median(fl_1 / fl_2)
else:
div = np.nan
return div
""" How to do the nuc Est Conv from MATLAB
"""
How to do the nuc Est Conv from MATLAB
Based on the code in MattSegCode/Matt Seg
GUI/@timelapseTraps/extractCellDataStacksParfor.m
Especially lines 342 to 399.
Especially lines 342 to 399.
This part only replicates the method to get the nuc_est_conv values
"""
import typing as t
......
# File with defaults for ease of use
import re
import typing as t
from pathlib import Path
import h5py
# should we move these functions here?
from aliby.tile.tiler import find_channel_name
......@@ -59,6 +58,7 @@ def exparams_from_meta(
for ch in extant_fluorescence_ch:
base["tree"][ch] = default_reduction_metrics
base["sub_bg"] = extant_fluorescence_ch
# additional extraction defaults if the channels are available
if "ph" in extras:
# SWAINLAB specific names
......
......@@ -44,5 +44,6 @@ def reduce_z(trap_image: np.ndarray, fun: t.Callable, axis: int = 0):
elif isinstance(fun, np.ufunc):
# optimise the reduction function if possible
return fun.reduce(trap_image, axis=axis)
else: # WARNING: Very slow, only use when no alternatives exist
else:
# WARNING: Very slow, only use when no alternatives exist
return np.apply_along_axis(fun, axis, trap_image)
......@@ -11,8 +11,10 @@ from extraction.core.functions.math_utils import div0
"""
Load functions for analysing cells and their background.
Note that inspect.getmembers returns a list of function names and functions,
and inspect.getfullargspec returns a function's arguments.
Note that inspect.getmembers returns a list of function names
and functions, and inspect.getfullargspec returns a
function's arguments.
"""
......@@ -66,7 +68,7 @@ def load_cellfuns():
# create dict of the core functions from cell.py - these functions apply to a single mask
cell_funs = load_cellfuns_core()
# create a dict of functions that apply the core functions to an array of cell_masks
CELLFUNS = {}
CELL_FUNS = {}
for f_name, f in cell_funs.items():
if isfunction(f):
......@@ -79,27 +81,27 @@ def load_cellfuns():
# function that applies f to m and img, the trap_image
return lambda m, img: trap_apply(f, m, img)
CELLFUNS[f_name] = tmp(f)
return CELLFUNS
CELL_FUNS[f_name] = tmp(f)
return CELL_FUNS
def load_trapfuns():
"""Load functions that are applied to an entire tile."""
TRAPFUNS = {
TRAP_FUNS = {
f[0]: f[1]
for f in getmembers(trap)
if isfunction(f[1])
and f[1].__module__.startswith("extraction.core.functions")
}
return TRAPFUNS
return TRAP_FUNS
def load_funs():
"""Combine all automatically loaded functions."""
CELLFUNS = load_cellfuns()
TRAPFUNS = load_trapfuns()
CELL_FUNS = load_cellfuns()
TRAP_FUNS = load_trapfuns()
# return dict of cell funs, dict of trap funs, and dict of both
return CELLFUNS, TRAPFUNS, {**TRAPFUNS, **CELLFUNS}
return CELL_FUNS, TRAP_FUNS, {**TRAP_FUNS, **CELL_FUNS}
def load_redfuns(
......