diff --git a/src/agora/io/cells.py b/src/agora/io/cells.py index 9735d4f8cc4601c60c06e6227d4f3f241e717877..13c540a7d5945535eb8c72bf50408944d2142edb 100644 --- a/src/agora/io/cells.py +++ b/src/agora/io/cells.py @@ -20,12 +20,15 @@ class Cells: 'cell_info', which contains 'angles', 'cell_label', 'centres', 'edgemasks', 'ellipse_dims', 'mother_assign', 'mother_assign_dynamic', - 'radii', 'timepoint', 'trap'. - All of these except for 'edgemasks' are a 1D ndarray. + 'radii', 'timepoint', and 'trap'. All of which except for 'edgemasks' + are a 1D ndarray. - and + 'trap_info', which contains 'drifts', and 'trap_locations'. - 'trap_info', which contains 'drifts', 'trap_locations' + The "timepoint", "cell_label", and "trap" variables are consistent 1D lists. + + Examples are self["timepoint"][self.get_idx(1, 3)] to find the time points + where cell 1 was present in trap 3. """ @@ -36,8 +39,26 @@ class Cells: self._edgemasks: t.Optional[str] = None self._tile_size: t.Optional[int] = None + def __getitem__(self, item): + """ + Dynamically fetch data from the h5 file and save as an attribute. + + These attributes are accessed like dict keys. + """ + assert item != "edgemasks", "Edgemasks must not be loaded as a whole" + _item = "_" + item + if not hasattr(self, _item): + setattr(self, _item, self.fetch(item)) + return getattr(self, _item) + + def fetch(self, path): + """Get data from the h5 file.""" + with h5py.File(self.filename, mode="r") as f: + return f[self.cinfo_path][path][()] + @classmethod def from_source(cls, source: t.Union[Path, str]): + """Ensure initiating file is a Path object.""" return cls(Path(source)) def _log(self, message: str, level: str = "warn"): @@ -46,31 +67,34 @@ class Cells: getattr(logger, level)(f"{self.__class__.__name__}: {message}") @staticmethod - def _asdense(array: np.ndarray): + def asdense(array: np.ndarray): + """Convert sparse array to dense array.""" if not isdense(array): array = array.todense() return array @staticmethod - def _astype(array: np.ndarray, kind: str): + def astype(array: np.ndarray, kind: str): """Convert sparse arrays if needed; if kind is 'mask' fill the outline.""" - array = Cells._asdense(array) + array = Cells.asdense(array) if kind == "mask": array = ndimage.binary_fill_holes(array).astype(bool) return array - def _get_idx(self, cell_id: int, trap_id: int): - """Return boolean array of time points where both the cell with cell_id and the trap with trap_id exist.""" + def get_idx(self, cell_id: int, trap_id: int): + """Return boolean array giving indices for a cell_id and trap_id.""" return (self["cell_label"] == cell_id) & (self["trap"] == trap_id) @property def max_labels(self) -> t.List[int]: """Return the maximum cell label per tile.""" - return [max((0, *self.labels_in_trap(i))) for i in range(self.ntraps)] + return [ + max((0, *self.cell_labels_in_trap(i))) for i in range(self.ntraps) + ] @property def max_label(self) -> int: - """Return the highest maximum cell label per tile.""" + """Return the maximum cell label over all tiles.""" return sum(self.max_labels) @property @@ -87,7 +111,7 @@ class Cells: @property def traps(self) -> t.List[int]: - """List tile, or trap, IDs.""" + """List unique tile, or trap, IDs.""" return list(set(self["trap"])) @property @@ -95,110 +119,62 @@ class Cells: """Give the x- and y- sizes of a tile.""" if self._tile_size is None: with h5py.File(self.filename, mode="r") as f: - # self._tile_size = f["trap_info/tile_size"][0] self._tile_size = f["cell_info/edgemasks"].shape[1:] return self._tile_size def nonempty_tp_in_trap(self, trap_id: int) -> set: - """Given a tile_id, return time points for which cells are available.""" + """Given a tile, return time points for which cells are available.""" return set(self["timepoint"][self["trap"] == trap_id]) @property def edgemasks(self) -> t.List[np.ndarray]: - """Return a 3D array of masks for every cell.""" + """Return a list of masks for every cell at every trap and time point.""" if self._edgemasks is None: edgem_path: str = "edgemasks" - self._edgemasks = self._fetch(edgem_path) + self._edgemasks = self.fetch(edgem_path) return self._edgemasks @property def labels(self) -> t.List[t.List[int]]: - """ - Return all cell labels per tile. - - We use mother_assign to list tiles because it is the only property - that appears even when no cells are found. - """ - return [self.labels_in_trap(trap) for trap in range(self.ntraps)] + """Return all cell labels per tile as a set for all tiles.""" + return [self.cell_labels_in_trap(trap) for trap in range(self.ntraps)] - def max_labels_in_frame(self, frame: int) -> t.List[int]: - """Get the maximal cell label for each tile.""" + def max_labels_in_frame(self, final_time_point: int) -> t.List[int]: + """Get the maximal cell label for each tile within a frame of time.""" max_labels = [ self["cell_label"][ - (self["timepoint"] <= frame) & (self["trap"] == trap_id) + (self["timepoint"] <= final_time_point) + & (self["trap"] == trap_id) ] for trap_id in range(self.ntraps) ] return [max([0, *labels]) for labels in max_labels] def where(self, cell_id: int, trap_id: int): - """ - Parameters - ---------- - cell_id: int - Cell index - trap_id: int - Trap index - - Returns - ---------- - indices: int array - boolean mask array - edge_ix int array - """ - indices = self._get_idx(cell_id, trap_id) - edgem_ix = self._edgem_where(cell_id, trap_id) + """Return time points, indices, and edge masks for a cell and trap.""" + idx = self.get_idx(cell_id, trap_id) return ( - self["timepoint"][indices], - indices, - edgem_ix, + self["timepoint"][idx], + idx, + self.edgemasks_where(cell_id, trap_id), ) def mask(self, cell_id, trap_id): - """ - Return the times and the binary masks of a given cell in a given tile. - - Parameters - ---------- - cell_id : int - The unique ID of the cell. - tile_id : int - The unique ID of the tile. - - Returns - ------- - Tuple[np.ndarray, np.ndarray] - The times when the binary masks were taken and the binary masks of the given cell in the given tile. - - """ + """Return the times and the filled edge masks for a cell and trap.""" times, outlines = self.outline(cell_id, trap_id) return times, np.array( [ndimage.morphology.binary_fill_holes(o) for o in outlines] ) def at_time( - self, timepoint: t.Iterable[int], kind="mask" + self, timepoint: int, kind="mask" ) -> t.List[t.List[np.ndarray]]: - """ - Return a list of lists of binary masks in a given list of time points. - - Parameters - ---------- - timepoints : Iterable[int] - The list of time points for which to return the binary masks. - kind : str, optional - The type of binary masks to return, by default "mask". - - Returns - ------- - List[List[np.ndarray]] - A list of lists with binary masks grouped by tile IDs. - """ - ix = self["timepoint"] == timepoint - traps = self["trap"][ix] - edgemasks = self._edgem_from_masking(ix) + """Return a dict with traps as keys and cell masks as values for a time point.""" + idx = self["timepoint"] == timepoint + traps = self["trap"][idx] + edgemasks = self.edgemasks_from_idx(idx) masks = [ - self._astype(edgemask, kind) + Cells.astype(edgemask, kind) for edgemask in edgemasks if edgemask.any() ] @@ -207,22 +183,7 @@ class Cells: def at_times( self, timepoints: t.Iterable[int], kind="mask" ) -> t.List[t.List[np.ndarray]]: - """ - Return a list of lists of binary masks for a given list of time points. - - Parameters - ---------- - timepoints : Iterable[int] - The list of time points for which to return the binary masks. - kind : str, optional - The type of binary masks to return, by default "mask". - - Returns - ------- - List[List[np.ndarray]] - A list of lists with binary masks grouped by tile IDs. - - """ + """Return a list of lists of cell masks one for specified time point.""" return [ [ np.stack(tile_masks) if len(tile_masks) else [] @@ -234,87 +195,77 @@ class Cells: def group_by_traps( self, traps: t.Collection, cell_labels: t.Collection ) -> t.Dict[int, t.List[int]]: - """Return a dict with traps as keys and a list of labels as value.""" + """Return a dict with traps as keys and a list of labels as values.""" iterator = groupby(zip(traps, cell_labels), lambda x: x[0]) d = {key: [x[1] for x in group] for key, group in iterator} d = {i: d.get(i, []) for i in self.traps} return d - def labels_in_trap(self, trap_id: int) -> t.Set[int]: - """Return set of cell ids for a given trap.""" + def cell_labels_in_trap(self, trap_id: int) -> t.Set[int]: + """Return unique cell labels for a given trap.""" return set((self["cell_label"][self["trap"] == trap_id])) def labels_at_time(self, timepoint: int) -> t.Dict[int, t.List[int]]: - """Return cell labels for each tile at the specified time point.""" + """Return a dict with traps as keys and cell labels as values for a time point.""" labels = self["cell_label"][self["timepoint"] == timepoint] traps = self["trap"][self["timepoint"] == timepoint] return self.group_by_traps(traps, labels) - def __getitem__(self, item): - """Define and return item as a underscored attribute.""" - assert item != "edgemasks", "Edgemasks must not be loaded as a whole" - _item = "_" + item - if not hasattr(self, _item): - setattr(self, _item, self._fetch(item)) - return getattr(self, _item) - - def _fetch(self, path): - with h5py.File(self.filename, mode="r") as f: - return f[self.cinfo_path][path][()] - - def _edgem_from_masking(self, mask): + def edgemasks_from_idx(self, idx): + """Get edge masks from the h5 file.""" with h5py.File(self.filename, mode="r") as f: - edgem = f[self.cinfo_path + "/edgemasks"][mask, ...] + edgem = f[self.cinfo_path + "/edgemasks"][idx, ...] return edgem - def _edgem_where(self, cell_id, trap_id): - id_mask = self._get_idx(cell_id, trap_id) - edgem = self._edgem_from_masking(id_mask) - - return edgem + def edgemasks_where(self, cell_id, trap_id): + """Get the edge masks for a given cell and trap for all time points.""" + idx = self.get_idx(cell_id, trap_id) + edgemasks = self.edgemasks_from_idx(idx) + return edgemasks def outline(self, cell_id: int, trap_id: int): - """Get times and masks for when cell_id is in trap_id.""" - id_mask = self._get_idx(cell_id, trap_id) - times = self["timepoint"][id_mask] - return times, self._edgem_from_masking(id_mask) + """Get times and edge masks for a given cell and trap.""" + idx = self.get_idx(cell_id, trap_id) + times = self["timepoint"][idx] + return times, self.edgemasks_from_idx(idx) @property def ntimepoints(self) -> int: - """Total number of time points in the experiment.""" + """Return total number of time points in the experiment.""" return self["timepoint"].max() + 1 @cached_property - def _cells_vs_tps(self): - """Binary matrix showing all cells in all time points.""" - ncells_per_tile = [len(x) for x in self.labels] - cells_vs_tps = np.zeros( - (sum(ncells_per_tile), self.ntimepoints), dtype=bool - ) + def cells_vs_tps(self): + """Boolean matrix showing when cells are present for all time points.""" + total_ncells = sum([len(x) for x in self.labels]) + cells_vs_tps = np.zeros((total_ncells, self.ntimepoints), dtype=bool) cells_vs_tps[ - self._cell_cumsum[self["trap"]] + self["cell_label"] - 1, + self.cell_cumlsum[self["trap"]] + self["cell_label"] - 1, self["timepoint"], ] = True return cells_vs_tps @cached_property - def _cell_cumsum(self): - """Cumulative sum indicating the number of cells per tile.""" + def cell_cumlsum(self): + """Find cumulative sum over tiles of the number of cells present.""" ncells_per_tile = [len(x) for x in self.labels] cumsum = np.roll(np.cumsum(ncells_per_tile), shift=1) cumsum[0] = 0 return cumsum - def _flat_index_to_tuple_location(self, idx: int) -> t.Tuple[int, int]: - """Convert a cell index to a tuple.""" - # Note that we assumes tiles and cell labels are flattened, but - # are agnostic to tps. - tile_id = int(np.where(idx + 1 > self._cell_cumsum)[0][-1]) - cell_label = idx - self._cell_cumsum[tile_id] + 1 + def index_to_tile_and_cell(self, idx: int) -> t.Tuple[int, int]: + """Convert an index to the equivalent pair of tile and cell IDs.""" + tile_id = int(np.where(idx + 1 > self.cell_cumlsum)[0][-1]) + cell_label = idx - self.cell_cumlsum[tile_id] + 1 return tile_id, cell_label @property - def _tiles_vs_cells_vs_tps(self): + def tiles_vs_cells_vs_tps(self): + """ + Boolean matrix showing if a cell is present. + + The matrix is indexed by trap, cell label, and time point. + """ ncells_mat = np.zeros( (self.ntraps, self["cell_label"].max(), self.ntimepoints), dtype=bool, @@ -329,11 +280,16 @@ class Cells: min_consecutive_tps: int = 15, interval: None or t.Tuple[int, int] = None, ): + """ + Find cells present for all time points in a sliding window of time. + + The result can be restricted to a particular interval of time. + """ window = sliding_window_view( - self._cells_vs_tps, min_consecutive_tps, axis=1 + self.cells_vs_tps, min_consecutive_tps, axis=1 ) tp_min = window.sum(axis=-1) == min_consecutive_tps - # apply an interval filter to focus on a slice + # apply a filter to restrict to an interval of time if interval is not None: interval = tuple(np.array(interval)) else: @@ -345,11 +301,17 @@ class Cells: @lru_cache(20) def mothers_in_trap(self, trap_id: int): + """Return mothers at a trap.""" return self.mothers[trap_id] @cached_property def mothers(self): - """Return nested list with final prediction of mother id for each cell in each tile.""" + """ + Return a list of mother IDs for each cell in each tile. + + Use Baby's "mother_assign_dynamic". + An ID of zero implies that no mother was assigned. + """ return self.mother_assign_from_dynamic( self["mother_assign_dynamic"], self["cell_label"], @@ -365,19 +327,21 @@ class Cells: Returns ------- mothers_daughters: np.ndarray - An array with shape (n, 3) where n is the number of mother-daughter pairs found. - The first column is the tile_id for the tile where the mother cell is located. - The second column is the cell index of a mother cell in the tile. - The third column is the index of the corresponding daughter cell. + An array with shape (n, 3) where n is the number of mother-daughter + pairs found. The first column is the tile_id for the tile where the + mother cell is located. The second column is the cell index of a + mother cell in the tile. The third column is the index of the + corresponding daughter cell. """ - nested_massign = self.mothers - if sum([x for y in nested_massign for x in y]): + # list of arrays, one per tile, giving mothers of each cell in each tile + mothers = self.mothers + if sum([x for y in mothers for x in y]): mothers_daughters = np.array( [ - (tid, m, d) - for tid, trapcells in enumerate(nested_massign) - for d, m in enumerate(trapcells, 1) - if m + (trap_id, mother, bud) + for trap_id, trapcells in enumerate(mothers) + for bud, mother in enumerate(trapcells, start=1) + if mother ], dtype=np.uint16, ) @@ -389,11 +353,11 @@ class Cells: @staticmethod def mother_assign_to_mb_matrix(ma: t.List[np.array]): """ - Convert a list of mother-daughter lists into a boolean sparse matrix. + Convert a list of mother-daughters into a boolean sparse matrix. Each row in the matrix correspond to daughter buds. If an entry is True, a given cell is a mother cell and a given - daughter bud is assigned to the mother cell in the next timepoint. + daughter bud is assigned to the mother cell in the next time point. Parameters: ----------- @@ -435,43 +399,56 @@ class Cells: @staticmethod def mother_assign_from_dynamic( - ma: np.ndarray, cell_label: t.List[int], trap: t.List[int], ntraps: int + ma: np.ndarray, + cell_label: t.List[int], + trap: t.List[int], + ntraps: int, ) -> t.List[t.List[int]]: """ - Interpolate the associated mothers from the 'mother_assign_dynamic' feature. + Find mothers from Baby's 'mother_assign_dynamic' variable. Parameters ---------- ma: np.ndarray - An array with shape (n_t, n_c) containing the 'mother_assign_dynamic' feature. + An array with of length number of time points times number of cells + containing the 'mother_assign_dynamic' produced by Baby. cell_label: List[int] - A list containing the cell labels. + A list of cell labels. trap: List[int] - A list containing the trap labels. + A list of trap labels. ntraps: int The total number of traps. Returns ------- List[List[int]] - A list of lists containing the interpolated mother assignment for each cell in each trap. + A list giving the mothers for each cell at each trap. """ - idlist = list(zip(trap, cell_label)) - cell_gid = np.unique(idlist, axis=0) + ids = np.unique(list(zip(trap, cell_label)), axis=0) + # find when each cell last appeared at its trap last_lin_preds = [ find_1st( - ((cell_label[::-1] == lbl) & (trap[::-1] == tr)), + ( + (cell_label[::-1] == cell_label_id) + & (trap[::-1] == trap_id) + ), True, cmp_equal, ) - for tr, lbl in cell_gid + for trap_id, cell_label_id in ids ] + # find the cell's mother using the latest prediction from Baby mother_assign_sorted = ma[::-1][last_lin_preds] - traps = cell_gid[:, 0] + # rearrange as a list of mother IDs for each cell in each tile + traps = ids[:, 0] iterator = groupby(zip(traps, mother_assign_sorted), lambda x: x[0]) - d = {key: [x[1] for x in group] for key, group in iterator} - nested_massign = [d.get(i, []) for i in range(ntraps)] - return nested_massign + d = {trap: [x[1] for x in mothers] for trap, mothers in iterator} + mothers = [d.get(i, []) for i in range(ntraps)] + return mothers + + ############################################################################### + # Apparently unused below here + ############################################################################### @lru_cache(maxsize=200) def labelled_in_frame( @@ -480,13 +457,15 @@ class Cells: """ Return labels in a 4D ndarray with potentially global ids. + Use lru_cache to cache the results for speed. + Parameters ---------- frame : int The frame number (time point). global_id : bool, optional - If True, the returned array contains global ids, otherwise it - contains only the local ids of the labels. Default is False. + If True, the returned array contains global ids, otherwise only + the local ids of the labels. Returns ------- @@ -495,17 +474,12 @@ class Cells: The array has dimensions (ntraps, max_nlabels, ysize, xsize), where max_nlabels is specific for this frame, not the entire experiment. - - Notes - ----- - This method uses lru_cache to cache the results for faster access. """ labels_in_frame = self.labels_at_time(frame) n_labels = [ len(labels_in_frame.get(trap_id, [])) for trap_id in range(self.ntraps) ] - # maxes = self.max_labels_in_frame(frame) stacks_in_frame = self.get_stacks_in_frame(frame, self.tile_size) first_id = np.cumsum([0, *n_labels]) labels_mat = np.zeros( @@ -539,7 +513,9 @@ class Cells: self, frame: int, tile_shape: t.Tuple[int] ) -> t.List[np.ndarray]: """ - Return a list of stacked masks, each corresponding to a tile at a given time point. + Return a list of stacked masks. + + Each corresponds to a tile at a given time point. Parameters ---------- @@ -551,7 +527,7 @@ class Cells: Returns ------- List[np.ndarray] - List of stacked masks for each tile at the given timepoint. + List of stacked masks for each tile at the given time point. """ masks = self.at_time(frame) return [ @@ -561,7 +537,7 @@ class Cells: for trap_id in range(self.ntraps) ] - def _sample_tiles_tps( + def sample_tiles_tps( self, size=1, min_consecutive_ntps: int = 15, @@ -607,9 +583,7 @@ class Cells: choices = np.random.randint(len(index_id), size=size) linear_indices = np.zeros_like(self["cell_label"], dtype=bool) for cell_index_flat, tp in zip(index_id[choices], tps[choices]): - tile_id, cell_label = self._flat_index_to_tuple_location( - cell_index_flat - ) + tile_id, cell_label = self.index_to_tile_and_cell(cell_index_flat) linear_indices[ ( (self["cell_label"] == cell_label) @@ -619,7 +593,7 @@ class Cells: ] = True return linear_indices - def _sample_masks( + def sample_masks( self, size: int = 1, min_consecutive_ntps: int = 15, @@ -650,7 +624,7 @@ class Cells: The second tuple contains: - `masks`: A list of 2D numpy arrays representing the binary masks of the sampled cells at each timepoint. """ - sampled_bitmask = self._sample_tiles_tps( + sampled_bitmask = self.sample_tiles_tps( size=size, min_consecutive_ntps=min_consecutive_ntps, seed=seed, @@ -688,7 +662,7 @@ class Cells: If the value in a cell is true its corresponding trap and timepoint contains more than min_ncells for at least min_consecutive time-points. """ window = sliding_window_view( - self._tiles_vs_cells_vs_tps, min_consecutive_tps, axis=2 + self.tiles_vs_cells_vs_tps, min_consecutive_tps, axis=2 ) tp_min = window.sum(axis=-1) == min_consecutive_tps ncells_tp_min = tp_min.sum(axis=1) >= min_ncells diff --git a/src/aliby/pipeline.py b/src/aliby/pipeline.py index f373702cf23f4dddf3b0fc0a742b9460c6712190..b6f58f29c3b8cf82d91c2b50971e74b5387229b9 100644 --- a/src/aliby/pipeline.py +++ b/src/aliby/pipeline.py @@ -39,7 +39,12 @@ class PipelineParameters(ParametersABC): _pool_index = None def __init__( - self, general, tiler, baby, extraction, postprocessing, reporting + self, + general, + tiler, + baby, + extraction, + postprocessing, ): """Initialise, but called by a class method - not directly.""" self.general = general @@ -47,7 +52,6 @@ class PipelineParameters(ParametersABC): self.baby = baby self.extraction = extraction self.postprocessing = postprocessing - self.reporting = reporting @classmethod def default( @@ -155,8 +159,6 @@ class PipelineParameters(ParametersABC): defaults["postprocessing"] = PostProcessorParameters.default( **postprocessing ).to_dict() - # TODO reporting - defaults["reporting"] = {} return cls(**{k: v for k, v in defaults.items()}) def load_logs(self): diff --git a/src/postprocessor/core/processor.py b/src/postprocessor/core/processor.py index 008b6ebf080db88399d780b03ebe118e88133ee5..6ccd510497087f93b8d71389737a61c374a415ea 100644 --- a/src/postprocessor/core/processor.py +++ b/src/postprocessor/core/processor.py @@ -132,15 +132,17 @@ class PostProcessor(ProcessABC): def run_prepost(self): """ - Run picker and merger and get lineages. + Run merger, get lineages, and then run picker. Necessary before any processes can run. """ # run merger record = self.signal.get_raw(self.targets["prepost"]["merger"]) merges = self.merger.run(record) - # get lineages from picker + # get lineages from cells object attached to picker lineage = _assoc_indices_to_3d(self.picker.cells.mothers_daughters) + if not np.any(lineage): + breakpoint() if merges.any(): # update lineages and merges after merging new_lineage, new_merges = merge_lineage(lineage, merges) diff --git a/src/postprocessor/core/reshapers/picker.py b/src/postprocessor/core/reshapers/picker.py index b333f18300689885d4af34d1cfd1bb85e1b10ab5..201edd738cc829db78ccc370a59938e535d8bd6e 100644 --- a/src/postprocessor/core/reshapers/picker.py +++ b/src/postprocessor/core/reshapers/picker.py @@ -53,8 +53,10 @@ class Picker(LineageProcess): mothers_daughters: t.Optional[np.ndarray] = None, ) -> pd.MultiIndex: """ - Return rows of a signal corresponding to either mothers, daughters, - or mother-daughter pairs using lineage information. + Return rows of a signal using lineage information. + + Rows correspond to either mothers, daughters, or mother-daughter + pairs. """ cells_present = drop_mother_label(signal.index) mothers_daughters = self.get_lineage_information(signal) @@ -65,10 +67,10 @@ class Picker(LineageProcess): def run(self, signal): """ - Pick indices from the index of a signal's dataframe and return - as an array. + Pick indices from the index of a signal's dataframe. Typically, we first pick by lineage, then by condition. + The indices are returned as an array. """ self.orig_signal = signal indices = set(signal.index)