diff --git a/src/agora/io/cells.py b/src/agora/io/cells.py index 0e2891f3151bf6ccad52156f84431c26f6ef44b3..2f5a8536beda5903521fa8c4b0682f96f3bb7cea 100644 --- a/src/agora/io/cells.py +++ b/src/agora/io/cells.py @@ -146,12 +146,47 @@ class Cells: ) def mask(self, cell_id, trap_id): + """ + Returns the times and the binary masks of a given cell in a given tile. + + Parameters + ---------- + cell_id : int + The unique ID of the cell. + tile_id : int + The unique ID of the tile. + + Returns + ------- + Tuple[np.ndarray, np.ndarray] + The times when the binary masks were taken and the binary masks of the given cell in the given tile. + + """ times, outlines = self.outline(cell_id, trap_id) return times, np.array( [ndimage.morphology.binary_fill_holes(o) for o in outlines] ) - def at_time(self, timepoint, kind="mask"): + def at_time( + self, timepoint: t.Iterable[int], kind="mask" + ) -> t.List[t.List[np.ndarray]]: + """ + Returns a list of lists of binary masks in a given list of time points. + + Parameters + ---------- + timepoints : Iterable[int] + The list of time points for which to return the binary masks. + kind : str, optional + The type of binary masks to return, by default "mask". + + Returns + ------- + List[List[np.ndarray]] + A list of lists with binary masks grouped by tile IDs. + + """ + ix = self["timepoint"] == timepoint traps = self["trap"][ix] edgemasks = self._edgem_from_masking(ix) @@ -162,7 +197,25 @@ class Cells: ] return self.group_by_traps(traps, masks) - def at_times(self, timepoints: t.Iterable[int], kind="mask"): + def at_times( + self, timepoints: t.Iterable[int], kind="mask" + ) -> t.List[t.List[np.ndarray]]: + """ + Returns a list of lists of binary masks in a given list of time points. + + Parameters + ---------- + timepoints : Iterable[int] + The list of time points for which to return the binary masks. + kind : str, optional + The type of binary masks to return, by default "mask". + + Returns + ------- + List[List[np.ndarray]] + A list of lists with binary masks grouped by tile IDs. + + """ return [ [ np.stack(tile_masks) if len(tile_masks) else [] @@ -267,17 +320,6 @@ class Cells: ncells_tp_min = tp_min.sum(axis=1) >= min_ncells return ncells_tp_min - def random_valid_trap_tp( - self, min_ncells: int = None, min_consecutive_tps: int = None - ): - # Return a randomly-selected pair of trap_id and timepoints - mat = self.matrix_trap_tp_where( - min_ncells=min_ncells, min_consecutive_tps=min_consecutive_tps - ) - traps, tps = np.where(mat) - rand = np.random.randint(mat.sum()) - return (traps[rand], tps[rand]) - @lru_cache(20) def mothers_in_trap(self, trap_id: int): return self.mothers[trap_id] @@ -297,8 +339,17 @@ class Cells: @cached_property def mothers_daughters(self) -> np.ndarray: """ - Return mothers and daugters as a single array with three columns: - trap, mothers and daughters + Return a single array with three columns, containing information about the mother-daughter relationships: + tile, mothers and daughters. + + Returns + ------- + np.ndarray + An array with shape (n, 3) where n is the number of mother-daughter pairs found. + The columns contain: + - tile: the tile where the mother cell is located. + - mothers: the index of the mother cell within the tile. If there is no mother, this is zero. + - daughters: the index of the daughter cell within the tile. If there is no daughter, this is zero. """ nested_massign = self.mothers @@ -320,7 +371,41 @@ class Cells: @staticmethod def mother_assign_to_mb_matrix(ma: t.List[np.array]): - # Convert from list of lists to mother_bud sparse matrix + """ + Convert from a list of lists of mother-bud paired assignments to a + sparse matrix with a boolean dtype. The rows correspond to + to daughter buds. The values are boolean and indicate whether a + given cell is a mother cell and a given daughter bud is assigned + to the mother cell in the next timepoint. + + Parameters: + ----------- + ma : list of lists of integers + A list of lists of mother-bud assignments. The i-th sublist contains the + bud assignments for the i-th tile. The integers in each sublist + represent the mother label, if it is zero no mother was found. + + Returns: + -------- + mb_matrix : boolean numpy array of shape (n, m) + An n x m boolean numpy array where n is the total number of cells (sum + of the lengths of all sublists in ma) and m is the maximum number of buds + assigned to any mother cell in ma. The value at (i, j) is True if cell i + is a daughter cell and cell j is its mother assigned to i. + + Examples: + -------- + ma = [[0, 0, 1], [0, 1, 0]] + Cells(None).mother_assign_to_mb_matrix(ma) + # array([[False, False, False, False, False, False], + # [False, False, False, False, False, False], + # [ True, False, False, False, False, False], + # [False, False, False, False, False, False], + # [False, False, False, True, False, False], + # [False, False, False, False, False, False]]) + + """ + ncells = sum([len(t) for t in ma]) mb_matrix = np.zeros((ncells, ncells), dtype=bool) c = 0 @@ -335,10 +420,26 @@ class Cells: @staticmethod def mother_assign_from_dynamic( - ma, cell_label: t.List[int], trap: t.List[int], ntraps: int - ): + ma: np.ndarray, cell_label: t.List[int], trap: t.List[int], ntraps: int + ) -> t.List[t.List[int]]: """ - Interpolate the list of lists containing the associated mothers from the mother_assign_dynamic feature + Interpolate the associated mothers from the 'mother_assign_dynamic' feature. + + Parameters + ---------- + ma: np.ndarray + An array with shape (n_t, n_c) containing the 'mother_assign_dynamic' feature. + cell_label: List[int] + A list containing the cell labels. + trap: List[int] + A list containing the trap labels. + ntraps: int + The total number of traps. + + Returns + ------- + List[List[int]] + A list of lists containing the interpolated mother assignment for each cell in each trap. """ idlist = list(zip(trap, cell_label)) cell_gid = np.unique(idlist, axis=0) @@ -363,12 +464,29 @@ class Cells: @lru_cache(maxsize=200) def labelled_in_frame(self, frame: int, global_id=False) -> np.ndarray: """ - Return labels in a ndarray with the global ids - with shape (ntraps, max_nlabels, ysize, xsize) - at a given frame. + Returns labels in a 4D ndarray with the global ids with shape + (ntraps, max_nlabels, ysize, xsize) at a given frame. + + Parameters + ---------- + frame : int + The frame number. + global_id : bool, optional + If True, the returned array contains global ids, otherwise it + contains only the local ids of the labels. Default is False. + + Returns + ------- + np.ndarray + A 4D numpy array containing the labels in the given frame. + The array has dimensions (ntraps, max_nlabels, ysize, xsize), + where max_nlabels is specific for this frame, not the entire + experiment. + + Notes + ----- + This method uses lru_cache to cache the results for faster access. - max_nlabels is specific for this frame, not - the entire experiment. """ labels_in_frame = self.labels_at_time(frame) n_labels = [ @@ -405,8 +523,24 @@ class Cells: labels_mat[trap_id] += global_id_masks return labels_mat - def get_stacks_in_frame(self, frame: int, tile_shape: t.Tuple[int]): - # Stack all cells in a trap-wise manner + def get_stacks_in_frame( + self, frame: int, tile_shape: t.Tuple[int] + ) -> t.List[np.ndarray]: + """ + Returns a list of stacked masks, each corresponding to a tile at a given timepoint. + + Parameters + ---------- + frame : int + Frame for which to obtain the stacked masks. + tile_shape : Tuple[int] + Shape of a tile to stack the masks into. + + Returns + ------- + List[np.ndarray] + List of stacked masks for each tile at the given timepoint. + """ masks = self.at_time(frame) return [ stack_masks_in_tile( @@ -415,12 +549,32 @@ class Cells: for trap_id in range(self.ntraps) ] - def _sample_occupied_tiles_tp( + def _sample_tiles_tps( self, size=1, min_ncells: int = 2, min_consecutive_ntps: int = 5, - ): + seed: int = 0, + ) -> t.Tuple[np.ndarray, np.ndarray]: + """ + Sample tiles that have a minimum number of cells and are occupied for at least a minimum number of consecutive timepoints. + + Parameters + ---------- + size: int, optional (default=1) + The number of tiles to sample. + min_ncells: int, optional (default=2) + The minimum number of cells per tile. + min_consecutive_ntps: int, optional (default=5) + The minimum number of consecutive timepoints a cell must be present in a trap. + seed: int, optional (default=0) + Random seed value for reproducibility. + + Returns + ------- + Tuple[np.ndarray, np.ndarray] + A tuple of 1D numpy arrays containing the indices of the sampled tiles and the corresponding timepoints. + """ cell_availability_matrix = self.matrix_trap_tp_where( min_ncells=min_ncells, min_consecutive_tps=min_consecutive_ntps ) @@ -428,6 +582,7 @@ class Cells: # Find all valid tiles with min_ncells for at least min_tps tile_ids, tps = np.where(cell_availability_matrix) + np.random.seed(seed) choices = np.random.choice(len(tile_ids), size=size) return tile_ids[choices], tps[choices] @@ -436,24 +591,54 @@ class Cells: size=1, min_ncells: int = 2, min_consecutive_ntps: int = 5, - ): - """Sample a number of cells from different traps each.""" - tile_ids, tps = self._sample_occupied_tiles_tp( + seed=0, + ) -> t.Tuple[t.Tuple[t.List[int], t.List[int], t.List[int]], np.ndarray]: + """ + Sample a number of cells from different tiles each. + + Parameters + ---------- + size: int, optional (default=1) + The number of cells to sample. + min_ncells: int, optional (default=2) + The minimum number of cells per tile. + min_consecutive_ntps: int, optional (default=5) + The minimum number of consecutive timepoints a cell must be present in a trap. + seed: int, optional (default=0) + Random seed value for reproducibility. + + Returns + ------- + Tuple[Tuple[np.ndarray, np.ndarray, List[int]], List[np.ndarray]] + Two tuples are returned. The first tuple contains: + - `tile_ids`: A 1D numpy array of the tile ids that correspond to the tile identifier. + - `tps`: A 1D numpy array of the timepoints at which the cells were sampled. + - `cell_ids`: A list of integers that correspond to the local id of the sampled cells. + The second tuple contains: + - `masks`: A list of 2D numpy arrays representing the binary masks of the sampled cells at each timepoint. + """ + tile_ids, tps = self._sample_tiles_tps( size=size, min_ncells=min_ncells, min_consecutive_ntps=min_consecutive_ntps, + seed=seed, ) # Sort sampled tiles to use automatic cache when possible order = np.argsort(tps) - tile_ids = tile_ids[order] - tps = tps[order] + tile_ids = tile_ids[order].tolist() + tps = tps[order].tolist() - return (tile_ids, tps), np.array( - [ - self.at_time(tp)[tile_ids[tile_id]] - for tp, tile_id in zip(tps, tile_ids) - ] - ) + cell_ids = [] + masks = [] + for tp, tile_id in zip(tps, tile_ids): + tile_masks = self.at_time(tp)[tile_id] + + np.random.seed(seed) + cell_id = np.random.randint(len(tile_masks)) + + cell_ids.append(cell_id) + masks.append(tile_masks[cell_id]) + return (tile_ids, tps, cell_ids), np.stack(masks) def stack_masks_in_tile( @@ -462,5 +647,5 @@ def stack_masks_in_tile( # Stack all masks in a trap padding accordingly if no outlines found result = np.zeros((0, *tile_shape), dtype=bool) if len(masks): - result = np.array(masks) + result = np.stack(masks) return result