Skip to content
Snippets Groups Projects
Commit f1c29ec5 authored by Alán Muñoz's avatar Alán Muñoz
Browse files

feat(cells): get labels as 4-D array

parent 32bc8ffa
No related branches found
No related tags found
No related merge requests found
...@@ -3,6 +3,7 @@ import typing as t ...@@ -3,6 +3,7 @@ import typing as t
from collections.abc import Iterable from collections.abc import Iterable
from itertools import groupby from itertools import groupby
from pathlib import Path, PosixPath from pathlib import Path, PosixPath
from functools import lru_cache
import h5py import h5py
import numpy as np import numpy as np
...@@ -55,7 +56,11 @@ class Cells: ...@@ -55,7 +56,11 @@ class Cells:
@property @property
def max_labels(self) -> t.List[int]: def max_labels(self) -> t.List[int]:
return [max(self.labels_in_trap(i)) for i in range(self.ntraps)] return [max((0, *self.labels_in_trap(i))) for i in range(self.ntraps)]
@property
def max_label(self) -> int:
return sum(self.max_labels)
@property @property
def ntraps(self) -> int: def ntraps(self) -> int:
...@@ -102,6 +107,16 @@ class Cells: ...@@ -102,6 +107,16 @@ class Cells:
""" """
return [self.labels_in_trap(trap) for trap in range(self.ntraps)] return [self.labels_in_trap(trap) for trap in range(self.ntraps)]
def max_labels_in_frame(self, frame: int) -> t.List[int]:
# Return the maximum label for each trap in the given frame
max_labels = [
self["cell_label"][
(self["timepoint"] <= frame) & (self["trap"] == trap_id)
]
for trap_id in range(self.ntraps)
]
return [max([0, *labels]) for labels in max_labels]
def where(self, cell_id: int, trap_id: int): def where(self, cell_id: int, trap_id: int):
""" """
Parameters Parameters
...@@ -321,3 +336,68 @@ class Cells: ...@@ -321,3 +336,68 @@ class Cells:
nested_massign = [d.get(i, []) for i in range(ntraps)] nested_massign = [d.get(i, []) for i in range(ntraps)]
return nested_massign return nested_massign
@lru_cache(maxsize=200)
def labelled_in_frame(self, frame: int, global_id=False) -> np.ndarray:
"""
Return labels in a ndarray with the global ids
with shape (ntraps, max_nlabels, ysize, xsize)
at a given frame.
max_nlabels is specific for this frame, not
the entire experiment.
"""
labels_in_frame = self.labels_at_time(frame)
n_labels = [
len(labels_in_frame.get(trap_id, []))
for trap_id in range(self.ntraps)
]
# maxes = self.max_labels_in_frame(frame)
stacks_in_frame = self.get_stacks_in_frame(frame, self.tile_size)
first_id = np.cumsum([0, *n_labels])
labels_mat = np.zeros(
(
self.ntraps,
max(n_labels),
*self.tile_size,
),
dtype=int,
)
for trap_id, masks in enumerate(stacks_in_frame): # new_axis = np.pad(
if trap_id in labels_in_frame:
new_axis = np.array(labels_in_frame[trap_id], dtype=int)[
:, np.newaxis, np.newaxis
]
global_id_masks = new_axis * masks
if global_id:
global_id_masks += first_id[trap_id] * masks
global_id_masks = np.pad(
global_id_masks,
pad_width=(
(0, labels_mat.shape[1] - global_id_masks.shape[0]),
(0, 0),
(0, 0),
),
)
labels_mat[trap_id] += global_id_masks
return labels_mat
def get_stacks_in_frame(self, frame: int, tile_shape: t.Tuple[int]):
# Stack all cells in a trap-wise manner
masks = self.at_time(frame)
return [
stack_masks_in_trap(
masks.get(trap_id, np.array([], dtype=bool)), tile_shape
)
for trap_id in range(self.ntraps)
]
def stack_masks_in_trap(
masks: t.List[np.ndarray], tile_shape: t.Tuple[int]
) -> np.ndarray:
# Stack all masks in a trap padding accordingly if no outlines found
result = np.zeros((0, *tile_shape), dtype=bool)
if len(masks):
result = np.array(masks)
return result
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment