diff --git a/io/cells.py b/io/cells.py index 072ee17d55a7541274ac869c0e0babc03e36adc7..a8ca61b0e13c28b266d5bb2b9bd8fd02384106e3 100644 --- a/io/cells.py +++ b/io/cells.py @@ -3,6 +3,7 @@ import typing as t from collections.abc import Iterable from itertools import groupby from pathlib import Path, PosixPath +from functools import lru_cache import h5py import numpy as np @@ -55,7 +56,11 @@ class Cells: @property def max_labels(self) -> t.List[int]: - return [max(self.labels_in_trap(i)) for i in range(self.ntraps)] + return [max((0, *self.labels_in_trap(i))) for i in range(self.ntraps)] + + @property + def max_label(self) -> int: + return sum(self.max_labels) @property def ntraps(self) -> int: @@ -102,6 +107,16 @@ class Cells: """ return [self.labels_in_trap(trap) for trap in range(self.ntraps)] + def max_labels_in_frame(self, frame: int) -> t.List[int]: + # Return the maximum label for each trap in the given frame + max_labels = [ + self["cell_label"][ + (self["timepoint"] <= frame) & (self["trap"] == trap_id) + ] + for trap_id in range(self.ntraps) + ] + return [max([0, *labels]) for labels in max_labels] + def where(self, cell_id: int, trap_id: int): """ Parameters @@ -321,3 +336,68 @@ class Cells: nested_massign = [d.get(i, []) for i in range(ntraps)] return nested_massign + + @lru_cache(maxsize=200) + def labelled_in_frame(self, frame: int, global_id=False) -> np.ndarray: + """ + Return labels in a ndarray with the global ids + with shape (ntraps, max_nlabels, ysize, xsize) + at a given frame. + + max_nlabels is specific for this frame, not + the entire experiment. + """ + labels_in_frame = self.labels_at_time(frame) + n_labels = [ + len(labels_in_frame.get(trap_id, [])) + for trap_id in range(self.ntraps) + ] + # maxes = self.max_labels_in_frame(frame) + stacks_in_frame = self.get_stacks_in_frame(frame, self.tile_size) + first_id = np.cumsum([0, *n_labels]) + labels_mat = np.zeros( + ( + self.ntraps, + max(n_labels), + *self.tile_size, + ), + dtype=int, + ) + for trap_id, masks in enumerate(stacks_in_frame): # new_axis = np.pad( + if trap_id in labels_in_frame: + new_axis = np.array(labels_in_frame[trap_id], dtype=int)[ + :, np.newaxis, np.newaxis + ] + global_id_masks = new_axis * masks + if global_id: + global_id_masks += first_id[trap_id] * masks + global_id_masks = np.pad( + global_id_masks, + pad_width=( + (0, labels_mat.shape[1] - global_id_masks.shape[0]), + (0, 0), + (0, 0), + ), + ) + labels_mat[trap_id] += global_id_masks + return labels_mat + + def get_stacks_in_frame(self, frame: int, tile_shape: t.Tuple[int]): + # Stack all cells in a trap-wise manner + masks = self.at_time(frame) + return [ + stack_masks_in_trap( + masks.get(trap_id, np.array([], dtype=bool)), tile_shape + ) + for trap_id in range(self.ntraps) + ] + + +def stack_masks_in_trap( + masks: t.List[np.ndarray], tile_shape: t.Tuple[int] +) -> np.ndarray: + # Stack all masks in a trap padding accordingly if no outlines found + result = np.zeros((0, *tile_shape), dtype=bool) + if len(masks): + result = np.array(masks) + return result