From 102a92bcc8d767df0f2d9852a978b5137e6f1ccb Mon Sep 17 00:00:00 2001 From: Diane Adjavon <diane.adjavon@ed.ac.uk> Date: Wed, 14 Oct 2020 08:53:05 +0200 Subject: [PATCH] make download from OMERO pipeline-based --- core/experiment.py | 58 +++++++++++++------ core/segment.py | 2 +- scripts/create_experiment.py | 86 +++++++++++++++++++++++++++ scripts/get_expt_metadata.py | 58 +++++++++++++++++++ scripts/load_test_experiments.py | 60 +++++++++++++++++++ scripts/run_pipeline.py | 99 ++++++++++++++++++++++++++++++++ setup.py | 6 +- test/test_sqlalchemy.py | 64 +++++++++++++++++++++ test/test_tiler.py | 63 ++++++++++++++++++++ 9 files changed, 473 insertions(+), 23 deletions(-) create mode 100644 scripts/create_experiment.py create mode 100644 scripts/get_expt_metadata.py create mode 100644 scripts/load_test_experiments.py create mode 100644 scripts/run_pipeline.py create mode 100644 test/test_sqlalchemy.py create mode 100644 test/test_tiler.py diff --git a/core/experiment.py b/core/experiment.py index d4fa07b4..35ccbacf 100644 --- a/core/experiment.py +++ b/core/experiment.py @@ -20,6 +20,7 @@ from logfile_parser import Parser from core.timelapse import TimelapseOMERO, TimelapseLocal from core.utils import accumulate +from database.records import Position logger = logging.getLogger(__name__) @@ -131,7 +132,9 @@ class ExperimentOMERO(Experiment): # Set up the current position as the first in the list self._current_position = self.get_position(self.positions[0]) - self.save_dir = Path(kwargs.get('save_dir', './')) + self.save_dir = Path(kwargs.get('save_dir', './')) / self.name + if not self.save_dir.exists(): + self.save_dir.mkdir(parents=True) self.running_tp = 0 @property @@ -141,7 +144,7 @@ class ExperimentOMERO(Experiment): def get_position(self, position): """Get a Timelapse object for a given position by name""" # assert position in self.positions, "Position not available." - img = self.connection.getObject("Image", self.positions[position]) + img = self.connection.getObject("Image", self._positions[position]) return TimelapseOMERO(img) def cache_locally(self, root_dir='./', positions=None, channels=None, @@ -186,37 +189,43 @@ class ExperimentOMERO(Experiment): logger.info('Downloaded experiment {}'.format(self.exptID)) # Todo: turn this static - def cache_annotations(self, save_dir): + def cache_annotations(self, save_dir, **kwargs): # Save the file annotations + save_mat = kwargs.get('save_mat', False) tags = dict()# and the tag annotations for annotation in self.dataset.listAnnotations(): if isinstance(annotation, omero.gateway.FileAnnotationWrapper): filepath = save_dir / annotation.getFileName() - if filepath.stem.endswith('.mat'): - mode = 'wb' - else: - mode = 'w' - with open(str(filepath), mode) as fd: - for chunk in annotation: - fd.write(chunk) + if save_mat or not str(filepath).endswith('mat') and not filepath.exists(): + with open(str(filepath), 'wb') as fd: + for chunk in annotation.getFileInChunks(): + fd.write(chunk) if isinstance(annotation, omero.gateway.TagAnnotationWrapper): # TODO save TagAnnotations in tags dictionary - pass + key = annotation.getDescription() + if key == '': + key = 'misc. tags' + if key in tags: + if not isinstance(tags[key], list): + tags[key] = [tags[key]] + tags[key].append(annotation.getValue()) + else: + tags[key] = annotation.getValue() with open(str(save_dir / 'omero_tags.json'), 'w') as fd: json.dump(tags, fd) return # Todo: turn this static def cache_set(self, save_dir, position: TimelapseOMERO, - timepoints: Iterable[int]): + timepoints: Iterable[int], db_pos, **kwargs): # Todo: save one time point to file # save it under self.save_dir / self.exptID / self.position # save each channel, z_position separately pos_dir = save_dir / position.name if not pos_dir.exists(): pos_dir.mkdir() - for channel in tqdm(position.channels): - for tp in tqdm(timepoints): + for tp in tqdm(timepoints): + for channel in tqdm(position.channels): for z_pos in tqdm(range(position.size_z)): ch_id = position.get_channel_index(channel) image = position.get_hypercube(x=None, y=None, @@ -230,9 +239,10 @@ class ExperimentOMERO(Experiment): z_pos + 1) cv2.imwrite(str(pos_dir / im_name), np.squeeze( image)) + db_pos.n_timepoints = tp return list(itertools.product([position.name], timepoints)) - def run(self, keys: Union[list, int], **kwargs): + def run(self, keys: Union[list, int], session, **kwargs): if self.running_tp == 0: self.cache_annotations(self.save_dir) if isinstance(keys, list): @@ -241,17 +251,27 @@ class ExperimentOMERO(Experiment): # Locally save `keys` images at a time for each position cached = [] for pos_name in self.positions: + db_pos = session.query(Position).filter_by(name=pos_name).first() + if db_pos is None: + db_pos = Position(name=pos_name, n_timepoints=0) + session.add(db_pos) position = self.get_position(pos_name) - timepoints = list(range(self.running_tp, - min(self.running_tp + keys, + timepoints = list(range(db_pos.n_timepoints, + min(db_pos.n_timepoints + keys, position.size_t))) - cached += self.cache_set(self.save_dir, position, timepoints) + if len(timepoints) > 0 and db_pos.n_timepoints < max(timepoints): + try: + cached += self.cache_set(self.save_dir, position, + timepoints, db_pos, **kwargs) + finally: + # Add position to storage + session.commit() self.running_tp += keys # increase by number of processed time points return cached class ExperimentLocal(Experiment): - def __init__(self, root_dir, finished=False): + def __init__(self, root_dir, finished=True): super(ExperimentLocal, self).__init__() self.root_dir = Path(root_dir) self.exptID = self.root_dir.name diff --git a/core/segment.py b/core/segment.py index 01dbfab8..263285ac 100644 --- a/core/segment.py +++ b/core/segment.py @@ -29,7 +29,7 @@ def get_tile_shapes(x, tile_size, max_shape): class Tiler: - def __init__(self, raw_expt, finished=False, template=None): + def __init__(self, raw_expt, finished=True, template=None): self.expt = raw_expt self.finished = finished if template is None: diff --git a/scripts/create_experiment.py b/scripts/create_experiment.py new file mode 100644 index 00000000..782139f3 --- /dev/null +++ b/scripts/create_experiment.py @@ -0,0 +1,86 @@ +import argparse +import os +import sys + +sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) +import json +import logging +from logging.handlers import RotatingFileHandler + +import click +import sqlalchemy as sa +from sqlalchemy.orm import sessionmaker + +from database.records import Base + +from core.experiment import Experiment +logger = logging.getLogger('core') +logger.handlers = [] +logger.setLevel(logging.DEBUG) + +console = logging.StreamHandler() +console.setLevel(logging.WARNING) +formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s') +console.setFormatter(formatter) +logger.addHandler(console) + +file_handler = RotatingFileHandler(filename='test.log', + maxBytes=1e5, + backupCount=1) + +file_handler.setLevel(logging.DEBUG) +file_formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s ' + '- %(message)s') +file_handler.setFormatter(file_formatter) +logger.addHandler(file_handler) + +@click.command() +@click.argument('config_file') +@click.option('-i', '--expt_id', type=int, help="Experiment ID") +@click.option('-t', '--time', type=int) +@click.option('--save_dir', default='./data/') +@click.option('--db', default='sqlite:///out.db') +def download(config_file, expt_id, time, save_dir, db): + with open(config_file, 'r') as fd: + config = json.load(fd) + if not expt_id: + expt_id = config['experiment'] + else: + expt_id = expt_id + expt = Experiment.from_source(expt_id, config['user'], + config['password'], config['host'], + config['port'], save_dir=save_dir) + if time: + timepoints = time + else: + timepoints = 0 + print(expt.name) + print(expt.shape) + + # Create SQL database + engine = sa.create_engine(db) + Base.metadata.create_all(engine) + Session = sessionmaker(engine) + session = Session() + try: + expt.run(timepoints, session) + except Exception as e: + raise e + finally: + expt.connection.close() + + +@click.group() +def cli(): + pass + +cli.add_command(download) + +if __name__ == "__main__": + try: + cli() + except Exception as e: + print("Caught the thing returning error") + sys.exit(1) + finally: + sys.exit(0) diff --git a/scripts/get_expt_metadata.py b/scripts/get_expt_metadata.py new file mode 100644 index 00000000..65b46869 --- /dev/null +++ b/scripts/get_expt_metadata.py @@ -0,0 +1,58 @@ +import argparse +import os +import sys + +sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) +import json +import logging +from logging.handlers import RotatingFileHandler + +from core.experiment import Experiment +logger = logging.getLogger('core') +logger.handlers = [] +logger.setLevel(logging.DEBUG) + +console = logging.StreamHandler() +console.setLevel(logging.WARNING) +formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s') +console.setFormatter(formatter) +logger.addHandler(console) + +file_handler = RotatingFileHandler(filename='test.log', + maxBytes=1e5, + backupCount=1) + +file_handler.setLevel(logging.DEBUG) +file_formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s ' + '- %(message)s') +file_handler.setFormatter(file_formatter) +logger.addHandler(file_handler) + +logger.debug('Set up the loggers as test.') + +parser = argparse.ArgumentParser(description='Load experiment from the database.') +parser.add_argument('--config', dest='config_file', type=str) +parser.add_argument('--id', type=int) + +args = parser.parse_args() + +with open(args.config_file, 'r') as fd: + config = json.load(fd) + +if not args.id: + expt_id = config['experiment'] +else: + expt_id = args.id + +try: + expt = Experiment.from_source(expt_id, config['user'], + config['password'], config['host'], + config['port'], save_dir='data/') + print(expt.name) + print(expt.metadata.channels) + print(expt.metadata.times) + print(expt.metadata.switch_params) + print(expt.metadata.zsections) + print(expt.metadata.positions) +finally: + expt.connection.seppuku() diff --git a/scripts/load_test_experiments.py b/scripts/load_test_experiments.py new file mode 100644 index 00000000..2c8841c6 --- /dev/null +++ b/scripts/load_test_experiments.py @@ -0,0 +1,60 @@ +import sys +import json +import numpy as np +from pathlib import Path + +import omero_py as op +import omero + +from core.connect import Database + + +def load_experiment(expt_id, save_dir): + save_dir = Path(save_dir) + with open('config.json', 'r') as fd: + config = json.load(fd) + db = Database(config['user'], + config['password'], + config['host'], + config['port']) + db.connect() + ds = db.getDataset(expt_id) + print('Experiment: {}'.format(ds.name)) + + save_dir = save_dir / ds.name + if not save_dir.exists(): + save_dir.mkdir(parents=True) + + # Load the annotation files + tag_fd = open(str(save_dir / 'tags.txt'), 'w') + + for ann in ds.dataset.listAnnotations(): + if isinstance(ann, omero.gateway.FileAnnotationWrapper): + with open(str(save_dir / ann.getFileName()), 'w') as fd: + for chunk in ann.getFileInChunks(): + fd.write(chunk) + else: + tag_fd.write('{} : {}\n'.format(ann.getDescription(), + ann.getValue())) + tag_fd.close() + + for img in list(ds.getImages()): + im_name = img.name + print("Getting image {}".format(im_name)) + for ix, channel in enumerate(img.channels): + print('Getting channel {}'.format(channel)) + channel_array = img.getHypercube(channels=[ix]) + print('Saving to {}'.format(save_dir / (channel + str(im_name)))) + np.save(save_dir / (channel + str(im_name)), channel_array) + + db.disconnect() + +if __name__ == "__main__": + if len(sys.argv) > 1: + save_dir = Path(sys.argv[1]) + if not save_dir.exists(): + save_dir = './' + + experiment_id = 10863 + + load_experiment(10863, save_dir) diff --git a/scripts/run_pipeline.py b/scripts/run_pipeline.py new file mode 100644 index 00000000..c0eea467 --- /dev/null +++ b/scripts/run_pipeline.py @@ -0,0 +1,99 @@ +import argparse +import itertools + +import logging +# Log to file +logger = logging.getLogger('run_pipeline') +hdlr = logging.FileHandler('run_pipeline.log') +formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s') +hdlr.setFormatter(formatter) +logger.addHandler(hdlr) +logger.setLevel(logging.INFO) +# Also send to stdout +logger.addHandler(logging.StreamHandler()) + + +import os + + +import sqlalchemy as sa +import numpy as np + +from core.pipeline import Pipeline +from core.pipeline import ExperimentLocal, Tiler, BabyClient + +from database.records import Base + + +def define_parser(): + parser = argparse.ArgumentParser(description='Run microscopy pipeline') + parser.add_argument('root_dir', type=str, help='The experiment root directory') + parser.add_argument('--camera', default="prime95b") + parser.add_argument('--channel', default="brightfield") + parser.add_argument('--zoom', default="60x") + parser.add_argument('--n_stacks', default="5z") + parser.add_argument('--time', type=int, default=100) + return parser + +def setup(root_dir, config): + raw_expt = ExperimentLocal(root_dir, finished=False) + tiler = Tiler(raw_expt, finished=False) + + sql_db = "sqlite:///{}.db".format(raw_expt.exptID) + store = "{}.hdf5".format(raw_expt.exptID) + + baby_client = BabyClient(tiler, **config) + + pipeline = Pipeline(pipeline_steps=[raw_expt, tiler, baby_client], + database=sql_db, + store=store) + return pipeline, raw_expt + +def run(pipeline, positions, timepoints): + for tp_range in timepoints: + logger.info("Running timepoints: {}".format(tp_range)) + run_step = list(itertools.product(positions, tp_range)) + pipeline.run_step(run_step) + pipeline.store_to_h5() + +def clean_up(exptID, error=False): + if error: + os.remove('{}.db'.format(exptID)) + os.remove('{}.hdf5'.format(exptID)) + else: + pass + +if __name__ == '__main__': + parser = define_parser() + args = parser.parse_args() + + config = {"camera" : args.camera, + "channel" : args.channel, + "zoom" : args.zoom, + "n_stacks" : args.n_stacks} + logger.info("Baby configuration: ", config) + + logger.info("Setting up pipeline.") + pipeline, raw_expt = setup(args.root_dir, config) + positions = raw_expt.positions + + # Todo: get the timepoints from the metadata + # or force processing even though the experiment is finished + tps = args.time + logger.info("Experiment: {}".format(raw_expt.exptID)) + logger.info("Positions: {}, timepoints: {}".format(len(positions), tps)) + + timepoints = np.arange(tps).reshape(12, -1).tolist() + logger.info("Running pipeline") + try: + run(pipeline, positions, timepoints) + except Exception as e: + logger.info("Cleaning up on error") + clean_up(raw_expt.exptID, error=True) + raise e + + logger.info("Cleaning up.") + clean_up(raw_expt.exptID, error=False) + + + diff --git a/setup.py b/setup.py index c916a743..61139789 100644 --- a/setup.py +++ b/setup.py @@ -15,12 +15,12 @@ setup( 'numpy', 'tqdm', 'pandas', - 'tables', + 'sqlalchemy', 'scikit-image==0.16.2', 'opencv-python', 'imageio==2.8.0', - 'omero-py==5.6.2', - 'zeroc-ice==3.6.5', + 'omero-py>=5.6.2', + 'zeroc-ice', 'logfile_parser@git+https://git.ecdf.ed.ac.uk/jpietsch/logfile_parser@master' ] ) diff --git a/test/test_sqlalchemy.py b/test/test_sqlalchemy.py new file mode 100644 index 00000000..94e30c05 --- /dev/null +++ b/test/test_sqlalchemy.py @@ -0,0 +1,64 @@ +import pandas as pd +from sqlalchemy import create_engine +from database.records import Base, Position, Trap, Cell +from sqlalchemy.orm import sessionmaker + +engine = create_engine("sqlite:///:memory:", echo=False) + +# Create the necessary tables +Base.metadata.create_all(engine) + +# Create a session +Session = sessionmaker(bind=engine) +session = Session() + +# Create a new position +pos1 = Position(name="pos001", n_timepoints=0) +print("Created position: ", pos1) + +session.add(pos1) + +# Create a new trap +trap1 = Trap(position=pos1, number=1, x=1, y=10, size=96) +print("Created trap: ",trap1) + +session.add(trap1) + +session.commit() + +# Query session for pos1 +queried_position = session.query(Position).filter_by(name='pos001').first() +print("Queried for position: ", pos1) +print("Success? ", queried_position is pos1) +print("Traps added to this posistion: ",queried_position.traps) + +# Create a new cell +cell1 = Cell(number=1, trap=trap1) +cell2 = Cell(number=2, trap=trap1) +session.add_all([cell1, cell2]) +print("Creating cells: ", cell1, cell2) + +print("Cells in trap1: ") +print(trap1.cells) + + +# Update the number of time points in the positions +queried_position.n_timepoints += 1 +session.commit() + +# Check how it works with Panda +positions = pd.read_sql_table('positions', engine) +print(positions) +traps = pd.read_sql_table('traps', engine) +print(traps) +cells = pd.read_sql_table('cells', engine) +print(cells) + +# Try a joined query + +queried_cell = session.query(Cell).filter_by(number=2)\ + .join(Trap).filter_by(number=1)\ + .join(Position).filter_by(name="pos001")\ + .first() + +print(queried_cell) diff --git a/test/test_tiler.py b/test/test_tiler.py new file mode 100644 index 00000000..7917757e --- /dev/null +++ b/test/test_tiler.py @@ -0,0 +1,63 @@ +import argparse + +from core.experiment import ExperimentLocal +from core.segment import Tiler + + +def define_parser(): + parser = argparse.ArgumentParser() + parser.add_argument('root_dir') + parser.add_argument('--position', default=None) + parser.add_argument('--template', default=None) + parser.add_argument('--trap', type=int, default=0) + parser.add_argument('--channel', type=str, default='Brightfield') + parser.add_argument('-z', '--z_positions', type=int, default=5) + parser.add_argument('--time', type=int, default=0) + parser.add_argument('--tile_size', type=int, default=96) + return parser + + +def initialise_objects(data_path, template=None): + expt = ExperimentLocal(data_path, finished=True) + tiler = Tiler(expt, finished=True, template=template) + return tiler + + +def change_position(position, tiler): + tiler.current_position = position + + +def get_n_traps_timepoints(tiler): + return tiler.n_traps, tiler.n_timepoints + +def trap_timelapse(tiler, trap_idx, channel, z): + channel_id = tiler.get_channel_index(channel) + timelapse = tiler.get_trap_timelapse(trap_idx, channels=[channel_id], + z=list(range(z))) + return timelapse + +def timepoint_traps(tiler, tp_idx, channel, z, tile_size): + channel_id = tiler.get_channel_index(channel) + traps = tiler.get_traps_timepoint(tp_idx, tile_size=tile_size, + channels=[channel_id], z=list(range(z))) + return traps + + +if __name__ == '__main__': + parser = define_parser() + args = parser.parse_args() + + tiler = initialise_objects(args.root_dir, template=args.template) + + if args.position is not None: + tiler.current_position = args.position + + n_traps, n_tps = get_n_traps_timepoints(tiler) + + timelapse = trap_timelapse(tiler, args.trap, args.channel, + args.z_positions) + traps = timepoint_traps(tiler, args.time, args.channel, args.z_positions, + args.tile_size) + + + -- GitLab