diff --git a/io/cells.py b/io/cells.py index baabc05928c1dd6397a64c6fd44bb844e5047b03..0fc63fd41946903d36a8b619d27c364cbbcf3e43 100644 --- a/io/cells.py +++ b/io/cells.py @@ -12,82 +12,47 @@ from scipy.sparse.base import isdense from utils_find_1st import cmp_equal, find_1st -class Cells: +class CellsLinear: """ - An object that gathers information about all the cells in a given - trap. - This is the abstract object, used for type testing. + Extracts information from an h5 file. This class accesses: + + 'cell_info', which contains 'angles', 'cell_label', 'centres', + 'edgemasks', 'ellipse_dims', 'mother_assign', 'mother_assign_dynamic', + 'radii', 'timepoint', 'trap'. + All of these except for 'edgemasks' are a 1D ndarray. + + 'trap_info', which contains 'drifts', 'trap_locations' + """ + def __init__(self, filename, path="cell_info"): + self.filename: t.Optional[t.Union[str, PosixPath]] = filename + self.cinfo_path: t.Optional[str] = path + self._edgemasks: t.Optional[str] = None + self._tile_size: t.Optional[int] = None + @classmethod def from_source(cls, source: t.Union[PosixPath, str]): return cls(Path(source)) @staticmethod - def _asdense(array): + def _asdense(array: np.ndarray): if not isdense(array): array = array.todense() return array @staticmethod - def _astype(array, kind): + def _astype(array: np.ndarray, kind: str): # Convert sparse arrays if needed and if kind is 'mask' it fills the outline array = Cells._asdense(array) if kind == "mask": array = ndimage.binary_fill_holes(array).astype(bool) return array - @classmethod - def hdf(cls, fpath): - return CellsHDF(fpath) - - -class CellsHDF(Cells): - def __init__( - self, filename: t.Union[str, PosixPath], path: str = "cell_info" - ): - """ - Extracts information from an h5 file. This class accesses: - - 'cell_info', which contains 'angles', 'cell_label', 'centres', - 'edgemasks', 'ellipse_dims', 'mother_assign', 'mother_assign_dynamic', - 'radii', 'timepoint', 'trap'. - All of these except for 'edgemasks' are a 1D ndarray. - - 'trap_info', which contains 'drifts', 'trap_locations' - """ - self.filename = filename - self.cinfo_path = path - self._edgem_indices = None - self._edgemasks = None - self._tile_size = None - - def __getitem__(self, item: str): - """ - Defines attributes from the h5 file, which can then be accessed like items in a dictionary. - - Data is accessed from /cinfo_path in the h5 file via _fetch. - - Alan: is cells[X] and cells._X better than cells.X? - """ - if item == "edgemasks": - return self.edgemasks - else: - _item = "_" + item - if not hasattr(self, _item): - # define from the h5 file - setattr(self, _item, self._fetch(item)) - return getattr(self, _item) - def _get_idx(self, cell_id: int, trap_id: int): # returns boolean array of time points where both the cell with cell_id and the trap with trap_id exist return (self["cell_label"] == cell_id) & (self["trap"] == trap_id) - def _fetch(self, item: str): - # get data from /cinfo_path in h5 file - with h5py.File(self.filename, mode="r") as f: - return f[self.cinfo_path][item][()] - @property def max_labels(self) -> t.List[int]: return [max(self.labels_in_trap(i)) for i in range(self.ntraps)] @@ -108,24 +73,13 @@ class CellsHDF(Cells): # returns a list of traps return list(set(self["trap"])) - @property - def ntimepoints(self) -> int: - return self["timepoint"].max() + 1 - @property def tile_size(self) -> t.Union[int, t.Tuple[int], None]: if self._tile_size is None: with h5py.File(self.filename, mode="r") as f: - self._tile_size == f["trap_info/tile_size"][0] + self._tile_size = f["trap_info/tile_size"][0] return self._tile_size - @property - def edgem_indices(self) -> t.Union[np.ndarray, None]: - if self._edgem_indices is None: - edgem_path = "edgemasks/indices" - self._edgem_indices = load_complex(self._fetch(edgem_path)) - return self._edgem_indices - def nonempty_tp_in_trap(self, trap_id: int) -> set: # given a trap_id returns time points in which cells are available return set(self["timepoint"][self["trap"] == trap_id]) @@ -134,7 +88,7 @@ class CellsHDF(Cells): def edgemasks(self) -> t.List[np.ndarray]: # returns the masks per tile if self._edgemasks is None: - edgem_path = "edgemasks" + edgem_path: str = "edgemasks" self._edgemasks = self._fetch(edgem_path) return self._edgemasks @@ -170,27 +124,16 @@ class CellsHDF(Cells): edgem_ix, ) - def outline( - self, cell_id: int, trap_id: int - ) -> t.Tuple[t.List, np.ndarray]: - times, indices, cell_ix = self.where(cell_id, trap_id) - return times, self["edgemasks"][cell_ix, times] - def mask(self, cell_id, trap_id): times, outlines = self.outline(cell_id, trap_id) return times, np.array( [ndimage.morphology.binary_fill_holes(o) for o in outlines] ) - def at_time( - self, timepoint: int, kind: str = "mask" - ) -> t.Dict[int, np.ndarray]: + def at_time(self, timepoint, kind="mask"): ix = self["timepoint"] == timepoint - cell_ix = self["cell_label"][ix] traps = self["trap"][ix] - indices = traps + 1j * cell_ix - choose = np.in1d(self.edgem_indices, indices) - edgemasks = self["edgemasks"][choose, timepoint] + edgemasks = self._edgem_from_masking(ix) masks = [ self._astype(edgemask, kind) for edgemask in edgemasks @@ -198,34 +141,25 @@ class CellsHDF(Cells): ] return self.group_by_traps(traps, masks) - def group_by_traps(self, traps, data): - # returns a dict with traps as keys and labels as value - # Alan: what is data? - iterator = groupby(zip(traps, data), lambda x: x[0]) + def group_by_traps( + self, traps: t.Collection, cell_labels: t.Collection + ) -> t.Dict[int, t.List[t.int]]: + # returns a dict with traps as keys and list of labels as value + # Data is a + iterator = groupby(zip(traps, cell_labels), lambda x: x[0]) d = {key: [x[1] for x in group] for key, group in iterator} d = {i: d.get(i, []) for i in self.traps} return d - def labels_in_trap(self, trap_id): + def labels_in_trap(self, trap_id: int) -> t.Set[int]: # return set of cell ids for a given trap return set((self["cell_label"][self["trap"] == trap_id])) - def labels_at_time(self, timepoint): + def labels_at_time(self, timepoint: int) -> t.Dict[int, t.List[int]]: labels = self["cell_label"][self["timepoint"] == timepoint] traps = self["trap"][self["timepoint"] == timepoint] return self.group_by_traps(traps, labels) - -class CellsLinear(CellsHDF): - """ - Reimplement functions from CellsHDF to save edgemasks in a (N,tile_size, tile_size) array - - This overrides the previous implementation of at_time. - """ - - def __init__(self, filename, path="cell_info"): - super().__init__(filename, path=path) - def __getitem__(self, item): assert item != "edgemasks", "Edgemasks must not be loaded as a whole" @@ -249,22 +183,11 @@ class CellsLinear(CellsHDF): return edgem - def outline(self, cell_id, trap_id): + def outline(self, cell_id: int, trap_id: int): id_mask = self._get_idx(cell_id, trap_id) times = self["timepoint"][id_mask] - return times, self.edgem_from_masking(id_mask) - - def at_time(self, timepoint, kind="mask"): - ix = self["timepoint"] == timepoint - traps = self["trap"][ix] - edgemasks = self._edgem_from_masking(ix) - masks = [ - self._astype(edgemask, kind) - for edgemask in edgemasks - if edgemask.any() - ] - return self.group_by_traps(traps, masks) + return times, self._edgem_from_masking(id_mask) @property def ntimepoints(self) -> int: