Skip to content
Snippets Groups Projects
Commit 96c53f04 authored by Alán Muñoz's avatar Alán Muñoz
Browse files

[WIP] refactor(cells): clean up and add typing

parent 337bb7ea
No related branches found
No related tags found
No related merge requests found
......@@ -12,82 +12,47 @@ from scipy.sparse.base import isdense
from utils_find_1st import cmp_equal, find_1st
class Cells:
class CellsLinear:
"""
An object that gathers information about all the cells in a given
trap.
This is the abstract object, used for type testing.
Extracts information from an h5 file. This class accesses:
'cell_info', which contains 'angles', 'cell_label', 'centres',
'edgemasks', 'ellipse_dims', 'mother_assign', 'mother_assign_dynamic',
'radii', 'timepoint', 'trap'.
All of these except for 'edgemasks' are a 1D ndarray.
'trap_info', which contains 'drifts', 'trap_locations'
"""
def __init__(self, filename, path="cell_info"):
self.filename: t.Optional[t.Union[str, PosixPath]] = filename
self.cinfo_path: t.Optional[str] = path
self._edgemasks: t.Optional[str] = None
self._tile_size: t.Optional[int] = None
@classmethod
def from_source(cls, source: t.Union[PosixPath, str]):
return cls(Path(source))
@staticmethod
def _asdense(array):
def _asdense(array: np.ndarray):
if not isdense(array):
array = array.todense()
return array
@staticmethod
def _astype(array, kind):
def _astype(array: np.ndarray, kind: str):
# Convert sparse arrays if needed and if kind is 'mask' it fills the outline
array = Cells._asdense(array)
if kind == "mask":
array = ndimage.binary_fill_holes(array).astype(bool)
return array
@classmethod
def hdf(cls, fpath):
return CellsHDF(fpath)
class CellsHDF(Cells):
def __init__(
self, filename: t.Union[str, PosixPath], path: str = "cell_info"
):
"""
Extracts information from an h5 file. This class accesses:
'cell_info', which contains 'angles', 'cell_label', 'centres',
'edgemasks', 'ellipse_dims', 'mother_assign', 'mother_assign_dynamic',
'radii', 'timepoint', 'trap'.
All of these except for 'edgemasks' are a 1D ndarray.
'trap_info', which contains 'drifts', 'trap_locations'
"""
self.filename = filename
self.cinfo_path = path
self._edgem_indices = None
self._edgemasks = None
self._tile_size = None
def __getitem__(self, item: str):
"""
Defines attributes from the h5 file, which can then be accessed like items in a dictionary.
Data is accessed from /cinfo_path in the h5 file via _fetch.
Alan: is cells[X] and cells._X better than cells.X?
"""
if item == "edgemasks":
return self.edgemasks
else:
_item = "_" + item
if not hasattr(self, _item):
# define from the h5 file
setattr(self, _item, self._fetch(item))
return getattr(self, _item)
def _get_idx(self, cell_id: int, trap_id: int):
# returns boolean array of time points where both the cell with cell_id and the trap with trap_id exist
return (self["cell_label"] == cell_id) & (self["trap"] == trap_id)
def _fetch(self, item: str):
# get data from /cinfo_path in h5 file
with h5py.File(self.filename, mode="r") as f:
return f[self.cinfo_path][item][()]
@property
def max_labels(self) -> t.List[int]:
return [max(self.labels_in_trap(i)) for i in range(self.ntraps)]
......@@ -108,24 +73,13 @@ class CellsHDF(Cells):
# returns a list of traps
return list(set(self["trap"]))
@property
def ntimepoints(self) -> int:
return self["timepoint"].max() + 1
@property
def tile_size(self) -> t.Union[int, t.Tuple[int], None]:
if self._tile_size is None:
with h5py.File(self.filename, mode="r") as f:
self._tile_size == f["trap_info/tile_size"][0]
self._tile_size = f["trap_info/tile_size"][0]
return self._tile_size
@property
def edgem_indices(self) -> t.Union[np.ndarray, None]:
if self._edgem_indices is None:
edgem_path = "edgemasks/indices"
self._edgem_indices = load_complex(self._fetch(edgem_path))
return self._edgem_indices
def nonempty_tp_in_trap(self, trap_id: int) -> set:
# given a trap_id returns time points in which cells are available
return set(self["timepoint"][self["trap"] == trap_id])
......@@ -134,7 +88,7 @@ class CellsHDF(Cells):
def edgemasks(self) -> t.List[np.ndarray]:
# returns the masks per tile
if self._edgemasks is None:
edgem_path = "edgemasks"
edgem_path: str = "edgemasks"
self._edgemasks = self._fetch(edgem_path)
return self._edgemasks
......@@ -170,27 +124,16 @@ class CellsHDF(Cells):
edgem_ix,
)
def outline(
self, cell_id: int, trap_id: int
) -> t.Tuple[t.List, np.ndarray]:
times, indices, cell_ix = self.where(cell_id, trap_id)
return times, self["edgemasks"][cell_ix, times]
def mask(self, cell_id, trap_id):
times, outlines = self.outline(cell_id, trap_id)
return times, np.array(
[ndimage.morphology.binary_fill_holes(o) for o in outlines]
)
def at_time(
self, timepoint: int, kind: str = "mask"
) -> t.Dict[int, np.ndarray]:
def at_time(self, timepoint, kind="mask"):
ix = self["timepoint"] == timepoint
cell_ix = self["cell_label"][ix]
traps = self["trap"][ix]
indices = traps + 1j * cell_ix
choose = np.in1d(self.edgem_indices, indices)
edgemasks = self["edgemasks"][choose, timepoint]
edgemasks = self._edgem_from_masking(ix)
masks = [
self._astype(edgemask, kind)
for edgemask in edgemasks
......@@ -198,34 +141,25 @@ class CellsHDF(Cells):
]
return self.group_by_traps(traps, masks)
def group_by_traps(self, traps, data):
# returns a dict with traps as keys and labels as value
# Alan: what is data?
iterator = groupby(zip(traps, data), lambda x: x[0])
def group_by_traps(
self, traps: t.Collection, cell_labels: t.Collection
) -> t.Dict[int, t.List[t.int]]:
# returns a dict with traps as keys and list of labels as value
# Data is a
iterator = groupby(zip(traps, cell_labels), lambda x: x[0])
d = {key: [x[1] for x in group] for key, group in iterator}
d = {i: d.get(i, []) for i in self.traps}
return d
def labels_in_trap(self, trap_id):
def labels_in_trap(self, trap_id: int) -> t.Set[int]:
# return set of cell ids for a given trap
return set((self["cell_label"][self["trap"] == trap_id]))
def labels_at_time(self, timepoint):
def labels_at_time(self, timepoint: int) -> t.Dict[int, t.List[int]]:
labels = self["cell_label"][self["timepoint"] == timepoint]
traps = self["trap"][self["timepoint"] == timepoint]
return self.group_by_traps(traps, labels)
class CellsLinear(CellsHDF):
"""
Reimplement functions from CellsHDF to save edgemasks in a (N,tile_size, tile_size) array
This overrides the previous implementation of at_time.
"""
def __init__(self, filename, path="cell_info"):
super().__init__(filename, path=path)
def __getitem__(self, item):
assert item != "edgemasks", "Edgemasks must not be loaded as a whole"
......@@ -249,22 +183,11 @@ class CellsLinear(CellsHDF):
return edgem
def outline(self, cell_id, trap_id):
def outline(self, cell_id: int, trap_id: int):
id_mask = self._get_idx(cell_id, trap_id)
times = self["timepoint"][id_mask]
return times, self.edgem_from_masking(id_mask)
def at_time(self, timepoint, kind="mask"):
ix = self["timepoint"] == timepoint
traps = self["trap"][ix]
edgemasks = self._edgem_from_masking(ix)
masks = [
self._astype(edgemask, kind)
for edgemask in edgemasks
if edgemask.any()
]
return self.group_by_traps(traps, masks)
return times, self._edgem_from_masking(id_mask)
@property
def ntimepoints(self) -> int:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment