diff --git a/src/agora/io/cells.py b/src/agora/io/cells.py index 6c783ee3ce82def4d1ad034598f9c777ab349771..63b5880f45938f4865a3b6ab408ff949ecd5f61c 100644 --- a/src/agora/io/cells.py +++ b/src/agora/io/cells.py @@ -278,7 +278,7 @@ class Cells: def ntimepoints(self) -> int: return self["timepoint"].max() + 1 - @property + @cached_property def _cells_vs_tps(self): # Binary matrix showing the presence of all cells in all time points ncells_per_tile = [len(x) for x in self.labels] @@ -286,13 +286,29 @@ class Cells: (sum(ncells_per_tile), self.ntimepoints), dtype=bool ) - cumsum = np.roll(np.cumsum(ncells_per_tile), shift=1) - cumsum[0] = 0 cells_vs_tps[ - cumsum[self["trap"]] + self["cell_label"] - 1, self["timepoint"] + self._cell_cumsum[self["trap"]] + self["cell_label"] - 1, + self["timepoint"], ] = True return cells_vs_tps + @cached_property + def _cell_cumsum(self): + # Cumulative sum indicating the number of cells per tile + ncells_per_tile = [len(x) for x in self.labels] + cumsum = np.roll(np.cumsum(ncells_per_tile), shift=1) + cumsum[0] = 0 + return cumsum + + def _flat_index_to_tuple_location(self, idx: int) -> t.Tuple[int, int]: + # Convert a cell index to a tuple + # Note that it assumes tiles and cell labels are flattened, but + # it is agnostic to tps + + tile_id = int(np.where(idx + 1 > self._cell_cumsum)[0][-1]) + cell_label = idx - self._cell_cumsum[tile_id] + 1 + return tile_id, cell_label + @property def _tiles_vs_cells_vs_tps(self): ncells_mat = np.zeros( @@ -304,39 +320,28 @@ class Cells: ] = True return ncells_mat - def cell_tp_where(self, min_consecutive_tps: int = 15): + def cell_tp_where( + self, + min_consecutive_tps: int = 15, + interval: None or t.Tuple[int, int] = None, + ): window = sliding_window_view( self._cells_vs_tps, min_consecutive_tps, axis=1 ) tp_min = window.sum(axis=-1) == min_consecutive_tps - return tp_min - - def matrix_trap_tp_where( - self, min_ncells: int = 2, min_consecutive_tps: int = 5 - ): - """ - Return a matrix of shape (ntraps x ntps - min_consecutive_tps to - indicate traps and time-points where min_ncells are available for at least min_consecutive_tps - Parameters - --------- - min_ncells: int Minimum number of cells - min_consecutive_tps: int - Minimum number of time-points a + # Apply an interval filter to focucs on a slice + if interval is not None: + interval = tuple(np.array(interval) - min_consecutive_tps // 2) + else: + interval = (0, window.shape[1]) - Returns - --------- - (ntraps x ( ntps-min_consecutive_tps )) 2D boolean numpy array where rows are trap ids and columns are timepoint windows. - If the value in a cell is true its corresponding trap and timepoint contains more than min_ncells for at least min_consecutive time-points. - """ + low_boundary, high_boundary = interval - window = sliding_window_view( - self._tiles_vs_cells_vs_tps, min_consecutive_tps, axis=2 - ) - tp_min = window.sum(axis=-1) == min_consecutive_tps - ncells_tp_min = tp_min.sum(axis=1) >= min_ncells - return ncells_tp_min + tp_min[:, :low_boundary] = False + tp_min[:, high_boundary:] = False + return tp_min @lru_cache(20) def mothers_in_trap(self, trap_id: int): @@ -572,10 +577,9 @@ class Cells: def _sample_tiles_tps( self, size=1, - # min_ncells: int = 2, - # max_ncells: int = 2, - min_consecutive_ntps: int = 10, + min_consecutive_ntps: int = 15, seed: int = 0, + interval=None, ) -> t.Tuple[np.ndarray, np.ndarray]: """ Sample tiles that have a minimum number of cells and are occupied for at least a minimum number of consecutive timepoints. @@ -590,10 +594,12 @@ class Cells: The minimum number of consecutive timepoints a cell must be present in a trap. seed: int, optional (default=0) Random seed value for reproducibility. + interval: None or Tuple(int,int), optional (default=None) + Random seed value for reproducibility. Returns ------- - Tuple[np.ndarray, np.ndarray] + Tuple[np.ndarray, np.ndarray,np.ndarray] A tuple of 1D numpy arrays containing the indices of the sampled tiles and the corresponding timepoints. """ # cell_availability_matrix = self.matrix_trap_tp_where( @@ -603,30 +609,43 @@ class Cells: # # Find all valid tiles with min_ncells for at least min_tps # tile_ids, tps = np.where(cell_availability_matrix) cell_availability_matrix = self.cell_tp_where( - min_consecutive_tps=min_consecutive_ntps + min_consecutive_tps=min_consecutive_ntps, + interval=interval, ) # Find all valid tiles with min_ncells for at least min_tps - index_id, _ = np.where(cell_availability_matrix) + index_id, tps = np.where(cell_availability_matrix) + + if interval is None: # Limit search + interval = (0, cell_availability_matrix.shape[1]) np.random.seed(seed) - choices = np.random.choice(index_id, size=size) - return ( - self["trap"][choices], - self["cell_label"][choices], - self["timepoint"][choices], - ) + choices = np.random.randint(len(index_id), size=size) + + linear_indices = np.zeros_like(self["cell_label"], dtype=bool) + for cell_index_flat, tp in zip(index_id[choices], tps[choices]): + tile_id, cell_label = self._flat_index_to_tuple_location( + cell_index_flat + ) + linear_indices[ + ( + (self["cell_label"] == cell_label) + & (self["trap"] == tile_id) + & (self["timepoint"] == tp) + ) + ] = True + + return linear_indices def _sample_masks( self, - size=1, - min_ncells: int = 2, - min_consecutive_ntps: int = 5, + size: int = 1, + min_consecutive_ntps: int = 15, interval: t.Union[None, t.Tuple[int, int]] = None, - seed=0, + seed: int = 0, ) -> t.Tuple[t.Tuple[t.List[int], t.List[int], t.List[int]], np.ndarray]: """ - Sample a number of cells from different tiles each. + Sample a number of cells from within an interval. Parameters ---------- @@ -649,27 +668,53 @@ class Cells: The second tuple contains: - `masks`: A list of 2D numpy arrays representing the binary masks of the sampled cells at each timepoint. """ - tile_ids, _, tps = self._sample_tiles_tps( + sampled_bitmask = self._sample_tiles_tps( size=size, min_consecutive_ntps=min_consecutive_ntps, seed=seed, + interval=interval, ) + # Sort sampled tiles to use automatic cache when possible - order = np.argsort(tps) - tile_ids = tile_ids[order].tolist() - tps = tps[order].tolist() + tile_ids = self["trap"][sampled_bitmask] + cell_labels = self["cell_label"][sampled_bitmask] + tps = self["timepoint"][sampled_bitmask] - cell_ids = [] masks = [] - for tp, tile_id in zip(tps, tile_ids): - tile_masks = self.at_time(tp)[tile_id] + for tile_id, cell_label, tp in zip(tile_ids, cell_labels, tps): + local_idx = self.labels_at_time(tp)[tile_id].index(cell_label) + tile_mask = self.at_time(tp)[tile_id][local_idx] + masks.append(tile_mask) + + return (tile_ids, cell_labels, tps), np.stack(masks) - np.random.seed(seed) - cell_id = np.random.randint(len(tile_masks)) + def matrix_trap_tp_where( + self, min_ncells: int = 2, min_consecutive_tps: int = 5 + ): + """ + NOTE CURRENLTY UNUSED WITHIN ALIBY THE MOMENT. MAY BE USEFUL IN THE FUTURE. + + Return a matrix of shape (ntraps x ntps - min_consecutive_tps) to + indicate traps and time-points where min_ncells are available for at least min_consecutive_tps - cell_ids.append(cell_id) - masks.append(tile_masks[cell_id]) - return (tile_ids, cell_ids, tps), np.stack(masks) + Parameters + --------- + min_ncells: int Minimum number of cells + min_consecutive_tps: int + Minimum number of time-points a + + Returns + --------- + (ntraps x ( ntps-min_consecutive_tps )) 2D boolean numpy array where rows are trap ids and columns are timepoint windows. + If the value in a cell is true its corresponding trap and timepoint contains more than min_ncells for at least min_consecutive time-points. + """ + + window = sliding_window_view( + self._tiles_vs_cells_vs_tps, min_consecutive_tps, axis=2 + ) + tp_min = window.sum(axis=-1) == min_consecutive_tps + ncells_tp_min = tp_min.sum(axis=1) >= min_ncells + return ncells_tp_min def stack_masks_in_tile(