Skip to content
Snippets Groups Projects
Commit 59ff4a72 authored by Alán Muñoz's avatar Alán Muñoz
Browse files

feat(cells): add _cell_tp_matrix

parent b5cc5524
No related branches found
No related tags found
No related merge requests found
......@@ -201,7 +201,7 @@ class Cells:
self, timepoints: t.Iterable[int], kind="mask"
) -> t.List[t.List[np.ndarray]]:
"""
Returns a list of lists of binary masks in a given list of time points.
Returns a list of lists of binary masks for a given list of time points.
Parameters
----------
......@@ -280,7 +280,22 @@ class Cells:
return self["timepoint"].max() + 1
@property
def ncells_matrix(self):
def _cells_vs_tps(self):
# Binary matrix showing the presence of all cells in all time points
ncells_per_tile = [len(x) for x in self.labels]
cells_vs_tps = np.zeros(
(sum(ncells_per_tile), self.ntimepoints), dtype=bool
)
cumsum = np.roll(np.cumsum(ncells_per_tile), shift=1)
cumsum[0] = 0
cells_vs_tps[
cumsum[self["trap"]] + self["cell_label"] - 1, self["timepoint"]
] = True
return cells_vs_tps
@property
def _tiles_vs_cells_vs_tps(self):
ncells_mat = np.zeros(
(self.ntraps, self["cell_label"].max(), self.ntimepoints),
dtype=bool,
......@@ -290,8 +305,16 @@ class Cells:
] = True
return ncells_mat
def cell_tp_where(self, min_consecutive_tps: int = 15):
window = sliding_window_view(
self._cells_vs_tps, min_consecutive_tps, axis=1
)
tp_min = window.sum(axis=-1) == min_consecutive_tps
return tp_min
def matrix_trap_tp_where(
self, min_ncells: int = None, min_consecutive_tps: int = None
self, min_ncells: int = 2, min_consecutive_tps: int = 5
):
"""
Return a matrix of shape (ntraps x ntps - min_consecutive_tps to
......@@ -308,13 +331,9 @@ class Cells:
(ntraps x ( ntps-min_consecutive_tps )) 2D boolean numpy array where rows are trap ids and columns are timepoint windows.
If the value in a cell is true its corresponding trap and timepoint contains more than min_ncells for at least min_consecutive time-points.
"""
if min_ncells is None:
min_ncells = 2
if min_consecutive_tps is None:
min_consecutive_tps = 5
window = sliding_window_view(
self.ncells_matrix, min_consecutive_tps, axis=2
self._tiles_vs_cells_vs_tps, min_consecutive_tps, axis=2
)
tp_min = window.sum(axis=-1) == min_consecutive_tps
ncells_tp_min = tp_min.sum(axis=1) >= min_ncells
......@@ -552,8 +571,9 @@ class Cells:
def _sample_tiles_tps(
self,
size=1,
min_ncells: int = 2,
min_consecutive_ntps: int = 5,
# min_ncells: int = 2,
# max_ncells: int = 2,
min_consecutive_ntps: int = 10,
seed: int = 0,
) -> t.Tuple[np.ndarray, np.ndarray]:
"""
......@@ -575,16 +595,26 @@ class Cells:
Tuple[np.ndarray, np.ndarray]
A tuple of 1D numpy arrays containing the indices of the sampled tiles and the corresponding timepoints.
"""
cell_availability_matrix = self.matrix_trap_tp_where(
min_ncells=min_ncells, min_consecutive_tps=min_consecutive_ntps
# cell_availability_matrix = self.matrix_trap_tp_where(
# min_ncells=min_ncells, min_consecutive_tps=min_consecutive_ntps
# )
# # Find all valid tiles with min_ncells for at least min_tps
# tile_ids, tps = np.where(cell_availability_matrix)
cell_availability_matrix = self.cell_tp_where(
min_consecutive_tps=min_consecutive_ntps
)
# Find all valid tiles with min_ncells for at least min_tps
tile_ids, tps = np.where(cell_availability_matrix)
index_id, _ = np.where(cell_availability_matrix)
np.random.seed(seed)
choices = np.random.choice(len(tile_ids), size=size)
return tile_ids[choices], tps[choices]
choices = np.random.choice(index_id, size=size)
return (
self["trap"][choices],
self["cell_label"][choices],
self["timepoint"][choices],
)
def _sample_masks(
self,
......@@ -617,9 +647,8 @@ class Cells:
The second tuple contains:
- `masks`: A list of 2D numpy arrays representing the binary masks of the sampled cells at each timepoint.
"""
tile_ids, tps = self._sample_tiles_tps(
tile_ids, _, tps = self._sample_tiles_tps(
size=size,
min_ncells=min_ncells,
min_consecutive_ntps=min_consecutive_ntps,
seed=seed,
)
......@@ -638,7 +667,7 @@ class Cells:
cell_ids.append(cell_id)
masks.append(tile_masks[cell_id])
return (tile_ids, tps, cell_ids), np.stack(masks)
return (tile_ids, cell_ids, tps), np.stack(masks)
def stack_masks_in_tile(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment