From a9a6291d8c36c6ed2134a51f16e73d1530fc50e0 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <alan.munoz@ed.ac.uk>
Date: Tue, 7 Mar 2023 19:59:47 +0000
Subject: [PATCH] fix(cells): _sample_masks works properly

---
 src/agora/io/cells.py | 161 +++++++++++++++++++++++++++---------------
 1 file changed, 103 insertions(+), 58 deletions(-)

diff --git a/src/agora/io/cells.py b/src/agora/io/cells.py
index 6c783ee3..63b5880f 100644
--- a/src/agora/io/cells.py
+++ b/src/agora/io/cells.py
@@ -278,7 +278,7 @@ class Cells:
     def ntimepoints(self) -> int:
         return self["timepoint"].max() + 1
 
-    @property
+    @cached_property
     def _cells_vs_tps(self):
         # Binary matrix showing the presence of all cells in all time points
         ncells_per_tile = [len(x) for x in self.labels]
@@ -286,13 +286,29 @@ class Cells:
             (sum(ncells_per_tile), self.ntimepoints), dtype=bool
         )
 
-        cumsum = np.roll(np.cumsum(ncells_per_tile), shift=1)
-        cumsum[0] = 0
         cells_vs_tps[
-            cumsum[self["trap"]] + self["cell_label"] - 1, self["timepoint"]
+            self._cell_cumsum[self["trap"]] + self["cell_label"] - 1,
+            self["timepoint"],
         ] = True
         return cells_vs_tps
 
+    @cached_property
+    def _cell_cumsum(self):
+        # Cumulative sum indicating the number of cells per tile
+        ncells_per_tile = [len(x) for x in self.labels]
+        cumsum = np.roll(np.cumsum(ncells_per_tile), shift=1)
+        cumsum[0] = 0
+        return cumsum
+
+    def _flat_index_to_tuple_location(self, idx: int) -> t.Tuple[int, int]:
+        # Convert a cell index to a tuple
+        # Note that it assumes tiles and cell labels are flattened, but
+        # it is agnostic to tps
+
+        tile_id = int(np.where(idx + 1 > self._cell_cumsum)[0][-1])
+        cell_label = idx - self._cell_cumsum[tile_id] + 1
+        return tile_id, cell_label
+
     @property
     def _tiles_vs_cells_vs_tps(self):
         ncells_mat = np.zeros(
@@ -304,39 +320,28 @@ class Cells:
         ] = True
         return ncells_mat
 
-    def cell_tp_where(self, min_consecutive_tps: int = 15):
+    def cell_tp_where(
+        self,
+        min_consecutive_tps: int = 15,
+        interval: None or t.Tuple[int, int] = None,
+    ):
         window = sliding_window_view(
             self._cells_vs_tps, min_consecutive_tps, axis=1
         )
 
         tp_min = window.sum(axis=-1) == min_consecutive_tps
-        return tp_min
-
-    def matrix_trap_tp_where(
-        self, min_ncells: int = 2, min_consecutive_tps: int = 5
-    ):
-        """
-        Return a matrix of shape (ntraps x ntps - min_consecutive_tps to
-        indicate traps and time-points where min_ncells are available for at least min_consecutive_tps
 
-        Parameters
-        ---------
-            min_ncells: int Minimum number of cells
-            min_consecutive_tps: int
-                Minimum number of time-points a
+        # Apply an interval filter to focucs on a slice
+        if interval is not None:
+            interval = tuple(np.array(interval) - min_consecutive_tps // 2)
+        else:
+            interval = (0, window.shape[1])
 
-        Returns
-        ---------
-            (ntraps x ( ntps-min_consecutive_tps )) 2D boolean numpy array where rows are trap ids and columns are timepoint windows.
-            If the value in a cell is true its corresponding trap and timepoint contains more than min_ncells for at least min_consecutive time-points.
-        """
+        low_boundary, high_boundary = interval
 
-        window = sliding_window_view(
-            self._tiles_vs_cells_vs_tps, min_consecutive_tps, axis=2
-        )
-        tp_min = window.sum(axis=-1) == min_consecutive_tps
-        ncells_tp_min = tp_min.sum(axis=1) >= min_ncells
-        return ncells_tp_min
+        tp_min[:, :low_boundary] = False
+        tp_min[:, high_boundary:] = False
+        return tp_min
 
     @lru_cache(20)
     def mothers_in_trap(self, trap_id: int):
@@ -572,10 +577,9 @@ class Cells:
     def _sample_tiles_tps(
         self,
         size=1,
-        # min_ncells: int = 2,
-        # max_ncells: int = 2,
-        min_consecutive_ntps: int = 10,
+        min_consecutive_ntps: int = 15,
         seed: int = 0,
+        interval=None,
     ) -> t.Tuple[np.ndarray, np.ndarray]:
         """
         Sample tiles that have a minimum number of cells and are occupied for at least a minimum number of consecutive timepoints.
@@ -590,10 +594,12 @@ class Cells:
             The minimum number of consecutive timepoints a cell must be present in a trap.
         seed: int, optional (default=0)
             Random seed value for reproducibility.
+        interval: None or Tuple(int,int), optional (default=None)
+            Random seed value for reproducibility.
 
         Returns
         -------
-        Tuple[np.ndarray, np.ndarray]
+        Tuple[np.ndarray, np.ndarray,np.ndarray]
             A tuple of 1D numpy arrays containing the indices of the sampled tiles and the corresponding timepoints.
         """
         # cell_availability_matrix = self.matrix_trap_tp_where(
@@ -603,30 +609,43 @@ class Cells:
         # # Find all valid tiles with min_ncells for at least min_tps
         # tile_ids, tps = np.where(cell_availability_matrix)
         cell_availability_matrix = self.cell_tp_where(
-            min_consecutive_tps=min_consecutive_ntps
+            min_consecutive_tps=min_consecutive_ntps,
+            interval=interval,
         )
 
         # Find all valid tiles with min_ncells for at least min_tps
-        index_id, _ = np.where(cell_availability_matrix)
+        index_id, tps = np.where(cell_availability_matrix)
+
+        if interval is None:  # Limit search
+            interval = (0, cell_availability_matrix.shape[1])
 
         np.random.seed(seed)
-        choices = np.random.choice(index_id, size=size)
-        return (
-            self["trap"][choices],
-            self["cell_label"][choices],
-            self["timepoint"][choices],
-        )
+        choices = np.random.randint(len(index_id), size=size)
+
+        linear_indices = np.zeros_like(self["cell_label"], dtype=bool)
+        for cell_index_flat, tp in zip(index_id[choices], tps[choices]):
+            tile_id, cell_label = self._flat_index_to_tuple_location(
+                cell_index_flat
+            )
+            linear_indices[
+                (
+                    (self["cell_label"] == cell_label)
+                    & (self["trap"] == tile_id)
+                    & (self["timepoint"] == tp)
+                )
+            ] = True
+
+        return linear_indices
 
     def _sample_masks(
         self,
-        size=1,
-        min_ncells: int = 2,
-        min_consecutive_ntps: int = 5,
+        size: int = 1,
+        min_consecutive_ntps: int = 15,
         interval: t.Union[None, t.Tuple[int, int]] = None,
-        seed=0,
+        seed: int = 0,
     ) -> t.Tuple[t.Tuple[t.List[int], t.List[int], t.List[int]], np.ndarray]:
         """
-        Sample a number of cells from different tiles each.
+        Sample a number of cells from within an interval.
 
         Parameters
         ----------
@@ -649,27 +668,53 @@ class Cells:
             The second tuple contains:
             - `masks`: A list of 2D numpy arrays representing the binary masks of the sampled cells at each timepoint.
         """
-        tile_ids, _, tps = self._sample_tiles_tps(
+        sampled_bitmask = self._sample_tiles_tps(
             size=size,
             min_consecutive_ntps=min_consecutive_ntps,
             seed=seed,
+            interval=interval,
         )
+
         #  Sort sampled tiles to use automatic cache when possible
-        order = np.argsort(tps)
-        tile_ids = tile_ids[order].tolist()
-        tps = tps[order].tolist()
+        tile_ids = self["trap"][sampled_bitmask]
+        cell_labels = self["cell_label"][sampled_bitmask]
+        tps = self["timepoint"][sampled_bitmask]
 
-        cell_ids = []
         masks = []
-        for tp, tile_id in zip(tps, tile_ids):
-            tile_masks = self.at_time(tp)[tile_id]
+        for tile_id, cell_label, tp in zip(tile_ids, cell_labels, tps):
+            local_idx = self.labels_at_time(tp)[tile_id].index(cell_label)
+            tile_mask = self.at_time(tp)[tile_id][local_idx]
+            masks.append(tile_mask)
+
+        return (tile_ids, cell_labels, tps), np.stack(masks)
 
-            np.random.seed(seed)
-            cell_id = np.random.randint(len(tile_masks))
+    def matrix_trap_tp_where(
+        self, min_ncells: int = 2, min_consecutive_tps: int = 5
+    ):
+        """
+        NOTE CURRENLTY UNUSED WITHIN ALIBY THE MOMENT. MAY BE USEFUL IN THE FUTURE.
+
+        Return a matrix of shape (ntraps x ntps - min_consecutive_tps) to
+        indicate traps and time-points where min_ncells are available for at least min_consecutive_tps
 
-            cell_ids.append(cell_id)
-            masks.append(tile_masks[cell_id])
-        return (tile_ids, cell_ids, tps), np.stack(masks)
+        Parameters
+        ---------
+            min_ncells: int Minimum number of cells
+            min_consecutive_tps: int
+                Minimum number of time-points a
+
+        Returns
+        ---------
+            (ntraps x ( ntps-min_consecutive_tps )) 2D boolean numpy array where rows are trap ids and columns are timepoint windows.
+            If the value in a cell is true its corresponding trap and timepoint contains more than min_ncells for at least min_consecutive time-points.
+        """
+
+        window = sliding_window_view(
+            self._tiles_vs_cells_vs_tps, min_consecutive_tps, axis=2
+        )
+        tp_min = window.sum(axis=-1) == min_consecutive_tps
+        ncells_tp_min = tp_min.sum(axis=1) >= min_ncells
+        return ncells_tp_min
 
 
 def stack_masks_in_tile(
-- 
GitLab