diff --git a/scripts/dis_alan.bak b/scripts/dis_alan.bak index 2443b928aa298379b1e1d4af2f36a1e1a6dcff1f..3c5dfb16fe416e272f233f485ca1d877d1d5e8bf 100644 --- a/scripts/dis_alan.bak +++ b/scripts/dis_alan.bak @@ -8,9 +8,6 @@ from pathos.multiprocessing import Pool from multiprocessing import set_start_method import numpy as np -from extraction.core.functions.defaults import exparams_from_meta -from core.io.signal import Signal - # set_start_method("spawn") from tqdm import tqdm @@ -27,10 +24,12 @@ from core.baby_client import DummyRunner from core.segment import Tiler from core.io.writer import TilerWriter, BabyWriter from core.utils import timed +from core.io.signal import Signal from extraction.core.extractor import Extractor from extraction.core.parameters import Parameters from extraction.core.functions.defaults import get_params +from extraction.core.functions.defaults import exparams_from_meta from postprocessor.core.processor import PostProcessorParameters, PostProcessor @@ -85,21 +84,13 @@ def create_pipeline(image_id, **config): filename = f"{directory}/{image.name}.h5" # Run metadata first process_from = 0 - if True: # not Path(filename).exists(): - meta = MetaData(directory, filename) - meta.run() - tiler = Tiler( - image.data, - image.metadata, - tile_size=general_config.get("tile_size", 117), - ) - else: - tiler = Tiler.from_hdf5(image.data, filename) - s = Signal(filename) - process_from = s["/general/None/extraction/volume"].columns[-1] - if process_from > 2: - process_from = process_from - 3 - tiler.n_processed = process_from + meta = MetaData(directory, filename) + meta.run() + tiler = Tiler( + image.data, + image.metadata, + tile_size=general_config.get("tile_size", 117), + ) writer = TilerWriter(filename) baby_config = config.get("baby", None) @@ -269,63 +260,30 @@ def visualise_timing(timings: dict, save_file: str): ax.figure.savefig(save_file, bbox_inches="tight", transparent=True) return - # if __name__ == "__main__": - - -strain = "YST_1512" -# exp = 18616 -# exp = 19232 -# exp = 19995 -# exp = 19993 -exp = 20191 -# exp = 19831 -with Dataset(exp) as conn: - imgs = conn.get_images() - exp_name = conn.unique_name - -with Image(list(imgs.values())[0]) as im: - meta = im.metadata -tps = int(meta["size_t"]) -# tps = meta["size_t"] -config = dict( - general=dict( - id=exp, - distributed=4, - tps=tps, - directory="../data/", - strain=strain, - tile_size=117, - ), - # general=dict(id=19303, distributed=0, tps=tps, strain=strain, directory="../data/"), - tiler=dict(), - baby=dict(tf_version=2), - earlystop=dict( - min_tp=50, - thresh_pos_clogged=0.3, - thresh_trap_clogged=7, - ntps_to_eval=5, - ), -) - -# log_file = f"../data/{exp_name}/issues.log" -log_file = "/shared_libs/pydask/pipeline-core/data/2020_02_20_protAgg_downUpShift_2_0_2_Ura8_Ura8HA_Ura8HR_01/issues.log" - -# initialise_logging(log_file) -save_timings = f"../data/{exp_name}/timings_{strain}_{tps}.pdf" -timings_file = f"../data/{exp_name}/timings_{strain}_{tps}.json" - -# Run -run_config(config) -# Get timing results -# timing = parse_timing(log_file) -# # Visualise timings and save -# visualise_timing(timing, save_timings) -# # Dump the rest to json -# with open(timings_file, "w") as fd: -# json.dump(timing, fd) - -# filename = "/shared_libs/pydask/pipeline-core/data/2020_02_20_protAgg_downUpShift_2_0_2_Ura8_Ura8HA_Ura8HR_01/Ura8H360R030.h5" -# import h5py - -# with h5py.File(filename, "r") as f: -# plt.imshow(f["cell_info/edgemasks/values"][0][-1]) +if __name__ == "__main__": + strain = 'Vph1' + tps =390 + config = dict( + general=dict( + id=19303, + distributed=5, + tps=tps, + strain=strain, + directory='../data/' + ), + tiler=dict(), + baby=dict(tf_version=2) + ) + log_file = '../data/2tozero_Hxts_02/issues.log' + initialise_logging(log_file) + save_timings = f"../data/2tozero_Hxts_02/timings_{strain}_{tps}.pdf" + timings_file = f"../data/2tozero_Hxts_02/timings_{strain}_{tps}.json" + # Run + #run_config(config) + # Get timing results + timing = parse_timing(log_file) + # Visualise timings and save + visualise_timing(timing, save_timings) + # Dump the rest to json + with open(timings_file, 'w') as fd: + json.dump(timing, fd) diff --git a/scripts/distributed_alan.py b/scripts/distributed_alan.py new file mode 100644 index 0000000000000000000000000000000000000000..fb6aca7fca57f8070f2b87fb922eb15a856a8a68 --- /dev/null +++ b/scripts/distributed_alan.py @@ -0,0 +1,311 @@ +from pathlib import Path +import json +from time import perf_counter +import logging + +from core.experiment import MetaData +from pathos.multiprocessing import Pool +from multiprocessing import set_start_method +import numpy as np + +from extraction.core.functions.defaults import exparams_from_meta +from core.io.signal import Signal + +# set_start_method("spawn") + +from tqdm import tqdm +import traceback +import matplotlib.pyplot as plt +import seaborn as sns +import operator + +from baby.brain import BabyBrain + +from core.io.omero import Dataset, Image +from core.haystack import initialise_tf +from core.baby_client import DummyRunner +from core.segment import Tiler +from core.io.writer import TilerWriter, BabyWriter +from core.utils import timed + +from extraction.core.extractor import Extractor +from extraction.core.parameters import Parameters +from extraction.core.functions.defaults import get_params +from postprocessor.core.processor import PostProcessorParameters, PostProcessor + + +def pipeline(image_id, tps=10, tf_version=2): + name, image_id = image_id + try: + # Initialise tensorflow + session = initialise_tf(tf_version) + with Image(image_id) as image: + print(f"Getting data for {image.name}") + tiler = Tiler(image.data, image.metadata, image.name) + writer = TilerWriter(f"../data/test2/{image.name}.h5") + runner = DummyRunner(tiler) + bwriter = BabyWriter(f"../data/test2/{image.name}.h5") + for i in tqdm(range(0, tps), desc=image.name): + trap_info = tiler.run_tp(i) + writer.write(trap_info, overwrite=[]) + seg = runner.run_tp(i) + bwriter.write(seg, overwrite=["mother_assign"]) + return True + except Exception as e: # bug in the trap getting + print(f"Caught exception in worker thread (x = {name}):") + # This prints the type, value, and stack trace of the + # current exception being handled. + traceback.print_exc() + print() + raise e + finally: + # Close session + if session: + session.close() + + +@timed("Position") +def create_pipeline(image_id, **config): + name, image_id = image_id + general_config = config.get("general", None) + assert general_config is not None + session = None + earlystop = config.get( + "earlystop", + { + "min_tp": 50, + "thresh_pos_clogged": 0.3, + "thresh_trap_clogged": 7, + "ntps_to_eval": 5, + }, + ) + try: + directory = general_config.get("directory", "") + with Image(image_id) as image: + filename = f"{directory}/{image.name}.h5" + # Run metadata first + process_from = 0 + if True: # not Path(filename).exists(): + meta = MetaData(directory, filename) + meta.run() + tiler = Tiler( + image.data, + image.metadata, + tile_size=general_config.get("tile_size", 117), + ) + else: + tiler = Tiler.from_hdf5(image.data, filename) + s = Signal(filename) + process_from = s["/general/None/extraction/volume"].columns[-1] + if process_from > 2: + process_from = process_from - 3 + tiler.n_processed = process_from + + writer = TilerWriter(filename) + baby_config = config.get("baby", None) + assert baby_config is not None # TODO add defaults + tf_version = baby_config.get("tf_version", 2) + session = initialise_tf(tf_version) + runner = DummyRunner(tiler) + bwriter = BabyWriter(filename) + params = Parameters(**exparams_from_meta(filename)) + ext = Extractor.from_tiler(params, store=filename, tiler=tiler) + # RUN + tps = general_config.get("tps", 0) + frac_clogged_traps = 0 + for i in tqdm( + range(process_from, tps), desc=image.name, initial=process_from + ): + if frac_clogged_traps < earlystop["thresh_pos_clogged"]: + t = perf_counter() + trap_info = tiler.run_tp(i) + logging.debug(f"Timing:Trap:{perf_counter() - t}s") + t = perf_counter() + writer.write(trap_info, overwrite=[]) + logging.debug(f"Timing:Writing-trap:{perf_counter() - t}s") + t = perf_counter() + seg = runner.run_tp(i) + logging.debug(f"Timing:Segmentation:{perf_counter() - t}s") + t = perf_counter() + bwriter.write(seg, overwrite=["mother_assign"]) + logging.debug(f"Timing:Writing-baby:{perf_counter() - t}s") + t = perf_counter() + ext.extract_pos(tps=[i]) + logging.debug(f"Timing:Extraction:{perf_counter() - t}s") + else: # Stop if more than 10% traps are clogged + logging.debug( + f"EarlyStop:{earlystop['thresh_pos_clogged']*100}% traps clogged at time point {i}" + ) + print( + f"Breaking experiment at time {i} with {frac_clogged_traps} clogged traps" + ) + break + + if i > earlystop["min_tp"]: # Calculate the fraction of clogged traps + s = Signal(filename) + df = s["/extraction/general/None/area"] + frac_clogged_traps = ( + df[df.columns[i - earlystop["ntps_to_eval"] : i]] + .dropna(how="all") + .notna() + .groupby("trap") + .apply(sum) + .apply(np.nanmean, axis=1) + > earlystop["thresh_trap_clogged"] + ).mean() + logging.debug(f"Quality:Clogged_traps:{frac_clogged_traps}") + print("Frac clogged traps: ", frac_clogged_traps) + + # Run post processing + # post_proc_params = PostProcessorParameters.default() + # post_process(filename, post_proc_params) + return True + except Exception as e: # bug in the trap getting + print(f"Caught exception in worker thread (x = {name}):") + # This prints the type, value, and stack trace of the + # current exception being handled. + traceback.print_exc() + print() + raise e + finally: + if session: + session.close() + + +@timed("Post-processing") +def post_process(filepath, params): + pp = PostProcessor(filepath, params) + tmp = pp.run() + return tmp + + +# instantiating the decorator +@timed("Pipeline") +def run_config(config): + # Config holds the general information, use in main + # Steps holds the description of tasks with their parameters + # Steps: all holds general tasks + # steps: strain_name holds task for a given strain + expt_id = config["general"].get("id") + distributed = config["general"].get("distributed", 0) + strain_filter = config["general"].get("strain", "") + root_dir = config["general"].get("directory", "output") + root_dir = Path(root_dir) + + print("Searching OMERO") + # Do all initialisation + with Dataset(int(expt_id)) as conn: + image_ids = conn.get_images() + directory = root_dir / conn.unique_name + if not directory.exists(): + directory.mkdir(parents=True) + # Download logs to use for metadata + conn.cache_logs(directory) + + # Modify to the configuration + config["general"]["directory"] = directory + # Filter + image_ids = {k: v for k, v in image_ids.items() if k.startswith(strain_filter)} + + if distributed != 0: # Gives the number of simultaneous processes + with Pool(distributed) as p: + results = p.map(lambda x: create_pipeline(x, **config), image_ids.items()) + p.terminate() + return results + else: # Sequential + results = [] + for k, v in image_ids.items(): + r = create_pipeline((k, v), **config) + results.append(r) + + +def initialise_logging(log_file: str): + logging.basicConfig(filename=log_file, level=logging.DEBUG) + for v in logging.Logger.manager.loggerDict.values(): + try: + if not v.name.startswith(["extraction", "core.io"]): + v.disabled = True + except: + pass + + +def parse_timing(log_file): + timings = dict() + # Open the log file + with open(log_file, "r") as f: + # Line by line read + for line in f.read().splitlines(): + if not line.startswith("DEBUG:root"): + continue + words = line.split(":") + # Only keep lines that include "Timing" + if "Timing" in words: + # Split the last two into key, value + k, v = words[-2:] + # Dict[key].append(value) + if k not in timings: + timings[k] = [] + timings[k].append(float(v[:-1])) + return timings + + +def visualise_timing(timings: dict, save_file: str): + plt.figure().clear() + plot_data = { + x: timings[x] + for x in timings + if x.startswith(("Trap", "Writing", "Segmentation", "Extraction")) + } + sorted_keys, fixed_data = zip( + *sorted(plot_data.items(), key=operator.itemgetter(1)) + ) + # Set up the graph parameters + sns.set(style="whitegrid") + # Plot the graph + # sns.stripplot(data=fixed_data, size=1) + ax = sns.boxplot(data=fixed_data, whis=np.inf, width=0.05) + ax.set(xlabel="Stage", ylabel="Time (s)", xticklabels=sorted_keys) + ax.tick_params(axis="x", rotation=90) + ax.figure.savefig(save_file, bbox_inches="tight", transparent=True) + return + + + +strain = "YST_1512" +# exp = 18616 +# exp = 19232 +# exp = 19995 +# exp = 19993 +exp = 20191 +# exp = 19831 + +with Dataset(exp) as conn: + imgs = conn.get_images() + exp_name = conn.unique_name + +with Image(list(imgs.values())[0]) as im: + meta = im.metadata +tps = int(meta["size_t"]) + +config = dict( + general=dict( + id=exp, + distributed=4, + tps=tps, + directory="../data/", + strain=strain, + tile_size=117, + ), + # general=dict(id=19303, distributed=0, tps=tps, strain=strain, directory="../data/"), + tiler=dict(), + baby=dict(tf_version=2), + earlystop=dict( + min_tp=50, + thresh_pos_clogged=0.3, + thresh_trap_clogged=7, + ntps_to_eval=5, + ), +) + +# Run +run_config(config) diff --git a/setup.py b/setup.py index ed1cf0e057af2b2e05cae46b3f77125ff5112477..a3db7d8a1c128b8e8a247809c7924e2f070c7c61 100644 --- a/setup.py +++ b/setup.py @@ -28,5 +28,8 @@ setup( 'tensorflow>=1.14,<=2.3', 'baby@git+ssh://git@git.ecdf.ed.ac.uk/jpietsch/baby.git@training', 'logfile_parser@git+ssh://git@git.ecdf.ed.ac.uk/swain-lab/python-pipeline/logfile_parser.git', + #"extraction@git+ssh://git@git.ecdf.ed.ac.uk/swain-lab/python-pipeline/extraction.git@dev", + #"postprocessor@git+ssh://git@git.ecdf.ed.ac.uk/swain-lab/python-pipeline/post-processing.git@dev", + ], )