Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • swain-lab/aliby/aliby-mirror
  • swain-lab/aliby/alibylite
2 results
Show changes
Commits on Source (60)
Showing
with 2898 additions and 2558 deletions
Source diff could not be displayed: it is too large. Options to address this: view the blob.
...@@ -33,7 +33,7 @@ pathos = "^0.2.8" # Lambda-friendly multithreading ...@@ -33,7 +33,7 @@ pathos = "^0.2.8" # Lambda-friendly multithreading
p-tqdm = "^1.3.3" p-tqdm = "^1.3.3"
pandas = ">=1.3.3" pandas = ">=1.3.3"
py-find-1st = "^1.1.5" # Fast indexing py-find-1st = "^1.1.5" # Fast indexing
scikit-learn = ">=1.0.2" # Used for an extraction metric scikit-learn = ">=1.0.2, <1.3" # Used for an extraction metric
scipy = ">=1.7.3" scipy = ">=1.7.3"
# Pipeline + I/O # Pipeline + I/O
...@@ -46,14 +46,11 @@ xmltodict = "^0.13.0" # read ome-tiff metadata ...@@ -46,14 +46,11 @@ xmltodict = "^0.13.0" # read ome-tiff metadata
zarr = "^2.14.0" zarr = "^2.14.0"
GitPython = "^3.1.27" GitPython = "^3.1.27"
h5py = "2.10" # File I/O h5py = "2.10" # File I/O
aliby-baby = "^0.1.17"
# Networking # Networking
omero-py = { version = ">=5.6.2", optional = true } # contact omero server omero-py = { version = ">=5.6.2", optional = true } # contact omero server
# Baby segmentation
aliby-baby = {version = "^0.1.17", optional=true}
# Postprocessing # Postprocessing
[tool.poetry.group.pp.dependencies] [tool.poetry.group.pp.dependencies]
leidenalg = "^0.8.8" leidenalg = "^0.8.8"
...@@ -113,7 +110,6 @@ grid-strategy = {version = "^0.0.1", optional=true} ...@@ -113,7 +110,6 @@ grid-strategy = {version = "^0.0.1", optional=true}
[tool.poetry.extras] [tool.poetry.extras]
omero = ["omero-py"] omero = ["omero-py"]
baby = ["aliby-baby"]
[tool.black] [tool.black]
line-length = 79 line-length = 79
......
...@@ -17,16 +17,14 @@ atomic = t.Union[int, float, str, bool] ...@@ -17,16 +17,14 @@ atomic = t.Union[int, float, str, bool]
class ParametersABC(ABC): class ParametersABC(ABC):
""" """
Defines parameters as attributes and allows parameters to Define parameters as attributes and allow parameters to
be converted to either a dictionary or to yaml. be converted to either a dictionary or to yaml.
No attribute should be called "parameters"! No attribute should be called "parameters"!
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
""" """Define parameters as attributes."""
Defines parameters as attributes
"""
assert ( assert (
"parameters" not in kwargs "parameters" not in kwargs
), "No attribute should be named parameters" ), "No attribute should be named parameters"
...@@ -35,8 +33,9 @@ class ParametersABC(ABC): ...@@ -35,8 +33,9 @@ class ParametersABC(ABC):
def to_dict(self, iterable="null") -> t.Dict: def to_dict(self, iterable="null") -> t.Dict:
""" """
Recursive function to return a nested dictionary of the Return a nested dictionary of the attributes of the class instance.
attributes of the class instance.
Use recursion.
""" """
if isinstance(iterable, dict): if isinstance(iterable, dict):
if any( if any(
...@@ -62,7 +61,8 @@ class ParametersABC(ABC): ...@@ -62,7 +61,8 @@ class ParametersABC(ABC):
def to_yaml(self, path: Union[Path, str] = None): def to_yaml(self, path: Union[Path, str] = None):
""" """
Returns a yaml stream of the attributes of the class instance. Return a yaml stream of the attributes of the class instance.
If path is provided, the yaml stream is saved there. If path is provided, the yaml stream is saved there.
Parameters Parameters
...@@ -81,9 +81,7 @@ class ParametersABC(ABC): ...@@ -81,9 +81,7 @@ class ParametersABC(ABC):
@classmethod @classmethod
def from_yaml(cls, source: Union[Path, str]): def from_yaml(cls, source: Union[Path, str]):
""" """Return instance from a yaml filename or stdin."""
Returns instance from a yaml filename or stdin
"""
is_buffer = True is_buffer = True
try: try:
if Path(source).exists(): if Path(source).exists():
...@@ -107,7 +105,8 @@ class ParametersABC(ABC): ...@@ -107,7 +105,8 @@ class ParametersABC(ABC):
def update(self, name: str, new_value): def update(self, name: str, new_value):
""" """
Update values recursively Update values recursively.
if name is a dictionary, replace data where existing found or add if not. if name is a dictionary, replace data where existing found or add if not.
It warns against type changes. It warns against type changes.
...@@ -116,7 +115,6 @@ class ParametersABC(ABC): ...@@ -116,7 +115,6 @@ class ParametersABC(ABC):
If a leaf node that is to be changed is a collection, it adds the new elements. If a leaf node that is to be changed is a collection, it adds the new elements.
""" """
assert name not in ( assert name not in (
"parameters", "parameters",
"params", "params",
...@@ -179,7 +177,8 @@ def add_to_collection( ...@@ -179,7 +177,8 @@ def add_to_collection(
class ProcessABC(ABC): class ProcessABC(ABC):
""" """
Base class for processes. Base class for processes.
Defines parameters as attributes and requires run method to be defined.
Define parameters as attributes and requires a run method.
""" """
def __init__(self, parameters): def __init__(self, parameters):
...@@ -190,8 +189,8 @@ class ProcessABC(ABC): ...@@ -190,8 +189,8 @@ class ProcessABC(ABC):
""" """
self._parameters = parameters self._parameters = parameters
# convert parameters to dictionary # convert parameters to dictionary
# and then define each parameter as an attribute
for k, v in parameters.to_dict().items(): for k, v in parameters.to_dict().items():
# define each parameter as an attribute
setattr(self, k, v) setattr(self, k, v)
@property @property
...@@ -243,11 +242,9 @@ class StepABC(ProcessABC): ...@@ -243,11 +242,9 @@ class StepABC(ProcessABC):
@timer @timer
def run_tp(self, tp: int, **kwargs): def run_tp(self, tp: int, **kwargs):
""" """Time and log the timing of a step."""
Time and log the timing of a step.
"""
return self._run_tp(tp, **kwargs) return self._run_tp(tp, **kwargs)
def run(self): def run(self):
# Replace run with run_tp # Replace run with run_tp
raise Warning("Steps use run_tp instead of run") raise Warning("Steps use run_tp instead of run.")
...@@ -14,183 +14,170 @@ from utils_find_1st import cmp_equal, find_1st ...@@ -14,183 +14,170 @@ from utils_find_1st import cmp_equal, find_1st
class Cells: class Cells:
""" """
Extracts information from an h5 file. This class accesses: Extract information from an h5 file.
Use output from BABY to find cells detected, get, and fill, edge masks
and retrieve mother-bud relationships.
This class accesses in the h5 file:
'cell_info', which contains 'angles', 'cell_label', 'centres', 'cell_info', which contains 'angles', 'cell_label', 'centres',
'edgemasks', 'ellipse_dims', 'mother_assign', 'mother_assign_dynamic', 'edgemasks', 'ellipse_dims', 'mother_assign', 'mother_assign_dynamic',
'radii', 'timepoint', 'trap'. 'radii', 'timepoint', and 'trap'. All of which except for 'edgemasks'
All of these except for 'edgemasks' are a 1D ndarray. are a 1D ndarray.
'trap_info', which contains 'drifts', 'trap_locations' 'trap_info', which contains 'drifts', and 'trap_locations'.
The "timepoint", "cell_label", and "trap" variables are mutually consistent
1D lists.
Examples are self["timepoint"][self.get_idx(1, 3)] to find the time points
where cell 1 was present in trap 3.
""" """
def __init__(self, filename, path="cell_info"): def __init__(self, filename, path="cell_info"):
"""Initialise from a filename."""
self.filename: t.Optional[t.Union[str, Path]] = filename self.filename: t.Optional[t.Union[str, Path]] = filename
self.cinfo_path: t.Optional[str] = path self.cinfo_path: t.Optional[str] = path
self._edgemasks: t.Optional[str] = None self._edgemasks: t.Optional[str] = None
self._tile_size: t.Optional[int] = None self._tile_size: t.Optional[int] = None
def __getitem__(self, item):
"""
Dynamically fetch data from the h5 file and save as an attribute.
These attributes are accessed like dict keys.
"""
assert item != "edgemasks", "Edgemasks must not be loaded as a whole"
_item = "_" + item
if not hasattr(self, _item):
setattr(self, _item, self.fetch(item))
return getattr(self, _item)
def fetch(self, path):
"""Get data from the h5 file."""
with h5py.File(self.filename, mode="r") as f:
return f[self.cinfo_path][path][()]
@classmethod @classmethod
def from_source(cls, source: t.Union[Path, str]): def from_source(cls, source: t.Union[Path, str]):
"""Ensure initiating file is a Path object."""
return cls(Path(source)) return cls(Path(source))
def _log(self, message: str, level: str = "warn"): def _log(self, message: str, level: str = "warn"):
# Log messages in the corresponding level """Log messages in the corresponding level."""
logger = logging.getLogger("aliby") logger = logging.getLogger("aliby")
getattr(logger, level)(f"{self.__class__.__name__}: {message}") getattr(logger, level)(f"{self.__class__.__name__}: {message}")
@staticmethod @staticmethod
def _asdense(array: np.ndarray): def asdense(array: np.ndarray):
"""Convert sparse array to dense array."""
if not isdense(array): if not isdense(array):
array = array.todense() array = array.todense()
return array return array
@staticmethod @staticmethod
def _astype(array: np.ndarray, kind: str): def astype(array: np.ndarray, kind: str):
# Convert sparse arrays if needed and if kind is 'mask' it fills the outline """Convert sparse arrays if needed; if kind is 'mask' fill the outline."""
array = Cells._asdense(array) array = Cells.asdense(array)
if kind == "mask": if kind == "mask":
array = ndimage.binary_fill_holes(array).astype(bool) array = ndimage.binary_fill_holes(array).astype(bool)
return array return array
def _get_idx(self, cell_id: int, trap_id: int): def get_idx(self, cell_id: int, trap_id: int):
# returns boolean array of time points where both the cell with cell_id and the trap with trap_id exist """Return boolean array giving indices for a cell_id and trap_id."""
return (self["cell_label"] == cell_id) & (self["trap"] == trap_id) return (self["cell_label"] == cell_id) & (self["trap"] == trap_id)
@property @property
def max_labels(self) -> t.List[int]: def max_labels(self) -> t.List[int]:
return [max((0, *self.labels_in_trap(i))) for i in range(self.ntraps)] """Return the maximum cell label per tile."""
return [
max((0, *self.cell_labels_in_trap(i))) for i in range(self.ntraps)
]
@property @property
def max_label(self) -> int: def max_label(self) -> int:
"""Return the maximum cell label over all tiles."""
return sum(self.max_labels) return sum(self.max_labels)
@property @property
def ntraps(self) -> int: def ntraps(self) -> int:
# find the number of traps from the h5 file """Find the number of tiles, or traps."""
with h5py.File(self.filename, mode="r") as f: with h5py.File(self.filename, mode="r") as f:
return len(f["trap_info/trap_locations"][()]) return len(f["trap_info/trap_locations"][()])
@property @property
def tinterval(self): def tinterval(self):
"""Return time interval in seconds."""
with h5py.File(self.filename, mode="r") as f: with h5py.File(self.filename, mode="r") as f:
return f.attrs["time_settings/timeinterval"] return f.attrs["time_settings/timeinterval"]
@property @property
def traps(self) -> t.List[int]: def traps(self) -> t.List[int]:
# returns a list of traps """List unique tile, or trap, IDs."""
return list(set(self["trap"])) return list(set(self["trap"]))
@property @property
def tile_size(self) -> t.Union[int, t.Tuple[int], None]: def tile_size(self) -> t.Union[int, t.Tuple[int], None]:
"""Give the x- and y- sizes of a tile."""
if self._tile_size is None: if self._tile_size is None:
with h5py.File(self.filename, mode="r") as f: with h5py.File(self.filename, mode="r") as f:
# self._tile_size = f["trap_info/tile_size"][0]
self._tile_size = f["cell_info/edgemasks"].shape[1:] self._tile_size = f["cell_info/edgemasks"].shape[1:]
return self._tile_size return self._tile_size
def nonempty_tp_in_trap(self, trap_id: int) -> set: def nonempty_tp_in_trap(self, trap_id: int) -> set:
# given a trap_id returns time points in which cells are available """Given a tile, return time points for which cells are available."""
return set(self["timepoint"][self["trap"] == trap_id]) return set(self["timepoint"][self["trap"] == trap_id])
@property @property
def edgemasks(self) -> t.List[np.ndarray]: def edgemasks(self) -> t.List[np.ndarray]:
# returns the masks per tile """Return a list of masks for every cell at every trap and time point."""
if self._edgemasks is None: if self._edgemasks is None:
edgem_path: str = "edgemasks" edgem_path: str = "edgemasks"
self._edgemasks = self._fetch(edgem_path) self._edgemasks = self.fetch(edgem_path)
return self._edgemasks return self._edgemasks
@property @property
def labels(self) -> t.List[t.List[int]]: def labels(self) -> t.List[t.List[int]]:
""" """Return all cell labels per tile as a set for all tiles."""
Return all cell labels in object return [self.cell_labels_in_trap(trap) for trap in range(self.ntraps)]
We use mother_assign to list traps because it is the only property that appears even
when no cells are found
"""
return [self.labels_in_trap(trap) for trap in range(self.ntraps)]
def max_labels_in_frame(self, frame: int) -> t.List[int]: def max_labels_in_frame(self, final_time_point: int) -> t.List[int]:
# Return the maximum label for each trap in the given frame """Get the maximal cell label for each tile within a frame of time."""
max_labels = [ max_labels = [
self["cell_label"][ self["cell_label"][
(self["timepoint"] <= frame) & (self["trap"] == trap_id) (self["timepoint"] <= final_time_point)
& (self["trap"] == trap_id)
] ]
for trap_id in range(self.ntraps) for trap_id in range(self.ntraps)
] ]
return [max([0, *labels]) for labels in max_labels] return [max([0, *labels]) for labels in max_labels]
def where(self, cell_id: int, trap_id: int): def where(self, cell_id: int, trap_id: int):
""" """Return time points, indices, and edge masks for a cell and trap."""
Parameters idx = self.get_idx(cell_id, trap_id)
----------
cell_id: int
Cell index
trap_id: int
Trap index
Returns
----------
indices int array
boolean mask array
edge_ix int array
"""
indices = self._get_idx(cell_id, trap_id)
edgem_ix = self._edgem_where(cell_id, trap_id)
return ( return (
self["timepoint"][indices], self["timepoint"][idx],
indices, idx,
edgem_ix, self.edgemasks_where(cell_id, trap_id),
) )
def mask(self, cell_id, trap_id): def mask(self, cell_id, trap_id):
""" """Return the times and the filled edge masks for a cell and trap."""
Returns the times and the binary masks of a given cell in a given tile.
Parameters
----------
cell_id : int
The unique ID of the cell.
tile_id : int
The unique ID of the tile.
Returns
-------
Tuple[np.ndarray, np.ndarray]
The times when the binary masks were taken and the binary masks of the given cell in the given tile.
"""
times, outlines = self.outline(cell_id, trap_id) times, outlines = self.outline(cell_id, trap_id)
return times, np.array( return times, np.array(
[ndimage.morphology.binary_fill_holes(o) for o in outlines] [ndimage.morphology.binary_fill_holes(o) for o in outlines]
) )
def at_time( def at_time(
self, timepoint: t.Iterable[int], kind="mask" self, timepoint: int, kind="mask"
) -> t.List[t.List[np.ndarray]]: ) -> t.List[t.List[np.ndarray]]:
""" """Return a dict with traps as keys and cell masks as values for a time point."""
Returns a list of lists of binary masks in a given list of time points. idx = self["timepoint"] == timepoint
traps = self["trap"][idx]
Parameters edgemasks = self.edgemasks_from_idx(idx)
----------
timepoints : Iterable[int]
The list of time points for which to return the binary masks.
kind : str, optional
The type of binary masks to return, by default "mask".
Returns
-------
List[List[np.ndarray]]
A list of lists with binary masks grouped by tile IDs.
"""
ix = self["timepoint"] == timepoint
traps = self["trap"][ix]
edgemasks = self._edgem_from_masking(ix)
masks = [ masks = [
self._astype(edgemask, kind) Cells.astype(edgemask, kind)
for edgemask in edgemasks for edgemask in edgemasks
if edgemask.any() if edgemask.any()
] ]
...@@ -199,22 +186,7 @@ class Cells: ...@@ -199,22 +186,7 @@ class Cells:
def at_times( def at_times(
self, timepoints: t.Iterable[int], kind="mask" self, timepoints: t.Iterable[int], kind="mask"
) -> t.List[t.List[np.ndarray]]: ) -> t.List[t.List[np.ndarray]]:
""" """Return a list of lists of cell masks one for specified time point."""
Returns a list of lists of binary masks for a given list of time points.
Parameters
----------
timepoints : Iterable[int]
The list of time points for which to return the binary masks.
kind : str, optional
The type of binary masks to return, by default "mask".
Returns
-------
List[List[np.ndarray]]
A list of lists with binary masks grouped by tile IDs.
"""
return [ return [
[ [
np.stack(tile_masks) if len(tile_masks) else [] np.stack(tile_masks) if len(tile_masks) else []
...@@ -226,91 +198,77 @@ class Cells: ...@@ -226,91 +198,77 @@ class Cells:
def group_by_traps( def group_by_traps(
self, traps: t.Collection, cell_labels: t.Collection self, traps: t.Collection, cell_labels: t.Collection
) -> t.Dict[int, t.List[int]]: ) -> t.Dict[int, t.List[int]]:
""" """Return a dict with traps as keys and a list of labels as values."""
Returns a dict with traps as keys and list of labels as value.
Note that the total number of traps are calculated from Cells.traps.
"""
iterator = groupby(zip(traps, cell_labels), lambda x: x[0]) iterator = groupby(zip(traps, cell_labels), lambda x: x[0])
d = {key: [x[1] for x in group] for key, group in iterator} d = {key: [x[1] for x in group] for key, group in iterator}
d = {i: d.get(i, []) for i in self.traps} d = {i: d.get(i, []) for i in self.traps}
return d return d
def labels_in_trap(self, trap_id: int) -> t.Set[int]: def cell_labels_in_trap(self, trap_id: int) -> t.Set[int]:
# return set of cell ids for a given trap """Return unique cell labels for a given trap."""
return set((self["cell_label"][self["trap"] == trap_id])) return set((self["cell_label"][self["trap"] == trap_id]))
def labels_at_time(self, timepoint: int) -> t.Dict[int, t.List[int]]: def labels_at_time(self, timepoint: int) -> t.Dict[int, t.List[int]]:
"""Return a dict with traps as keys and cell labels as values for a time point."""
labels = self["cell_label"][self["timepoint"] == timepoint] labels = self["cell_label"][self["timepoint"] == timepoint]
traps = self["trap"][self["timepoint"] == timepoint] traps = self["trap"][self["timepoint"] == timepoint]
return self.group_by_traps(traps, labels) return self.group_by_traps(traps, labels)
def __getitem__(self, item): def edgemasks_from_idx(self, idx):
assert item != "edgemasks", "Edgemasks must not be loaded as a whole" """Get edge masks from the h5 file."""
_item = "_" + item
if not hasattr(self, _item):
setattr(self, _item, self._fetch(item))
return getattr(self, _item)
def _fetch(self, path):
with h5py.File(self.filename, mode="r") as f:
return f[self.cinfo_path][path][()]
def _edgem_from_masking(self, mask):
with h5py.File(self.filename, mode="r") as f: with h5py.File(self.filename, mode="r") as f:
edgem = f[self.cinfo_path + "/edgemasks"][mask, ...] edgem = f[self.cinfo_path + "/edgemasks"][idx, ...]
return edgem return edgem
def _edgem_where(self, cell_id, trap_id): def edgemasks_where(self, cell_id, trap_id):
id_mask = self._get_idx(cell_id, trap_id) """Get the edge masks for a given cell and trap for all time points."""
edgem = self._edgem_from_masking(id_mask) idx = self.get_idx(cell_id, trap_id)
edgemasks = self.edgemasks_from_idx(idx)
return edgem return edgemasks
def outline(self, cell_id: int, trap_id: int): def outline(self, cell_id: int, trap_id: int):
id_mask = self._get_idx(cell_id, trap_id) """Get times and edge masks for a given cell and trap."""
times = self["timepoint"][id_mask] idx = self.get_idx(cell_id, trap_id)
times = self["timepoint"][idx]
return times, self._edgem_from_masking(id_mask) return times, self.edgemasks_from_idx(idx)
@property @property
def ntimepoints(self) -> int: def ntimepoints(self) -> int:
"""Return total number of time points in the experiment."""
return self["timepoint"].max() + 1 return self["timepoint"].max() + 1
@cached_property @cached_property
def _cells_vs_tps(self): def cells_vs_tps(self):
# Binary matrix showing the presence of all cells in all time points """Boolean matrix showing when cells are present for all time points."""
ncells_per_tile = [len(x) for x in self.labels] total_ncells = sum([len(x) for x in self.labels])
cells_vs_tps = np.zeros( cells_vs_tps = np.zeros((total_ncells, self.ntimepoints), dtype=bool)
(sum(ncells_per_tile), self.ntimepoints), dtype=bool
)
cells_vs_tps[ cells_vs_tps[
self._cell_cumsum[self["trap"]] + self["cell_label"] - 1, self.cell_cumlsum[self["trap"]] + self["cell_label"] - 1,
self["timepoint"], self["timepoint"],
] = True ] = True
return cells_vs_tps return cells_vs_tps
@cached_property @cached_property
def _cell_cumsum(self): def cell_cumlsum(self):
# Cumulative sum indicating the number of cells per tile """Find cumulative sum over tiles of the number of cells present."""
ncells_per_tile = [len(x) for x in self.labels] ncells_per_tile = [len(x) for x in self.labels]
cumsum = np.roll(np.cumsum(ncells_per_tile), shift=1) cumsum = np.roll(np.cumsum(ncells_per_tile), shift=1)
cumsum[0] = 0 cumsum[0] = 0
return cumsum return cumsum
def _flat_index_to_tuple_location(self, idx: int) -> t.Tuple[int, int]: def index_to_tile_and_cell(self, idx: int) -> t.Tuple[int, int]:
# Convert a cell index to a tuple """Convert an index to the equivalent pair of tile and cell IDs."""
# Note that it assumes tiles and cell labels are flattened, but tile_id = int(np.where(idx + 1 > self.cell_cumlsum)[0][-1])
# it is agnostic to tps cell_label = idx - self.cell_cumlsum[tile_id] + 1
tile_id = int(np.where(idx + 1 > self._cell_cumsum)[0][-1])
cell_label = idx - self._cell_cumsum[tile_id] + 1
return tile_id, cell_label return tile_id, cell_label
@property @property
def _tiles_vs_cells_vs_tps(self): def tiles_vs_cells_vs_tps(self):
"""
Boolean matrix showing if a cell is present.
The matrix is indexed by trap, cell label, and time point.
"""
ncells_mat = np.zeros( ncells_mat = np.zeros(
(self.ntraps, self["cell_label"].max(), self.ntimepoints), (self.ntraps, self["cell_label"].max(), self.ntimepoints),
dtype=bool, dtype=bool,
...@@ -325,32 +283,37 @@ class Cells: ...@@ -325,32 +283,37 @@ class Cells:
min_consecutive_tps: int = 15, min_consecutive_tps: int = 15,
interval: None or t.Tuple[int, int] = None, interval: None or t.Tuple[int, int] = None,
): ):
"""
Find cells present for all time points in a sliding window of time.
The result can be restricted to a particular interval of time.
"""
window = sliding_window_view( window = sliding_window_view(
self._cells_vs_tps, min_consecutive_tps, axis=1 self.cells_vs_tps, min_consecutive_tps, axis=1
) )
tp_min = window.sum(axis=-1) == min_consecutive_tps tp_min = window.sum(axis=-1) == min_consecutive_tps
# apply a filter to restrict to an interval of time
# Apply an interval filter to focucs on a slice
if interval is not None: if interval is not None:
interval = tuple(np.array(interval)) interval = tuple(np.array(interval))
else: else:
interval = (0, window.shape[1]) interval = (0, window.shape[1])
low_boundary, high_boundary = interval low_boundary, high_boundary = interval
tp_min[:, :low_boundary] = False tp_min[:, :low_boundary] = False
tp_min[:, high_boundary:] = False tp_min[:, high_boundary:] = False
return tp_min return tp_min
@lru_cache(20) @lru_cache(20)
def mothers_in_trap(self, trap_id: int): def mothers_in_trap(self, trap_id: int):
"""Return mothers at a trap."""
return self.mothers[trap_id] return self.mothers[trap_id]
@cached_property @cached_property
def mothers(self): def mothers(self):
""" """
Return nested list with final prediction of mother id for each cell Return a list of mother IDs for each cell in each tile.
Use Baby's "mother_assign_dynamic".
An ID of zero implies that no mother was assigned.
""" """
return self.mother_assign_from_dynamic( return self.mother_assign_from_dynamic(
self["mother_assign_dynamic"], self["mother_assign_dynamic"],
...@@ -362,73 +325,71 @@ class Cells: ...@@ -362,73 +325,71 @@ class Cells:
@cached_property @cached_property
def mothers_daughters(self) -> np.ndarray: def mothers_daughters(self) -> np.ndarray:
""" """
Return a single array with three columns, containing information about Return mother-daughter relationships for all tiles.
the mother-daughter relationships: tile, mothers and daughters.
Returns Returns
------- -------
np.ndarray mothers_daughters: np.ndarray
An array with shape (n, 3) where n is the number of mother-daughter pairs found. An array with shape (n, 3) where n is the number of mother-daughter
The columns contain: pairs found. The first column is the tile_id for the tile where the
- tile: the tile where the mother cell is located. mother cell is located. The second column is the cell index of a
- mothers: the index of the mother cell within the tile. mother cell in the tile. The third column is the index of the
- daughters: the index of the daughter cell within the tile. corresponding daughter cell.
""" """
nested_massign = self.mothers # list of arrays, one per tile, giving mothers of each cell in each tile
mothers = self.mothers
if sum([x for y in nested_massign for x in y]): if sum([x for y in mothers for x in y]):
mothers_daughters = np.array( mothers_daughters = np.array(
[ [
(tid, m, d) (trap_id, mother, bud)
for tid, trapcells in enumerate(nested_massign) for trap_id, trapcells in enumerate(mothers)
for d, m in enumerate(trapcells, 1) for bud, mother in enumerate(trapcells, start=1)
if m if mother
], ],
dtype=np.uint16, dtype=np.uint16,
) )
else: else:
mothers_daughters = np.array([]) mothers_daughters = np.array([])
self._log("No mother-daughters assigned") self._log("No mother-daughters assigned")
return mothers_daughters return mothers_daughters
@staticmethod @staticmethod
def mother_assign_to_mb_matrix(ma: t.List[np.array]): def mother_assign_to_mb_matrix(ma: t.List[np.array]):
""" """
Convert from a list of lists of mother-bud paired assignments to a Convert a list of mother-daughters into a boolean sparse matrix.
sparse matrix with a boolean dtype. The rows correspond to
to daughter buds. The values are boolean and indicate whether a Each row in the matrix correspond to daughter buds.
given cell is a mother cell and a given daughter bud is assigned If an entry is True, a given cell is a mother cell and a given
to the mother cell in the next timepoint. daughter bud is assigned to the mother cell in the next time point.
Parameters: Parameters:
----------- -----------
ma : list of lists of integers ma : list of lists of integers
A list of lists of mother-bud assignments. The i-th sublist contains the A list of lists of mother-bud assignments.
bud assignments for the i-th tile. The integers in each sublist The i-th sublist contains the bud assignments for the i-th tile.
represent the mother label, if it is zero no mother was found. The integers in each sublist represent the mother label, with zero
implying no mother found.
Returns: Returns:
-------- --------
mb_matrix : boolean numpy array of shape (n, m) mb_matrix : boolean numpy array of shape (n, m)
An n x m boolean numpy array where n is the total number of cells (sum An n x m array where n is the total number of cells (sum
of the lengths of all sublists in ma) and m is the maximum number of buds of the lengths of all sublists in ma) and m is the maximum
assigned to any mother cell in ma. The value at (i, j) is True if cell i number of buds assigned to any mother cell in ma.
is a daughter cell and cell j is its mother assigned to i. The value at (i, j) is True if cell i is a daughter cell and
cell j is its assigned mother.
Examples: Examples:
-------- --------
ma = [[0, 0, 1], [0, 1, 0]] >>> ma = [[0, 0, 1], [0, 1, 0]]
Cells(None).mother_assign_to_mb_matrix(ma) >>> Cells(None).mother_assign_to_mb_matrix(ma)
# array([[False, False, False, False, False, False], >>> array([[False, False, False, False, False, False],
# [False, False, False, False, False, False], [False, False, False, False, False, False],
# [ True, False, False, False, False, False], [ True, False, False, False, False, False],
# [False, False, False, False, False, False], [False, False, False, False, False, False],
# [False, False, False, True, False, False], [False, False, False, True, False, False],
# [False, False, False, False, False, False]]) [False, False, False, False, False, False]])
""" """
ncells = sum([len(t) for t in ma]) ncells = sum([len(t) for t in ma])
mb_matrix = np.zeros((ncells, ncells), dtype=bool) mb_matrix = np.zeros((ncells, ncells), dtype=bool)
c = 0 c = 0
...@@ -436,69 +397,78 @@ class Cells: ...@@ -436,69 +397,78 @@ class Cells:
for d, m in enumerate(cells): for d, m in enumerate(cells):
if m: if m:
mb_matrix[c + d, c + m - 1] = True mb_matrix[c + d, c + m - 1] = True
c += len(cells) c += len(cells)
return mb_matrix return mb_matrix
@staticmethod @staticmethod
def mother_assign_from_dynamic( def mother_assign_from_dynamic(
ma: np.ndarray, cell_label: t.List[int], trap: t.List[int], ntraps: int ma: np.ndarray,
cell_label: t.List[int],
trap: t.List[int],
ntraps: int,
) -> t.List[t.List[int]]: ) -> t.List[t.List[int]]:
""" """
Interpolate the associated mothers from the 'mother_assign_dynamic' feature. Find mothers from Baby's 'mother_assign_dynamic' variable.
Parameters Parameters
---------- ----------
ma: np.ndarray ma: np.ndarray
An array with shape (n_t, n_c) containing the 'mother_assign_dynamic' feature. An array with of length number of time points times number of cells
containing the 'mother_assign_dynamic' produced by Baby.
cell_label: List[int] cell_label: List[int]
A list containing the cell labels. A list of cell labels.
trap: List[int] trap: List[int]
A list containing the trap labels. A list of trap labels.
ntraps: int ntraps: int
The total number of traps. The total number of traps.
Returns Returns
------- -------
List[List[int]] List[List[int]]
A list of lists containing the interpolated mother assignment for each cell in each trap. A list giving the mothers for each cell at each trap.
""" """
idlist = list(zip(trap, cell_label)) ids = np.unique(list(zip(trap, cell_label)), axis=0)
cell_gid = np.unique(idlist, axis=0) # find when each cell last appeared at its trap
last_lin_preds = [ last_lin_preds = [
find_1st( find_1st(
((cell_label[::-1] == lbl) & (trap[::-1] == tr)), (
(cell_label[::-1] == cell_label_id)
& (trap[::-1] == trap_id)
),
True, True,
cmp_equal, cmp_equal,
) )
for tr, lbl in cell_gid for trap_id, cell_label_id in ids
] ]
# find the cell's mother using the latest prediction from Baby
mother_assign_sorted = ma[::-1][last_lin_preds] mother_assign_sorted = ma[::-1][last_lin_preds]
# rearrange as a list of mother IDs for each cell in each tile
traps = cell_gid[:, 0] traps = ids[:, 0]
iterator = groupby(zip(traps, mother_assign_sorted), lambda x: x[0]) iterator = groupby(zip(traps, mother_assign_sorted), lambda x: x[0])
d = {key: [x[1] for x in group] for key, group in iterator} d = {trap: [x[1] for x in mothers] for trap, mothers in iterator}
nested_massign = [d.get(i, []) for i in range(ntraps)] mothers = [d.get(i, []) for i in range(ntraps)]
return mothers
return nested_massign ###############################################################################
# Apparently unused below here
###############################################################################
@lru_cache(maxsize=200) @lru_cache(maxsize=200)
def labelled_in_frame( def labelled_in_frame(
self, frame: int, global_id: bool = False self, frame: int, global_id: bool = False
) -> np.ndarray: ) -> np.ndarray:
""" """
Returns labels in a 4D ndarray with the global ids with shape Return labels in a 4D ndarray with potentially global ids.
(ntraps, max_nlabels, ysize, xsize) at a given frame.
Use lru_cache to cache the results for speed.
Parameters Parameters
---------- ----------
frame : int frame : int
The frame number. The frame number (time point).
global_id : bool, optional global_id : bool, optional
If True, the returned array contains global ids, otherwise it If True, the returned array contains global ids, otherwise only
contains only the local ids of the labels. Default is False. the local ids of the labels.
Returns Returns
------- -------
...@@ -507,18 +477,12 @@ class Cells: ...@@ -507,18 +477,12 @@ class Cells:
The array has dimensions (ntraps, max_nlabels, ysize, xsize), The array has dimensions (ntraps, max_nlabels, ysize, xsize),
where max_nlabels is specific for this frame, not the entire where max_nlabels is specific for this frame, not the entire
experiment. experiment.
Notes
-----
This method uses lru_cache to cache the results for faster access.
""" """
labels_in_frame = self.labels_at_time(frame) labels_in_frame = self.labels_at_time(frame)
n_labels = [ n_labels = [
len(labels_in_frame.get(trap_id, [])) len(labels_in_frame.get(trap_id, []))
for trap_id in range(self.ntraps) for trap_id in range(self.ntraps)
] ]
# maxes = self.max_labels_in_frame(frame)
stacks_in_frame = self.get_stacks_in_frame(frame, self.tile_size) stacks_in_frame = self.get_stacks_in_frame(frame, self.tile_size)
first_id = np.cumsum([0, *n_labels]) first_id = np.cumsum([0, *n_labels])
labels_mat = np.zeros( labels_mat = np.zeros(
...@@ -552,7 +516,9 @@ class Cells: ...@@ -552,7 +516,9 @@ class Cells:
self, frame: int, tile_shape: t.Tuple[int] self, frame: int, tile_shape: t.Tuple[int]
) -> t.List[np.ndarray]: ) -> t.List[np.ndarray]:
""" """
Returns a list of stacked masks, each corresponding to a tile at a given timepoint. Return a list of stacked masks.
Each corresponds to a tile at a given time point.
Parameters Parameters
---------- ----------
...@@ -564,7 +530,7 @@ class Cells: ...@@ -564,7 +530,7 @@ class Cells:
Returns Returns
------- -------
List[np.ndarray] List[np.ndarray]
List of stacked masks for each tile at the given timepoint. List of stacked masks for each tile at the given time point.
""" """
masks = self.at_time(frame) masks = self.at_time(frame)
return [ return [
...@@ -574,7 +540,7 @@ class Cells: ...@@ -574,7 +540,7 @@ class Cells:
for trap_id in range(self.ntraps) for trap_id in range(self.ntraps)
] ]
def _sample_tiles_tps( def sample_tiles_tps(
self, self,
size=1, size=1,
min_consecutive_ntps: int = 15, min_consecutive_ntps: int = 15,
...@@ -582,7 +548,7 @@ class Cells: ...@@ -582,7 +548,7 @@ class Cells:
interval=None, interval=None,
) -> t.Tuple[np.ndarray, np.ndarray]: ) -> t.Tuple[np.ndarray, np.ndarray]:
""" """
Sample tiles that have a minimum number of cells and are occupied for at least a minimum number of consecutive timepoints. Sample tiles that have a minimum number of cells and are occupied for at least a minimum number of consecutive time points.
Parameters Parameters
---------- ----------
...@@ -591,7 +557,7 @@ class Cells: ...@@ -591,7 +557,7 @@ class Cells:
min_ncells: int, optional (default=2) min_ncells: int, optional (default=2)
The minimum number of cells per tile. The minimum number of cells per tile.
min_consecutive_ntps: int, optional (default=5) min_consecutive_ntps: int, optional (default=5)
The minimum number of consecutive timepoints a cell must be present in a trap. The minimum number of consecutive timep oints a cell must be present in a trap.
seed: int, optional (default=0) seed: int, optional (default=0)
Random seed value for reproducibility. Random seed value for reproducibility.
interval: None or Tuple(int,int), optional (default=None) interval: None or Tuple(int,int), optional (default=None)
...@@ -612,21 +578,15 @@ class Cells: ...@@ -612,21 +578,15 @@ class Cells:
min_consecutive_tps=min_consecutive_ntps, min_consecutive_tps=min_consecutive_ntps,
interval=interval, interval=interval,
) )
# Find all valid tiles with min_ncells for at least min_tps # Find all valid tiles with min_ncells for at least min_tps
index_id, tps = np.where(cell_availability_matrix) index_id, tps = np.where(cell_availability_matrix)
if interval is None: # Limit search if interval is None: # Limit search
interval = (0, cell_availability_matrix.shape[1]) interval = (0, cell_availability_matrix.shape[1])
np.random.seed(seed) np.random.seed(seed)
choices = np.random.randint(len(index_id), size=size) choices = np.random.randint(len(index_id), size=size)
linear_indices = np.zeros_like(self["cell_label"], dtype=bool) linear_indices = np.zeros_like(self["cell_label"], dtype=bool)
for cell_index_flat, tp in zip(index_id[choices], tps[choices]): for cell_index_flat, tp in zip(index_id[choices], tps[choices]):
tile_id, cell_label = self._flat_index_to_tuple_location( tile_id, cell_label = self.index_to_tile_and_cell(cell_index_flat)
cell_index_flat
)
linear_indices[ linear_indices[
( (
(self["cell_label"] == cell_label) (self["cell_label"] == cell_label)
...@@ -634,10 +594,9 @@ class Cells: ...@@ -634,10 +594,9 @@ class Cells:
& (self["timepoint"] == tp) & (self["timepoint"] == tp)
) )
] = True ] = True
return linear_indices return linear_indices
def _sample_masks( def sample_masks(
self, self,
size: int = 1, size: int = 1,
min_consecutive_ntps: int = 15, min_consecutive_ntps: int = 15,
...@@ -668,31 +627,28 @@ class Cells: ...@@ -668,31 +627,28 @@ class Cells:
The second tuple contains: The second tuple contains:
- `masks`: A list of 2D numpy arrays representing the binary masks of the sampled cells at each timepoint. - `masks`: A list of 2D numpy arrays representing the binary masks of the sampled cells at each timepoint.
""" """
sampled_bitmask = self._sample_tiles_tps( sampled_bitmask = self.sample_tiles_tps(
size=size, size=size,
min_consecutive_ntps=min_consecutive_ntps, min_consecutive_ntps=min_consecutive_ntps,
seed=seed, seed=seed,
interval=interval, interval=interval,
) )
# Sort sampled tiles to use automatic cache when possible # Sort sampled tiles to use automatic cache when possible
tile_ids = self["trap"][sampled_bitmask] tile_ids = self["trap"][sampled_bitmask]
cell_labels = self["cell_label"][sampled_bitmask] cell_labels = self["cell_label"][sampled_bitmask]
tps = self["timepoint"][sampled_bitmask] tps = self["timepoint"][sampled_bitmask]
masks = [] masks = []
for tile_id, cell_label, tp in zip(tile_ids, cell_labels, tps): for tile_id, cell_label, tp in zip(tile_ids, cell_labels, tps):
local_idx = self.labels_at_time(tp)[tile_id].index(cell_label) local_idx = self.labels_at_time(tp)[tile_id].index(cell_label)
tile_mask = self.at_time(tp)[tile_id][local_idx] tile_mask = self.at_time(tp)[tile_id][local_idx]
masks.append(tile_mask) masks.append(tile_mask)
return (tile_ids, cell_labels, tps), np.stack(masks) return (tile_ids, cell_labels, tps), np.stack(masks)
def matrix_trap_tp_where( def matrix_trap_tp_where(
self, min_ncells: int = 2, min_consecutive_tps: int = 5 self, min_ncells: int = 2, min_consecutive_tps: int = 5
): ):
""" """
NOTE CURRENLTY UNUSED WITHIN ALIBY THE MOMENT. MAY BE USEFUL IN THE FUTURE. NOTE CURRENTLY UNUSED BUT USEFUL.
Return a matrix of shape (ntraps x ntps - min_consecutive_tps) to Return a matrix of shape (ntraps x ntps - min_consecutive_tps) to
indicate traps and time-points where min_ncells are available for at least min_consecutive_tps indicate traps and time-points where min_ncells are available for at least min_consecutive_tps
...@@ -708,9 +664,8 @@ class Cells: ...@@ -708,9 +664,8 @@ class Cells:
(ntraps x ( ntps-min_consecutive_tps )) 2D boolean numpy array where rows are trap ids and columns are timepoint windows. (ntraps x ( ntps-min_consecutive_tps )) 2D boolean numpy array where rows are trap ids and columns are timepoint windows.
If the value in a cell is true its corresponding trap and timepoint contains more than min_ncells for at least min_consecutive time-points. If the value in a cell is true its corresponding trap and timepoint contains more than min_ncells for at least min_consecutive time-points.
""" """
window = sliding_window_view( window = sliding_window_view(
self._tiles_vs_cells_vs_tps, min_consecutive_tps, axis=2 self.tiles_vs_cells_vs_tps, min_consecutive_tps, axis=2
) )
tp_min = window.sum(axis=-1) == min_consecutive_tps tp_min = window.sum(axis=-1) == min_consecutive_tps
ncells_tp_min = tp_min.sum(axis=1) >= min_ncells ncells_tp_min = tp_min.sum(axis=1) >= min_ncells
...@@ -720,7 +675,7 @@ class Cells: ...@@ -720,7 +675,7 @@ class Cells:
def stack_masks_in_tile( def stack_masks_in_tile(
masks: t.List[np.ndarray], tile_shape: t.Tuple[int] masks: t.List[np.ndarray], tile_shape: t.Tuple[int]
) -> np.ndarray: ) -> np.ndarray:
# Stack all masks in a trap padding accordingly if no outlines found """Stack all masks in a trap, padding accordingly if no outlines found."""
result = np.zeros((0, *tile_shape), dtype=bool) result = np.zeros((0, *tile_shape), dtype=bool)
if len(masks): if len(masks):
result = np.stack(masks) result = np.stack(masks)
......
...@@ -6,17 +6,19 @@ import typing as t ...@@ -6,17 +6,19 @@ import typing as t
from functools import wraps from functools import wraps
def _first_arg_str_to_df( def _first_arg_str_to_raw_df(
fn: t.Callable, fn: t.Callable,
): ):
"""Enable Signal-like classes to convert strings to data sets.""" """Enable Signal-like classes to convert strings to data sets."""
@wraps(fn) @wraps(fn)
def format_input(*args, **kwargs): def format_input(*args, **kwargs):
cls = args[0] cls = args[0]
data = args[1] data = args[1]
if isinstance(data, str): if isinstance(data, str):
# get data from h5 file # get data from h5 file using Signal's get_raw
data = cls.get_raw(data) data = cls.get_raw(data)
# replace path in the undecorated function with data # replace path in the undecorated function with data
return fn(cls, data, *args[2:], **kwargs) return fn(cls, data, *args[2:], **kwargs)
return format_input return format_input
...@@ -66,7 +66,7 @@ class MetaData: ...@@ -66,7 +66,7 @@ class MetaData:
# Needed because HDF5 attributes do not support dictionaries # Needed because HDF5 attributes do not support dictionaries
def flatten_dict(nested_dict, separator="/"): def flatten_dict(nested_dict, separator="/"):
""" """
Flattens nested dictionary. If empty return as-is. Flatten nested dictionary. If empty return as-is.
""" """
flattened = {} flattened = {}
if nested_dict: if nested_dict:
...@@ -79,9 +79,7 @@ def flatten_dict(nested_dict, separator="/"): ...@@ -79,9 +79,7 @@ def flatten_dict(nested_dict, separator="/"):
# Needed because HDF5 attributes do not support datetime objects # Needed because HDF5 attributes do not support datetime objects
# Takes care of time zones & daylight saving # Takes care of time zones & daylight saving
def datetime_to_timestamp(time, locale="Europe/London"): def datetime_to_timestamp(time, locale="Europe/London"):
""" """Convert datetime object to UNIX timestamp."""
Convert datetime object to UNIX timestamp
"""
return timezone(locale).localize(time).timestamp() return timezone(locale).localize(time).timestamp()
...@@ -189,36 +187,37 @@ def parse_swainlab_metadata(filedir: t.Union[str, Path]): ...@@ -189,36 +187,37 @@ def parse_swainlab_metadata(filedir: t.Union[str, Path]):
Dictionary with minimal metadata Dictionary with minimal metadata
""" """
filedir = Path(filedir) filedir = Path(filedir)
filepath = find_file(filedir, "*.log") filepath = find_file(filedir, "*.log")
if filepath: if filepath:
# new log files
raw_parse = parse_from_swainlab_grammar(filepath) raw_parse = parse_from_swainlab_grammar(filepath)
minimal_meta = get_meta_swainlab(raw_parse) minimal_meta = get_meta_swainlab(raw_parse)
else: else:
# old log files
if filedir.is_file() or str(filedir).endswith(".zarr"): if filedir.is_file() or str(filedir).endswith(".zarr"):
# log file is in parent directory
filedir = filedir.parent filedir = filedir.parent
legacy_parse = parse_logfiles(filedir) legacy_parse = parse_logfiles(filedir)
minimal_meta = ( minimal_meta = (
get_meta_from_legacy(legacy_parse) if legacy_parse else {} get_meta_from_legacy(legacy_parse) if legacy_parse else {}
) )
return minimal_meta return minimal_meta
def dispatch_metadata_parser(filepath: t.Union[str, Path]): def dispatch_metadata_parser(filepath: t.Union[str, Path]):
""" """
Function to dispatch different metadata parsers that convert logfiles into a Dispatch different metadata parsers that convert logfiles into a dictionary.
basic metadata dictionary. Currently only contains the swainlab log parsers.
Currently only contains the swainlab log parsers.
Input: Input:
-------- --------
filepath: str existing file containing metadata, or folder containing naming conventions filepath: str existing file containing metadata, or folder containing naming
conventions
""" """
parsed_meta = parse_swainlab_metadata(filepath) parsed_meta = parse_swainlab_metadata(filepath)
if parsed_meta is None: if parsed_meta is None:
parsed_meta = dir_to_meta parsed_meta = dir_to_meta
return parsed_meta return parsed_meta
......
...@@ -5,7 +5,7 @@ import h5py ...@@ -5,7 +5,7 @@ import h5py
import numpy as np import numpy as np
from agora.io.bridge import groupsort from agora.io.bridge import groupsort
from agora.io.writer import load_attributes from agora.io.writer import load_meta
class DynamicReader: class DynamicReader:
...@@ -13,7 +13,7 @@ class DynamicReader: ...@@ -13,7 +13,7 @@ class DynamicReader:
def __init__(self, file: str): def __init__(self, file: str):
self.file = file self.file = file
self.metadata = load_attributes(file) self.metadata = load_meta(file)
class StateReader(DynamicReader): class StateReader(DynamicReader):
......
...@@ -9,9 +9,10 @@ import h5py ...@@ -9,9 +9,10 @@ import h5py
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import aliby.global_parameters as global_parameters
from agora.io.bridge import BridgeH5 from agora.io.bridge import BridgeH5
from agora.io.decorators import _first_arg_str_to_df from agora.io.decorators import _first_arg_str_to_raw_df
from agora.utils.indexing import validate_association from agora.utils.indexing import validate_lineage
from agora.utils.kymograph import add_index_levels from agora.utils.kymograph import add_index_levels
from agora.utils.merge import apply_merges from agora.utils.merge import apply_merges
...@@ -20,11 +21,14 @@ class Signal(BridgeH5): ...@@ -20,11 +21,14 @@ class Signal(BridgeH5):
""" """
Fetch data from h5 files for post-processing. Fetch data from h5 files for post-processing.
Signal assumes that the metadata and data are accessible to perform time-adjustments and apply previously recorded post-processes. Signal assumes that the metadata and data are accessible to
perform time-adjustments and apply previously recorded
post-processes.
""" """
def __init__(self, file: t.Union[str, Path]): def __init__(self, file: t.Union[str, Path]):
"""Define index_names for dataframes, candidate fluorescence channels, and composite statistics.""" """Define index_names for dataframes, candidate fluorescence channels,
and composite statistics."""
super().__init__(file, flag=None) super().__init__(file, flag=None)
self.index_names = ( self.index_names = (
"experiment", "experiment",
...@@ -33,22 +37,13 @@ class Signal(BridgeH5): ...@@ -33,22 +37,13 @@ class Signal(BridgeH5):
"cell_label", "cell_label",
"mother_label", "mother_label",
) )
self.candidate_channels = ( self.candidate_channels = global_parameters.possible_imaging_channels
"GFP",
"GFPFast",
"mCherry",
"Flavin",
"Citrine",
"mKO2",
"Cy5",
"pHluorin405",
)
def __getitem__(self, dsets: t.Union[str, t.Collection]): def __getitem__(self, dsets: t.Union[str, t.Collection]):
"""Get and potentially pre-process data from h5 file and return as a dataframe.""" """Get and potentially pre-process data from h5 file and return as a dataframe."""
if isinstance(dsets, str): # no pre-processing if isinstance(dsets, str):
return self.get(dsets) return self.get(dsets)
elif isinstance(dsets, list): # pre-processing elif isinstance(dsets, list):
is_bgd = [dset.endswith("imBackground") for dset in dsets] is_bgd = [dset.endswith("imBackground") for dset in dsets]
# Check we are not comparing tile-indexed and cell-indexed data # Check we are not comparing tile-indexed and cell-indexed data
assert sum(is_bgd) == 0 or sum(is_bgd) == len( assert sum(is_bgd) == 0 or sum(is_bgd) == len(
...@@ -58,22 +53,23 @@ class Signal(BridgeH5): ...@@ -58,22 +53,23 @@ class Signal(BridgeH5):
else: else:
raise Exception(f"Invalid type {type(dsets)} to get datasets") raise Exception(f"Invalid type {type(dsets)} to get datasets")
def get(self, dsets: t.Union[str, t.Collection], **kwargs): def get(self, dset_name: t.Union[str, t.Collection], **kwargs):
"""Get and potentially pre-process data from h5 file and return as a dataframe.""" """Return pre-processed data as a dataframe."""
if isinstance(dsets, str): # no pre-processing if isinstance(dset_name, str):
df = self.get_raw(dsets, **kwargs) dsets = self.get_raw(dset_name, **kwargs)
prepost_applied = self.apply_prepost(dsets, **kwargs) prepost_applied = self.apply_prepost(dsets, **kwargs)
return self.add_name(prepost_applied, dset_name)
return self.add_name(prepost_applied, dsets) else:
raise Exception("Error in Signal.get")
@staticmethod @staticmethod
def add_name(df, name): def add_name(df, name):
"""Add column of identical strings to a dataframe.""" """Add name of the Signal as an attribute to its corresponding dataframe."""
df.name = name df.name = name
return df return df
def cols_in_mins(self, df: pd.DataFrame): def cols_in_mins(self, df: pd.DataFrame):
# Convert numerical columns in a dataframe to minutes """Convert numerical columns in a dataframe to minutes."""
try: try:
df.columns = (df.columns * self.tinterval // 60).astype(int) df.columns = (df.columns * self.tinterval // 60).astype(int)
except Exception as e: except Exception as e:
...@@ -94,14 +90,15 @@ class Signal(BridgeH5): ...@@ -94,14 +90,15 @@ class Signal(BridgeH5):
if tinterval_location in f.attrs: if tinterval_location in f.attrs:
return f.attrs[tinterval_location][0] return f.attrs[tinterval_location][0]
else: else:
logging.getlogger("aliby").warn( logging.getLogger("aliby").warn(
f"{str(self.filename).split('/')[-1]}: using default time interval of 5 minutes" f"{str(self.filename).split('/')[-1]}: using default time interval of 5 minutes"
) )
return 5 return 5
@staticmethod @staticmethod
def get_retained(df, cutoff): def get_retained(df, cutoff):
"""Return a fraction of the df, one without later time points.""" """Return rows of df with at least cutoff fraction of the total number
of time points."""
return df.loc[bn.nansum(df.notna(), axis=1) > df.shape[1] * cutoff] return df.loc[bn.nansum(df.notna(), axis=1) > df.shape[1] * cutoff]
@property @property
...@@ -110,15 +107,17 @@ class Signal(BridgeH5): ...@@ -110,15 +107,17 @@ class Signal(BridgeH5):
with h5py.File(self.filename, "r") as f: with h5py.File(self.filename, "r") as f:
return list(f.attrs["channels"]) return list(f.attrs["channels"])
@_first_arg_str_to_df def retained(
def retained(self, signal, cutoff=0.8): self, signal, cutoff=global_parameters.signal_retained_cutoff
):
""" """
Load data (via decorator) and reduce the resulting dataframe. Load data (via decorator) and reduce the resulting dataframe.
Load data for a signal or a list of signals and reduce the resulting Load data for a signal or a list of signals and reduce the resulting
dataframes to a fraction of their original size, losing late time dataframes to rows with sufficient numbers of time points.
points.
""" """
if isinstance(signal, str):
signal = self.get_raw(signal)
if isinstance(signal, pd.DataFrame): if isinstance(signal, pd.DataFrame):
return self.get_retained(signal, cutoff) return self.get_retained(signal, cutoff)
elif isinstance(signal, list): elif isinstance(signal, list):
...@@ -131,17 +130,15 @@ class Signal(BridgeH5): ...@@ -131,17 +130,15 @@ class Signal(BridgeH5):
""" """
Get lineage data from a given location in the h5 file. Get lineage data from a given location in the h5 file.
Returns an array with three columns: the tile id, the mother label, and the daughter label. Returns an array with three columns: the tile id, the mother label,
and the daughter label.
""" """
if lineage_location is None: if lineage_location is None:
lineage_location = "modifiers/lineage_merged" lineage_location = "modifiers/lineage_merged"
with h5py.File(self.filename, "r") as f: with h5py.File(self.filename, "r") as f:
# if lineage_location not in f:
# lineage_location = lineage_location.split("_")[0]
if lineage_location not in f: if lineage_location not in f:
lineage_location = "postprocessing/lineage" lineage_location = "postprocessing/lineage"
tile_mo_da = f[lineage_location] tile_mo_da = f[lineage_location]
if isinstance(tile_mo_da, h5py.Dataset): if isinstance(tile_mo_da, h5py.Dataset):
lineage = tile_mo_da[()] lineage = tile_mo_da[()]
else: else:
...@@ -154,7 +151,7 @@ class Signal(BridgeH5): ...@@ -154,7 +151,7 @@ class Signal(BridgeH5):
).T ).T
return lineage return lineage
@_first_arg_str_to_df @_first_arg_str_to_raw_df
def apply_prepost( def apply_prepost(
self, self,
data: t.Union[str, pd.DataFrame], data: t.Union[str, pd.DataFrame],
...@@ -162,57 +159,40 @@ class Signal(BridgeH5): ...@@ -162,57 +159,40 @@ class Signal(BridgeH5):
picks: t.Union[t.Collection, bool] = True, picks: t.Union[t.Collection, bool] = True,
): ):
""" """
Apply modifier operations (picker or merger) to a dataframe. Apply picking and merging to a Signal data frame.
Parameters Parameters
---------- ----------
data : t.Union[str, pd.DataFrame] data : t.Union[str, pd.DataFrame]
DataFrame or path to one. A data frame or a path to one.
merges : t.Union[np.ndarray, bool] merges : t.Union[np.ndarray, bool]
(optional) 2-D array with three columns: the tile id, the mother label, and the daughter id. (optional) An array of pairs of (trap, cell) indices to merge.
If True, fetch merges from file. If True, fetch merges from file.
picks : t.Union[np.ndarray, bool] picks : t.Union[np.ndarray, bool]
(optional) 2-D array with two columns: the tiles and (optional) An array of (trap, cell) indices.
the cell labels.
If True, fetch picks from file. If True, fetch picks from file.
Examples
--------
FIXME: Add docs.
""" """
if isinstance(merges, bool): if isinstance(merges, bool):
merges: np.ndarray = self.load_merges() if merges else np.array([]) merges = self.load_merges() if merges else np.array([])
if merges.any(): if merges.any():
merged = apply_merges(data, merges) merged = apply_merges(data, merges)
else: else:
merged = copy(data) merged = copy(data)
if isinstance(picks, bool): if isinstance(picks, bool):
picks = ( picks = (
self.get_picks(names=merged.index.names) self.get_picks(
names=merged.index.names, path="modifiers/picks/"
)
if picks if picks
else set(merged.index) else merged.index
) )
with h5py.File(self.filename, "r") as f: if picks:
if "modifiers/picks" in f and picks: picked_indices = set(picks).intersection(
if picks: [tuple(x) for x in merged.index]
return merged.loc[ )
set(picks).intersection( return merged.loc[picked_indices]
[tuple(x) for x in merged.index] else:
) return merged
]
else:
if isinstance(merged.index, pd.MultiIndex):
empty_lvls = [[] for i in merged.index.names]
index = pd.MultiIndex(
levels=empty_lvls,
codes=empty_lvls,
names=merged.index.names,
)
else:
index = pd.Index([], name=merged.index.name)
merged = pd.DataFrame([], index=index)
return merged
@cached_property @cached_property
def p_available(self): def p_available(self):
...@@ -272,10 +252,11 @@ class Signal(BridgeH5): ...@@ -272,10 +252,11 @@ class Signal(BridgeH5):
Parameters Parameters
---------- ----------
dataset: str or list of strs dataset: str or list of strs
The name of the h5 file or a list of h5 file names The name of the h5 file or a list of h5 file names.
in_minutes: boolean in_minutes: boolean
If True, If True, convert column headings to times in minutes.
lineage: boolean lineage: boolean
If True, add mother_label to index.
""" """
try: try:
if isinstance(dataset, str): if isinstance(dataset, str):
...@@ -288,15 +269,17 @@ class Signal(BridgeH5): ...@@ -288,15 +269,17 @@ class Signal(BridgeH5):
self.get_raw(dset, in_minutes=in_minutes, lineage=lineage) self.get_raw(dset, in_minutes=in_minutes, lineage=lineage)
for dset in dataset for dset in dataset
] ]
if lineage: # assume that df is sorted if lineage:
# assume that df is sorted
mother_label = np.zeros(len(df), dtype=int) mother_label = np.zeros(len(df), dtype=int)
lineage = self.lineage() lineage = self.lineage()
a, b = validate_association( # information on buds
valid_lineage, valid_indices = validate_lineage(
lineage, lineage,
np.array(df.index.to_list()), np.array(df.index.to_list()),
match_column=1, "daughters",
) )
mother_label[b] = lineage[a, 1] mother_label[valid_indices] = lineage[valid_lineage, 1]
df = add_index_levels(df, {"mother_label": mother_label}) df = add_index_levels(df, {"mother_label": mother_label})
return df return df
except Exception as e: except Exception as e:
...@@ -316,13 +299,14 @@ class Signal(BridgeH5): ...@@ -316,13 +299,14 @@ class Signal(BridgeH5):
names: t.Tuple[str, ...] = ("trap", "cell_label"), names: t.Tuple[str, ...] = ("trap", "cell_label"),
path: str = "modifiers/picks/", path: str = "modifiers/picks/",
) -> t.Set[t.Tuple[int, str]]: ) -> t.Set[t.Tuple[int, str]]:
"""Get the relevant picks based on names.""" """Get picks from the h5 file."""
with h5py.File(self.filename, "r") as f: with h5py.File(self.filename, "r") as f:
picks = set()
if path in f: if path in f:
picks = set( picks = set(
zip(*[f[path + name] for name in names if name in f[path]]) zip(*[f[path + name] for name in names if name in f[path]])
) )
else:
picks = set()
return picks return picks
def dataset_to_df(self, f: h5py.File, path: str) -> pd.DataFrame: def dataset_to_df(self, f: h5py.File, path: str) -> pd.DataFrame:
...@@ -353,10 +337,7 @@ class Signal(BridgeH5): ...@@ -353,10 +337,7 @@ class Signal(BridgeH5):
fullname: str, fullname: str,
node: t.Union[h5py.Dataset, h5py.Group], node: t.Union[h5py.Dataset, h5py.Group],
): ):
""" """Store the name of a signal if it is a leaf node and if it starts with extraction."""
Store the name of a signal if it is a leaf node
(a group with no more groups inside) and if it starts with extraction.
"""
if isinstance(node, h5py.Group) and np.all( if isinstance(node, h5py.Group) and np.all(
[isinstance(x, h5py.Dataset) for x in node.values()] [isinstance(x, h5py.Dataset) for x in node.values()]
): ):
......
...@@ -15,9 +15,10 @@ from agora.io.bridge import BridgeH5 ...@@ -15,9 +15,10 @@ from agora.io.bridge import BridgeH5
#################### Dynamic version ################################## #################### Dynamic version ##################################
def load_attributes(file: str, group="/"): def load_meta(file: str, group="/"):
""" """
Load the metadata from an h5 file and convert to a dictionary, including the "parameters" field which is stored as YAML. Load the metadata from an h5 file and convert to a dictionary, including
the "parameters" field which is stored as YAML.
Parameters Parameters
---------- ----------
...@@ -26,8 +27,9 @@ def load_attributes(file: str, group="/"): ...@@ -26,8 +27,9 @@ def load_attributes(file: str, group="/"):
group: str, optional group: str, optional
The group in the h5 file from which to read the data The group in the h5 file from which to read the data
""" """
# load the metadata, stored as attributes, from the h5 file and return as a dictionary # load the metadata, stored as attributes, from the h5 file
with h5py.File(file, "r") as f: with h5py.File(file, "r") as f:
# return as a dict
meta = dict(f[group].attrs.items()) meta = dict(f[group].attrs.items())
if "parameters" in meta: if "parameters" in meta:
# convert from yaml format into dict # convert from yaml format into dict
...@@ -51,7 +53,7 @@ class DynamicWriter: ...@@ -51,7 +53,7 @@ class DynamicWriter:
self.file = file self.file = file
# the metadata is stored as attributes in the h5 file # the metadata is stored as attributes in the h5 file
if Path(file).exists(): if Path(file).exists():
self.metadata = load_attributes(file) self.metadata = load_meta(file)
def _log(self, message: str, level: str = "warn"): def _log(self, message: str, level: str = "warn"):
# Log messages in the corresponding level # Log messages in the corresponding level
......
...@@ -9,6 +9,152 @@ This can be: ...@@ -9,6 +9,152 @@ This can be:
import numpy as np import numpy as np
import typing as t import typing as t
# data type to link together trap and cell ids
i_dtype = {"names": ["trap_id", "cell_id"], "formats": [np.int64, np.int64]}
def validate_lineage(
lineage: np.ndarray, indices: np.ndarray, how: str = "families"
):
"""
Identify mother-bud pairs that exist both in lineage and a Signal's
indices.
We expect the lineage information to be unique: a bud should not have
two mothers.
Parameters
----------
lineage : np.ndarray
2D array of lineage associations where columns are
(trap, mother, daughter)
or
a 3D array, which is an array of 2 X 2 arrays comprising
[[trap_id, mother_label], [trap_id, daughter_label]].
indices : np.ndarray
A 2D array of cell indices from a Signal, (trap_id, cell_label).
This array should not include mother_label.
how: str
If "mothers", matches indicate mothers from mother-bud pairs;
If "daughters", matches indicate daughters from mother-bud pairs;
If "families", matches indicate mothers and daughters in mother-bud pairs.
Returns
-------
valid_lineage: boolean np.ndarray
1D array indicating matched elements in lineage.
valid_indices: boolean np.ndarray
1D array indicating matched elements in indices.
Examples
--------
>>> import numpy as np
>>> from agora.utils.indexing import validate_lineage
>>> lineage = np.array([ [[0, 1], [0, 3]], [[0, 1], [0, 4]], [[0, 1], [0, 6]], [[0, 4], [0, 7]] ])
>>> indices = np.array([ [0, 1], [0, 2], [0, 3]])
>>> valid_lineage, valid_indices = validate_lineage(lineage, indices)
>>> print(valid_lineage)
array([ True, False, False, False])
>>> print(valid_indices)
array([ True, False, True])
and
>>> lineage = np.array([[[0,3], [0,1]], [[0,2], [0,4]]])
>>> indices = np.array([[0,1], [0,2], [0,3]])
>>> valid_lineage, valid_indices = validate_lineage(lineage, indices)
>>> print(valid_lineage)
array([ True, False])
>>> print(valid_indices)
array([ True, False, True])
"""
if lineage.ndim == 2:
# [trap, mother, daughter] becomes [[trap, mother], [trap, daughter]]
lineage = _assoc_indices_to_3d(lineage)
if how == "mothers":
c_index = 0
elif how == "daughters":
c_index = 1
# find valid lineage
valid_lineages = index_isin(lineage, indices)
if how == "families":
# both mother and bud must be in indices
valid_lineage = valid_lineages.all(axis=1)
else:
valid_lineage = valid_lineages[:, c_index, :]
flat_valid_lineage = valid_lineage.flatten()
# find valid indices
selected_lineages = lineage[flat_valid_lineage, ...]
if how == "families":
# select only pairs of mother and bud indices
valid_indices = index_isin(indices, selected_lineages)
else:
valid_indices = index_isin(indices, selected_lineages[:, c_index, :])
flat_valid_indices = valid_indices.flatten()
if (
indices[flat_valid_indices, :].size
!= np.unique(
lineage[flat_valid_lineage, :].reshape(-1, 2), axis=0
).size
):
# all unique indices in valid_lineages should be in valid_indices
raise Exception(
"Error in validate_lineage: "
"lineage information is likely not unique."
)
return flat_valid_lineage, flat_valid_indices
def index_isin(x: np.ndarray, y: np.ndarray) -> np.ndarray:
"""
Find those elements of x that are in y.
Both arrays must be arrays of integer indices,
such as (trap_id, cell_id).
"""
x = np.ascontiguousarray(x, dtype=np.int64)
y = np.ascontiguousarray(y, dtype=np.int64)
xv = x.view(i_dtype)
inboth = np.intersect1d(xv, y.view(i_dtype))
x_bool = np.isin(xv, inboth)
return x_bool
def _assoc_indices_to_3d(ndarray: np.ndarray):
"""
Convert the last column to a new row and repeat first column's values.
For example: [trap, mother, daughter] becomes
[[trap, mother], [trap, daughter]].
Assumes the input array has shape (N,3).
"""
result = ndarray
if len(ndarray) and ndarray.ndim > 1:
# faster indexing for single positions
if ndarray.shape[1] == 3:
result = np.transpose(
np.hstack((ndarray[:, [0]], ndarray)).reshape(-1, 2, 2),
axes=[0, 2, 1],
)
else:
# 20% slower but more general indexing
columns = np.arange(ndarray.shape[1])
result = np.stack(
(
ndarray[:, np.delete(columns, -1)],
ndarray[:, np.delete(columns, -2)],
),
axis=1,
)
return result
###################################################################
def validate_association( def validate_association(
association: np.ndarray, association: np.ndarray,
...@@ -104,38 +250,8 @@ def validate_association( ...@@ -104,38 +250,8 @@ def validate_association(
return valid_association, valid_indices return valid_association, valid_indices
def _assoc_indices_to_3d(ndarray: np.ndarray):
"""
Convert the last column to a new row while repeating all previous indices.
This is useful when converting a signal multiindex before comparing association.
Assumes the input array has shape (N,3)
"""
result = ndarray
if len(ndarray) and ndarray.ndim > 1:
if ndarray.shape[1] == 3: # Faster indexing for single positions
result = np.transpose(
np.hstack((ndarray[:, [0]], ndarray)).reshape(-1, 2, 2),
axes=[0, 2, 1],
)
else: # 20% slower but more general indexing
columns = np.arange(ndarray.shape[1])
result = np.stack(
(
ndarray[:, np.delete(columns, -1)],
ndarray[:, np.delete(columns, -2)],
),
axis=1,
)
return result
def _3d_index_to_2d(array: np.ndarray): def _3d_index_to_2d(array: np.ndarray):
""" """Revert _assoc_indices_to_3d."""
Opposite to _assoc_indices_to_3d.
"""
result = array result = array
if len(array): if len(array):
result = np.concatenate( result = np.concatenate(
......
...@@ -86,16 +86,19 @@ def bidirectional_retainment_filter( ...@@ -86,16 +86,19 @@ def bidirectional_retainment_filter(
daughters_thresh: int = 7, daughters_thresh: int = 7,
) -> pd.DataFrame: ) -> pd.DataFrame:
""" """
Retrieve families where mothers are present for more than a fraction of the experiment, and daughters for longer than some number of time-points. Retrieve families where mothers are present for more than a fraction
of the experiment and daughters for longer than some number of
time-points.
Parameters Parameters
---------- ----------
df: pd.DataFrame df: pd.DataFrame
Data Data
mothers_thresh: float mothers_thresh: float
Minimum fraction of experiment's total duration for which mothers must be present. Minimum fraction of experiment's total duration for which mothers
must be present.
daughters_thresh: int daughters_thresh: int
Minimum number of time points for which daughters must be observed Minimum number of time points for which daughters must be observed.
""" """
# daughters # daughters
all_daughters = df.loc[df.index.get_level_values("mother_label") > 0] all_daughters = df.loc[df.index.get_level_values("mother_label") > 0]
...@@ -170,6 +173,7 @@ def slices_from_spans(spans: t.Tuple[int], df: pd.DataFrame) -> t.List[slice]: ...@@ -170,6 +173,7 @@ def slices_from_spans(spans: t.Tuple[int], df: pd.DataFrame) -> t.List[slice]:
def drop_mother_label(index: pd.MultiIndex) -> np.ndarray: def drop_mother_label(index: pd.MultiIndex) -> np.ndarray:
"""Remove mother_label level from a MultiIndex."""
no_mother_label = index no_mother_label = index
if "mother_label" in index.names: if "mother_label" in index.names:
no_mother_label = index.droplevel("mother_label") no_mother_label = index.droplevel("mother_label")
......
#!/usr/bin/env python3 #!/usr/bin/env python3
import re
import typing as t
import numpy as np import numpy as np
import pandas as pd
from agora.io.bridge import groupsort from agora.io.bridge import groupsort
from itertools import groupby
def mb_array_to_dict(mb_array: np.ndarray): def mb_array_to_dict(mb_array: np.ndarray):
...@@ -19,4 +15,3 @@ def mb_array_to_dict(mb_array: np.ndarray): ...@@ -19,4 +15,3 @@ def mb_array_to_dict(mb_array: np.ndarray):
for trap, mo_da in groupsort(mb_array).items() for trap, mo_da in groupsort(mb_array).items()
for mo, daughters in groupsort(mo_da).items() for mo, daughters in groupsort(mo_da).items()
} }
...@@ -3,90 +3,161 @@ ...@@ -3,90 +3,161 @@
Functions to efficiently merge rows in DataFrames. Functions to efficiently merge rows in DataFrames.
""" """
import typing as t import typing as t
from copy import copy
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from utils_find_1st import cmp_larger, find_1st from utils_find_1st import cmp_larger, find_1st
from agora.utils.indexing import compare_indices, validate_association from agora.utils.indexing import index_isin
def group_merges(merges: np.ndarray) -> t.List[t.Tuple]:
"""
Convert merges into a list of merges for traps requiring multiple
merges and then for traps requiring single merges.
"""
left_tracks = merges[:, 0]
right_tracks = merges[:, 1]
# find traps requiring multiple merges
linr = merges[index_isin(left_tracks, right_tracks).flatten(), :]
rinl = merges[index_isin(right_tracks, left_tracks).flatten(), :]
# make unique and order merges for each trap
multi_merge = np.unique(np.concatenate((linr, rinl)), axis=0)
# find traps requiring a singe merge
single_merge = merges[
~index_isin(merges, multi_merge).all(axis=1).flatten(), :
]
# convert to lists of arrays
single_merge_list = [[sm] for sm in single_merge]
multi_merge_list = [
multi_merge[multi_merge[:, 0, 0] == trap_id, ...]
for trap_id in np.unique(multi_merge[:, 0, 0])
]
res = [*multi_merge_list, *single_merge_list]
return res
def merge_lineage(
lineage: np.ndarray, merges: np.ndarray
) -> (np.ndarray, np.ndarray):
"""
Use merges to update lineage information.
Check if merging causes any buds to have multiple mothers and discard
those incorrect merges.
Return updated lineage and merge arrays.
"""
flat_lineage = lineage.reshape(-1, 2)
bud_mother_dict = {
tuple(bud): mother for bud, mother in zip(lineage[:, 1], lineage[:, 0])
}
left_tracks = merges[:, 0]
# find left tracks that are in lineages
valid_lineages = index_isin(flat_lineage, left_tracks).flatten()
# group into multi- and then single merges
grouped_merges = group_merges(merges)
# perform merges
if valid_lineages.any():
# indices of each left track -> indices of rightmost right track
replacement_dict = {
tuple(contig_pair[0]): merge[-1][1]
for merge in grouped_merges
for contig_pair in merge
}
# if both key and value are buds, they must have the same mother
buds = lineage[:, 1]
incorrect_merges = [
key
for key in replacement_dict
if np.any(index_isin(buds, replacement_dict[key]).flatten())
and np.any(index_isin(buds, key).flatten())
and not np.array_equal(
bud_mother_dict[key],
bud_mother_dict[tuple(replacement_dict[key])],
)
]
if incorrect_merges:
# reassign incorrect merges so that they have no affect
for key in incorrect_merges:
replacement_dict[key] = key
# find only correct merges
new_merges = merges[
~index_isin(
merges[:, 0], np.array(incorrect_merges)
).flatten(),
...,
]
else:
new_merges = merges
# correct lineage information
# replace mother or bud index with index of rightmost track
flat_lineage[valid_lineages] = [
replacement_dict[tuple(index)]
for index in flat_lineage[valid_lineages]
]
else:
new_merges = merges
# reverse flattening
new_lineage = flat_lineage.reshape(-1, 2, 2)
# remove any duplicates
new_lineage = np.unique(new_lineage, axis=0)
return new_lineage, new_merges
def apply_merges(data: pd.DataFrame, merges: np.ndarray): def apply_merges(data: pd.DataFrame, merges: np.ndarray):
"""Split data in two, one subset for rows relevant for merging and one """
without them. It uses an array of source tracklets and target tracklets Generate a new data frame containing merged tracks.
to efficiently merge them.
Parameters Parameters
---------- ----------
data : pd.DataFrame data : pd.DataFrame
Input DataFrame. A Signal data frame.
merges : np.ndarray merges : np.ndarray
3-D ndarray where dimensions are (X,2,2): nmerges, source-target An array of pairs of (trap, cell) indices to merge.
pair and single-cell identifiers, respectively.
Examples
--------
FIXME: Add docs.
""" """
indices = data.index indices = data.index
if "mother_label" in indices.names: if "mother_label" in indices.names:
indices = indices.droplevel("mother_label") indices = indices.droplevel("mother_label")
valid_merges, indices = validate_association( indices = np.array(list(indices))
merges, np.array(list(indices)) # merges in the data frame's indices
) valid_merges = index_isin(merges, indices).all(axis=1).flatten()
# corresponding indices for the data frame in merges
# Assign non-merged selected_merges = merges[valid_merges, ...]
merged = data.loc[~indices] valid_indices = index_isin(indices, selected_merges).flatten()
# data not requiring merging
# Implement the merges and drop source rows. merged = data.loc[~valid_indices]
# TODO Use matrices to perform merges in batch # merge tracks
# for ecficiency
if valid_merges.any(): if valid_merges.any():
to_merge = data.loc[indices] to_merge = data.loc[valid_indices].copy()
targets, sources = zip(*merges[valid_merges]) left_indices = merges[:, 0]
for source, target in zip(sources, targets): right_indices = merges[:, 1]
target = tuple(target) # join left track with right track
to_merge.loc[target] = join_tracks_pair( for left_index, right_index in zip(left_indices, right_indices):
to_merge.loc[target].values, to_merge.loc[tuple(left_index)] = join_two_tracks(
to_merge.loc[tuple(source)].values, to_merge.loc[tuple(left_index)].values,
to_merge.loc[tuple(right_index)].values,
) )
to_merge.drop(map(tuple, sources), inplace=True) # drop indices for right tracks
to_merge.drop(map(tuple, right_indices), inplace=True)
# add to data not requiring merges
merged = pd.concat((merged, to_merge), names=data.index.names) merged = pd.concat((merged, to_merge), names=data.index.names)
return merged return merged
def join_tracks_pair(target: np.ndarray, source: np.ndarray) -> np.ndarray: def join_two_tracks(
""" left_track: np.ndarray, right_track: np.ndarray
Join two tracks and return the new value of the target. ) -> np.ndarray:
""" """Join two tracks and return the new one."""
target_copy = target new_track = left_track.copy()
end = find_1st(target_copy[::-1], 0, cmp_larger) # find last positive element by inverting track
target_copy[-end:] = source[-end:] end = find_1st(left_track[::-1], 0, cmp_larger)
return target_copy # merge tracks into one
new_track[-end:] = right_track[-end:]
return new_track
def group_merges(merges: np.ndarray) -> t.List[t.Tuple]:
# Return a list where the cell is present as source and target
# (multimerges)
sources_targets = compare_indices(merges[:, 0, :], merges[:, 1, :])
is_multimerge = sources_targets.any(axis=0) | sources_targets.any(axis=1)
is_monomerge = ~is_multimerge
multimerge_subsets = union_find(zip(*np.where(sources_targets)))
merge_groups = [merges[np.array(tuple(x))] for x in multimerge_subsets]
sorted_merges = list(map(sort_association, merge_groups))
# Ensure that source and target are at the edges ##################################################################
return [
*sorted_merges,
*[[event] for event in merges[is_monomerge]],
]
def union_find(lsts): def union_find(lsts):
...@@ -120,27 +191,3 @@ def sort_association(array: np.ndarray): ...@@ -120,27 +191,3 @@ def sort_association(array: np.ndarray):
[res.append(x) for x in np.flip(order).flatten() if x not in res] [res.append(x) for x in np.flip(order).flatten() if x not in res]
sorted_array = array[np.array(res)] sorted_array = array[np.array(res)]
return sorted_array return sorted_array
def merge_association(
association: np.ndarray, merges: np.ndarray
) -> np.ndarray:
grouped_merges = group_merges(merges)
flat_indices = association.reshape(-1, 2)
comparison_mat = compare_indices(merges[:, 0], flat_indices)
valid_indices = comparison_mat.any(axis=0)
if valid_indices.any(): # Where valid, perform transformation
replacement_d = {}
for dataset in grouped_merges:
for k in dataset:
replacement_d[tuple(k[0])] = dataset[-1][1]
flat_indices[valid_indices] = [
replacement_d[tuple(i)] for i in flat_indices[valid_indices]
]
merged_indices = flat_indices.reshape(-1, 2, 2)
return merged_indices
# parameters to stop the pipeline when exceeded
earlystop = dict(
min_tp=100,
thresh_pos_clogged=0.4,
thresh_trap_ncells=8,
thresh_trap_area=0.9,
ntps_to_eval=5,
)
# imaging properties of the microscope
imaging_specifications = {
"pixel_size": 0.236,
"z_size": 0.6,
"spacing": 0.6,
}
# possible imaging channels
possible_imaging_channels = [
"Citrine",
"GFP",
"GFPFast",
"mCherry",
"Flavin",
"Citrine",
"mKO2",
"Cy5",
"pHluorin405",
"pHluorin488",
]
# functions to apply to the fluorescence of each cell
fluorescence_functions = [
"mean",
"median",
"std",
"imBackground",
"max5px_median",
]
# default fraction of time a cell must be in the experiment to be kept by Signal
signal_retained_cutoff = 0.8
...@@ -54,7 +54,7 @@ class DatasetLocalABC(ABC): ...@@ -54,7 +54,7 @@ class DatasetLocalABC(ABC):
Abstract Base class to find local files, either OME-XML or raw images. Abstract Base class to find local files, either OME-XML or raw images.
""" """
_valid_suffixes = ("tiff", "png", "zarr") _valid_suffixes = ("tiff", "png", "zarr", "tif")
_valid_meta_suffixes = ("txt", "log") _valid_meta_suffixes = ("txt", "log")
def __init__(self, dpath: t.Union[str, Path], *args, **kwargs): def __init__(self, dpath: t.Union[str, Path], *args, **kwargs):
......
...@@ -30,14 +30,14 @@ from agora.io.metadata import dir_to_meta, dispatch_metadata_parser ...@@ -30,14 +30,14 @@ from agora.io.metadata import dir_to_meta, dispatch_metadata_parser
def get_examples_dir(): def get_examples_dir():
"""Get examples directory which stores dummy image for tiler""" """Get examples directory that stores dummy image for tiler."""
return files("aliby").parent.parent / "examples" / "tiler" return files("aliby").parent.parent / "examples" / "tiler"
def instantiate_image( def instantiate_image(
source: t.Union[str, int, t.Dict[str, str], Path], **kwargs source: t.Union[str, int, t.Dict[str, str], Path], **kwargs
): ):
"""Wrapper to instatiate the appropiate image """Wrapper to instantiate the appropriate image
Parameters Parameters
---------- ----------
...@@ -55,26 +55,26 @@ def instantiate_image( ...@@ -55,26 +55,26 @@ def instantiate_image(
def dispatch_image(source: t.Union[str, int, t.Dict[str, str], Path]): def dispatch_image(source: t.Union[str, int, t.Dict[str, str], Path]):
""" """Pick the appropriate Image class depending on the source of data."""
Wrapper to pick the appropiate Image class depending on the source of data.
"""
if isinstance(source, (int, np.int64)): if isinstance(source, (int, np.int64)):
from aliby.io.omero import Image from aliby.io.omero import Image
instatiator = Image instantiator = Image
elif isinstance(source, dict) or ( elif isinstance(source, dict) or (
isinstance(source, (str, Path)) and Path(source).is_dir() isinstance(source, (str, Path)) and Path(source).is_dir()
): ):
if Path(source).suffix == ".zarr": if Path(source).suffix == ".zarr":
instatiator = ImageZarr instantiator = ImageZarr
else: else:
instatiator = ImageDir instantiator = ImageDir
elif isinstance(source, Path) and source.is_file():
# my addition
instantiator = ImageLocalOME
elif isinstance(source, str) and Path(source).is_file(): elif isinstance(source, str) and Path(source).is_file():
instatiator = ImageLocalOME instantiator = ImageLocalOME
else: else:
raise Exception(f"Invalid data source at {source}") raise Exception(f"Invalid data source at {source}")
return instantiator
return instatiator
class BaseLocalImage(ABC): class BaseLocalImage(ABC):
...@@ -82,6 +82,7 @@ class BaseLocalImage(ABC): ...@@ -82,6 +82,7 @@ class BaseLocalImage(ABC):
Base Image class to set path and provide context management method. Base Image class to set path and provide context management method.
""" """
# default image order
_default_dimorder = "tczyx" _default_dimorder = "tczyx"
def __init__(self, path: t.Union[str, Path]): def __init__(self, path: t.Union[str, Path]):
...@@ -98,8 +99,7 @@ class BaseLocalImage(ABC): ...@@ -98,8 +99,7 @@ class BaseLocalImage(ABC):
return False return False
def rechunk_data(self, img): def rechunk_data(self, img):
# Format image using x and y size from metadata. """Format image using x and y size from metadata."""
self._rechunked_img = da.rechunk( self._rechunked_img = da.rechunk(
img, img,
chunks=( chunks=(
...@@ -145,16 +145,16 @@ class ImageLocalOME(BaseLocalImage): ...@@ -145,16 +145,16 @@ class ImageLocalOME(BaseLocalImage):
in which a multidimensional tiff image contains the metadata. in which a multidimensional tiff image contains the metadata.
""" """
def __init__(self, path: str, dimorder=None): def __init__(self, path: str, dimorder=None, **kwargs):
super().__init__(path) super().__init__(path)
self._id = str(path) self._id = str(path)
self.set_meta(str(path))
def set_meta(self): def set_meta(self, path):
meta = dict() meta = dict()
try: try:
with TiffFile(path) as f: with TiffFile(path) as f:
self._meta = xmltodict.parse(f.ome_metadata)["OME"] self._meta = xmltodict.parse(f.ome_metadata)["OME"]
for dim in self.dimorder: for dim in self.dimorder:
meta["size_" + dim.lower()] = int( meta["size_" + dim.lower()] = int(
self._meta["Image"]["Pixels"]["@Size" + dim] self._meta["Image"]["Pixels"]["@Size" + dim]
...@@ -165,21 +165,19 @@ class ImageLocalOME(BaseLocalImage): ...@@ -165,21 +165,19 @@ class ImageLocalOME(BaseLocalImage):
] ]
meta["name"] = self._meta["Image"]["@Name"] meta["name"] = self._meta["Image"]["@Name"]
meta["type"] = self._meta["Image"]["Pixels"]["@Type"] meta["type"] = self._meta["Image"]["Pixels"]["@Type"]
except Exception as e:
except Exception as e: # Images not in OMEXML # images not in OMEXML
print("Warning:Metadata not found: {}".format(e)) print("Warning:Metadata not found: {}".format(e))
print( print(
f"Warning: No dimensional info provided. Assuming {self._default_dimorder}" "Warning: No dimensional info provided. "
f"Assuming {self._default_dimorder}"
) )
# mark non-existent dimensions for padding
# Mark non-existent dimensions for padding
self.base = self._default_dimorder self.base = self._default_dimorder
# self.ids = [self.index(i) for i in dimorder] # self.ids = [self.index(i) for i in dimorder]
self._dimorder = self.base
self._dimorder = base
self._meta = meta self._meta = meta
# self._meta["name"] = Path(path).name.split(".")[0]
@property @property
def name(self): def name(self):
...@@ -246,7 +244,7 @@ class ImageDir(BaseLocalImage): ...@@ -246,7 +244,7 @@ class ImageDir(BaseLocalImage):
It inherits from BaseLocalImage so we only override methods that are critical. It inherits from BaseLocalImage so we only override methods that are critical.
Assumptions: Assumptions:
- One folders per position. - One folder per position.
- Images are flat. - Images are flat.
- Channel, Time, z-stack and the others are determined by filenames. - Channel, Time, z-stack and the others are determined by filenames.
- Provides Dimorder as it is set in the filenames, or expects order during instatiation - Provides Dimorder as it is set in the filenames, or expects order during instatiation
...@@ -318,7 +316,7 @@ class ImageZarr(BaseLocalImage): ...@@ -318,7 +316,7 @@ class ImageZarr(BaseLocalImage):
print(f"Could not add size info to metadata: {e}") print(f"Could not add size info to metadata: {e}")
def get_data_lazy(self) -> da.Array: def get_data_lazy(self) -> da.Array:
"""Return 5D dask array. For lazy-loading local multidimensional zarr files""" """Return 5D dask array for lazy-loading local multidimensional zarr files."""
return self._img return self._img
def add_size_to_meta(self): def add_size_to_meta(self):
......
...@@ -131,7 +131,6 @@ class BridgeOmero: ...@@ -131,7 +131,6 @@ class BridgeOmero:
FIXME: Add docs. FIXME: Add docs.
""" """
# metadata = load_attributes(filepath)
bridge = BridgeH5(filepath) bridge = BridgeH5(filepath)
meta = safe_load(bridge.meta_h5["parameters"])["general"] meta = safe_load(bridge.meta_h5["parameters"])["general"]
server_info = {k: meta[k] for k in ("host", "username", "password")} server_info = {k: meta[k] for k in ("host", "username", "password")}
...@@ -268,7 +267,6 @@ class Dataset(BridgeOmero): ...@@ -268,7 +267,6 @@ class Dataset(BridgeOmero):
FIXME: Add docs. FIXME: Add docs.
""" """
# metadata = load_attributes(filepath)
bridge = BridgeH5(filepath) bridge = BridgeH5(filepath)
dataset_keys = ("omero_id", "omero_id,", "dataset_id") dataset_keys = ("omero_id", "omero_id,", "dataset_id")
for k in dataset_keys: for k in dataset_keys:
...@@ -301,21 +299,21 @@ class Image(BridgeOmero): ...@@ -301,21 +299,21 @@ class Image(BridgeOmero):
cls, cls,
filepath: t.Union[str, Path], filepath: t.Union[str, Path],
): ):
"""Instatiate Image from a hdf5 file. """
Instantiate Image from a h5 file.
Parameters Parameters
---------- ----------
cls : Image cls : Image
Image class Image class
filepath : t.Union[str, Path] filepath : t.Union[str, Path]
Location of hdf5 file. Location of h5 file.
Examples Examples
-------- --------
FIXME: Add docs. FIXME: Add docs.
""" """
# metadata = load_attributes(filepath)
bridge = BridgeH5(filepath) bridge = BridgeH5(filepath)
image_id = bridge.meta_h5["image_id"] image_id = bridge.meta_h5["image_id"]
return cls(image_id, **cls.server_info_from_h5(filepath)) return cls(image_id, **cls.server_info_from_h5(filepath))
......
...@@ -7,22 +7,19 @@ import typing as t ...@@ -7,22 +7,19 @@ import typing as t
from copy import copy from copy import copy
from importlib.metadata import version from importlib.metadata import version
from pathlib import Path from pathlib import Path
from pprint import pprint
import h5py import h5py
import numpy as np import numpy as np
import pandas as pd
from pathos.multiprocessing import Pool from pathos.multiprocessing import Pool
from tqdm import tqdm from tqdm import tqdm
import aliby.global_parameters as global_parameters
from agora.abc import ParametersABC, ProcessABC from agora.abc import ParametersABC, ProcessABC
from agora.io.metadata import MetaData, parse_logfiles from agora.io.metadata import MetaData, parse_logfiles
from agora.io.reader import StateReader from agora.io.reader import StateReader
from agora.io.signal import Signal from agora.io.signal import Signal
from agora.io.writer import ( from agora.io.writer import LinearBabyWriter, StateWriter, TilerWriter
LinearBabyWriter,
StateWriter,
TilerWriter,
)
from aliby.baby_client import BabyParameters, BabyRunner from aliby.baby_client import BabyParameters, BabyRunner
from aliby.haystack import initialise_tf from aliby.haystack import initialise_tf
from aliby.io.dataset import dispatch_dataset from aliby.io.dataset import dispatch_dataset
...@@ -32,6 +29,10 @@ from extraction.core.extractor import Extractor, ExtractorParameters ...@@ -32,6 +29,10 @@ from extraction.core.extractor import Extractor, ExtractorParameters
from extraction.core.functions.defaults import exparams_from_meta from extraction.core.functions.defaults import exparams_from_meta
from postprocessor.core.processor import PostProcessor, PostProcessorParameters from postprocessor.core.processor import PostProcessor, PostProcessorParameters
# stop warnings from TensorFlow
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
logging.getLogger("tensorflow").setLevel(logging.ERROR)
class PipelineParameters(ParametersABC): class PipelineParameters(ParametersABC):
"""Define parameters for the steps of the pipeline.""" """Define parameters for the steps of the pipeline."""
...@@ -39,7 +40,12 @@ class PipelineParameters(ParametersABC): ...@@ -39,7 +40,12 @@ class PipelineParameters(ParametersABC):
_pool_index = None _pool_index = None
def __init__( def __init__(
self, general, tiler, baby, extraction, postprocessing, reporting self,
general,
tiler,
baby,
extraction,
postprocessing,
): ):
"""Initialise, but called by a class method - not directly.""" """Initialise, but called by a class method - not directly."""
self.general = general self.general = general
...@@ -47,7 +53,6 @@ class PipelineParameters(ParametersABC): ...@@ -47,7 +53,6 @@ class PipelineParameters(ParametersABC):
self.baby = baby self.baby = baby
self.extraction = extraction self.extraction = extraction
self.postprocessing = postprocessing self.postprocessing = postprocessing
self.reporting = reporting
@classmethod @classmethod
def default( def default(
...@@ -76,16 +81,15 @@ class PipelineParameters(ParametersABC): ...@@ -76,16 +81,15 @@ class PipelineParameters(ParametersABC):
postprocessing: dict (optional) postprocessing: dict (optional)
Parameters for post-processing. Parameters for post-processing.
""" """
expt_id = general.get("expt_id", 19993) if (
if isinstance(expt_id, Path): isinstance(general["expt_id"], Path)
assert expt_id.exists() and general["expt_id"].exists()
):
expt_id = str(expt_id) expt_id = str(general["expt_id"])
general["expt_id"] = expt_id else:
expt_id = general["expt_id"]
directory = Path(general["directory"]) directory = Path(general["directory"])
# get metadata from log files either locally or via OMERO
# get log files, either locally or via OMERO
with dispatch_dataset( with dispatch_dataset(
expt_id, expt_id,
**{k: general.get(k) for k in ("host", "username", "password")}, **{k: general.get(k) for k in ("host", "username", "password")},
...@@ -107,7 +111,6 @@ class PipelineParameters(ParametersABC): ...@@ -107,7 +111,6 @@ class PipelineParameters(ParametersABC):
} }
# set minimal metadata # set minimal metadata
meta_d = minimal_default_meta meta_d = minimal_default_meta
# define default values for general parameters # define default values for general parameters
tps = meta_d.get("ntps", 2000) tps = meta_d.get("ntps", 2000)
defaults = { defaults = {
...@@ -117,19 +120,12 @@ class PipelineParameters(ParametersABC): ...@@ -117,19 +120,12 @@ class PipelineParameters(ParametersABC):
tps=tps, tps=tps,
directory=str(directory.parent), directory=str(directory.parent),
filter="", filter="",
earlystop=dict( earlystop=global_parameters.earlystop,
min_tp=100,
thresh_pos_clogged=0.4,
thresh_trap_ncells=8,
thresh_trap_area=0.9,
ntps_to_eval=5,
),
logfile_level="INFO", logfile_level="INFO",
use_explog=True, use_explog=True,
) )
} }
# update default values for general using inputs
# update default values using inputs
for k, v in general.items(): for k, v in general.items():
if k not in defaults["general"]: if k not in defaults["general"]:
defaults["general"][k] = v defaults["general"][k] = v
...@@ -138,11 +134,9 @@ class PipelineParameters(ParametersABC): ...@@ -138,11 +134,9 @@ class PipelineParameters(ParametersABC):
defaults["general"][k][k2] = v2 defaults["general"][k][k2] = v2
else: else:
defaults["general"][k] = v defaults["general"][k] = v
# default Tiler parameters
# define defaults and update with any inputs
defaults["tiler"] = TilerParameters.default(**tiler).to_dict() defaults["tiler"] = TilerParameters.default(**tiler).to_dict()
# generate a backup channel for when logfile meta is available
# Generate a backup channel, for when logfile meta is available
# but not image metadata. # but not image metadata.
backup_ref_channel = None backup_ref_channel = None
if "channels" in meta_d and isinstance( if "channels" in meta_d and isinstance(
...@@ -152,20 +146,18 @@ class PipelineParameters(ParametersABC): ...@@ -152,20 +146,18 @@ class PipelineParameters(ParametersABC):
defaults["tiler"]["ref_channel"] defaults["tiler"]["ref_channel"]
) )
defaults["tiler"]["backup_ref_channel"] = backup_ref_channel defaults["tiler"]["backup_ref_channel"] = backup_ref_channel
# default BABY parameters
defaults["baby"] = BabyParameters.default(**baby).to_dict() defaults["baby"] = BabyParameters.default(**baby).to_dict()
defaults["extraction"] = ( # default Extraction parmeters
exparams_from_meta(meta_d) defaults["extraction"] = exparams_from_meta(meta_d)
or BabyParameters.default(**extraction).to_dict() # default PostProcessing parameters
)
defaults["postprocessing"] = PostProcessorParameters.default( defaults["postprocessing"] = PostProcessorParameters.default(
**postprocessing **postprocessing
).to_dict() ).to_dict()
defaults["reporting"] = {}
return cls(**{k: v for k, v in defaults.items()}) return cls(**{k: v for k, v in defaults.items()})
def load_logs(self): def load_logs(self):
"""Load and parse log files."""
parsed_flattened = parse_logfiles(self.log_dir) parsed_flattened = parse_logfiles(self.log_dir)
return parsed_flattened return parsed_flattened
...@@ -187,7 +179,7 @@ class Pipeline(ProcessABC): ...@@ -187,7 +179,7 @@ class Pipeline(ProcessABC):
"postprocessing", "postprocessing",
] ]
# Specify the group in the h5 files written by each step # specify the group in the h5 files written by each step
writer_groups = { writer_groups = {
"tiler": ["trap_info"], "tiler": ["trap_info"],
"baby": ["cell_info"], "baby": ["cell_info"],
...@@ -228,12 +220,6 @@ class Pipeline(ProcessABC): ...@@ -228,12 +220,6 @@ class Pipeline(ProcessABC):
fh.setFormatter(formatter) fh.setFormatter(formatter)
logger.addHandler(fh) logger.addHandler(fh)
@classmethod
def from_yaml(cls, fpath):
# This is just a convenience function, think before implementing
# for other processes
return cls(parameters=PipelineParameters.from_yaml(fpath))
@classmethod @classmethod
def from_folder(cls, dir_path): def from_folder(cls, dir_path):
""" """
...@@ -304,11 +290,16 @@ class Pipeline(ProcessABC): ...@@ -304,11 +290,16 @@ class Pipeline(ProcessABC):
def run(self): def run(self):
"""Run separate pipelines for all positions in an experiment.""" """Run separate pipelines for all positions in an experiment."""
# general information in config # display configuration
config = self.parameters.to_dict() config = self.parameters.to_dict()
for step in config:
print("\n---\n" + step + "\n---")
pprint(config[step])
print()
# extract from configuration
expt_id = config["general"]["id"] expt_id = config["general"]["id"]
distributed = config["general"]["distributed"] distributed = config["general"]["distributed"]
pos_filter = config["general"]["filter"] position_filter = config["general"]["filter"]
root_dir = Path(config["general"]["directory"]) root_dir = Path(config["general"]["directory"])
self.server_info = { self.server_info = {
k: config["general"].get(k) k: config["general"].get(k)
...@@ -320,56 +311,76 @@ class Pipeline(ProcessABC): ...@@ -320,56 +311,76 @@ class Pipeline(ProcessABC):
) )
# get log files, either locally or via OMERO # get log files, either locally or via OMERO
with dispatcher as conn: with dispatcher as conn:
image_ids = conn.get_images() position_ids = conn.get_images()
directory = self.store or root_dir / conn.unique_name directory = self.store or root_dir / conn.unique_name
if not directory.exists(): if not directory.exists():
directory.mkdir(parents=True) directory.mkdir(parents=True)
# download logs to use for metadata # get logs to use for metadata
conn.cache_logs(directory) conn.cache_logs(directory)
print("Positions available:")
for i, pos in enumerate(position_ids.keys()):
print("\t" + f"{i}: " + pos.split(".")[0])
# update configuration # update configuration
self.parameters.general["directory"] = str(directory) self.parameters.general["directory"] = str(directory)
config["general"]["directory"] = directory config["general"]["directory"] = directory
self.setLogger(directory) self.setLogger(directory)
# pick particular images if desired # pick particular positions if desired
if pos_filter is not None: if position_filter is not None:
if isinstance(pos_filter, list): if isinstance(position_filter, list):
image_ids = { position_ids = {
k: v k: v
for filt in pos_filter for filt in position_filter
for k, v in self.apply_filter(image_ids, filt).items() for k, v in self.apply_filter(position_ids, filt).items()
} }
else: else:
image_ids = self.apply_filter(image_ids, pos_filter) position_ids = self.apply_filter(position_ids, position_filter)
assert len(image_ids), "No images to segment" if not len(position_ids):
# create pipelines raise Exception("No images to segment.")
else:
print("\nPositions selected:")
for pos in position_ids:
print("\t" + pos.split(".")[0])
# create and run pipelines
if distributed != 0: if distributed != 0:
# multiple cores # multiple cores
with Pool(distributed) as p: with Pool(distributed) as p:
results = p.map( results = p.map(
lambda x: self.run_one_position(*x), lambda x: self.run_one_position(*x),
[(k, i) for i, k in enumerate(image_ids.items())], [
(position_id, i)
for i, position_id in enumerate(position_ids.items())
],
) )
else: else:
# single core # single core
results = [] results = [
for k, v in tqdm(image_ids.items()): self.run_one_position((position_id, position_id_path), 1)
r = self.run_one_position((k, v), 1) for position_id, position_id_path in tqdm(position_ids.items())
results.append(r) ]
return results return results
def apply_filter(self, image_ids: dict, filt: int or str): def apply_filter(self, position_ids: dict, position_filter: int or str):
"""Select images by picking a particular one or by using a regular expression to parse their file names.""" """
if isinstance(filt, str): Select positions.
# pick images using a regular expression
image_ids = { Either pick a particular position or use a regular expression
k: v for k, v in image_ids.items() if re.search(filt, k) to parse their file names.
"""
if isinstance(position_filter, str):
# pick positions using a regular expression
position_ids = {
k: v
for k, v in position_ids.items()
if re.search(position_filter, k)
} }
elif isinstance(filt, int): elif isinstance(position_filter, int):
# pick the filt'th image # pick a particular position
image_ids = { position_ids = {
k: v for i, (k, v) in enumerate(image_ids.items()) if i == filt k: v
for i, (k, v) in enumerate(position_ids.items())
if i == position_filter
} }
return image_ids return position_ids
def run_one_position( def run_one_position(
self, self,
...@@ -379,120 +390,101 @@ class Pipeline(ProcessABC): ...@@ -379,120 +390,101 @@ class Pipeline(ProcessABC):
"""Set up and run a pipeline for one position.""" """Set up and run a pipeline for one position."""
self._pool_index = index self._pool_index = index
name, image_id = name_image_id name, image_id = name_image_id
# session and filename are defined by calling setup_pipeline. # session is defined by calling pipe_pipeline.
# can they be deleted here? # can it be deleted here?
session = None session = None
filename = None run_kwargs = {"extraction": {"cell_labels": None, "masks": None}}
#
run_kwargs = {"extraction": {"labels": None, "masks": None}}
try: try:
( pipe, session = self.setup_pipeline(image_id, name)
filename,
meta,
config,
process_from,
tps,
steps,
earlystop,
session,
trackers_state,
) = self._setup_pipeline(image_id)
loaded_writers = { loaded_writers = {
name: writer(filename) name: writer(pipe["filename"])
for k in self.step_sequence for k in self.step_sequence
if k in self.writers if k in self.writers
for name, writer in self.writers[k] for name, writer in self.writers[k]
} }
writer_ow_kwargs = { writer_overwrite_kwargs = {
"state": loaded_writers["state"].datatypes.keys(), "state": loaded_writers["state"].datatypes.keys(),
"baby": ["mother_assign"], "baby": ["mother_assign"],
} }
# START PIPELINE # START PIPELINE
frac_clogged_traps = 0.0 frac_clogged_traps = 0.0
min_process_from = min(process_from.values()) min_process_from = min(pipe["process_from"].values())
with dispatch_image(image_id)( with dispatch_image(image_id)(
image_id, **self.server_info image_id, **self.server_info
) as image: ) as image:
# initialise steps # initialise steps
if "tiler" not in steps: if "tiler" not in pipe["steps"]:
steps["tiler"] = Tiler.from_image( pipe["config"]["tiler"]["position_name"] = name.split(".")[
image, TilerParameters.from_dict(config["tiler"]) 0
]
pipe["steps"]["tiler"] = Tiler.from_image(
image,
TilerParameters.from_dict(pipe["config"]["tiler"]),
) )
if process_from["baby"] < tps: if pipe["process_from"]["baby"] < pipe["tps"]:
session = initialise_tf(2) session = initialise_tf(2)
steps["baby"] = BabyRunner.from_tiler( pipe["steps"]["baby"] = BabyRunner.from_tiler(
BabyParameters.from_dict(config["baby"]), BabyParameters.from_dict(pipe["config"]["baby"]),
steps["tiler"], pipe["steps"]["tiler"],
)
if trackers_state:
steps["baby"].crawler.tracker_states = trackers_state
# limit extraction parameters using the available channels in tiler
if process_from["extraction"] < tps:
# TODO Move this parameter validation into Extractor
av_channels = set((*steps["tiler"].channels, "general"))
config["extraction"]["tree"] = {
k: v
for k, v in config["extraction"]["tree"].items()
if k in av_channels
}
config["extraction"]["sub_bg"] = av_channels.intersection(
config["extraction"]["sub_bg"]
)
av_channels_wsub = av_channels.union(
[c + "_bgsub" for c in config["extraction"]["sub_bg"]]
) )
tmp = copy(config["extraction"]["multichannel_ops"]) if pipe["trackers_state"]:
for op, (input_ch, _, _) in tmp.items(): pipe["steps"]["baby"].crawler.tracker_states = pipe[
if not set(input_ch).issubset(av_channels_wsub): "trackers_state"
del config["extraction"]["multichannel_ops"][op] ]
if pipe["process_from"]["extraction"] < pipe["tps"]:
exparams = ExtractorParameters.from_dict( exparams = ExtractorParameters.from_dict(
config["extraction"] pipe["config"]["extraction"]
) )
steps["extraction"] = Extractor.from_tiler( pipe["steps"]["extraction"] = Extractor.from_tiler(
exparams, store=filename, tiler=steps["tiler"] exparams,
store=pipe["filename"],
tiler=pipe["steps"]["tiler"],
) )
# set up progress meter # initiate progress bar
pbar = tqdm( pbar = tqdm(
range(min_process_from, tps), range(min_process_from, pipe["tps"]),
desc=image.name, desc=image.name,
initial=min_process_from, initial=min_process_from,
total=tps, total=pipe["tps"],
) )
# run through time points
for i in pbar: for i in pbar:
if ( if (
frac_clogged_traps frac_clogged_traps
< earlystop["thresh_pos_clogged"] < pipe["earlystop"]["thresh_pos_clogged"]
or i < earlystop["min_tp"] or i < pipe["earlystop"]["min_tp"]
): ):
# run through steps # run through steps
for step in self.pipeline_steps: for step in self.pipeline_steps:
if i >= process_from[step]: if i >= pipe["process_from"][step]:
result = steps[step].run_tp( # perform step
result = pipe["steps"][step].run_tp(
i, **run_kwargs.get(step, {}) i, **run_kwargs.get(step, {})
) )
# write to h5 file using writers
# extractor writes to h5 itself
if step in loaded_writers: if step in loaded_writers:
loaded_writers[step].write( loaded_writers[step].write(
data=result, data=result,
overwrite=writer_ow_kwargs.get( overwrite=writer_overwrite_kwargs.get(
step, [] step, []
), ),
tp=i, tp=i,
meta={"last_processed": i}, meta={"last_processed": i},
) )
# perform step # clean up
if ( if (
step == "tiler" step == "tiler"
and i == min_process_from and i == min_process_from
): ):
logging.getLogger("aliby").info( logging.getLogger("aliby").info(
f"Found {steps['tiler'].n_tiles} traps in {image.name}" f"Found {pipe['steps']['tiler'].no_tiles} traps in {image.name}"
) )
elif step == "baby": elif step == "baby":
# write state and pass info to Extractor # write state
loaded_writers["state"].write( loaded_writers["state"].write(
data=steps[ data=pipe["steps"][
step step
].crawler.tracker_states, ].crawler.tracker_states,
overwrite=loaded_writers[ overwrite=loaded_writers[
...@@ -501,12 +493,14 @@ class Pipeline(ProcessABC): ...@@ -501,12 +493,14 @@ class Pipeline(ProcessABC):
tp=i, tp=i,
) )
elif step == "extraction": elif step == "extraction":
# remove mask/label after extraction # remove masks and labels after extraction
for k in ["masks", "labels"]: for k in ["masks", "cell_labels"]:
run_kwargs[step][k] = None run_kwargs[step][k] = None
# check and report clogging # check and report clogging
frac_clogged_traps = self.check_earlystop( frac_clogged_traps = self.check_earlystop(
filename, earlystop, steps["tiler"].tile_size pipe["filename"],
pipe["earlystop"],
pipe["steps"]["tiler"].tile_size,
) )
if frac_clogged_traps > 0.3: if frac_clogged_traps > 0.3:
self._log( self._log(
...@@ -519,18 +513,17 @@ class Pipeline(ProcessABC): ...@@ -519,18 +513,17 @@ class Pipeline(ProcessABC):
self._log( self._log(
f"{name}:Stopped early at time {i} with {frac_clogged_traps} clogged traps" f"{name}:Stopped early at time {i} with {frac_clogged_traps} clogged traps"
) )
meta.add_fields({"end_status": "Clogged"}) pipe["meta"].add_fields({"end_status": "Clogged"})
break break
meta.add_fields({"last_processed": i}) pipe["meta"].add_fields({"last_processed": i})
pipe["meta"].add_fields({"end_status": "Success"})
# run post-processing # run post-processing
meta.add_fields({"end_status": "Success"})
post_proc_params = PostProcessorParameters.from_dict( post_proc_params = PostProcessorParameters.from_dict(
config["postprocessing"] pipe["config"]["postprocessing"]
) )
PostProcessor(filename, post_proc_params).run() PostProcessor(pipe["filename"], post_proc_params).run()
self._log("Analysis finished successfully.", "info") self._log("Analysis finished successfully.", "info")
return 1 return 1
except Exception as e: except Exception as e:
# catch bugs during setup or run time # catch bugs during setup or run time
logging.exception( logging.exception(
...@@ -541,88 +534,12 @@ class Pipeline(ProcessABC): ...@@ -541,88 +534,12 @@ class Pipeline(ProcessABC):
traceback.print_exc() traceback.print_exc()
raise e raise e
finally: finally:
_close_session(session) close_session(session)
@staticmethod
def check_earlystop(filename: str, es_parameters: dict, tile_size: int):
"""
Check recent time points for tiles with too many cells.
Returns the fraction of clogged tiles, where clogged tiles have
too many cells or too much of their area covered by cells.
Parameters
----------
filename: str
Name of h5 file.
es_parameters: dict
Parameters defining when early stopping should happen.
For example:
{'min_tp': 100,
'thresh_pos_clogged': 0.4,
'thresh_trap_ncells': 8,
'thresh_trap_area': 0.9,
'ntps_to_eval': 5}
tile_size: int
Size of tile.
"""
# get the area of the cells organised by trap and cell number
s = Signal(filename)
df = s.get_raw("/extraction/general/None/area")
# check the latest time points only
cells_used = df[
df.columns[-1 - es_parameters["ntps_to_eval"] : -1]
].dropna(how="all")
# find tiles with too many cells
traps_above_nthresh = (
cells_used.groupby("trap").count().apply(np.mean, axis=1)
> es_parameters["thresh_trap_ncells"]
)
# find tiles with cells covering too great a fraction of the tiles' area
traps_above_athresh = (
cells_used.groupby("trap").sum().apply(np.mean, axis=1)
/ tile_size**2
> es_parameters["thresh_trap_area"]
)
return (traps_above_nthresh & traps_above_athresh).mean()
# FIXME: Remove this functionality. It used to be for def setup_pipeline(
# older hdf5 file formats.
def _load_config_from_file(
self, self,
filename: Path, image_id: int,
process_from: t.Dict[str, int], name: str,
trackers_state: t.List,
overwrite: t.Dict[str, bool],
):
with h5py.File(filename, "r") as f:
for k in process_from.keys():
if not overwrite[k]:
process_from[k] = self.legacy_get_last_tp[k](f)
process_from[k] += 1
return process_from, trackers_state, overwrite
# FIXME: Remove this functionality. It used to be for
# older hdf5 file formats.
@staticmethod
def legacy_get_last_tp(step: str) -> t.Callable:
"""Get last time-point in different ways depending
on which step we are using
To support segmentation in aliby < v0.24
TODO Deprecate and replace with State method
"""
switch_case = {
"tiler": lambda f: f["trap_info/drifts"].shape[0] - 1,
"baby": lambda f: f["cell_info/timepoint"][-1],
"extraction": lambda f: f[
"extraction/general/None/area/timepoint"
][-1],
}
return switch_case[step]
def _setup_pipeline(
self, image_id: int
) -> t.Tuple[ ) -> t.Tuple[
Path, Path,
MetaData, MetaData,
...@@ -645,83 +562,100 @@ class Pipeline(ProcessABC): ...@@ -645,83 +562,100 @@ class Pipeline(ProcessABC):
Returns Returns
------- -------
filename: str pipe: dict
Path to a h5 file to write to. With keys
meta: object filename: str
agora.io.metadata.MetaData object Path to a h5 file to write to.
config: dict meta: object
Configuration parameters. agora.io.metadata.MetaData object
process_from: dict config: dict
Gives from which time point each step of the pipeline should start. Configuration parameters.
tps: int process_from: dict
Number of time points. Gives time points from which each step of the
steps: dict pipeline should start.
earlystop: dict tps: int
Parameters to check whether the pipeline should be stopped. Number of time points.
steps: dict
earlystop: dict
Parameters to check whether the pipeline should
be stopped.
trackers_state: list
States of any trackers from earlier runs.
session: None session: None
trackers_state: list
States of any trackers from earlier runs.
""" """
pipe = {}
config = self.parameters.to_dict() config = self.parameters.to_dict()
# TODO Alan: Verify if session must be passed # TODO Alan: Verify if session must be passed
session = None session = None
earlystop = config["general"].get("earlystop", None) pipe["earlystop"] = config["general"].get("earlystop", None)
process_from = {k: 0 for k in self.pipeline_steps} pipe["process_from"] = {k: 0 for k in self.pipeline_steps}
steps = {} pipe["steps"] = {}
# check overwriting # check overwriting
ow_id = config["general"].get("overwrite", 0) overwrite_id = config["general"].get("overwrite", 0)
ow = {step: True for step in self.step_sequence} overwrite = {step: True for step in self.step_sequence}
if ow_id and ow_id is not True: if overwrite_id and overwrite_id is not True:
ow = { overwrite = {
step: self.step_sequence.index(ow_id) < i step: self.step_sequence.index(overwrite_id) < i
for i, step in enumerate(self.step_sequence, 1) for i, step in enumerate(self.step_sequence, 1)
} }
# set up
# Set up
directory = config["general"]["directory"] directory = config["general"]["directory"]
pipe["trackers_state"] = []
trackers_state: t.List[np.ndarray] = []
with dispatch_image(image_id)(image_id, **self.server_info) as image: with dispatch_image(image_id)(image_id, **self.server_info) as image:
filename = Path(f"{directory}/{image.name}.h5") pipe["filename"] = Path(f"{directory}/{image.name}.h5")
meta = MetaData(directory, filename) # load metadata from h5 file
from_start = True if np.any(ow.values()) else False pipe["meta"] = MetaData(directory, pipe["filename"])
# remove existing file if overwriting from_start = True if np.any(overwrite.values()) else False
# remove existing h5 file if overwriting
if ( if (
from_start from_start
and ( and (
config["general"].get("overwrite", False) config["general"].get("overwrite", False)
or np.all(list(ow.values())) or np.all(list(overwrite.values()))
) )
and filename.exists() and pipe["filename"].exists()
): ):
os.remove(filename) os.remove(pipe["filename"])
# if the file exists with no previous segmentation use its tiler # if the file exists with no previous segmentation use its tiler
if filename.exists(): if pipe["filename"].exists():
self._log("Result file exists.", "info") self._log("Result file exists.", "info")
if not ow["tiler"]: if not overwrite["tiler"]:
steps["tiler"] = Tiler.from_hdf5(image, filename) tiler_params_dict = TilerParameters.default().to_dict()
tiler_params_dict["position_name"] = name.split(".")[0]
tiler_params = TilerParameters.from_dict(tiler_params_dict)
pipe["steps"]["tiler"] = Tiler.from_h5(
image, pipe["filename"], tiler_params
)
try: try:
( (
process_from, process_from,
trackers_state, trackers_state,
ow, overwrite,
) = self._load_config_from_file( ) = self._load_config_from_file(
filename, process_from, trackers_state, ow pipe["filename"],
pipe["process_from"],
pipe["trackers_state"],
overwrite,
) )
# get state array # get state array
trackers_state = ( pipe["trackers_state"] = (
[] []
if ow["baby"] if overwrite["baby"]
else StateReader(filename).get_formatted_states() else StateReader(
pipe["filename"]
).get_formatted_states()
) )
config["tiler"] = steps["tiler"].parameters.to_dict() config["tiler"] = pipe["steps"][
"tiler"
].parameters.to_dict()
except Exception: except Exception:
self._log(f"Overwriting tiling data") self._log("Overwriting tiling data")
if config["general"]["use_explog"]: if config["general"]["use_explog"]:
meta.run() pipe["meta"].run()
pipe["config"] = config
# add metadata not in the log file # add metadata not in the log file
meta.add_fields( pipe["meta"].add_fields(
{ {
"aliby_version": version("aliby"), "aliby_version": version("aliby"),
"baby_version": version("aliby-baby"), "baby_version": version("aliby-baby"),
...@@ -734,20 +668,53 @@ class Pipeline(ProcessABC): ...@@ -734,20 +668,53 @@ class Pipeline(ProcessABC):
).to_yaml(), ).to_yaml(),
} }
) )
tps = min(config["general"]["tps"], image.data.shape[0]) pipe["tps"] = min(config["general"]["tps"], image.data.shape[0])
return ( return pipe, session
filename,
meta, @staticmethod
config, def check_earlystop(filename: str, es_parameters: dict, tile_size: int):
process_from, """
tps, Check recent time points for tiles with too many cells.
steps,
earlystop, Returns the fraction of clogged tiles, where clogged tiles have
session, too many cells or too much of their area covered by cells.
trackers_state,
) Parameters
----------
filename: str
Name of h5 file.
es_parameters: dict
Parameters defining when early stopping should happen.
For example:
{'min_tp': 100,
'thresh_pos_clogged': 0.4,
'thresh_trap_ncells': 8,
'thresh_trap_area': 0.9,
'ntps_to_eval': 5}
tile_size: int
Size of tile.
"""
# get the area of the cells organised by trap and cell number
s = Signal(filename)
df = s.get_raw("/extraction/general/None/area")
# check the latest time points only
cells_used = df[
df.columns[-1 - es_parameters["ntps_to_eval"] : -1]
].dropna(how="all")
# find tiles with too many cells
traps_above_nthresh = (
cells_used.groupby("trap").count().apply(np.mean, axis=1)
> es_parameters["thresh_trap_ncells"]
)
# find tiles with cells covering too great a fraction of the tiles' area
traps_above_athresh = (
cells_used.groupby("trap").sum().apply(np.mean, axis=1)
/ tile_size**2
> es_parameters["thresh_trap_area"]
)
return (traps_above_nthresh & traps_above_athresh).mean()
def _close_session(session): def close_session(session):
if session: if session:
session.close() session.close()
""" """
Tiler: Divides images into smaller tiles. Tiler: Divides images into smaller tiles.
The tasks of the Tiler are selecting regions of interest, or tiles, of images - with one trap per tile, correcting for the drift of the microscope stage over time, and handling errors and bridging between the image data and Aliby’s image-processing steps. The tasks of the Tiler are selecting regions of interest, or tiles, of
images - with one trap per tile, correcting for the drift of the microscope
stage over time, and handling errors and bridging between the image data
and Aliby’s image-processing steps.
Tiler subclasses deal with either network connections or local files. Tiler subclasses deal with either network connections or local files.
To find tiles, we use a two-step process: we analyse the bright-field image to produce the template of a trap, and we fit this template to the image to find the tiles' centres. To find tiles, we use a two-step process: we analyse the bright-field image
to produce the template of a trap, and we fit this template to the image to
find the tiles' centres.
We use texture-based segmentation (entropy) to split the image into foreground -- cells and traps -- and background, which we then identify with an Otsu filter. Two methods are used to produce a template trap from these regions: pick the trap with the smallest minor axis length and average over all validated traps. We use texture-based segmentation (entropy) to split the image into
foreground -- cells and traps -- and background, which we then identify with
an Otsu filter. Two methods are used to produce a template trap from these
regions: pick the trap with the smallest minor axis length and average over
all validated traps.
A peak-identifying algorithm recovers the x and y-axis location of traps in the original image, and we choose the approach to template that identifies the most tiles. A peak-identifying algorithm recovers the x and y-axis location of traps in
the original image, and we choose the approach to template that identifies
the most tiles.
The experiment is stored as an array with a standard indexing order of (Time, Channels, Z-stack, X, Y). The experiment is stored as an array with a standard indexing order of
(Time, Channels, Z-stack, X, Y).
""" """
import logging import logging
import re import re
...@@ -25,28 +37,24 @@ import h5py ...@@ -25,28 +37,24 @@ import h5py
import numpy as np import numpy as np
from skimage.registration import phase_cross_correlation from skimage.registration import phase_cross_correlation
import aliby.global_parameters as global_parameters
from agora.abc import ParametersABC, StepABC from agora.abc import ParametersABC, StepABC
from agora.io.writer import BridgeH5 from agora.io.writer import BridgeH5
from aliby.io.image import ImageDummy
from aliby.tile.traps import segment_traps from aliby.tile.traps import segment_traps
class Tile: class Tile:
""" """Store a tile's location and size."""
Store a tile's location and size.
Checks to see if the tile should be padded.
Can export the tile either in OMERO or numpy formats.
"""
def __init__(self, centre, parent, size, max_size): def __init__(self, centre, parent_class, size, max_size):
"""Initialise using a parent class."""
self.centre = centre self.centre = centre
self.parent = parent # used to access drifts self.parent_class = parent_class # used to access drifts
self.size = size self.size = size
self.half_size = size // 2 self.half_size = size // 2
self.max_size = max_size self.max_size = max_size
def at_time(self, tp: int) -> t.List[int]: def centre_at_time(self, tp: int) -> t.List[int]:
""" """
Return tile's centre by applying drifts. Return tile's centre by applying drifts.
...@@ -55,7 +63,7 @@ class Tile: ...@@ -55,7 +63,7 @@ class Tile:
tp: integer tp: integer
Index for the time point of interest. Index for the time point of interest.
""" """
drifts = self.parent.drifts drifts = self.parent_class.drifts
tile_centre = self.centre - np.sum(drifts[: tp + 1], axis=0) tile_centre = self.centre - np.sum(drifts[: tp + 1], axis=0)
return list(tile_centre.astype(int)) return list(tile_centre.astype(int))
...@@ -74,15 +82,15 @@ class Tile: ...@@ -74,15 +82,15 @@ class Tile:
Returns Returns
------- -------
x: int x: int
x-coordinate of bottom left corner of tile x-coordinate of bottom left corner of tile.
y: int y: int
y-coordinate of bottom left corner of tile y-coordinate of bottom left corner of tile.
w: int w: int
Width of tile Width of tile.
h: int h: int
Height of tile Height of tile.
""" """
x, y = self.at_time(tp) x, y = self.centre_at_time(tp)
# tile bottom corner # tile bottom corner
x = int(x - self.half_size) x = int(x - self.half_size)
y = int(y - self.half_size) y = int(y - self.half_size)
...@@ -90,8 +98,7 @@ class Tile: ...@@ -90,8 +98,7 @@ class Tile:
def as_range(self, tp: int): def as_range(self, tp: int):
""" """
Return tile in a range format: two slice objects that can Return a horizontal and a vertical slice of a tile.
be used in arrays.
Parameters Parameters
---------- ----------
...@@ -117,6 +124,20 @@ class TileLocations: ...@@ -117,6 +124,20 @@ class TileLocations:
max_size: int = 1200, max_size: int = 1200,
drifts: np.array = None, drifts: np.array = None,
): ):
"""
Initialise tiles as an array of Tile objects.
Parameters
----------
initial_location: array
An array of tile centres.
tile_size: int
Length of one side of a square tile.
max_size: int, optional
Default is 1200.
drifts: array
An array of translations to correct drift of the microscope.
"""
if drifts is None: if drifts is None:
drifts = [] drifts = []
self.tile_size = tile_size self.tile_size = tile_size
...@@ -129,20 +150,21 @@ class TileLocations: ...@@ -129,20 +150,21 @@ class TileLocations:
self.drifts = drifts self.drifts = drifts
def __len__(self): def __len__(self):
"""Find number of tiles."""
return len(self.tiles) return len(self.tiles)
def __iter__(self): def __iter__(self):
"""Return the next tile from the list of tiles."""
yield from self.tiles yield from self.tiles
@property @property
def shape(self): def shape(self):
"""Return numbers of tiles and drifts.""" """Return the number of tiles and the number of drifts."""
return len(self.tiles), len(self.drifts) return len(self.tiles), len(self.drifts)
def to_dict(self, tp: int): def to_dict(self, tp: int):
""" """
Export initial locations, tile_size, max_size, and drifts Export initial locations, tile_size, max_size, and drifts as a dict.
as a dictionary.
Parameters Parameters
---------- ----------
...@@ -157,19 +179,22 @@ class TileLocations: ...@@ -157,19 +179,22 @@ class TileLocations:
res["drifts"] = np.expand_dims(self.drifts[tp], axis=0) res["drifts"] = np.expand_dims(self.drifts[tp], axis=0)
return res return res
def at_time(self, tp: int) -> np.ndarray: def centres_at_time(self, tp: int) -> np.ndarray:
"""Return an array of tile centres (x- and y-coords).""" """Return an array of tile centres (x- and y-coords)."""
return np.array([tile.at_time(tp) for tile in self.tiles]) return np.array([tile.centre_at_time(tp) for tile in self.tiles])
@classmethod @classmethod
def from_tiler_init( def from_tiler_init(
cls, initial_location, tile_size: int = None, max_size: int = 1200 cls,
initial_location,
tile_size: int = None,
max_size: int = 1200,
): ):
"""Instantiate from a Tiler.""" """Instantiate from a Tiler."""
return cls(initial_location, tile_size, max_size, drifts=[]) return cls(initial_location, tile_size, max_size, drifts=[])
@classmethod @classmethod
def read_hdf5(cls, file): def read_h5(cls, file):
"""Instantiate from a h5 file.""" """Instantiate from a h5 file."""
with h5py.File(file, "r") as hfile: with h5py.File(file, "r") as hfile:
tile_info = hfile["trap_info"] tile_info = hfile["trap_info"]
...@@ -183,30 +208,41 @@ class TileLocations: ...@@ -183,30 +208,41 @@ class TileLocations:
class TilerParameters(ParametersABC): class TilerParameters(ParametersABC):
""" """Define default values for tile size and the reference channels."""
tile_size: int
ref_channel: str or int
ref_z: int
backup_ref_channel int or None, if int indicates the index for reference channel. Used when image does not include metadata, ref_channel is a string and channel names are included in parsed logfiles.
"""
_defaults = { _defaults = {
"tile_size": 117, "tile_size": 117,
"ref_channel": "Brightfield", "ref_channel": "Brightfield",
"ref_z": 0, "ref_z": 0,
"backup_ref_channel": None, "backup_ref_channel": None,
"position_name": None,
}
def find_channels_by_position(meta):
"""Parse metadata to find the imaging channels used for each group."""
channels_dict = {
position_name: [] for position_name in meta["positions/posname"]
} }
imaging_channels = meta["channels"]
for i, position_name in enumerate(meta["positions/posname"]):
for imaging_channel in imaging_channels:
if (
"positions/" + imaging_channel in meta
and meta["positions/" + imaging_channel][i]
):
channels_dict[position_name].append(imaging_channel)
return channels_dict
class Tiler(StepABC): class Tiler(StepABC):
""" """
Divide images into smaller tiles for faster processing. Divide images into smaller tiles for faster processing.
Finds tiles and re-registers images if they drift. Find tiles and re-register images if they drift.
Fetch images from an OMERO server if necessary. Fetch images from an OMERO server if necessary.
Uses an Image instance, which lazily provides the data on pixels, Uses an Image instance, which lazily provides the pixel data,
and, as an independent argument, metadata. and, as an independent argument, metadata.
""" """
...@@ -215,7 +251,7 @@ class Tiler(StepABC): ...@@ -215,7 +251,7 @@ class Tiler(StepABC):
image: da.core.Array, image: da.core.Array,
metadata: dict, metadata: dict,
parameters: TilerParameters, parameters: TilerParameters,
tile_locs=None, tile_locations=None,
): ):
""" """
Initialise. Initialise.
...@@ -229,64 +265,25 @@ class Tiler(StepABC): ...@@ -229,64 +265,25 @@ class Tiler(StepABC):
""" """
super().__init__(parameters) super().__init__(parameters)
self.image = image self.image = image
self._metadata = metadata self.position_name = parameters.to_dict()["position_name"]
self.channels = metadata.get( # get channels for this position
"channels", channel_dict = find_channels_by_position(metadata)
self.channels = channel_dict.get(
self.position_name,
list(range(metadata.get("size_c", 0))), list(range(metadata.get("size_c", 0))),
) )
# get reference channel - used for segmentation
self.ref_channel = self.get_channel_index(parameters.ref_channel) self.ref_channel = self.get_channel_index(parameters.ref_channel)
if self.ref_channel is None: if self.ref_channel is None:
self.ref_channel = self.backup_ref_channel self.ref_channel = self.backup_ref_channel
self.tile_locs = tile_locations
self.ref_channel = self.get_channel_index(parameters.ref_channel) if "zsections" in metadata:
self.tile_locs = tile_locs
try:
self.z_perchannel = { self.z_perchannel = {
ch: zsect ch: zsect
for ch, zsect in zip(self.channels, metadata["zsections"]) for ch, zsect in zip(self.channels, metadata["zsections"])
} }
except Exception as e:
self._log(f"No z_perchannel data: {e}")
self.tile_size = self.tile_size or min(self.image.shape[-2:]) self.tile_size = self.tile_size or min(self.image.shape[-2:])
@classmethod
def dummy(cls, parameters: dict):
"""
Instantiate dummy Tiler from dummy image.
If image.dimorder exists dimensions are saved in that order.
Otherwise default to "tczyx".
Parameters
----------
parameters: dict
An instance of TilerParameters converted to a dict.
"""
imgdmy_obj = ImageDummy(parameters)
dummy_image = imgdmy_obj.get_data_lazy()
# default to "tczyx" if image.dimorder is None
dummy_omero_metadata = {
f"size_{dim}": dim_size
for dim, dim_size in zip(
imgdmy_obj.dimorder or "tczyx", dummy_image.shape
)
}
dummy_omero_metadata.update(
{
"channels": [
parameters["ref_channel"],
*(["nil"] * (dummy_omero_metadata["size_c"] - 1)),
],
"name": "",
}
)
return cls(
imgdmy_obj.data,
dummy_omero_metadata,
TilerParameters.from_dict(parameters),
)
@classmethod @classmethod
def from_image(cls, image, parameters: TilerParameters): def from_image(cls, image, parameters: TilerParameters):
""" """
...@@ -307,16 +304,16 @@ class Tiler(StepABC): ...@@ -307,16 +304,16 @@ class Tiler(StepABC):
parameters: t.Optional[TilerParameters] = None, parameters: t.Optional[TilerParameters] = None,
): ):
""" """
Instantiate from h5 files. Instantiate from an h5 file.
Parameters Parameters
---------- ----------
image: an instance of Image image: an instance of Image
filepath: Path instance filepath: Path instance
Path to a directory of h5 files Path to an h5 file.
parameters: an instance of TileParameters (optional) parameters: an instance of TileParameters (optional)
""" """
tile_locs = TileLocations.read_hdf5(filepath) tile_locs = TileLocations.read_h5(filepath)
metadata = BridgeH5(filepath).meta_h5 metadata = BridgeH5(filepath).meta_h5
metadata["channels"] = image.metadata["channels"] metadata["channels"] = image.metadata["channels"]
if parameters is None: if parameters is None:
...@@ -328,11 +325,11 @@ class Tiler(StepABC): ...@@ -328,11 +325,11 @@ class Tiler(StepABC):
tile_locs=tile_locs, tile_locs=tile_locs,
) )
if hasattr(tile_locs, "drifts"): if hasattr(tile_locs, "drifts"):
tiler.n_processed = len(tile_locs.drifts) tiler.no_processed = len(tile_locs.drifts)
return tiler return tiler
@lru_cache(maxsize=2) @lru_cache(maxsize=2)
def get_tc(self, t: int, c: int) -> np.ndarray: def load_image(self, tp: int, c: int) -> np.ndarray:
""" """
Load image using dask. Load image using dask.
...@@ -345,7 +342,7 @@ class Tiler(StepABC): ...@@ -345,7 +342,7 @@ class Tiler(StepABC):
Parameters Parameters
---------- ----------
t: integer tp: integer
An index for a time point An index for a time point
c: integer c: integer
An index for a channel An index for a channel
...@@ -354,32 +351,35 @@ class Tiler(StepABC): ...@@ -354,32 +351,35 @@ class Tiler(StepABC):
------- -------
full: an array of images full: an array of images
""" """
full = self.image[t, c] full = self.image[tp, c]
if hasattr(full, "compute"): # If using dask fetch images here if hasattr(full, "compute"):
# if using dask fetch images
full = full.compute(scheduler="synchronous") full = full.compute(scheduler="synchronous")
return full return full
@property @property
def shape(self): def shape(self):
""" """
Return properties of the time-lapse as shown by self.image.shape Return the shape of the image array.
The image array is arranged as number of images, number of channels,
number of z sections, and size of the image in y and x.
""" """
return self.image.shape return self.image.shape
@property @property
def n_processed(self): def no_processed(self):
"""Return the number of processed images.""" """Return the number of processed images."""
if not hasattr(self, "_n_processed"): if not hasattr(self, "_no_processed"):
self._n_processed = 0 self._no_processed = 0
return self._n_processed return self._no_processed
@n_processed.setter @no_processed.setter
def n_processed(self, value): def no_processed(self, value):
self._n_processed = value self._no_processed = value
@property @property
def n_tiles(self): def no_tiles(self):
"""Return number of tiles.""" """Return number of tiles."""
return len(self.tile_locs) return len(self.tile_locs)
...@@ -398,9 +398,8 @@ class Tiler(StepABC): ...@@ -398,9 +398,8 @@ class Tiler(StepABC):
initial_image = self.image[0, self.ref_channel, self.ref_z] initial_image = self.image[0, self.ref_channel, self.ref_z]
if tile_size: if tile_size:
half_tile = tile_size // 2 half_tile = tile_size // 2
# max_size is the minimal number of x or y pixels # max_size is the minimum of the numbers of x and y pixels
max_size = min(self.image.shape[-2:]) max_size = min(self.image.shape[-2:])
# first time point, reference channel, reference z-position
# find the tiles # find the tiles
tile_locs = segment_traps(initial_image, tile_size) tile_locs = segment_traps(initial_image, tile_size)
# keep only tiles that are not near an edge # keep only tiles that are not near an edge
...@@ -415,6 +414,7 @@ class Tiler(StepABC): ...@@ -415,6 +414,7 @@ class Tiler(StepABC):
tile_locs, tile_size tile_locs, tile_size
) )
else: else:
# one tile with its centre at the image's centre
yx_shape = self.image.shape[-2:] yx_shape = self.image.shape[-2:]
tile_locs = [[x // 2 for x in yx_shape]] tile_locs = [[x // 2 for x in yx_shape]]
self.tile_locs = TileLocations.from_tiler_init( self.tile_locs = TileLocations.from_tiler_init(
...@@ -423,8 +423,9 @@ class Tiler(StepABC): ...@@ -423,8 +423,9 @@ class Tiler(StepABC):
def find_drift(self, tp: int): def find_drift(self, tp: int):
""" """
Find any translational drift between two images at consecutive Find any translational drift between two images.
time points using cross correlation.
Use cross correlation between two consecutive images.
Arguments Arguments
--------- ---------
...@@ -445,7 +446,7 @@ class Tiler(StepABC): ...@@ -445,7 +446,7 @@ class Tiler(StepABC):
def get_tp_data(self, tp, c) -> np.ndarray: def get_tp_data(self, tp, c) -> np.ndarray:
""" """
Returns all tiles corrected for drift. Return all tiles corrected for drift.
Parameters Parameters
---------- ----------
...@@ -456,25 +457,24 @@ class Tiler(StepABC): ...@@ -456,25 +457,24 @@ class Tiler(StepABC):
Returns Returns
---------- ----------
Numpy ndarray of tiles with shape (tile, z, y, x) Numpy ndarray of tiles with shape (no tiles, z-sections, y, x)
""" """
tiles = [] tiles = []
# get OMERO image full = self.load_image(tp, c)
full = self.get_tc(tp, c)
for tile in self.tile_locs: for tile in self.tile_locs:
# pad tile if necessary # pad tile if necessary
ndtile = self.ifoob_pad(full, tile.as_range(tp)) ndtile = Tiler.if_out_of_bounds_pad(full, tile.as_range(tp))
tiles.append(ndtile) tiles.append(ndtile)
return np.stack(tiles) return np.stack(tiles)
def get_tile_data(self, tile_id: int, tp: int, c: int): def get_tile_data(self, tile_id: int, tp: int, c: int):
""" """
Return a particular tile corrected for drift and padding. Return a tile corrected for drift and padding.
Parameters Parameters
---------- ----------
tile_id: integer tile_id: integer
Number of tile. Index of tile.
tp: integer tp: integer
Index of time points. Index of time points.
c: integer c: integer
...@@ -485,14 +485,14 @@ class Tiler(StepABC): ...@@ -485,14 +485,14 @@ class Tiler(StepABC):
ndtile: array ndtile: array
An array of (x, y) arrays, one for each z stack An array of (x, y) arrays, one for each z stack
""" """
full = self.get_tc(tp, c) full = self.load_image(tp, c)
tile = self.tile_locs.tiles[tile_id] tile = self.tile_locs.tiles[tile_id]
ndtile = self.ifoob_pad(full, tile.as_range(tp)) ndtile = self.if_out_of_bounds_pad(full, tile.as_range(tp))
return ndtile return ndtile
def _run_tp(self, tp: int): def _run_tp(self, tp: int):
""" """
Find tiles if they have not yet been found. Find tiles for a given time point.
Determine any translational drift of the current image from the Determine any translational drift of the current image from the
previous one. previous one.
...@@ -502,41 +502,33 @@ class Tiler(StepABC): ...@@ -502,41 +502,33 @@ class Tiler(StepABC):
tp: integer tp: integer
The time point to tile. The time point to tile.
""" """
# assert tp >= self.n_processed, "Time point already processed" if self.no_processed == 0 or not hasattr(self.tile_locs, "drifts"):
# TODO check contiguity?
if self.n_processed == 0 or not hasattr(self.tile_locs, "drifts"):
self.initialise_tiles(self.tile_size) self.initialise_tiles(self.tile_size)
if hasattr(self.tile_locs, "drifts"): if hasattr(self.tile_locs, "drifts"):
drift_len = len(self.tile_locs.drifts) drift_len = len(self.tile_locs.drifts)
if self.n_processed != drift_len: if self.no_processed != drift_len:
warnings.warn("Tiler:n_processed and ndrifts don't match") warnings.warn(
self.n_processed = drift_len "Tiler: the number of processed tiles and the number of drifts"
# determine drift " calculated do not match."
)
self.no_processed = drift_len
# determine drift for this time point and update tile_locs.drifts
self.find_drift(tp) self.find_drift(tp)
# update n_processed # update no_processed
self.n_processed = tp + 1 self.no_processed = tp + 1
# return result for writer # return result for writer
return self.tile_locs.to_dict(tp) return self.tile_locs.to_dict(tp)
def run(self, time_dim=None): def run(self, time_dim=None):
""" """Tile all time points in an experiment at once."""
Tile all time points in an experiment at once.
"""
if time_dim is None: if time_dim is None:
time_dim = 0 time_dim = 0
for frame in range(self.image.shape[time_dim]): for frame in range(self.image.shape[time_dim]):
self.run_tp(frame) self.run_tp(frame)
return None return None
def get_traps_timepoint(self, *args, **kwargs):
self._log(
"get_traps_timepoint is deprecated; get_tiles_timepoint instead."
)
return self.get_tiles_timepoint(*args, **kwargs)
# The next set of functions are necessary for the extraction object
def get_tiles_timepoint( def get_tiles_timepoint(
self, tp: int, tile_shape=None, channels=None, z: int = 0 self, tp: int, channels=None, z: int = 0
) -> np.ndarray: ) -> np.ndarray:
""" """
Get a multidimensional array with all tiles for a set of channels Get a multidimensional array with all tiles for a set of channels
...@@ -558,33 +550,23 @@ class Tiler(StepABC): ...@@ -558,33 +550,23 @@ class Tiler(StepABC):
Returns Returns
------- -------
res: array res: array
Data arranged as (tiles, channels, time points, X, Y, Z) Data arranged as (tiles, channels, Z, X, Y)
""" """
# FIXME add support for sub-tiling a tile
# FIXME can we ignore z
if channels is None: if channels is None:
channels = [0] channels = [0]
elif isinstance(channels, str): elif isinstance(channels, str):
channels = [channels] channels = [channels]
# get the data # get the data as a list of length of the number of channels
res = [] res = []
for c in channels: for c in channels:
# only return requested z # only return requested z
val = self.get_tp_data(tp, c)[:, z] tiles = self.get_tp_data(tp, c)[:, z]
# starts with the order: tiles, z, y, x # insert new axis at index 1 for missing channel
# returns the order: tiles, C, T, Z, X, Y tiles = np.expand_dims(tiles, axis=1)
val = np.expand_dims(val, axis=1) res.append(tiles)
res.append(val) # stack over channels if more than one
if tile_shape is not None: final = np.stack(res, axis=1)
if isinstance(tile_shape, int): return final
tile_shape = (tile_shape, tile_shape)
assert np.all(
[
(tile_size - ax) > -1
for tile_size, ax in zip(tile_shape, res[0].shape[-3:-2])
]
)
return np.stack(res, axis=1)
@property @property
def ref_channel_index(self): def ref_channel_index(self):
...@@ -593,32 +575,35 @@ class Tiler(StepABC): ...@@ -593,32 +575,35 @@ class Tiler(StepABC):
def get_channel_index(self, channel: str or int) -> int or None: def get_channel_index(self, channel: str or int) -> int or None:
""" """
Find index for channel using regex. Returns the first matched string. Find index for channel using regex.
If self.channels is integers (no image metadata) it returns None.
If channel is integer If channels are strings, return the first matched string.
If channels are integers, return channel unchanged if it is
an integer.
Parameters Parameters
---------- ----------
channel: string or int channel: string or int
The channel or index to be used. The channel or index to be used.
""" """
if isinstance(channel, int) and all(
if all(map(lambda x: isinstance(x, int), self.channels)): map(lambda x: isinstance(x, int), self.channels)
channel = channel if isinstance(channel, int) else None ):
return channel
if isinstance(channel, str): elif isinstance(channel, str):
channel = find_channel_index(self.channels, channel) return find_channel_index(self.channels, channel)
return channel else:
return None
@staticmethod @staticmethod
def ifoob_pad(full, slices): def if_out_of_bounds_pad(image_array, slices):
""" """
Return the slices padded if out of bounds. Pad slices if out of bounds.
Parameters Parameters
---------- ----------
full: array full: array
Slice of OMERO image (zstacks, x, y) - the entire position Slice of image (zstacks, x, y) - the entire position
with zstacks as first axis with zstacks as first axis
slices: tuple of two slices slices: tuple of two slices
Delineates indices for the x- and y- ranges of the tile. Delineates indices for the x- and y- ranges of the tile.
...@@ -631,11 +616,11 @@ class Tiler(StepABC): ...@@ -631,11 +616,11 @@ class Tiler(StepABC):
If much padding is needed, a tile of NaN is returned. If much padding is needed, a tile of NaN is returned.
""" """
# number of pixels in the y direction # number of pixels in the y direction
max_size = full.shape[-1] max_size = image_array.shape[-1]
# ignore parts of the tile outside of the image # ignore parts of the tile outside of the image
y, x = [slice(max(0, s.start), min(max_size, s.stop)) for s in slices] y, x = [slice(max(0, s.start), min(max_size, s.stop)) for s in slices]
# get the tile including all z stacks # get the tile including all z stacks
tile = full[:, y, x] tile = image_array[:, y, x]
# find extent of padding needed in x and y # find extent of padding needed in x and y
padding = np.array( padding = np.array(
[(-min(0, s.start), -min(0, max_size - s.stop)) for s in slices] [(-min(0, s.start), -min(0, max_size - s.stop)) for s in slices]
...@@ -643,43 +628,31 @@ class Tiler(StepABC): ...@@ -643,43 +628,31 @@ class Tiler(StepABC):
if padding.any(): if padding.any():
tile_size = slices[0].stop - slices[0].start tile_size = slices[0].stop - slices[0].start
if (padding > tile_size / 4).any(): if (padding > tile_size / 4).any():
# too much of the tile is outside of the image # fill with NaN because too much of the tile is outside of the image
# fill with NaN tile = np.full(
tile = np.full((full.shape[0], tile_size, tile_size), np.nan) (image_array.shape[0], tile_size, tile_size), np.nan
)
else: else:
# pad tile with median value of the tile # pad tile with median value of the tile
tile = np.pad(tile, [[0, 0]] + padding.tolist(), "median") tile = np.pad(tile, [[0, 0]] + padding.tolist(), "median")
return tile return tile
# FIXME: Refactor to support both channel or index def find_channel_index(image_channels: t.List[str], channel_regex: str):
# self._log below is not defined """Use a regex to find the index of a channel."""
def find_channel_index(image_channels: t.List[str], channel: str): for index, ch in enumerate(image_channels):
""" found = re.match(channel_regex, ch, re.IGNORECASE)
Access
"""
for i, ch in enumerate(image_channels):
found = re.match(channel, ch, re.IGNORECASE)
if found: if found:
if len(found.string) - (found.endpos - found.start()): if len(found.string) - (found.endpos - found.start()):
logging.getLogger("aliby").log( logging.getLogger("aliby").log(
logging.WARNING, logging.WARNING,
f"Channel {channel} matched {ch} using regex", f"Channel {channel_regex} matched {ch} using regex",
) )
return i return index
def find_channel_name(image_channels: t.List[str], channel: str): def find_channel_name(image_channels: t.List[str], channel_regex: str):
""" """Find the name of the channel using regex."""
Find the name of the channel using regex. index = find_channel_index(image_channels, channel_regex)
Parameters
----------
image_channels: list of str
Channels.
channel: str
A regular expression.
"""
index = find_channel_index(image_channels, channel)
if index is not None: if index is not None:
return image_channels[index] return image_channels[index]
...@@ -169,7 +169,7 @@ class RemoteImageViewer(BaseImageViewer): ...@@ -169,7 +169,7 @@ class RemoteImageViewer(BaseImageViewer):
with self._image_class(self.image_id, **server_info) as image: with self._image_class(self.image_id, **server_info) as image:
self.tiler.image = image.data self.tiler.image = image.data
return self.tiler.get_tc(tp, channel) return self.tiler.load_image(tp, channel)
def _find_channels(self, channels: str, guess: bool = True): def _find_channels(self, channels: str, guess: bool = True):
channels = channels or self.tiler.ref_channel channels = channels or self.tiler.ref_channel
......