diff --git a/src/agora/io/cells.py b/src/agora/io/cells.py index 182590b7ee2637a8516c902f79f6c763c47454dd..0e2891f3151bf6ccad52156f84431c26f6ef44b3 100644 --- a/src/agora/io/cells.py +++ b/src/agora/io/cells.py @@ -409,14 +409,54 @@ class Cells: # Stack all cells in a trap-wise manner masks = self.at_time(frame) return [ - stack_masks_in_trap( + stack_masks_in_tile( masks.get(trap_id, np.array([], dtype=bool)), tile_shape ) for trap_id in range(self.ntraps) ] + def _sample_occupied_tiles_tp( + self, + size=1, + min_ncells: int = 2, + min_consecutive_ntps: int = 5, + ): + cell_availability_matrix = self.matrix_trap_tp_where( + min_ncells=min_ncells, min_consecutive_tps=min_consecutive_ntps + ) + + # Find all valid tiles with min_ncells for at least min_tps + tile_ids, tps = np.where(cell_availability_matrix) + + choices = np.random.choice(len(tile_ids), size=size) + return tile_ids[choices], tps[choices] + + def _sample_masks( + self, + size=1, + min_ncells: int = 2, + min_consecutive_ntps: int = 5, + ): + """Sample a number of cells from different traps each.""" + tile_ids, tps = self._sample_occupied_tiles_tp( + size=size, + min_ncells=min_ncells, + min_consecutive_ntps=min_consecutive_ntps, + ) + # Sort sampled tiles to use automatic cache when possible + order = np.argsort(tps) + tile_ids = tile_ids[order] + tps = tps[order] + + return (tile_ids, tps), np.array( + [ + self.at_time(tp)[tile_ids[tile_id]] + for tp, tile_id in zip(tps, tile_ids) + ] + ) + -def stack_masks_in_trap( +def stack_masks_in_tile( masks: t.List[np.ndarray], tile_shape: t.Tuple[int] ) -> np.ndarray: # Stack all masks in a trap padding accordingly if no outlines found