From 102a92bcc8d767df0f2d9852a978b5137e6f1ccb Mon Sep 17 00:00:00 2001
From: Diane Adjavon <diane.adjavon@ed.ac.uk>
Date: Wed, 14 Oct 2020 08:53:05 +0200
Subject: [PATCH] make download from OMERO pipeline-based

---
 core/experiment.py               | 58 +++++++++++++------
 core/segment.py                  |  2 +-
 scripts/create_experiment.py     | 86 +++++++++++++++++++++++++++
 scripts/get_expt_metadata.py     | 58 +++++++++++++++++++
 scripts/load_test_experiments.py | 60 +++++++++++++++++++
 scripts/run_pipeline.py          | 99 ++++++++++++++++++++++++++++++++
 setup.py                         |  6 +-
 test/test_sqlalchemy.py          | 64 +++++++++++++++++++++
 test/test_tiler.py               | 63 ++++++++++++++++++++
 9 files changed, 473 insertions(+), 23 deletions(-)
 create mode 100644 scripts/create_experiment.py
 create mode 100644 scripts/get_expt_metadata.py
 create mode 100644 scripts/load_test_experiments.py
 create mode 100644 scripts/run_pipeline.py
 create mode 100644 test/test_sqlalchemy.py
 create mode 100644 test/test_tiler.py

diff --git a/core/experiment.py b/core/experiment.py
index d4fa07b4..35ccbacf 100644
--- a/core/experiment.py
+++ b/core/experiment.py
@@ -20,6 +20,7 @@ from logfile_parser import Parser
 
 from core.timelapse import TimelapseOMERO, TimelapseLocal
 from core.utils import accumulate
+from database.records import Position
 
 logger = logging.getLogger(__name__)
 
@@ -131,7 +132,9 @@ class ExperimentOMERO(Experiment):
         # Set up the current position as the first in the list
         self._current_position = self.get_position(self.positions[0])
 
-        self.save_dir = Path(kwargs.get('save_dir', './'))
+        self.save_dir = Path(kwargs.get('save_dir', './')) / self.name
+        if not self.save_dir.exists():
+            self.save_dir.mkdir(parents=True)
         self.running_tp = 0
 
     @property
@@ -141,7 +144,7 @@ class ExperimentOMERO(Experiment):
     def get_position(self, position):
         """Get a Timelapse object for a given position by name"""
         # assert position in self.positions, "Position not available."
-        img = self.connection.getObject("Image", self.positions[position])
+        img = self.connection.getObject("Image", self._positions[position])
         return TimelapseOMERO(img)
 
     def cache_locally(self, root_dir='./', positions=None, channels=None,
@@ -186,37 +189,43 @@ class ExperimentOMERO(Experiment):
         logger.info('Downloaded experiment {}'.format(self.exptID))
 
     # Todo: turn this static
-    def cache_annotations(self, save_dir):
+    def cache_annotations(self, save_dir, **kwargs):
         # Save the file annotations
+        save_mat = kwargs.get('save_mat', False)
         tags = dict()# and the tag annotations
         for annotation in self.dataset.listAnnotations():
             if isinstance(annotation, omero.gateway.FileAnnotationWrapper):
                 filepath = save_dir / annotation.getFileName()
-                if filepath.stem.endswith('.mat'):
-                    mode = 'wb'
-                else:
-                    mode = 'w'
-                with open(str(filepath), mode) as fd:
-                    for chunk in annotation:
-                        fd.write(chunk)
+                if save_mat or not str(filepath).endswith('mat') and not filepath.exists():
+                    with open(str(filepath), 'wb') as fd:
+                        for chunk in annotation.getFileInChunks():
+                            fd.write(chunk)
             if isinstance(annotation, omero.gateway.TagAnnotationWrapper):
                 # TODO save TagAnnotations in tags dictionary
-                pass
+                key = annotation.getDescription()
+                if key == '':
+                    key = 'misc. tags'
+                if key in tags:
+                    if not isinstance(tags[key], list):
+                        tags[key] = [tags[key]]
+                    tags[key].append(annotation.getValue())
+                else:
+                    tags[key] = annotation.getValue()
         with open(str(save_dir / 'omero_tags.json'), 'w') as fd:
             json.dump(tags, fd)
         return
 
     # Todo: turn this static
     def cache_set(self, save_dir, position: TimelapseOMERO,
-                  timepoints: Iterable[int]):
+                  timepoints: Iterable[int], db_pos, **kwargs):
         # Todo: save one time point to file
         #       save it under self.save_dir / self.exptID / self.position
         #       save each channel, z_position separately
         pos_dir = save_dir / position.name
         if not pos_dir.exists():
             pos_dir.mkdir()
-        for channel in tqdm(position.channels):
-            for tp in tqdm(timepoints):
+        for tp in tqdm(timepoints):
+            for channel in tqdm(position.channels):
                 for z_pos in tqdm(range(position.size_z)):
                     ch_id = position.get_channel_index(channel)
                     image = position.get_hypercube(x=None, y=None,
@@ -230,9 +239,10 @@ class ExperimentOMERO(Experiment):
                                                                z_pos + 1)
                     cv2.imwrite(str(pos_dir / im_name), np.squeeze(
                         image))
+            db_pos.n_timepoints = tp
         return list(itertools.product([position.name], timepoints))
 
-    def run(self, keys: Union[list, int], **kwargs):
+    def run(self, keys: Union[list, int], session, **kwargs):
         if self.running_tp == 0:
             self.cache_annotations(self.save_dir)
         if isinstance(keys, list):
@@ -241,17 +251,27 @@ class ExperimentOMERO(Experiment):
         # Locally save `keys` images at a time for each position
         cached = []
         for pos_name in self.positions:
+            db_pos = session.query(Position).filter_by(name=pos_name).first()
+            if db_pos is None:
+                db_pos = Position(name=pos_name, n_timepoints=0)
+                session.add(db_pos)
             position = self.get_position(pos_name)
-            timepoints = list(range(self.running_tp,
-                                    min(self.running_tp + keys,
+            timepoints = list(range(db_pos.n_timepoints,
+                                    min(db_pos.n_timepoints + keys,
                                         position.size_t)))
-            cached += self.cache_set(self.save_dir, position, timepoints)
+            if len(timepoints) > 0 and db_pos.n_timepoints < max(timepoints):
+                try:
+                    cached += self.cache_set(self.save_dir, position,
+                                         timepoints, db_pos, **kwargs)
+                finally:
+                    # Add position to storage
+                    session.commit()
         self.running_tp += keys  # increase by number of processed time points
         return cached
 
 
 class ExperimentLocal(Experiment):
-    def __init__(self, root_dir, finished=False):
+    def __init__(self, root_dir, finished=True):
         super(ExperimentLocal, self).__init__()
         self.root_dir = Path(root_dir)
         self.exptID = self.root_dir.name
diff --git a/core/segment.py b/core/segment.py
index 01dbfab8..263285ac 100644
--- a/core/segment.py
+++ b/core/segment.py
@@ -29,7 +29,7 @@ def get_tile_shapes(x, tile_size, max_shape):
 
 
 class Tiler:
-    def __init__(self, raw_expt, finished=False, template=None):
+    def __init__(self, raw_expt, finished=True, template=None):
         self.expt = raw_expt
         self.finished = finished
         if template is None: 
diff --git a/scripts/create_experiment.py b/scripts/create_experiment.py
new file mode 100644
index 00000000..782139f3
--- /dev/null
+++ b/scripts/create_experiment.py
@@ -0,0 +1,86 @@
+import argparse 
+import os
+import sys
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
+import json
+import logging
+from logging.handlers import RotatingFileHandler
+
+import click
+import sqlalchemy as sa
+from sqlalchemy.orm import sessionmaker
+
+from database.records import Base
+
+from core.experiment import Experiment
+logger = logging.getLogger('core')
+logger.handlers = []
+logger.setLevel(logging.DEBUG)
+
+console = logging.StreamHandler()
+console.setLevel(logging.WARNING)
+formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s')
+console.setFormatter(formatter)
+logger.addHandler(console)
+
+file_handler = RotatingFileHandler(filename='test.log',
+                                   maxBytes=1e5,
+                                   backupCount=1)
+
+file_handler.setLevel(logging.DEBUG)
+file_formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s '
+                                   '- %(message)s')
+file_handler.setFormatter(file_formatter)
+logger.addHandler(file_handler)
+
+@click.command()
+@click.argument('config_file')
+@click.option('-i', '--expt_id', type=int, help="Experiment ID")
+@click.option('-t', '--time', type=int)
+@click.option('--save_dir', default='./data/')
+@click.option('--db', default='sqlite:///out.db')
+def download(config_file, expt_id, time, save_dir, db):
+    with open(config_file, 'r') as fd:
+        config = json.load(fd)
+    if not expt_id:
+        expt_id = config['experiment']
+    else:
+        expt_id = expt_id
+    expt = Experiment.from_source(expt_id, config['user'],
+                                  config['password'], config['host'],
+                                  config['port'], save_dir=save_dir)
+    if time:
+        timepoints = time
+    else:
+        timepoints = 0
+    print(expt.name)
+    print(expt.shape)
+    
+    # Create SQL database
+    engine = sa.create_engine(db)
+    Base.metadata.create_all(engine)
+    Session = sessionmaker(engine)
+    session = Session()
+    try:
+        expt.run(timepoints, session)
+    except Exception as e:
+        raise e
+    finally:
+        expt.connection.close()
+
+
+@click.group()
+def cli():
+    pass
+
+cli.add_command(download)
+
+if __name__ == "__main__":
+    try:
+        cli()
+    except Exception as e: 
+        print("Caught the thing returning error")
+        sys.exit(1)
+    finally:
+        sys.exit(0)
diff --git a/scripts/get_expt_metadata.py b/scripts/get_expt_metadata.py
new file mode 100644
index 00000000..65b46869
--- /dev/null
+++ b/scripts/get_expt_metadata.py
@@ -0,0 +1,58 @@
+import argparse 
+import os
+import sys
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
+import json
+import logging
+from logging.handlers import RotatingFileHandler
+
+from core.experiment import Experiment
+logger = logging.getLogger('core')
+logger.handlers = []
+logger.setLevel(logging.DEBUG)
+
+console = logging.StreamHandler()
+console.setLevel(logging.WARNING)
+formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s')
+console.setFormatter(formatter)
+logger.addHandler(console)
+
+file_handler = RotatingFileHandler(filename='test.log',
+                                   maxBytes=1e5,
+                                   backupCount=1)
+
+file_handler.setLevel(logging.DEBUG)
+file_formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s '
+                                   '- %(message)s')
+file_handler.setFormatter(file_formatter)
+logger.addHandler(file_handler)
+
+logger.debug('Set up the loggers as test.')
+
+parser = argparse.ArgumentParser(description='Load experiment from the database.')
+parser.add_argument('--config', dest='config_file', type=str)
+parser.add_argument('--id', type=int)
+
+args = parser.parse_args()
+
+with open(args.config_file, 'r') as fd:
+    config = json.load(fd)
+
+if not args.id:
+    expt_id = config['experiment']
+else:
+    expt_id = args.id
+
+try:
+    expt = Experiment.from_source(expt_id, config['user'],
+                              config['password'], config['host'],
+                              config['port'], save_dir='data/')
+    print(expt.name)
+    print(expt.metadata.channels)
+    print(expt.metadata.times)
+    print(expt.metadata.switch_params)
+    print(expt.metadata.zsections)
+    print(expt.metadata.positions)
+finally:
+    expt.connection.seppuku()
diff --git a/scripts/load_test_experiments.py b/scripts/load_test_experiments.py
new file mode 100644
index 00000000..2c8841c6
--- /dev/null
+++ b/scripts/load_test_experiments.py
@@ -0,0 +1,60 @@
+import sys
+import json
+import numpy as np
+from pathlib import Path
+
+import omero_py as op
+import omero
+
+from core.connect import Database
+
+
+def load_experiment(expt_id, save_dir):
+    save_dir = Path(save_dir)
+    with open('config.json', 'r') as fd:
+        config = json.load(fd)
+    db = Database(config['user'],
+                  config['password'],
+                  config['host'],
+                  config['port'])
+    db.connect()
+    ds = db.getDataset(expt_id)
+    print('Experiment: {}'.format(ds.name))
+
+    save_dir = save_dir / ds.name
+    if not save_dir.exists():
+        save_dir.mkdir(parents=True)
+
+    # Load the annotation files
+    tag_fd = open(str(save_dir / 'tags.txt'), 'w')
+
+    for ann in ds.dataset.listAnnotations():
+        if isinstance(ann, omero.gateway.FileAnnotationWrapper):
+            with open(str(save_dir / ann.getFileName()), 'w') as fd:
+                for chunk in ann.getFileInChunks():
+                    fd.write(chunk)
+        else:
+            tag_fd.write('{} : {}\n'.format(ann.getDescription(),
+                                            ann.getValue()))
+    tag_fd.close()
+
+    for img in list(ds.getImages()):
+        im_name = img.name
+        print("Getting image {}".format(im_name))
+        for ix, channel in enumerate(img.channels):
+            print('Getting channel {}'.format(channel))
+            channel_array = img.getHypercube(channels=[ix])
+            print('Saving to {}'.format(save_dir / (channel + str(im_name))))
+            np.save(save_dir / (channel + str(im_name)), channel_array)
+
+    db.disconnect()
+
+if __name__ == "__main__":
+    if len(sys.argv) > 1:
+        save_dir = Path(sys.argv[1])
+        if not save_dir.exists():
+            save_dir = './'
+
+    experiment_id = 10863
+
+    load_experiment(10863, save_dir)
diff --git a/scripts/run_pipeline.py b/scripts/run_pipeline.py
new file mode 100644
index 00000000..c0eea467
--- /dev/null
+++ b/scripts/run_pipeline.py
@@ -0,0 +1,99 @@
+import argparse
+import itertools
+
+import logging
+# Log to file
+logger = logging.getLogger('run_pipeline')
+hdlr = logging.FileHandler('run_pipeline.log')
+formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
+hdlr.setFormatter(formatter)
+logger.addHandler(hdlr) 
+logger.setLevel(logging.INFO)
+# Also send to stdout
+logger.addHandler(logging.StreamHandler())
+
+
+import os
+
+
+import sqlalchemy as sa
+import numpy as np
+
+from core.pipeline import Pipeline
+from core.pipeline import ExperimentLocal, Tiler, BabyClient
+
+from database.records import Base
+
+
+def define_parser():
+    parser = argparse.ArgumentParser(description='Run microscopy pipeline')
+    parser.add_argument('root_dir', type=str, help='The experiment root directory')
+    parser.add_argument('--camera', default="prime95b")
+    parser.add_argument('--channel', default="brightfield")
+    parser.add_argument('--zoom', default="60x")
+    parser.add_argument('--n_stacks', default="5z")
+    parser.add_argument('--time', type=int, default=100)
+    return parser
+
+def setup(root_dir, config):
+    raw_expt = ExperimentLocal(root_dir, finished=False)
+    tiler = Tiler(raw_expt, finished=False)
+
+    sql_db = "sqlite:///{}.db".format(raw_expt.exptID)
+    store = "{}.hdf5".format(raw_expt.exptID)
+
+    baby_client = BabyClient(tiler, **config)
+
+    pipeline = Pipeline(pipeline_steps=[raw_expt, tiler, baby_client], 
+                        database=sql_db, 
+                        store=store)
+    return pipeline, raw_expt
+
+def run(pipeline, positions, timepoints): 
+    for tp_range in timepoints:
+        logger.info("Running timepoints: {}".format(tp_range))
+        run_step = list(itertools.product(positions, tp_range))
+        pipeline.run_step(run_step)
+    pipeline.store_to_h5()
+
+def clean_up(exptID, error=False):
+    if error:
+        os.remove('{}.db'.format(exptID))
+        os.remove('{}.hdf5'.format(exptID))
+    else: 
+        pass
+
+if __name__ == '__main__':
+    parser = define_parser()
+    args = parser.parse_args()
+
+    config = {"camera" : args.camera,
+              "channel" : args.channel, 
+              "zoom" : args.zoom, 
+              "n_stacks" : args.n_stacks}
+    logger.info("Baby configuration: ", config)
+    
+    logger.info("Setting up pipeline.")
+    pipeline, raw_expt = setup(args.root_dir, config)
+    positions = raw_expt.positions
+   
+    # Todo: get the timepoints from the metadata 
+    #       or force processing even though the experiment is finished
+    tps = args.time
+    logger.info("Experiment: {}".format(raw_expt.exptID))
+    logger.info("Positions: {}, timepoints: {}".format(len(positions), tps))
+
+    timepoints = np.arange(tps).reshape(12, -1).tolist()
+    logger.info("Running pipeline")
+    try:
+        run(pipeline, positions, timepoints)
+    except Exception as e: 
+        logger.info("Cleaning up on error")
+        clean_up(raw_expt.exptID, error=True)
+        raise e
+
+    logger.info("Cleaning up.")
+    clean_up(raw_expt.exptID, error=False)
+
+    
+
diff --git a/setup.py b/setup.py
index c916a743..61139789 100644
--- a/setup.py
+++ b/setup.py
@@ -15,12 +15,12 @@ setup(
         'numpy',
         'tqdm',
         'pandas',
-        'tables',
+        'sqlalchemy',
         'scikit-image==0.16.2',
         'opencv-python',
         'imageio==2.8.0',
-        'omero-py==5.6.2',
-        'zeroc-ice==3.6.5',
+        'omero-py>=5.6.2',
+        'zeroc-ice',
         'logfile_parser@git+https://git.ecdf.ed.ac.uk/jpietsch/logfile_parser@master'
     ]
 )
diff --git a/test/test_sqlalchemy.py b/test/test_sqlalchemy.py
new file mode 100644
index 00000000..94e30c05
--- /dev/null
+++ b/test/test_sqlalchemy.py
@@ -0,0 +1,64 @@
+import pandas as pd
+from sqlalchemy import create_engine
+from database.records import Base, Position, Trap, Cell
+from sqlalchemy.orm import sessionmaker
+
+engine = create_engine("sqlite:///:memory:", echo=False)
+
+# Create the necessary tables
+Base.metadata.create_all(engine)
+
+# Create a session
+Session = sessionmaker(bind=engine)
+session = Session()
+
+# Create a new position
+pos1 = Position(name="pos001", n_timepoints=0)
+print("Created position: ", pos1)
+
+session.add(pos1)
+
+# Create a new trap
+trap1 = Trap(position=pos1, number=1, x=1, y=10, size=96)
+print("Created trap: ",trap1)
+
+session.add(trap1)
+
+session.commit()
+
+# Query session for pos1
+queried_position = session.query(Position).filter_by(name='pos001').first()
+print("Queried for position: ", pos1)
+print("Success? ", queried_position is pos1)
+print("Traps added to this posistion: ",queried_position.traps)
+
+# Create a new cell
+cell1 = Cell(number=1, trap=trap1)
+cell2 = Cell(number=2, trap=trap1)
+session.add_all([cell1, cell2])
+print("Creating cells: ", cell1, cell2)
+
+print("Cells in trap1: ")
+print(trap1.cells)
+
+
+# Update the number of time points in the positions
+queried_position.n_timepoints += 1
+session.commit()
+
+# Check how it works with Panda
+positions = pd.read_sql_table('positions', engine)
+print(positions)
+traps = pd.read_sql_table('traps', engine)
+print(traps)
+cells = pd.read_sql_table('cells', engine)
+print(cells)
+
+# Try a joined query
+
+queried_cell = session.query(Cell).filter_by(number=2)\
+                .join(Trap).filter_by(number=1)\
+                .join(Position).filter_by(name="pos001")\
+                .first()
+
+print(queried_cell)
diff --git a/test/test_tiler.py b/test/test_tiler.py
new file mode 100644
index 00000000..7917757e
--- /dev/null
+++ b/test/test_tiler.py
@@ -0,0 +1,63 @@
+import argparse
+
+from core.experiment import ExperimentLocal
+from core.segment import Tiler
+
+
+def define_parser():
+    parser = argparse.ArgumentParser()
+    parser.add_argument('root_dir')
+    parser.add_argument('--position', default=None)
+    parser.add_argument('--template', default=None)
+    parser.add_argument('--trap', type=int, default=0)
+    parser.add_argument('--channel', type=str, default='Brightfield')
+    parser.add_argument('-z', '--z_positions', type=int, default=5)
+    parser.add_argument('--time', type=int, default=0)
+    parser.add_argument('--tile_size', type=int, default=96)
+    return parser
+
+
+def initialise_objects(data_path, template=None):
+    expt = ExperimentLocal(data_path, finished=True)
+    tiler = Tiler(expt, finished=True, template=template)
+    return tiler
+
+
+def change_position(position, tiler):
+    tiler.current_position = position
+
+
+def get_n_traps_timepoints(tiler):
+    return tiler.n_traps, tiler.n_timepoints
+
+def trap_timelapse(tiler, trap_idx, channel, z):
+    channel_id = tiler.get_channel_index(channel)
+    timelapse = tiler.get_trap_timelapse(trap_idx, channels=[channel_id],
+                                         z=list(range(z)))
+    return timelapse
+
+def timepoint_traps(tiler, tp_idx, channel, z, tile_size):
+    channel_id = tiler.get_channel_index(channel)
+    traps = tiler.get_traps_timepoint(tp_idx, tile_size=tile_size,
+                                      channels=[channel_id], z=list(range(z)))
+    return traps
+
+
+if __name__ == '__main__':
+    parser = define_parser()
+    args = parser.parse_args()
+
+    tiler = initialise_objects(args.root_dir, template=args.template)
+
+    if args.position is not None:
+        tiler.current_position = args.position
+
+    n_traps, n_tps = get_n_traps_timepoints(tiler)
+
+    timelapse = trap_timelapse(tiler, args.trap, args.channel,
+                               args.z_positions)
+    traps = timepoint_traps(tiler, args.time, args.channel, args.z_positions,
+                           args.tile_size)
+
+
+
-- 
GitLab