Skip to content
Snippets Groups Projects
Commit 0d9f8df5 authored by Alán Muñoz's avatar Alán Muñoz
Browse files

improve distributed and tiler

parent 50d77d41
No related branches found
No related tags found
No related merge requests found
...@@ -7,8 +7,9 @@ from core.experiment import MetaData ...@@ -7,8 +7,9 @@ from core.experiment import MetaData
from pathos.multiprocessing import Pool from pathos.multiprocessing import Pool
from multiprocessing import set_start_method from multiprocessing import set_start_method
import numpy as np import numpy as np
from postprocessor.core.processor import PostProcessorParameters, PostProcessor
from extraction.core.functions.defaults import exparams_from_meta from extraction.core.functions.defaults import exparams_from_meta
from core.io.signal import Signal
# set_start_method("spawn") # set_start_method("spawn")
...@@ -31,10 +32,6 @@ from extraction.core.extractor import Extractor ...@@ -31,10 +32,6 @@ from extraction.core.extractor import Extractor
from extraction.core.parameters import Parameters from extraction.core.parameters import Parameters
from extraction.core.functions.defaults import get_params from extraction.core.functions.defaults import get_params
import warnings
# TODO This is for extraction issue #9, remove when fixed
warnings.simplefilter('ignore', RuntimeWarning)
def pipeline(image_id, tps=10, tf_version=2): def pipeline(image_id, tps=10, tf_version=2):
name, image_id = image_id name, image_id = image_id
...@@ -42,19 +39,19 @@ def pipeline(image_id, tps=10, tf_version=2): ...@@ -42,19 +39,19 @@ def pipeline(image_id, tps=10, tf_version=2):
# Initialise tensorflow # Initialise tensorflow
session = initialise_tf(tf_version) session = initialise_tf(tf_version)
with Image(image_id) as image: with Image(image_id) as image:
print(f'Getting data for {image.name}') print(f"Getting data for {image.name}")
tiler = Tiler(image.data, image.metadata, image.name) tiler = Tiler(image.data, image.metadata, image.name)
writer = TilerWriter(f'../data/test2/{image.name}.h5') writer = TilerWriter(f"../data/test2/{image.name}.h5")
runner = DummyRunner(tiler) runner = DummyRunner(tiler)
bwriter = BabyWriter(f'../data/test2/{image.name}.h5') bwriter = BabyWriter(f"../data/test2/{image.name}.h5")
for i in tqdm(range(0, tps), desc=image.name): for i in tqdm(range(0, tps), desc=image.name):
trap_info = tiler.run_tp(i) trap_info = tiler.run_tp(i)
writer.write(trap_info, overwrite=[]) writer.write(trap_info, overwrite=[])
seg = runner.run_tp(i) seg = runner.run_tp(i)
bwriter.write(seg, overwrite=['mother_assign']) bwriter.write(seg, overwrite=["mother_assign"])
return True return True
except Exception as e: # bug in the trap getting except Exception as e: # bug in the trap getting
print(f'Caught exception in worker thread (x = {name}):') print(f"Caught exception in worker thread (x = {name}):")
# This prints the type, value, and stack trace of the # This prints the type, value, and stack trace of the
# current exception being handled. # current exception being handled.
traceback.print_exc() traceback.print_exc()
...@@ -72,50 +69,95 @@ def create_pipeline(image_id, **config): ...@@ -72,50 +69,95 @@ def create_pipeline(image_id, **config):
general_config = config.get("general", None) general_config = config.get("general", None)
assert general_config is not None assert general_config is not None
session = None session = None
earlystop = config.get(
"earlystop",
{
"min_tp": 50,
"thresh_pos_clogged": 0.3,
"thresh_trap_clogged": 7,
"ntps_to_eval": 5,
},
)
try: try:
directory = general_config.get("directory", "") directory = general_config.get("directory", "")
with Image(image_id) as image: with Image(image_id) as image:
filename = f"{directory}/{image.name}.h5" filename = f"{directory}/{image.name}.h5"
# Run metadata first # Run metadata first
meta = MetaData(directory, filename) process_from = 0
meta.run() if True: # not Path(filename).exists():
tiler = Tiler(image.data, image.metadata) meta = MetaData(directory, filename)
meta.run()
tiler = Tiler(
image.data,
image.metadata,
tile_size=general_config.get("tile_size", 117),
)
else:
tiler = Tiler.from_hdf5(image.data, filename)
s = Signal(filename)
process_from = s["/general/None/extraction/volume"].columns[-1]
if process_from > 2:
process_from = process_from - 3
tiler.n_processed = process_from
writer = TilerWriter(filename) writer = TilerWriter(filename)
baby_config = config.get("baby", None) baby_config = config.get("baby", None)
assert baby_config is not None # TODO add defaults assert baby_config is not None # TODO add defaults
tf_version = baby_config.get("tf_version", 1) tf_version = baby_config.get("tf_version", 2)
session = initialise_tf(tf_version) session = initialise_tf(tf_version)
runner = DummyRunner(tiler) runner = DummyRunner(tiler)
bwriter = BabyWriter(filename) bwriter = BabyWriter(filename)
# FIXME testing here the extraction params = Parameters(**exparams_from_meta(filename))
meta = load_attributes(filename)
namebuild = [meta["microscope"].lower(), "fast"]
if "mCherry" in meta["channels/channel"]:
namebuild.insert(1, "dual")
params = Parameters(**get_params("_".join(namebuild)))
ext = Extractor.from_tiler(params, store=filename, tiler=tiler) ext = Extractor.from_tiler(params, store=filename, tiler=tiler)
# RUN # RUN
tps = general_config.get("tps", 0) tps = general_config.get("tps", 0)
for i in tqdm(range(0, tps), desc=image.name): frac_clogged_traps = 0
t = perf_counter() for i in tqdm(
trap_info = tiler.run_tp(i) range(process_from, tps), desc=image.name, initial=process_from
logging.debug(f"Timing:Trap:{perf_counter() - t}s") ):
t = perf_counter() if frac_clogged_traps < earlystop["thresh_pos_clogged"]:
writer.write(trap_info, overwrite=[]) t = perf_counter()
logging.debug(f"Timing:Writing-trap:{perf_counter() - t}s") trap_info = tiler.run_tp(i)
t = perf_counter() logging.debug(f"Timing:Trap:{perf_counter() - t}s")
seg = runner.run_tp(i) t = perf_counter()
logging.debug(f"Timing:Segmentation:{perf_counter() - t}s") writer.write(trap_info, overwrite=[])
t = perf_counter() logging.debug(f"Timing:Writing-trap:{perf_counter() - t}s")
bwriter.write(seg, overwrite=["mother_assign"]) t = perf_counter()
logging.debug(f"Timing:Writing-baby:{perf_counter() - t}s") seg = runner.run_tp(i)
t = perf_counter() logging.debug(f"Timing:Segmentation:{perf_counter() - t}s")
ext.extract_pos(tps=[i]) t = perf_counter()
logging.debug(f"Timing:Extraction:{perf_counter() - t}s") bwriter.write(seg, overwrite=["mother_assign"])
logging.debug(f"Timing:Writing-baby:{perf_counter() - t}s")
t = perf_counter()
ext.extract_pos(tps=[i])
logging.debug(f"Timing:Extraction:{perf_counter() - t}s")
else: # Stop if more than 10% traps are clogged
logging.debug(
f"EarlyStop:{earlystop['thresh_pos_clogged']*100}% traps clogged at time point {i}"
)
print(
f"Breaking experiment at time {i} with {frac_clogged_traps} clogged traps"
)
break
if i > earlystop["min_tp"]: # Calculate the fraction of clogged traps
s = Signal(filename)
df = s["/extraction/general/None/area"]
frac_clogged_traps = (
df[df.columns[i - earlystop["ntps_to_eval"] : i]]
.dropna(how="all")
.notna()
.groupby("trap")
.apply(sum)
.apply(np.nanmean, axis=1)
> earlystop["thresh_trap_clogged"]
).mean()
logging.debug(f"Quality:Clogged_traps:{frac_clogged_traps}")
print("Frac clogged traps: ", frac_clogged_traps)
# Run post processing # Run post processing
post_proc_params = PostProcessorParameters.default() # post_proc_params = PostProcessorParameters.default()
post_process(filename, post_proc_params) # post_process(filename, post_proc_params)
return True return True
except Exception as e: # bug in the trap getting except Exception as e: # bug in the trap getting
print(f"Caught exception in worker thread (x = {name}):") print(f"Caught exception in worker thread (x = {name}):")
...@@ -128,74 +170,77 @@ def create_pipeline(image_id, **config): ...@@ -128,74 +170,77 @@ def create_pipeline(image_id, **config):
if session: if session:
session.close() session.close()
@timed('Post-processing')
@timed("Post-processing")
def post_process(filepath, params): def post_process(filepath, params):
pp = PostProcessor(filepath, params) pp = PostProcessor(filepath, params)
tmp = pp.run() tmp = pp.run()
return tmp return tmp
@timed('Pipeline')
# instantiating the decorator
@timed("Pipeline")
def run_config(config): def run_config(config):
# Config holds the general information, use in main # Config holds the general information, use in main
# Steps holds the description of tasks with their parameters # Steps holds the description of tasks with their parameters
# Steps: all holds general tasks # Steps: all holds general tasks
# steps: strain_name holds task for a given strain # steps: strain_name holds task for a given strain
expt_id = config['general'].get('id') expt_id = config["general"].get("id")
distributed = config['general'].get('distributed', 0) distributed = config["general"].get("distributed", 0)
strain_filter = config['general'].get('strain', '') strain_filter = config["general"].get("strain", "")
root_dir = config['general'].get('directory', 'output') root_dir = config["general"].get("directory", "output")
root_dir = Path(root_dir) root_dir = Path(root_dir)
print('Searching OMERO') print("Searching OMERO")
# Do all initialisation # Do all initialisation
with Dataset(int(expt_id)) as conn: with Dataset(int(expt_id)) as conn:
image_ids = conn.get_images() image_ids = conn.get_images()
directory = root_dir / conn.name directory = root_dir / conn.unique_name
if not directory.exists(): if not directory.exists():
directory.mkdir(parents=True) directory.mkdir(parents=True)
# Download logs to use for metadata # Download logs to use for metadata
conn.cache_logs(directory) conn.cache_logs(directory)
# Modify to the configuration # Modify to the configuration
config['general']['directory'] = directory config["general"]["directory"] = directory
# Filter # Filter
image_ids = {k: v for k, v in image_ids.items() if k.startswith( image_ids = {k: v for k, v in image_ids.items() if k.startswith(strain_filter)}
strain_filter)}
if distributed != 0: # Gives the number of simultaneous processes if distributed != 0: # Gives the number of simultaneous processes
with Pool(distributed) as p: with Pool(distributed) as p:
results = p.map(lambda x: create_pipeline(x, **config), image_ids.items()) results = p.map(lambda x: create_pipeline(x, **config), image_ids.items())
p.terminate()
return results return results
else: # Sequential else: # Sequential
results = [] results = []
for k, v in image_ids.items(): for k, v in image_ids.items():
r = create_pipeline((k, v), **config) r = create_pipeline((k, v), **config)
results.append(r) results.append(r)
def initialise_logging(log_file: str): def initialise_logging(log_file: str):
logging.basicConfig(filename=log_file, level=logging.DEBUG) logging.basicConfig(filename=log_file, level=logging.DEBUG)
for v in logging.Logger.manager.loggerDict.values(): for v in logging.Logger.manager.loggerDict.values():
try: try:
if not v.name.startswith(['extraction', 'core.io']): if not v.name.startswith(["extraction", "core.io"]):
v.disabled = True v.disabled = True
except: except:
pass pass
def parse_timing(log_file): def parse_timing(log_file):
timings = dict() timings = dict()
# Open the log file # Open the log file
with open(log_file, 'r') as f: with open(log_file, "r") as f:
# Line by line read # Line by line read
for line in f.read().splitlines(): for line in f.read().splitlines():
if not line.startswith('DEBUG:root'): if not line.startswith("DEBUG:root"):
continue continue
words = line.split(':') words = line.split(":")
# Only keep lines that include "Timing" # Only keep lines that include "Timing"
if 'Timing' in words: if "Timing" in words:
# Split the last two into key, value # Split the last two into key, value
k,v = words[-2:] k, v = words[-2:]
# Dict[key].append(value) # Dict[key].append(value)
if k not in timings: if k not in timings:
timings[k] = [] timings[k] = []
...@@ -205,43 +250,81 @@ def parse_timing(log_file): ...@@ -205,43 +250,81 @@ def parse_timing(log_file):
def visualise_timing(timings: dict, save_file: str): def visualise_timing(timings: dict, save_file: str):
plt.figure().clear() plt.figure().clear()
plot_data = {x: timings[x] for x in timings if x.startswith(('Trap', 'Writing', 'Segmentation', 'Extraction'))} plot_data = {
sorted_keys, fixed_data = zip(*sorted(plot_data.items(), key=operator.itemgetter(1))) x: timings[x]
#Set up the graph parameters for x in timings
sns.set(style='whitegrid') if x.startswith(("Trap", "Writing", "Segmentation", "Extraction"))
#Plot the graph }
#sns.stripplot(data=fixed_data, size=1) sorted_keys, fixed_data = zip(
ax = sns.boxplot(data=fixed_data, whis=np.inf, width=.05) *sorted(plot_data.items(), key=operator.itemgetter(1))
)
# Set up the graph parameters
sns.set(style="whitegrid")
# Plot the graph
# sns.stripplot(data=fixed_data, size=1)
ax = sns.boxplot(data=fixed_data, whis=np.inf, width=0.05)
ax.set(xlabel="Stage", ylabel="Time (s)", xticklabels=sorted_keys) ax.set(xlabel="Stage", ylabel="Time (s)", xticklabels=sorted_keys)
ax.tick_params(axis='x', rotation=90); ax.tick_params(axis="x", rotation=90)
ax.figure.savefig(save_file, bbox_inches='tight', transparent=True) ax.figure.savefig(save_file, bbox_inches="tight", transparent=True)
return return
# if __name__ == "__main__":
if __name__ == "__main__":
strain = 'Vph1' strain = "YST_1512"
tps =390 # exp = 18616
config = dict( # exp = 19232
general=dict( # exp = 19995
id=19303, # exp = 19993
distributed=5, exp = 20191
tps=tps, # exp = 19831
strain=strain, with Dataset(exp) as conn:
directory='../data/' imgs = conn.get_images()
), exp_name = conn.unique_name
tiler=dict(),
baby=dict(tf_version=2) with Image(list(imgs.values())[0]) as im:
) meta = im.metadata
log_file = '../data/2tozero_Hxts_02/issues.log' tps = int(meta["size_t"])
initialise_logging(log_file) # tps = meta["size_t"]
save_timings = f"../data/2tozero_Hxts_02/timings_{strain}_{tps}.pdf" config = dict(
timings_file = f"../data/2tozero_Hxts_02/timings_{strain}_{tps}.json" general=dict(
# Run id=exp,
#run_config(config) distributed=4,
# Get timing results tps=tps,
timing = parse_timing(log_file) directory="../data/",
# Visualise timings and save strain=strain,
visualise_timing(timing, save_timings) tile_size=117,
# Dump the rest to json ),
with open(timings_file, 'w') as fd: # general=dict(id=19303, distributed=0, tps=tps, strain=strain, directory="../data/"),
json.dump(timing, fd) tiler=dict(),
baby=dict(tf_version=2),
earlystop=dict(
min_tp=50,
thresh_pos_clogged=0.3,
thresh_trap_clogged=7,
ntps_to_eval=5,
),
)
# log_file = f"../data/{exp_name}/issues.log"
log_file = "/shared_libs/pydask/pipeline-core/data/2020_02_20_protAgg_downUpShift_2_0_2_Ura8_Ura8HA_Ura8HR_01/issues.log"
# initialise_logging(log_file)
save_timings = f"../data/{exp_name}/timings_{strain}_{tps}.pdf"
timings_file = f"../data/{exp_name}/timings_{strain}_{tps}.json"
# Run
run_config(config)
# Get timing results
# timing = parse_timing(log_file)
# # Visualise timings and save
# visualise_timing(timing, save_timings)
# # Dump the rest to json
# with open(timings_file, "w") as fd:
# json.dump(timing, fd)
# filename = "/shared_libs/pydask/pipeline-core/data/2020_02_20_protAgg_downUpShift_2_0_2_Ura8_Ura8HA_Ura8HR_01/Ura8H360R030.h5"
# import h5py
# with h5py.File(filename, "r") as f:
# plt.imshow(f["cell_info/edgemasks/values"][0][-1])
#!/usr/bin/env python3
expts = [18616, 19232, 19995, 19993, 20191, 19831]
# fetch images
test_imgs = []
for e in expts:
with Dataset(int(e)) as conn:
image_ids = conn.get_images()
for im_id in image_ids.values():
with Image(im_id) as image:
dimg = image.data
print("computing")
img = dimg[
0, image.metadata["channels"].index("Brightfield"), 2, ...
].compute()
test_imgs.append(img)
from numpy import save, load
# save
for i, nd in enumerate(test_imgs):
save("raw_" + str(i) + ".png", nd)
# load
def stretch_image(image):
image = ((image - image.min()) / (image.max() - image.min())) * 255
minval = np.percentile(image, 2)
maxval = np.percentile(image, 98)
image = np.clip(image, minval, maxval)
image = (image - minval) / (maxval - minval)
return image
def segment_traps(image, tile_size, downscale=0.4):
# Make image go between 0 and 255
img = image # Keep a memory of image in case need to re-run
stretched = stretch_image(image)
img = stretch_image(image)
# TODO Optimise the hyperparameters
disk_radius = int(min([0.01 * x for x in img.shape]))
min_area = 0.2 * (tile_size ** 2)
if downscale != 1:
img = transform.rescale(image, downscale)
entropy_image = entropy(img, disk(disk_radius))
if downscale != 1:
entropy_image = transform.rescale(entropy_image, 1 / downscale)
# apply threshold
thresh = threshold_otsu(entropy_image)
bw = closing(entropy_image > thresh, square(3))
# remove artifacts connected to image border
cleared = clear_border(bw)
# label image regions
label_image = label(cleared)
areas = [
region.area
for region in regionprops(label_image)
if region.area > min_area and region.area < tile_size ** 2 * 0.8
]
traps = (
np.array(
[
region.centroid
for region in regionprops(label_image)
if region.area > min_area and region.area < tile_size ** 2 * 0.8
]
)
.round()
.astype(int)
)
rprops = regionprops_table(
label_image,
properties=[
"area",
"eccentricity",
"convex_area",
"feret_diameter_max",
"orientation",
"solidity",
"minor_axis_length",
],
)
trapmask = (rprops["area"] > min_area) & (rprops["area"] < tile_size ** 2 * 0.8)
candidates = [
stretched[
x - tile_size // 2 : x + tile_size // 2,
y - tile_size // 2 : y + tile_size // 2,
]
for x, y in np.array(traps).round().astype(int)
]
# valleys = [find_valley(c) for c in candidates]
from copy import copy
bak = copy(candidates)
candidates = [bak[x] for x in np.argsort(rprops["minor_axis_length"][trapmask])]
return candidates[:5]
# fig, axes = plt.subplots(5, 8)
# indices = np.concatenate((np.arange(20), -np.arange(1, 21)[::-1]))
# for i in range(5):
# for j in range(8):
# if i * 8 + j < len(candidates):
# # axes[i, j].imshow(candidates[i * 8 + j])
# axes[i, j].imshow(candidates[indices[i * 8 + j]])
# plt.show()
# chosen_trap_coords = np.round(traps[np.argsort(area)[len(area) // 2]]).astype(int)
# chosen_trap_coords = np.round(traps[np.argsort(ma)[len(ma) // 2]]).astype(int)
x, y = chosen_trap_coords
template = image[
x - tile_size // 2 : x + tile_size // 2, y - tile_size // 2 : y + tile_size // 2
]
return template
new_coords = identify_trap_locations(image, template)
# def get_tile(tile_size=117):
# tile = np.ones((tile_size, tile_size))
# tile[1:-1, 1:-1] = False
# return tile
# tile = get_tile(tile_size)
# # tmp
# mask = np.zeros_like(image, dtype="bool")
# # for x, y in np.array(traps).round().astype(int):
# for x, y in new_coords:
# dist = int(tile_size / 2)
# size_okay = (
# np.array(mask[x - dist : x + dist + 1, y - dist : y + dist + 1].shape)
# == np.array(tile.shape)
# ).all()
# if size_okay:
# maxes = np.maximum.reduce(
# (mask[x - dist : x + dist + 1, y - dist : y + dist + 1], tile)
# )
# mask[x - dist : x + dist + 1, y - dist : y + dist + 1] = maxes
# from skimage.color import label2rgb
# traps_img = label2rgb(mask, image=stretched, bg_label=0, alpha=0.5)
if len(traps) < 10 and downscale != 1:
print("Trying again.")
return segment_traps(image, tile_size, downscale=1)
# return traps
return traps_img
ncols = 10
rands = np.random.randint(0, 138, ncols)
top_cands = [segment_traps(test_imgs[r], tile_size=117) for r in rands]
fig, axes = plt.subplots(5, ncols)
for i in range(ncols):
for j in range(5):
axes[j, i].imshow(top_cands[i][j])
plt.show()
# res = [segment_traps(im, tile_size=117) for im in test_imgs[rands]]
from scipy.signal import find_peaks
def find_valley(template):
template = ((template - template.min()) / (template.max() - template.min())) * 255
summed = template.sum(axis=1)
norm = summed / summed.max()
find_peaks(norm[20:-20])
max1, max2 = np.argsort(norm[peaks[0]])[:2]
if max2 < max1:
tmp = max2
max2 = max1
max1 = tmp
return norm[max1:max2].min()
for i, im in enumerate(res):
plt.imshow(im)
plt.axis("off")
plt.savefig("tiles" + str(i), dpi=400)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment