Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • swain-lab/aliby/aliby-mirror
  • swain-lab/aliby/alibylite
2 results
Show changes
Commits on Source (124)
Showing
with 3650 additions and 2487 deletions
Source diff could not be displayed: it is too large. Options to address this: view the blob.
...@@ -33,7 +33,7 @@ pathos = "^0.2.8" # Lambda-friendly multithreading ...@@ -33,7 +33,7 @@ pathos = "^0.2.8" # Lambda-friendly multithreading
p-tqdm = "^1.3.3" p-tqdm = "^1.3.3"
pandas = ">=1.3.3" pandas = ">=1.3.3"
py-find-1st = "^1.1.5" # Fast indexing py-find-1st = "^1.1.5" # Fast indexing
scikit-learn = ">=1.0.2" # Used for an extraction metric scikit-learn = ">=1.0.2, <1.3" # Used for an extraction metric
scipy = ">=1.7.3" scipy = ">=1.7.3"
# Pipeline + I/O # Pipeline + I/O
...@@ -46,14 +46,11 @@ xmltodict = "^0.13.0" # read ome-tiff metadata ...@@ -46,14 +46,11 @@ xmltodict = "^0.13.0" # read ome-tiff metadata
zarr = "^2.14.0" zarr = "^2.14.0"
GitPython = "^3.1.27" GitPython = "^3.1.27"
h5py = "2.10" # File I/O h5py = "2.10" # File I/O
aliby-baby = "^0.1.17"
# Networking # Networking
omero-py = { version = ">=5.6.2", optional = true } # contact omero server omero-py = { version = ">=5.6.2", optional = true } # contact omero server
# Baby segmentation
aliby-baby = {version = "^0.1.17", optional=true}
# Postprocessing # Postprocessing
[tool.poetry.group.pp.dependencies] [tool.poetry.group.pp.dependencies]
leidenalg = "^0.8.8" leidenalg = "^0.8.8"
...@@ -113,7 +110,6 @@ grid-strategy = {version = "^0.0.1", optional=true} ...@@ -113,7 +110,6 @@ grid-strategy = {version = "^0.0.1", optional=true}
[tool.poetry.extras] [tool.poetry.extras]
omero = ["omero-py"] omero = ["omero-py"]
baby = ["aliby-baby"]
[tool.black] [tool.black]
line-length = 79 line-length = 79
......
...@@ -7,7 +7,7 @@ from pathlib import Path ...@@ -7,7 +7,7 @@ from pathlib import Path
from time import perf_counter from time import perf_counter
from typing import Union from typing import Union
from flatten_dict import flatten from flatten_dict import flatten, unflatten
from yaml import dump, safe_load from yaml import dump, safe_load
from agora.logging import timer from agora.logging import timer
...@@ -17,16 +17,14 @@ atomic = t.Union[int, float, str, bool] ...@@ -17,16 +17,14 @@ atomic = t.Union[int, float, str, bool]
class ParametersABC(ABC): class ParametersABC(ABC):
""" """
Defines parameters as attributes and allows parameters to Define parameters typically for a step in the pipeline.
be converted to either a dictionary or to yaml.
Outputs can be either a dict or yaml.
No attribute should be called "parameters"! No attribute should be called "parameters"!
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
""" """Define parameters as attributes."""
Defines parameters as attributes
"""
assert ( assert (
"parameters" not in kwargs "parameters" not in kwargs
), "No attribute should be named parameters" ), "No attribute should be named parameters"
...@@ -35,8 +33,9 @@ class ParametersABC(ABC): ...@@ -35,8 +33,9 @@ class ParametersABC(ABC):
def to_dict(self, iterable="null") -> t.Dict: def to_dict(self, iterable="null") -> t.Dict:
""" """
Recursive function to return a nested dictionary of the Return a nested dictionary of the attributes of the class instance.
attributes of the class instance.
Use recursion.
""" """
if isinstance(iterable, dict): if isinstance(iterable, dict):
if any( if any(
...@@ -47,9 +46,11 @@ class ParametersABC(ABC): ...@@ -47,9 +46,11 @@ class ParametersABC(ABC):
] ]
): ):
return { return {
k: v.to_dict() k: (
if hasattr(v, "to_dict") v.to_dict()
else self.to_dict(v) if hasattr(v, "to_dict")
else self.to_dict(v)
)
for k, v in iterable.items() for k, v in iterable.items()
} }
else: else:
...@@ -62,7 +63,8 @@ class ParametersABC(ABC): ...@@ -62,7 +63,8 @@ class ParametersABC(ABC):
def to_yaml(self, path: Union[Path, str] = None): def to_yaml(self, path: Union[Path, str] = None):
""" """
Returns a yaml stream of the attributes of the class instance. Return a yaml stream of the attributes of the class instance.
If path is provided, the yaml stream is saved there. If path is provided, the yaml stream is saved there.
Parameters Parameters
...@@ -77,20 +79,19 @@ class ParametersABC(ABC): ...@@ -77,20 +79,19 @@ class ParametersABC(ABC):
@classmethod @classmethod
def from_dict(cls, d: dict): def from_dict(cls, d: dict):
"""Initialise from a dict of parameters."""
return cls(**d) return cls(**d)
@classmethod @classmethod
def from_yaml(cls, source: Union[Path, str]): def from_yaml(cls, source: Union[Path, str]):
""" """Initialise from a yaml filename or stdin."""
Returns instance from a yaml filename or stdin
"""
is_buffer = True is_buffer = True
try: try:
if Path(source).exists(): if Path(source).exists():
is_buffer = False is_buffer = False
except Exception as _: except Exception as e:
print(e)
assert isinstance(source, str), "Invalid source type." assert isinstance(source, str), "Invalid source type."
if is_buffer: if is_buffer:
params = safe_load(source) params = safe_load(source)
else: else:
...@@ -100,86 +101,48 @@ class ParametersABC(ABC): ...@@ -100,86 +101,48 @@ class ParametersABC(ABC):
@classmethod @classmethod
def default(cls, **kwargs): def default(cls, **kwargs):
"""Initialise allowing the default parameters to be potentially replaced."""
overriden_defaults = copy(cls._defaults) overriden_defaults = copy(cls._defaults)
for k, v in kwargs.items(): for k, v in kwargs.items():
overriden_defaults[k] = v overriden_defaults[k] = v
return cls.from_dict(overriden_defaults) return cls.from_dict(overriden_defaults)
def update(self, name: str, new_value): def update(self, name: str, new_value):
""" """Update a parameter in the nested dict of parameters."""
Update values recursively flat_params_dict = flatten(self.to_dict(), keep_empty_types=(dict,))
if name is a dictionary, replace data where existing found or add if not. names_found = [
It warns against type changes. param for param in flat_params_dict.keys() if name in param
]
If the existing structure under name is a dictionary, if len(names_found) == 1:
it looks for the first occurrence and modifies it accordingly. keys = names_found.pop()
if type(flat_params_dict[keys]) is not type(new_value):
If a leaf node that is to be changed is a collection, it adds the new elements. print("Warning:Changing type is risky.")
""" flat_params_dict[keys] = new_value
params_dict = unflatten(flat_params_dict)
assert name not in ( # replace all old values
"parameters", for key, value in params_dict.items():
"params", setattr(self, key, value)
), "Attribute can't be named params or parameters"
if name in self.__dict__:
if check_type_recursive(getattr(self, name), new_value):
print("Warnings:Type changes are risky")
if isinstance(getattr(self, name), dict):
flattened = flatten(self.to_dict())
names_found = [k for k in flattened.keys() if name in k]
found_idx = [keys.index(name) for keys in names_found]
assert len(names_found), f"{name} not found as key."
keys = None
if len(names_found) > 1:
for level in zip(found_idx, names_found):
if level == min(found_idx):
keys = level
print(
f"Warning: {name} was found in multiple keys. Selected {keys}"
)
break
else:
keys = names_found.pop()
if keys:
current_val = flattened.get(keys, None)
# if isinstance(current_val, t.Collection):
elif isinstance(getattr(self, name), t.Collection):
add_to_collection(getattr(self, name), new_value)
elif isinstance(getattr(self, name), set):
pass # TODO implement
new_d = getattr(self, name)
new_d.update(new_value)
setattr(self, name, new_d)
else: else:
setattr(self, name, new_value) print(f"Warning:{name} was neither recognised nor updated.")
def add_to_collection( def add_to_collection(
collection: t.Collection, value: t.Union[atomic, t.Collection] collection: t.Collection, element: t.Union[atomic, t.Collection]
): ):
# Adds element(s) in place. """Add elements to a collection, a list or set, in place."""
if not isinstance(value, t.Collection): if not isinstance(element, t.Collection):
value = [value] element = [element]
if isinstance(collection, list): if isinstance(collection, list):
collection += value collection += element
elif isinstance(collection, set): elif isinstance(collection, set):
collection.update(value) collection.update(element)
class ProcessABC(ABC): class ProcessABC(ABC):
""" """
Base class for processes. Base class for processes.
Defines parameters as attributes and requires run method to be defined.
Define parameters as attributes and requires a run method.
""" """
def __init__(self, parameters): def __init__(self, parameters):
...@@ -190,8 +153,8 @@ class ProcessABC(ABC): ...@@ -190,8 +153,8 @@ class ProcessABC(ABC):
""" """
self._parameters = parameters self._parameters = parameters
# convert parameters to dictionary # convert parameters to dictionary
# and then define each parameter as an attribute
for k, v in parameters.to_dict().items(): for k, v in parameters.to_dict().items():
# define each parameter as an attribute
setattr(self, k, v) setattr(self, k, v)
@property @property
...@@ -202,32 +165,12 @@ class ProcessABC(ABC): ...@@ -202,32 +165,12 @@ class ProcessABC(ABC):
def run(self): def run(self):
pass pass
def _log(self, message: str, level: str = "warning"): def log(self, message: str, level: str = "warning"):
# Log messages in the corresponding level """Log messages at the corresponding level."""
logger = logging.getLogger("aliby") logger = logging.getLogger("aliby")
getattr(logger, level)(f"{self.__class__.__name__}: {message}") getattr(logger, level)(f"{self.__class__.__name__}: {message}")
def check_type_recursive(val1, val2):
same_types = True
if not isinstance(val1, type(val2)) and not all(
type(x) in (Path, str) for x in (val1, val2) # Ignore str->path
):
return False
if not isinstance(val1, t.Iterable) and not isinstance(val2, t.Iterable):
return isinstance(val1, type(val2))
elif isinstance(val1, (tuple, list)) and isinstance(val2, (tuple, list)):
return bool(
sum([check_type_recursive(v1, v2) for v1, v2 in zip(val1, val2)])
)
elif isinstance(val1, dict) and isinstance(val2, dict):
if not len(val1) or not len(val2):
return False
for k in val2.keys():
same_types = same_types and check_type_recursive(val1[k], val2[k])
return same_types
class StepABC(ProcessABC): class StepABC(ProcessABC):
""" """
Base class that expands on ProcessABC to include tools used by Aliby steps. Base class that expands on ProcessABC to include tools used by Aliby steps.
...@@ -243,11 +186,9 @@ class StepABC(ProcessABC): ...@@ -243,11 +186,9 @@ class StepABC(ProcessABC):
@timer @timer
def run_tp(self, tp: int, **kwargs): def run_tp(self, tp: int, **kwargs):
""" """Time and log the timing of a step."""
Time and log the timing of a step.
"""
return self._run_tp(tp, **kwargs) return self._run_tp(tp, **kwargs)
def run(self): def run(self):
# Replace run with run_tp # Replace run with run_tp
raise Warning("Steps use run_tp instead of run") raise Warning("Steps use run_tp instead of run.")
""" """
Tools to interact with h5 files and handle data consistently. Tools to interact with h5 files and handle data consistently.
""" """
import collections import collections
import logging import logging
import typing as t import typing as t
...@@ -23,20 +24,19 @@ class BridgeH5: ...@@ -23,20 +24,19 @@ class BridgeH5:
"""Initialise with the name of the h5 file.""" """Initialise with the name of the h5 file."""
self.filename = filename self.filename = filename
if flag is not None: if flag is not None:
self._hdf = h5py.File(filename, flag) self.hdf = h5py.File(filename, flag)
self._filecheck assert (
"cell_info" in self.hdf
), "Invalid file. No 'cell_info' found."
def _log(self, message: str, level: str = "warn"): def log(self, message: str, level: str = "warn"):
# Log messages in the corresponding level # Log messages in the corresponding level
logger = logging.getLogger("aliby") logger = logging.getLogger("aliby")
getattr(logger, level)(f"{self.__class__.__name__}: {message}") getattr(logger, level)(f"{self.__class__.__name__}: {message}")
def _filecheck(self):
assert "cell_info" in self._hdf, "Invalid file. No 'cell_info' found."
def close(self): def close(self):
"""Close the h5 file.""" """Close the h5 file."""
self._hdf.close() self.hdf.close()
@property @property
def meta_h5(self) -> t.Dict[str, t.Any]: def meta_h5(self) -> t.Dict[str, t.Any]:
...@@ -83,7 +83,7 @@ class BridgeH5: ...@@ -83,7 +83,7 @@ class BridgeH5:
def get_npairs_over_time(self, nstepsback=2): def get_npairs_over_time(self, nstepsback=2):
tree = self.cell_tree tree = self.cell_tree
npairs = [] npairs = []
for tp in self._hdf["cell_info"]["processed_timepoints"][()]: for tp in self.hdf["cell_info"]["processed_timepoints"][()]:
tmp_tree = { tmp_tree = {
k: {k2: v2 for k2, v2 in v.items() if k2 <= tp} k: {k2: v2 for k2, v2 in v.items() if k2 <= tp}
for k, v in tree.items() for k, v in tree.items()
...@@ -115,7 +115,7 @@ class BridgeH5: ...@@ -115,7 +115,7 @@ class BridgeH5:
---------- ----------
Nested dictionary where keys (or branches) are the upper levels and the leaves are the last element of :fields:. Nested dictionary where keys (or branches) are the upper levels and the leaves are the last element of :fields:.
""" """
zipped_info = (*zip(*[self._hdf["cell_info"][f][()] for f in fields]),) zipped_info = (*zip(*[self.hdf["cell_info"][f][()] for f in fields]),)
return recursive_groupsort(zipped_info) return recursive_groupsort(zipped_info)
......
...@@ -14,183 +14,170 @@ from utils_find_1st import cmp_equal, find_1st ...@@ -14,183 +14,170 @@ from utils_find_1st import cmp_equal, find_1st
class Cells: class Cells:
""" """
Extracts information from an h5 file. This class accesses: Extract information from an h5 file.
Use output from BABY to find cells detected, get, and fill, edge masks
and retrieve mother-bud relationships.
This class accesses in the h5 file:
'cell_info', which contains 'angles', 'cell_label', 'centres', 'cell_info', which contains 'angles', 'cell_label', 'centres',
'edgemasks', 'ellipse_dims', 'mother_assign', 'mother_assign_dynamic', 'edgemasks', 'ellipse_dims', 'mother_assign', 'mother_assign_dynamic',
'radii', 'timepoint', 'trap'. 'radii', 'timepoint', and 'trap'. All of which except for 'edgemasks'
All of these except for 'edgemasks' are a 1D ndarray. are a 1D ndarray.
'trap_info', which contains 'drifts', 'trap_locations' 'trap_info', which contains 'drifts', and 'trap_locations'.
The "timepoint", "cell_label", and "trap" variables are mutually consistent
1D lists.
Examples are self["timepoint"][self.get_idx(1, 3)] to find the time points
where cell 1 was present in trap 3.
""" """
def __init__(self, filename, path="cell_info"): def __init__(self, filename, path="cell_info"):
"""Initialise from a filename."""
self.filename: t.Optional[t.Union[str, Path]] = filename self.filename: t.Optional[t.Union[str, Path]] = filename
self.cinfo_path: t.Optional[str] = path self.cinfo_path: t.Optional[str] = path
self._edgemasks: t.Optional[str] = None self._edgemasks: t.Optional[str] = None
self._tile_size: t.Optional[int] = None self._tile_size: t.Optional[int] = None
def __getitem__(self, item):
"""
Dynamically fetch data from the h5 file and save as an attribute.
These attributes are accessed like dict keys.
"""
assert item != "edgemasks", "Edgemasks must not be loaded as a whole"
_item = "_" + item
if not hasattr(self, _item):
setattr(self, _item, self.fetch(item))
return getattr(self, _item)
def fetch(self, path):
"""Get data from the h5 file."""
with h5py.File(self.filename, mode="r") as f:
return f[self.cinfo_path][path][()]
@classmethod @classmethod
def from_source(cls, source: t.Union[Path, str]): def from_source(cls, source: t.Union[Path, str]):
"""Ensure initiating file is a Path object."""
return cls(Path(source)) return cls(Path(source))
def _log(self, message: str, level: str = "warn"): def log(self, message: str, level: str = "warn"):
# Log messages in the corresponding level """Log messages in the corresponding level."""
logger = logging.getLogger("aliby") logger = logging.getLogger("aliby")
getattr(logger, level)(f"{self.__class__.__name__}: {message}") getattr(logger, level)(f"{self.__class__.__name__}: {message}")
@staticmethod @staticmethod
def _asdense(array: np.ndarray): def asdense(array: np.ndarray):
"""Convert sparse array to dense array."""
if not isdense(array): if not isdense(array):
array = array.todense() array = array.todense()
return array return array
@staticmethod @staticmethod
def _astype(array: np.ndarray, kind: str): def astype(array: np.ndarray, kind: str):
# Convert sparse arrays if needed and if kind is 'mask' it fills the outline """Convert sparse arrays if needed; if kind is 'mask' fill the outline."""
array = Cells._asdense(array) array = Cells.asdense(array)
if kind == "mask": if kind == "mask":
array = ndimage.binary_fill_holes(array).astype(bool) array = ndimage.binary_fill_holes(array).astype(bool)
return array return array
def _get_idx(self, cell_id: int, trap_id: int): def get_idx(self, cell_id: int, trap_id: int):
# returns boolean array of time points where both the cell with cell_id and the trap with trap_id exist """Return boolean array giving indices for a cell_id and trap_id."""
return (self["cell_label"] == cell_id) & (self["trap"] == trap_id) return (self["cell_label"] == cell_id) & (self["trap"] == trap_id)
@property @property
def max_labels(self) -> t.List[int]: def max_labels(self) -> t.List[int]:
return [max((0, *self.labels_in_trap(i))) for i in range(self.ntraps)] """Return the maximum cell label per tile."""
return [
max((0, *self.cell_labels_in_trap(i))) for i in range(self.ntraps)
]
@property @property
def max_label(self) -> int: def max_label(self) -> int:
"""Return the maximum cell label over all tiles."""
return sum(self.max_labels) return sum(self.max_labels)
@property @property
def ntraps(self) -> int: def ntraps(self) -> int:
# find the number of traps from the h5 file """Find the number of tiles, or traps."""
with h5py.File(self.filename, mode="r") as f: with h5py.File(self.filename, mode="r") as f:
return len(f["trap_info/trap_locations"][()]) return len(f["trap_info/trap_locations"][()])
@property @property
def tinterval(self): def tinterval(self):
"""Return time interval in seconds."""
with h5py.File(self.filename, mode="r") as f: with h5py.File(self.filename, mode="r") as f:
return f.attrs["time_settings/timeinterval"] return f.attrs["time_settings/timeinterval"]
@property @property
def traps(self) -> t.List[int]: def traps(self) -> t.List[int]:
# returns a list of traps """List unique tile, or trap, IDs."""
return list(set(self["trap"])) return list(set(self["trap"]))
@property @property
def tile_size(self) -> t.Union[int, t.Tuple[int], None]: def tile_size(self) -> t.Union[int, t.Tuple[int], None]:
"""Give the x- and y- sizes of a tile."""
if self._tile_size is None: if self._tile_size is None:
with h5py.File(self.filename, mode="r") as f: with h5py.File(self.filename, mode="r") as f:
# self._tile_size = f["trap_info/tile_size"][0]
self._tile_size = f["cell_info/edgemasks"].shape[1:] self._tile_size = f["cell_info/edgemasks"].shape[1:]
return self._tile_size return self._tile_size
def nonempty_tp_in_trap(self, trap_id: int) -> set: def nonempty_tp_in_trap(self, trap_id: int) -> set:
# given a trap_id returns time points in which cells are available """Given a tile, return time points for which cells are available."""
return set(self["timepoint"][self["trap"] == trap_id]) return set(self["timepoint"][self["trap"] == trap_id])
@property @property
def edgemasks(self) -> t.List[np.ndarray]: def edgemasks(self) -> t.List[np.ndarray]:
# returns the masks per tile """Return a list of masks for every cell at every trap and time point."""
if self._edgemasks is None: if self._edgemasks is None:
edgem_path: str = "edgemasks" edgem_path: str = "edgemasks"
self._edgemasks = self._fetch(edgem_path) self._edgemasks = self.fetch(edgem_path)
return self._edgemasks return self._edgemasks
@property @property
def labels(self) -> t.List[t.List[int]]: def labels(self) -> t.List[t.List[int]]:
""" """Return all cell labels per tile as a set for all tiles."""
Return all cell labels in object return [self.cell_labels_in_trap(trap) for trap in range(self.ntraps)]
We use mother_assign to list traps because it is the only property that appears even
when no cells are found
"""
return [self.labels_in_trap(trap) for trap in range(self.ntraps)]
def max_labels_in_frame(self, frame: int) -> t.List[int]: def max_labels_in_frame(self, final_time_point: int) -> t.List[int]:
# Return the maximum label for each trap in the given frame """Get the maximal cell label for each tile within a frame of time."""
max_labels = [ max_labels = [
self["cell_label"][ self["cell_label"][
(self["timepoint"] <= frame) & (self["trap"] == trap_id) (self["timepoint"] <= final_time_point)
& (self["trap"] == trap_id)
] ]
for trap_id in range(self.ntraps) for trap_id in range(self.ntraps)
] ]
return [max([0, *labels]) for labels in max_labels] return [max([0, *labels]) for labels in max_labels]
def where(self, cell_id: int, trap_id: int): def where(self, cell_id: int, trap_id: int):
""" """Return time points, indices, and edge masks for a cell and trap."""
Parameters idx = self.get_idx(cell_id, trap_id)
----------
cell_id: int
Cell index
trap_id: int
Trap index
Returns
----------
indices int array
boolean mask array
edge_ix int array
"""
indices = self._get_idx(cell_id, trap_id)
edgem_ix = self._edgem_where(cell_id, trap_id)
return ( return (
self["timepoint"][indices], self["timepoint"][idx],
indices, idx,
edgem_ix, self.edgemasks_where(cell_id, trap_id),
) )
def mask(self, cell_id, trap_id): def mask(self, cell_id, trap_id):
""" """Return the times and the filled edge masks for a cell and trap."""
Returns the times and the binary masks of a given cell in a given tile.
Parameters
----------
cell_id : int
The unique ID of the cell.
tile_id : int
The unique ID of the tile.
Returns
-------
Tuple[np.ndarray, np.ndarray]
The times when the binary masks were taken and the binary masks of the given cell in the given tile.
"""
times, outlines = self.outline(cell_id, trap_id) times, outlines = self.outline(cell_id, trap_id)
return times, np.array( return times, np.array(
[ndimage.morphology.binary_fill_holes(o) for o in outlines] [ndimage.morphology.binary_fill_holes(o) for o in outlines]
) )
def at_time( def at_time(
self, timepoint: t.Iterable[int], kind="mask" self, timepoint: int, kind="mask"
) -> t.List[t.List[np.ndarray]]: ) -> t.List[t.List[np.ndarray]]:
""" """Return a dict with traps as keys and cell masks as values for a time point."""
Returns a list of lists of binary masks in a given list of time points. idx = self["timepoint"] == timepoint
traps = self["trap"][idx]
Parameters edgemasks = self.edgemasks_from_idx(idx)
----------
timepoints : Iterable[int]
The list of time points for which to return the binary masks.
kind : str, optional
The type of binary masks to return, by default "mask".
Returns
-------
List[List[np.ndarray]]
A list of lists with binary masks grouped by tile IDs.
"""
ix = self["timepoint"] == timepoint
traps = self["trap"][ix]
edgemasks = self._edgem_from_masking(ix)
masks = [ masks = [
self._astype(edgemask, kind) Cells.astype(edgemask, kind)
for edgemask in edgemasks for edgemask in edgemasks
if edgemask.any() if edgemask.any()
] ]
...@@ -199,22 +186,7 @@ class Cells: ...@@ -199,22 +186,7 @@ class Cells:
def at_times( def at_times(
self, timepoints: t.Iterable[int], kind="mask" self, timepoints: t.Iterable[int], kind="mask"
) -> t.List[t.List[np.ndarray]]: ) -> t.List[t.List[np.ndarray]]:
""" """Return a list of lists of cell masks one for specified time point."""
Returns a list of lists of binary masks for a given list of time points.
Parameters
----------
timepoints : Iterable[int]
The list of time points for which to return the binary masks.
kind : str, optional
The type of binary masks to return, by default "mask".
Returns
-------
List[List[np.ndarray]]
A list of lists with binary masks grouped by tile IDs.
"""
return [ return [
[ [
np.stack(tile_masks) if len(tile_masks) else [] np.stack(tile_masks) if len(tile_masks) else []
...@@ -226,98 +198,84 @@ class Cells: ...@@ -226,98 +198,84 @@ class Cells:
def group_by_traps( def group_by_traps(
self, traps: t.Collection, cell_labels: t.Collection self, traps: t.Collection, cell_labels: t.Collection
) -> t.Dict[int, t.List[int]]: ) -> t.Dict[int, t.List[int]]:
""" """Return a dict with traps as keys and a list of labels as values."""
Returns a dict with traps as keys and list of labels as value.
Note that the total number of traps are calculated from Cells.traps.
"""
iterator = groupby(zip(traps, cell_labels), lambda x: x[0]) iterator = groupby(zip(traps, cell_labels), lambda x: x[0])
d = {key: [x[1] for x in group] for key, group in iterator} d = {key: [x[1] for x in group] for key, group in iterator}
d = {i: d.get(i, []) for i in self.traps} d = {i: d.get(i, []) for i in self.traps}
return d return d
def labels_in_trap(self, trap_id: int) -> t.Set[int]: def cell_labels_in_trap(self, trap_id: int) -> t.Set[int]:
# return set of cell ids for a given trap """Return unique cell labels for a given trap."""
return set((self["cell_label"][self["trap"] == trap_id])) return set((self["cell_label"][self["trap"] == trap_id]))
def labels_at_time(self, timepoint: int) -> t.Dict[int, t.List[int]]: def labels_at_time(self, timepoint: int) -> t.Dict[int, t.List[int]]:
"""Return a dict with traps as keys and cell labels as values for a time point."""
labels = self["cell_label"][self["timepoint"] == timepoint] labels = self["cell_label"][self["timepoint"] == timepoint]
traps = self["trap"][self["timepoint"] == timepoint] traps = self["trap"][self["timepoint"] == timepoint]
return self.group_by_traps(traps, labels) return self.group_by_traps(traps, labels)
def __getitem__(self, item): def edgemasks_from_idx(self, idx):
assert item != "edgemasks", "Edgemasks must not be loaded as a whole" """Get edge masks from the h5 file."""
_item = "_" + item
if not hasattr(self, _item):
setattr(self, _item, self._fetch(item))
return getattr(self, _item)
def _fetch(self, path):
with h5py.File(self.filename, mode="r") as f: with h5py.File(self.filename, mode="r") as f:
return f[self.cinfo_path][path][()] edgem = f[self.cinfo_path + "/edgemasks"][idx, ...]
def _edgem_from_masking(self, mask):
with h5py.File(self.filename, mode="r") as f:
edgem = f[self.cinfo_path + "/edgemasks"][mask, ...]
return edgem return edgem
def _edgem_where(self, cell_id, trap_id): def edgemasks_where(self, cell_id, trap_id):
id_mask = self._get_idx(cell_id, trap_id) """Get the edge masks for a given cell and trap for all time points."""
edgem = self._edgem_from_masking(id_mask) idx = self.get_idx(cell_id, trap_id)
edgemasks = self.edgemasks_from_idx(idx)
return edgem return edgemasks
def outline(self, cell_id: int, trap_id: int): def outline(self, cell_id: int, trap_id: int):
id_mask = self._get_idx(cell_id, trap_id) """Get times and edge masks for a given cell and trap."""
times = self["timepoint"][id_mask] idx = self.get_idx(cell_id, trap_id)
times = self["timepoint"][idx]
return times, self._edgem_from_masking(id_mask) return times, self.edgemasks_from_idx(idx)
@property @property
def ntimepoints(self) -> int: def ntimepoints(self) -> int:
"""Return total number of time points in the experiment."""
return self["timepoint"].max() + 1 return self["timepoint"].max() + 1
@cached_property @cached_property
def _cells_vs_tps(self): def cells_vs_tps(self):
# Binary matrix showing the presence of all cells in all time points """Boolean matrix showing when cells are present for all time points."""
ncells_per_tile = [len(x) for x in self.labels] total_ncells = sum([len(x) for x in self.labels])
cells_vs_tps = np.zeros( cells_vs_tps = np.zeros((total_ncells, self.ntimepoints), dtype=bool)
(sum(ncells_per_tile), self.ntimepoints), dtype=bool
)
cells_vs_tps[ cells_vs_tps[
self._cell_cumsum[self["trap"]] + self["cell_label"] - 1, self.cell_cumlsum[self["trap"]] + self["cell_label"] - 1,
self["timepoint"], self["timepoint"],
] = True ] = True
return cells_vs_tps return cells_vs_tps
@cached_property @cached_property
def _cell_cumsum(self): def cell_cumlsum(self):
# Cumulative sum indicating the number of cells per tile """Find cumulative sum over tiles of the number of cells present."""
ncells_per_tile = [len(x) for x in self.labels] ncells_per_tile = [len(x) for x in self.labels]
cumsum = np.roll(np.cumsum(ncells_per_tile), shift=1) cumsum = np.roll(np.cumsum(ncells_per_tile), shift=1)
cumsum[0] = 0 cumsum[0] = 0
return cumsum return cumsum
def _flat_index_to_tuple_location(self, idx: int) -> t.Tuple[int, int]: def index_to_tile_and_cell(self, idx: int) -> t.Tuple[int, int]:
# Convert a cell index to a tuple """Convert an index to the equivalent pair of tile and cell IDs."""
# Note that it assumes tiles and cell labels are flattened, but tile_id = int(np.where(idx + 1 > self.cell_cumlsum)[0][-1])
# it is agnostic to tps cell_label = idx - self.cell_cumlsum[tile_id] + 1
tile_id = int(np.where(idx + 1 > self._cell_cumsum)[0][-1])
cell_label = idx - self._cell_cumsum[tile_id] + 1
return tile_id, cell_label return tile_id, cell_label
@property @property
def _tiles_vs_cells_vs_tps(self): def tiles_vs_cells_vs_tps(self):
"""
Boolean matrix showing if a cell is present.
The matrix is indexed by trap, cell label, and time point.
"""
ncells_mat = np.zeros( ncells_mat = np.zeros(
(self.ntraps, self["cell_label"].max(), self.ntimepoints), (self.ntraps, self["cell_label"].max(), self.ntimepoints),
dtype=bool, dtype=bool,
) )
ncells_mat[ ncells_mat[self["trap"], self["cell_label"] - 1, self["timepoint"]] = (
self["trap"], self["cell_label"] - 1, self["timepoint"] True
] = True )
return ncells_mat return ncells_mat
def cell_tp_where( def cell_tp_where(
...@@ -325,32 +283,37 @@ class Cells: ...@@ -325,32 +283,37 @@ class Cells:
min_consecutive_tps: int = 15, min_consecutive_tps: int = 15,
interval: None or t.Tuple[int, int] = None, interval: None or t.Tuple[int, int] = None,
): ):
"""
Find cells present for all time points in a sliding window of time.
The result can be restricted to a particular interval of time.
"""
window = sliding_window_view( window = sliding_window_view(
self._cells_vs_tps, min_consecutive_tps, axis=1 self.cells_vs_tps, min_consecutive_tps, axis=1
) )
tp_min = window.sum(axis=-1) == min_consecutive_tps tp_min = window.sum(axis=-1) == min_consecutive_tps
# apply a filter to restrict to an interval of time
# Apply an interval filter to focucs on a slice
if interval is not None: if interval is not None:
interval = tuple(np.array(interval)) interval = tuple(np.array(interval))
else: else:
interval = (0, window.shape[1]) interval = (0, window.shape[1])
low_boundary, high_boundary = interval low_boundary, high_boundary = interval
tp_min[:, :low_boundary] = False tp_min[:, :low_boundary] = False
tp_min[:, high_boundary:] = False tp_min[:, high_boundary:] = False
return tp_min return tp_min
@lru_cache(20) @lru_cache(20)
def mothers_in_trap(self, trap_id: int): def mothers_in_trap(self, trap_id: int):
"""Return mothers at a trap."""
return self.mothers[trap_id] return self.mothers[trap_id]
@cached_property @cached_property
def mothers(self): def mothers(self):
""" """
Return nested list with final prediction of mother id for each cell Return a list of mother IDs for each cell in each tile.
Use Baby's "mother_assign_dynamic".
An ID of zero implies that no mother was assigned.
""" """
return self.mother_assign_from_dynamic( return self.mother_assign_from_dynamic(
self["mother_assign_dynamic"], self["mother_assign_dynamic"],
...@@ -362,73 +325,71 @@ class Cells: ...@@ -362,73 +325,71 @@ class Cells:
@cached_property @cached_property
def mothers_daughters(self) -> np.ndarray: def mothers_daughters(self) -> np.ndarray:
""" """
Return a single array with three columns, containing information about Return mother-daughter relationships for all tiles.
the mother-daughter relationships: tile, mothers and daughters.
Returns Returns
------- -------
np.ndarray mothers_daughters: np.ndarray
An array with shape (n, 3) where n is the number of mother-daughter pairs found. An array with shape (n, 3) where n is the number of mother-daughter
The columns contain: pairs found. The first column is the tile_id for the tile where the
- tile: the tile where the mother cell is located. mother cell is located. The second column is the cell index of a
- mothers: the index of the mother cell within the tile. mother cell in the tile. The third column is the index of the
- daughters: the index of the daughter cell within the tile. corresponding daughter cell.
""" """
nested_massign = self.mothers # list of arrays, one per tile, giving mothers of each cell in each tile
mothers = self.mothers
if sum([x for y in nested_massign for x in y]): if sum([x for y in mothers for x in y]):
mothers_daughters = np.array( mothers_daughters = np.array(
[ [
(tid, m, d) (trap_id, mother, bud)
for tid, trapcells in enumerate(nested_massign) for trap_id, trapcells in enumerate(mothers)
for d, m in enumerate(trapcells, 1) for bud, mother in enumerate(trapcells, start=1)
if m if mother
], ],
dtype=np.uint16, dtype=np.uint16,
) )
else: else:
mothers_daughters = np.array([]) mothers_daughters = np.array([])
self._log("No mother-daughters assigned") self.log("No mother-daughters assigned")
return mothers_daughters return mothers_daughters
@staticmethod @staticmethod
def mother_assign_to_mb_matrix(ma: t.List[np.array]): def mother_assign_to_mb_matrix(ma: t.List[np.array]):
""" """
Convert from a list of lists of mother-bud paired assignments to a Convert a list of mother-daughters into a boolean sparse matrix.
sparse matrix with a boolean dtype. The rows correspond to
to daughter buds. The values are boolean and indicate whether a Each row in the matrix correspond to daughter buds.
given cell is a mother cell and a given daughter bud is assigned If an entry is True, a given cell is a mother cell and a given
to the mother cell in the next timepoint. daughter bud is assigned to the mother cell in the next time point.
Parameters: Parameters:
----------- -----------
ma : list of lists of integers ma : list of lists of integers
A list of lists of mother-bud assignments. The i-th sublist contains the A list of lists of mother-bud assignments.
bud assignments for the i-th tile. The integers in each sublist The i-th sublist contains the bud assignments for the i-th tile.
represent the mother label, if it is zero no mother was found. The integers in each sublist represent the mother label, with zero
implying no mother found.
Returns: Returns:
-------- --------
mb_matrix : boolean numpy array of shape (n, m) mb_matrix : boolean numpy array of shape (n, m)
An n x m boolean numpy array where n is the total number of cells (sum An n x m array where n is the total number of cells (sum
of the lengths of all sublists in ma) and m is the maximum number of buds of the lengths of all sublists in ma) and m is the maximum
assigned to any mother cell in ma. The value at (i, j) is True if cell i number of buds assigned to any mother cell in ma.
is a daughter cell and cell j is its mother assigned to i. The value at (i, j) is True if cell i is a daughter cell and
cell j is its assigned mother.
Examples: Examples:
-------- --------
ma = [[0, 0, 1], [0, 1, 0]] >>> ma = [[0, 0, 1], [0, 1, 0]]
Cells(None).mother_assign_to_mb_matrix(ma) >>> Cells(None).mother_assign_to_mb_matrix(ma)
# array([[False, False, False, False, False, False], >>> array([[False, False, False, False, False, False],
# [False, False, False, False, False, False], [False, False, False, False, False, False],
# [ True, False, False, False, False, False], [ True, False, False, False, False, False],
# [False, False, False, False, False, False], [False, False, False, False, False, False],
# [False, False, False, True, False, False], [False, False, False, True, False, False],
# [False, False, False, False, False, False]]) [False, False, False, False, False, False]])
""" """
ncells = sum([len(t) for t in ma]) ncells = sum([len(t) for t in ma])
mb_matrix = np.zeros((ncells, ncells), dtype=bool) mb_matrix = np.zeros((ncells, ncells), dtype=bool)
c = 0 c = 0
...@@ -436,69 +397,78 @@ class Cells: ...@@ -436,69 +397,78 @@ class Cells:
for d, m in enumerate(cells): for d, m in enumerate(cells):
if m: if m:
mb_matrix[c + d, c + m - 1] = True mb_matrix[c + d, c + m - 1] = True
c += len(cells) c += len(cells)
return mb_matrix return mb_matrix
@staticmethod @staticmethod
def mother_assign_from_dynamic( def mother_assign_from_dynamic(
ma: np.ndarray, cell_label: t.List[int], trap: t.List[int], ntraps: int ma: np.ndarray,
cell_label: t.List[int],
trap: t.List[int],
ntraps: int,
) -> t.List[t.List[int]]: ) -> t.List[t.List[int]]:
""" """
Interpolate the associated mothers from the 'mother_assign_dynamic' feature. Find mothers from Baby's 'mother_assign_dynamic' variable.
Parameters Parameters
---------- ----------
ma: np.ndarray ma: np.ndarray
An array with shape (n_t, n_c) containing the 'mother_assign_dynamic' feature. An array with of length number of time points times number of cells
containing the 'mother_assign_dynamic' produced by Baby.
cell_label: List[int] cell_label: List[int]
A list containing the cell labels. A list of cell labels.
trap: List[int] trap: List[int]
A list containing the trap labels. A list of trap labels.
ntraps: int ntraps: int
The total number of traps. The total number of traps.
Returns Returns
------- -------
List[List[int]] List[List[int]]
A list of lists containing the interpolated mother assignment for each cell in each trap. A list giving the mothers for each cell at each trap.
""" """
idlist = list(zip(trap, cell_label)) ids = np.unique(list(zip(trap, cell_label)), axis=0)
cell_gid = np.unique(idlist, axis=0) # find when each cell last appeared at its trap
last_lin_preds = [ last_lin_preds = [
find_1st( find_1st(
((cell_label[::-1] == lbl) & (trap[::-1] == tr)), (
(cell_label[::-1] == cell_label_id)
& (trap[::-1] == trap_id)
),
True, True,
cmp_equal, cmp_equal,
) )
for tr, lbl in cell_gid for trap_id, cell_label_id in ids
] ]
# find the cell's mother using the latest prediction from Baby
mother_assign_sorted = ma[::-1][last_lin_preds] mother_assign_sorted = ma[::-1][last_lin_preds]
# rearrange as a list of mother IDs for each cell in each tile
traps = cell_gid[:, 0] traps = ids[:, 0]
iterator = groupby(zip(traps, mother_assign_sorted), lambda x: x[0]) iterator = groupby(zip(traps, mother_assign_sorted), lambda x: x[0])
d = {key: [x[1] for x in group] for key, group in iterator} d = {trap: [x[1] for x in mothers] for trap, mothers in iterator}
nested_massign = [d.get(i, []) for i in range(ntraps)] mothers = [d.get(i, []) for i in range(ntraps)]
return mothers
return nested_massign ###############################################################################
# Apparently unused below here
###############################################################################
@lru_cache(maxsize=200) @lru_cache(maxsize=200)
def labelled_in_frame( def labelled_in_frame(
self, frame: int, global_id: bool = False self, frame: int, global_id: bool = False
) -> np.ndarray: ) -> np.ndarray:
""" """
Returns labels in a 4D ndarray with the global ids with shape Return labels in a 4D ndarray with potentially global ids.
(ntraps, max_nlabels, ysize, xsize) at a given frame.
Use lru_cache to cache the results for speed.
Parameters Parameters
---------- ----------
frame : int frame : int
The frame number. The frame number (time point).
global_id : bool, optional global_id : bool, optional
If True, the returned array contains global ids, otherwise it If True, the returned array contains global ids, otherwise only
contains only the local ids of the labels. Default is False. the local ids of the labels.
Returns Returns
------- -------
...@@ -507,18 +477,12 @@ class Cells: ...@@ -507,18 +477,12 @@ class Cells:
The array has dimensions (ntraps, max_nlabels, ysize, xsize), The array has dimensions (ntraps, max_nlabels, ysize, xsize),
where max_nlabels is specific for this frame, not the entire where max_nlabels is specific for this frame, not the entire
experiment. experiment.
Notes
-----
This method uses lru_cache to cache the results for faster access.
""" """
labels_in_frame = self.labels_at_time(frame) labels_in_frame = self.labels_at_time(frame)
n_labels = [ n_labels = [
len(labels_in_frame.get(trap_id, [])) len(labels_in_frame.get(trap_id, []))
for trap_id in range(self.ntraps) for trap_id in range(self.ntraps)
] ]
# maxes = self.max_labels_in_frame(frame)
stacks_in_frame = self.get_stacks_in_frame(frame, self.tile_size) stacks_in_frame = self.get_stacks_in_frame(frame, self.tile_size)
first_id = np.cumsum([0, *n_labels]) first_id = np.cumsum([0, *n_labels])
labels_mat = np.zeros( labels_mat = np.zeros(
...@@ -552,7 +516,9 @@ class Cells: ...@@ -552,7 +516,9 @@ class Cells:
self, frame: int, tile_shape: t.Tuple[int] self, frame: int, tile_shape: t.Tuple[int]
) -> t.List[np.ndarray]: ) -> t.List[np.ndarray]:
""" """
Returns a list of stacked masks, each corresponding to a tile at a given timepoint. Return a list of stacked masks.
Each corresponds to a tile at a given time point.
Parameters Parameters
---------- ----------
...@@ -564,7 +530,7 @@ class Cells: ...@@ -564,7 +530,7 @@ class Cells:
Returns Returns
------- -------
List[np.ndarray] List[np.ndarray]
List of stacked masks for each tile at the given timepoint. List of stacked masks for each tile at the given time point.
""" """
masks = self.at_time(frame) masks = self.at_time(frame)
return [ return [
...@@ -574,7 +540,7 @@ class Cells: ...@@ -574,7 +540,7 @@ class Cells:
for trap_id in range(self.ntraps) for trap_id in range(self.ntraps)
] ]
def _sample_tiles_tps( def sample_tiles_tps(
self, self,
size=1, size=1,
min_consecutive_ntps: int = 15, min_consecutive_ntps: int = 15,
...@@ -582,7 +548,7 @@ class Cells: ...@@ -582,7 +548,7 @@ class Cells:
interval=None, interval=None,
) -> t.Tuple[np.ndarray, np.ndarray]: ) -> t.Tuple[np.ndarray, np.ndarray]:
""" """
Sample tiles that have a minimum number of cells and are occupied for at least a minimum number of consecutive timepoints. Sample tiles that have a minimum number of cells and are occupied for at least a minimum number of consecutive time points.
Parameters Parameters
---------- ----------
...@@ -591,7 +557,7 @@ class Cells: ...@@ -591,7 +557,7 @@ class Cells:
min_ncells: int, optional (default=2) min_ncells: int, optional (default=2)
The minimum number of cells per tile. The minimum number of cells per tile.
min_consecutive_ntps: int, optional (default=5) min_consecutive_ntps: int, optional (default=5)
The minimum number of consecutive timepoints a cell must be present in a trap. The minimum number of consecutive timep oints a cell must be present in a trap.
seed: int, optional (default=0) seed: int, optional (default=0)
Random seed value for reproducibility. Random seed value for reproducibility.
interval: None or Tuple(int,int), optional (default=None) interval: None or Tuple(int,int), optional (default=None)
...@@ -612,21 +578,15 @@ class Cells: ...@@ -612,21 +578,15 @@ class Cells:
min_consecutive_tps=min_consecutive_ntps, min_consecutive_tps=min_consecutive_ntps,
interval=interval, interval=interval,
) )
# Find all valid tiles with min_ncells for at least min_tps # Find all valid tiles with min_ncells for at least min_tps
index_id, tps = np.where(cell_availability_matrix) index_id, tps = np.where(cell_availability_matrix)
if interval is None: # Limit search if interval is None: # Limit search
interval = (0, cell_availability_matrix.shape[1]) interval = (0, cell_availability_matrix.shape[1])
np.random.seed(seed) np.random.seed(seed)
choices = np.random.randint(len(index_id), size=size) choices = np.random.randint(len(index_id), size=size)
linear_indices = np.zeros_like(self["cell_label"], dtype=bool) linear_indices = np.zeros_like(self["cell_label"], dtype=bool)
for cell_index_flat, tp in zip(index_id[choices], tps[choices]): for cell_index_flat, tp in zip(index_id[choices], tps[choices]):
tile_id, cell_label = self._flat_index_to_tuple_location( tile_id, cell_label = self.index_to_tile_and_cell(cell_index_flat)
cell_index_flat
)
linear_indices[ linear_indices[
( (
(self["cell_label"] == cell_label) (self["cell_label"] == cell_label)
...@@ -634,10 +594,9 @@ class Cells: ...@@ -634,10 +594,9 @@ class Cells:
& (self["timepoint"] == tp) & (self["timepoint"] == tp)
) )
] = True ] = True
return linear_indices return linear_indices
def _sample_masks( def sample_masks(
self, self,
size: int = 1, size: int = 1,
min_consecutive_ntps: int = 15, min_consecutive_ntps: int = 15,
...@@ -668,31 +627,28 @@ class Cells: ...@@ -668,31 +627,28 @@ class Cells:
The second tuple contains: The second tuple contains:
- `masks`: A list of 2D numpy arrays representing the binary masks of the sampled cells at each timepoint. - `masks`: A list of 2D numpy arrays representing the binary masks of the sampled cells at each timepoint.
""" """
sampled_bitmask = self._sample_tiles_tps( sampled_bitmask = self.sample_tiles_tps(
size=size, size=size,
min_consecutive_ntps=min_consecutive_ntps, min_consecutive_ntps=min_consecutive_ntps,
seed=seed, seed=seed,
interval=interval, interval=interval,
) )
# Sort sampled tiles to use automatic cache when possible # Sort sampled tiles to use automatic cache when possible
tile_ids = self["trap"][sampled_bitmask] tile_ids = self["trap"][sampled_bitmask]
cell_labels = self["cell_label"][sampled_bitmask] cell_labels = self["cell_label"][sampled_bitmask]
tps = self["timepoint"][sampled_bitmask] tps = self["timepoint"][sampled_bitmask]
masks = [] masks = []
for tile_id, cell_label, tp in zip(tile_ids, cell_labels, tps): for tile_id, cell_label, tp in zip(tile_ids, cell_labels, tps):
local_idx = self.labels_at_time(tp)[tile_id].index(cell_label) local_idx = self.labels_at_time(tp)[tile_id].index(cell_label)
tile_mask = self.at_time(tp)[tile_id][local_idx] tile_mask = self.at_time(tp)[tile_id][local_idx]
masks.append(tile_mask) masks.append(tile_mask)
return (tile_ids, cell_labels, tps), np.stack(masks) return (tile_ids, cell_labels, tps), np.stack(masks)
def matrix_trap_tp_where( def matrix_trap_tp_where(
self, min_ncells: int = 2, min_consecutive_tps: int = 5 self, min_ncells: int = 2, min_consecutive_tps: int = 5
): ):
""" """
NOTE CURRENLTY UNUSED WITHIN ALIBY THE MOMENT. MAY BE USEFUL IN THE FUTURE. NOTE CURRENTLY UNUSED BUT USEFUL.
Return a matrix of shape (ntraps x ntps - min_consecutive_tps) to Return a matrix of shape (ntraps x ntps - min_consecutive_tps) to
indicate traps and time-points where min_ncells are available for at least min_consecutive_tps indicate traps and time-points where min_ncells are available for at least min_consecutive_tps
...@@ -708,9 +664,8 @@ class Cells: ...@@ -708,9 +664,8 @@ class Cells:
(ntraps x ( ntps-min_consecutive_tps )) 2D boolean numpy array where rows are trap ids and columns are timepoint windows. (ntraps x ( ntps-min_consecutive_tps )) 2D boolean numpy array where rows are trap ids and columns are timepoint windows.
If the value in a cell is true its corresponding trap and timepoint contains more than min_ncells for at least min_consecutive time-points. If the value in a cell is true its corresponding trap and timepoint contains more than min_ncells for at least min_consecutive time-points.
""" """
window = sliding_window_view( window = sliding_window_view(
self._tiles_vs_cells_vs_tps, min_consecutive_tps, axis=2 self.tiles_vs_cells_vs_tps, min_consecutive_tps, axis=2
) )
tp_min = window.sum(axis=-1) == min_consecutive_tps tp_min = window.sum(axis=-1) == min_consecutive_tps
ncells_tp_min = tp_min.sum(axis=1) >= min_ncells ncells_tp_min = tp_min.sum(axis=1) >= min_ncells
...@@ -720,7 +675,7 @@ class Cells: ...@@ -720,7 +675,7 @@ class Cells:
def stack_masks_in_tile( def stack_masks_in_tile(
masks: t.List[np.ndarray], tile_shape: t.Tuple[int] masks: t.List[np.ndarray], tile_shape: t.Tuple[int]
) -> np.ndarray: ) -> np.ndarray:
# Stack all masks in a trap padding accordingly if no outlines found """Stack all masks in a trap, padding accordingly if no outlines found."""
result = np.zeros((0, *tile_shape), dtype=bool) result = np.zeros((0, *tile_shape), dtype=bool)
if len(masks): if len(masks):
result = np.stack(masks) result = np.stack(masks)
......
...@@ -6,17 +6,19 @@ import typing as t ...@@ -6,17 +6,19 @@ import typing as t
from functools import wraps from functools import wraps
def _first_arg_str_to_df( def _first_arg_str_to_raw_df(
fn: t.Callable, fn: t.Callable,
): ):
"""Enable Signal-like classes to convert strings to data sets.""" """Enable Signal-like classes to convert strings to data sets."""
@wraps(fn) @wraps(fn)
def format_input(*args, **kwargs): def format_input(*args, **kwargs):
cls = args[0] cls = args[0]
data = args[1] data = args[1]
if isinstance(data, str): if isinstance(data, str):
# get data from h5 file # get data from h5 file using Signal's get_raw
data = cls.get_raw(data) data = cls.get_raw(data)
# replace path in the undecorated function with data # replace path in the undecorated function with data
return fn(cls, data, *args[2:], **kwargs) return fn(cls, data, *args[2:], **kwargs)
return format_input return format_input
""" """
Anthology of interfaces fordispatch_metadata_parse different parsers and lack of them. Aliby decides on using different metadata parsers based on two elements:
1. The parameter given by PipelineParameters (either True/False or a string
ALIBY decides on using different metadata parsers based on two elements: pointing to the metadata file)
2. The available files in the root folder where images are found (either
1. The parameter given by PipelineParameters (Either True/False, or a string pointing to the metadata file) remote or locally).
2. The available files in the root folder where images are found (remote or locally)
If parameters is a string pointing to a metadata file, Aliby picks a parser
If parameters is a string pointing to a metadata file, ALIBY picks a parser based on the file format. based on the file format.
If parameters is True (as a boolean), ALIBY searches for any available file and uses the first valid one. If parameters is True, Aliby searches for any available file and uses the
If there are no metadata files, ALIBY requires indicating indices for tiler, segmentation and extraction. first valid one.
If there are no metadata files, Aliby requires indices in the tiff file names
for tiler, segmentation, and extraction.
WARNING: grammars depend on the directory structure of a local log-file_parser
repository.
""" """
import glob import glob
import logging import logging
import numpy as np
import os import os
import typing as t import typing as t
from datetime import datetime from datetime import datetime
...@@ -27,28 +32,32 @@ from logfile_parser.swainlab_parser import parse_from_swainlab_grammar ...@@ -27,28 +32,32 @@ from logfile_parser.swainlab_parser import parse_from_swainlab_grammar
class MetaData: class MetaData:
"""Small metadata Process that loads log.""" """Metadata process that loads and parses log files."""
def __init__(self, log_dir, store): def __init__(self, log_dir, store):
"""Initialise with log-file directory and h5 location to write."""
self.log_dir = log_dir self.log_dir = log_dir
self.store = store self.store = store
self.metadata_writer = Writer(self.store) self.metadata_writer = Writer(self.store)
def __getitem__(self, item): def __getitem__(self, item):
"""Load log and access item in resulting meta data dictionary."""
return self.load_logs()[item] return self.load_logs()[item]
def load_logs(self): def load_logs(self):
# parsed_flattened = parse_logfiles(self.log_dir) """Load log using a hierarchy of parsers."""
parsed_flattened = dispatch_metadata_parser(self.log_dir) parsed_flattened = dispatch_metadata_parser(self.log_dir)
return parsed_flattened return parsed_flattened
def run(self, overwrite=False): def run(self, overwrite=False):
"""Load and parse logs and write to h5 file."""
metadata_dict = self.load_logs() metadata_dict = self.load_logs()
self.metadata_writer.write( self.metadata_writer.write(
path="/", meta=metadata_dict, overwrite=overwrite path="/", meta=metadata_dict, overwrite=overwrite
) )
def add_field(self, field_name, field_value, **kwargs): def add_field(self, field_name, field_value, **kwargs):
"""Write a field and its values to the h5 file."""
self.metadata_writer.write( self.metadata_writer.write(
path="/", path="/",
meta={field_name: field_value}, meta={field_name: field_value},
...@@ -56,206 +65,220 @@ class MetaData: ...@@ -56,206 +65,220 @@ class MetaData:
) )
def add_fields(self, fields_values: dict, **kwargs): def add_fields(self, fields_values: dict, **kwargs):
"""Write a dict of fields and values to the h5 file."""
for field, value in fields_values.items(): for field, value in fields_values.items():
self.add_field(field, value) self.add_field(field, value)
# Paradigm: able to do something with all datatypes present in log files,
# then pare down on what specific information is really useful later.
# Needed because HDF5 attributes do not support dictionaries
def flatten_dict(nested_dict, separator="/"): def flatten_dict(nested_dict, separator="/"):
""" """
Flattens nested dictionary. If empty return as-is. Flatten nested dictionary because h5 attributes cannot be dicts.
If empty return as-is.
""" """
flattened = {} flattened = {}
if nested_dict: if nested_dict:
df = pd.json_normalize(nested_dict, sep=separator) df = pd.json_normalize(nested_dict, sep=separator)
flattened = df.to_dict(orient="records")[0] or {} flattened = df.to_dict(orient="records")[0] or {}
return flattened return flattened
# Needed because HDF5 attributes do not support datetime objects
# Takes care of time zones & daylight saving
def datetime_to_timestamp(time, locale="Europe/London"): def datetime_to_timestamp(time, locale="Europe/London"):
""" """Convert datetime object to UNIX timestamp."""
Convert datetime object to UNIX timestamp # h5 attributes do not support datetime objects
"""
return timezone(locale).localize(time).timestamp() return timezone(locale).localize(time).timestamp()
def find_file(root_dir, regex): def find_file(root_dir, regex):
"""Find files in a directory using regex."""
# ignore aliby.log files
file = [ file = [
f f
for f in glob.glob(os.path.join(str(root_dir), regex)) for f in glob.glob(os.path.join(str(root_dir), regex))
if Path(f).name != "aliby.log" # Skip filename reserved for aliby if Path(f).name != "aliby.log"
] ]
if len(file) > 1:
print(
"Warning:Metadata: More than one logfile found. Defaulting to first option."
)
file = [sorted(file)[0]]
if len(file) == 0: if len(file) == 0:
logging.getLogger("aliby").log( return None
logging.WARNING, "Metadata: No valid swainlab .log found." elif len(file) > 1:
print(
"Warning:Metadata: More than one log file found."
" Defaulting to first option."
) )
return sorted(file)[0]
else: else:
return file[0] return file[0]
return None
# TODO: re-write this as a class if appropriate
# WARNING: grammars depend on the directory structure of a locally installed
# logfile_parser repo
def parse_logfiles( def parse_logfiles(
root_dir, root_dir,
acq_grammar="multiDGUI_acq_format.json", acq_grammar="multiDGUI_acq_format.json",
log_grammar="multiDGUI_log_format.json", log_grammar="multiDGUI_log_format.json",
): ):
""" """
Parse acq and log files depending on the grammar specified, then merge into Parse acq and log files using the grammar specified.
single dict.
Merge results into a single dict.
""" """
# Both acq and log files contain useful information.
# ACQ_FILE = 'flavin_htb2_glucose_long_ramp_DelftAcq.txt'
# LOG_FILE = 'flavin_htb2_glucose_long_ramp_Delftlog.txt'
log_parser = Parser(log_grammar) log_parser = Parser(log_grammar)
acq_parser = Parser(acq_grammar) acq_parser = Parser(acq_grammar)
log_file = find_file(root_dir, "*log.txt") log_file = find_file(root_dir, "*log.txt")
acq_file = find_file(root_dir, "*[Aa]cq.txt") acq_file = find_file(root_dir, "*[Aa]cq.txt")
# parse into a single dict
parsed = {} parsed = {}
if log_file and acq_file: if log_file and acq_file:
with open(log_file, "r") as f: with open(log_file, "r") as f:
log_parsed = log_parser.parse(f) log_parsed = log_parser.parse(f)
with open(acq_file, "r") as f: with open(acq_file, "r") as f:
acq_parsed = acq_parser.parse(f) acq_parsed = acq_parser.parse(f)
parsed = {**acq_parsed, **log_parsed} parsed = {**acq_parsed, **log_parsed}
# convert data to having time stamps
for key, value in parsed.items(): for key, value in parsed.items():
if isinstance(value, datetime): if isinstance(value, datetime):
parsed[key] = datetime_to_timestamp(value) parsed[key] = datetime_to_timestamp(value)
# flatten dict
parsed_flattened = flatten_dict(parsed) parsed_flattened = flatten_dict(parsed)
for k, v in parsed_flattened.items(): for k, v in parsed_flattened.items():
if isinstance(v, list): if isinstance(v, list):
# replace None with 0
parsed_flattened[k] = [0 if el is None else el for el in v] parsed_flattened[k] = [0 if el is None else el for el in v]
return parsed_flattened return parsed_flattened
def get_meta_swainlab(parsed_metadata: dict): def find_channels_by_position(meta):
""" """
Convert raw parsing of Swainlab logfile to the metadata interface. Parse metadata to find the imaging channels for each group.
Input: Return a dict with groups as keys and channels as values.
-------- """
parsed_metadata: Dict[str, str or int or DataFrame or Dict] if isinstance(meta, pd.DataFrame):
default['general', 'image_config', 'device_properties', 'group_position', 'group_time', 'group_config'] imaging_channels = list(meta.columns)
channels_dict = {group: [] for group in meta.index}
for group in channels_dict:
for channel in imaging_channels:
if meta.loc[group, channel] is not None:
channels_dict[group].append(channel)
elif isinstance(meta, dict):
channels_dict = {
position_name: [] for position_name in meta["positions/posname"]
}
imaging_channels = meta["channels"]
for i, position_name in enumerate(meta["positions/posname"]):
for imaging_channel in imaging_channels:
if (
"positions/" + imaging_channel in meta
and meta["positions/" + imaging_channel][i]
):
channels_dict[position_name].append(imaging_channel)
else:
channels_dict = {}
return channels_dict
def get_minimal_meta_swainlab(parsed_metadata: dict):
"""
Extract channels from parsed metadata.
Returns: Parameters
-------- --------
Dictionary with metadata following the standard parsed_metadata: dict[str, str or int or DataFrame or Dict]
default['general', 'image_config', 'device_properties',
'group_position', 'group_time', 'group_config']
Returns
--------
Dict with channels metadata
""" """
channels_dict = find_channels_by_position(parsed_metadata["group_config"])
channels = parsed_metadata["image_config"]["Image config"].values.tolist() channels = parsed_metadata["image_config"]["Image config"].values.tolist()
# nframes = int(parsed_metadata["group_time"]["frames"].max()) ntps = parsed_metadata["group_time"]["frames"].max()
timeinterval = parsed_metadata["group_time"]["interval"].min()
# return {"channels": channels, "nframes": nframes} minimal_meta = {
return {"channels": channels} "channels_by_group": channels_dict,
"channels": channels,
"time_settings/ntimepoints": int(ntps),
"time_settings/timeinterval": int(timeinterval),
}
return minimal_meta
def get_meta_from_legacy(parsed_metadata: dict): def get_meta_from_legacy(parsed_metadata: dict):
"""Fix naming convention for channels in legacy .txt log files."""
result = parsed_metadata result = parsed_metadata
result["channels"] = result["channels/channel"] result["channels"] = result["channels/channel"]
return result return result
def parse_swainlab_metadata(filedir: t.Union[str, Path]): def parse_swainlab_metadata(filedir: t.Union[str, Path]):
""" """Parse new, .log, and old, .txt, files in a directory into a dict."""
Dispatcher function that determines which parser to use based on the file ending.
Input:
--------
filedir: Directory where the logfile is located.
Returns:
--------
Dictionary with minimal metadata
"""
filedir = Path(filedir) filedir = Path(filedir)
filepath = find_file(filedir, "*.log") filepath = find_file(filedir, "*.log")
if filepath: if filepath:
# new log files ending in .log
raw_parse = parse_from_swainlab_grammar(filepath) raw_parse = parse_from_swainlab_grammar(filepath)
minimal_meta = get_meta_swainlab(raw_parse) minimal_meta = get_minimal_meta_swainlab(raw_parse)
else: else:
# old log files ending in .txt
if filedir.is_file() or str(filedir).endswith(".zarr"): if filedir.is_file() or str(filedir).endswith(".zarr"):
# log file is in parent directory
filedir = filedir.parent filedir = filedir.parent
legacy_parse = parse_logfiles(filedir) legacy_parse = parse_logfiles(filedir)
minimal_meta = ( minimal_meta = (
get_meta_from_legacy(legacy_parse) if legacy_parse else {} get_meta_from_legacy(legacy_parse) if legacy_parse else {}
) )
return minimal_meta return minimal_meta
def dispatch_metadata_parser(filepath: t.Union[str, Path]): def dispatch_metadata_parser(filepath: t.Union[str, Path]):
""" """
Function to dispatch different metadata parsers that convert logfiles into a Dispatch different metadata parsers that convert logfiles into a dictionary.
basic metadata dictionary. Currently only contains the swainlab log parsers.
Currently only contains the swainlab log parsers.
Input: Parameters
-------- --------
filepath: str existing file containing metadata, or folder containing naming conventions filepath: str
File containing metadata or folder containing naming conventions.
""" """
parsed_meta = parse_swainlab_metadata(filepath) parsed_meta = parse_swainlab_metadata(filepath)
if parsed_meta is None: if parsed_meta is None:
# try to deduce metadata
parsed_meta = dir_to_meta parsed_meta = dir_to_meta
return parsed_meta return parsed_meta
def dir_to_meta(path: Path, suffix="tiff"): def dir_to_meta(path: Path, suffix="tiff"):
"""Deduce meta data from the naming convention of tiff files."""
filenames = list(path.glob(f"*.{suffix}")) filenames = list(path.glob(f"*.{suffix}"))
try: try:
# Deduct order from filenames # deduce order from filenames
dimorder = "".join( dim_order = "".join(
map(lambda x: x[0], filenames[0].stem.split("_")[1:]) map(lambda x: x[0], filenames[0].stem.split("_")[1:])
) )
dim_value = list( dim_value = list(
map( map(
lambda f: filename_to_dict_indices(f.stem), lambda f: filename_to_dict_indices(f.stem),
path.glob("*.tiff"), path.glob("*.tiff"),
) )
) )
maxes = [max(map(lambda x: x[dim], dim_value)) for dim in dimorder] maxs = [max(map(lambda x: x[dim], dim_value)) for dim in dim_order]
mins = [min(map(lambda x: x[dim], dim_value)) for dim in dimorder] mins = [min(map(lambda x: x[dim], dim_value)) for dim in dim_order]
_dim_shapes = [ dim_shapes = [
max_val - min_val + 1 for max_val, min_val in zip(maxes, mins) max_val - min_val + 1 for max_val, min_val in zip(maxs, mins)
] ]
meta = { meta = {
"size_" + dim: shape for dim, shape in zip(dimorder, _dim_shapes) "size_" + dim: shape for dim, shape in zip(dim_order, dim_shapes)
} }
except Exception as e: except Exception as e:
print( print(
f"Warning:Metadata: Cannot extract dimensions from filenames. Empty meta set {e}" "Warning:Metadata: Cannot extract dimensions from filenames."
f" Empty meta set {e}"
) )
meta = {} meta = {}
return meta return meta
def filename_to_dict_indices(stem: str): def filename_to_dict_indices(stem: str):
"""Convert a file name into a dict by splitting."""
return { return {
dim_number[0]: int(dim_number[1:]) dim_number[0]: int(dim_number[1:])
for dim_number in stem.split("_")[1:] for dim_number in stem.split("_")[1:]
......
...@@ -5,7 +5,7 @@ import h5py ...@@ -5,7 +5,7 @@ import h5py
import numpy as np import numpy as np
from agora.io.bridge import groupsort from agora.io.bridge import groupsort
from agora.io.writer import load_attributes from agora.io.writer import load_meta
class DynamicReader: class DynamicReader:
...@@ -13,7 +13,7 @@ class DynamicReader: ...@@ -13,7 +13,7 @@ class DynamicReader:
def __init__(self, file: str): def __init__(self, file: str):
self.file = file self.file = file
self.metadata = load_attributes(file) self.metadata = load_meta(file)
class StateReader(DynamicReader): class StateReader(DynamicReader):
......
...@@ -9,9 +9,10 @@ import h5py ...@@ -9,9 +9,10 @@ import h5py
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import aliby.global_parameters as global_parameters
from agora.io.bridge import BridgeH5 from agora.io.bridge import BridgeH5
from agora.io.decorators import _first_arg_str_to_df from agora.io.decorators import _first_arg_str_to_raw_df
from agora.utils.indexing import validate_association from agora.utils.indexing import validate_lineage
from agora.utils.kymograph import add_index_levels from agora.utils.kymograph import add_index_levels
from agora.utils.merge import apply_merges from agora.utils.merge import apply_merges
...@@ -20,11 +21,13 @@ class Signal(BridgeH5): ...@@ -20,11 +21,13 @@ class Signal(BridgeH5):
""" """
Fetch data from h5 files for post-processing. Fetch data from h5 files for post-processing.
Signal assumes that the metadata and data are accessible to perform time-adjustments and apply previously recorded post-processes. Signal assumes that the metadata and data are accessible to
perform time-adjustments and apply previously recorded
post-processes.
""" """
def __init__(self, file: t.Union[str, Path]): def __init__(self, file: t.Union[str, Path]):
"""Define index_names for dataframes, candidate fluorescence channels, and composite statistics.""" """Initialise defining index names for the dataframe."""
super().__init__(file, flag=None) super().__init__(file, flag=None)
self.index_names = ( self.index_names = (
"experiment", "experiment",
...@@ -33,51 +36,33 @@ class Signal(BridgeH5): ...@@ -33,51 +36,33 @@ class Signal(BridgeH5):
"cell_label", "cell_label",
"mother_label", "mother_label",
) )
self.candidate_channels = ( self.candidate_channels = global_parameters.possible_imaging_channels
"GFP",
"GFPFast",
"mCherry",
"Flavin",
"Citrine",
"mKO2",
"Cy5",
"pHluorin405",
)
def __getitem__(self, dsets: t.Union[str, t.Collection]): def get(
"""Get and potentially pre-process data from h5 file and return as a dataframe.""" self,
if isinstance(dsets, str): # no pre-processing dset: t.Union[str, t.Collection],
return self.get(dsets) tmax_in_mins: int = None,
elif isinstance(dsets, list): # pre-processing ):
is_bgd = [dset.endswith("imBackground") for dset in dsets] """Get Signal after merging and picking."""
# Check we are not comparing tile-indexed and cell-indexed data if isinstance(dset, str):
assert sum(is_bgd) == 0 or sum(is_bgd) == len( record = self.get_raw(dset, tmax_in_mins=tmax_in_mins)
dsets if record is not None:
), "Tile data and cell data can't be mixed" picked_merged = self.apply_merging_picking(record)
return [self.get(dset) for dset in dsets] return self.add_name(picked_merged, dset)
elif isinstance(dset, list):
return [self.get(d) for d in dset]
else: else:
raise Exception(f"Invalid type {type(dsets)} to get datasets") raise Exception("Error in Signal.get")
def get(self, dsets: t.Union[str, t.Collection], **kwargs):
"""Get and potentially pre-process data from h5 file and return as a dataframe."""
if isinstance(dsets, str): # no pre-processing
df = self.get_raw(dsets, **kwargs)
prepost_applied = self.apply_prepost(dsets, **kwargs)
return self.add_name(prepost_applied, dsets)
@staticmethod @staticmethod
def add_name(df, name): def add_name(df, name):
"""Add column of identical strings to a dataframe.""" """Add name of the Signal as an attribute to its data frame."""
df.name = name df.name = name
return df return df
def cols_in_mins(self, df: pd.DataFrame): def cols_in_mins(self, df: pd.DataFrame):
# Convert numerical columns in a dataframe to minutes """Convert numerical columns in a data frame to minutes."""
try: df.columns = (df.columns * self.tinterval // 60).astype(int)
df.columns = (df.columns * self.tinterval // 60).astype(int)
except Exception as e:
self._log(f"Unable to convert columns to minutes: {e}", "debug")
return df return df
@cached_property @cached_property
...@@ -88,20 +73,39 @@ class Signal(BridgeH5): ...@@ -88,20 +73,39 @@ class Signal(BridgeH5):
@cached_property @cached_property
def tinterval(self) -> int: def tinterval(self) -> int:
"""Find the interval between time points (minutes).""" """Find the interval between time points (seconds)."""
tinterval_location = "time_settings/timeinterval" tinterval_location = "time_settings/timeinterval"
with h5py.File(self.filename, "r") as f: with h5py.File(self.filename, "r") as f:
if tinterval_location in f.attrs: if tinterval_location in f.attrs:
return f.attrs[tinterval_location][0] res = f.attrs[tinterval_location]
if type(res) is list:
return res[0]
else:
return res
else: else:
logging.getlogger("aliby").warn( logging.getLogger("aliby").warn(
f"{str(self.filename).split('/')[-1]}: using default time interval of 5 minutes" f"{str(self.filename).split('/')[-1]}: using default time interval of 300 seconds."
) )
return 5 return 300
def retained(self, signal, cutoff: float = 0, tmax_in_mins: int = None):
"""Get retained cells for a Signal or list of Signals."""
if isinstance(signal, str):
# get data frame
signal = self.get(signal, tmax_in_mins=tmax_in_mins)
if isinstance(signal, pd.DataFrame):
return self.get_retained(signal, cutoff)
elif isinstance(signal, list):
return [self.get_retained(d, cutoff=cutoff) for d in signal]
@staticmethod @staticmethod
def get_retained(df, cutoff): def get_retained(df, cutoff):
"""Return a fraction of the df, one without later time points.""" """
Return sub data frame with retained cells.
Cells must be present for at least cutoff fraction of the total number
of time points.
"""
return df.loc[bn.nansum(df.notna(), axis=1) > df.shape[1] * cutoff] return df.loc[bn.nansum(df.notna(), axis=1) > df.shape[1] * cutoff]
@property @property
...@@ -110,20 +114,6 @@ class Signal(BridgeH5): ...@@ -110,20 +114,6 @@ class Signal(BridgeH5):
with h5py.File(self.filename, "r") as f: with h5py.File(self.filename, "r") as f:
return list(f.attrs["channels"]) return list(f.attrs["channels"])
@_first_arg_str_to_df
def retained(self, signal, cutoff=0.8):
"""
Load data (via decorator) and reduce the resulting dataframe.
Load data for a signal or a list of signals and reduce the resulting
dataframes to a fraction of their original size, losing late time
points.
"""
if isinstance(signal, pd.DataFrame):
return self.get_retained(signal, cutoff)
elif isinstance(signal, list):
return [self.get_retained(d, cutoff=cutoff) for d in signal]
@lru_cache(2) @lru_cache(2)
def lineage( def lineage(
self, lineage_location: t.Optional[str] = None, merged: bool = False self, lineage_location: t.Optional[str] = None, merged: bool = False
...@@ -131,91 +121,72 @@ class Signal(BridgeH5): ...@@ -131,91 +121,72 @@ class Signal(BridgeH5):
""" """
Get lineage data from a given location in the h5 file. Get lineage data from a given location in the h5 file.
Returns an array with three columns: the tile id, the mother label, and the daughter label. Returns an array with three columns: the tile id, the mother label,
and the daughter label.
""" """
if lineage_location is None: if lineage_location is None:
lineage_location = "modifiers/lineage_merged" lineage_location = "modifiers/lineage_merged"
with h5py.File(self.filename, "r") as f: with h5py.File(self.filename, "r") as f:
# if lineage_location not in f:
# lineage_location = lineage_location.split("_")[0]
if lineage_location not in f: if lineage_location not in f:
lineage_location = "postprocessing/lineage" lineage_location = "postprocessing/lineage"
tile_mo_da = f[lineage_location] traps_mothers_daughters = f[lineage_location]
if isinstance(traps_mothers_daughters, h5py.Dataset):
if isinstance(tile_mo_da, h5py.Dataset): lineage = traps_mothers_daughters[()]
lineage = tile_mo_da[()]
else: else:
lineage = np.array( lineage = np.array(
( (
tile_mo_da["trap"], traps_mothers_daughters["trap"],
tile_mo_da["mother_label"], traps_mothers_daughters["mother_label"],
tile_mo_da["daughter_label"], traps_mothers_daughters["daughter_label"],
) )
).T ).T
return lineage return lineage
@_first_arg_str_to_df # @_first_arg_str_to_raw_df
def apply_prepost( def apply_merging_picking(
self, self,
data: t.Union[str, pd.DataFrame], data: t.Union[str, pd.DataFrame],
merges: t.Union[np.ndarray, bool] = True, merges: t.Union[np.ndarray, bool] = True,
picks: t.Union[t.Collection, bool] = True, picks: t.Union[t.Collection, bool] = True,
): ):
""" """
Apply modifier operations (picker or merger) to a dataframe. Apply picking and merging to a Signal data frame.
Parameters Parameters
---------- ----------
data : t.Union[str, pd.DataFrame] data : t.Union[str, pd.DataFrame]
DataFrame or path to one. A data frame or a path to one.
merges : t.Union[np.ndarray, bool] merges : t.Union[np.ndarray, bool]
(optional) 2-D array with three columns: the tile id, the mother label, and the daughter id. (optional) An array of pairs of (trap, cell) indices to merge.
If True, fetch merges from file. If True, fetch merges from file.
picks : t.Union[np.ndarray, bool] picks : t.Union[np.ndarray, bool]
(optional) 2-D array with two columns: the tiles and (optional) An array of (trap, cell) indices.
the cell labels.
If True, fetch picks from file. If True, fetch picks from file.
Examples
--------
FIXME: Add docs.
""" """
if isinstance(merges, bool): if isinstance(merges, bool):
merges: np.ndarray = self.load_merges() if merges else np.array([]) merges = self.load_merges() if merges else np.array([])
if merges.any(): if merges.any():
merged = apply_merges(data, merges) merged = apply_merges(data, merges)
else: else:
merged = copy(data) merged = copy(data)
if isinstance(picks, bool): if isinstance(picks, bool):
picks = ( if picks is True:
self.get_picks(names=merged.index.names) # load picks from h5
if picks picks = self.get_picks(
else set(merged.index) names=merged.index.names, path="modifiers/picks/"
)
else:
return merged
if len(picks):
picked_indices = set(picks).intersection(
[tuple(x) for x in merged.index]
) )
with h5py.File(self.filename, "r") as f: return merged.loc[picked_indices]
if "modifiers/picks" in f and picks: else:
if picks: return merged
return merged.loc[
set(picks).intersection(
[tuple(x) for x in merged.index]
)
]
else:
if isinstance(merged.index, pd.MultiIndex):
empty_lvls = [[] for i in merged.index.names]
index = pd.MultiIndex(
levels=empty_lvls,
codes=empty_lvls,
names=merged.index.names,
)
else:
index = pd.Index([], name=merged.index.name)
merged = pd.DataFrame([], index=index)
return merged
@cached_property @cached_property
def p_available(self): def print_available(self):
"""Print data sets available in h5 file.""" """Print data sets available in h5 file."""
if not hasattr(self, "_available"): if not hasattr(self, "_available"):
self._available = [] self._available = []
...@@ -233,13 +204,12 @@ class Signal(BridgeH5): ...@@ -233,13 +204,12 @@ class Signal(BridgeH5):
with h5py.File(self.filename, "r") as f: with h5py.File(self.filename, "r") as f:
f.visititems(self.store_signal_path) f.visititems(self.store_signal_path)
except Exception as e: except Exception as e:
self._log("Exception when visiting h5: {}".format(e), "exception") self.log("Exception when visiting h5: {}".format(e), "exception")
return self._available return self._available
def get_merged(self, dataset): def get_merged(self, dataset):
"""Run preprocessing for merges.""" """Run merging."""
return self.apply_prepost(dataset, picks=False) return self.apply_merging_picking(dataset, picks=False)
@cached_property @cached_property
def merges(self) -> np.ndarray: def merges(self) -> np.ndarray:
...@@ -265,43 +235,73 @@ class Signal(BridgeH5): ...@@ -265,43 +235,73 @@ class Signal(BridgeH5):
dataset: str or t.List[str], dataset: str or t.List[str],
in_minutes: bool = True, in_minutes: bool = True,
lineage: bool = False, lineage: bool = False,
tmax_in_mins: int = None,
**kwargs,
) -> pd.DataFrame or t.List[pd.DataFrame]: ) -> pd.DataFrame or t.List[pd.DataFrame]:
""" """
Load data from a h5 file and return as a dataframe. Get raw Signal without merging, picking, and lineage information.
Parameters Parameters
---------- ----------
dataset: str or list of strs dataset: str or list of strs
The name of the h5 file or a list of h5 file names The name of the h5 file or a list of h5 file names.
in_minutes: boolean in_minutes: boolean
If True, If True, convert column headings to times in minutes.
lineage: boolean lineage: boolean
If True, add mother_label to index.
run_lineage_check: boolean
If True, raise exception if a likely error in the lineage assignment.
tmax_in_mins: int (optional)
Discard data for times > tmax_in_mins. Cells with all NaNs will also
be discarded to help with assigning lineages.
Setting tmax_in_mins is a way to ignore parts of the experiment with
incorrect lineages generated by clogging.
""" """
try: if isinstance(dataset, str):
if isinstance(dataset, str): with h5py.File(self.filename, "r") as f:
with h5py.File(self.filename, "r") as f: df = self.dataset_to_df(f, dataset)
df = self.dataset_to_df(f, dataset).sort_index() if df is not None:
df = df.sort_index()
if in_minutes: if in_minutes:
df = self.cols_in_mins(df) df = self.cols_in_mins(df)
elif isinstance(dataset, list): # limit data by time and discard NaNs
return [ if (
self.get_raw(dset, in_minutes=in_minutes, lineage=lineage) in_minutes
for dset in dataset and tmax_in_mins
] and type(tmax_in_mins) is int
if lineage: # assume that df is sorted ):
mother_label = np.zeros(len(df), dtype=int) df = df[df.columns[df.columns <= tmax_in_mins]]
lineage = self.lineage() df = df.dropna(how="all")
a, b = validate_association( # add mother label to data frame
lineage, if lineage:
np.array(df.index.to_list()), if "mother_label" in df.index.names:
match_column=1, df = df.droplevel("mother_label")
mother_label = np.zeros(len(df), dtype=int)
lineage = self.lineage()
(
valid_lineage,
valid_indices,
lineage,
) = validate_lineage(
lineage,
indices=np.array(df.index.to_list()),
how="daughters",
)
mother_label[valid_indices] = lineage[valid_lineage, 1]
df = add_index_levels(
df, {"mother_label": mother_label}
)
return df
elif isinstance(dataset, list):
return [
self.get_raw(
dset,
in_minutes=in_minutes,
lineage=lineage,
tmax_in_mins=tmax_in_mins,
) )
mother_label[b] = lineage[a, 1] for dset in dataset
df = add_index_levels(df, {"mother_label": mother_label}) ]
return df
except Exception as e:
self._log(f"Could not fetch dataset {dataset}: {e}", "error")
raise e
def load_merges(self): def load_merges(self):
"""Get merge events going up to the first level.""" """Get merge events going up to the first level."""
...@@ -316,32 +316,36 @@ class Signal(BridgeH5): ...@@ -316,32 +316,36 @@ class Signal(BridgeH5):
names: t.Tuple[str, ...] = ("trap", "cell_label"), names: t.Tuple[str, ...] = ("trap", "cell_label"),
path: str = "modifiers/picks/", path: str = "modifiers/picks/",
) -> t.Set[t.Tuple[int, str]]: ) -> t.Set[t.Tuple[int, str]]:
"""Get the relevant picks based on names.""" """Get picks from the h5 file."""
with h5py.File(self.filename, "r") as f: with h5py.File(self.filename, "r") as f:
picks = set()
if path in f: if path in f:
picks = set( picks = set(
zip(*[f[path + name] for name in names if name in f[path]]) zip(*[f[path + name] for name in names if name in f[path]])
) )
else:
picks = set()
return picks return picks
def dataset_to_df(self, f: h5py.File, path: str) -> pd.DataFrame: def dataset_to_df(self, f: h5py.File, path: str) -> pd.DataFrame:
"""Get data from h5 file as a dataframe.""" """Get data from h5 file as a dataframe."""
assert path in f, f"{path} not in {f}" if path not in f:
dset = f[path] self.log(f"{path} not in {f}.")
values, index, columns = [], [], [] return None
index_names = copy(self.index_names) else:
valid_names = [lbl for lbl in index_names if lbl in dset.keys()] dset = f[path]
if valid_names: values, index, columns = [], [], []
index = pd.MultiIndex.from_arrays( index_names = copy(self.index_names)
[dset[lbl] for lbl in valid_names], names=valid_names valid_names = [lbl for lbl in index_names if lbl in dset.keys()]
) if valid_names:
columns = dset.attrs.get("columns", None) index = pd.MultiIndex.from_arrays(
if "timepoint" in dset: [dset[lbl] for lbl in valid_names], names=valid_names
columns = f[path + "/timepoint"][()] )
values = f[path + "/values"][()] columns = dset.attrs.get("columns", None)
df = pd.DataFrame(values, index=index, columns=columns) if "timepoint" in dset:
return df columns = f[path + "/timepoint"][()]
values = f[path + "/values"][()]
df = pd.DataFrame(values, index=index, columns=columns)
return df
@property @property
def stem(self): def stem(self):
...@@ -353,10 +357,7 @@ class Signal(BridgeH5): ...@@ -353,10 +357,7 @@ class Signal(BridgeH5):
fullname: str, fullname: str,
node: t.Union[h5py.Dataset, h5py.Group], node: t.Union[h5py.Dataset, h5py.Group],
): ):
""" """Store the name of a signal if it is a leaf node and if it starts with extraction."""
Store the name of a signal if it is a leaf node
(a group with no more groups inside) and if it starts with extraction.
"""
if isinstance(node, h5py.Group) and np.all( if isinstance(node, h5py.Group) and np.all(
[isinstance(x, h5py.Dataset) for x in node.values()] [isinstance(x, h5py.Dataset) for x in node.values()]
): ):
...@@ -379,62 +380,7 @@ class Signal(BridgeH5): ...@@ -379,62 +380,7 @@ class Signal(BridgeH5):
if isinstance(obj, h5py.Group) and name.endswith("picks"): if isinstance(obj, h5py.Group) and name.endswith("picks"):
return obj[()] return obj[()]
# TODO FUTURE add stages support to fluigent system
@property @property
def ntps(self) -> int: def ntps(self) -> int:
"""Get number of time points from the metadata.""" """Get number of time points from the metadata."""
return self.meta_h5["time_settings/ntimepoints"][0] return self.meta_h5["time_settings/ntimepoints"][0]
@property
def stages(self) -> t.List[str]:
"""Get the contents of the pump with highest flow rate at each stage."""
flowrate_name = "pumpinit/flowrate"
pumprate_name = "pumprate"
switchtimes_name = "switchtimes"
main_pump_id = np.concatenate(
(
(np.argmax(self.meta_h5[flowrate_name]),),
np.argmax(self.meta_h5[pumprate_name], axis=0),
)
)
if not self.meta_h5[switchtimes_name][0]: # Cover for t0 switches
main_pump_id = main_pump_id[1:]
return [self.meta_h5["pumpinit/contents"][i] for i in main_pump_id]
@property
def nstages(self) -> int:
return len(self.switch_times) + 1
@property
def max_span(self) -> int:
return int(self.tinterval * self.ntps / 60)
@property
def switch_times(self) -> t.List[int]:
switchtimes_name = "switchtimes"
switches_minutes = self.meta_h5[switchtimes_name]
return [
t_min
for t_min in switches_minutes
if t_min and t_min < self.max_span
] # Cover for t0 switches
@property
def stages_span(self) -> t.Tuple[t.Tuple[str, int], ...]:
"""Get consecutive stages and their corresponding number of time points."""
transition_tps = (0, *self.switch_times, self.max_span)
spans = [
end - start
for start, end in zip(transition_tps[:-1], transition_tps[1:])
if end <= self.max_span
]
return tuple((stage, ntps) for stage, ntps in zip(self.stages, spans))
@property
def stages_span_tp(self) -> t.Tuple[t.Tuple[str, int], ...]:
return tuple(
[
(name, (t_min * 60) // self.tinterval)
for name, t_min in self.stages_span
]
)
...@@ -15,9 +15,10 @@ from agora.io.bridge import BridgeH5 ...@@ -15,9 +15,10 @@ from agora.io.bridge import BridgeH5
#################### Dynamic version ################################## #################### Dynamic version ##################################
def load_attributes(file: str, group="/"): def load_meta(file: str, group="/"):
""" """
Load the metadata from an h5 file and convert to a dictionary, including the "parameters" field which is stored as YAML. Load the metadata from an h5 file and convert to a dictionary, including
the "parameters" field which is stored as YAML.
Parameters Parameters
---------- ----------
...@@ -26,8 +27,9 @@ def load_attributes(file: str, group="/"): ...@@ -26,8 +27,9 @@ def load_attributes(file: str, group="/"):
group: str, optional group: str, optional
The group in the h5 file from which to read the data The group in the h5 file from which to read the data
""" """
# load the metadata, stored as attributes, from the h5 file and return as a dictionary # load the metadata, stored as attributes, from the h5 file
with h5py.File(file, "r") as f: with h5py.File(file, "r") as f:
# return as a dict
meta = dict(f[group].attrs.items()) meta = dict(f[group].attrs.items())
if "parameters" in meta: if "parameters" in meta:
# convert from yaml format into dict # convert from yaml format into dict
...@@ -51,9 +53,9 @@ class DynamicWriter: ...@@ -51,9 +53,9 @@ class DynamicWriter:
self.file = file self.file = file
# the metadata is stored as attributes in the h5 file # the metadata is stored as attributes in the h5 file
if Path(file).exists(): if Path(file).exists():
self.metadata = load_attributes(file) self.metadata = load_meta(file)
def _log(self, message: str, level: str = "warn"): def log(self, message: str, level: str = "warn"):
# Log messages in the corresponding level # Log messages in the corresponding level
logger = logging.getLogger("aliby") logger = logging.getLogger("aliby")
getattr(logger, level)(f"{self.__class__.__name__}: {message}") getattr(logger, level)(f"{self.__class__.__name__}: {message}")
...@@ -102,9 +104,11 @@ class DynamicWriter: ...@@ -102,9 +104,11 @@ class DynamicWriter:
maxshape=max_shape, maxshape=max_shape,
dtype=dtype, dtype=dtype,
compression=self.compression, compression=self.compression,
compression_opts=self.compression_opts compression_opts=(
if self.compression is not None self.compression_opts
else None, if self.compression is not None
else None
),
) )
# write all data, signified by the empty tuple # write all data, signified by the empty tuple
hgroup[key][()] = data hgroup[key][()] = data
...@@ -172,7 +176,7 @@ class DynamicWriter: ...@@ -172,7 +176,7 @@ class DynamicWriter:
# append or create new dataset # append or create new dataset
self._append(value, key, hgroup) self._append(value, key, hgroup)
except Exception as e: except Exception as e:
self._log( self.log(
f"{key}:{value} could not be written: {e}", "error" f"{key}:{value} could not be written: {e}", "error"
) )
# write metadata # write metadata
...@@ -448,7 +452,6 @@ class Writer(BridgeH5): ...@@ -448,7 +452,6 @@ class Writer(BridgeH5):
""" """
self.id_cache = {} self.id_cache = {}
with h5py.File(self.filename, "a") as f: with h5py.File(self.filename, "a") as f:
# Alan, haven't we already opened the h5 file through BridgeH5's init?
if overwrite == "overwrite": # TODO refactor overwriting if overwrite == "overwrite": # TODO refactor overwriting
if path in f: if path in f:
del f[path] del f[path]
...@@ -490,7 +493,12 @@ class Writer(BridgeH5): ...@@ -490,7 +493,12 @@ class Writer(BridgeH5):
def write_meta(self, f: h5py.File, path: str, attr: str, data: Iterable): def write_meta(self, f: h5py.File, path: str, attr: str, data: Iterable):
"""Write metadata to an open h5 file.""" """Write metadata to an open h5 file."""
obj = f.require_group(path) obj = f.require_group(path)
obj.attrs[attr] = data if type(data) is dict:
# necessary for channels_dict from find_channels_by_position
for key, vlist in data.items():
obj.attrs[attr + key] = vlist
else:
obj.attrs[attr] = data
@staticmethod @staticmethod
def write_arraylike(f: h5py.File, path: str, data: Iterable, **kwargs): def write_arraylike(f: h5py.File, path: str, data: Iterable, **kwargs):
...@@ -535,7 +543,6 @@ class Writer(BridgeH5): ...@@ -535,7 +543,6 @@ class Writer(BridgeH5):
path + "values" if path.endswith("/") else path + "/values" path + "values" if path.endswith("/") else path + "/values"
) )
if path not in f: if path not in f:
# create dataset and write data # create dataset and write data
max_ncells = 2e5 max_ncells = 2e5
max_tps = 1e3 max_tps = 1e3
...@@ -581,7 +588,6 @@ class Writer(BridgeH5): ...@@ -581,7 +588,6 @@ class Writer(BridgeH5):
else: else:
f[path].attrs["columns"] = df.columns.tolist() f[path].attrs["columns"] = df.columns.tolist()
else: else:
# path exists # path exists
dset = f[values_path] dset = f[values_path]
...@@ -618,9 +624,9 @@ class Writer(BridgeH5): ...@@ -618,9 +624,9 @@ class Writer(BridgeH5):
# sort indices for h5 indexing # sort indices for h5 indexing
incremental_existing = np.argsort(found_indices) incremental_existing = np.argsort(found_indices)
self.id_cache[df.index.nlevels][ self.id_cache[df.index.nlevels]["found_indices"] = (
"found_indices" found_indices[incremental_existing]
] = found_indices[incremental_existing] )
self.id_cache[df.index.nlevels]["found_multi"] = found_multis[ self.id_cache[df.index.nlevels]["found_multi"] = found_multis[
incremental_existing incremental_existing
] ]
......
#!/usr/bin/env jupyter
"""
Add general logging functions and decorators
"""
import logging import logging
from time import perf_counter from time import perf_counter
def timer(func): def timer(func):
# Log duration of a function into aliby logfile """Log duration of a function into the aliby log file."""
def wrap_func(*args, **kwargs): def wrap_func(*args, **kwargs):
t1 = perf_counter() t1 = perf_counter()
result = func(*args, **kwargs) result = func(*args, **kwargs)
......
#!/usr/bin/env jupyter
"""
Utilities based on association are used to efficiently acquire indices of tracklets with some kind of relationship.
This can be:
- Cells that are to be merged
- Cells that have a linear relationship
"""
import numpy as np import numpy as np
import typing as t import pandas as pd
# data type to link together trap and cell ids
i_dtype = {"names": ["trap_id", "cell_id"], "formats": [np.int64, np.int64]}
def validate_association(
association: np.ndarray,
indices: np.ndarray,
match_column: t.Optional[int] = None,
) -> t.Tuple[np.ndarray, np.ndarray]:
"""Select rows from the first array that are present in both.
We use casting for fast multiindexing, generalising for lineage dynamics
Parameters
----------
association : np.ndarray
2-D array where columns are (trap, mother, daughter) or 3-D array where
dimensions are (X,trap,2), containing tuples ((trap,mother), (trap,daughter))
across the 3rd dimension.
indices : np.ndarray
2-D array where each column is a different level. This should not include mother_label.
match_column: int
int indicating a specific column is required to match (i.e.
0-1 for target-source when trying to merge tracklets or mother-bud for lineage)
must be present in indices. If it is false one match suffices for the resultant indices
vector to be True.
Returns
-------
np.ndarray
1-D boolean array indicating valid merge events.
np.ndarray
1-D boolean array indicating indices with an association relationship.
Examples
--------
>>> import numpy as np
>>> from agora.utils.indexing import validate_association
>>> merges = np.array(range(12)).reshape(3,2,2)
>>> indices = np.array(range(6)).reshape(3,2)
>>> print(merges, indices)
>>> print(merges); print(indices)
[[[ 0 1]
[ 2 3]]
[[ 4 5]
[ 6 7]]
[[ 8 9]
[10 11]]]
[[0 1]
[2 3]
[4 5]]
>>> valid_associations, valid_indices = validate_association(merges, indices)
>>> print(valid_associations, valid_indices)
[ True False False] [ True True False]
def validate_lineage(
lineage: np.ndarray,
indices: np.ndarray,
how: str = "families",
):
""" """
if association.ndim == 2: Identify mother-bud pairs both in lineage and a Signal's indices.
# Reshape into 3-D array for broadcasting if neded
# association = np.stack( We expect the lineage information to be unique: a bud should not have
# (association[:, [0, 1]], association[:, [0, 2]]), axis=1 two mothers.
# )
association = _assoc_indices_to_3d(association) Lineage is returned with buds assigned only to their first mother if they
have multiple.
# Compare existing association with available indices
# Swap trap and label axes for the association array to correctly cast Parameters
valid_ndassociation = association[..., None] == indices.T[None, ...] ----------
lineage : np.ndarray
# Broadcasting is confusing (but efficient): 2D array of lineage associations where columns are
# First we check the dimension across trap and cell id, to ensure both match (trap, mother, daughter)
valid_cell_ids = valid_ndassociation.all(axis=2) or
a 3D array, which is an array of 2 X 2 arrays comprising
if match_column is None: [[trap_id, mother_label], [trap_id, daughter_label]].
# Then we check the merge tuples to check which cases have both target and source indices : np.ndarray
valid_association = valid_cell_ids.any(axis=2).all(axis=1) A 2D array of cell indices from a Signal, (trap_id, cell_label).
This array should not include mother_label.
# Finally we check the dimension that crosses all indices, to ensure the pair how: str
# is present in a valid merge event. If "mothers", matches indicate mothers from mother-bud pairs;
valid_indices = ( If "daughters", matches indicate daughters from mother-bud pairs;
valid_ndassociation[valid_association].all(axis=2).any(axis=(0, 1)) If "families", matches indicate mothers and daughters in mother-bud pairs.
)
else: # We fetch specific indices if we aim for the ones with one present Returns
valid_indices = valid_cell_ids[:, match_column].any(axis=0) -------
# Valid association then becomes a boolean array, true means that there is a valid_lineage: boolean np.ndarray
# match (match_column) between that cell and the index 1D array indicating matched elements in lineage.
valid_association = ( valid_indices: boolean np.ndarray
valid_cell_ids[:, match_column] & valid_indices 1D array indicating matched elements in indices.
).any(axis=1) lineage: np.ndarray
Any bud already having a mother that is assigned to another has that
second assignment discarded.
Examples
--------
>>> import numpy as np
>>> from agora.utils.indexing import validate_lineage
>>> lineage = np.array([ [[0, 1], [0, 3]], [[0, 1], [0, 4]], [[0, 1], [0, 6]], [[0, 4], [0, 7]] ])
>>> indices = np.array([ [0, 1], [0, 2], [0, 3]])
>>> valid_lineage, valid_indices, lineage = validate_lineage(lineage, indices)
>>> print(valid_lineage)
array([ True, False, False, False])
>>> print(valid_indices)
array([ True, False, True])
and
>>> lineage = np.array([[[0,3], [0,1]], [[0,2], [0,4]]])
>>> indices = np.array([[0,1], [0,2], [0,3]])
>>> valid_lineage, valid_indices, lineage = validate_lineage(lineage, indices)
>>> print(valid_lineage)
array([ True, False])
>>> print(valid_indices)
array([ True, False, True])
"""
if lineage.ndim == 2:
# [trap, mother, daughter] becomes [[trap, mother], [trap, daughter]]
lineage = assoc_indices_to_3d(lineage)
invert_lineage = True
if how == "mothers":
c_index = 0
elif how == "daughters":
c_index = 1
# if buds have two mothers, pick the first one
lineage = lineage[
~pd.DataFrame(lineage[:, 1, :]).duplicated().values, :, :
]
# find valid lineage
valid_lineages = index_isin(lineage, indices)
if how == "families":
# both mother and bud must be in indices
valid_lineage = valid_lineages.all(axis=1)
else:
valid_lineage = valid_lineages[:, c_index, :]
flat_valid_lineage = valid_lineage.flatten()
# find valid indices
selected_lineages = lineage[flat_valid_lineage, ...]
if how == "families":
# select only pairs of mother and bud indices
valid_indices = index_isin(indices, selected_lineages)
else:
valid_indices = index_isin(indices, selected_lineages[:, c_index, :])
flat_valid_indices = valid_indices.flatten()
# put the corrected lineage in the right format
if invert_lineage:
lineage = assoc_indices_to_2d(lineage)
return flat_valid_lineage, flat_valid_indices, lineage
def index_isin(x: np.ndarray, y: np.ndarray) -> np.ndarray:
"""
Find those elements of x that are in y.
return valid_association, valid_indices Both arrays must be arrays of integer indices,
such as (trap_id, cell_id).
"""
x = np.ascontiguousarray(x, dtype=np.int64)
y = np.ascontiguousarray(y, dtype=np.int64)
xv = x.view(i_dtype)
inboth = np.intersect1d(xv, y.view(i_dtype))
x_bool = np.isin(xv, inboth)
return x_bool
def _assoc_indices_to_3d(ndarray: np.ndarray): def assoc_indices_to_3d(ndarray: np.ndarray):
""" """
Convert the last column to a new row while repeating all previous indices. Convert the last column to a new row and repeat first column's values.
This is useful when converting a signal multiindex before comparing association. For example: [trap, mother, daughter] becomes
[[trap, mother], [trap, daughter]].
Assumes the input array has shape (N,3) Assumes the input array has shape (N,3).
""" """
result = ndarray result = ndarray
if len(ndarray) and ndarray.ndim > 1: if len(ndarray) and ndarray.ndim > 1:
if ndarray.shape[1] == 3: # Faster indexing for single positions # faster indexing for single positions
if ndarray.shape[1] == 3:
result = np.transpose( result = np.transpose(
np.hstack((ndarray[:, [0]], ndarray)).reshape(-1, 2, 2), np.hstack((ndarray[:, [0]], ndarray)).reshape(-1, 2, 2),
axes=[0, 2, 1], axes=[0, 2, 1],
) )
else: # 20% slower but more general indexing else:
# 20% slower but more general indexing
columns = np.arange(ndarray.shape[1]) columns = np.arange(ndarray.shape[1])
result = np.stack( result = np.stack(
( (
ndarray[:, np.delete(columns, -1)], ndarray[:, np.delete(columns, -1)],
...@@ -132,21 +150,11 @@ def _assoc_indices_to_3d(ndarray: np.ndarray): ...@@ -132,21 +150,11 @@ def _assoc_indices_to_3d(ndarray: np.ndarray):
return result return result
def _3d_index_to_2d(array: np.ndarray): def assoc_indices_to_2d(array: np.ndarray):
""" """Convert indices to 2d."""
Opposite to _assoc_indices_to_3d.
"""
result = array result = array
if len(array): if len(array):
result = np.concatenate( result = np.concatenate(
(array[:, 0, :], array[:, 1, 1, np.newaxis]), axis=1 (array[:, 0, :], array[:, 1, 1, np.newaxis]), axis=1
) )
return result return result
def compare_indices(x: np.ndarray, y: np.ndarray) -> np.ndarray:
"""
Fetch two 2-D indices and return a binary 2-D matrix
where a True value links two cells where all cells are the same
"""
return (x[..., None] == y.T[None, ...]).all(axis=1)
...@@ -6,7 +6,7 @@ import numpy as np ...@@ -6,7 +6,7 @@ import numpy as np
import pandas as pd import pandas as pd
from sklearn.cluster import KMeans from sklearn.cluster import KMeans
from agora.utils.indexing import validate_association # from agora.utils.indexing import validate_association
index_row = t.Tuple[str, str, int, int] index_row = t.Tuple[str, str, int, int]
...@@ -86,16 +86,19 @@ def bidirectional_retainment_filter( ...@@ -86,16 +86,19 @@ def bidirectional_retainment_filter(
daughters_thresh: int = 7, daughters_thresh: int = 7,
) -> pd.DataFrame: ) -> pd.DataFrame:
""" """
Retrieve families where mothers are present for more than a fraction of the experiment, and daughters for longer than some number of time-points. Retrieve families where mothers are present for more than a fraction
of the experiment and daughters for longer than some number of
time-points.
Parameters Parameters
---------- ----------
df: pd.DataFrame df: pd.DataFrame
Data Data
mothers_thresh: float mothers_thresh: float
Minimum fraction of experiment's total duration for which mothers must be present. Minimum fraction of experiment's total duration for which mothers
must be present.
daughters_thresh: int daughters_thresh: int
Minimum number of time points for which daughters must be observed Minimum number of time points for which daughters must be observed.
""" """
# daughters # daughters
all_daughters = df.loc[df.index.get_level_values("mother_label") > 0] all_daughters = df.loc[df.index.get_level_values("mother_label") > 0]
...@@ -170,6 +173,7 @@ def slices_from_spans(spans: t.Tuple[int], df: pd.DataFrame) -> t.List[slice]: ...@@ -170,6 +173,7 @@ def slices_from_spans(spans: t.Tuple[int], df: pd.DataFrame) -> t.List[slice]:
def drop_mother_label(index: pd.MultiIndex) -> np.ndarray: def drop_mother_label(index: pd.MultiIndex) -> np.ndarray:
"""Remove mother_label level from a MultiIndex."""
no_mother_label = index no_mother_label = index
if "mother_label" in index.names: if "mother_label" in index.names:
no_mother_label = index.droplevel("mother_label") no_mother_label = index.droplevel("mother_label")
......
#!/usr/bin/env python3
import re
import typing as t
import numpy as np
import pandas as pd
from agora.io.bridge import groupsort
from itertools import groupby
def mb_array_to_dict(mb_array: np.ndarray):
"""
Convert a lineage ndarray (trap, mother_id, daughter_id)
into a dictionary of lists ( mother_id ->[daughters_ids] )
"""
return {
(trap, mo): [(trap, d[0]) for d in daughters]
for trap, mo_da in groupsort(mb_array).items()
for mo, daughters in groupsort(mo_da).items()
}
...@@ -3,90 +3,161 @@ ...@@ -3,90 +3,161 @@
Functions to efficiently merge rows in DataFrames. Functions to efficiently merge rows in DataFrames.
""" """
import typing as t import typing as t
from copy import copy
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from utils_find_1st import cmp_larger, find_1st from utils_find_1st import cmp_larger, find_1st
from agora.utils.indexing import compare_indices, validate_association from agora.utils.indexing import index_isin
def group_merges(merges: np.ndarray) -> t.List[t.Tuple]:
"""
Convert merges into a list of merges for traps requiring multiple
merges and then for traps requiring single merges.
"""
left_tracks = merges[:, 0]
right_tracks = merges[:, 1]
# find traps requiring multiple merges
linr = merges[index_isin(left_tracks, right_tracks).flatten(), :]
rinl = merges[index_isin(right_tracks, left_tracks).flatten(), :]
# make unique and order merges for each trap
multi_merge = np.unique(np.concatenate((linr, rinl)), axis=0)
# find traps requiring a singe merge
single_merge = merges[
~index_isin(merges, multi_merge).all(axis=1).flatten(), :
]
# convert to lists of arrays
single_merge_list = [[sm] for sm in single_merge]
multi_merge_list = [
multi_merge[multi_merge[:, 0, 0] == trap_id, ...]
for trap_id in np.unique(multi_merge[:, 0, 0])
]
res = [*multi_merge_list, *single_merge_list]
return res
def merge_lineage(
lineage: np.ndarray, merges: np.ndarray
) -> (np.ndarray, np.ndarray):
"""
Use merges to update lineage information.
Check if merging causes any buds to have multiple mothers and discard
those incorrect merges.
Return updated lineage and merge arrays.
"""
flat_lineage = lineage.reshape(-1, 2)
bud_mother_dict = {
tuple(bud): mother for bud, mother in zip(lineage[:, 1], lineage[:, 0])
}
left_tracks = merges[:, 0]
# find left tracks that are in lineages
valid_lineages = index_isin(flat_lineage, left_tracks).flatten()
# group into multi- and then single merges
grouped_merges = group_merges(merges)
# perform merges
if valid_lineages.any():
# indices of each left track -> indices of rightmost right track
replacement_dict = {
tuple(contig_pair[0]): merge[-1][1]
for merge in grouped_merges
for contig_pair in merge
}
# if both key and value are buds, they must have the same mother
buds = lineage[:, 1]
incorrect_merges = [
key
for key in replacement_dict
if np.any(index_isin(buds, replacement_dict[key]).flatten())
and np.any(index_isin(buds, key).flatten())
and not np.array_equal(
bud_mother_dict[key],
bud_mother_dict[tuple(replacement_dict[key])],
)
]
if incorrect_merges:
# reassign incorrect merges so that they have no affect
for key in incorrect_merges:
replacement_dict[key] = key
# find only correct merges
new_merges = merges[
~index_isin(
merges[:, 0], np.array(incorrect_merges)
).flatten(),
...,
]
else:
new_merges = merges
# correct lineage information
# replace mother or bud index with index of rightmost track
flat_lineage[valid_lineages] = [
replacement_dict[tuple(index)]
for index in flat_lineage[valid_lineages]
]
else:
new_merges = merges
# reverse flattening
new_lineage = flat_lineage.reshape(-1, 2, 2)
# remove any duplicates
new_lineage = np.unique(new_lineage, axis=0)
return new_lineage, new_merges
def apply_merges(data: pd.DataFrame, merges: np.ndarray): def apply_merges(data: pd.DataFrame, merges: np.ndarray):
"""Split data in two, one subset for rows relevant for merging and one """
without them. It uses an array of source tracklets and target tracklets Generate a new data frame containing merged tracks.
to efficiently merge them.
Parameters Parameters
---------- ----------
data : pd.DataFrame data : pd.DataFrame
Input DataFrame. A Signal data frame.
merges : np.ndarray merges : np.ndarray
3-D ndarray where dimensions are (X,2,2): nmerges, source-target An array of pairs of (trap, cell) indices to merge.
pair and single-cell identifiers, respectively.
Examples
--------
FIXME: Add docs.
""" """
indices = data.index indices = data.index
if "mother_label" in indices.names: if "mother_label" in indices.names:
indices = indices.droplevel("mother_label") indices = indices.droplevel("mother_label")
valid_merges, indices = validate_association( indices = np.array(list(indices))
merges, np.array(list(indices)) # merges in the data frame's indices
) valid_merges = index_isin(merges, indices).all(axis=1).flatten()
# corresponding indices for the data frame in merges
# Assign non-merged selected_merges = merges[valid_merges, ...]
merged = data.loc[~indices] valid_indices = index_isin(indices, selected_merges).flatten()
# data not requiring merging
# Implement the merges and drop source rows. merged = data.loc[~valid_indices]
# TODO Use matrices to perform merges in batch # merge tracks
# for ecficiency
if valid_merges.any(): if valid_merges.any():
to_merge = data.loc[indices] to_merge = data.loc[valid_indices].copy()
targets, sources = zip(*merges[valid_merges]) left_indices = merges[valid_merges, 0]
for source, target in zip(sources, targets): right_indices = merges[valid_merges, 1]
target = tuple(target) # join left track with right track
to_merge.loc[target] = join_tracks_pair( for left_index, right_index in zip(left_indices, right_indices):
to_merge.loc[target].values, to_merge.loc[tuple(left_index)] = join_two_tracks(
to_merge.loc[tuple(source)].values, to_merge.loc[tuple(left_index)].values,
to_merge.loc[tuple(right_index)].values,
) )
to_merge.drop(map(tuple, sources), inplace=True) # drop indices for right tracks
to_merge.drop(map(tuple, right_indices), inplace=True)
# add to data not requiring merges
merged = pd.concat((merged, to_merge), names=data.index.names) merged = pd.concat((merged, to_merge), names=data.index.names)
return merged return merged
def join_tracks_pair(target: np.ndarray, source: np.ndarray) -> np.ndarray: def join_two_tracks(
""" left_track: np.ndarray, right_track: np.ndarray
Join two tracks and return the new value of the target. ) -> np.ndarray:
""" """Join two tracks and return the new one."""
target_copy = target new_track = left_track.copy()
end = find_1st(target_copy[::-1], 0, cmp_larger) # find last positive element by inverting track
target_copy[-end:] = source[-end:] end = find_1st(left_track[::-1], 0, cmp_larger)
return target_copy # merge tracks into one
new_track[-end:] = right_track[-end:]
return new_track
def group_merges(merges: np.ndarray) -> t.List[t.Tuple]:
# Return a list where the cell is present as source and target
# (multimerges)
sources_targets = compare_indices(merges[:, 0, :], merges[:, 1, :])
is_multimerge = sources_targets.any(axis=0) | sources_targets.any(axis=1)
is_monomerge = ~is_multimerge
multimerge_subsets = union_find(zip(*np.where(sources_targets)))
merge_groups = [merges[np.array(tuple(x))] for x in multimerge_subsets]
sorted_merges = list(map(sort_association, merge_groups))
# Ensure that source and target are at the edges ##################################################################
return [
*sorted_merges,
*[[event] for event in merges[is_monomerge]],
]
def union_find(lsts): def union_find(lsts):
...@@ -120,27 +191,3 @@ def sort_association(array: np.ndarray): ...@@ -120,27 +191,3 @@ def sort_association(array: np.ndarray):
[res.append(x) for x in np.flip(order).flatten() if x not in res] [res.append(x) for x in np.flip(order).flatten() if x not in res]
sorted_array = array[np.array(res)] sorted_array = array[np.array(res)]
return sorted_array return sorted_array
def merge_association(
association: np.ndarray, merges: np.ndarray
) -> np.ndarray:
grouped_merges = group_merges(merges)
flat_indices = association.reshape(-1, 2)
comparison_mat = compare_indices(merges[:, 0], flat_indices)
valid_indices = comparison_mat.any(axis=0)
if valid_indices.any(): # Where valid, perform transformation
replacement_d = {}
for dataset in grouped_merges:
for k in dataset:
replacement_d[tuple(k[0])] = dataset[-1][1]
flat_indices[valid_indices] = [
replacement_d[tuple(i)] for i in flat_indices[valid_indices]
]
merged_indices = flat_indices.reshape(-1, 2, 2)
return merged_indices
""" """
Orchestration module and network mid-level interfaces. Orchestration module and network mid-level interfaces.
""" """
from .version import __version__
...@@ -22,18 +22,16 @@ from requests.exceptions import HTTPError, Timeout ...@@ -22,18 +22,16 @@ from requests.exceptions import HTTPError, Timeout
################### Dask Methods ################################ ################### Dask Methods ################################
def format_segmentation(segmentation, tp): def format_segmentation(segmentation, tp):
"""Format a single timepoint into a dictionary. """
Format BABY's results from a single time point into a dictionary.
Parameters Parameters
------------ ------------
segmentation: list segmentation: list
A list of results, each result is the output of the crawler, which is JSON-encoded A list of results, each result is the output of BABY
crawler, which is JSON-encoded.
tp: int tp: int
the time point considered The time point.
Returns
--------
A dictionary containing the formatted results of BABY
""" """
# Segmentation is a list of dictionaries, ordered by trap # Segmentation is a list of dictionaries, ordered by trap
# Add trap information # Add trap information
...@@ -204,6 +202,7 @@ def choose_model_from_params( ...@@ -204,6 +202,7 @@ def choose_model_from_params(
------- -------
model_name : str model_name : str
""" """
# cameras prime95 has become sCMOS and evolve has EMCCD
valid_models = list(modelsets().keys()) valid_models = list(modelsets().keys())
# Apply modelset filter if specified # Apply modelset filter if specified
......
import itertools
import re
import typing as t
from pathlib import Path
import numpy as np
from baby import BabyCrawler, modelsets
from agora.abc import ParametersABC, StepABC
class BabyParameters(ParametersABC):
"""Parameters used for running BABY."""
def __init__(
self,
modelset_name,
clogging_thresh,
min_bud_tps,
isbud_thresh,
):
"""Initialise parameters for BABY."""
# pixel_size is specified in BABY's model sets
self.modelset_name = modelset_name
self.clogging_thresh = clogging_thresh
self.min_bud_tps = min_bud_tps
self.isbud_thresh = isbud_thresh
@classmethod
def default(cls, **kwargs):
"""Define default parameters; kwargs choose BABY model set."""
return cls(
modelset_name=get_modelset_name_from_params(**kwargs),
clogging_thresh=1,
min_bud_tps=3,
isbud_thresh=0.5,
)
def update_baby_modelset(self, path: t.Union[str, Path, t.Dict[str, str]]):
"""
Replace default BABY model and flattener.
Both are saved in a folder by our retraining script.
"""
if isinstance(path, dict):
weights_flattener = {k: Path(v) for k, v in path.items()}
else:
weights_dir = Path(path)
weights_flattener = {
"flattener_file": weights_dir.parent / "flattener.json",
"morph_model_file": weights_dir / "weights.h5",
}
self.update("modelset_name", weights_flattener)
class BabyRunner(StepABC):
"""
A BabyRunner object for cell segmentation.
Segments one time point at a time.
"""
def __init__(self, tiler, parameters=None, **kwargs):
"""Instantiate from a Tiler object."""
self.tiler = tiler
modelset_name = (
get_modelset_name_from_params(**kwargs)
if parameters is None
else parameters.modelset_name
)
tiler_z = self.tiler.shape[-3]
if f"{tiler_z}z" not in modelset_name:
raise KeyError(
f"Tiler z-stack ({tiler_z}) and model"
f" ({modelset_name}) do not match."
)
if parameters is None:
brain = modelsets.get(modelset_name)
else:
brain = modelsets.get(
modelset_name,
clogging_thresh=parameters.clogging_thresh,
min_bud_tps=parameters.min_bud_tps,
isbud_thresh=parameters.isbud_thresh,
)
self.crawler = BabyCrawler(brain)
self.brightfield_channel = self.tiler.ref_channel_index
@classmethod
def from_tiler(cls, parameters: BabyParameters, tiler):
"""Explicitly instantiate from a Tiler object."""
return cls(tiler, parameters)
def get_data(self, tp):
"""Get image and re-arrange axes."""
img_from_tiler = self.tiler.get_tp_data(tp, self.brightfield_channel)
# move z axis to the last axis; Baby expects (n, x, y, z)
img = np.moveaxis(img_from_tiler, 1, destination=-1)
return img
def _run_tp(
self,
tp,
refine_outlines=True,
assign_mothers=True,
with_edgemasks=True,
**kwargs,
):
"""Segment data from one time point."""
img = self.get_data(tp)
segmentation = self.crawler.step(
img,
refine_outlines=refine_outlines,
assign_mothers=assign_mothers,
with_edgemasks=with_edgemasks,
**kwargs,
)
res = format_segmentation(segmentation, tp)
return res
def get_modelset_name_from_params(
imaging_device="alcatras",
channel="brightfield",
camera="sCMOS",
zoom="60x",
n_stacks="5z",
):
"""Get the appropriate model set from BABY's trained models."""
# list of models - microscopy setups - for which BABY has been trained
# cameras prime95 and evolve have become sCMOS and EMCCD
possible_models = list(modelsets.remote_modelsets()["models"].keys())
# filter possible_models
params = [
str(x) if x is not None else ".+"
for x in [imaging_device, channel.lower(), camera, zoom, n_stacks]
]
params_regex = re.compile("-".join(params) + "$")
valid_models = [
res for res in filter(params_regex.search, possible_models)
]
# check that there are valid models
if len(valid_models) == 1:
return valid_models[0]
else:
raise KeyError(
"Error in finding BABY model sets matching {}".format(
", ".join(params)
)
)
def format_segmentation(segmentation, tp):
"""
Format BABY's results for a single time point into a dict.
The dict has BABY's outputs as keys and lists of the results
for each segmented cell as values.
Parameters
------------
segmentation: list
A list of BABY's results as dicts for each tile.
tp: int
The time point.
"""
# segmentation is a list of dictionaries for each tile
for i, tile_dict in enumerate(segmentation):
# assign the trap ID to each cell identified
tile_dict["trap"] = [i] * len(tile_dict["cell_label"])
# record mothers for each labelled cell
tile_dict["mother_assign_dynamic"] = np.array(
tile_dict["mother_assign"]
)[np.array(tile_dict["cell_label"], dtype=int) - 1]
# merge into a dict with BABY's outputs as keys and
# lists of results for all cells as values
merged = {
output: list(
itertools.chain.from_iterable(
tile_dict[output] for tile_dict in segmentation
)
)
for output in segmentation[0].keys()
}
# remove mother_assign
merged.pop("mother_assign", None)
# ensure that each value is a list of the same length
no_cells = min([len(v) for v in merged.values()])
merged = {k: v[:no_cells] for k, v in merged.items()}
# define time point key
merged["timepoint"] = [tp] * no_cells
return merged
"""Set up and run pipelines: tiling, segmentation, extraction, and then post-processing."""
import logging
import os
import re
import traceback
import typing as t
from copy import copy
from importlib.metadata import version
from pathlib import Path
from pprint import pprint
import baby
import baby.errors
import h5py
import numpy as np
import tensorflow as tf
from pathos.multiprocessing import Pool
from tqdm import tqdm
try:
if baby.__version__ == "v0.30.1":
from aliby.baby_sitter import BabyParameters, BabyRunner
except AttributeError:
from aliby.baby_client import BabyParameters, BabyRunner
import aliby.global_parameters as global_parameters
from agora.abc import ParametersABC, ProcessABC
from agora.io.metadata import MetaData
from agora.io.reader import StateReader
from agora.io.signal import Signal
from agora.io.writer import LinearBabyWriter, StateWriter, TilerWriter
from aliby.io.dataset import dispatch_dataset
from aliby.io.image import dispatch_image
from aliby.tile.tiler import Tiler, TilerParameters
from extraction.core.extractor import (
Extractor,
ExtractorParameters,
extraction_params_from_meta,
)
from postprocessor.core.postprocessing import (
PostProcessor,
PostProcessorParameters,
)
# stop warnings from TensorFlow
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
logging.getLogger("tensorflow").setLevel(logging.ERROR)
class PipelineParameters(ParametersABC):
"""Define parameters for the steps of the pipeline."""
_pool_index = None
def __init__(
self,
general,
tiler,
baby,
extraction,
postprocessing,
):
"""Initialise, but called by a class method - not directly."""
self.general = general
self.tiler = tiler
self.baby = baby
self.extraction = extraction
self.postprocessing = postprocessing
@classmethod
def default(
cls,
general={},
tiler={},
baby={},
extraction={},
postprocessing={},
):
"""
Initialise parameters for steps of the pipeline.
Some parameters are extracted from the log files.
Parameters
---------
general: dict
Parameters to set up the pipeline.
tiler: dict
Parameters for tiler.
baby: dict (optional)
Parameters for Baby.
extraction: dict (optional)
Parameters for extraction.
postprocessing: dict (optional)
Parameters for post-processing.
"""
if (
isinstance(general["expt_id"], Path)
and general["expt_id"].exists()
):
expt_id = str(general["expt_id"])
else:
expt_id = general["expt_id"]
directory = Path(general["directory"])
# get metadata from log files either locally or via OMERO
with dispatch_dataset(
expt_id,
**{k: general.get(k) for k in ("host", "username", "password")},
) as conn:
directory = directory / conn.unique_name
if not directory.exists():
directory.mkdir(parents=True)
# download logs for metadata
conn.cache_logs(directory)
try:
meta_d = MetaData(directory, None).load_logs()
except Exception as e:
logging.getLogger("aliby").warn(
f"WARNING:Metadata: error when loading: {e}"
)
minimal_default_meta = {
"channels": ["Brightfield"],
"ntps": [2000],
}
# set minimal metadata
meta_d = minimal_default_meta
# define default values for general parameters
tps = meta_d.get("ntps", 2000)
defaults = {
"general": dict(
id=expt_id,
distributed=0,
tps=tps,
directory=str(directory.parent),
filter="",
earlystop=global_parameters.earlystop,
logfile_level="INFO",
use_explog=True,
)
}
# update default values for general using inputs
for k, v in general.items():
if k not in defaults["general"]:
defaults["general"][k] = v
elif isinstance(v, dict):
for k2, v2 in v.items():
defaults["general"][k][k2] = v2
else:
defaults["general"][k] = v
# default Tiler parameters
defaults["tiler"] = TilerParameters.default(**tiler).to_dict()
# generate a backup channel for when logfile meta is available
# but not image metadata.
backup_ref_channel = None
if "channels" in meta_d and isinstance(
defaults["tiler"]["ref_channel"], str
):
backup_ref_channel = meta_d["channels"].index(
defaults["tiler"]["ref_channel"]
)
defaults["tiler"]["backup_ref_channel"] = backup_ref_channel
# default BABY parameters
defaults["baby"] = BabyParameters.default(**baby).to_dict()
# default Extraction parmeters
defaults["extraction"] = extraction_params_from_meta(meta_d)
# default PostProcessing parameters
defaults["postprocessing"] = PostProcessorParameters.default(
**postprocessing
).to_dict()
return cls(**{k: v for k, v in defaults.items()})
class Pipeline(ProcessABC):
"""
Initialise and run tiling, segmentation, extraction and post-processing.
Each step feeds the next one.
To customise parameters for any step use the PipelineParameters class.stem
"""
pipeline_steps = ["tiler", "baby", "extraction"]
step_sequence = [
"tiler",
"baby",
"extraction",
"postprocessing",
]
# specify the group in the h5 files written by each step
writer_groups = {
"tiler": ["trap_info"],
"baby": ["cell_info"],
"extraction": ["extraction"],
"postprocessing": ["postprocessing", "modifiers"],
}
writers = { # TODO integrate Extractor and PostProcessing in here
"tiler": [("tiler", TilerWriter)],
"baby": [("baby", LinearBabyWriter), ("state", StateWriter)],
}
def __init__(self, parameters: PipelineParameters, store=None):
"""Initialise - not usually called directly."""
super().__init__(parameters)
if store is not None:
store = Path(store)
self.store = store
@staticmethod
def setLogger(
folder, file_level: str = "INFO", stream_level: str = "WARNING"
):
"""Initialise and format logger."""
logger = logging.getLogger("aliby")
logger.setLevel(getattr(logging, file_level))
formatter = logging.Formatter(
"%(asctime)s - %(levelname)s:%(message)s",
datefmt="%Y-%m-%dT%H:%M:%S%z",
)
# for streams - stdout, files, etc.
ch = logging.StreamHandler()
ch.setLevel(getattr(logging, stream_level))
ch.setFormatter(formatter)
logger.addHandler(ch)
# create file handler that logs even debug messages
fh = logging.FileHandler(Path(folder) / "aliby.log", "w+")
fh.setLevel(getattr(logging, file_level))
fh.setFormatter(formatter)
logger.addHandler(fh)
@classmethod
def from_folder(cls, dir_path):
"""
Re-process all h5 files in a given folder.
All files must share the same parameters, even if they have different channels.
Parameters
---------
dir_path : str or Pathlib
Folder containing the files.
"""
# find h5 files
dir_path = Path(dir_path)
files = list(dir_path.rglob("*.h5"))
assert len(files), "No valid files found in folder"
fpath = files[0]
# TODO add support for non-standard unique folder names
with h5py.File(fpath, "r") as f:
pipeline_parameters = PipelineParameters.from_yaml(
f.attrs["parameters"]
)
pipeline_parameters.general["directory"] = dir_path.parent
pipeline_parameters.general["filter"] = [fpath.stem for fpath in files]
# fix legacy post-processing parameters
post_process_params = pipeline_parameters.postprocessing.get(
"parameters", None
)
if post_process_params:
pipeline_parameters.postprocessing["param_sets"] = copy(
post_process_params
)
del pipeline_parameters.postprocessing["parameters"]
return cls(pipeline_parameters)
@classmethod
def from_existing_h5(cls, fpath):
"""
Re-process an existing h5 file.
Not suitable for more than one file.
Parameters
---------
fpath: str
Name of file.
"""
with h5py.File(fpath, "r") as f:
pipeline_parameters = PipelineParameters.from_yaml(
f.attrs["parameters"]
)
directory = Path(fpath).parent
pipeline_parameters.general["directory"] = directory
pipeline_parameters.general["filter"] = Path(fpath).stem
post_process_params = pipeline_parameters.postprocessing.get(
"parameters", None
)
if post_process_params:
pipeline_parameters.postprocessing["param_sets"] = copy(
post_process_params
)
del pipeline_parameters.postprocessing["parameters"]
return cls(pipeline_parameters, store=directory)
@property
def logger(self):
return logging.getLogger("aliby")
def run(self):
"""Run separate pipelines for all positions in an experiment."""
# display configuration
config = self.parameters.to_dict()
print("\nalibylite\n")
try:
logging.getLogger("aliby").info(f"Using Baby {baby.__version__}.")
except AttributeError:
logging.getLogger("aliby").info("Using original Baby.")
for step in config:
print("\n---\n" + step + "\n---")
pprint(config[step])
print()
# extract from configuration
expt_id = config["general"]["id"]
distributed = config["general"]["distributed"]
position_filter = config["general"]["filter"]
root_dir = Path(config["general"]["directory"])
self.server_info = {
k: config["general"].get(k)
for k in ("host", "username", "password")
}
dispatcher = dispatch_dataset(expt_id, **self.server_info)
logging.getLogger("aliby").info(
f"Fetching data using {dispatcher.__class__.__name__}."
)
# get log files, either locally or via OMERO
with dispatcher as conn:
position_ids = conn.get_images()
directory = self.store or root_dir / conn.unique_name
if not directory.exists():
directory.mkdir(parents=True)
# get logs to use for metadata
conn.cache_logs(directory)
print("Positions available:")
for i, pos in enumerate(position_ids.keys()):
print("\t" + f"{i}: " + pos.split(".")[0])
# update configuration
self.parameters.general["directory"] = str(directory)
config["general"]["directory"] = directory
self.setLogger(directory)
# pick particular positions if desired
if position_filter is not None:
if isinstance(position_filter, list):
position_ids = {
k: v
for filt in position_filter
for k, v in self.apply_filter(position_ids, filt).items()
}
else:
position_ids = self.apply_filter(position_ids, position_filter)
if not len(position_ids):
raise Exception("No images to segment.")
else:
print("\nPositions selected:")
for pos in position_ids:
print("\t" + pos.split(".")[0])
# create and run pipelines
if distributed != 0:
# multiple cores
with Pool(distributed) as p:
results = p.map(
lambda x: self.run_one_position(*x),
[
(position_id, i)
for i, position_id in enumerate(position_ids.items())
],
)
else:
# single core
results = [
self.run_one_position((position_id, position_id_path), 1)
for position_id, position_id_path in tqdm(position_ids.items())
]
return results
def apply_filter(self, position_ids: dict, position_filter: int or str):
"""
Select positions.
Either pick a particular position or use a regular expression
to parse their file names.
"""
if isinstance(position_filter, str):
# pick positions using a regular expression
position_ids = {
k: v
for k, v in position_ids.items()
if re.search(position_filter, k)
}
elif isinstance(position_filter, int):
# pick a particular position
position_ids = {
k: v
for i, (k, v) in enumerate(position_ids.items())
if i == position_filter
}
return position_ids
def run_one_position(
self,
name_image_id: t.Tuple[str, str or Path or int],
index: t.Optional[int] = None,
):
"""Set up and run a pipeline for one position."""
self._pool_index = index
name, image_id = name_image_id
# session is defined by calling pipe_pipeline.
# can it be deleted here?
session = None
run_kwargs = {"extraction": {"cell_labels": None, "masks": None}}
try:
pipe, session = self.setup_pipeline(image_id, name)
loaded_writers = {
name: writer(pipe["filename"])
for k in self.step_sequence
if k in self.writers
for name, writer in self.writers[k]
}
writer_overwrite_kwargs = {
"state": loaded_writers["state"].datatypes.keys(),
"baby": ["mother_assign"],
}
# START PIPELINE
frac_clogged_traps = 0.0
min_process_from = min(pipe["process_from"].values())
with dispatch_image(image_id)(
image_id, **self.server_info
) as image:
# initialise steps
if "tiler" not in pipe["steps"]:
pipe["config"]["tiler"]["position_name"] = name.split(".")[
0
]
# loads local meta data from image
pipe["steps"]["tiler"] = Tiler.from_image(
image,
TilerParameters.from_dict(pipe["config"]["tiler"]),
)
if pipe["process_from"]["baby"] < pipe["tps"]:
session = initialise_tf(2)
pipe["steps"]["baby"] = BabyRunner.from_tiler(
BabyParameters.from_dict(pipe["config"]["baby"]),
pipe["steps"]["tiler"],
)
if pipe["trackers_state"]:
pipe["steps"]["baby"].crawler.tracker_states = pipe[
"trackers_state"
]
if pipe["process_from"]["extraction"] < pipe["tps"]:
exparams = ExtractorParameters.from_dict(
pipe["config"]["extraction"]
)
pipe["steps"]["extraction"] = Extractor.from_tiler(
exparams,
store=pipe["filename"],
tiler=pipe["steps"]["tiler"],
)
# initiate progress bar
progress_bar = tqdm(
range(min_process_from, pipe["tps"]),
desc=image.name,
initial=min_process_from,
total=pipe["tps"],
)
# run through time points
for i in progress_bar:
if (
frac_clogged_traps
< pipe["earlystop"]["thresh_pos_clogged"]
or i < pipe["earlystop"]["min_tp"]
):
# run through steps
for step in self.pipeline_steps:
if i >= pipe["process_from"][step]:
# perform step
try:
result = pipe["steps"][step].run_tp(
i, **run_kwargs.get(step, {})
)
except baby.errors.Clogging:
logging.getLogger("aliby").warn(
"WARNING:Clogging threshold exceeded in BABY."
)
# write result to h5 file using writers
# extractor writes to h5 itself
if step in loaded_writers:
loaded_writers[step].write(
data=result,
overwrite=writer_overwrite_kwargs.get(
step, []
),
tp=i,
meta={"last_processed": i},
)
# clean up
if (
step == "tiler"
and i == min_process_from
):
logging.getLogger("aliby").info(
f"Found {pipe['steps']['tiler'].no_tiles} traps in {image.name}"
)
elif step == "baby":
# write state
loaded_writers["state"].write(
data=pipe["steps"][
step
].crawler.tracker_states,
overwrite=loaded_writers[
"state"
].datatypes.keys(),
tp=i,
)
elif step == "extraction":
# remove masks and labels after extraction
for k in ["masks", "cell_labels"]:
run_kwargs[step][k] = None
# check and report clogging
frac_clogged_traps = check_earlystop(
pipe["filename"],
pipe["earlystop"],
pipe["steps"]["tiler"].tile_size,
)
if frac_clogged_traps > 0.3:
self.log(
f"{name}:Clogged_traps:{frac_clogged_traps}"
)
frac = np.round(frac_clogged_traps * 100)
progress_bar.set_postfix_str(f"{frac} Clogged")
else:
# stop if too many traps are clogged
self.log(
f"{name}:Stopped early at time {i} with {frac_clogged_traps} clogged traps"
)
pipe["meta"].add_fields({"end_status": "Clogged"})
break
pipe["meta"].add_fields({"last_processed": i})
pipe["meta"].add_fields({"end_status": "Success"})
# run post-processing
post_proc_params = PostProcessorParameters.from_dict(
pipe["config"]["postprocessing"]
)
PostProcessor(pipe["filename"], post_proc_params).run()
self.log("Analysis finished successfully.", "info")
return 1
except Exception as e:
# catch bugs during setup or run time
logging.exception(
f"{name}: Exception caught.",
exc_info=True,
)
# print the type, value, and stack trace of the exception
traceback.print_exc()
raise e
finally:
close_session(session)
def setup_pipeline(
self,
image_id: int,
name: str,
) -> t.Tuple[
Path,
MetaData,
t.Dict,
int,
t.Dict,
t.Dict,
t.Optional[int],
t.List[np.ndarray],
]:
"""
Initialise steps in a pipeline.
If necessary use a file to re-start experiments already partly run.
Parameters
----------
image_id : int or str
Identifier of a data set in an OMERO server or a filename.
Returns
-------
pipe: dict
With keys
filename: str
Path to a h5 file to write to.
meta: object
agora.io.metadata.MetaData object
config: dict
Configuration parameters.
process_from: dict
Gives time points from which each step of the
pipeline should start.
tps: int
Number of time points.
steps: dict
earlystop: dict
Parameters to check whether the pipeline should
be stopped.
trackers_state: list
States of any trackers from earlier runs.
session: None
"""
pipe = {}
config = self.parameters.to_dict()
# TODO Verify if session must be passed
session = None
pipe["earlystop"] = config["general"].get("earlystop", None)
pipe["process_from"] = {k: 0 for k in self.pipeline_steps}
pipe["steps"] = {}
# check overwriting
overwrite_id = config["general"].get("overwrite", 0)
overwrite = {step: True for step in self.step_sequence}
if overwrite_id and overwrite_id is not True:
overwrite = {
step: self.step_sequence.index(overwrite_id) < i
for i, step in enumerate(self.step_sequence, 1)
}
# set up
directory = config["general"]["directory"]
pipe["trackers_state"] = []
with dispatch_image(image_id)(image_id, **self.server_info) as image:
pipe["filename"] = Path(f"{directory}/{image.name}.h5")
# load metadata from h5 file
pipe["meta"] = MetaData(directory, pipe["filename"])
from_start = True if np.any(overwrite.values()) else False
# remove existing h5 file if overwriting
if (
from_start
and (
config["general"].get("overwrite", False)
or np.all(list(overwrite.values()))
)
and pipe["filename"].exists()
):
os.remove(pipe["filename"])
# if the file exists with no previous segmentation use its tiler
if pipe["filename"].exists():
self.log("Result file exists.", "info")
if not overwrite["tiler"]:
tiler_params_dict = TilerParameters.default().to_dict()
tiler_params_dict["position_name"] = name.split(".")[0]
tiler_params = TilerParameters.from_dict(tiler_params_dict)
pipe["steps"]["tiler"] = Tiler.from_h5(
image, pipe["filename"], tiler_params
)
try:
(
process_from,
trackers_state,
overwrite,
) = self._load_config_from_file(
pipe["filename"],
pipe["process_from"],
pipe["trackers_state"],
overwrite,
)
# get state array
pipe["trackers_state"] = (
[]
if overwrite["baby"]
else StateReader(
pipe["filename"]
).get_formatted_states()
)
config["tiler"] = pipe["steps"][
"tiler"
].parameters.to_dict()
except Exception:
self.log("Overwriting tiling data")
if config["general"]["use_explog"]:
pipe["meta"].run()
pipe["config"] = config
# add metadata not in the log file
pipe["meta"].add_fields(
{
"aliby_version": version("aliby"),
"baby_version": version("aliby-baby"),
"omero_id": config["general"]["id"],
"image_id": (
image_id
if isinstance(image_id, int)
else str(image_id)
),
"parameters": PipelineParameters.from_dict(
config
).to_yaml(),
}
)
pipe["tps"] = min(config["general"]["tps"], image.data.shape[0])
return pipe, session
def check_earlystop(filename: str, es_parameters: dict, tile_size: int):
"""
Check recent time points for tiles with too many cells.
Returns the fraction of clogged tiles, where clogged tiles have
too many cells or too much of their area covered by cells.
Parameters
----------
filename: str
Name of h5 file.
es_parameters: dict
Parameters defining when early stopping should happen.
For example:
{'min_tp': 100,
'thresh_pos_clogged': 0.4,
'thresh_trap_ncells': 8,
'thresh_trap_area': 0.9,
'ntps_to_eval': 5}
tile_size: int
Size of tile.
"""
# get the area of the cells organised by trap and cell number
s = Signal(filename)
df = s.get_raw("/extraction/general/None/area")
# check the latest time points only
cells_used = df[
df.columns[-1 - es_parameters["ntps_to_eval"] : -1]
].dropna(how="all")
# find tiles with too many cells
traps_above_nthresh = (
cells_used.groupby("trap").count().apply(np.mean, axis=1)
> es_parameters["thresh_trap_ncells"]
)
# find tiles with cells covering too great a fraction of the tiles' area
traps_above_athresh = (
cells_used.groupby("trap").sum().apply(np.mean, axis=1) / tile_size**2
> es_parameters["thresh_trap_area"]
)
return (traps_above_nthresh & traps_above_athresh).mean()
def close_session(session):
"""Close session for multiprocessing."""
if session:
session.close()
def initialise_tf(version):
"""Initialise tensorflow."""
if version == 1:
core_config = tf.ConfigProto()
core_config.gpu_options.allow_growth = True
session = tf.Session(config=core_config)
return session
if version == 2:
gpus = tf.config.experimental.list_physical_devices("GPU")
if gpus:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
logical_gpus = tf.config.experimental.list_logical_devices("GPU")
print(
len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs"
)
return None
# parameters to stop the pipeline when exceeded
earlystop = dict(
min_tp=100,
thresh_pos_clogged=0.4,
thresh_trap_ncells=8,
thresh_trap_area=0.9,
ntps_to_eval=5,
)
# imaging properties of the microscope
imaging_specifications = {
"pixel_size": 0.236,
"z_size": 0.6,
"spacing": 0.6,
}
# possible imaging channels
possible_imaging_channels = [
"Citrine",
"GFP",
"GFPFast",
"mCherry",
"Flavin",
"Citrine",
"mKO2",
"Cy5",
"pHluorin405",
"pHluorin488",
]
# functions to apply to the fluorescence of each cell
fluorescence_functions = [
"mean",
"median",
"std",
"imBackground",
"max5px_median",
]