From 09bb025c5919d8fa4296a9cabbab8b20dac955b5 Mon Sep 17 00:00:00 2001
From: Diane Adjavon <diane.adjavon@ed.ac.uk>
Date: Sat, 30 May 2020 13:05:15 +0200
Subject: [PATCH] Multi-session for Baby Client. Pipelining.

---
 .gitignore               |  3 ++
 core/baby_client.py      | 81 ++++++++++++++++++++++++++++++------
 core/experiment.py       | 23 ++++++++---
 core/pipeline.py         | 88 ++++++++++++++++++++++++++++++++++++++++
 core/segment.py          | 22 ++++------
 test/test_baby_client.py |  5 ++-
 test/test_pipeline.py    | 34 ++++++++++++++++
 7 files changed, 225 insertions(+), 31 deletions(-)
 create mode 100644 core/pipeline.py
 create mode 100644 test/test_pipeline.py

diff --git a/.gitignore b/.gitignore
index 2b136197..7ef2a437 100644
--- a/.gitignore
+++ b/.gitignore
@@ -112,3 +112,6 @@ omero_py/pipeline/
 **.ipynb
 data/
 notebooks/
+*.pdf
+*.h5
+*.hdf5
diff --git a/core/baby_client.py b/core/baby_client.py
index 1f0fb2b6..6f4ba523 100644
--- a/core/baby_client.py
+++ b/core/baby_client.py
@@ -1,6 +1,10 @@
+import itertools
 import json
+import numpy as np
+import pandas as pd
 import re
 import requests
+import tables
 from requests.exceptions import Timeout, HTTPError
 from requests_toolbelt.multipart.encoder import MultipartEncoder
 
@@ -16,6 +20,7 @@ class BabyNoSilent(Exception):
     pass
 
 
+# Todo: add defaults!
 def choose_model_from_params(valid_models,
                              modelset_filter=None, camera=None, channel=None,
                              zoom=None, n_stacks=None, **kwargs):
@@ -167,8 +172,7 @@ class BabyClient:
             for trap_id in trap_locations.columns:
                 # Finish processing previously queued images
                 self.flush_processing()
-                self.process_trap(prompt, trap_id,
-                                  trap_locations.filter(items=[trap_id]))
+                self.process_trap(prompt, trap_id, trap_locations)
         except KeyError as e:
             # TODO log that this will not be processed
             raise e
@@ -176,14 +180,58 @@ class BabyClient:
         while len(self.processing) > 0:
             self.flush_processing()
 
-    def process_trap(self, prompt, trap_id, trap_location):
-        tile = get_trap_timelapse(self.raw_expt, trap_location, trap_id)
+    # Todo: defined based on the model configuration what z should be
+    def process_trap(self, prompt, trap_id, trap_locations, tile_size=81,
+                     z=[0, 1, 2, 3, 4]):
+        tile = get_trap_timelapse(self.raw_expt, trap_locations, trap_id,
+                                  tile_size=tile_size, z=z)
+        tile = np.squeeze(tile)
+        # try:
+        #     self.store._handle.create_group(prompt, f'trap{trap_id}')
+        # except tables.exceptions.NodeError as e:
+        #     pass
         # Get the corresponding session
-        trap_key = prompt + f'trap{trap_id}'
+        trap_key = prompt + f'/trap{trap_id}'
         session_id = self.sessions[trap_key]
-        self.queue_image(tile, session_id)
-        self.processing.append(trap_key)
+        batches = np.array_split(tile, 8, axis=0)
+        for batch in batches:
+            self.queue_image(batch, session_id)
+            self.processing.append(trap_key)
+            self.flush_processing()
+
+    def format_seg_result(self, result, time_origin=0, max_size=16):
+        # Todo: update time origin at each step.
+        for i, res in enumerate(result):
+            res['timepoint'] = [i + time_origin] * len(res['cell_label'])
+        merged = {k: list(itertools.chain.from_iterable(
+            res[k] for res in result))
+            for k in result[0].keys()}
+        df = pd.DataFrame(merged)
+        df.set_index('timepoint', inplace=True)
+        if len(df) == 0:
+            return dict()
+
+        # Todo: split more systematically
+        for k in ['angles', 'radii']:
+            values = df[k].tolist()
+            for val in values:
+                val += [np.nan] * (max_size - len(val))
+            try:
+                df[[k + str(i) for i in range(max_size)]] = \
+                    pd.DataFrame(values, index=df.index)
+            except ValueError as e:
+                print(k)
+                print([len(val) for val in values])
+                print(result)
+                raise e
+        df[['centrex', 'centrey']] = pd.DataFrame(df['centres'].tolist(),
+                                                  index=df.index)
+        df.drop(['centres', 'angles', 'radii'], axis=1, inplace=True)
+
+        per_cell_dfs = {i: x for i, x in df.groupby(df['cell_label'])}
+        return per_cell_dfs
 
+    # Todo: batching
     def flush_processing(self):
         """
         Get the results of previously queued images.
@@ -191,15 +239,25 @@ class BabyClient:
         """
         for trap_key in self.processing:
             try:
-                segmentation = self.get_segmentation(self.sessions[trap_key])
-                #TODO format the segmentation then add to a table rather
-                # than just assigning
-                self.store[trap_key] = segmentation
+                result = self.get_segmentation(self.sessions[trap_key])
+                segmentation = self.format_seg_result(result)
+                for i, seg in segmentation.items():
+                    cell_key = trap_key + f'/cell{i}'
+                    try:
+                        self.store.append(cell_key, seg)
+                    except Exception as e:
+                        print(seg)
+                        raise e
                 self.processing.remove(trap_key)
             except Timeout:
                 continue
             except HTTPError:
                 continue
+            except TypeError as e:
+                raise e
+            except KeyError as e:
+                print(self.store.keys())
+                raise e
 
     def run(self, keys, store):
         if self.store is None:
@@ -207,4 +265,3 @@ class BabyClient:
         for prompt in keys:
             self.process_position(prompt)
         return keys
-
diff --git a/core/experiment.py b/core/experiment.py
index b340160d..bbbed2e4 100644
--- a/core/experiment.py
+++ b/core/experiment.py
@@ -41,7 +41,9 @@ class Experiment(abc.ABC):
     #metadata_parser = AcqMetadataParser()
 
     def __init__(self):
+        self.exptID = ''
         self._current_position = None
+        self.position_to_process = 0
 
     def __getitem__(self, item):
         return self.current_position[item]
@@ -97,7 +99,19 @@ class Experiment(abc.ABC):
                                                    z_positions, channels,
                                                    timepoints)
 
+    # Pipelining
+    def run(self, keys, store):
+        try:
+            self.current_position = self.positions[self.position_to_process]
+            # Todo: check if we should use the position's id or name
+            return ['/'.join(['', self.exptID, self.current_position.name])]
+            # Todo: write to store
+        except IndexError:
+            return None
 
+
+
+# Todo: cache images like in ExperimentLocal
 class ExperimentOMERO(Experiment):
     """
     Experiment class to organise different timelapses.
@@ -128,8 +142,8 @@ class ExperimentOMERO(Experiment):
 
     def get_position(self, position):
         """Get a Timelapse object for a given position by name"""
-        assert position in self.positions, "Position not available."
-        img = self.connection.getObject("Image", self._positions[position])
+        #assert position in self.positions, "Position not available."
+        img = self.connection.getObject("Image", self.positions[position])
         return TimelapseOMERO(img)
 
     def cache_locally(self, root_dir='./', positions=None, channels=None,
@@ -266,9 +280,8 @@ class ExperimentLocal(Experiment):
         return self._positions
 
     def get_position(self, position):
-        assert position in self.positions, "Position {} not available in {" \
-                                           "}.".format(position, self.positions)
-        # TODO cache positions?
+        # assert position in self.positions, "Position {} not available in {" \
+        #                                    "}.".format(position, self.positions)
         return TimelapseLocal(position, self.root_dir)
 
 
diff --git a/core/pipeline.py b/core/pipeline.py
new file mode 100644
index 00000000..78eaa305
--- /dev/null
+++ b/core/pipeline.py
@@ -0,0 +1,88 @@
+"""
+Pipeline and chaining elements.
+"""
+from abc import ABC, abstractmethod
+from typing import Iterable, List
+
+import pandas as pd
+import tables as tb
+
+from core.experiment import Experiment, ExperimentLocal, ExperimentOMERO
+from core.segment import Tiler
+from core.baby_client import BabyClient
+
+
+class Results:
+    """
+    Object storing the data from the Pipeline.
+    Uses pandas' HDFStore object.
+
+    In addition, it implements:
+     - IO functionality (read from file, write to file)
+
+    """
+    def __init__(self):
+        pass
+
+    def to_store(self):
+        pass
+
+    def from_json(self):
+        pass
+
+
+class PipelineStep(ABC):
+    @abstractmethod
+    def run(self, keys: List[str], store: pd.HDFStore) -> List[str]:
+        """
+        Abstract run method, when implemented by subclasses, runs analysis
+        on the keys and saves results in store.
+        :param keys: list of keys on which to run analysis
+        :return: A set of keys now available for anlaysis for the next step.
+        """
+        return keys
+
+
+class Pipeline:
+    """
+    A chained set of Pipeline elements connected through pipes.
+    """
+
+    def __init__(self, store: pd.HDFStore, pipeline_steps: Iterable):
+        self.store = store
+        # Setup steps
+        self.steps = pipeline_steps
+
+    def run(self, max_runs=2):
+        keys = []
+        runs = 0
+        while runs <= max_runs and keys is not None:
+            # TODO make run functions return None when finished
+            for pipeline_step in self.steps:
+                keys = pipeline_step.run(keys, self.store)
+            runs += 1
+
+# Todo future: could have more flexibility by using pytables directly. At
+#  the moment using pandas.HDFstore, does not allow for writing arrays to
+#  the file, which would be more convenient for variable-sized data.
+class Store:
+    """
+    Implements an interface to pytables.
+    """
+    def __init__(self, filename: str, title: str = "Test file"):
+        self.h5file = tb.open_file(filename, mode='w', title=title)
+
+    def add_group(self, root: str, name: str, description=""):
+        # Todo: infer root from name
+        group = self.h5file.create_group(root, name, description)
+        return group
+
+    def add_array(self, group, name, values=None, title=""):
+        self.h5file.create_array(group, name, values, title)
+
+    def close(self):
+        self.h5file.close()
+
+
+
+
diff --git a/core/segment.py b/core/segment.py
index 2ff32bca..60b148fc 100644
--- a/core/segment.py
+++ b/core/segment.py
@@ -143,20 +143,16 @@ class Tiler(object):
         for pos in self.positions:
             self.current_position = pos #So tiling occurs
             store_key = '/'.join([expt_root, pos, 'trap_locations'])
+            # Clear the trap_locations
+            store.remove(store_key)
             store.append(store_key, self.trap_locations[pos])
         return
 
-    def fit_to_pipe(self, pipe, split_results=True):
-        """
-        Takes a pipe of core.timelapse.Timelapse objects and their
-        corresponding experiment ID and in return yields Results objects
-        with trap location results.
-
-        :param pipe: Input generator of Timelapse objects.
-        :param split_results: Determines whether the output Results are
-        split by trap or not.
-        :returns:
-        """
-
-        return
+    def run(self, keys, store):
+        for key in keys:
+            _, pos = key.rsplit('/', maxsplit=1)
+            self.current_position = pos #So tiling occurs
+            trap_loc_key = key + '/trap_locations'
+            store.put(trap_loc_key, self.trap_locations[pos])
+        return keys
 
diff --git a/test/test_baby_client.py b/test/test_baby_client.py
index 0a0ad6ac..08918061 100644
--- a/test/test_baby_client.py
+++ b/test/test_baby_client.py
@@ -39,8 +39,11 @@ try:
                                          with_edgemasks=False)
         while True:
             try:
-                result = baby_client.get_segmentation()
+                print('Loading.', end='')
+                result = baby_client.get_segmentation(baby_client.sessions[
+                                                          'default'])
             except:
+                print('.', end='')
                 time.sleep(2)
                 continue
             break
diff --git a/test/test_pipeline.py b/test/test_pipeline.py
new file mode 100644
index 00000000..8782c907
--- /dev/null
+++ b/test/test_pipeline.py
@@ -0,0 +1,34 @@
+import unittest
+from core.pipeline import Pipeline
+from core.pipeline import ExperimentLocal, Tiler, BabyClient
+
+import pandas as pd
+
+
+class TestCase(unittest.TestCase):
+    def setUp(self) -> None:
+        self.store = pd.HDFStore('store.h5')
+        root_dir = '/Users/s1893247/PhD/pipeline-core/data/glclvl_0' \
+                   '.1_mig1_msn2_maf1_sfp1_dot6_03'
+        raw_expt = ExperimentLocal(root_dir)
+        tiler = Tiler(raw_expt)
+
+        config = {"camera": "evolve",
+                  "channel": "brightfield",
+                  "zoom": "60x",
+                  "n_stacks": "5z"}
+
+        baby_client = BabyClient(raw_expt, **config)
+        self.pipeline = Pipeline(self.store,
+                                 pipeline_steps=[raw_expt, tiler, baby_client])
+
+    def test_run(self):
+        self.pipeline.run(max_runs=1)
+
+    def tearDown(self) -> None:
+        self.store.close()
+
+
+
+if __name__ == '__main__':
+    unittest.main()
-- 
GitLab