From 403b483d77d56aea0215eba19f88ff30b6ab95ec Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <alan.munoz@ed.ac.uk>
Date: Tue, 4 Oct 2022 20:04:10 +0100
Subject: [PATCH] refactor(cells): make lineage consistent

---
 src/agora/io/cells.py | 26 ++++++++++++++++----------
 1 file changed, 16 insertions(+), 10 deletions(-)

diff --git a/src/agora/io/cells.py b/src/agora/io/cells.py
index 095b59c4..d1108691 100644
--- a/src/agora/io/cells.py
+++ b/src/agora/io/cells.py
@@ -3,7 +3,7 @@ import typing as t
 from collections.abc import Iterable
 from itertools import groupby
 from pathlib import Path, PosixPath
-from functools import lru_cache
+from functools import lru_cache, cached_property
 
 import h5py
 import numpy as np
@@ -264,10 +264,11 @@ class Cells:
         rand = np.random.randint(mat.sum())
         return (traps[rand], tps[rand])
 
+    @lru_cache(20)
     def mothers_in_trap(self, trap_id: int):
         return self.mothers[trap_id]
 
-    @property
+    @cached_property
     def mothers(self):
         """
         Return nested list with final prediction of mother id for each cell
@@ -279,24 +280,29 @@ class Cells:
             self.ntraps,
         )
 
-    @property
-    def mothers_daughters(self):
+    @cached_property
+    def mothers_daughters(self) -> np.ndarray:
+        """
+        Return mothers and daugters as a single array with three columns:
+        trap, mothers and daughters
+        """
         nested_massign = self.mothers
 
         if sum([x for y in nested_massign for x in y]):
-            mothers, daughters = zip(
-                *[
-                    ((tid, m), (tid, d))
+            mothers_daughters = np.array(
+                [
+                    (tid, m, d)
                     for tid, trapcells in enumerate(nested_massign)
                     for d, m in enumerate(trapcells, 1)
                     if m
-                ]
+                ],
+                dtype=np.uint16,
             )
         else:
-            mothers, daughters = ([], [])
+            mothers_daughters = np.array([])
             # print("Warning:Cells: No mother-daughters assigned")
 
-        return mothers, daughters
+        return mothers_daughters
 
     @staticmethod
     def mother_assign_to_mb_matrix(ma: t.List[np.array]):
-- 
GitLab