Skip to content
Snippets Groups Projects
Commit eefd743f authored by pswain's avatar pswain
Browse files

docs for pipeline

parent f0b031db
No related branches found
No related tags found
No related merge requests found
""" """Set up and run pipelines: tiling, segmentation, extraction, and then post-processing."""
Pipeline and chaining elements.
"""
import logging import logging
import os import os
import re import re
...@@ -36,12 +34,14 @@ from postprocessor.core.processor import PostProcessor, PostProcessorParameters ...@@ -36,12 +34,14 @@ from postprocessor.core.processor import PostProcessor, PostProcessorParameters
class PipelineParameters(ParametersABC): class PipelineParameters(ParametersABC):
"""Define parameters for the different steps of the pipeline.""" """Define parameters for the steps of the pipeline."""
_pool_index = None _pool_index = None
def __init__(self, general, tiler, baby, extraction, postprocessing, reporting): def __init__(
"""Initialise, but called by a class method not directly.""" self, general, tiler, baby, extraction, postprocessing, reporting
):
"""Initialise, but called by a class method - not directly."""
self.general = general self.general = general
self.tiler = tiler self.tiler = tiler
self.baby = baby self.baby = baby
...@@ -143,7 +143,8 @@ class PipelineParameters(ParametersABC): ...@@ -143,7 +143,8 @@ class PipelineParameters(ParametersABC):
defaults["tiler"] = TilerParameters.default(**tiler).to_dict() defaults["tiler"] = TilerParameters.default(**tiler).to_dict()
defaults["baby"] = BabyParameters.default(**baby).to_dict() defaults["baby"] = BabyParameters.default(**baby).to_dict()
defaults["extraction"] = ( defaults["extraction"] = (
exparams_from_meta(meta_d) or BabyParameters.default(**extraction).to_dict() exparams_from_meta(meta_d)
or BabyParameters.default(**extraction).to_dict()
) )
defaults["postprocessing"] = PostProcessorParameters.default( defaults["postprocessing"] = PostProcessorParameters.default(
**postprocessing **postprocessing
...@@ -159,10 +160,11 @@ class PipelineParameters(ParametersABC): ...@@ -159,10 +160,11 @@ class PipelineParameters(ParametersABC):
class Pipeline(ProcessABC): class Pipeline(ProcessABC):
""" """
A chained set of Pipeline elements connected through pipes. Initialise and run tiling, segmentation, extraction and post-processing.
Tiling, Segmentation,Extraction and Postprocessing should use their own default parameters.
These can be overriden passing the key:value of parameters to override to a PipelineParameters class Each step feeds the next one.
To customise parameters for any step use the PipelineParameters class.stem
""" """
pipeline_steps = ["tiler", "baby", "extraction"] pipeline_steps = ["tiler", "baby", "extraction"]
...@@ -173,7 +175,7 @@ class Pipeline(ProcessABC): ...@@ -173,7 +175,7 @@ class Pipeline(ProcessABC):
"postprocessing", "postprocessing",
] ]
# Indicate step-writer groupings to perform special operations during step iteration # Indicate step-writer groupings to perform special operations during step iteration
# specify the group in the h5 files written by each step (?) # Alan: replace with - specify the group in the h5 files written by each step (?)
writer_groups = { writer_groups = {
"tiler": ["trap_info"], "tiler": ["trap_info"],
"baby": ["cell_info"], "baby": ["cell_info"],
...@@ -193,7 +195,9 @@ class Pipeline(ProcessABC): ...@@ -193,7 +195,9 @@ class Pipeline(ProcessABC):
self.store = store self.store = store
@staticmethod @staticmethod
def setLogger(folder, file_level: str = "INFO", stream_level: str = "WARNING"): def setLogger(
folder, file_level: str = "INFO", stream_level: str = "WARNING"
):
"""Initialise and format logger.""" """Initialise and format logger."""
logger = logging.getLogger("aliby") logger = logging.getLogger("aliby")
logger.setLevel(getattr(logging, file_level)) logger.setLevel(getattr(logging, file_level))
...@@ -237,13 +241,19 @@ class Pipeline(ProcessABC): ...@@ -237,13 +241,19 @@ class Pipeline(ProcessABC):
fpath = files[0] fpath = files[0]
# TODO add support for non-standard unique folder names # TODO add support for non-standard unique folder names
with h5py.File(fpath, "r") as f: with h5py.File(fpath, "r") as f:
pipeline_parameters = PipelineParameters.from_yaml(f.attrs["parameters"]) pipeline_parameters = PipelineParameters.from_yaml(
f.attrs["parameters"]
)
pipeline_parameters.general["directory"] = dir_path.parent pipeline_parameters.general["directory"] = dir_path.parent
pipeline_parameters.general["filter"] = [fpath.stem for fpath in files] pipeline_parameters.general["filter"] = [fpath.stem for fpath in files]
# fix legacy post-processing parameters # fix legacy post-processing parameters
post_process_params = pipeline_parameters.postprocessing.get("parameters", None) post_process_params = pipeline_parameters.postprocessing.get(
"parameters", None
)
if post_process_params: if post_process_params:
pipeline_parameters.postprocessing["param_sets"] = copy(post_process_params) pipeline_parameters.postprocessing["param_sets"] = copy(
post_process_params
)
del pipeline_parameters.postprocessing["parameters"] del pipeline_parameters.postprocessing["parameters"]
return cls(pipeline_parameters) return cls(pipeline_parameters)
...@@ -260,13 +270,19 @@ class Pipeline(ProcessABC): ...@@ -260,13 +270,19 @@ class Pipeline(ProcessABC):
Name of file. Name of file.
""" """
with h5py.File(fpath, "r") as f: with h5py.File(fpath, "r") as f:
pipeline_parameters = PipelineParameters.from_yaml(f.attrs["parameters"]) pipeline_parameters = PipelineParameters.from_yaml(
f.attrs["parameters"]
)
directory = Path(fpath).parent directory = Path(fpath).parent
pipeline_parameters.general["directory"] = directory pipeline_parameters.general["directory"] = directory
pipeline_parameters.general["filter"] = Path(fpath).stem pipeline_parameters.general["filter"] = Path(fpath).stem
post_process_params = pipeline_parameters.postprocessing.get("parameters", None) post_process_params = pipeline_parameters.postprocessing.get(
"parameters", None
)
if post_process_params: if post_process_params:
pipeline_parameters.postprocessing["param_sets"] = copy(post_process_params) pipeline_parameters.postprocessing["param_sets"] = copy(
post_process_params
)
del pipeline_parameters.postprocessing["parameters"] del pipeline_parameters.postprocessing["parameters"]
return cls(pipeline_parameters, store=directory) return cls(pipeline_parameters, store=directory)
...@@ -275,18 +291,16 @@ class Pipeline(ProcessABC): ...@@ -275,18 +291,16 @@ class Pipeline(ProcessABC):
return logging.getLogger("aliby") return logging.getLogger("aliby")
def run(self): def run(self):
""" """Run separate pipelines for all positions in an experiment."""
Config holds the general information, use in main # general information in config
Steps: all holds general tasks
steps: strain_name holds task for a given strain
"""
config = self.parameters.to_dict() config = self.parameters.to_dict()
expt_id = config["general"]["id"] expt_id = config["general"]["id"]
distributed = config["general"]["distributed"] distributed = config["general"]["distributed"]
pos_filter = config["general"]["filter"] pos_filter = config["general"]["filter"]
root_dir = Path(config["general"]["directory"]) root_dir = Path(config["general"]["directory"])
self.server_info = { self.server_info = {
k: config["general"].get(k) for k in ("host", "username", "password") k: config["general"].get(k)
for k in ("host", "username", "password")
} }
dispatcher = dispatch_dataset(expt_id, **self.server_info) dispatcher = dispatch_dataset(expt_id, **self.server_info)
logging.getLogger("aliby").info( logging.getLogger("aliby").info(
...@@ -305,28 +319,29 @@ class Pipeline(ProcessABC): ...@@ -305,28 +319,29 @@ class Pipeline(ProcessABC):
config["general"]["directory"] = directory config["general"]["directory"] = directory
self.setLogger(directory) self.setLogger(directory)
# pick particular images if desired # pick particular images if desired
if isinstance(pos_filter, list): if pos_filter:
image_ids = { if isinstance(pos_filter, list):
k: v image_ids = {
for filt in pos_filter k: v
for k, v in self.apply_filter(image_ids, filt).items() for filt in pos_filter
} for k, v in self.apply_filter(image_ids, filt).items()
else: }
image_ids = self.apply_filter(image_ids, pos_filter) else:
image_ids = self.apply_filter(image_ids, pos_filter)
assert len(image_ids), "No images to segment" assert len(image_ids), "No images to segment"
# create pipeline # create pipelines
if distributed != 0: if distributed != 0:
# multiple cores # multiple cores
with Pool(distributed) as p: with Pool(distributed) as p:
results = p.map( results = p.map(
lambda x: self.run_one_pipeline(*x), lambda x: self.run_one_position(*x),
[(k, i) for i, k in enumerate(image_ids.items())], [(k, i) for i, k in enumerate(image_ids.items())],
) )
else: else:
# single core # single core
results = [] results = []
for k, v in tqdm(image_ids.items()): for k, v in tqdm(image_ids.items()):
r = self.run_one_pipeline((k, v), 1) r = self.run_one_position((k, v), 1)
results.append(r) results.append(r)
return results return results
...@@ -334,7 +349,9 @@ class Pipeline(ProcessABC): ...@@ -334,7 +349,9 @@ class Pipeline(ProcessABC):
"""Select images by picking a particular one or by using a regular expression to parse their file names.""" """Select images by picking a particular one or by using a regular expression to parse their file names."""
if isinstance(filt, str): if isinstance(filt, str):
# pick images using a regular expression # pick images using a regular expression
image_ids = {k: v for k, v in image_ids.items() if re.search(filt, k)} image_ids = {
k: v for k, v in image_ids.items() if re.search(filt, k)
}
elif isinstance(filt, int): elif isinstance(filt, int):
# pick the filt'th image # pick the filt'th image
image_ids = { image_ids = {
...@@ -342,16 +359,19 @@ class Pipeline(ProcessABC): ...@@ -342,16 +359,19 @@ class Pipeline(ProcessABC):
} }
return image_ids return image_ids
def run_one_pipeline( def run_one_position(
self, self,
name_image_id: t.Tuple[str, str or PosixPath or int], name_image_id: t.Tuple[str, str or PosixPath or int],
index: t.Optional[int] = None, index: t.Optional[int] = None,
): ):
""" """ """Set up and run a pipeline for one position."""
self._pool_index = index self._pool_index = index
name, image_id = name_image_id name, image_id = name_image_id
# session and filename are defined by calling setup_pipeline.
# can they be deleted here?
session = None session = None
filename = None filename = None
#
run_kwargs = {"extraction": {"labels": None, "masks": None}} run_kwargs = {"extraction": {"labels": None, "masks": None}}
try: try:
( (
...@@ -379,7 +399,9 @@ class Pipeline(ProcessABC): ...@@ -379,7 +399,9 @@ class Pipeline(ProcessABC):
# START # START
frac_clogged_traps = 0 frac_clogged_traps = 0
min_process_from = min(process_from.values()) min_process_from = min(process_from.values())
with get_image_class(image_id)(image_id, **self.server_info) as image: with get_image_class(image_id)(
image_id, **self.server_info
) as image:
# initialise steps # initialise steps
if "tiler" not in steps: if "tiler" not in steps:
steps["tiler"] = Tiler.from_image( steps["tiler"] = Tiler.from_image(
...@@ -412,22 +434,26 @@ class Pipeline(ProcessABC): ...@@ -412,22 +434,26 @@ class Pipeline(ProcessABC):
for op, (input_ch, _, _) in tmp.items(): for op, (input_ch, _, _) in tmp.items():
if not set(input_ch).issubset(av_channels_wsub): if not set(input_ch).issubset(av_channels_wsub):
del config["extraction"]["multichannel_ops"][op] del config["extraction"]["multichannel_ops"][op]
exparams = ExtractorParameters.from_dict(config["extraction"]) exparams = ExtractorParameters.from_dict(
config["extraction"]
)
steps["extraction"] = Extractor.from_tiler( steps["extraction"] = Extractor.from_tiler(
exparams, store=filename, tiler=steps["tiler"] exparams, store=filename, tiler=steps["tiler"]
) )
# set up progress meter
pbar = tqdm( pbar = tqdm(
range(min_process_from, tps), range(min_process_from, tps),
desc=image.name, desc=image.name,
initial=min_process_from, initial=min_process_from,
total=tps, total=tps,
# position=index + 1,
) )
for i in pbar: for i in pbar:
if ( if (
frac_clogged_traps < earlystop["thresh_pos_clogged"] frac_clogged_traps
< earlystop["thresh_pos_clogged"]
or i < earlystop["min_tp"] or i < earlystop["min_tp"]
): ):
# run through steps
for step in self.pipeline_steps: for step in self.pipeline_steps:
if i >= process_from[step]: if i >= process_from[step]:
result = steps[step].run_tp( result = steps[step].run_tp(
...@@ -436,48 +462,52 @@ class Pipeline(ProcessABC): ...@@ -436,48 +462,52 @@ class Pipeline(ProcessABC):
if step in loaded_writers: if step in loaded_writers:
loaded_writers[step].write( loaded_writers[step].write(
data=result, data=result,
overwrite=writer_ow_kwargs.get(step, []), overwrite=writer_ow_kwargs.get(
step, []
),
tp=i, tp=i,
meta={"last_processed": i}, meta={"last_processed": i},
) )
# perform step
# step-specific actions if (
if step == "tiler" and i == min_process_from: step == "tiler"
and i == min_process_from
):
logging.getLogger("aliby").info( logging.getLogger("aliby").info(
f"Found {steps['tiler'].n_traps} traps in {image.name}" f"Found {steps['tiler'].n_traps} traps in {image.name}"
) )
elif step == "baby": elif step == "baby":
# write state and pass info to ext # write state and pass info to ext (Alan: what's ext?)
loaded_writers["state"].write( loaded_writers["state"].write(
data=steps[step].crawler.tracker_states, data=steps[
step
].crawler.tracker_states,
overwrite=loaded_writers[ overwrite=loaded_writers[
"state" "state"
].datatypes.keys(), ].datatypes.keys(),
tp=i, tp=i,
) )
elif ( elif step == "extraction":
step == "extraction" # remove mask/label after extraction
): # Remove mask/label after ext
for k in ["masks", "labels"]: for k in ["masks", "labels"]:
run_kwargs[step][k] = None run_kwargs[step][k] = None
# check and report clogging
frac_clogged_traps = self.check_earlystop( frac_clogged_traps = self.check_earlystop(
filename, earlystop, steps["tiler"].tile_size filename, earlystop, steps["tiler"].tile_size
) )
self._log(f"{name}:Clogged_traps:{frac_clogged_traps}") self._log(
f"{name}:Clogged_traps:{frac_clogged_traps}"
)
frac = np.round(frac_clogged_traps * 100) frac = np.round(frac_clogged_traps * 100)
pbar.set_postfix_str(f"{frac} Clogged") pbar.set_postfix_str(f"{frac} Clogged")
else: else:
# stop if more than X% traps are clogged # stop if too many traps are clogged
self._log( self._log(
f"{name}:Analysis stopped early at time {i} with {frac_clogged_traps} clogged traps" f"{name}:Stopped early at time {i} with {frac_clogged_traps} clogged traps"
) )
meta.add_fields({"end_status": "Clogged"}) meta.add_fields({"end_status": "Clogged"})
break break
meta.add_fields({"last_processed": i}) meta.add_fields({"last_processed": i})
# run post-processing # run post-processing
meta.add_fields({"end_status": "Success"}) meta.add_fields({"end_status": "Success"})
post_proc_params = PostProcessorParameters.from_dict( post_proc_params = PostProcessorParameters.from_dict(
...@@ -501,22 +531,48 @@ class Pipeline(ProcessABC): ...@@ -501,22 +531,48 @@ class Pipeline(ProcessABC):
@staticmethod @staticmethod
def check_earlystop(filename: str, es_parameters: dict, tile_size: int): def check_earlystop(filename: str, es_parameters: dict, tile_size: int):
"""
Check recent time points for tiles with too many cells.
Returns the fraction of clogged tiles, where clogged tiles have
too many cells or too much of their area covered by cells.
Parameters
----------
filename: str
Name of h5 file.
es_parameters: dict
Parameters defining when early stopping should happen.
For example:
{'min_tp': 100,
'thresh_pos_clogged': 0.4,
'thresh_trap_ncells': 8,
'thresh_trap_area': 0.9,
'ntps_to_eval': 5}
tile_size: int
Size of tile.
"""
# get the area of the cells organised by trap and cell number
s = Signal(filename) s = Signal(filename)
df = s["/extraction/general/None/area"] df = s["/extraction/general/None/area"]
cells_used = df[df.columns[-1 - es_parameters["ntps_to_eval"] : -1]].dropna( # check the latest time points only
how="all" cells_used = df[
) df.columns[-1 - es_parameters["ntps_to_eval"] : -1]
].dropna(how="all")
# find tiles with too many cells
traps_above_nthresh = ( traps_above_nthresh = (
cells_used.groupby("trap").count().apply(np.mean, axis=1) cells_used.groupby("trap").count().apply(np.mean, axis=1)
> es_parameters["thresh_trap_ncells"] > es_parameters["thresh_trap_ncells"]
) )
# find tiles with cells covering too great a fraction of the tiles' area
traps_above_athresh = ( traps_above_athresh = (
cells_used.groupby("trap").sum().apply(np.mean, axis=1) / tile_size**2 cells_used.groupby("trap").sum().apply(np.mean, axis=1)
/ tile_size**2
> es_parameters["thresh_trap_area"] > es_parameters["thresh_trap_area"]
) )
return (traps_above_nthresh & traps_above_athresh).mean() return (traps_above_nthresh & traps_above_athresh).mean()
# Alan: can both this method and the next be deleted?
def _load_config_from_file( def _load_config_from_file(
self, self,
filename: PosixPath, filename: PosixPath,
...@@ -542,7 +598,9 @@ class Pipeline(ProcessABC): ...@@ -542,7 +598,9 @@ class Pipeline(ProcessABC):
switch_case = { switch_case = {
"tiler": lambda f: f["trap_info/drifts"].shape[0] - 1, "tiler": lambda f: f["trap_info/drifts"].shape[0] - 1,
"baby": lambda f: f["cell_info/timepoint"][-1], "baby": lambda f: f["cell_info/timepoint"][-1],
"extraction": lambda f: f["extraction/general/None/area/timepoint"][-1], "extraction": lambda f: f[
"extraction/general/None/area/timepoint"
][-1],
} }
return switch_case[step] return switch_case[step]
...@@ -559,72 +617,66 @@ class Pipeline(ProcessABC): ...@@ -559,72 +617,66 @@ class Pipeline(ProcessABC):
t.List[np.ndarray], t.List[np.ndarray],
]: ]:
""" """
Initialise pipeline components. Initialise steps in a pipeline.
If necessary use a file to continue existing experiments. If necessary use a file to re-start experiments already partly run.
Parameters Parameters
---------- ----------
image_id : int or str image_id : int or str
Identifier of image in OMERO server, or filename Identifier of a data set in an OMERO server or a filename.
Returns Returns
------- -------
filename: str filename: str
meta: Path to a h5 file to write to.
config: meta: object
process_from: agora.io.metadata.MetaData object
tps: config: dict
steps: Configuration parameters.
earlystop: process_from: dict
session: Gives from which time point each step of the pipeline should start.
trackers_state: tps: int
Number of time points.
Examples steps: dict
-------- earlystop: dict
FIXME: Add docs. Parameters to check whether the pipeline should be stopped.
session: None
trackers_state: list
States of any trackers from earlier runs.
""" """
config = self.parameters.to_dict() config = self.parameters.to_dict()
pparams = config # Alan: session is never changed
image_id = image_id
general_config = config["general"]
session = None session = None
earlystop = general_config.get("earlystop", None) earlystop = config["general"].get("earlystop", None)
process_from = {k: 0 for k in self.pipeline_steps} process_from = {k: 0 for k in self.pipeline_steps}
steps = {} steps = {}
ow = {k: 0 for k in self.step_sequence}
# check overwriting # check overwriting
ow_id = general_config.get("overwrite", 0) ow_id = config["general"].get("overwrite", 0)
ow = {step: True for step in self.step_sequence} ow = {step: True for step in self.step_sequence}
if ow_id and ow_id is not True: if ow_id and ow_id is not True:
ow = { ow = {
step: self.step_sequence.index(ow_id) < i step: self.step_sequence.index(ow_id) < i
for i, step in enumerate(self.step_sequence, 1) for i, step in enumerate(self.step_sequence, 1)
} }
# set up
# Set up directory = config["general"]["directory"]
directory = general_config["directory"] trackers_state = []
trackers_state: t.List[np.ndarray] = []
with get_image_class(image_id)(image_id, **self.server_info) as image: with get_image_class(image_id)(image_id, **self.server_info) as image:
filename = Path(f"{directory}/{image.name}.h5") filename = Path(f"{directory}/{image.name}.h5")
meta = MetaData(directory, filename) meta = MetaData(directory, filename)
from_start = True if np.any(ow.values()) else False from_start = True if np.any(ow.values()) else False
# remove existing file if overwriting
# New experiment or overwriting
if ( if (
from_start from_start
and ( and (
config.get("overwrite", False) == True or np.all(list(ow.values())) config["general"].get("overwrite", False)
or np.all(list(ow.values()))
) )
and filename.exists() and filename.exists()
): ):
os.remove(filename) os.remove(filename)
# if the file exists with no previous segmentation use its tiler
# If no previous segmentation and keep tiler
if filename.exists(): if filename.exists():
self._log("Result file exists.", "info") self._log("Result file exists.", "info")
if not ow["tiler"]: if not ow["tiler"]:
...@@ -643,15 +695,14 @@ class Pipeline(ProcessABC): ...@@ -643,15 +695,14 @@ class Pipeline(ProcessABC):
if ow["baby"] if ow["baby"]
else StateReader(filename).get_formatted_states() else StateReader(filename).get_formatted_states()
) )
config["tiler"] = steps["tiler"].parameters.to_dict() config["tiler"] = steps["tiler"].parameters.to_dict()
except Exception: except Exception:
# Alan: a warning or log here?
pass pass
if config["general"]["use_explog"]: if config["general"]["use_explog"]:
meta.run() meta.run()
# add metadata not in the log file
meta.add_fields( # Add non-logfile metadata meta.add_fields(
{ {
"aliby_version": version("aliby"), "aliby_version": version("aliby"),
"baby_version": version("aliby-baby"), "baby_version": version("aliby-baby"),
...@@ -659,12 +710,12 @@ class Pipeline(ProcessABC): ...@@ -659,12 +710,12 @@ class Pipeline(ProcessABC):
"image_id": image_id "image_id": image_id
if isinstance(image_id, int) if isinstance(image_id, int)
else str(image_id), else str(image_id),
"parameters": PipelineParameters.from_dict(pparams).to_yaml(), "parameters": PipelineParameters.from_dict(
config
).to_yaml(),
} }
) )
tps = min(config["general"]["tps"], image.data.shape[0])
tps = min(general_config["tps"], image.data.shape[0])
return ( return (
filename, filename,
meta, meta,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment