From 091ca6511bccbd5869576f796dded3ecda97def7 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <alan.munoz@ed.ac.uk>
Date: Mon, 27 Feb 2023 10:42:09 +0000
Subject: [PATCH] [WIP]feat(cells): add cell sampling

---
 src/agora/io/cells.py | 44 +++++++++++++++++++++++++++++++++++++++++--
 1 file changed, 42 insertions(+), 2 deletions(-)

diff --git a/src/agora/io/cells.py b/src/agora/io/cells.py
index 182590b7..0e2891f3 100644
--- a/src/agora/io/cells.py
+++ b/src/agora/io/cells.py
@@ -409,14 +409,54 @@ class Cells:
         # Stack all cells in a trap-wise manner
         masks = self.at_time(frame)
         return [
-            stack_masks_in_trap(
+            stack_masks_in_tile(
                 masks.get(trap_id, np.array([], dtype=bool)), tile_shape
             )
             for trap_id in range(self.ntraps)
         ]
 
+    def _sample_occupied_tiles_tp(
+        self,
+        size=1,
+        min_ncells: int = 2,
+        min_consecutive_ntps: int = 5,
+    ):
+        cell_availability_matrix = self.matrix_trap_tp_where(
+            min_ncells=min_ncells, min_consecutive_tps=min_consecutive_ntps
+        )
+
+        # Find all valid tiles with min_ncells for at least min_tps
+        tile_ids, tps = np.where(cell_availability_matrix)
+
+        choices = np.random.choice(len(tile_ids), size=size)
+        return tile_ids[choices], tps[choices]
+
+    def _sample_masks(
+        self,
+        size=1,
+        min_ncells: int = 2,
+        min_consecutive_ntps: int = 5,
+    ):
+        """Sample a number of cells from different traps each."""
+        tile_ids, tps = self._sample_occupied_tiles_tp(
+            size=size,
+            min_ncells=min_ncells,
+            min_consecutive_ntps=min_consecutive_ntps,
+        )
+        #  Sort sampled tiles to use automatic cache when possible
+        order = np.argsort(tps)
+        tile_ids = tile_ids[order]
+        tps = tps[order]
+
+        return (tile_ids, tps), np.array(
+            [
+                self.at_time(tp)[tile_ids[tile_id]]
+                for tp, tile_id in zip(tps, tile_ids)
+            ]
+        )
+
 
-def stack_masks_in_trap(
+def stack_masks_in_tile(
     masks: t.List[np.ndarray], tile_shape: t.Tuple[int]
 ) -> np.ndarray:
     # Stack all masks in a trap padding accordingly if no outlines found
-- 
GitLab