diff --git a/io/cells.py b/io/cells.py
index baabc05928c1dd6397a64c6fd44bb844e5047b03..0fc63fd41946903d36a8b619d27c364cbbcf3e43 100644
--- a/io/cells.py
+++ b/io/cells.py
@@ -12,82 +12,47 @@ from scipy.sparse.base import isdense
 from utils_find_1st import cmp_equal, find_1st
 
 
-class Cells:
+class CellsLinear:
     """
-    An object that gathers information about all the cells in a given
-    trap.
-    This is the abstract object, used for type testing.
+    Extracts information from an h5 file. This class accesses:
+
+    'cell_info', which contains 'angles', 'cell_label', 'centres',
+    'edgemasks', 'ellipse_dims', 'mother_assign', 'mother_assign_dynamic',
+    'radii', 'timepoint', 'trap'.
+    All of these except for 'edgemasks' are a 1D ndarray.
+
+    'trap_info', which contains 'drifts', 'trap_locations'
+
     """
 
+    def __init__(self, filename, path="cell_info"):
+        self.filename: t.Optional[t.Union[str, PosixPath]] = filename
+        self.cinfo_path: t.Optional[str] = path
+        self._edgemasks: t.Optional[str] = None
+        self._tile_size: t.Optional[int] = None
+
     @classmethod
     def from_source(cls, source: t.Union[PosixPath, str]):
         return cls(Path(source))
 
     @staticmethod
-    def _asdense(array):
+    def _asdense(array: np.ndarray):
         if not isdense(array):
             array = array.todense()
         return array
 
     @staticmethod
-    def _astype(array, kind):
+    def _astype(array: np.ndarray, kind: str):
         # Convert sparse arrays if needed and if kind is 'mask' it fills the outline
         array = Cells._asdense(array)
         if kind == "mask":
             array = ndimage.binary_fill_holes(array).astype(bool)
         return array
 
-    @classmethod
-    def hdf(cls, fpath):
-        return CellsHDF(fpath)
-
-
-class CellsHDF(Cells):
-    def __init__(
-        self, filename: t.Union[str, PosixPath], path: str = "cell_info"
-    ):
-        """
-        Extracts information from an h5 file. This class accesses:
-
-        'cell_info', which contains 'angles', 'cell_label', 'centres',
-        'edgemasks', 'ellipse_dims', 'mother_assign', 'mother_assign_dynamic',
-        'radii', 'timepoint', 'trap'.
-        All of these except for 'edgemasks' are a 1D ndarray.
-
-        'trap_info', which contains 'drifts', 'trap_locations'
-        """
-        self.filename = filename
-        self.cinfo_path = path
-        self._edgem_indices = None
-        self._edgemasks = None
-        self._tile_size = None
-
-    def __getitem__(self, item: str):
-        """
-        Defines attributes from the h5 file, which can then be accessed like items in a dictionary.
-
-        Data is accessed from /cinfo_path in the h5 file via _fetch.
-
-        Alan: is cells[X] and cells._X better than cells.X?
-        """
-        if item == "edgemasks":
-            return self.edgemasks
-        else:
-            _item = "_" + item
-            if not hasattr(self, _item):
-                # define from the h5 file
-                setattr(self, _item, self._fetch(item))
-            return getattr(self, _item)
-
     def _get_idx(self, cell_id: int, trap_id: int):
         # returns boolean array of time points where both the cell with cell_id and the trap with trap_id exist
         return (self["cell_label"] == cell_id) & (self["trap"] == trap_id)
 
-    def _fetch(self, item: str):
-        # get data from /cinfo_path in h5 file
-        with h5py.File(self.filename, mode="r") as f:
-            return f[self.cinfo_path][item][()]
-
     @property
     def max_labels(self) -> t.List[int]:
         return [max(self.labels_in_trap(i)) for i in range(self.ntraps)]
@@ -108,24 +73,13 @@ class CellsHDF(Cells):
         # returns a list of traps
         return list(set(self["trap"]))
 
-    @property
-    def ntimepoints(self) -> int:
-        return self["timepoint"].max() + 1
-
     @property
     def tile_size(self) -> t.Union[int, t.Tuple[int], None]:
         if self._tile_size is None:
             with h5py.File(self.filename, mode="r") as f:
-                self._tile_size == f["trap_info/tile_size"][0]
+                self._tile_size = f["trap_info/tile_size"][0]
         return self._tile_size
 
-    @property
-    def edgem_indices(self) -> t.Union[np.ndarray, None]:
-        if self._edgem_indices is None:
-            edgem_path = "edgemasks/indices"
-            self._edgem_indices = load_complex(self._fetch(edgem_path))
-        return self._edgem_indices
-
     def nonempty_tp_in_trap(self, trap_id: int) -> set:
         # given a trap_id returns time points in which cells are available
         return set(self["timepoint"][self["trap"] == trap_id])
@@ -134,7 +88,7 @@ class CellsHDF(Cells):
     def edgemasks(self) -> t.List[np.ndarray]:
         # returns the masks per tile
         if self._edgemasks is None:
-            edgem_path = "edgemasks"
+            edgem_path: str = "edgemasks"
             self._edgemasks = self._fetch(edgem_path)
         return self._edgemasks
 
@@ -170,27 +124,16 @@ class CellsHDF(Cells):
             edgem_ix,
         )
 
-    def outline(
-        self, cell_id: int, trap_id: int
-    ) -> t.Tuple[t.List, np.ndarray]:
-        times, indices, cell_ix = self.where(cell_id, trap_id)
-        return times, self["edgemasks"][cell_ix, times]
-
     def mask(self, cell_id, trap_id):
         times, outlines = self.outline(cell_id, trap_id)
         return times, np.array(
             [ndimage.morphology.binary_fill_holes(o) for o in outlines]
         )
 
-    def at_time(
-        self, timepoint: int, kind: str = "mask"
-    ) -> t.Dict[int, np.ndarray]:
+    def at_time(self, timepoint, kind="mask"):
         ix = self["timepoint"] == timepoint
-        cell_ix = self["cell_label"][ix]
         traps = self["trap"][ix]
-        indices = traps + 1j * cell_ix
-        choose = np.in1d(self.edgem_indices, indices)
-        edgemasks = self["edgemasks"][choose, timepoint]
+        edgemasks = self._edgem_from_masking(ix)
         masks = [
             self._astype(edgemask, kind)
             for edgemask in edgemasks
@@ -198,34 +141,25 @@ class CellsHDF(Cells):
         ]
         return self.group_by_traps(traps, masks)
 
-    def group_by_traps(self, traps, data):
-        # returns a dict with traps as keys and labels as value
-        # Alan: what is data?
-        iterator = groupby(zip(traps, data), lambda x: x[0])
+    def group_by_traps(
+        self, traps: t.Collection, cell_labels: t.Collection
+    ) -> t.Dict[int, t.List[t.int]]:
+        # returns a dict with traps as keys and list of labels as value
+        # Data is a
+        iterator = groupby(zip(traps, cell_labels), lambda x: x[0])
         d = {key: [x[1] for x in group] for key, group in iterator}
         d = {i: d.get(i, []) for i in self.traps}
         return d
 
-    def labels_in_trap(self, trap_id):
+    def labels_in_trap(self, trap_id: int) -> t.Set[int]:
         # return set of cell ids for a given trap
         return set((self["cell_label"][self["trap"] == trap_id]))
 
-    def labels_at_time(self, timepoint):
+    def labels_at_time(self, timepoint: int) -> t.Dict[int, t.List[int]]:
         labels = self["cell_label"][self["timepoint"] == timepoint]
         traps = self["trap"][self["timepoint"] == timepoint]
         return self.group_by_traps(traps, labels)
 
-
-class CellsLinear(CellsHDF):
-    """
-    Reimplement functions from CellsHDF to save edgemasks in a (N,tile_size, tile_size) array
-
-    This overrides the previous implementation of at_time.
-    """
-
-    def __init__(self, filename, path="cell_info"):
-        super().__init__(filename, path=path)
-
     def __getitem__(self, item):
         assert item != "edgemasks", "Edgemasks must not be loaded as a whole"
 
@@ -249,22 +183,11 @@ class CellsLinear(CellsHDF):
 
         return edgem
 
-    def outline(self, cell_id, trap_id):
+    def outline(self, cell_id: int, trap_id: int):
         id_mask = self._get_idx(cell_id, trap_id)
         times = self["timepoint"][id_mask]
 
-        return times, self.edgem_from_masking(id_mask)
-
-    def at_time(self, timepoint, kind="mask"):
-        ix = self["timepoint"] == timepoint
-        traps = self["trap"][ix]
-        edgemasks = self._edgem_from_masking(ix)
-        masks = [
-            self._astype(edgemask, kind)
-            for edgemask in edgemasks
-            if edgemask.any()
-        ]
-        return self.group_by_traps(traps, masks)
+        return times, self._edgem_from_masking(id_mask)
 
     @property
     def ntimepoints(self) -> int: