Skip to content
Snippets Groups Projects
Commit a9a6291d authored by Alán Muñoz's avatar Alán Muñoz
Browse files

fix(cells): _sample_masks works properly

parent 673d9649
No related branches found
No related tags found
No related merge requests found
...@@ -278,7 +278,7 @@ class Cells: ...@@ -278,7 +278,7 @@ class Cells:
def ntimepoints(self) -> int: def ntimepoints(self) -> int:
return self["timepoint"].max() + 1 return self["timepoint"].max() + 1
@property @cached_property
def _cells_vs_tps(self): def _cells_vs_tps(self):
# Binary matrix showing the presence of all cells in all time points # Binary matrix showing the presence of all cells in all time points
ncells_per_tile = [len(x) for x in self.labels] ncells_per_tile = [len(x) for x in self.labels]
...@@ -286,13 +286,29 @@ class Cells: ...@@ -286,13 +286,29 @@ class Cells:
(sum(ncells_per_tile), self.ntimepoints), dtype=bool (sum(ncells_per_tile), self.ntimepoints), dtype=bool
) )
cumsum = np.roll(np.cumsum(ncells_per_tile), shift=1)
cumsum[0] = 0
cells_vs_tps[ cells_vs_tps[
cumsum[self["trap"]] + self["cell_label"] - 1, self["timepoint"] self._cell_cumsum[self["trap"]] + self["cell_label"] - 1,
self["timepoint"],
] = True ] = True
return cells_vs_tps return cells_vs_tps
@cached_property
def _cell_cumsum(self):
# Cumulative sum indicating the number of cells per tile
ncells_per_tile = [len(x) for x in self.labels]
cumsum = np.roll(np.cumsum(ncells_per_tile), shift=1)
cumsum[0] = 0
return cumsum
def _flat_index_to_tuple_location(self, idx: int) -> t.Tuple[int, int]:
# Convert a cell index to a tuple
# Note that it assumes tiles and cell labels are flattened, but
# it is agnostic to tps
tile_id = int(np.where(idx + 1 > self._cell_cumsum)[0][-1])
cell_label = idx - self._cell_cumsum[tile_id] + 1
return tile_id, cell_label
@property @property
def _tiles_vs_cells_vs_tps(self): def _tiles_vs_cells_vs_tps(self):
ncells_mat = np.zeros( ncells_mat = np.zeros(
...@@ -304,39 +320,28 @@ class Cells: ...@@ -304,39 +320,28 @@ class Cells:
] = True ] = True
return ncells_mat return ncells_mat
def cell_tp_where(self, min_consecutive_tps: int = 15): def cell_tp_where(
self,
min_consecutive_tps: int = 15,
interval: None or t.Tuple[int, int] = None,
):
window = sliding_window_view( window = sliding_window_view(
self._cells_vs_tps, min_consecutive_tps, axis=1 self._cells_vs_tps, min_consecutive_tps, axis=1
) )
tp_min = window.sum(axis=-1) == min_consecutive_tps tp_min = window.sum(axis=-1) == min_consecutive_tps
return tp_min
def matrix_trap_tp_where(
self, min_ncells: int = 2, min_consecutive_tps: int = 5
):
"""
Return a matrix of shape (ntraps x ntps - min_consecutive_tps to
indicate traps and time-points where min_ncells are available for at least min_consecutive_tps
Parameters # Apply an interval filter to focucs on a slice
--------- if interval is not None:
min_ncells: int Minimum number of cells interval = tuple(np.array(interval) - min_consecutive_tps // 2)
min_consecutive_tps: int else:
Minimum number of time-points a interval = (0, window.shape[1])
Returns low_boundary, high_boundary = interval
---------
(ntraps x ( ntps-min_consecutive_tps )) 2D boolean numpy array where rows are trap ids and columns are timepoint windows.
If the value in a cell is true its corresponding trap and timepoint contains more than min_ncells for at least min_consecutive time-points.
"""
window = sliding_window_view( tp_min[:, :low_boundary] = False
self._tiles_vs_cells_vs_tps, min_consecutive_tps, axis=2 tp_min[:, high_boundary:] = False
) return tp_min
tp_min = window.sum(axis=-1) == min_consecutive_tps
ncells_tp_min = tp_min.sum(axis=1) >= min_ncells
return ncells_tp_min
@lru_cache(20) @lru_cache(20)
def mothers_in_trap(self, trap_id: int): def mothers_in_trap(self, trap_id: int):
...@@ -572,10 +577,9 @@ class Cells: ...@@ -572,10 +577,9 @@ class Cells:
def _sample_tiles_tps( def _sample_tiles_tps(
self, self,
size=1, size=1,
# min_ncells: int = 2, min_consecutive_ntps: int = 15,
# max_ncells: int = 2,
min_consecutive_ntps: int = 10,
seed: int = 0, seed: int = 0,
interval=None,
) -> t.Tuple[np.ndarray, np.ndarray]: ) -> t.Tuple[np.ndarray, np.ndarray]:
""" """
Sample tiles that have a minimum number of cells and are occupied for at least a minimum number of consecutive timepoints. Sample tiles that have a minimum number of cells and are occupied for at least a minimum number of consecutive timepoints.
...@@ -590,10 +594,12 @@ class Cells: ...@@ -590,10 +594,12 @@ class Cells:
The minimum number of consecutive timepoints a cell must be present in a trap. The minimum number of consecutive timepoints a cell must be present in a trap.
seed: int, optional (default=0) seed: int, optional (default=0)
Random seed value for reproducibility. Random seed value for reproducibility.
interval: None or Tuple(int,int), optional (default=None)
Random seed value for reproducibility.
Returns Returns
------- -------
Tuple[np.ndarray, np.ndarray] Tuple[np.ndarray, np.ndarray,np.ndarray]
A tuple of 1D numpy arrays containing the indices of the sampled tiles and the corresponding timepoints. A tuple of 1D numpy arrays containing the indices of the sampled tiles and the corresponding timepoints.
""" """
# cell_availability_matrix = self.matrix_trap_tp_where( # cell_availability_matrix = self.matrix_trap_tp_where(
...@@ -603,30 +609,43 @@ class Cells: ...@@ -603,30 +609,43 @@ class Cells:
# # Find all valid tiles with min_ncells for at least min_tps # # Find all valid tiles with min_ncells for at least min_tps
# tile_ids, tps = np.where(cell_availability_matrix) # tile_ids, tps = np.where(cell_availability_matrix)
cell_availability_matrix = self.cell_tp_where( cell_availability_matrix = self.cell_tp_where(
min_consecutive_tps=min_consecutive_ntps min_consecutive_tps=min_consecutive_ntps,
interval=interval,
) )
# Find all valid tiles with min_ncells for at least min_tps # Find all valid tiles with min_ncells for at least min_tps
index_id, _ = np.where(cell_availability_matrix) index_id, tps = np.where(cell_availability_matrix)
if interval is None: # Limit search
interval = (0, cell_availability_matrix.shape[1])
np.random.seed(seed) np.random.seed(seed)
choices = np.random.choice(index_id, size=size) choices = np.random.randint(len(index_id), size=size)
return (
self["trap"][choices], linear_indices = np.zeros_like(self["cell_label"], dtype=bool)
self["cell_label"][choices], for cell_index_flat, tp in zip(index_id[choices], tps[choices]):
self["timepoint"][choices], tile_id, cell_label = self._flat_index_to_tuple_location(
) cell_index_flat
)
linear_indices[
(
(self["cell_label"] == cell_label)
& (self["trap"] == tile_id)
& (self["timepoint"] == tp)
)
] = True
return linear_indices
def _sample_masks( def _sample_masks(
self, self,
size=1, size: int = 1,
min_ncells: int = 2, min_consecutive_ntps: int = 15,
min_consecutive_ntps: int = 5,
interval: t.Union[None, t.Tuple[int, int]] = None, interval: t.Union[None, t.Tuple[int, int]] = None,
seed=0, seed: int = 0,
) -> t.Tuple[t.Tuple[t.List[int], t.List[int], t.List[int]], np.ndarray]: ) -> t.Tuple[t.Tuple[t.List[int], t.List[int], t.List[int]], np.ndarray]:
""" """
Sample a number of cells from different tiles each. Sample a number of cells from within an interval.
Parameters Parameters
---------- ----------
...@@ -649,27 +668,53 @@ class Cells: ...@@ -649,27 +668,53 @@ class Cells:
The second tuple contains: The second tuple contains:
- `masks`: A list of 2D numpy arrays representing the binary masks of the sampled cells at each timepoint. - `masks`: A list of 2D numpy arrays representing the binary masks of the sampled cells at each timepoint.
""" """
tile_ids, _, tps = self._sample_tiles_tps( sampled_bitmask = self._sample_tiles_tps(
size=size, size=size,
min_consecutive_ntps=min_consecutive_ntps, min_consecutive_ntps=min_consecutive_ntps,
seed=seed, seed=seed,
interval=interval,
) )
# Sort sampled tiles to use automatic cache when possible # Sort sampled tiles to use automatic cache when possible
order = np.argsort(tps) tile_ids = self["trap"][sampled_bitmask]
tile_ids = tile_ids[order].tolist() cell_labels = self["cell_label"][sampled_bitmask]
tps = tps[order].tolist() tps = self["timepoint"][sampled_bitmask]
cell_ids = []
masks = [] masks = []
for tp, tile_id in zip(tps, tile_ids): for tile_id, cell_label, tp in zip(tile_ids, cell_labels, tps):
tile_masks = self.at_time(tp)[tile_id] local_idx = self.labels_at_time(tp)[tile_id].index(cell_label)
tile_mask = self.at_time(tp)[tile_id][local_idx]
masks.append(tile_mask)
return (tile_ids, cell_labels, tps), np.stack(masks)
np.random.seed(seed) def matrix_trap_tp_where(
cell_id = np.random.randint(len(tile_masks)) self, min_ncells: int = 2, min_consecutive_tps: int = 5
):
"""
NOTE CURRENLTY UNUSED WITHIN ALIBY THE MOMENT. MAY BE USEFUL IN THE FUTURE.
Return a matrix of shape (ntraps x ntps - min_consecutive_tps) to
indicate traps and time-points where min_ncells are available for at least min_consecutive_tps
cell_ids.append(cell_id) Parameters
masks.append(tile_masks[cell_id]) ---------
return (tile_ids, cell_ids, tps), np.stack(masks) min_ncells: int Minimum number of cells
min_consecutive_tps: int
Minimum number of time-points a
Returns
---------
(ntraps x ( ntps-min_consecutive_tps )) 2D boolean numpy array where rows are trap ids and columns are timepoint windows.
If the value in a cell is true its corresponding trap and timepoint contains more than min_ncells for at least min_consecutive time-points.
"""
window = sliding_window_view(
self._tiles_vs_cells_vs_tps, min_consecutive_tps, axis=2
)
tp_min = window.sum(axis=-1) == min_consecutive_tps
ncells_tp_min = tp_min.sum(axis=1) >= min_ncells
return ncells_tp_min
def stack_masks_in_tile( def stack_masks_in_tile(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment