diff --git a/io/cells.py b/io/cells.py
index 072ee17d55a7541274ac869c0e0babc03e36adc7..a8ca61b0e13c28b266d5bb2b9bd8fd02384106e3 100644
--- a/io/cells.py
+++ b/io/cells.py
@@ -3,6 +3,7 @@ import typing as t
 from collections.abc import Iterable
 from itertools import groupby
 from pathlib import Path, PosixPath
+from functools import lru_cache
 
 import h5py
 import numpy as np
@@ -55,7 +56,11 @@ class Cells:
 
     @property
     def max_labels(self) -> t.List[int]:
-        return [max(self.labels_in_trap(i)) for i in range(self.ntraps)]
+        return [max((0, *self.labels_in_trap(i))) for i in range(self.ntraps)]
+
+    @property
+    def max_label(self) -> int:
+        return sum(self.max_labels)
 
     @property
     def ntraps(self) -> int:
@@ -102,6 +107,16 @@ class Cells:
         """
         return [self.labels_in_trap(trap) for trap in range(self.ntraps)]
 
+    def max_labels_in_frame(self, frame: int) -> t.List[int]:
+        # Return the maximum label for each trap in the given frame
+        max_labels = [
+            self["cell_label"][
+                (self["timepoint"] <= frame) & (self["trap"] == trap_id)
+            ]
+            for trap_id in range(self.ntraps)
+        ]
+        return [max([0, *labels]) for labels in max_labels]
+
     def where(self, cell_id: int, trap_id: int):
         """
         Parameters
@@ -321,3 +336,68 @@ class Cells:
         nested_massign = [d.get(i, []) for i in range(ntraps)]
 
         return nested_massign
+
+    @lru_cache(maxsize=200)
+    def labelled_in_frame(self, frame: int, global_id=False) -> np.ndarray:
+        """
+        Return labels in a ndarray with the global ids
+        with shape (ntraps, max_nlabels, ysize, xsize)
+        at a given frame.
+
+        max_nlabels is specific for this frame, not
+        the entire experiment.
+        """
+        labels_in_frame = self.labels_at_time(frame)
+        n_labels = [
+            len(labels_in_frame.get(trap_id, []))
+            for trap_id in range(self.ntraps)
+        ]
+        # maxes = self.max_labels_in_frame(frame)
+        stacks_in_frame = self.get_stacks_in_frame(frame, self.tile_size)
+        first_id = np.cumsum([0, *n_labels])
+        labels_mat = np.zeros(
+            (
+                self.ntraps,
+                max(n_labels),
+                *self.tile_size,
+            ),
+            dtype=int,
+        )
+        for trap_id, masks in enumerate(stacks_in_frame):  # new_axis = np.pad(
+            if trap_id in labels_in_frame:
+                new_axis = np.array(labels_in_frame[trap_id], dtype=int)[
+                    :, np.newaxis, np.newaxis
+                ]
+                global_id_masks = new_axis * masks
+                if global_id:
+                    global_id_masks += first_id[trap_id] * masks
+                global_id_masks = np.pad(
+                    global_id_masks,
+                    pad_width=(
+                        (0, labels_mat.shape[1] - global_id_masks.shape[0]),
+                        (0, 0),
+                        (0, 0),
+                    ),
+                )
+                labels_mat[trap_id] += global_id_masks
+        return labels_mat
+
+    def get_stacks_in_frame(self, frame: int, tile_shape: t.Tuple[int]):
+        # Stack all cells in a trap-wise manner
+        masks = self.at_time(frame)
+        return [
+            stack_masks_in_trap(
+                masks.get(trap_id, np.array([], dtype=bool)), tile_shape
+            )
+            for trap_id in range(self.ntraps)
+        ]
+
+
+def stack_masks_in_trap(
+    masks: t.List[np.ndarray], tile_shape: t.Tuple[int]
+) -> np.ndarray:
+    # Stack all masks in a trap padding accordingly if no outlines found
+    result = np.zeros((0, *tile_shape), dtype=bool)
+    if len(masks):
+        result = np.array(masks)
+    return result