From 7861f7d6a4369ed2be9dcbc66af8cc56a8a10411 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <alan.munoz@ed.ac.uk>
Date: Mon, 27 Feb 2023 12:00:15 +0000
Subject: [PATCH] docs(Cells): Document all complex methods

---
 src/agora/io/cells.py | 265 +++++++++++++++++++++++++++++++++++-------
 1 file changed, 225 insertions(+), 40 deletions(-)

diff --git a/src/agora/io/cells.py b/src/agora/io/cells.py
index 0e2891f3..2f5a8536 100644
--- a/src/agora/io/cells.py
+++ b/src/agora/io/cells.py
@@ -146,12 +146,47 @@ class Cells:
         )
 
     def mask(self, cell_id, trap_id):
+        """
+        Returns the times and the binary masks of a given cell in a given tile.
+
+        Parameters
+        ----------
+        cell_id : int
+            The unique ID of the cell.
+        tile_id : int
+            The unique ID of the tile.
+
+        Returns
+        -------
+        Tuple[np.ndarray, np.ndarray]
+            The times when the binary masks were taken and the binary masks of the given cell in the given tile.
+
+        """
         times, outlines = self.outline(cell_id, trap_id)
         return times, np.array(
             [ndimage.morphology.binary_fill_holes(o) for o in outlines]
         )
 
-    def at_time(self, timepoint, kind="mask"):
+    def at_time(
+        self, timepoint: t.Iterable[int], kind="mask"
+    ) -> t.List[t.List[np.ndarray]]:
+        """
+        Returns a list of lists of binary masks in a given list of time points.
+
+        Parameters
+        ----------
+        timepoints : Iterable[int]
+            The list of time points for which to return the binary masks.
+        kind : str, optional
+            The type of binary masks to return, by default "mask".
+
+        Returns
+        -------
+        List[List[np.ndarray]]
+            A list of lists with binary masks grouped by tile IDs.
+
+        """
+
         ix = self["timepoint"] == timepoint
         traps = self["trap"][ix]
         edgemasks = self._edgem_from_masking(ix)
@@ -162,7 +197,25 @@ class Cells:
         ]
         return self.group_by_traps(traps, masks)
 
-    def at_times(self, timepoints: t.Iterable[int], kind="mask"):
+    def at_times(
+        self, timepoints: t.Iterable[int], kind="mask"
+    ) -> t.List[t.List[np.ndarray]]:
+        """
+        Returns a list of lists of binary masks in a given list of time points.
+
+        Parameters
+        ----------
+        timepoints : Iterable[int]
+            The list of time points for which to return the binary masks.
+        kind : str, optional
+            The type of binary masks to return, by default "mask".
+
+        Returns
+        -------
+        List[List[np.ndarray]]
+            A list of lists with binary masks grouped by tile IDs.
+
+        """
         return [
             [
                 np.stack(tile_masks) if len(tile_masks) else []
@@ -267,17 +320,6 @@ class Cells:
         ncells_tp_min = tp_min.sum(axis=1) >= min_ncells
         return ncells_tp_min
 
-    def random_valid_trap_tp(
-        self, min_ncells: int = None, min_consecutive_tps: int = None
-    ):
-        # Return a randomly-selected pair of trap_id and timepoints
-        mat = self.matrix_trap_tp_where(
-            min_ncells=min_ncells, min_consecutive_tps=min_consecutive_tps
-        )
-        traps, tps = np.where(mat)
-        rand = np.random.randint(mat.sum())
-        return (traps[rand], tps[rand])
-
     @lru_cache(20)
     def mothers_in_trap(self, trap_id: int):
         return self.mothers[trap_id]
@@ -297,8 +339,17 @@ class Cells:
     @cached_property
     def mothers_daughters(self) -> np.ndarray:
         """
-        Return mothers and daugters as a single array with three columns:
-        trap, mothers and daughters
+        Return a single array with three columns, containing information about the mother-daughter relationships:
+        tile, mothers and daughters.
+
+        Returns
+        -------
+        np.ndarray
+            An array with shape (n, 3) where n is the number of mother-daughter pairs found.
+            The columns contain:
+            - tile: the tile where the mother cell is located.
+            - mothers: the index of the mother cell within the tile. If there is no mother, this is zero.
+            - daughters: the index of the daughter cell within the tile. If there is no daughter, this is zero.
         """
         nested_massign = self.mothers
 
@@ -320,7 +371,41 @@ class Cells:
 
     @staticmethod
     def mother_assign_to_mb_matrix(ma: t.List[np.array]):
-        # Convert from list of lists to mother_bud sparse matrix
+        """
+        Convert from a list of lists of mother-bud paired assignments to a
+        sparse matrix with a boolean dtype. The rows correspond to
+        to daughter buds. The values are boolean and indicate whether a
+        given cell is a mother cell and a given daughter bud is assigned
+        to the mother cell in the next timepoint.
+
+        Parameters:
+        -----------
+        ma : list of lists of integers
+            A list of lists of mother-bud assignments. The i-th sublist contains the
+            bud assignments for the i-th tile. The integers in each sublist
+            represent the mother label, if it is zero no mother was found.
+
+        Returns:
+        --------
+        mb_matrix : boolean numpy array of shape (n, m)
+            An n x m boolean numpy array where n is the total number of cells (sum
+            of the lengths of all sublists in ma) and m is the maximum number of buds
+            assigned to any mother cell in ma. The value at (i, j) is True if cell i
+            is a daughter cell and cell j is its mother assigned to i.
+
+        Examples:
+        --------
+        ma = [[0, 0, 1], [0, 1, 0]]
+        Cells(None).mother_assign_to_mb_matrix(ma)
+        # array([[False, False, False, False, False, False],
+        #        [False, False, False, False, False, False],
+        #        [ True, False, False, False, False, False],
+        #        [False, False, False, False, False, False],
+        #        [False, False, False,  True, False, False],
+        #        [False, False, False, False, False, False]])
+
+        """
+
         ncells = sum([len(t) for t in ma])
         mb_matrix = np.zeros((ncells, ncells), dtype=bool)
         c = 0
@@ -335,10 +420,26 @@ class Cells:
 
     @staticmethod
     def mother_assign_from_dynamic(
-        ma, cell_label: t.List[int], trap: t.List[int], ntraps: int
-    ):
+        ma: np.ndarray, cell_label: t.List[int], trap: t.List[int], ntraps: int
+    ) -> t.List[t.List[int]]:
         """
-        Interpolate the list of lists containing the associated mothers from the mother_assign_dynamic feature
+        Interpolate the associated mothers from the 'mother_assign_dynamic' feature.
+
+        Parameters
+        ----------
+        ma: np.ndarray
+            An array with shape (n_t, n_c) containing the 'mother_assign_dynamic' feature.
+        cell_label: List[int]
+            A list containing the cell labels.
+        trap: List[int]
+            A list containing the trap labels.
+        ntraps: int
+            The total number of traps.
+
+        Returns
+        -------
+        List[List[int]]
+            A list of lists containing the interpolated mother assignment for each cell in each trap.
         """
         idlist = list(zip(trap, cell_label))
         cell_gid = np.unique(idlist, axis=0)
@@ -363,12 +464,29 @@ class Cells:
     @lru_cache(maxsize=200)
     def labelled_in_frame(self, frame: int, global_id=False) -> np.ndarray:
         """
-        Return labels in a ndarray with the global ids
-        with shape (ntraps, max_nlabels, ysize, xsize)
-        at a given frame.
+        Returns labels in a 4D ndarray with the global ids with shape
+        (ntraps, max_nlabels, ysize, xsize) at a given frame.
+
+        Parameters
+        ----------
+        frame : int
+            The frame number.
+        global_id : bool, optional
+            If True, the returned array contains global ids, otherwise it
+            contains only the local ids of the labels. Default is False.
+
+        Returns
+        -------
+        np.ndarray
+            A 4D numpy array containing the labels in the given frame.
+            The array has dimensions (ntraps, max_nlabels, ysize, xsize),
+            where max_nlabels is specific for this frame, not the entire
+            experiment.
+
+        Notes
+        -----
+        This method uses lru_cache to cache the results for faster access.
 
-        max_nlabels is specific for this frame, not
-        the entire experiment.
         """
         labels_in_frame = self.labels_at_time(frame)
         n_labels = [
@@ -405,8 +523,24 @@ class Cells:
                 labels_mat[trap_id] += global_id_masks
         return labels_mat
 
-    def get_stacks_in_frame(self, frame: int, tile_shape: t.Tuple[int]):
-        # Stack all cells in a trap-wise manner
+    def get_stacks_in_frame(
+        self, frame: int, tile_shape: t.Tuple[int]
+    ) -> t.List[np.ndarray]:
+        """
+        Returns a list of stacked masks, each corresponding to a tile at a given timepoint.
+
+        Parameters
+        ----------
+        frame : int
+            Frame for which to obtain the stacked masks.
+        tile_shape : Tuple[int]
+            Shape of a tile to stack the masks into.
+
+        Returns
+        -------
+        List[np.ndarray]
+            List of stacked masks for each tile at the given timepoint.
+        """
         masks = self.at_time(frame)
         return [
             stack_masks_in_tile(
@@ -415,12 +549,32 @@ class Cells:
             for trap_id in range(self.ntraps)
         ]
 
-    def _sample_occupied_tiles_tp(
+    def _sample_tiles_tps(
         self,
         size=1,
         min_ncells: int = 2,
         min_consecutive_ntps: int = 5,
-    ):
+        seed: int = 0,
+    ) -> t.Tuple[np.ndarray, np.ndarray]:
+        """
+        Sample tiles that have a minimum number of cells and are occupied for at least a minimum number of consecutive timepoints.
+
+        Parameters
+        ----------
+        size: int, optional (default=1)
+            The number of tiles to sample.
+        min_ncells: int, optional (default=2)
+            The minimum number of cells per tile.
+        min_consecutive_ntps: int, optional (default=5)
+            The minimum number of consecutive timepoints a cell must be present in a trap.
+        seed: int, optional (default=0)
+            Random seed value for reproducibility.
+
+        Returns
+        -------
+        Tuple[np.ndarray, np.ndarray]
+            A tuple of 1D numpy arrays containing the indices of the sampled tiles and the corresponding timepoints.
+        """
         cell_availability_matrix = self.matrix_trap_tp_where(
             min_ncells=min_ncells, min_consecutive_tps=min_consecutive_ntps
         )
@@ -428,6 +582,7 @@ class Cells:
         # Find all valid tiles with min_ncells for at least min_tps
         tile_ids, tps = np.where(cell_availability_matrix)
 
+        np.random.seed(seed)
         choices = np.random.choice(len(tile_ids), size=size)
         return tile_ids[choices], tps[choices]
 
@@ -436,24 +591,54 @@ class Cells:
         size=1,
         min_ncells: int = 2,
         min_consecutive_ntps: int = 5,
-    ):
-        """Sample a number of cells from different traps each."""
-        tile_ids, tps = self._sample_occupied_tiles_tp(
+        seed=0,
+    ) -> t.Tuple[t.Tuple[t.List[int], t.List[int], t.List[int]], np.ndarray]:
+        """
+        Sample a number of cells from different tiles each.
+
+        Parameters
+        ----------
+        size: int, optional (default=1)
+            The number of cells to sample.
+        min_ncells: int, optional (default=2)
+            The minimum number of cells per tile.
+        min_consecutive_ntps: int, optional (default=5)
+            The minimum number of consecutive timepoints a cell must be present in a trap.
+        seed: int, optional (default=0)
+            Random seed value for reproducibility.
+
+        Returns
+        -------
+        Tuple[Tuple[np.ndarray, np.ndarray, List[int]], List[np.ndarray]]
+            Two tuples are returned. The first tuple contains:
+            - `tile_ids`: A 1D numpy array of the tile ids that correspond to the tile identifier.
+            - `tps`: A 1D numpy array of the timepoints at which the cells were sampled.
+            - `cell_ids`: A list of integers that correspond to the local id of the sampled cells.
+            The second tuple contains:
+            - `masks`: A list of 2D numpy arrays representing the binary masks of the sampled cells at each timepoint.
+        """
+        tile_ids, tps = self._sample_tiles_tps(
             size=size,
             min_ncells=min_ncells,
             min_consecutive_ntps=min_consecutive_ntps,
+            seed=seed,
         )
         #  Sort sampled tiles to use automatic cache when possible
         order = np.argsort(tps)
-        tile_ids = tile_ids[order]
-        tps = tps[order]
+        tile_ids = tile_ids[order].tolist()
+        tps = tps[order].tolist()
 
-        return (tile_ids, tps), np.array(
-            [
-                self.at_time(tp)[tile_ids[tile_id]]
-                for tp, tile_id in zip(tps, tile_ids)
-            ]
-        )
+        cell_ids = []
+        masks = []
+        for tp, tile_id in zip(tps, tile_ids):
+            tile_masks = self.at_time(tp)[tile_id]
+
+            np.random.seed(seed)
+            cell_id = np.random.randint(len(tile_masks))
+
+            cell_ids.append(cell_id)
+            masks.append(tile_masks[cell_id])
+        return (tile_ids, tps, cell_ids), np.stack(masks)
 
 
 def stack_masks_in_tile(
@@ -462,5 +647,5 @@ def stack_masks_in_tile(
     # Stack all masks in a trap padding accordingly if no outlines found
     result = np.zeros((0, *tile_shape), dtype=bool)
     if len(masks):
-        result = np.array(masks)
+        result = np.stack(masks)
     return result
-- 
GitLab