Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • swain-lab/aliby/aliby-mirror
  • swain-lab/aliby/alibylite
2 results
Show changes
Commits on Source (60)
Showing
with 2898 additions and 2558 deletions
Source diff could not be displayed: it is too large. Options to address this: view the blob.
......@@ -33,7 +33,7 @@ pathos = "^0.2.8" # Lambda-friendly multithreading
p-tqdm = "^1.3.3"
pandas = ">=1.3.3"
py-find-1st = "^1.1.5" # Fast indexing
scikit-learn = ">=1.0.2" # Used for an extraction metric
scikit-learn = ">=1.0.2, <1.3" # Used for an extraction metric
scipy = ">=1.7.3"
# Pipeline + I/O
......@@ -46,14 +46,11 @@ xmltodict = "^0.13.0" # read ome-tiff metadata
zarr = "^2.14.0"
GitPython = "^3.1.27"
h5py = "2.10" # File I/O
aliby-baby = "^0.1.17"
# Networking
omero-py = { version = ">=5.6.2", optional = true } # contact omero server
# Baby segmentation
aliby-baby = {version = "^0.1.17", optional=true}
# Postprocessing
[tool.poetry.group.pp.dependencies]
leidenalg = "^0.8.8"
......@@ -113,7 +110,6 @@ grid-strategy = {version = "^0.0.1", optional=true}
[tool.poetry.extras]
omero = ["omero-py"]
baby = ["aliby-baby"]
[tool.black]
line-length = 79
......
......@@ -17,16 +17,14 @@ atomic = t.Union[int, float, str, bool]
class ParametersABC(ABC):
"""
Defines parameters as attributes and allows parameters to
Define parameters as attributes and allow parameters to
be converted to either a dictionary or to yaml.
No attribute should be called "parameters"!
"""
def __init__(self, **kwargs):
"""
Defines parameters as attributes
"""
"""Define parameters as attributes."""
assert (
"parameters" not in kwargs
), "No attribute should be named parameters"
......@@ -35,8 +33,9 @@ class ParametersABC(ABC):
def to_dict(self, iterable="null") -> t.Dict:
"""
Recursive function to return a nested dictionary of the
attributes of the class instance.
Return a nested dictionary of the attributes of the class instance.
Use recursion.
"""
if isinstance(iterable, dict):
if any(
......@@ -62,7 +61,8 @@ class ParametersABC(ABC):
def to_yaml(self, path: Union[Path, str] = None):
"""
Returns a yaml stream of the attributes of the class instance.
Return a yaml stream of the attributes of the class instance.
If path is provided, the yaml stream is saved there.
Parameters
......@@ -81,9 +81,7 @@ class ParametersABC(ABC):
@classmethod
def from_yaml(cls, source: Union[Path, str]):
"""
Returns instance from a yaml filename or stdin
"""
"""Return instance from a yaml filename or stdin."""
is_buffer = True
try:
if Path(source).exists():
......@@ -107,7 +105,8 @@ class ParametersABC(ABC):
def update(self, name: str, new_value):
"""
Update values recursively
Update values recursively.
if name is a dictionary, replace data where existing found or add if not.
It warns against type changes.
......@@ -116,7 +115,6 @@ class ParametersABC(ABC):
If a leaf node that is to be changed is a collection, it adds the new elements.
"""
assert name not in (
"parameters",
"params",
......@@ -179,7 +177,8 @@ def add_to_collection(
class ProcessABC(ABC):
"""
Base class for processes.
Defines parameters as attributes and requires run method to be defined.
Define parameters as attributes and requires a run method.
"""
def __init__(self, parameters):
......@@ -190,8 +189,8 @@ class ProcessABC(ABC):
"""
self._parameters = parameters
# convert parameters to dictionary
# and then define each parameter as an attribute
for k, v in parameters.to_dict().items():
# define each parameter as an attribute
setattr(self, k, v)
@property
......@@ -243,11 +242,9 @@ class StepABC(ProcessABC):
@timer
def run_tp(self, tp: int, **kwargs):
"""
Time and log the timing of a step.
"""
"""Time and log the timing of a step."""
return self._run_tp(tp, **kwargs)
def run(self):
# Replace run with run_tp
raise Warning("Steps use run_tp instead of run")
raise Warning("Steps use run_tp instead of run.")
......@@ -14,183 +14,170 @@ from utils_find_1st import cmp_equal, find_1st
class Cells:
"""
Extracts information from an h5 file. This class accesses:
Extract information from an h5 file.
Use output from BABY to find cells detected, get, and fill, edge masks
and retrieve mother-bud relationships.
This class accesses in the h5 file:
'cell_info', which contains 'angles', 'cell_label', 'centres',
'edgemasks', 'ellipse_dims', 'mother_assign', 'mother_assign_dynamic',
'radii', 'timepoint', 'trap'.
All of these except for 'edgemasks' are a 1D ndarray.
'radii', 'timepoint', and 'trap'. All of which except for 'edgemasks'
are a 1D ndarray.
'trap_info', which contains 'drifts', 'trap_locations'
'trap_info', which contains 'drifts', and 'trap_locations'.
The "timepoint", "cell_label", and "trap" variables are mutually consistent
1D lists.
Examples are self["timepoint"][self.get_idx(1, 3)] to find the time points
where cell 1 was present in trap 3.
"""
def __init__(self, filename, path="cell_info"):
"""Initialise from a filename."""
self.filename: t.Optional[t.Union[str, Path]] = filename
self.cinfo_path: t.Optional[str] = path
self._edgemasks: t.Optional[str] = None
self._tile_size: t.Optional[int] = None
def __getitem__(self, item):
"""
Dynamically fetch data from the h5 file and save as an attribute.
These attributes are accessed like dict keys.
"""
assert item != "edgemasks", "Edgemasks must not be loaded as a whole"
_item = "_" + item
if not hasattr(self, _item):
setattr(self, _item, self.fetch(item))
return getattr(self, _item)
def fetch(self, path):
"""Get data from the h5 file."""
with h5py.File(self.filename, mode="r") as f:
return f[self.cinfo_path][path][()]
@classmethod
def from_source(cls, source: t.Union[Path, str]):
"""Ensure initiating file is a Path object."""
return cls(Path(source))
def _log(self, message: str, level: str = "warn"):
# Log messages in the corresponding level
"""Log messages in the corresponding level."""
logger = logging.getLogger("aliby")
getattr(logger, level)(f"{self.__class__.__name__}: {message}")
@staticmethod
def _asdense(array: np.ndarray):
def asdense(array: np.ndarray):
"""Convert sparse array to dense array."""
if not isdense(array):
array = array.todense()
return array
@staticmethod
def _astype(array: np.ndarray, kind: str):
# Convert sparse arrays if needed and if kind is 'mask' it fills the outline
array = Cells._asdense(array)
def astype(array: np.ndarray, kind: str):
"""Convert sparse arrays if needed; if kind is 'mask' fill the outline."""
array = Cells.asdense(array)
if kind == "mask":
array = ndimage.binary_fill_holes(array).astype(bool)
return array
def _get_idx(self, cell_id: int, trap_id: int):
# returns boolean array of time points where both the cell with cell_id and the trap with trap_id exist
def get_idx(self, cell_id: int, trap_id: int):
"""Return boolean array giving indices for a cell_id and trap_id."""
return (self["cell_label"] == cell_id) & (self["trap"] == trap_id)
@property
def max_labels(self) -> t.List[int]:
return [max((0, *self.labels_in_trap(i))) for i in range(self.ntraps)]
"""Return the maximum cell label per tile."""
return [
max((0, *self.cell_labels_in_trap(i))) for i in range(self.ntraps)
]
@property
def max_label(self) -> int:
"""Return the maximum cell label over all tiles."""
return sum(self.max_labels)
@property
def ntraps(self) -> int:
# find the number of traps from the h5 file
"""Find the number of tiles, or traps."""
with h5py.File(self.filename, mode="r") as f:
return len(f["trap_info/trap_locations"][()])
@property
def tinterval(self):
"""Return time interval in seconds."""
with h5py.File(self.filename, mode="r") as f:
return f.attrs["time_settings/timeinterval"]
@property
def traps(self) -> t.List[int]:
# returns a list of traps
"""List unique tile, or trap, IDs."""
return list(set(self["trap"]))
@property
def tile_size(self) -> t.Union[int, t.Tuple[int], None]:
"""Give the x- and y- sizes of a tile."""
if self._tile_size is None:
with h5py.File(self.filename, mode="r") as f:
# self._tile_size = f["trap_info/tile_size"][0]
self._tile_size = f["cell_info/edgemasks"].shape[1:]
return self._tile_size
def nonempty_tp_in_trap(self, trap_id: int) -> set:
# given a trap_id returns time points in which cells are available
"""Given a tile, return time points for which cells are available."""
return set(self["timepoint"][self["trap"] == trap_id])
@property
def edgemasks(self) -> t.List[np.ndarray]:
# returns the masks per tile
"""Return a list of masks for every cell at every trap and time point."""
if self._edgemasks is None:
edgem_path: str = "edgemasks"
self._edgemasks = self._fetch(edgem_path)
self._edgemasks = self.fetch(edgem_path)
return self._edgemasks
@property
def labels(self) -> t.List[t.List[int]]:
"""
Return all cell labels in object
We use mother_assign to list traps because it is the only property that appears even
when no cells are found
"""
return [self.labels_in_trap(trap) for trap in range(self.ntraps)]
"""Return all cell labels per tile as a set for all tiles."""
return [self.cell_labels_in_trap(trap) for trap in range(self.ntraps)]
def max_labels_in_frame(self, frame: int) -> t.List[int]:
# Return the maximum label for each trap in the given frame
def max_labels_in_frame(self, final_time_point: int) -> t.List[int]:
"""Get the maximal cell label for each tile within a frame of time."""
max_labels = [
self["cell_label"][
(self["timepoint"] <= frame) & (self["trap"] == trap_id)
(self["timepoint"] <= final_time_point)
& (self["trap"] == trap_id)
]
for trap_id in range(self.ntraps)
]
return [max([0, *labels]) for labels in max_labels]
def where(self, cell_id: int, trap_id: int):
"""
Parameters
----------
cell_id: int
Cell index
trap_id: int
Trap index
Returns
----------
indices int array
boolean mask array
edge_ix int array
"""
indices = self._get_idx(cell_id, trap_id)
edgem_ix = self._edgem_where(cell_id, trap_id)
"""Return time points, indices, and edge masks for a cell and trap."""
idx = self.get_idx(cell_id, trap_id)
return (
self["timepoint"][indices],
indices,
edgem_ix,
self["timepoint"][idx],
idx,
self.edgemasks_where(cell_id, trap_id),
)
def mask(self, cell_id, trap_id):
"""
Returns the times and the binary masks of a given cell in a given tile.
Parameters
----------
cell_id : int
The unique ID of the cell.
tile_id : int
The unique ID of the tile.
Returns
-------
Tuple[np.ndarray, np.ndarray]
The times when the binary masks were taken and the binary masks of the given cell in the given tile.
"""
"""Return the times and the filled edge masks for a cell and trap."""
times, outlines = self.outline(cell_id, trap_id)
return times, np.array(
[ndimage.morphology.binary_fill_holes(o) for o in outlines]
)
def at_time(
self, timepoint: t.Iterable[int], kind="mask"
self, timepoint: int, kind="mask"
) -> t.List[t.List[np.ndarray]]:
"""
Returns a list of lists of binary masks in a given list of time points.
Parameters
----------
timepoints : Iterable[int]
The list of time points for which to return the binary masks.
kind : str, optional
The type of binary masks to return, by default "mask".
Returns
-------
List[List[np.ndarray]]
A list of lists with binary masks grouped by tile IDs.
"""
ix = self["timepoint"] == timepoint
traps = self["trap"][ix]
edgemasks = self._edgem_from_masking(ix)
"""Return a dict with traps as keys and cell masks as values for a time point."""
idx = self["timepoint"] == timepoint
traps = self["trap"][idx]
edgemasks = self.edgemasks_from_idx(idx)
masks = [
self._astype(edgemask, kind)
Cells.astype(edgemask, kind)
for edgemask in edgemasks
if edgemask.any()
]
......@@ -199,22 +186,7 @@ class Cells:
def at_times(
self, timepoints: t.Iterable[int], kind="mask"
) -> t.List[t.List[np.ndarray]]:
"""
Returns a list of lists of binary masks for a given list of time points.
Parameters
----------
timepoints : Iterable[int]
The list of time points for which to return the binary masks.
kind : str, optional
The type of binary masks to return, by default "mask".
Returns
-------
List[List[np.ndarray]]
A list of lists with binary masks grouped by tile IDs.
"""
"""Return a list of lists of cell masks one for specified time point."""
return [
[
np.stack(tile_masks) if len(tile_masks) else []
......@@ -226,91 +198,77 @@ class Cells:
def group_by_traps(
self, traps: t.Collection, cell_labels: t.Collection
) -> t.Dict[int, t.List[int]]:
"""
Returns a dict with traps as keys and list of labels as value.
Note that the total number of traps are calculated from Cells.traps.
"""
"""Return a dict with traps as keys and a list of labels as values."""
iterator = groupby(zip(traps, cell_labels), lambda x: x[0])
d = {key: [x[1] for x in group] for key, group in iterator}
d = {i: d.get(i, []) for i in self.traps}
return d
def labels_in_trap(self, trap_id: int) -> t.Set[int]:
# return set of cell ids for a given trap
def cell_labels_in_trap(self, trap_id: int) -> t.Set[int]:
"""Return unique cell labels for a given trap."""
return set((self["cell_label"][self["trap"] == trap_id]))
def labels_at_time(self, timepoint: int) -> t.Dict[int, t.List[int]]:
"""Return a dict with traps as keys and cell labels as values for a time point."""
labels = self["cell_label"][self["timepoint"] == timepoint]
traps = self["trap"][self["timepoint"] == timepoint]
return self.group_by_traps(traps, labels)
def __getitem__(self, item):
assert item != "edgemasks", "Edgemasks must not be loaded as a whole"
_item = "_" + item
if not hasattr(self, _item):
setattr(self, _item, self._fetch(item))
return getattr(self, _item)
def _fetch(self, path):
with h5py.File(self.filename, mode="r") as f:
return f[self.cinfo_path][path][()]
def _edgem_from_masking(self, mask):
def edgemasks_from_idx(self, idx):
"""Get edge masks from the h5 file."""
with h5py.File(self.filename, mode="r") as f:
edgem = f[self.cinfo_path + "/edgemasks"][mask, ...]
edgem = f[self.cinfo_path + "/edgemasks"][idx, ...]
return edgem
def _edgem_where(self, cell_id, trap_id):
id_mask = self._get_idx(cell_id, trap_id)
edgem = self._edgem_from_masking(id_mask)
return edgem
def edgemasks_where(self, cell_id, trap_id):
"""Get the edge masks for a given cell and trap for all time points."""
idx = self.get_idx(cell_id, trap_id)
edgemasks = self.edgemasks_from_idx(idx)
return edgemasks
def outline(self, cell_id: int, trap_id: int):
id_mask = self._get_idx(cell_id, trap_id)
times = self["timepoint"][id_mask]
return times, self._edgem_from_masking(id_mask)
"""Get times and edge masks for a given cell and trap."""
idx = self.get_idx(cell_id, trap_id)
times = self["timepoint"][idx]
return times, self.edgemasks_from_idx(idx)
@property
def ntimepoints(self) -> int:
"""Return total number of time points in the experiment."""
return self["timepoint"].max() + 1
@cached_property
def _cells_vs_tps(self):
# Binary matrix showing the presence of all cells in all time points
ncells_per_tile = [len(x) for x in self.labels]
cells_vs_tps = np.zeros(
(sum(ncells_per_tile), self.ntimepoints), dtype=bool
)
def cells_vs_tps(self):
"""Boolean matrix showing when cells are present for all time points."""
total_ncells = sum([len(x) for x in self.labels])
cells_vs_tps = np.zeros((total_ncells, self.ntimepoints), dtype=bool)
cells_vs_tps[
self._cell_cumsum[self["trap"]] + self["cell_label"] - 1,
self.cell_cumlsum[self["trap"]] + self["cell_label"] - 1,
self["timepoint"],
] = True
return cells_vs_tps
@cached_property
def _cell_cumsum(self):
# Cumulative sum indicating the number of cells per tile
def cell_cumlsum(self):
"""Find cumulative sum over tiles of the number of cells present."""
ncells_per_tile = [len(x) for x in self.labels]
cumsum = np.roll(np.cumsum(ncells_per_tile), shift=1)
cumsum[0] = 0
return cumsum
def _flat_index_to_tuple_location(self, idx: int) -> t.Tuple[int, int]:
# Convert a cell index to a tuple
# Note that it assumes tiles and cell labels are flattened, but
# it is agnostic to tps
tile_id = int(np.where(idx + 1 > self._cell_cumsum)[0][-1])
cell_label = idx - self._cell_cumsum[tile_id] + 1
def index_to_tile_and_cell(self, idx: int) -> t.Tuple[int, int]:
"""Convert an index to the equivalent pair of tile and cell IDs."""
tile_id = int(np.where(idx + 1 > self.cell_cumlsum)[0][-1])
cell_label = idx - self.cell_cumlsum[tile_id] + 1
return tile_id, cell_label
@property
def _tiles_vs_cells_vs_tps(self):
def tiles_vs_cells_vs_tps(self):
"""
Boolean matrix showing if a cell is present.
The matrix is indexed by trap, cell label, and time point.
"""
ncells_mat = np.zeros(
(self.ntraps, self["cell_label"].max(), self.ntimepoints),
dtype=bool,
......@@ -325,32 +283,37 @@ class Cells:
min_consecutive_tps: int = 15,
interval: None or t.Tuple[int, int] = None,
):
"""
Find cells present for all time points in a sliding window of time.
The result can be restricted to a particular interval of time.
"""
window = sliding_window_view(
self._cells_vs_tps, min_consecutive_tps, axis=1
self.cells_vs_tps, min_consecutive_tps, axis=1
)
tp_min = window.sum(axis=-1) == min_consecutive_tps
# Apply an interval filter to focucs on a slice
# apply a filter to restrict to an interval of time
if interval is not None:
interval = tuple(np.array(interval))
else:
interval = (0, window.shape[1])
low_boundary, high_boundary = interval
tp_min[:, :low_boundary] = False
tp_min[:, high_boundary:] = False
return tp_min
@lru_cache(20)
def mothers_in_trap(self, trap_id: int):
"""Return mothers at a trap."""
return self.mothers[trap_id]
@cached_property
def mothers(self):
"""
Return nested list with final prediction of mother id for each cell
Return a list of mother IDs for each cell in each tile.
Use Baby's "mother_assign_dynamic".
An ID of zero implies that no mother was assigned.
"""
return self.mother_assign_from_dynamic(
self["mother_assign_dynamic"],
......@@ -362,73 +325,71 @@ class Cells:
@cached_property
def mothers_daughters(self) -> np.ndarray:
"""
Return a single array with three columns, containing information about
the mother-daughter relationships: tile, mothers and daughters.
Return mother-daughter relationships for all tiles.
Returns
-------
np.ndarray
An array with shape (n, 3) where n is the number of mother-daughter pairs found.
The columns contain:
- tile: the tile where the mother cell is located.
- mothers: the index of the mother cell within the tile.
- daughters: the index of the daughter cell within the tile.
mothers_daughters: np.ndarray
An array with shape (n, 3) where n is the number of mother-daughter
pairs found. The first column is the tile_id for the tile where the
mother cell is located. The second column is the cell index of a
mother cell in the tile. The third column is the index of the
corresponding daughter cell.
"""
nested_massign = self.mothers
if sum([x for y in nested_massign for x in y]):
# list of arrays, one per tile, giving mothers of each cell in each tile
mothers = self.mothers
if sum([x for y in mothers for x in y]):
mothers_daughters = np.array(
[
(tid, m, d)
for tid, trapcells in enumerate(nested_massign)
for d, m in enumerate(trapcells, 1)
if m
(trap_id, mother, bud)
for trap_id, trapcells in enumerate(mothers)
for bud, mother in enumerate(trapcells, start=1)
if mother
],
dtype=np.uint16,
)
else:
mothers_daughters = np.array([])
self._log("No mother-daughters assigned")
return mothers_daughters
@staticmethod
def mother_assign_to_mb_matrix(ma: t.List[np.array]):
"""
Convert from a list of lists of mother-bud paired assignments to a
sparse matrix with a boolean dtype. The rows correspond to
to daughter buds. The values are boolean and indicate whether a
given cell is a mother cell and a given daughter bud is assigned
to the mother cell in the next timepoint.
Convert a list of mother-daughters into a boolean sparse matrix.
Each row in the matrix correspond to daughter buds.
If an entry is True, a given cell is a mother cell and a given
daughter bud is assigned to the mother cell in the next time point.
Parameters:
-----------
ma : list of lists of integers
A list of lists of mother-bud assignments. The i-th sublist contains the
bud assignments for the i-th tile. The integers in each sublist
represent the mother label, if it is zero no mother was found.
A list of lists of mother-bud assignments.
The i-th sublist contains the bud assignments for the i-th tile.
The integers in each sublist represent the mother label, with zero
implying no mother found.
Returns:
--------
mb_matrix : boolean numpy array of shape (n, m)
An n x m boolean numpy array where n is the total number of cells (sum
of the lengths of all sublists in ma) and m is the maximum number of buds
assigned to any mother cell in ma. The value at (i, j) is True if cell i
is a daughter cell and cell j is its mother assigned to i.
An n x m array where n is the total number of cells (sum
of the lengths of all sublists in ma) and m is the maximum
number of buds assigned to any mother cell in ma.
The value at (i, j) is True if cell i is a daughter cell and
cell j is its assigned mother.
Examples:
--------
ma = [[0, 0, 1], [0, 1, 0]]
Cells(None).mother_assign_to_mb_matrix(ma)
# array([[False, False, False, False, False, False],
# [False, False, False, False, False, False],
# [ True, False, False, False, False, False],
# [False, False, False, False, False, False],
# [False, False, False, True, False, False],
# [False, False, False, False, False, False]])
>>> ma = [[0, 0, 1], [0, 1, 0]]
>>> Cells(None).mother_assign_to_mb_matrix(ma)
>>> array([[False, False, False, False, False, False],
[False, False, False, False, False, False],
[ True, False, False, False, False, False],
[False, False, False, False, False, False],
[False, False, False, True, False, False],
[False, False, False, False, False, False]])
"""
ncells = sum([len(t) for t in ma])
mb_matrix = np.zeros((ncells, ncells), dtype=bool)
c = 0
......@@ -436,69 +397,78 @@ class Cells:
for d, m in enumerate(cells):
if m:
mb_matrix[c + d, c + m - 1] = True
c += len(cells)
return mb_matrix
@staticmethod
def mother_assign_from_dynamic(
ma: np.ndarray, cell_label: t.List[int], trap: t.List[int], ntraps: int
ma: np.ndarray,
cell_label: t.List[int],
trap: t.List[int],
ntraps: int,
) -> t.List[t.List[int]]:
"""
Interpolate the associated mothers from the 'mother_assign_dynamic' feature.
Find mothers from Baby's 'mother_assign_dynamic' variable.
Parameters
----------
ma: np.ndarray
An array with shape (n_t, n_c) containing the 'mother_assign_dynamic' feature.
An array with of length number of time points times number of cells
containing the 'mother_assign_dynamic' produced by Baby.
cell_label: List[int]
A list containing the cell labels.
A list of cell labels.
trap: List[int]
A list containing the trap labels.
A list of trap labels.
ntraps: int
The total number of traps.
Returns
-------
List[List[int]]
A list of lists containing the interpolated mother assignment for each cell in each trap.
A list giving the mothers for each cell at each trap.
"""
idlist = list(zip(trap, cell_label))
cell_gid = np.unique(idlist, axis=0)
ids = np.unique(list(zip(trap, cell_label)), axis=0)
# find when each cell last appeared at its trap
last_lin_preds = [
find_1st(
((cell_label[::-1] == lbl) & (trap[::-1] == tr)),
(
(cell_label[::-1] == cell_label_id)
& (trap[::-1] == trap_id)
),
True,
cmp_equal,
)
for tr, lbl in cell_gid
for trap_id, cell_label_id in ids
]
# find the cell's mother using the latest prediction from Baby
mother_assign_sorted = ma[::-1][last_lin_preds]
traps = cell_gid[:, 0]
# rearrange as a list of mother IDs for each cell in each tile
traps = ids[:, 0]
iterator = groupby(zip(traps, mother_assign_sorted), lambda x: x[0])
d = {key: [x[1] for x in group] for key, group in iterator}
nested_massign = [d.get(i, []) for i in range(ntraps)]
d = {trap: [x[1] for x in mothers] for trap, mothers in iterator}
mothers = [d.get(i, []) for i in range(ntraps)]
return mothers
return nested_massign
###############################################################################
# Apparently unused below here
###############################################################################
@lru_cache(maxsize=200)
def labelled_in_frame(
self, frame: int, global_id: bool = False
) -> np.ndarray:
"""
Returns labels in a 4D ndarray with the global ids with shape
(ntraps, max_nlabels, ysize, xsize) at a given frame.
Return labels in a 4D ndarray with potentially global ids.
Use lru_cache to cache the results for speed.
Parameters
----------
frame : int
The frame number.
The frame number (time point).
global_id : bool, optional
If True, the returned array contains global ids, otherwise it
contains only the local ids of the labels. Default is False.
If True, the returned array contains global ids, otherwise only
the local ids of the labels.
Returns
-------
......@@ -507,18 +477,12 @@ class Cells:
The array has dimensions (ntraps, max_nlabels, ysize, xsize),
where max_nlabels is specific for this frame, not the entire
experiment.
Notes
-----
This method uses lru_cache to cache the results for faster access.
"""
labels_in_frame = self.labels_at_time(frame)
n_labels = [
len(labels_in_frame.get(trap_id, []))
for trap_id in range(self.ntraps)
]
# maxes = self.max_labels_in_frame(frame)
stacks_in_frame = self.get_stacks_in_frame(frame, self.tile_size)
first_id = np.cumsum([0, *n_labels])
labels_mat = np.zeros(
......@@ -552,7 +516,9 @@ class Cells:
self, frame: int, tile_shape: t.Tuple[int]
) -> t.List[np.ndarray]:
"""
Returns a list of stacked masks, each corresponding to a tile at a given timepoint.
Return a list of stacked masks.
Each corresponds to a tile at a given time point.
Parameters
----------
......@@ -564,7 +530,7 @@ class Cells:
Returns
-------
List[np.ndarray]
List of stacked masks for each tile at the given timepoint.
List of stacked masks for each tile at the given time point.
"""
masks = self.at_time(frame)
return [
......@@ -574,7 +540,7 @@ class Cells:
for trap_id in range(self.ntraps)
]
def _sample_tiles_tps(
def sample_tiles_tps(
self,
size=1,
min_consecutive_ntps: int = 15,
......@@ -582,7 +548,7 @@ class Cells:
interval=None,
) -> t.Tuple[np.ndarray, np.ndarray]:
"""
Sample tiles that have a minimum number of cells and are occupied for at least a minimum number of consecutive timepoints.
Sample tiles that have a minimum number of cells and are occupied for at least a minimum number of consecutive time points.
Parameters
----------
......@@ -591,7 +557,7 @@ class Cells:
min_ncells: int, optional (default=2)
The minimum number of cells per tile.
min_consecutive_ntps: int, optional (default=5)
The minimum number of consecutive timepoints a cell must be present in a trap.
The minimum number of consecutive timep oints a cell must be present in a trap.
seed: int, optional (default=0)
Random seed value for reproducibility.
interval: None or Tuple(int,int), optional (default=None)
......@@ -612,21 +578,15 @@ class Cells:
min_consecutive_tps=min_consecutive_ntps,
interval=interval,
)
# Find all valid tiles with min_ncells for at least min_tps
index_id, tps = np.where(cell_availability_matrix)
if interval is None: # Limit search
interval = (0, cell_availability_matrix.shape[1])
np.random.seed(seed)
choices = np.random.randint(len(index_id), size=size)
linear_indices = np.zeros_like(self["cell_label"], dtype=bool)
for cell_index_flat, tp in zip(index_id[choices], tps[choices]):
tile_id, cell_label = self._flat_index_to_tuple_location(
cell_index_flat
)
tile_id, cell_label = self.index_to_tile_and_cell(cell_index_flat)
linear_indices[
(
(self["cell_label"] == cell_label)
......@@ -634,10 +594,9 @@ class Cells:
& (self["timepoint"] == tp)
)
] = True
return linear_indices
def _sample_masks(
def sample_masks(
self,
size: int = 1,
min_consecutive_ntps: int = 15,
......@@ -668,31 +627,28 @@ class Cells:
The second tuple contains:
- `masks`: A list of 2D numpy arrays representing the binary masks of the sampled cells at each timepoint.
"""
sampled_bitmask = self._sample_tiles_tps(
sampled_bitmask = self.sample_tiles_tps(
size=size,
min_consecutive_ntps=min_consecutive_ntps,
seed=seed,
interval=interval,
)
# Sort sampled tiles to use automatic cache when possible
tile_ids = self["trap"][sampled_bitmask]
cell_labels = self["cell_label"][sampled_bitmask]
tps = self["timepoint"][sampled_bitmask]
masks = []
for tile_id, cell_label, tp in zip(tile_ids, cell_labels, tps):
local_idx = self.labels_at_time(tp)[tile_id].index(cell_label)
tile_mask = self.at_time(tp)[tile_id][local_idx]
masks.append(tile_mask)
return (tile_ids, cell_labels, tps), np.stack(masks)
def matrix_trap_tp_where(
self, min_ncells: int = 2, min_consecutive_tps: int = 5
):
"""
NOTE CURRENLTY UNUSED WITHIN ALIBY THE MOMENT. MAY BE USEFUL IN THE FUTURE.
NOTE CURRENTLY UNUSED BUT USEFUL.
Return a matrix of shape (ntraps x ntps - min_consecutive_tps) to
indicate traps and time-points where min_ncells are available for at least min_consecutive_tps
......@@ -708,9 +664,8 @@ class Cells:
(ntraps x ( ntps-min_consecutive_tps )) 2D boolean numpy array where rows are trap ids and columns are timepoint windows.
If the value in a cell is true its corresponding trap and timepoint contains more than min_ncells for at least min_consecutive time-points.
"""
window = sliding_window_view(
self._tiles_vs_cells_vs_tps, min_consecutive_tps, axis=2
self.tiles_vs_cells_vs_tps, min_consecutive_tps, axis=2
)
tp_min = window.sum(axis=-1) == min_consecutive_tps
ncells_tp_min = tp_min.sum(axis=1) >= min_ncells
......@@ -720,7 +675,7 @@ class Cells:
def stack_masks_in_tile(
masks: t.List[np.ndarray], tile_shape: t.Tuple[int]
) -> np.ndarray:
# Stack all masks in a trap padding accordingly if no outlines found
"""Stack all masks in a trap, padding accordingly if no outlines found."""
result = np.zeros((0, *tile_shape), dtype=bool)
if len(masks):
result = np.stack(masks)
......
......@@ -6,17 +6,19 @@ import typing as t
from functools import wraps
def _first_arg_str_to_df(
def _first_arg_str_to_raw_df(
fn: t.Callable,
):
"""Enable Signal-like classes to convert strings to data sets."""
@wraps(fn)
def format_input(*args, **kwargs):
cls = args[0]
data = args[1]
if isinstance(data, str):
# get data from h5 file
# get data from h5 file using Signal's get_raw
data = cls.get_raw(data)
# replace path in the undecorated function with data
return fn(cls, data, *args[2:], **kwargs)
return format_input
......@@ -66,7 +66,7 @@ class MetaData:
# Needed because HDF5 attributes do not support dictionaries
def flatten_dict(nested_dict, separator="/"):
"""
Flattens nested dictionary. If empty return as-is.
Flatten nested dictionary. If empty return as-is.
"""
flattened = {}
if nested_dict:
......@@ -79,9 +79,7 @@ def flatten_dict(nested_dict, separator="/"):
# Needed because HDF5 attributes do not support datetime objects
# Takes care of time zones & daylight saving
def datetime_to_timestamp(time, locale="Europe/London"):
"""
Convert datetime object to UNIX timestamp
"""
"""Convert datetime object to UNIX timestamp."""
return timezone(locale).localize(time).timestamp()
......@@ -189,36 +187,37 @@ def parse_swainlab_metadata(filedir: t.Union[str, Path]):
Dictionary with minimal metadata
"""
filedir = Path(filedir)
filepath = find_file(filedir, "*.log")
if filepath:
# new log files
raw_parse = parse_from_swainlab_grammar(filepath)
minimal_meta = get_meta_swainlab(raw_parse)
else:
# old log files
if filedir.is_file() or str(filedir).endswith(".zarr"):
# log file is in parent directory
filedir = filedir.parent
legacy_parse = parse_logfiles(filedir)
minimal_meta = (
get_meta_from_legacy(legacy_parse) if legacy_parse else {}
)
return minimal_meta
def dispatch_metadata_parser(filepath: t.Union[str, Path]):
"""
Function to dispatch different metadata parsers that convert logfiles into a
basic metadata dictionary. Currently only contains the swainlab log parsers.
Dispatch different metadata parsers that convert logfiles into a dictionary.
Currently only contains the swainlab log parsers.
Input:
--------
filepath: str existing file containing metadata, or folder containing naming conventions
filepath: str existing file containing metadata, or folder containing naming
conventions
"""
parsed_meta = parse_swainlab_metadata(filepath)
if parsed_meta is None:
parsed_meta = dir_to_meta
return parsed_meta
......
......@@ -5,7 +5,7 @@ import h5py
import numpy as np
from agora.io.bridge import groupsort
from agora.io.writer import load_attributes
from agora.io.writer import load_meta
class DynamicReader:
......@@ -13,7 +13,7 @@ class DynamicReader:
def __init__(self, file: str):
self.file = file
self.metadata = load_attributes(file)
self.metadata = load_meta(file)
class StateReader(DynamicReader):
......
......@@ -9,9 +9,10 @@ import h5py
import numpy as np
import pandas as pd
import aliby.global_parameters as global_parameters
from agora.io.bridge import BridgeH5
from agora.io.decorators import _first_arg_str_to_df
from agora.utils.indexing import validate_association
from agora.io.decorators import _first_arg_str_to_raw_df
from agora.utils.indexing import validate_lineage
from agora.utils.kymograph import add_index_levels
from agora.utils.merge import apply_merges
......@@ -20,11 +21,14 @@ class Signal(BridgeH5):
"""
Fetch data from h5 files for post-processing.
Signal assumes that the metadata and data are accessible to perform time-adjustments and apply previously recorded post-processes.
Signal assumes that the metadata and data are accessible to
perform time-adjustments and apply previously recorded
post-processes.
"""
def __init__(self, file: t.Union[str, Path]):
"""Define index_names for dataframes, candidate fluorescence channels, and composite statistics."""
"""Define index_names for dataframes, candidate fluorescence channels,
and composite statistics."""
super().__init__(file, flag=None)
self.index_names = (
"experiment",
......@@ -33,22 +37,13 @@ class Signal(BridgeH5):
"cell_label",
"mother_label",
)
self.candidate_channels = (
"GFP",
"GFPFast",
"mCherry",
"Flavin",
"Citrine",
"mKO2",
"Cy5",
"pHluorin405",
)
self.candidate_channels = global_parameters.possible_imaging_channels
def __getitem__(self, dsets: t.Union[str, t.Collection]):
"""Get and potentially pre-process data from h5 file and return as a dataframe."""
if isinstance(dsets, str): # no pre-processing
if isinstance(dsets, str):
return self.get(dsets)
elif isinstance(dsets, list): # pre-processing
elif isinstance(dsets, list):
is_bgd = [dset.endswith("imBackground") for dset in dsets]
# Check we are not comparing tile-indexed and cell-indexed data
assert sum(is_bgd) == 0 or sum(is_bgd) == len(
......@@ -58,22 +53,23 @@ class Signal(BridgeH5):
else:
raise Exception(f"Invalid type {type(dsets)} to get datasets")
def get(self, dsets: t.Union[str, t.Collection], **kwargs):
"""Get and potentially pre-process data from h5 file and return as a dataframe."""
if isinstance(dsets, str): # no pre-processing
df = self.get_raw(dsets, **kwargs)
def get(self, dset_name: t.Union[str, t.Collection], **kwargs):
"""Return pre-processed data as a dataframe."""
if isinstance(dset_name, str):
dsets = self.get_raw(dset_name, **kwargs)
prepost_applied = self.apply_prepost(dsets, **kwargs)
return self.add_name(prepost_applied, dsets)
return self.add_name(prepost_applied, dset_name)
else:
raise Exception("Error in Signal.get")
@staticmethod
def add_name(df, name):
"""Add column of identical strings to a dataframe."""
"""Add name of the Signal as an attribute to its corresponding dataframe."""
df.name = name
return df
def cols_in_mins(self, df: pd.DataFrame):
# Convert numerical columns in a dataframe to minutes
"""Convert numerical columns in a dataframe to minutes."""
try:
df.columns = (df.columns * self.tinterval // 60).astype(int)
except Exception as e:
......@@ -94,14 +90,15 @@ class Signal(BridgeH5):
if tinterval_location in f.attrs:
return f.attrs[tinterval_location][0]
else:
logging.getlogger("aliby").warn(
logging.getLogger("aliby").warn(
f"{str(self.filename).split('/')[-1]}: using default time interval of 5 minutes"
)
return 5
@staticmethod
def get_retained(df, cutoff):
"""Return a fraction of the df, one without later time points."""
"""Return rows of df with at least cutoff fraction of the total number
of time points."""
return df.loc[bn.nansum(df.notna(), axis=1) > df.shape[1] * cutoff]
@property
......@@ -110,15 +107,17 @@ class Signal(BridgeH5):
with h5py.File(self.filename, "r") as f:
return list(f.attrs["channels"])
@_first_arg_str_to_df
def retained(self, signal, cutoff=0.8):
def retained(
self, signal, cutoff=global_parameters.signal_retained_cutoff
):
"""
Load data (via decorator) and reduce the resulting dataframe.
Load data for a signal or a list of signals and reduce the resulting
dataframes to a fraction of their original size, losing late time
points.
dataframes to rows with sufficient numbers of time points.
"""
if isinstance(signal, str):
signal = self.get_raw(signal)
if isinstance(signal, pd.DataFrame):
return self.get_retained(signal, cutoff)
elif isinstance(signal, list):
......@@ -131,17 +130,15 @@ class Signal(BridgeH5):
"""
Get lineage data from a given location in the h5 file.
Returns an array with three columns: the tile id, the mother label, and the daughter label.
Returns an array with three columns: the tile id, the mother label,
and the daughter label.
"""
if lineage_location is None:
lineage_location = "modifiers/lineage_merged"
with h5py.File(self.filename, "r") as f:
# if lineage_location not in f:
# lineage_location = lineage_location.split("_")[0]
if lineage_location not in f:
lineage_location = "postprocessing/lineage"
tile_mo_da = f[lineage_location]
if isinstance(tile_mo_da, h5py.Dataset):
lineage = tile_mo_da[()]
else:
......@@ -154,7 +151,7 @@ class Signal(BridgeH5):
).T
return lineage
@_first_arg_str_to_df
@_first_arg_str_to_raw_df
def apply_prepost(
self,
data: t.Union[str, pd.DataFrame],
......@@ -162,57 +159,40 @@ class Signal(BridgeH5):
picks: t.Union[t.Collection, bool] = True,
):
"""
Apply modifier operations (picker or merger) to a dataframe.
Apply picking and merging to a Signal data frame.
Parameters
----------
data : t.Union[str, pd.DataFrame]
DataFrame or path to one.
A data frame or a path to one.
merges : t.Union[np.ndarray, bool]
(optional) 2-D array with three columns: the tile id, the mother label, and the daughter id.
(optional) An array of pairs of (trap, cell) indices to merge.
If True, fetch merges from file.
picks : t.Union[np.ndarray, bool]
(optional) 2-D array with two columns: the tiles and
the cell labels.
(optional) An array of (trap, cell) indices.
If True, fetch picks from file.
Examples
--------
FIXME: Add docs.
"""
if isinstance(merges, bool):
merges: np.ndarray = self.load_merges() if merges else np.array([])
merges = self.load_merges() if merges else np.array([])
if merges.any():
merged = apply_merges(data, merges)
else:
merged = copy(data)
if isinstance(picks, bool):
picks = (
self.get_picks(names=merged.index.names)
self.get_picks(
names=merged.index.names, path="modifiers/picks/"
)
if picks
else set(merged.index)
else merged.index
)
with h5py.File(self.filename, "r") as f:
if "modifiers/picks" in f and picks:
if picks:
return merged.loc[
set(picks).intersection(
[tuple(x) for x in merged.index]
)
]
else:
if isinstance(merged.index, pd.MultiIndex):
empty_lvls = [[] for i in merged.index.names]
index = pd.MultiIndex(
levels=empty_lvls,
codes=empty_lvls,
names=merged.index.names,
)
else:
index = pd.Index([], name=merged.index.name)
merged = pd.DataFrame([], index=index)
return merged
if picks:
picked_indices = set(picks).intersection(
[tuple(x) for x in merged.index]
)
return merged.loc[picked_indices]
else:
return merged
@cached_property
def p_available(self):
......@@ -272,10 +252,11 @@ class Signal(BridgeH5):
Parameters
----------
dataset: str or list of strs
The name of the h5 file or a list of h5 file names
The name of the h5 file or a list of h5 file names.
in_minutes: boolean
If True,
If True, convert column headings to times in minutes.
lineage: boolean
If True, add mother_label to index.
"""
try:
if isinstance(dataset, str):
......@@ -288,15 +269,17 @@ class Signal(BridgeH5):
self.get_raw(dset, in_minutes=in_minutes, lineage=lineage)
for dset in dataset
]
if lineage: # assume that df is sorted
if lineage:
# assume that df is sorted
mother_label = np.zeros(len(df), dtype=int)
lineage = self.lineage()
a, b = validate_association(
# information on buds
valid_lineage, valid_indices = validate_lineage(
lineage,
np.array(df.index.to_list()),
match_column=1,
"daughters",
)
mother_label[b] = lineage[a, 1]
mother_label[valid_indices] = lineage[valid_lineage, 1]
df = add_index_levels(df, {"mother_label": mother_label})
return df
except Exception as e:
......@@ -316,13 +299,14 @@ class Signal(BridgeH5):
names: t.Tuple[str, ...] = ("trap", "cell_label"),
path: str = "modifiers/picks/",
) -> t.Set[t.Tuple[int, str]]:
"""Get the relevant picks based on names."""
"""Get picks from the h5 file."""
with h5py.File(self.filename, "r") as f:
picks = set()
if path in f:
picks = set(
zip(*[f[path + name] for name in names if name in f[path]])
)
else:
picks = set()
return picks
def dataset_to_df(self, f: h5py.File, path: str) -> pd.DataFrame:
......@@ -353,10 +337,7 @@ class Signal(BridgeH5):
fullname: str,
node: t.Union[h5py.Dataset, h5py.Group],
):
"""
Store the name of a signal if it is a leaf node
(a group with no more groups inside) and if it starts with extraction.
"""
"""Store the name of a signal if it is a leaf node and if it starts with extraction."""
if isinstance(node, h5py.Group) and np.all(
[isinstance(x, h5py.Dataset) for x in node.values()]
):
......
......@@ -15,9 +15,10 @@ from agora.io.bridge import BridgeH5
#################### Dynamic version ##################################
def load_attributes(file: str, group="/"):
def load_meta(file: str, group="/"):
"""
Load the metadata from an h5 file and convert to a dictionary, including the "parameters" field which is stored as YAML.
Load the metadata from an h5 file and convert to a dictionary, including
the "parameters" field which is stored as YAML.
Parameters
----------
......@@ -26,8 +27,9 @@ def load_attributes(file: str, group="/"):
group: str, optional
The group in the h5 file from which to read the data
"""
# load the metadata, stored as attributes, from the h5 file and return as a dictionary
# load the metadata, stored as attributes, from the h5 file
with h5py.File(file, "r") as f:
# return as a dict
meta = dict(f[group].attrs.items())
if "parameters" in meta:
# convert from yaml format into dict
......@@ -51,7 +53,7 @@ class DynamicWriter:
self.file = file
# the metadata is stored as attributes in the h5 file
if Path(file).exists():
self.metadata = load_attributes(file)
self.metadata = load_meta(file)
def _log(self, message: str, level: str = "warn"):
# Log messages in the corresponding level
......
......@@ -9,6 +9,152 @@ This can be:
import numpy as np
import typing as t
# data type to link together trap and cell ids
i_dtype = {"names": ["trap_id", "cell_id"], "formats": [np.int64, np.int64]}
def validate_lineage(
lineage: np.ndarray, indices: np.ndarray, how: str = "families"
):
"""
Identify mother-bud pairs that exist both in lineage and a Signal's
indices.
We expect the lineage information to be unique: a bud should not have
two mothers.
Parameters
----------
lineage : np.ndarray
2D array of lineage associations where columns are
(trap, mother, daughter)
or
a 3D array, which is an array of 2 X 2 arrays comprising
[[trap_id, mother_label], [trap_id, daughter_label]].
indices : np.ndarray
A 2D array of cell indices from a Signal, (trap_id, cell_label).
This array should not include mother_label.
how: str
If "mothers", matches indicate mothers from mother-bud pairs;
If "daughters", matches indicate daughters from mother-bud pairs;
If "families", matches indicate mothers and daughters in mother-bud pairs.
Returns
-------
valid_lineage: boolean np.ndarray
1D array indicating matched elements in lineage.
valid_indices: boolean np.ndarray
1D array indicating matched elements in indices.
Examples
--------
>>> import numpy as np
>>> from agora.utils.indexing import validate_lineage
>>> lineage = np.array([ [[0, 1], [0, 3]], [[0, 1], [0, 4]], [[0, 1], [0, 6]], [[0, 4], [0, 7]] ])
>>> indices = np.array([ [0, 1], [0, 2], [0, 3]])
>>> valid_lineage, valid_indices = validate_lineage(lineage, indices)
>>> print(valid_lineage)
array([ True, False, False, False])
>>> print(valid_indices)
array([ True, False, True])
and
>>> lineage = np.array([[[0,3], [0,1]], [[0,2], [0,4]]])
>>> indices = np.array([[0,1], [0,2], [0,3]])
>>> valid_lineage, valid_indices = validate_lineage(lineage, indices)
>>> print(valid_lineage)
array([ True, False])
>>> print(valid_indices)
array([ True, False, True])
"""
if lineage.ndim == 2:
# [trap, mother, daughter] becomes [[trap, mother], [trap, daughter]]
lineage = _assoc_indices_to_3d(lineage)
if how == "mothers":
c_index = 0
elif how == "daughters":
c_index = 1
# find valid lineage
valid_lineages = index_isin(lineage, indices)
if how == "families":
# both mother and bud must be in indices
valid_lineage = valid_lineages.all(axis=1)
else:
valid_lineage = valid_lineages[:, c_index, :]
flat_valid_lineage = valid_lineage.flatten()
# find valid indices
selected_lineages = lineage[flat_valid_lineage, ...]
if how == "families":
# select only pairs of mother and bud indices
valid_indices = index_isin(indices, selected_lineages)
else:
valid_indices = index_isin(indices, selected_lineages[:, c_index, :])
flat_valid_indices = valid_indices.flatten()
if (
indices[flat_valid_indices, :].size
!= np.unique(
lineage[flat_valid_lineage, :].reshape(-1, 2), axis=0
).size
):
# all unique indices in valid_lineages should be in valid_indices
raise Exception(
"Error in validate_lineage: "
"lineage information is likely not unique."
)
return flat_valid_lineage, flat_valid_indices
def index_isin(x: np.ndarray, y: np.ndarray) -> np.ndarray:
"""
Find those elements of x that are in y.
Both arrays must be arrays of integer indices,
such as (trap_id, cell_id).
"""
x = np.ascontiguousarray(x, dtype=np.int64)
y = np.ascontiguousarray(y, dtype=np.int64)
xv = x.view(i_dtype)
inboth = np.intersect1d(xv, y.view(i_dtype))
x_bool = np.isin(xv, inboth)
return x_bool
def _assoc_indices_to_3d(ndarray: np.ndarray):
"""
Convert the last column to a new row and repeat first column's values.
For example: [trap, mother, daughter] becomes
[[trap, mother], [trap, daughter]].
Assumes the input array has shape (N,3).
"""
result = ndarray
if len(ndarray) and ndarray.ndim > 1:
# faster indexing for single positions
if ndarray.shape[1] == 3:
result = np.transpose(
np.hstack((ndarray[:, [0]], ndarray)).reshape(-1, 2, 2),
axes=[0, 2, 1],
)
else:
# 20% slower but more general indexing
columns = np.arange(ndarray.shape[1])
result = np.stack(
(
ndarray[:, np.delete(columns, -1)],
ndarray[:, np.delete(columns, -2)],
),
axis=1,
)
return result
###################################################################
def validate_association(
association: np.ndarray,
......@@ -104,38 +250,8 @@ def validate_association(
return valid_association, valid_indices
def _assoc_indices_to_3d(ndarray: np.ndarray):
"""
Convert the last column to a new row while repeating all previous indices.
This is useful when converting a signal multiindex before comparing association.
Assumes the input array has shape (N,3)
"""
result = ndarray
if len(ndarray) and ndarray.ndim > 1:
if ndarray.shape[1] == 3: # Faster indexing for single positions
result = np.transpose(
np.hstack((ndarray[:, [0]], ndarray)).reshape(-1, 2, 2),
axes=[0, 2, 1],
)
else: # 20% slower but more general indexing
columns = np.arange(ndarray.shape[1])
result = np.stack(
(
ndarray[:, np.delete(columns, -1)],
ndarray[:, np.delete(columns, -2)],
),
axis=1,
)
return result
def _3d_index_to_2d(array: np.ndarray):
"""
Opposite to _assoc_indices_to_3d.
"""
"""Revert _assoc_indices_to_3d."""
result = array
if len(array):
result = np.concatenate(
......
......@@ -86,16 +86,19 @@ def bidirectional_retainment_filter(
daughters_thresh: int = 7,
) -> pd.DataFrame:
"""
Retrieve families where mothers are present for more than a fraction of the experiment, and daughters for longer than some number of time-points.
Retrieve families where mothers are present for more than a fraction
of the experiment and daughters for longer than some number of
time-points.
Parameters
----------
df: pd.DataFrame
Data
mothers_thresh: float
Minimum fraction of experiment's total duration for which mothers must be present.
Minimum fraction of experiment's total duration for which mothers
must be present.
daughters_thresh: int
Minimum number of time points for which daughters must be observed
Minimum number of time points for which daughters must be observed.
"""
# daughters
all_daughters = df.loc[df.index.get_level_values("mother_label") > 0]
......@@ -170,6 +173,7 @@ def slices_from_spans(spans: t.Tuple[int], df: pd.DataFrame) -> t.List[slice]:
def drop_mother_label(index: pd.MultiIndex) -> np.ndarray:
"""Remove mother_label level from a MultiIndex."""
no_mother_label = index
if "mother_label" in index.names:
no_mother_label = index.droplevel("mother_label")
......
#!/usr/bin/env python3
import re
import typing as t
import numpy as np
import pandas as pd
from agora.io.bridge import groupsort
from itertools import groupby
def mb_array_to_dict(mb_array: np.ndarray):
......@@ -19,4 +15,3 @@ def mb_array_to_dict(mb_array: np.ndarray):
for trap, mo_da in groupsort(mb_array).items()
for mo, daughters in groupsort(mo_da).items()
}
......@@ -3,90 +3,161 @@
Functions to efficiently merge rows in DataFrames.
"""
import typing as t
from copy import copy
import numpy as np
import pandas as pd
from utils_find_1st import cmp_larger, find_1st
from agora.utils.indexing import compare_indices, validate_association
from agora.utils.indexing import index_isin
def group_merges(merges: np.ndarray) -> t.List[t.Tuple]:
"""
Convert merges into a list of merges for traps requiring multiple
merges and then for traps requiring single merges.
"""
left_tracks = merges[:, 0]
right_tracks = merges[:, 1]
# find traps requiring multiple merges
linr = merges[index_isin(left_tracks, right_tracks).flatten(), :]
rinl = merges[index_isin(right_tracks, left_tracks).flatten(), :]
# make unique and order merges for each trap
multi_merge = np.unique(np.concatenate((linr, rinl)), axis=0)
# find traps requiring a singe merge
single_merge = merges[
~index_isin(merges, multi_merge).all(axis=1).flatten(), :
]
# convert to lists of arrays
single_merge_list = [[sm] for sm in single_merge]
multi_merge_list = [
multi_merge[multi_merge[:, 0, 0] == trap_id, ...]
for trap_id in np.unique(multi_merge[:, 0, 0])
]
res = [*multi_merge_list, *single_merge_list]
return res
def merge_lineage(
lineage: np.ndarray, merges: np.ndarray
) -> (np.ndarray, np.ndarray):
"""
Use merges to update lineage information.
Check if merging causes any buds to have multiple mothers and discard
those incorrect merges.
Return updated lineage and merge arrays.
"""
flat_lineage = lineage.reshape(-1, 2)
bud_mother_dict = {
tuple(bud): mother for bud, mother in zip(lineage[:, 1], lineage[:, 0])
}
left_tracks = merges[:, 0]
# find left tracks that are in lineages
valid_lineages = index_isin(flat_lineage, left_tracks).flatten()
# group into multi- and then single merges
grouped_merges = group_merges(merges)
# perform merges
if valid_lineages.any():
# indices of each left track -> indices of rightmost right track
replacement_dict = {
tuple(contig_pair[0]): merge[-1][1]
for merge in grouped_merges
for contig_pair in merge
}
# if both key and value are buds, they must have the same mother
buds = lineage[:, 1]
incorrect_merges = [
key
for key in replacement_dict
if np.any(index_isin(buds, replacement_dict[key]).flatten())
and np.any(index_isin(buds, key).flatten())
and not np.array_equal(
bud_mother_dict[key],
bud_mother_dict[tuple(replacement_dict[key])],
)
]
if incorrect_merges:
# reassign incorrect merges so that they have no affect
for key in incorrect_merges:
replacement_dict[key] = key
# find only correct merges
new_merges = merges[
~index_isin(
merges[:, 0], np.array(incorrect_merges)
).flatten(),
...,
]
else:
new_merges = merges
# correct lineage information
# replace mother or bud index with index of rightmost track
flat_lineage[valid_lineages] = [
replacement_dict[tuple(index)]
for index in flat_lineage[valid_lineages]
]
else:
new_merges = merges
# reverse flattening
new_lineage = flat_lineage.reshape(-1, 2, 2)
# remove any duplicates
new_lineage = np.unique(new_lineage, axis=0)
return new_lineage, new_merges
def apply_merges(data: pd.DataFrame, merges: np.ndarray):
"""Split data in two, one subset for rows relevant for merging and one
without them. It uses an array of source tracklets and target tracklets
to efficiently merge them.
"""
Generate a new data frame containing merged tracks.
Parameters
----------
data : pd.DataFrame
Input DataFrame.
A Signal data frame.
merges : np.ndarray
3-D ndarray where dimensions are (X,2,2): nmerges, source-target
pair and single-cell identifiers, respectively.
Examples
--------
FIXME: Add docs.
An array of pairs of (trap, cell) indices to merge.
"""
indices = data.index
if "mother_label" in indices.names:
indices = indices.droplevel("mother_label")
valid_merges, indices = validate_association(
merges, np.array(list(indices))
)
# Assign non-merged
merged = data.loc[~indices]
# Implement the merges and drop source rows.
# TODO Use matrices to perform merges in batch
# for ecficiency
indices = np.array(list(indices))
# merges in the data frame's indices
valid_merges = index_isin(merges, indices).all(axis=1).flatten()
# corresponding indices for the data frame in merges
selected_merges = merges[valid_merges, ...]
valid_indices = index_isin(indices, selected_merges).flatten()
# data not requiring merging
merged = data.loc[~valid_indices]
# merge tracks
if valid_merges.any():
to_merge = data.loc[indices]
targets, sources = zip(*merges[valid_merges])
for source, target in zip(sources, targets):
target = tuple(target)
to_merge.loc[target] = join_tracks_pair(
to_merge.loc[target].values,
to_merge.loc[tuple(source)].values,
to_merge = data.loc[valid_indices].copy()
left_indices = merges[:, 0]
right_indices = merges[:, 1]
# join left track with right track
for left_index, right_index in zip(left_indices, right_indices):
to_merge.loc[tuple(left_index)] = join_two_tracks(
to_merge.loc[tuple(left_index)].values,
to_merge.loc[tuple(right_index)].values,
)
to_merge.drop(map(tuple, sources), inplace=True)
# drop indices for right tracks
to_merge.drop(map(tuple, right_indices), inplace=True)
# add to data not requiring merges
merged = pd.concat((merged, to_merge), names=data.index.names)
return merged
def join_tracks_pair(target: np.ndarray, source: np.ndarray) -> np.ndarray:
"""
Join two tracks and return the new value of the target.
"""
target_copy = target
end = find_1st(target_copy[::-1], 0, cmp_larger)
target_copy[-end:] = source[-end:]
return target_copy
def group_merges(merges: np.ndarray) -> t.List[t.Tuple]:
# Return a list where the cell is present as source and target
# (multimerges)
sources_targets = compare_indices(merges[:, 0, :], merges[:, 1, :])
is_multimerge = sources_targets.any(axis=0) | sources_targets.any(axis=1)
is_monomerge = ~is_multimerge
multimerge_subsets = union_find(zip(*np.where(sources_targets)))
merge_groups = [merges[np.array(tuple(x))] for x in multimerge_subsets]
def join_two_tracks(
left_track: np.ndarray, right_track: np.ndarray
) -> np.ndarray:
"""Join two tracks and return the new one."""
new_track = left_track.copy()
# find last positive element by inverting track
end = find_1st(left_track[::-1], 0, cmp_larger)
# merge tracks into one
new_track[-end:] = right_track[-end:]
return new_track
sorted_merges = list(map(sort_association, merge_groups))
# Ensure that source and target are at the edges
return [
*sorted_merges,
*[[event] for event in merges[is_monomerge]],
]
##################################################################
def union_find(lsts):
......@@ -120,27 +191,3 @@ def sort_association(array: np.ndarray):
[res.append(x) for x in np.flip(order).flatten() if x not in res]
sorted_array = array[np.array(res)]
return sorted_array
def merge_association(
association: np.ndarray, merges: np.ndarray
) -> np.ndarray:
grouped_merges = group_merges(merges)
flat_indices = association.reshape(-1, 2)
comparison_mat = compare_indices(merges[:, 0], flat_indices)
valid_indices = comparison_mat.any(axis=0)
if valid_indices.any(): # Where valid, perform transformation
replacement_d = {}
for dataset in grouped_merges:
for k in dataset:
replacement_d[tuple(k[0])] = dataset[-1][1]
flat_indices[valid_indices] = [
replacement_d[tuple(i)] for i in flat_indices[valid_indices]
]
merged_indices = flat_indices.reshape(-1, 2, 2)
return merged_indices
# parameters to stop the pipeline when exceeded
earlystop = dict(
min_tp=100,
thresh_pos_clogged=0.4,
thresh_trap_ncells=8,
thresh_trap_area=0.9,
ntps_to_eval=5,
)
# imaging properties of the microscope
imaging_specifications = {
"pixel_size": 0.236,
"z_size": 0.6,
"spacing": 0.6,
}
# possible imaging channels
possible_imaging_channels = [
"Citrine",
"GFP",
"GFPFast",
"mCherry",
"Flavin",
"Citrine",
"mKO2",
"Cy5",
"pHluorin405",
"pHluorin488",
]
# functions to apply to the fluorescence of each cell
fluorescence_functions = [
"mean",
"median",
"std",
"imBackground",
"max5px_median",
]
# default fraction of time a cell must be in the experiment to be kept by Signal
signal_retained_cutoff = 0.8
......@@ -54,7 +54,7 @@ class DatasetLocalABC(ABC):
Abstract Base class to find local files, either OME-XML or raw images.
"""
_valid_suffixes = ("tiff", "png", "zarr")
_valid_suffixes = ("tiff", "png", "zarr", "tif")
_valid_meta_suffixes = ("txt", "log")
def __init__(self, dpath: t.Union[str, Path], *args, **kwargs):
......
......@@ -30,14 +30,14 @@ from agora.io.metadata import dir_to_meta, dispatch_metadata_parser
def get_examples_dir():
"""Get examples directory which stores dummy image for tiler"""
"""Get examples directory that stores dummy image for tiler."""
return files("aliby").parent.parent / "examples" / "tiler"
def instantiate_image(
source: t.Union[str, int, t.Dict[str, str], Path], **kwargs
):
"""Wrapper to instatiate the appropiate image
"""Wrapper to instantiate the appropriate image
Parameters
----------
......@@ -55,26 +55,26 @@ def instantiate_image(
def dispatch_image(source: t.Union[str, int, t.Dict[str, str], Path]):
"""
Wrapper to pick the appropiate Image class depending on the source of data.
"""
"""Pick the appropriate Image class depending on the source of data."""
if isinstance(source, (int, np.int64)):
from aliby.io.omero import Image
instatiator = Image
instantiator = Image
elif isinstance(source, dict) or (
isinstance(source, (str, Path)) and Path(source).is_dir()
):
if Path(source).suffix == ".zarr":
instatiator = ImageZarr
instantiator = ImageZarr
else:
instatiator = ImageDir
instantiator = ImageDir
elif isinstance(source, Path) and source.is_file():
# my addition
instantiator = ImageLocalOME
elif isinstance(source, str) and Path(source).is_file():
instatiator = ImageLocalOME
instantiator = ImageLocalOME
else:
raise Exception(f"Invalid data source at {source}")
return instatiator
return instantiator
class BaseLocalImage(ABC):
......@@ -82,6 +82,7 @@ class BaseLocalImage(ABC):
Base Image class to set path and provide context management method.
"""
# default image order
_default_dimorder = "tczyx"
def __init__(self, path: t.Union[str, Path]):
......@@ -98,8 +99,7 @@ class BaseLocalImage(ABC):
return False
def rechunk_data(self, img):
# Format image using x and y size from metadata.
"""Format image using x and y size from metadata."""
self._rechunked_img = da.rechunk(
img,
chunks=(
......@@ -145,16 +145,16 @@ class ImageLocalOME(BaseLocalImage):
in which a multidimensional tiff image contains the metadata.
"""
def __init__(self, path: str, dimorder=None):
def __init__(self, path: str, dimorder=None, **kwargs):
super().__init__(path)
self._id = str(path)
self.set_meta(str(path))
def set_meta(self):
def set_meta(self, path):
meta = dict()
try:
with TiffFile(path) as f:
self._meta = xmltodict.parse(f.ome_metadata)["OME"]
for dim in self.dimorder:
meta["size_" + dim.lower()] = int(
self._meta["Image"]["Pixels"]["@Size" + dim]
......@@ -165,21 +165,19 @@ class ImageLocalOME(BaseLocalImage):
]
meta["name"] = self._meta["Image"]["@Name"]
meta["type"] = self._meta["Image"]["Pixels"]["@Type"]
except Exception as e: # Images not in OMEXML
except Exception as e:
# images not in OMEXML
print("Warning:Metadata not found: {}".format(e))
print(
f"Warning: No dimensional info provided. Assuming {self._default_dimorder}"
"Warning: No dimensional info provided. "
f"Assuming {self._default_dimorder}"
)
# Mark non-existent dimensions for padding
# mark non-existent dimensions for padding
self.base = self._default_dimorder
# self.ids = [self.index(i) for i in dimorder]
self._dimorder = base
self._dimorder = self.base
self._meta = meta
# self._meta["name"] = Path(path).name.split(".")[0]
@property
def name(self):
......@@ -246,7 +244,7 @@ class ImageDir(BaseLocalImage):
It inherits from BaseLocalImage so we only override methods that are critical.
Assumptions:
- One folders per position.
- One folder per position.
- Images are flat.
- Channel, Time, z-stack and the others are determined by filenames.
- Provides Dimorder as it is set in the filenames, or expects order during instatiation
......@@ -318,7 +316,7 @@ class ImageZarr(BaseLocalImage):
print(f"Could not add size info to metadata: {e}")
def get_data_lazy(self) -> da.Array:
"""Return 5D dask array. For lazy-loading local multidimensional zarr files"""
"""Return 5D dask array for lazy-loading local multidimensional zarr files."""
return self._img
def add_size_to_meta(self):
......
......@@ -131,7 +131,6 @@ class BridgeOmero:
FIXME: Add docs.
"""
# metadata = load_attributes(filepath)
bridge = BridgeH5(filepath)
meta = safe_load(bridge.meta_h5["parameters"])["general"]
server_info = {k: meta[k] for k in ("host", "username", "password")}
......@@ -268,7 +267,6 @@ class Dataset(BridgeOmero):
FIXME: Add docs.
"""
# metadata = load_attributes(filepath)
bridge = BridgeH5(filepath)
dataset_keys = ("omero_id", "omero_id,", "dataset_id")
for k in dataset_keys:
......@@ -301,21 +299,21 @@ class Image(BridgeOmero):
cls,
filepath: t.Union[str, Path],
):
"""Instatiate Image from a hdf5 file.
"""
Instantiate Image from a h5 file.
Parameters
----------
cls : Image
Image class
filepath : t.Union[str, Path]
Location of hdf5 file.
Location of h5 file.
Examples
--------
FIXME: Add docs.
"""
# metadata = load_attributes(filepath)
bridge = BridgeH5(filepath)
image_id = bridge.meta_h5["image_id"]
return cls(image_id, **cls.server_info_from_h5(filepath))
......
......@@ -7,22 +7,19 @@ import typing as t
from copy import copy
from importlib.metadata import version
from pathlib import Path
from pprint import pprint
import h5py
import numpy as np
import pandas as pd
from pathos.multiprocessing import Pool
from tqdm import tqdm
import aliby.global_parameters as global_parameters
from agora.abc import ParametersABC, ProcessABC
from agora.io.metadata import MetaData, parse_logfiles
from agora.io.reader import StateReader
from agora.io.signal import Signal
from agora.io.writer import (
LinearBabyWriter,
StateWriter,
TilerWriter,
)
from agora.io.writer import LinearBabyWriter, StateWriter, TilerWriter
from aliby.baby_client import BabyParameters, BabyRunner
from aliby.haystack import initialise_tf
from aliby.io.dataset import dispatch_dataset
......@@ -32,6 +29,10 @@ from extraction.core.extractor import Extractor, ExtractorParameters
from extraction.core.functions.defaults import exparams_from_meta
from postprocessor.core.processor import PostProcessor, PostProcessorParameters
# stop warnings from TensorFlow
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
logging.getLogger("tensorflow").setLevel(logging.ERROR)
class PipelineParameters(ParametersABC):
"""Define parameters for the steps of the pipeline."""
......@@ -39,7 +40,12 @@ class PipelineParameters(ParametersABC):
_pool_index = None
def __init__(
self, general, tiler, baby, extraction, postprocessing, reporting
self,
general,
tiler,
baby,
extraction,
postprocessing,
):
"""Initialise, but called by a class method - not directly."""
self.general = general
......@@ -47,7 +53,6 @@ class PipelineParameters(ParametersABC):
self.baby = baby
self.extraction = extraction
self.postprocessing = postprocessing
self.reporting = reporting
@classmethod
def default(
......@@ -76,16 +81,15 @@ class PipelineParameters(ParametersABC):
postprocessing: dict (optional)
Parameters for post-processing.
"""
expt_id = general.get("expt_id", 19993)
if isinstance(expt_id, Path):
assert expt_id.exists()
expt_id = str(expt_id)
general["expt_id"] = expt_id
if (
isinstance(general["expt_id"], Path)
and general["expt_id"].exists()
):
expt_id = str(general["expt_id"])
else:
expt_id = general["expt_id"]
directory = Path(general["directory"])
# get log files, either locally or via OMERO
# get metadata from log files either locally or via OMERO
with dispatch_dataset(
expt_id,
**{k: general.get(k) for k in ("host", "username", "password")},
......@@ -107,7 +111,6 @@ class PipelineParameters(ParametersABC):
}
# set minimal metadata
meta_d = minimal_default_meta
# define default values for general parameters
tps = meta_d.get("ntps", 2000)
defaults = {
......@@ -117,19 +120,12 @@ class PipelineParameters(ParametersABC):
tps=tps,
directory=str(directory.parent),
filter="",
earlystop=dict(
min_tp=100,
thresh_pos_clogged=0.4,
thresh_trap_ncells=8,
thresh_trap_area=0.9,
ntps_to_eval=5,
),
earlystop=global_parameters.earlystop,
logfile_level="INFO",
use_explog=True,
)
}
# update default values using inputs
# update default values for general using inputs
for k, v in general.items():
if k not in defaults["general"]:
defaults["general"][k] = v
......@@ -138,11 +134,9 @@ class PipelineParameters(ParametersABC):
defaults["general"][k][k2] = v2
else:
defaults["general"][k] = v
# define defaults and update with any inputs
# default Tiler parameters
defaults["tiler"] = TilerParameters.default(**tiler).to_dict()
# Generate a backup channel, for when logfile meta is available
# generate a backup channel for when logfile meta is available
# but not image metadata.
backup_ref_channel = None
if "channels" in meta_d and isinstance(
......@@ -152,20 +146,18 @@ class PipelineParameters(ParametersABC):
defaults["tiler"]["ref_channel"]
)
defaults["tiler"]["backup_ref_channel"] = backup_ref_channel
# default BABY parameters
defaults["baby"] = BabyParameters.default(**baby).to_dict()
defaults["extraction"] = (
exparams_from_meta(meta_d)
or BabyParameters.default(**extraction).to_dict()
)
# default Extraction parmeters
defaults["extraction"] = exparams_from_meta(meta_d)
# default PostProcessing parameters
defaults["postprocessing"] = PostProcessorParameters.default(
**postprocessing
).to_dict()
defaults["reporting"] = {}
return cls(**{k: v for k, v in defaults.items()})
def load_logs(self):
"""Load and parse log files."""
parsed_flattened = parse_logfiles(self.log_dir)
return parsed_flattened
......@@ -187,7 +179,7 @@ class Pipeline(ProcessABC):
"postprocessing",
]
# Specify the group in the h5 files written by each step
# specify the group in the h5 files written by each step
writer_groups = {
"tiler": ["trap_info"],
"baby": ["cell_info"],
......@@ -228,12 +220,6 @@ class Pipeline(ProcessABC):
fh.setFormatter(formatter)
logger.addHandler(fh)
@classmethod
def from_yaml(cls, fpath):
# This is just a convenience function, think before implementing
# for other processes
return cls(parameters=PipelineParameters.from_yaml(fpath))
@classmethod
def from_folder(cls, dir_path):
"""
......@@ -304,11 +290,16 @@ class Pipeline(ProcessABC):
def run(self):
"""Run separate pipelines for all positions in an experiment."""
# general information in config
# display configuration
config = self.parameters.to_dict()
for step in config:
print("\n---\n" + step + "\n---")
pprint(config[step])
print()
# extract from configuration
expt_id = config["general"]["id"]
distributed = config["general"]["distributed"]
pos_filter = config["general"]["filter"]
position_filter = config["general"]["filter"]
root_dir = Path(config["general"]["directory"])
self.server_info = {
k: config["general"].get(k)
......@@ -320,56 +311,76 @@ class Pipeline(ProcessABC):
)
# get log files, either locally or via OMERO
with dispatcher as conn:
image_ids = conn.get_images()
position_ids = conn.get_images()
directory = self.store or root_dir / conn.unique_name
if not directory.exists():
directory.mkdir(parents=True)
# download logs to use for metadata
# get logs to use for metadata
conn.cache_logs(directory)
print("Positions available:")
for i, pos in enumerate(position_ids.keys()):
print("\t" + f"{i}: " + pos.split(".")[0])
# update configuration
self.parameters.general["directory"] = str(directory)
config["general"]["directory"] = directory
self.setLogger(directory)
# pick particular images if desired
if pos_filter is not None:
if isinstance(pos_filter, list):
image_ids = {
# pick particular positions if desired
if position_filter is not None:
if isinstance(position_filter, list):
position_ids = {
k: v
for filt in pos_filter
for k, v in self.apply_filter(image_ids, filt).items()
for filt in position_filter
for k, v in self.apply_filter(position_ids, filt).items()
}
else:
image_ids = self.apply_filter(image_ids, pos_filter)
assert len(image_ids), "No images to segment"
# create pipelines
position_ids = self.apply_filter(position_ids, position_filter)
if not len(position_ids):
raise Exception("No images to segment.")
else:
print("\nPositions selected:")
for pos in position_ids:
print("\t" + pos.split(".")[0])
# create and run pipelines
if distributed != 0:
# multiple cores
with Pool(distributed) as p:
results = p.map(
lambda x: self.run_one_position(*x),
[(k, i) for i, k in enumerate(image_ids.items())],
[
(position_id, i)
for i, position_id in enumerate(position_ids.items())
],
)
else:
# single core
results = []
for k, v in tqdm(image_ids.items()):
r = self.run_one_position((k, v), 1)
results.append(r)
results = [
self.run_one_position((position_id, position_id_path), 1)
for position_id, position_id_path in tqdm(position_ids.items())
]
return results
def apply_filter(self, image_ids: dict, filt: int or str):
"""Select images by picking a particular one or by using a regular expression to parse their file names."""
if isinstance(filt, str):
# pick images using a regular expression
image_ids = {
k: v for k, v in image_ids.items() if re.search(filt, k)
def apply_filter(self, position_ids: dict, position_filter: int or str):
"""
Select positions.
Either pick a particular position or use a regular expression
to parse their file names.
"""
if isinstance(position_filter, str):
# pick positions using a regular expression
position_ids = {
k: v
for k, v in position_ids.items()
if re.search(position_filter, k)
}
elif isinstance(filt, int):
# pick the filt'th image
image_ids = {
k: v for i, (k, v) in enumerate(image_ids.items()) if i == filt
elif isinstance(position_filter, int):
# pick a particular position
position_ids = {
k: v
for i, (k, v) in enumerate(position_ids.items())
if i == position_filter
}
return image_ids
return position_ids
def run_one_position(
self,
......@@ -379,120 +390,101 @@ class Pipeline(ProcessABC):
"""Set up and run a pipeline for one position."""
self._pool_index = index
name, image_id = name_image_id
# session and filename are defined by calling setup_pipeline.
# can they be deleted here?
# session is defined by calling pipe_pipeline.
# can it be deleted here?
session = None
filename = None
#
run_kwargs = {"extraction": {"labels": None, "masks": None}}
run_kwargs = {"extraction": {"cell_labels": None, "masks": None}}
try:
(
filename,
meta,
config,
process_from,
tps,
steps,
earlystop,
session,
trackers_state,
) = self._setup_pipeline(image_id)
pipe, session = self.setup_pipeline(image_id, name)
loaded_writers = {
name: writer(filename)
name: writer(pipe["filename"])
for k in self.step_sequence
if k in self.writers
for name, writer in self.writers[k]
}
writer_ow_kwargs = {
writer_overwrite_kwargs = {
"state": loaded_writers["state"].datatypes.keys(),
"baby": ["mother_assign"],
}
# START PIPELINE
frac_clogged_traps = 0.0
min_process_from = min(process_from.values())
min_process_from = min(pipe["process_from"].values())
with dispatch_image(image_id)(
image_id, **self.server_info
) as image:
# initialise steps
if "tiler" not in steps:
steps["tiler"] = Tiler.from_image(
image, TilerParameters.from_dict(config["tiler"])
if "tiler" not in pipe["steps"]:
pipe["config"]["tiler"]["position_name"] = name.split(".")[
0
]
pipe["steps"]["tiler"] = Tiler.from_image(
image,
TilerParameters.from_dict(pipe["config"]["tiler"]),
)
if process_from["baby"] < tps:
if pipe["process_from"]["baby"] < pipe["tps"]:
session = initialise_tf(2)
steps["baby"] = BabyRunner.from_tiler(
BabyParameters.from_dict(config["baby"]),
steps["tiler"],
)
if trackers_state:
steps["baby"].crawler.tracker_states = trackers_state
# limit extraction parameters using the available channels in tiler
if process_from["extraction"] < tps:
# TODO Move this parameter validation into Extractor
av_channels = set((*steps["tiler"].channels, "general"))
config["extraction"]["tree"] = {
k: v
for k, v in config["extraction"]["tree"].items()
if k in av_channels
}
config["extraction"]["sub_bg"] = av_channels.intersection(
config["extraction"]["sub_bg"]
)
av_channels_wsub = av_channels.union(
[c + "_bgsub" for c in config["extraction"]["sub_bg"]]
pipe["steps"]["baby"] = BabyRunner.from_tiler(
BabyParameters.from_dict(pipe["config"]["baby"]),
pipe["steps"]["tiler"],
)
tmp = copy(config["extraction"]["multichannel_ops"])
for op, (input_ch, _, _) in tmp.items():
if not set(input_ch).issubset(av_channels_wsub):
del config["extraction"]["multichannel_ops"][op]
if pipe["trackers_state"]:
pipe["steps"]["baby"].crawler.tracker_states = pipe[
"trackers_state"
]
if pipe["process_from"]["extraction"] < pipe["tps"]:
exparams = ExtractorParameters.from_dict(
config["extraction"]
pipe["config"]["extraction"]
)
steps["extraction"] = Extractor.from_tiler(
exparams, store=filename, tiler=steps["tiler"]
pipe["steps"]["extraction"] = Extractor.from_tiler(
exparams,
store=pipe["filename"],
tiler=pipe["steps"]["tiler"],
)
# set up progress meter
# initiate progress bar
pbar = tqdm(
range(min_process_from, tps),
range(min_process_from, pipe["tps"]),
desc=image.name,
initial=min_process_from,
total=tps,
total=pipe["tps"],
)
# run through time points
for i in pbar:
if (
frac_clogged_traps
< earlystop["thresh_pos_clogged"]
or i < earlystop["min_tp"]
< pipe["earlystop"]["thresh_pos_clogged"]
or i < pipe["earlystop"]["min_tp"]
):
# run through steps
for step in self.pipeline_steps:
if i >= process_from[step]:
result = steps[step].run_tp(
if i >= pipe["process_from"][step]:
# perform step
result = pipe["steps"][step].run_tp(
i, **run_kwargs.get(step, {})
)
# write to h5 file using writers
# extractor writes to h5 itself
if step in loaded_writers:
loaded_writers[step].write(
data=result,
overwrite=writer_ow_kwargs.get(
overwrite=writer_overwrite_kwargs.get(
step, []
),
tp=i,
meta={"last_processed": i},
)
# perform step
# clean up
if (
step == "tiler"
and i == min_process_from
):
logging.getLogger("aliby").info(
f"Found {steps['tiler'].n_tiles} traps in {image.name}"
f"Found {pipe['steps']['tiler'].no_tiles} traps in {image.name}"
)
elif step == "baby":
# write state and pass info to Extractor
# write state
loaded_writers["state"].write(
data=steps[
data=pipe["steps"][
step
].crawler.tracker_states,
overwrite=loaded_writers[
......@@ -501,12 +493,14 @@ class Pipeline(ProcessABC):
tp=i,
)
elif step == "extraction":
# remove mask/label after extraction
for k in ["masks", "labels"]:
# remove masks and labels after extraction
for k in ["masks", "cell_labels"]:
run_kwargs[step][k] = None
# check and report clogging
frac_clogged_traps = self.check_earlystop(
filename, earlystop, steps["tiler"].tile_size
pipe["filename"],
pipe["earlystop"],
pipe["steps"]["tiler"].tile_size,
)
if frac_clogged_traps > 0.3:
self._log(
......@@ -519,18 +513,17 @@ class Pipeline(ProcessABC):
self._log(
f"{name}:Stopped early at time {i} with {frac_clogged_traps} clogged traps"
)
meta.add_fields({"end_status": "Clogged"})
pipe["meta"].add_fields({"end_status": "Clogged"})
break
meta.add_fields({"last_processed": i})
pipe["meta"].add_fields({"last_processed": i})
pipe["meta"].add_fields({"end_status": "Success"})
# run post-processing
meta.add_fields({"end_status": "Success"})
post_proc_params = PostProcessorParameters.from_dict(
config["postprocessing"]
pipe["config"]["postprocessing"]
)
PostProcessor(filename, post_proc_params).run()
PostProcessor(pipe["filename"], post_proc_params).run()
self._log("Analysis finished successfully.", "info")
return 1
except Exception as e:
# catch bugs during setup or run time
logging.exception(
......@@ -541,88 +534,12 @@ class Pipeline(ProcessABC):
traceback.print_exc()
raise e
finally:
_close_session(session)
@staticmethod
def check_earlystop(filename: str, es_parameters: dict, tile_size: int):
"""
Check recent time points for tiles with too many cells.
Returns the fraction of clogged tiles, where clogged tiles have
too many cells or too much of their area covered by cells.
Parameters
----------
filename: str
Name of h5 file.
es_parameters: dict
Parameters defining when early stopping should happen.
For example:
{'min_tp': 100,
'thresh_pos_clogged': 0.4,
'thresh_trap_ncells': 8,
'thresh_trap_area': 0.9,
'ntps_to_eval': 5}
tile_size: int
Size of tile.
"""
# get the area of the cells organised by trap and cell number
s = Signal(filename)
df = s.get_raw("/extraction/general/None/area")
# check the latest time points only
cells_used = df[
df.columns[-1 - es_parameters["ntps_to_eval"] : -1]
].dropna(how="all")
# find tiles with too many cells
traps_above_nthresh = (
cells_used.groupby("trap").count().apply(np.mean, axis=1)
> es_parameters["thresh_trap_ncells"]
)
# find tiles with cells covering too great a fraction of the tiles' area
traps_above_athresh = (
cells_used.groupby("trap").sum().apply(np.mean, axis=1)
/ tile_size**2
> es_parameters["thresh_trap_area"]
)
return (traps_above_nthresh & traps_above_athresh).mean()
close_session(session)
# FIXME: Remove this functionality. It used to be for
# older hdf5 file formats.
def _load_config_from_file(
def setup_pipeline(
self,
filename: Path,
process_from: t.Dict[str, int],
trackers_state: t.List,
overwrite: t.Dict[str, bool],
):
with h5py.File(filename, "r") as f:
for k in process_from.keys():
if not overwrite[k]:
process_from[k] = self.legacy_get_last_tp[k](f)
process_from[k] += 1
return process_from, trackers_state, overwrite
# FIXME: Remove this functionality. It used to be for
# older hdf5 file formats.
@staticmethod
def legacy_get_last_tp(step: str) -> t.Callable:
"""Get last time-point in different ways depending
on which step we are using
To support segmentation in aliby < v0.24
TODO Deprecate and replace with State method
"""
switch_case = {
"tiler": lambda f: f["trap_info/drifts"].shape[0] - 1,
"baby": lambda f: f["cell_info/timepoint"][-1],
"extraction": lambda f: f[
"extraction/general/None/area/timepoint"
][-1],
}
return switch_case[step]
def _setup_pipeline(
self, image_id: int
image_id: int,
name: str,
) -> t.Tuple[
Path,
MetaData,
......@@ -645,83 +562,100 @@ class Pipeline(ProcessABC):
Returns
-------
filename: str
Path to a h5 file to write to.
meta: object
agora.io.metadata.MetaData object
config: dict
Configuration parameters.
process_from: dict
Gives from which time point each step of the pipeline should start.
tps: int
Number of time points.
steps: dict
earlystop: dict
Parameters to check whether the pipeline should be stopped.
pipe: dict
With keys
filename: str
Path to a h5 file to write to.
meta: object
agora.io.metadata.MetaData object
config: dict
Configuration parameters.
process_from: dict
Gives time points from which each step of the
pipeline should start.
tps: int
Number of time points.
steps: dict
earlystop: dict
Parameters to check whether the pipeline should
be stopped.
trackers_state: list
States of any trackers from earlier runs.
session: None
trackers_state: list
States of any trackers from earlier runs.
"""
pipe = {}
config = self.parameters.to_dict()
# TODO Alan: Verify if session must be passed
session = None
earlystop = config["general"].get("earlystop", None)
process_from = {k: 0 for k in self.pipeline_steps}
steps = {}
pipe["earlystop"] = config["general"].get("earlystop", None)
pipe["process_from"] = {k: 0 for k in self.pipeline_steps}
pipe["steps"] = {}
# check overwriting
ow_id = config["general"].get("overwrite", 0)
ow = {step: True for step in self.step_sequence}
if ow_id and ow_id is not True:
ow = {
step: self.step_sequence.index(ow_id) < i
overwrite_id = config["general"].get("overwrite", 0)
overwrite = {step: True for step in self.step_sequence}
if overwrite_id and overwrite_id is not True:
overwrite = {
step: self.step_sequence.index(overwrite_id) < i
for i, step in enumerate(self.step_sequence, 1)
}
# Set up
# set up
directory = config["general"]["directory"]
trackers_state: t.List[np.ndarray] = []
pipe["trackers_state"] = []
with dispatch_image(image_id)(image_id, **self.server_info) as image:
filename = Path(f"{directory}/{image.name}.h5")
meta = MetaData(directory, filename)
from_start = True if np.any(ow.values()) else False
# remove existing file if overwriting
pipe["filename"] = Path(f"{directory}/{image.name}.h5")
# load metadata from h5 file
pipe["meta"] = MetaData(directory, pipe["filename"])
from_start = True if np.any(overwrite.values()) else False
# remove existing h5 file if overwriting
if (
from_start
and (
config["general"].get("overwrite", False)
or np.all(list(ow.values()))
or np.all(list(overwrite.values()))
)
and filename.exists()
and pipe["filename"].exists()
):
os.remove(filename)
os.remove(pipe["filename"])
# if the file exists with no previous segmentation use its tiler
if filename.exists():
if pipe["filename"].exists():
self._log("Result file exists.", "info")
if not ow["tiler"]:
steps["tiler"] = Tiler.from_hdf5(image, filename)
if not overwrite["tiler"]:
tiler_params_dict = TilerParameters.default().to_dict()
tiler_params_dict["position_name"] = name.split(".")[0]
tiler_params = TilerParameters.from_dict(tiler_params_dict)
pipe["steps"]["tiler"] = Tiler.from_h5(
image, pipe["filename"], tiler_params
)
try:
(
process_from,
trackers_state,
ow,
overwrite,
) = self._load_config_from_file(
filename, process_from, trackers_state, ow
pipe["filename"],
pipe["process_from"],
pipe["trackers_state"],
overwrite,
)
# get state array
trackers_state = (
pipe["trackers_state"] = (
[]
if ow["baby"]
else StateReader(filename).get_formatted_states()
if overwrite["baby"]
else StateReader(
pipe["filename"]
).get_formatted_states()
)
config["tiler"] = steps["tiler"].parameters.to_dict()
config["tiler"] = pipe["steps"][
"tiler"
].parameters.to_dict()
except Exception:
self._log(f"Overwriting tiling data")
self._log("Overwriting tiling data")
if config["general"]["use_explog"]:
meta.run()
pipe["meta"].run()
pipe["config"] = config
# add metadata not in the log file
meta.add_fields(
pipe["meta"].add_fields(
{
"aliby_version": version("aliby"),
"baby_version": version("aliby-baby"),
......@@ -734,20 +668,53 @@ class Pipeline(ProcessABC):
).to_yaml(),
}
)
tps = min(config["general"]["tps"], image.data.shape[0])
return (
filename,
meta,
config,
process_from,
tps,
steps,
earlystop,
session,
trackers_state,
)
pipe["tps"] = min(config["general"]["tps"], image.data.shape[0])
return pipe, session
@staticmethod
def check_earlystop(filename: str, es_parameters: dict, tile_size: int):
"""
Check recent time points for tiles with too many cells.
Returns the fraction of clogged tiles, where clogged tiles have
too many cells or too much of their area covered by cells.
Parameters
----------
filename: str
Name of h5 file.
es_parameters: dict
Parameters defining when early stopping should happen.
For example:
{'min_tp': 100,
'thresh_pos_clogged': 0.4,
'thresh_trap_ncells': 8,
'thresh_trap_area': 0.9,
'ntps_to_eval': 5}
tile_size: int
Size of tile.
"""
# get the area of the cells organised by trap and cell number
s = Signal(filename)
df = s.get_raw("/extraction/general/None/area")
# check the latest time points only
cells_used = df[
df.columns[-1 - es_parameters["ntps_to_eval"] : -1]
].dropna(how="all")
# find tiles with too many cells
traps_above_nthresh = (
cells_used.groupby("trap").count().apply(np.mean, axis=1)
> es_parameters["thresh_trap_ncells"]
)
# find tiles with cells covering too great a fraction of the tiles' area
traps_above_athresh = (
cells_used.groupby("trap").sum().apply(np.mean, axis=1)
/ tile_size**2
> es_parameters["thresh_trap_area"]
)
return (traps_above_nthresh & traps_above_athresh).mean()
def _close_session(session):
def close_session(session):
if session:
session.close()
"""
Tiler: Divides images into smaller tiles.
The tasks of the Tiler are selecting regions of interest, or tiles, of images - with one trap per tile, correcting for the drift of the microscope stage over time, and handling errors and bridging between the image data and Aliby’s image-processing steps.
The tasks of the Tiler are selecting regions of interest, or tiles, of
images - with one trap per tile, correcting for the drift of the microscope
stage over time, and handling errors and bridging between the image data
and Aliby’s image-processing steps.
Tiler subclasses deal with either network connections or local files.
To find tiles, we use a two-step process: we analyse the bright-field image to produce the template of a trap, and we fit this template to the image to find the tiles' centres.
To find tiles, we use a two-step process: we analyse the bright-field image
to produce the template of a trap, and we fit this template to the image to
find the tiles' centres.
We use texture-based segmentation (entropy) to split the image into foreground -- cells and traps -- and background, which we then identify with an Otsu filter. Two methods are used to produce a template trap from these regions: pick the trap with the smallest minor axis length and average over all validated traps.
We use texture-based segmentation (entropy) to split the image into
foreground -- cells and traps -- and background, which we then identify with
an Otsu filter. Two methods are used to produce a template trap from these
regions: pick the trap with the smallest minor axis length and average over
all validated traps.
A peak-identifying algorithm recovers the x and y-axis location of traps in the original image, and we choose the approach to template that identifies the most tiles.
A peak-identifying algorithm recovers the x and y-axis location of traps in
the original image, and we choose the approach to template that identifies
the most tiles.
The experiment is stored as an array with a standard indexing order of (Time, Channels, Z-stack, X, Y).
The experiment is stored as an array with a standard indexing order of
(Time, Channels, Z-stack, X, Y).
"""
import logging
import re
......@@ -25,28 +37,24 @@ import h5py
import numpy as np
from skimage.registration import phase_cross_correlation
import aliby.global_parameters as global_parameters
from agora.abc import ParametersABC, StepABC
from agora.io.writer import BridgeH5
from aliby.io.image import ImageDummy
from aliby.tile.traps import segment_traps
class Tile:
"""
Store a tile's location and size.
Checks to see if the tile should be padded.
Can export the tile either in OMERO or numpy formats.
"""
"""Store a tile's location and size."""
def __init__(self, centre, parent, size, max_size):
def __init__(self, centre, parent_class, size, max_size):
"""Initialise using a parent class."""
self.centre = centre
self.parent = parent # used to access drifts
self.parent_class = parent_class # used to access drifts
self.size = size
self.half_size = size // 2
self.max_size = max_size
def at_time(self, tp: int) -> t.List[int]:
def centre_at_time(self, tp: int) -> t.List[int]:
"""
Return tile's centre by applying drifts.
......@@ -55,7 +63,7 @@ class Tile:
tp: integer
Index for the time point of interest.
"""
drifts = self.parent.drifts
drifts = self.parent_class.drifts
tile_centre = self.centre - np.sum(drifts[: tp + 1], axis=0)
return list(tile_centre.astype(int))
......@@ -74,15 +82,15 @@ class Tile:
Returns
-------
x: int
x-coordinate of bottom left corner of tile
x-coordinate of bottom left corner of tile.
y: int
y-coordinate of bottom left corner of tile
y-coordinate of bottom left corner of tile.
w: int
Width of tile
Width of tile.
h: int
Height of tile
Height of tile.
"""
x, y = self.at_time(tp)
x, y = self.centre_at_time(tp)
# tile bottom corner
x = int(x - self.half_size)
y = int(y - self.half_size)
......@@ -90,8 +98,7 @@ class Tile:
def as_range(self, tp: int):
"""
Return tile in a range format: two slice objects that can
be used in arrays.
Return a horizontal and a vertical slice of a tile.
Parameters
----------
......@@ -117,6 +124,20 @@ class TileLocations:
max_size: int = 1200,
drifts: np.array = None,
):
"""
Initialise tiles as an array of Tile objects.
Parameters
----------
initial_location: array
An array of tile centres.
tile_size: int
Length of one side of a square tile.
max_size: int, optional
Default is 1200.
drifts: array
An array of translations to correct drift of the microscope.
"""
if drifts is None:
drifts = []
self.tile_size = tile_size
......@@ -129,20 +150,21 @@ class TileLocations:
self.drifts = drifts
def __len__(self):
"""Find number of tiles."""
return len(self.tiles)
def __iter__(self):
"""Return the next tile from the list of tiles."""
yield from self.tiles
@property
def shape(self):
"""Return numbers of tiles and drifts."""
"""Return the number of tiles and the number of drifts."""
return len(self.tiles), len(self.drifts)
def to_dict(self, tp: int):
"""
Export initial locations, tile_size, max_size, and drifts
as a dictionary.
Export initial locations, tile_size, max_size, and drifts as a dict.
Parameters
----------
......@@ -157,19 +179,22 @@ class TileLocations:
res["drifts"] = np.expand_dims(self.drifts[tp], axis=0)
return res
def at_time(self, tp: int) -> np.ndarray:
def centres_at_time(self, tp: int) -> np.ndarray:
"""Return an array of tile centres (x- and y-coords)."""
return np.array([tile.at_time(tp) for tile in self.tiles])
return np.array([tile.centre_at_time(tp) for tile in self.tiles])
@classmethod
def from_tiler_init(
cls, initial_location, tile_size: int = None, max_size: int = 1200
cls,
initial_location,
tile_size: int = None,
max_size: int = 1200,
):
"""Instantiate from a Tiler."""
return cls(initial_location, tile_size, max_size, drifts=[])
@classmethod
def read_hdf5(cls, file):
def read_h5(cls, file):
"""Instantiate from a h5 file."""
with h5py.File(file, "r") as hfile:
tile_info = hfile["trap_info"]
......@@ -183,30 +208,41 @@ class TileLocations:
class TilerParameters(ParametersABC):
"""
tile_size: int
ref_channel: str or int
ref_z: int
backup_ref_channel int or None, if int indicates the index for reference channel. Used when image does not include metadata, ref_channel is a string and channel names are included in parsed logfiles.
"""
"""Define default values for tile size and the reference channels."""
_defaults = {
"tile_size": 117,
"ref_channel": "Brightfield",
"ref_z": 0,
"backup_ref_channel": None,
"position_name": None,
}
def find_channels_by_position(meta):
"""Parse metadata to find the imaging channels used for each group."""
channels_dict = {
position_name: [] for position_name in meta["positions/posname"]
}
imaging_channels = meta["channels"]
for i, position_name in enumerate(meta["positions/posname"]):
for imaging_channel in imaging_channels:
if (
"positions/" + imaging_channel in meta
and meta["positions/" + imaging_channel][i]
):
channels_dict[position_name].append(imaging_channel)
return channels_dict
class Tiler(StepABC):
"""
Divide images into smaller tiles for faster processing.
Finds tiles and re-registers images if they drift.
Find tiles and re-register images if they drift.
Fetch images from an OMERO server if necessary.
Uses an Image instance, which lazily provides the data on pixels,
Uses an Image instance, which lazily provides the pixel data,
and, as an independent argument, metadata.
"""
......@@ -215,7 +251,7 @@ class Tiler(StepABC):
image: da.core.Array,
metadata: dict,
parameters: TilerParameters,
tile_locs=None,
tile_locations=None,
):
"""
Initialise.
......@@ -229,64 +265,25 @@ class Tiler(StepABC):
"""
super().__init__(parameters)
self.image = image
self._metadata = metadata
self.channels = metadata.get(
"channels",
self.position_name = parameters.to_dict()["position_name"]
# get channels for this position
channel_dict = find_channels_by_position(metadata)
self.channels = channel_dict.get(
self.position_name,
list(range(metadata.get("size_c", 0))),
)
# get reference channel - used for segmentation
self.ref_channel = self.get_channel_index(parameters.ref_channel)
if self.ref_channel is None:
self.ref_channel = self.backup_ref_channel
self.ref_channel = self.get_channel_index(parameters.ref_channel)
self.tile_locs = tile_locs
try:
self.tile_locs = tile_locations
if "zsections" in metadata:
self.z_perchannel = {
ch: zsect
for ch, zsect in zip(self.channels, metadata["zsections"])
}
except Exception as e:
self._log(f"No z_perchannel data: {e}")
self.tile_size = self.tile_size or min(self.image.shape[-2:])
@classmethod
def dummy(cls, parameters: dict):
"""
Instantiate dummy Tiler from dummy image.
If image.dimorder exists dimensions are saved in that order.
Otherwise default to "tczyx".
Parameters
----------
parameters: dict
An instance of TilerParameters converted to a dict.
"""
imgdmy_obj = ImageDummy(parameters)
dummy_image = imgdmy_obj.get_data_lazy()
# default to "tczyx" if image.dimorder is None
dummy_omero_metadata = {
f"size_{dim}": dim_size
for dim, dim_size in zip(
imgdmy_obj.dimorder or "tczyx", dummy_image.shape
)
}
dummy_omero_metadata.update(
{
"channels": [
parameters["ref_channel"],
*(["nil"] * (dummy_omero_metadata["size_c"] - 1)),
],
"name": "",
}
)
return cls(
imgdmy_obj.data,
dummy_omero_metadata,
TilerParameters.from_dict(parameters),
)
@classmethod
def from_image(cls, image, parameters: TilerParameters):
"""
......@@ -307,16 +304,16 @@ class Tiler(StepABC):
parameters: t.Optional[TilerParameters] = None,
):
"""
Instantiate from h5 files.
Instantiate from an h5 file.
Parameters
----------
image: an instance of Image
filepath: Path instance
Path to a directory of h5 files
Path to an h5 file.
parameters: an instance of TileParameters (optional)
"""
tile_locs = TileLocations.read_hdf5(filepath)
tile_locs = TileLocations.read_h5(filepath)
metadata = BridgeH5(filepath).meta_h5
metadata["channels"] = image.metadata["channels"]
if parameters is None:
......@@ -328,11 +325,11 @@ class Tiler(StepABC):
tile_locs=tile_locs,
)
if hasattr(tile_locs, "drifts"):
tiler.n_processed = len(tile_locs.drifts)
tiler.no_processed = len(tile_locs.drifts)
return tiler
@lru_cache(maxsize=2)
def get_tc(self, t: int, c: int) -> np.ndarray:
def load_image(self, tp: int, c: int) -> np.ndarray:
"""
Load image using dask.
......@@ -345,7 +342,7 @@ class Tiler(StepABC):
Parameters
----------
t: integer
tp: integer
An index for a time point
c: integer
An index for a channel
......@@ -354,32 +351,35 @@ class Tiler(StepABC):
-------
full: an array of images
"""
full = self.image[t, c]
if hasattr(full, "compute"): # If using dask fetch images here
full = self.image[tp, c]
if hasattr(full, "compute"):
# if using dask fetch images
full = full.compute(scheduler="synchronous")
return full
@property
def shape(self):
"""
Return properties of the time-lapse as shown by self.image.shape
Return the shape of the image array.
The image array is arranged as number of images, number of channels,
number of z sections, and size of the image in y and x.
"""
return self.image.shape
@property
def n_processed(self):
def no_processed(self):
"""Return the number of processed images."""
if not hasattr(self, "_n_processed"):
self._n_processed = 0
return self._n_processed
if not hasattr(self, "_no_processed"):
self._no_processed = 0
return self._no_processed
@n_processed.setter
def n_processed(self, value):
self._n_processed = value
@no_processed.setter
def no_processed(self, value):
self._no_processed = value
@property
def n_tiles(self):
def no_tiles(self):
"""Return number of tiles."""
return len(self.tile_locs)
......@@ -398,9 +398,8 @@ class Tiler(StepABC):
initial_image = self.image[0, self.ref_channel, self.ref_z]
if tile_size:
half_tile = tile_size // 2
# max_size is the minimal number of x or y pixels
# max_size is the minimum of the numbers of x and y pixels
max_size = min(self.image.shape[-2:])
# first time point, reference channel, reference z-position
# find the tiles
tile_locs = segment_traps(initial_image, tile_size)
# keep only tiles that are not near an edge
......@@ -415,6 +414,7 @@ class Tiler(StepABC):
tile_locs, tile_size
)
else:
# one tile with its centre at the image's centre
yx_shape = self.image.shape[-2:]
tile_locs = [[x // 2 for x in yx_shape]]
self.tile_locs = TileLocations.from_tiler_init(
......@@ -423,8 +423,9 @@ class Tiler(StepABC):
def find_drift(self, tp: int):
"""
Find any translational drift between two images at consecutive
time points using cross correlation.
Find any translational drift between two images.
Use cross correlation between two consecutive images.
Arguments
---------
......@@ -445,7 +446,7 @@ class Tiler(StepABC):
def get_tp_data(self, tp, c) -> np.ndarray:
"""
Returns all tiles corrected for drift.
Return all tiles corrected for drift.
Parameters
----------
......@@ -456,25 +457,24 @@ class Tiler(StepABC):
Returns
----------
Numpy ndarray of tiles with shape (tile, z, y, x)
Numpy ndarray of tiles with shape (no tiles, z-sections, y, x)
"""
tiles = []
# get OMERO image
full = self.get_tc(tp, c)
full = self.load_image(tp, c)
for tile in self.tile_locs:
# pad tile if necessary
ndtile = self.ifoob_pad(full, tile.as_range(tp))
ndtile = Tiler.if_out_of_bounds_pad(full, tile.as_range(tp))
tiles.append(ndtile)
return np.stack(tiles)
def get_tile_data(self, tile_id: int, tp: int, c: int):
"""
Return a particular tile corrected for drift and padding.
Return a tile corrected for drift and padding.
Parameters
----------
tile_id: integer
Number of tile.
Index of tile.
tp: integer
Index of time points.
c: integer
......@@ -485,14 +485,14 @@ class Tiler(StepABC):
ndtile: array
An array of (x, y) arrays, one for each z stack
"""
full = self.get_tc(tp, c)
full = self.load_image(tp, c)
tile = self.tile_locs.tiles[tile_id]
ndtile = self.ifoob_pad(full, tile.as_range(tp))
ndtile = self.if_out_of_bounds_pad(full, tile.as_range(tp))
return ndtile
def _run_tp(self, tp: int):
"""
Find tiles if they have not yet been found.
Find tiles for a given time point.
Determine any translational drift of the current image from the
previous one.
......@@ -502,41 +502,33 @@ class Tiler(StepABC):
tp: integer
The time point to tile.
"""
# assert tp >= self.n_processed, "Time point already processed"
# TODO check contiguity?
if self.n_processed == 0 or not hasattr(self.tile_locs, "drifts"):
if self.no_processed == 0 or not hasattr(self.tile_locs, "drifts"):
self.initialise_tiles(self.tile_size)
if hasattr(self.tile_locs, "drifts"):
drift_len = len(self.tile_locs.drifts)
if self.n_processed != drift_len:
warnings.warn("Tiler:n_processed and ndrifts don't match")
self.n_processed = drift_len
# determine drift
if self.no_processed != drift_len:
warnings.warn(
"Tiler: the number of processed tiles and the number of drifts"
" calculated do not match."
)
self.no_processed = drift_len
# determine drift for this time point and update tile_locs.drifts
self.find_drift(tp)
# update n_processed
self.n_processed = tp + 1
# update no_processed
self.no_processed = tp + 1
# return result for writer
return self.tile_locs.to_dict(tp)
def run(self, time_dim=None):
"""
Tile all time points in an experiment at once.
"""
"""Tile all time points in an experiment at once."""
if time_dim is None:
time_dim = 0
for frame in range(self.image.shape[time_dim]):
self.run_tp(frame)
return None
def get_traps_timepoint(self, *args, **kwargs):
self._log(
"get_traps_timepoint is deprecated; get_tiles_timepoint instead."
)
return self.get_tiles_timepoint(*args, **kwargs)
# The next set of functions are necessary for the extraction object
def get_tiles_timepoint(
self, tp: int, tile_shape=None, channels=None, z: int = 0
self, tp: int, channels=None, z: int = 0
) -> np.ndarray:
"""
Get a multidimensional array with all tiles for a set of channels
......@@ -558,33 +550,23 @@ class Tiler(StepABC):
Returns
-------
res: array
Data arranged as (tiles, channels, time points, X, Y, Z)
Data arranged as (tiles, channels, Z, X, Y)
"""
# FIXME add support for sub-tiling a tile
# FIXME can we ignore z
if channels is None:
channels = [0]
elif isinstance(channels, str):
channels = [channels]
# get the data
# get the data as a list of length of the number of channels
res = []
for c in channels:
# only return requested z
val = self.get_tp_data(tp, c)[:, z]
# starts with the order: tiles, z, y, x
# returns the order: tiles, C, T, Z, X, Y
val = np.expand_dims(val, axis=1)
res.append(val)
if tile_shape is not None:
if isinstance(tile_shape, int):
tile_shape = (tile_shape, tile_shape)
assert np.all(
[
(tile_size - ax) > -1
for tile_size, ax in zip(tile_shape, res[0].shape[-3:-2])
]
)
return np.stack(res, axis=1)
tiles = self.get_tp_data(tp, c)[:, z]
# insert new axis at index 1 for missing channel
tiles = np.expand_dims(tiles, axis=1)
res.append(tiles)
# stack over channels if more than one
final = np.stack(res, axis=1)
return final
@property
def ref_channel_index(self):
......@@ -593,32 +575,35 @@ class Tiler(StepABC):
def get_channel_index(self, channel: str or int) -> int or None:
"""
Find index for channel using regex. Returns the first matched string.
If self.channels is integers (no image metadata) it returns None.
If channel is integer
Find index for channel using regex.
If channels are strings, return the first matched string.
If channels are integers, return channel unchanged if it is
an integer.
Parameters
----------
channel: string or int
The channel or index to be used.
"""
if all(map(lambda x: isinstance(x, int), self.channels)):
channel = channel if isinstance(channel, int) else None
if isinstance(channel, str):
channel = find_channel_index(self.channels, channel)
return channel
if isinstance(channel, int) and all(
map(lambda x: isinstance(x, int), self.channels)
):
return channel
elif isinstance(channel, str):
return find_channel_index(self.channels, channel)
else:
return None
@staticmethod
def ifoob_pad(full, slices):
def if_out_of_bounds_pad(image_array, slices):
"""
Return the slices padded if out of bounds.
Pad slices if out of bounds.
Parameters
----------
full: array
Slice of OMERO image (zstacks, x, y) - the entire position
Slice of image (zstacks, x, y) - the entire position
with zstacks as first axis
slices: tuple of two slices
Delineates indices for the x- and y- ranges of the tile.
......@@ -631,11 +616,11 @@ class Tiler(StepABC):
If much padding is needed, a tile of NaN is returned.
"""
# number of pixels in the y direction
max_size = full.shape[-1]
max_size = image_array.shape[-1]
# ignore parts of the tile outside of the image
y, x = [slice(max(0, s.start), min(max_size, s.stop)) for s in slices]
# get the tile including all z stacks
tile = full[:, y, x]
tile = image_array[:, y, x]
# find extent of padding needed in x and y
padding = np.array(
[(-min(0, s.start), -min(0, max_size - s.stop)) for s in slices]
......@@ -643,43 +628,31 @@ class Tiler(StepABC):
if padding.any():
tile_size = slices[0].stop - slices[0].start
if (padding > tile_size / 4).any():
# too much of the tile is outside of the image
# fill with NaN
tile = np.full((full.shape[0], tile_size, tile_size), np.nan)
# fill with NaN because too much of the tile is outside of the image
tile = np.full(
(image_array.shape[0], tile_size, tile_size), np.nan
)
else:
# pad tile with median value of the tile
tile = np.pad(tile, [[0, 0]] + padding.tolist(), "median")
return tile
# FIXME: Refactor to support both channel or index
# self._log below is not defined
def find_channel_index(image_channels: t.List[str], channel: str):
"""
Access
"""
for i, ch in enumerate(image_channels):
found = re.match(channel, ch, re.IGNORECASE)
def find_channel_index(image_channels: t.List[str], channel_regex: str):
"""Use a regex to find the index of a channel."""
for index, ch in enumerate(image_channels):
found = re.match(channel_regex, ch, re.IGNORECASE)
if found:
if len(found.string) - (found.endpos - found.start()):
logging.getLogger("aliby").log(
logging.WARNING,
f"Channel {channel} matched {ch} using regex",
f"Channel {channel_regex} matched {ch} using regex",
)
return i
return index
def find_channel_name(image_channels: t.List[str], channel: str):
"""
Find the name of the channel using regex.
Parameters
----------
image_channels: list of str
Channels.
channel: str
A regular expression.
"""
index = find_channel_index(image_channels, channel)
def find_channel_name(image_channels: t.List[str], channel_regex: str):
"""Find the name of the channel using regex."""
index = find_channel_index(image_channels, channel_regex)
if index is not None:
return image_channels[index]
......@@ -169,7 +169,7 @@ class RemoteImageViewer(BaseImageViewer):
with self._image_class(self.image_id, **server_info) as image:
self.tiler.image = image.data
return self.tiler.get_tc(tp, channel)
return self.tiler.load_image(tp, channel)
def _find_channels(self, channels: str, guess: bool = True):
channels = channels or self.tiler.ref_channel
......