diff --git a/src/agora/io/signal.py b/src/agora/io/signal.py index 4ea45539c8a5905e366604ce2c643271ffa00b22..322bb5ee1929650653e93d7bbcab3ae2e72aebd2 100644 --- a/src/agora/io/signal.py +++ b/src/agora/io/signal.py @@ -51,7 +51,7 @@ class Signal(BridgeH5): return self.add_name(df, dsets) elif isinstance(dsets, list): # pre-processing is_bgd = [dset.endswith("imBackground") for dset in dsets] - # Check we are not comaring tile-indexed and cell-indexed data + # Check we are not comparing tile-indexed and cell-indexed data assert sum(is_bgd) == 0 or sum(is_bgd) == len( dsets ), "Tile data and cell data can't be mixed" diff --git a/src/aliby/io/dataset.py b/src/aliby/io/dataset.py index eb383bf2c88987c07b74984f3eb654258d2dd2c0..f26baf7f93b361cc0b9e734ca2155e20ea70605f 100644 --- a/src/aliby/io/dataset.py +++ b/src/aliby/io/dataset.py @@ -12,41 +12,46 @@ import typing as t from abc import ABC, abstractproperty, abstractmethod from pathlib import Path, PosixPath - +from agora.io.bridge import BridgeH5 from aliby.io.image import ImageLocalOME def dispatch_dataset(expt_id: int or str, **kwargs): """ - Choose a subtype of dataset based on the identifier. + Find paths to the data. - Input: - -------- - expt_id: int or string serving as dataset identifier. + Connects to OMERO if data is remotely available. - Returns: - -------- - Callable Dataset instance, either network-dependent or local. - """ - if isinstance(expt_id, int): # Is an experiment online + Parameters + ---------- + expt_id: int or str + To identify the data, either an OMERO ID or an OME-TIFF file or a local directory. + Returns + ------- + A callable Dataset instance, either network-dependent or local. + """ + if isinstance(expt_id, int): + # data available online from aliby.io.omero import Dataset return Dataset(expt_id, **kwargs) - - elif isinstance(expt_id, str): # Files or Dir + elif isinstance(expt_id, str): + # data available locally expt_path = Path(expt_id) if expt_path.is_dir(): + # data in multiple folders return DatasetLocalDir(expt_path) else: + # data in one folder as OME-TIFF files return DatasetLocalOME(expt_path) else: - raise Warning("Invalid expt_id") + raise Warning(f"{expt_id} is an invalid expt_id") class DatasetLocalABC(ABC): """ - Abstract Base class to fetch local files, either OME-XML or raw images. + Abstract Base class to find local files, either OME-XML or raw images. """ _valid_suffixes = ("tiff", "png") @@ -73,12 +78,9 @@ class DatasetLocalABC(ABC): def unique_name(self): return self.path.name - @abstractproperty - def date(self): - pass - @property def files(self): + """Return a dictionary with any available metadata files.""" if not hasattr(self, "_files"): self._files = { f: f @@ -91,34 +93,35 @@ class DatasetLocalABC(ABC): return self._files def cache_logs(self, root_dir): - # Copy metadata files to results folder + """Copy metadata files to results folder.""" for name, annotation in self.files.items(): shutil.copy(annotation, root_dir / name.name) return True + @abstractproperty + def date(self): + pass + @abstractmethod def get_images(self): - # Return a dictionary with the name of images and their unique identifiers pass class DatasetLocalDir(DatasetLocalABC): - """ - Organise an entire dataset, composed of multiple images, as a directory containing directories with individual files. - It relies on ImageDir to manage images. - """ + """Find paths to a data set, comprising multiple images in different folders.""" def __init__(self, dpath: t.Union[str, PosixPath], *args, **kwargs): super().__init__(dpath) @property def date(self): - # Use folder creation date, for cases where metadata is minimal + """Find date when a folder was created.""" return time.strftime( "%Y%m%d", time.strptime(time.ctime(os.path.getmtime(self.path))) ) def get_images(self): + """Return a dictionary of folder names and their paths.""" return { folder.name: folder for folder in self.path.glob("*/") @@ -131,13 +134,7 @@ class DatasetLocalDir(DatasetLocalABC): class DatasetLocalOME(DatasetLocalABC): - """Load a dataset from a folder - - We use a given image of a dataset to obtain the metadata, - as we cannot expect folders to contain this information. - - It uses the standard OME-TIFF file format. - """ + """Find names of images in a folder, assuming images in OME-TIFF format.""" def __init__(self, dpath: t.Union[str, PosixPath], *args, **kwargs): super().__init__(dpath) @@ -145,11 +142,11 @@ class DatasetLocalOME(DatasetLocalABC): @property def date(self): - # Access the date from the metadata of the first position + """Get the date from the metadata of the first position.""" return ImageLocalOME(list(self.get_images().values())[0]).date def get_images(self): - # Fetches all valid formats and overwrites if duplicates with different suffix + """Return a dictionary with the names of the image files.""" return { f.name: str(f) for suffix in self._valid_suffixes diff --git a/src/aliby/pipeline.py b/src/aliby/pipeline.py index 2977cada71b37a131f30e25103d6d668853878d0..2edcc52c55e5ed72b3bd8d9bc9b740847ed64397 100644 --- a/src/aliby/pipeline.py +++ b/src/aliby/pipeline.py @@ -1,6 +1,4 @@ -""" -Pipeline and chaining elements. -""" +"""Set up and run pipelines: tiling, segmentation, extraction, and then post-processing.""" import logging import os import re @@ -36,23 +34,14 @@ from postprocessor.core.processor import PostProcessor, PostProcessorParameters class PipelineParameters(ParametersABC): - """ - Parameters that host what is run and how. It takes a list of dictionaries, one for - general in collection: - pass dictionary for each step - -------------------- - expt_id: int or str Experiment id (if integer) or local path (if string). - directory: str Directory into which results are dumped. Default is "../data" - - Provides default parameters for the entire pipeline. This downloads the logfiles and sets the default - timepoints and extraction parameters from there. - """ + """Define parameters for the steps of the pipeline.""" _pool_index = None def __init__( self, general, tiler, baby, extraction, postprocessing, reporting ): + """Initialise, but called by a class method - not directly.""" self.general = general self.tiler = tiler self.baby = baby @@ -69,13 +58,34 @@ class PipelineParameters(ParametersABC): extraction={}, postprocessing={}, ): + """ + Initialise parameters for steps of the pipeline. + + Some parameters are extracted from the log files. + + Parameters + --------- + general: dict + Parameters to set up the pipeline. + tiler: dict + Parameters for tiler. + baby: dict (optional) + Parameters for Baby. + extraction: dict (optional) + Parameters for extraction. + postprocessing: dict (optional) + Parameters for post-processing. + """ + # Alan: should 19993 be updated? expt_id = general.get("expt_id", 19993) if isinstance(expt_id, PosixPath): expt_id = str(expt_id) general["expt_id"] = expt_id + # Alan: an error message rather than a default might be better directory = Path(general.get("directory", "../data")) + # get log files, either locally or via OMERO with dispatch_dataset( expt_id, **{k: general.get(k) for k in ("host", "username", "password")}, @@ -83,7 +93,7 @@ class PipelineParameters(ParametersABC): directory = directory / conn.unique_name if not directory.exists(): directory.mkdir(parents=True) - # Download logs to use for metadata + # download logs for metadata conn.cache_logs(directory) try: meta_d = MetaData(directory, None).load_logs() @@ -95,9 +105,10 @@ class PipelineParameters(ParametersABC): "channels": ["Brightfield"], "ntps": [2000], } - # Set minimal metadata + # set minimal metadata meta_d = minimal_default_meta + # define default values for general parameters tps = meta_d.get("ntps", 2000) defaults = { "general": dict( @@ -118,7 +129,8 @@ class PipelineParameters(ParametersABC): ) } - for k, v in general.items(): # Overwrite general parameters + # update default values using inputs + for k, v in general.items(): if k not in defaults["general"]: defaults["general"][k] = v elif isinstance(v, dict): @@ -127,15 +139,13 @@ class PipelineParameters(ParametersABC): else: defaults["general"][k] = v + # define defaults and update with any inputs defaults["tiler"] = TilerParameters.default(**tiler).to_dict() defaults["baby"] = BabyParameters.default(**baby).to_dict() defaults["extraction"] = ( exparams_from_meta(meta_d) or BabyParameters.default(**extraction).to_dict() ) - defaults["postprocessing"] = {} - defaults["reporting"] = {} - defaults["postprocessing"] = PostProcessorParameters.default( **postprocessing ).to_dict() @@ -150,22 +160,22 @@ class PipelineParameters(ParametersABC): class Pipeline(ProcessABC): """ - A chained set of Pipeline elements connected through pipes. - Tiling, Segmentation,Extraction and Postprocessing should use their own default parameters. - These can be overriden passing the key:value of parameters to override to a PipelineParameters class + Initialise and run tiling, segmentation, extraction and post-processing. - """ + Each step feeds the next one. - iterative_steps = ["tiler", "baby", "extraction"] + To customise parameters for any step use the PipelineParameters class.stem + """ + pipeline_steps = ["tiler", "baby", "extraction"] step_sequence = [ "tiler", "baby", "extraction", "postprocessing", ] - # Indicate step-writer groupings to perform special operations during step iteration + # Alan: replace with - specify the group in the h5 files written by each step (?) writer_groups = { "tiler": ["trap_info"], "baby": ["cell_info"], @@ -178,8 +188,8 @@ class Pipeline(ProcessABC): } def __init__(self, parameters: PipelineParameters, store=None): + """Initialise - not usually called directly.""" super().__init__(parameters) - if store is not None: store = Path(store) self.store = store @@ -188,20 +198,19 @@ class Pipeline(ProcessABC): def setLogger( folder, file_level: str = "INFO", stream_level: str = "WARNING" ): - + """Initialise and format logger.""" logger = logging.getLogger("aliby") logger.setLevel(getattr(logging, file_level)) formatter = logging.Formatter( "%(asctime)s - %(levelname)s:%(message)s", datefmt="%Y-%m-%dT%H:%M:%S%z", ) - + # for streams - stdout, files, etc. ch = logging.StreamHandler() ch.setLevel(getattr(logging, stream_level)) ch.setFormatter(formatter) logger.addHandler(ch) - - # create file handler which logs even debug messages + # create file handler that logs even debug messages fh = logging.FileHandler(Path(folder) / "aliby.log", "w+") fh.setLevel(getattr(logging, file_level)) fh.setFormatter(formatter) @@ -216,20 +225,20 @@ class Pipeline(ProcessABC): @classmethod def from_folder(cls, dir_path): """ - Constructor to re-process all files in a given folder. + Re-process all h5 files in a given folder. - Assumes all files share the same parameters (even if they don't share - the same channel set). + All files must share the same parameters, even if they have different channels. Parameters --------- - dir_path : str or Pathlib indicating the folder containing the files to process + dir_path : str or Pathlib + Folder containing the files. """ + # find h5 files dir_path = Path(dir_path) files = list(dir_path.rglob("*.h5")) assert len(files), "No valid files found in folder" fpath = files[0] - # TODO add support for non-standard unique folder names with h5py.File(fpath, "r") as f: pipeline_parameters = PipelineParameters.from_yaml( @@ -237,8 +246,7 @@ class Pipeline(ProcessABC): ) pipeline_parameters.general["directory"] = dir_path.parent pipeline_parameters.general["filter"] = [fpath.stem for fpath in files] - - # Fix legacy postprocessing parameters + # fix legacy post-processing parameters post_process_params = pipeline_parameters.postprocessing.get( "parameters", None ) @@ -247,16 +255,19 @@ class Pipeline(ProcessABC): post_process_params ) del pipeline_parameters.postprocessing["parameters"] - return cls(pipeline_parameters) @classmethod def from_existing_h5(cls, fpath): """ - Constructor to process an existing hdf5 file. - Notice that it forces a single file, not suitable for multiprocessing of certain positions. + Re-process an existing h5 file. + + Not suitable for more than one file. - It i s also used as a base for a folder-wide reprocessing. + Parameters + --------- + fpath: str + Name of file. """ with h5py.File(fpath, "r") as f: pipeline_parameters = PipelineParameters.from_yaml( @@ -265,7 +276,6 @@ class Pipeline(ProcessABC): directory = Path(fpath).parent pipeline_parameters.general["directory"] = directory pipeline_parameters.general["filter"] = Path(fpath).stem - post_process_params = pipeline_parameters.postprocessing.get( "parameters", None ) @@ -274,7 +284,6 @@ class Pipeline(ProcessABC): post_process_params ) del pipeline_parameters.postprocessing["parameters"] - return cls(pipeline_parameters, store=directory) @property @@ -282,12 +291,8 @@ class Pipeline(ProcessABC): return logging.getLogger("aliby") def run(self): - """ - Config holds the general information, use in main - Steps: all holds general tasks - steps: strain_name holds task for a given strain - """ - + """Run separate pipelines for all positions in an experiment.""" + # general information in config config = self.parameters.to_dict() expt_id = config["general"]["id"] distributed = config["general"]["distributed"] @@ -297,82 +302,76 @@ class Pipeline(ProcessABC): k: config["general"].get(k) for k in ("host", "username", "password") } - dispatcher = dispatch_dataset(expt_id, **self.server_info) logging.getLogger("aliby").info( f"Fetching data using {dispatcher.__class__.__name__}" ) - # Do all all initialisations - + # get log files, either locally or via OMERO with dispatcher as conn: image_ids = conn.get_images() - directory = self.store or root_dir / conn.unique_name - if not directory.exists(): directory.mkdir(parents=True) - - # Download logs to use for metadata + # download logs to use for metadata conn.cache_logs(directory) - - # Modify to the configuration + # update configuration self.parameters.general["directory"] = str(directory) config["general"]["directory"] = directory - self.setLogger(directory) - - # Filter TODO integrate filter onto class and add regex - def filt_int(d: dict, filt: int): - return {k: v for i, (k, v) in enumerate(d.items()) if i == filt} - - def filt_str(image_ids: dict, filt: str): - return {k: v for k, v in image_ids.items() if re.search(filt, k)} - - def pick_filter(image_ids: dict, filt: int or str): - if isinstance(filt, str): - image_ids = filt_str(image_ids, filt) - elif isinstance(filt, int): - image_ids = filt_int(image_ids, filt) - return image_ids - - if isinstance(pos_filter, list): - image_ids = { - k: v - for filt in pos_filter - for k, v in pick_filter(image_ids, filt).items() - } - else: - image_ids = pick_filter(image_ids, pos_filter) - + # pick particular images if desired + if pos_filter: + if isinstance(pos_filter, list): + image_ids = { + k: v + for filt in pos_filter + for k, v in self.apply_filter(image_ids, filt).items() + } + else: + image_ids = self.apply_filter(image_ids, pos_filter) assert len(image_ids), "No images to segment" - - if distributed != 0: # Gives the number of simultaneous processes + # create pipelines + if distributed != 0: + # multiple cores with Pool(distributed) as p: results = p.map( - lambda x: self.create_pipeline(*x), + lambda x: self.run_one_position(*x), [(k, i) for i, k in enumerate(image_ids.items())], - # num_cpus=distributed, - # position=0, ) - - else: # Sequential + else: + # single core results = [] for k, v in tqdm(image_ids.items()): - r = self.create_pipeline((k, v), 1) + r = self.run_one_position((k, v), 1) results.append(r) - return results - def create_pipeline( + def apply_filter(self, image_ids: dict, filt: int or str): + """Select images by picking a particular one or by using a regular expression to parse their file names.""" + if isinstance(filt, str): + # pick images using a regular expression + image_ids = { + k: v for k, v in image_ids.items() if re.search(filt, k) + } + elif isinstance(filt, int): + # pick the filt'th image + image_ids = { + k: v for i, (k, v) in enumerate(image_ids.items()) if i == filt + } + return image_ids + + def run_one_position( self, - image_id: t.Tuple[str, str or PosixPath or int], + name_image_id: t.Tuple[str, str or PosixPath or int], index: t.Optional[int] = None, ): - """ """ + """Set up and run a pipeline for one position.""" self._pool_index = index - name, image_id = image_id + name, image_id = name_image_id + # session and filename are defined by calling setup_pipeline. + # can they be deleted here? session = None filename = None + # run_kwargs = {"extraction": {"labels": None, "masks": None}} try: ( @@ -386,7 +385,6 @@ class Pipeline(ProcessABC): session, trackers_state, ) = self._setup_pipeline(image_id) - loaded_writers = { name: writer(filename) for k in self.step_sequence @@ -398,20 +396,17 @@ class Pipeline(ProcessABC): "baby": ["mother_assign"], } - # START PIPELINE + # START frac_clogged_traps = 0 min_process_from = min(process_from.values()) - with get_image_class(image_id)( image_id, **self.server_info ) as image: - - # Initialise Steps + # initialise steps if "tiler" not in steps: steps["tiler"] = Tiler.from_image( image, TilerParameters.from_dict(config["tiler"]) ) - if process_from["baby"] < tps: session = initialise_tf(2) steps["baby"] = BabyRunner.from_tiler( @@ -420,8 +415,7 @@ class Pipeline(ProcessABC): ) if trackers_state: steps["baby"].crawler.tracker_states = trackers_state - - # Limit extraction parameters during run using the available channels in tiler + # limit extraction parameters using the available channels in tiler if process_from["extraction"] < tps: # TODO Move this parameter validation into Extractor av_channels = set((*steps["tiler"].channels, "general")) @@ -433,7 +427,6 @@ class Pipeline(ProcessABC): config["extraction"]["sub_bg"] = av_channels.intersection( config["extraction"]["sub_bg"] ) - av_channels_wsub = av_channels.union( [c + "_bgsub" for c in config["extraction"]["sub_bg"]] ) @@ -441,29 +434,27 @@ class Pipeline(ProcessABC): for op, (input_ch, _, _) in tmp.items(): if not set(input_ch).issubset(av_channels_wsub): del config["extraction"]["multichannel_ops"][op] - exparams = ExtractorParameters.from_dict( config["extraction"] ) steps["extraction"] = Extractor.from_tiler( exparams, store=filename, tiler=steps["tiler"] ) + # set up progress meter pbar = tqdm( range(min_process_from, tps), desc=image.name, initial=min_process_from, total=tps, - # position=index + 1, ) for i in pbar: - if ( frac_clogged_traps < earlystop["thresh_pos_clogged"] or i < earlystop["min_tp"] ): - - for step in self.iterative_steps: + # run through steps + for step in self.pipeline_steps: if i >= process_from[step]: result = steps[step].run_tp( i, **run_kwargs.get(step, {}) @@ -477,18 +468,16 @@ class Pipeline(ProcessABC): tp=i, meta={"last_processed": i}, ) - - # Step-specific actions + # perform step if ( step == "tiler" and i == min_process_from ): logging.getLogger("aliby").info( - f"Found {steps['tiler'].n_traps} traps in {image.name}" + f"Found {steps['tiler'].n_tiles} traps in {image.name}" ) - elif ( - step == "baby" - ): # Write state and pass info to ext + elif step == "baby": + # write state and pass info to ext (Alan: what's ext?) loaded_writers["state"].write( data=steps[ step @@ -498,47 +487,43 @@ class Pipeline(ProcessABC): ].datatypes.keys(), tp=i, ) - elif ( - step == "extraction" - ): # Remove mask/label after ext + elif step == "extraction": + # remove mask/label after extraction for k in ["masks", "labels"]: run_kwargs[step][k] = None - + # check and report clogging frac_clogged_traps = self.check_earlystop( filename, earlystop, steps["tiler"].tile_size ) self._log( f"{name}:Clogged_traps:{frac_clogged_traps}" ) - frac = np.round(frac_clogged_traps * 100) pbar.set_postfix_str(f"{frac} Clogged") - else: # Stop if more than X% traps are clogged + else: + # stop if too many traps are clogged self._log( - f"{name}:Analysis stopped early at time {i} with {frac_clogged_traps} clogged traps" + f"{name}:Stopped early at time {i} with {frac_clogged_traps} clogged traps" ) meta.add_fields({"end_status": "Clogged"}) break - meta.add_fields({"last_processed": i}) - - # Run post-processing + # run post-processing meta.add_fields({"end_status": "Success"}) post_proc_params = PostProcessorParameters.from_dict( config["postprocessing"] ) PostProcessor(filename, post_proc_params).run() - self._log("Analysis finished successfully.", "info") return 1 - except Exception as e: # Catch bugs during setup or runtime + except Exception as e: + # catch bugs during setup or run time logging.exception( f"{name}: Exception caught.", exc_info=True, ) - # This prints the type, value, and stack trace of the - # current exception being handled. + # print the type, value, and stack trace of the exception traceback.print_exc() raise e finally: @@ -546,23 +531,48 @@ class Pipeline(ProcessABC): @staticmethod def check_earlystop(filename: str, es_parameters: dict, tile_size: int): + """ + Check recent time points for tiles with too many cells. + + Returns the fraction of clogged tiles, where clogged tiles have + too many cells or too much of their area covered by cells. + + Parameters + ---------- + filename: str + Name of h5 file. + es_parameters: dict + Parameters defining when early stopping should happen. + For example: + {'min_tp': 100, + 'thresh_pos_clogged': 0.4, + 'thresh_trap_ncells': 8, + 'thresh_trap_area': 0.9, + 'ntps_to_eval': 5} + tile_size: int + Size of tile. + """ + # get the area of the cells organised by trap and cell number s = Signal(filename) df = s["/extraction/general/None/area"] + # check the latest time points only cells_used = df[ df.columns[-1 - es_parameters["ntps_to_eval"] : -1] ].dropna(how="all") + # find tiles with too many cells traps_above_nthresh = ( cells_used.groupby("trap").count().apply(np.mean, axis=1) > es_parameters["thresh_trap_ncells"] ) + # find tiles with cells covering too great a fraction of the tiles' area traps_above_athresh = ( cells_used.groupby("trap").sum().apply(np.mean, axis=1) / tile_size**2 > es_parameters["thresh_trap_area"] ) - return (traps_above_nthresh & traps_above_athresh).mean() + # Alan: can both this method and the next be deleted? def _load_config_from_file( self, filename: PosixPath, @@ -607,73 +617,66 @@ class Pipeline(ProcessABC): t.List[np.ndarray], ]: """ - Initialise pipeline components and if necessary use - exising file to continue existing experiments. + Initialise steps in a pipeline. + If necessary use a file to re-start experiments already partly run. Parameters ---------- - image_id : int - identifier of image in OMERO server, or filename + image_id : int or str + Identifier of a data set in an OMERO server or a filename. Returns - --------- + ------- filename: str - meta: - config: - process_from: - tps: - steps: - earlystop: - session: - trackers_state: - - Examples - -------- - FIXME: Add docs. - + Path to a h5 file to write to. + meta: object + agora.io.metadata.MetaData object + config: dict + Configuration parameters. + process_from: dict + Gives from which time point each step of the pipeline should start. + tps: int + Number of time points. + steps: dict + earlystop: dict + Parameters to check whether the pipeline should be stopped. + session: None + trackers_state: list + States of any trackers from earlier runs. """ config = self.parameters.to_dict() - pparams = config - image_id = image_id - general_config = config["general"] + # Alan: session is never changed session = None - earlystop = general_config.get("earlystop", None) - process_from = {k: 0 for k in self.iterative_steps} + earlystop = config["general"].get("earlystop", None) + process_from = {k: 0 for k in self.pipeline_steps} steps = {} - ow = {k: 0 for k in self.step_sequence} - # check overwriting - ow_id = general_config.get("overwrite", 0) + ow_id = config["general"].get("overwrite", 0) ow = {step: True for step in self.step_sequence} if ow_id and ow_id is not True: ow = { step: self.step_sequence.index(ow_id) < i for i, step in enumerate(self.step_sequence, 1) } - - # Set up - directory = general_config["directory"] - - trackers_state: t.List[np.ndarray] = [] + # set up + directory = config["general"]["directory"] + trackers_state = [] with get_image_class(image_id)(image_id, **self.server_info) as image: filename = Path(f"{directory}/{image.name}.h5") meta = MetaData(directory, filename) - from_start = True if np.any(ow.values()) else False - - # New experiment or overwriting + # remove existing file if overwriting if ( from_start and ( - config.get("overwrite", False) == True + config["general"].get("overwrite", False) or np.all(list(ow.values())) ) and filename.exists() ): os.remove(filename) - - # If no previous segmentation and keep tiler + # if the file exists with no previous segmentation use its tiler if filename.exists(): self._log("Result file exists.", "info") if not ow["tiler"]: @@ -692,15 +695,14 @@ class Pipeline(ProcessABC): if ow["baby"] else StateReader(filename).get_formatted_states() ) - config["tiler"] = steps["tiler"].parameters.to_dict() except Exception: + # Alan: a warning or log here? pass - if config["general"]["use_explog"]: meta.run() - - meta.add_fields( # Add non-logfile metadata + # add metadata not in the log file + meta.add_fields( { "aliby_version": version("aliby"), "baby_version": version("aliby-baby"), @@ -709,13 +711,11 @@ class Pipeline(ProcessABC): if isinstance(image_id, int) else str(image_id), "parameters": PipelineParameters.from_dict( - pparams + config ).to_yaml(), } ) - - tps = min(general_config["tps"], image.data.shape[0]) - + tps = min(config["general"]["tps"], image.data.shape[0]) return ( filename, meta, diff --git a/src/aliby/tile/tiler.py b/src/aliby/tile/tiler.py index 34aa89d1c2af76976f134a0fd38d58919cedc9bb..9c1c08ce247f2aa146e9af4f3e17c5473afb9901 100644 --- a/src/aliby/tile/tiler.py +++ b/src/aliby/tile/tiler.py @@ -1,19 +1,15 @@ """ -Tiler: Tiles and tracks traps. +Tiler: Divides images into smaller tiles. -The tasks of the Tiler are selecting regions of interest, or tiles, of an image - with one tile per trap, tracking and correcting for the drift of the microscope stage over time, and handling errors and bridging between the image data and ALIBY’s image-processing steps. +The tasks of the Tiler are selecting regions of interest, or tiles, of images - with one trap per tile, correcting for the drift of the microscope stage over time, and handling errors and bridging between the image data and Aliby’s image-processing steps. Tiler subclasses deal with either network connections or local files. -To find traps, we use a two-step process: we analyse the bright-field image to produce the template of a trap, and we fit this template to the image to find the traps' centres. +To find tiles, we use a two-step process: we analyse the bright-field image to produce the template of a trap, and we fit this template to the image to find the tiles' centres. We use texture-based segmentation (entropy) to split the image into foreground -- cells and traps -- and background, which we then identify with an Otsu filter. Two methods are used to produce a template trap from these regions: pick the trap with the smallest minor axis length and average over all validated traps. -A peak-identifying algorithm recovers the x and y-axis location of traps in the original image, and we choose the templating approach that identifies the most traps - -One key method is Tiler.run. - -The image-processing is performed by traps/segment_traps. +A peak-identifying algorithm recovers the x and y-axis location of traps in the original image, and we choose the approach to template that identifies the most tiles. The experiment is stored as an array with a standard indexing order of (Time, Channels, Z-stack, X, Y). """ @@ -34,11 +30,12 @@ from aliby.io.image import ImageLocalOME, ImageDir, ImageDummy from aliby.tile.traps import segment_traps -class Trap: +class Tile: """ - Stores a trap's location and size. - Allows checks to see if the trap should be padded. - Can export the trap either in OMERO or numpy formats. + Store a tile's location and size. + + Checks to see if the tile should be padded. + Can export the tile either in OMERO or numpy formats. """ def __init__(self, centre, parent, size, max_size): @@ -50,31 +47,28 @@ class Trap: def at_time(self, tp: int) -> t.List[int]: """ - Return trap centre at time tp by applying drifts + Return tile's centre by applying drifts. Parameters ---------- tp: integer - Index for a time point - - Returns - ------- - trap_centre: + Index for the time point of interest. """ drifts = self.parent.drifts - trap_centre = self.centre - np.sum(drifts[: tp + 1], axis=0) - return list(trap_centre.astype(int)) + tile_centre = self.centre - np.sum(drifts[: tp + 1], axis=0) + return list(tile_centre.astype(int)) - def as_tile(self, tp): + def as_tile(self, tp: int): """ - Return trap in the OMERO tile format of x, y, w, h - where x, y are at the bottom left corner of the tile + Return tile in the OMERO tile format of x, y, w, h. + + Here x, y are at the bottom left corner of the tile and w and h are the tile width and height. Parameters ---------- tp: integer - Index for a time point + Index for the time point of interest. Returns ------- @@ -93,10 +87,10 @@ class Trap: y = int(y - self.half_size) return x, y, self.size, self.size - def as_range(self, tp): + def as_range(self, tp: int): """ - Return trap in a range format: two slice objects that can - be used in arrays + Return tile in a range format: two slice objects that can + be used in arrays. Parameters ---------- @@ -112,11 +106,8 @@ class Trap: return slice(x, x + w), slice(y, y + h) -class TrapLocations: - """ - Stores each trap as an instance of Trap. - Traps can be iterated. - """ +class TileLocations: + """Store each tile as an instance of Tile.""" def __init__( self, @@ -130,29 +121,27 @@ class TrapLocations: self.tile_size = tile_size self.max_size = max_size self.initial_location = initial_location - self.traps = [ - Trap(centre, self, tile_size or max_size, max_size) + self.tiles = [ + Tile(centre, self, tile_size or max_size, max_size) for centre in initial_location ] self.drifts = drifts def __len__(self): - return len(self.traps) + return len(self.tiles) def __iter__(self): - yield from self.traps + yield from self.tiles @property def shape(self): - """ - Returns no of traps and no of drifts - """ - return len(self.traps), len(self.drifts) + """Return numbers of tiles and drifts.""" + return len(self.tiles), len(self.drifts) - def to_dict(self, tp): + def to_dict(self, tp: int): """ - Export inital locations, tile_size, max_size, and drifts - as a dictionary + Export initial locations, tile_size, max_size, and drifts + as a dictionary. Parameters ---------- @@ -168,47 +157,49 @@ class TrapLocations: return res def at_time(self, tp: int) -> np.ndarray: - # Returns ( ntraps, 2 ) ndarray with the trap centres as individual rows - return np.array([trap.at_time(tp) for trap in self.traps]) + """Return an array of tile centres (x- and y-coords).""" + return np.array([tile.at_time(tp) for tile in self.tiles]) @classmethod def from_tiler_init( cls, initial_location, tile_size: int = None, max_size: int = 1200 ): - """ - Instantiate class from an instance of the Tiler class - """ + """Instantiate from a Tiler.""" return cls(initial_location, tile_size, max_size, drifts=[]) @classmethod def read_hdf5(cls, file): - """ - Instantiate class from a hdf5 file - """ + """Instantiate from a h5 file.""" with h5py.File(file, "r") as hfile: - trap_info = hfile["trap_info"] - initial_locations = trap_info["trap_locations"][()] - drifts = trap_info["drifts"][()].tolist() - max_size = trap_info.attrs["max_size"] - tile_size = trap_info.attrs["tile_size"] - trap_locs = cls(initial_locations, tile_size, max_size=max_size) - trap_locs.drifts = drifts - return trap_locs + tile_info = hfile["trap_info"] + initial_locations = tile_info["trap_locations"][()] + drifts = tile_info["drifts"][()].tolist() + max_size = tile_info.attrs["max_size"] + tile_size = tile_info.attrs["tile_size"] + tile_loc_cls = cls(initial_locations, tile_size, max_size=max_size) + tile_loc_cls.drifts = drifts + return tile_loc_cls class TilerParameters(ParametersABC): - _defaults = {"tile_size": 117, "ref_channel": "Brightfield", "ref_z": 0} + """Set default parameters for Tiler.""" + + _defaults = { + "tile_size": 117, + "ref_channel": "Brightfield", + "ref_z": 0, + } class Tiler(StepABC): """ - Remote Timelapse Tiler. + Divide images into smaller tiles for faster processing. - Finds traps and re-registers images if there is any drifting. - Fetches images from a server. + Finds tiles and re-registers images if they drift. + Fetch images from an OMERO server if necessary. - Uses an Image instance, which lazily provides the data on pixels, and, as - an independent argument, metadata. + Uses an Image instance, which lazily provides the data on pixels, + and, as an independent argument, metadata. """ def __init__( @@ -216,17 +207,17 @@ class Tiler(StepABC): image: da.core.Array, metadata: dict, parameters: TilerParameters, - trap_locs=None, + tile_locs=None, ): """ - Initialise Tiler + Initialise. Parameters ---------- image: an instance of Image metadata: dictionary - parameters: an instance of TilerPameters - trap_locs: (optional) + parameters: an instance of TilerParameters + tile_locs: (optional) """ super().__init__(parameters) self.image = image @@ -235,8 +226,7 @@ class Tiler(StepABC): "channels", list(range(metadata["size_c"])) ) self.ref_channel = self.get_channel_index(parameters.ref_channel) - - self.trap_locs = trap_locs + self.tile_locs = tile_locs try: self.z_perchannel = { ch: zsect @@ -244,24 +234,24 @@ class Tiler(StepABC): } except Exception as e: self._log(f"No z_perchannel data: {e}") - self.tile_size = self.tile_size or min(self.image.shape[-2:]) @classmethod def dummy(cls, parameters: dict): """ - Instantiate dummy Tiler from dummy image + Instantiate dummy Tiler from dummy image. If image.dimorder exists dimensions are saved in that order. Otherwise default to "tczyx". Parameters ---------- - parameters: dictionary output of an instance of TilerParameters + parameters: dict + An instance of TilerParameters converted to a dict. """ imgdmy_obj = ImageDummy(parameters) dummy_image = imgdmy_obj.get_data_lazy() - # Default to "tczyx" if image.dimorder is None + # default to "tczyx" if image.dimorder is None dummy_omero_metadata = { f"size_{dim}": dim_size for dim, dim_size in zip( @@ -277,7 +267,6 @@ class Tiler(StepABC): "name": "", } ) - return cls( imgdmy_obj.data, dummy_omero_metadata, @@ -287,7 +276,7 @@ class Tiler(StepABC): @classmethod def from_image(cls, image, parameters: TilerParameters): """ - Instantiate Tiler from an Image instance + Instantiate from an Image instance. Parameters ---------- @@ -306,7 +295,7 @@ class Tiler(StepABC): parameters: TilerParameters = None, ): """ - Instantiate Tiler from hdf5 files + Instantiate from h5 files. Parameters ---------- @@ -315,7 +304,7 @@ class Tiler(StepABC): Path to a directory of h5 files parameters: an instance of TileParameters (optional) """ - trap_locs = TrapLocations.read_hdf5(filepath) + tile_locs = TileLocations.read_hdf5(filepath) metadata = BridgeH5(filepath).meta_h5 metadata["channels"] = image.metadata["channels"] if parameters is None: @@ -324,16 +313,17 @@ class Tiler(StepABC): image.data, metadata, parameters, - trap_locs=trap_locs, + tile_locs=tile_locs, ) - if hasattr(trap_locs, "drifts"): - tiler.n_processed = len(trap_locs.drifts) + if hasattr(tile_locs, "drifts"): + tiler.n_processed = len(tile_locs.drifts) return tiler @lru_cache(maxsize=2) - def get_tc(self, t, c): + def get_tc(self, t: int, c: int): """ Load image using dask. + Assumes the image is arranged as no of time points no of channels @@ -348,7 +338,7 @@ class Tiler(StepABC): c: integer An index for a channel - Retruns + Returns ------- full: an array of images """ @@ -358,16 +348,13 @@ class Tiler(StepABC): @property def shape(self): """ - Returns properties of the time-lapse as shown by self.image.shape - + Return properties of the time-lapse as shown by self.image.shape """ return self.image.shape @property def n_processed(self): - """ - Returns the number of images that have been processed - """ + """Return the number of processed images.""" if not hasattr(self, "_n_processed"): self._n_processed = 0 return self._n_processed @@ -377,22 +364,21 @@ class Tiler(StepABC): self._n_processed = value @property - def n_traps(self): - """ - Returns number of traps - """ - return len(self.trap_locs) + def n_tiles(self): + """Return number of tiles.""" + return len(self.tile_locs) - def initialise_traps(self, tile_size: int = None): + def initialise_tiles(self, tile_size: int = None): """ - Find initial trap positions if they have not been initialised. - Removes all those that are too close to the edge so no padding - is necessary. + Find initial positions of tiles. + + Remove tiles that are too close to the edge of the image + so no padding is necessary. Parameters ---------- tile_size: integer - The size of a tile + The size of a tile. """ initial_image = self.image[0, self.ref_channel, self.ref_z] if tile_size: @@ -400,27 +386,27 @@ class Tiler(StepABC): # max_size is the minimal number of x or y pixels max_size = min(self.image.shape[-2:]) # first time point, reference channel, reference z-position - # find the traps - trap_locs = segment_traps(initial_image, tile_size) - # keep only traps that are not near an edge - trap_locs = [ + # find the tiles + tile_locs = segment_traps(initial_image, tile_size) + # keep only tiles that are not near an edge + tile_locs = [ [x, y] - for x, y in trap_locs + for x, y in tile_locs if half_tile < x < max_size - half_tile and half_tile < y < max_size - half_tile ] - # store traps in an instance of TrapLocations - self.trap_locs = TrapLocations.from_tiler_init( - trap_locs, tile_size + # store tiles in an instance of TileLocations + self.tile_locs = TileLocations.from_tiler_init( + tile_locs, tile_size ) else: yx_shape = self.image.shape[-2:] - trap_locs = [[x // 2 for x in yx_shape]] - self.trap_locs = TrapLocations.from_tiler_init( - trap_locs, max_size=min(yx_shape) + tile_locs = [[x // 2 for x in yx_shape]] + self.tile_locs = TileLocations.from_tiler_init( + tile_locs, max_size=min(yx_shape) ) - def find_drift(self, tp): + def find_drift(self, tp: int): """ Find any translational drift between two images at consecutive time points using cross correlation. @@ -428,7 +414,7 @@ class Tiler(StepABC): Arguments --------- tp: integer - Index for a time point + Index for a time point. """ prev_tp = max(0, tp - 1) # cross-correlate @@ -437,14 +423,14 @@ class Tiler(StepABC): self.image[tp, self.ref_channel, self.ref_z], ) # store drift - if 0 < tp < len(self.trap_locs.drifts): - self.trap_locs.drifts[tp] = drift.tolist() + if 0 < tp < len(self.tile_locs.drifts): + self.tile_locs.drifts[tp] = drift.tolist() else: - self.trap_locs.drifts.append(drift.tolist()) + self.tile_locs.drifts.append(drift.tolist()) def get_tp_data(self, tp, c): """ - Returns all traps corrected for drift. + Return all tiles corrected for drift. Parameters ---------- @@ -453,41 +439,42 @@ class Tiler(StepABC): c: integer An index for a channel """ - traps = [] + tiles = [] # get OMERO image full = self.get_tc(tp, c) - for trap in self.trap_locs: - # pad trap if necessary - ndtrap = self.ifoob_pad(full, trap.as_range(tp)) - traps.append(ndtrap) - return np.stack(traps) + for tile in self.tile_locs: + # pad tile if necessary + ndtile = self.ifoob_pad(full, tile.as_range(tp)) + tiles.append(ndtile) + return np.stack(tiles) - def get_trap_data(self, trap_id, tp, c): + def get_tile_data(self, tile_id: int, tp: int, c: int): """ - Returns a particular trap corrected for drift and padding + Return a particular tile corrected for drift and padding. Parameters ---------- - trap_id: integer - Number of trap + tile_id: integer + Number of tile. tp: integer - Index of time points + Index of time points. c: integer - Index of channel + Index of channel. Returns ------- - ndtrap: array + ndtile: array An array of (x, y) arrays, one for each z stack """ full = self.get_tc(tp, c) - trap = self.trap_locs.traps[trap_id] - ndtrap = self.ifoob_pad(full, trap.as_range(tp)) - return ndtrap + tile = self.tile_locs.tiles[tile_id] + ndtile = self.ifoob_pad(full, tile.as_range(tp)) + return ndtile - def _run_tp(self, tp): + def _run_tp(self, tp: int): """ - Find traps if they have not yet been found. + Find tiles if they have not yet been found. + Determine any translational drift of the current image from the previous one. @@ -498,10 +485,10 @@ class Tiler(StepABC): """ # assert tp >= self.n_processed, "Time point already processed" # TODO check contiguity? - if self.n_processed == 0 or not hasattr(self.trap_locs, "drifts"): - self.initialise_traps(self.tile_size) - if hasattr(self.trap_locs, "drifts"): - drift_len = len(self.trap_locs.drifts) + if self.n_processed == 0 or not hasattr(self.tile_locs, "drifts"): + self.initialise_tiles(self.tile_size) + if hasattr(self.tile_locs, "drifts"): + drift_len = len(self.tile_locs.drifts) if self.n_processed != drift_len: warnings.warn("Tiler:n_processed and ndrifts don't match") self.n_processed = drift_len @@ -510,7 +497,7 @@ class Tiler(StepABC): # update n_processed self.n_processed = tp + 1 # return result for writer - return self.trap_locs.to_dict(tp) + return self.tile_locs.to_dict(tp) def run(self, time_dim=None): """ @@ -524,14 +511,13 @@ class Tiler(StepABC): def get_traps_timepoint(self, *args, **kwargs): self._log( - "get_trap_timepoints is deprecated; get_tiles_timepoint instead." + "get_traps_timepoint is deprecated; get_tiles_timepoint instead." ) - return self.get_tiles_timepoint(*args, **kwargs) # The next set of functions are necessary for the extraction object def get_tiles_timepoint( - self, tp, tile_shape=None, channels=None, z: int = 0 + self, tp: int, tile_shape=None, channels=None, z: int = 0 ) -> np.ndarray: """ Get a multidimensional array with all tiles for a set of channels @@ -553,10 +539,10 @@ class Tiler(StepABC): Returns ------- res: array - Data arranged as (traps, channels, timepoints, X, Y, Z) + Data arranged as (tiles, channels, time points, X, Y, Z) """ - # FIXME add support for subtiling trap - # FIXME can we ignore z(always give) + # FIXME add support for sub-tiling a tile + # FIXME can we ignore z if channels is None: channels = [0] elif isinstance(channels, str): @@ -566,8 +552,8 @@ class Tiler(StepABC): for c in channels: # only return requested z val = self.get_tp_data(tp, c)[:, z] - # starts with the order: traps, z, y, x - # returns the order: trap, C, T, X, Y, Z + # starts with the order: tiles, z, y, x + # returns the order: tiles, C, T, X, Y, Z val = val.swapaxes(1, 3).swapaxes(1, 2) val = np.expand_dims(val, axis=1) res.append(val) @@ -584,16 +570,19 @@ class Tiler(StepABC): @property def ref_channel_index(self): + """Return index of reference channel.""" return self.get_channel_index(self.parameters.ref_channel) def get_channel_index(self, channel: str or int): """ - Find index for channel using regex. Returns the first matched string. + Find index for channel using regex. + + Returns the first matched string. Parameters ---------- channel: string or int - The channel or index to be used + The channel or index to be used. """ if isinstance(channel, str): channel = find_channel_index(self.channels, channel) @@ -606,7 +595,7 @@ class Tiler(StepABC): @staticmethod def ifoob_pad(full, slices): """ - Returns the slices padded if it is out of bounds. + Return the slices padded if out of bounds. Parameters ---------- @@ -614,11 +603,11 @@ class Tiler(StepABC): Slice of OMERO image (zstacks, x, y) - the entire position with zstacks as first axis slices: tuple of two slices - Delineates indiceds for the x- and y- ranges of the tile. + Delineates indices for the x- and y- ranges of the tile. Returns ------- - trap: array + tile: array A tile with all z stacks for the given slices. If some padding is needed, the median of the image is used. If much padding is needed, a tile of NaN is returned. @@ -628,7 +617,7 @@ class Tiler(StepABC): # ignore parts of the tile outside of the image y, x = [slice(max(0, s.start), min(max_size, s.stop)) for s in slices] # get the tile including all z stacks - trap = full[:, y, x] + tile = full[:, y, x] # find extent of padding needed in x and y padding = np.array( [(-min(0, s.start), -min(0, max_size - s.stop)) for s in slices] @@ -638,13 +627,15 @@ class Tiler(StepABC): if (padding > tile_size / 4).any(): # too much of the tile is outside of the image # fill with NaN - trap = np.full((full.shape[0], tile_size, tile_size), np.nan) + tile = np.full((full.shape[0], tile_size, tile_size), np.nan) else: - # pad tile with median value of trap image - trap = np.pad(trap, [[0, 0]] + padding.tolist(), "median") - return trap + # pad tile with median value of the tile + tile = np.pad(tile, [[0, 0]] + padding.tolist(), "median") + return tile +# Alan: do we need these as well as get_channel_index and get_channel_name? +# self._log below is not defined def find_channel_index(image_channels: t.List[str], channel: str): """ Access @@ -659,7 +650,14 @@ def find_channel_index(image_channels: t.List[str], channel: str): def find_channel_name(image_channels: t.List[str], channel: str): """ - Find the name of the channel according to a given channel regex. + Find the name of the channel using regex. + + Parameters + ---------- + image_channels: list of str + Channels. + channel: str + A regular expression. """ index = find_channel_index(image_channels, channel) if index is not None: diff --git a/src/aliby/tile/traps.py b/src/aliby/tile/traps.py index 4eddeb7e45a0f39ea0de28c865b79b685061da5b..65b21b9ce744d59bd87b2232b12b6971a05a8395 100644 --- a/src/aliby/tile/traps.py +++ b/src/aliby/tile/traps.py @@ -1,7 +1,4 @@ -""" -A set of utilities for dealing with ALCATRAS traps -""" - +"""Functions for identifying and dealing with ALCATRAS traps.""" import numpy as np from skimage import feature, transform @@ -31,10 +28,10 @@ def segment_traps( **identify_traps_kwargs, ): """ - Uses an entropy filter and Otsu thresholding to find a trap template, + Use an entropy filter and Otsu thresholding to find a trap template, which is then passed to identify_trap_locations. - To obtain candidate traps it the major axis length of a tile must be smaller than tilesize. + To obtain candidate traps the major axis length of a tile must be smaller than tilesize. The hyperparameters have not been optimised. @@ -60,7 +57,7 @@ def segment_traps( Returns ------- traps: an array of pairs of integers - The coordinates of the centroids of the traps + The coordinates of the centroids of the traps. """ # keep a memory of image in case need to re-run img = image @@ -144,17 +141,18 @@ def identify_trap_locations( image, trap_template, optimize_scale=True, downscale=0.35, trap_size=None ): """ - Identify the traps in a single image based on a trap template, - which requires the trap template to be similar to the image - (same camera, same magification - ideally the same experiment). + Identify the traps in a single image based on a trap template. + + Requires the trap template to be similar to the image + (same camera, same magnification - ideally the same experiment). - Uses normalised correlation in scikit-image's match_template. + Use normalised correlation in scikit-image's to match_template. - The search is speeded up by downscaling both the image and + The search is sped up by down-scaling both the image and the trap template before running the template matching. The trap template is rotated and re-scaled to improve matching. - The parameters of the rotation and rescaling are optimised, although + The parameters of the rotation and re-scaling are optimised, although over restricted ranges. Parameters @@ -243,4 +241,4 @@ def stretch_image(image): maxval = np.percentile(image, 98) image = np.clip(image, minval, maxval) image = (image - minval) / (maxval - minval) - return image \ No newline at end of file + return image diff --git a/src/extraction/core/extractor.py b/src/extraction/core/extractor.py index eee3185a613eb19af8faa5b88566d76d1f197d53..c12f581fc5aaf47652d131601dd0cf8ec683467e 100644 --- a/src/extraction/core/extractor.py +++ b/src/extraction/core/extractor.py @@ -18,7 +18,7 @@ from extraction.core.functions.loaders import ( load_redfuns, ) -# Define types +# define types reduction_method = t.Union[t.Callable, str, None] extraction_tree = t.Dict[ str, t.Dict[reduction_method, t.Dict[str, t.Collection]] @@ -27,7 +27,7 @@ extraction_result = t.Dict[ str, t.Dict[reduction_method, t.Dict[str, t.Dict[str, pd.Series]]] ] -# Global parameters used to load functions that either analyse cells or their background. These global parameters both allow the functions to be stored in a dictionary for access only on demand and to be defined simply in extraction/core/functions. +# Global variables used to load functions that either analyse cells or their background. These global variables both allow the functions to be stored in a dictionary for access only on demand and to be defined simply in extraction/core/functions. CELL_FUNS, TRAPFUNS, FUNS = load_funs() CUSTOM_FUNS, CUSTOM_ARGS = load_custom_args() RED_FUNS = load_redfuns() @@ -37,9 +37,7 @@ RED_FUNS = load_redfuns() class ExtractorParameters(ParametersABC): - """ - Base class to define parameters for extraction. - """ + """Base class to define parameters for extraction.""" def __init__( self, @@ -48,12 +46,14 @@ class ExtractorParameters(ParametersABC): multichannel_ops: t.Dict = {}, ): """ + Initialise. + Parameters ---------- tree: dict Nested dictionary indicating channels, reduction functions and metrics to be used. - str channel -> U(function,None) reduction -> str metric + str channel -> U(function, None) reduction -> str metric If not of depth three, tree will be filled with None. sub_bg: set multichannel_ops: dict @@ -65,7 +65,7 @@ class ExtractorParameters(ParametersABC): @staticmethod def guess_from_meta(store_name: str, suffix="fast"): """ - Find the microscope used from the h5 metadata. + Find the microscope name from the h5 metadata. Parameters ---------- @@ -98,17 +98,7 @@ class Extractor(StepABC): Usually the metric is applied to only a tile's masked area, but some metrics depend on the whole tile. - Extraction follows a three-level tree structure. Channels, such as GFP, are the root level; the reduction algorithm, such as maximum projection, is the second level; the specific metric, or operation, to apply to the masks is the third level. - - Parameters - ---------- - parameters: core.extractor Parameters - Parameters that include the channels, and reduction and - extraction functions. - store: str - Path to the h5 file, which must contain the cell masks. - tiler: pipeline-core.core.segmentation tiler - Class that contains or fetches the images used for segmentation. + Extraction follows a three-level tree structure. Channels, such as GFP, are the root level; the reduction algorithm, such as maximum projection, is the second level; the specific metric, or operation, to apply to the masks, such as mean, is the third level. """ # Alan: should this data be stored here or all such data in a separate file @@ -129,10 +119,13 @@ class Extractor(StepABC): Parameters ---------- - parameters: ExtractorParameters object + parameters: core.extractor Parameters + Parameters that include the channels, reduction and + extraction functions. store: str - Name of h5 file - tiler: Tiler object + Path to the h5 file containing the cell masks. + tiler: pipeline-core.core.segmentation tiler + Class that contains or fetches the images used for segmentation. """ self.params = parameters if store: @@ -179,14 +172,18 @@ class Extractor(StepABC): @property def group(self): - # returns path within h5 file + """Return path within the h5 file.""" if not hasattr(self, "_out_path"): self._group = "/extraction/" return self._group def load_custom_funs(self): """ - Define any custom functions to be functions of cell_masks and trap_image only. + Incorporate the extra arguments of custom functions into their definitions. + + Normal functions only have cell_masks and trap_image as their + arguments, and here custom functions are made the same by + setting the values of their extra arguments. Any other parameters are taken from the experiment's metadata and automatically applied. These parameters therefore must be loaded within an Extractor instance. """ @@ -206,12 +203,13 @@ class Extractor(StepABC): k: {k2: self.get_meta(k2) for k2 in v} for k, v in CUSTOM_ARGS.items() } - # define custom functions - those with extra arguments other than cell_masks and trap_image - as functions of two variables + # define custom functions self._custom_funs = {} for k, f in CUSTOM_FUNS.items(): def tmp(f): # pass extra arguments to custom function + # return a function of cell_masks and trap_image return lambda cell_masks, trap_image: trap_apply( f, cell_masks, @@ -222,6 +220,7 @@ class Extractor(StepABC): self._custom_funs[k] = tmp(f) def load_funs(self): + """Define all functions, including custum ones.""" self.load_custom_funs() self._all_cell_funs = set(self._custom_funs.keys()).union(CELL_FUNS) # merge the two dicts @@ -239,20 +238,18 @@ class Extractor(StepABC): **kwargs, ) -> t.Optional[np.ndarray]: """ - Find tiles for a given time point and given channels and z-stacks. - - Returns None if no tiles are found. + Find tiles for a given time point, channels, and z-stacks. Any additional keyword arguments are passed to tiler.get_tiles_timepoint Parameters ---------- tp: int - Time point of interest + Time point of interest. channels: list of strings (optional) - Channels of interest + Channels of interest. z: list of integers (optional) - Indices for the z-stacks of interest + Indices for the z-stacks of interest. """ if channels is None: # find channels from tiler @@ -265,16 +262,16 @@ class Extractor(StepABC): channel_ids = None if z is None: # gets the tiles data via tiler - z: t.List[int] = list(range(self.tiler.shape[-3])) - tiles = ( + z = list(range(self.tiler.shape[-3])) + res = ( self.tiler.get_tiles_timepoint( tp, channels=channel_ids, z=z, **kwargs ) if channel_ids else None ) - # data arranged as (traps, channels, timepoints, X, Y, Z) - return tiles + # data arranged as (tiles, channels, time points, X, Y, Z) + return res def extract_traps( self, @@ -302,11 +299,10 @@ class Extractor(StepABC): Returns ------- res_idx: a tuple of tuples - A two-tuple of a tuple of results and a tuple with the corresponding trap_id and cell labels + A two-tuple comprising a tuple of results and a tuple of the tile_id and cell labels """ if labels is None: self._log("No labels given. Sorting cells using index.") - cell_fun = True if metric in self._all_cell_funs else False idx = [] results = [] @@ -337,7 +333,9 @@ class Extractor(StepABC): **kwargs, ) -> t.Dict[str, pd.Series]: """ - Returns dict with metrics as key and metrics applied to data as values for data from one timepoint. + Return dict with metrics as key and metrics applied to data as values. + + Data from one time point is used. """ d = { metric: self.extract_traps( @@ -359,8 +357,8 @@ class Extractor(StepABC): Parameters ---------- - traps: array - An array of image data arranged as (traps, X, Y, Z) + tiles_data: array + An array of image data arranged as (tiles, X, Y, Z) masks: list of arrays An array of masks for each trap: one per cell at the trap red_metrics: dict @@ -371,20 +369,20 @@ class Extractor(StepABC): Returns ------ - Dictionary of dataframes with the corresponding reductions and metrics nested. + Dict of dataframes with the corresponding reductions and metrics nested. """ # create dict with keys naming the reduction in the z-direction and the reduced data as values - reduced_traps = {} + reduced_tiles_data = {} if traps is not None: for red_fun in red_metrics.keys(): - reduced_traps[red_fun] = [ - self.reduce_dims(trap, method=RED_FUNS[red_fun]) - for trap in traps + reduced_tiles_data[red_fun] = [ + self.reduce_dims(tile_data, method=RED_FUNS[red_fun]) + for tile_data in traps ] d = { red_fun: self.extract_funs( metrics=metrics, - traps=reduced_traps.get(red_fun, [None for _ in masks]), + traps=reduced_tiles_data.get(red_fun, [None for _ in masks]), masks=masks, **kwargs, ) @@ -403,9 +401,9 @@ class Extractor(StepABC): Parameters ---------- img: array - An array of the image data arranged as (X, Y, Z) + An array of the image data arranged as (X, Y, Z). method: function - The reduction function + The reduction function. """ reduced = img if method is not None: @@ -422,7 +420,7 @@ class Extractor(StepABC): **kwargs, ) -> t.Dict[str, t.Dict[str, t.Dict[str, tuple]]]: """ - Extract for an individual time-point. + Extract for an individual time point. Parameters ---------- @@ -452,7 +450,6 @@ class Extractor(StepABC): The first tuple is the result of applying the metrics to a particular cell or trap; the second tuple is either (trap_id, cell_label) for a metric applied to a cell or a trap_id for a metric applied to a trap. An example is d["GFP"]["np_max"]["mean"][0], which gives a tuple of the calculated mean GFP fluorescence for all cells. - """ # TODO Can we split the different extraction types into sub-methods to make this easier to read? if tree is None: @@ -464,7 +461,6 @@ class Extractor(StepABC): tree_chs = (*ch_tree,) # create a Cells object to extract information from the h5 file cells = Cells(self.local) - # find the cell labels and store as dict with trap_ids as keys if labels is None: raw_labels = cells.labels_at_time(tp) @@ -472,7 +468,6 @@ class Extractor(StepABC): trap_id: raw_labels.get(trap_id, []) for trap_id in range(cells.ntraps) } - # find the cell masks for a given trap as a dict with trap_ids as keys if masks is None: raw_masks = cells.at_time(tp, kind="mask") @@ -482,11 +477,9 @@ class Extractor(StepABC): masks[trap_id] = np.dstack(np.array(cells)).astype(bool) # convert to a list of masks masks = [np.array(v) for v in masks.values()] - # find image data at the time point - # stored as an array arranged as (traps, channels, timepoints, X, Y, Z) + # stored as an array arranged as (traps, channels, time points, X, Y, Z) tiles = self.get_tiles(tp, tile_shape=tile_size, channels=tree_chs) - # generate boolean masks for background as a list with one mask per trap bgs = [] if self.params.sub_bg: @@ -496,7 +489,6 @@ class Extractor(StepABC): else np.zeros((tile_size, tile_size)) for m in masks ] - # perform extraction by applying metrics d = {} self.img_bgsub = {} @@ -510,9 +502,9 @@ class Extractor(StepABC): img = None # apply metrics to image data d[ch] = self.reduce_extract( - red_metrics=red_metrics, traps=img, masks=masks, + red_metrics=red_metrics, labels=labels, **kwargs, ) @@ -537,8 +529,7 @@ class Extractor(StepABC): labels=labels, **kwargs, ) - - # apply any metrics that use multiple channels (eg pH calculations) + # apply any metrics using multiple channels, such as pH calculations for name, ( chs, merge_fun, @@ -560,10 +551,9 @@ class Extractor(StepABC): labels=labels, **kwargs, ) - return d - def get_imgs(self, channel: t.Optional[str], traps, channels=None): + def get_imgs(self, channel: t.Optional[str], tiles, channels=None): """ Return image from a correct source, either raw or bgsub. @@ -571,20 +561,20 @@ class Extractor(StepABC): ---------- channel: str Name of channel to get. - traps: ndarray - An array of the image data having dimensions of (trap_id, channel, tp, tile_size, tile_size, n_zstacks). + tiles: ndarray + An array of the image data having dimensions of (tile_id, channel, tp, tile_size, tile_size, n_zstacks). channels: list of str (optional) List of available channels. Returns ------- img: ndarray - An array of image data with dimensions (no traps, X, Y, no Z channels) + An array of image data with dimensions (no tiles, X, Y, no Z channels) """ if channels is None: channels = (*self.params.tree,) if channel in channels: # TODO start here to fetch channel using regex - return traps[:, channels.index(channel), 0] + return tiles[:, channels.index(channel), 0] elif channel in self.img_bgsub: return self.img_bgsub[channel] @@ -622,7 +612,6 @@ class Extractor(StepABC): tps = list(range(self.meta["time_settings/ntimepoints"][0])) elif isinstance(tps, int): tps = [tps] - # store results in dict d = {} for tp in tps: @@ -669,7 +658,7 @@ class Extractor(StepABC): self.writer.id_cache.clear() def get_meta(self, flds: t.Union[str, t.Collection]): - # Obtain metadata for one or multiple fields + """Obtain metadata for one or multiple fields.""" if isinstance(flds, str): flds = [flds] meta_short = {k.split("/")[-1]: v for k, v in self.meta.items()} @@ -692,7 +681,7 @@ def flatten_nesteddict( to: str (optional) Specifies the format of the output, either pd.Series (default) or a list tp: int - Timepoint used to name the pd.Series + Time point used to name the pd.Series Returns ------- diff --git a/src/extraction/core/functions/cell.py b/src/extraction/core/functions/cell.py index c3d99d27ade073aef681a80ee77a50db090297c6..0e7b9fe854d5c18efdf9fd09b69a2a93dbf05032 100644 --- a/src/extraction/core/functions/cell.py +++ b/src/extraction/core/functions/cell.py @@ -1,13 +1,15 @@ """ -Base functions to extract information from a single cell +Base functions to extract information from a single cell. -These functions are automatically read by extractor.py, and so can only have the cell_mask and trap_image as inputs and must return only one value. +These functions are automatically read by extractor.py, and +so can only have the cell_mask and trap_image as inputs. They +must return only one value. They assume that there are no NaNs in the image. - We use bottleneck when it performs faster than numpy: -- Median -- values containing NaNs (We make sure this does not happen) +We use the module bottleneck when it performs faster than numpy: +- Median +- values containing NaNs (but we make sure this does not happen) """ import math import typing as t @@ -19,24 +21,24 @@ from scipy import ndimage def area(cell_mask) -> int: """ - Find the area of a cell mask + Find the area of a cell mask. Parameters ---------- cell_mask: 2d array - Segmentation mask for the cell + Segmentation mask for the cell. """ return np.sum(cell_mask) def eccentricity(cell_mask) -> float: """ - Find the eccentricity using the approximate major and minor axes + Find the eccentricity using the approximate major and minor axes. Parameters ---------- cell_mask: 2d array - Segmentation mask for the cell + Segmentation mask for the cell. """ min_ax, maj_ax = min_maj_approximation(cell_mask) return np.sqrt(maj_ax**2 - min_ax**2) / maj_ax @@ -44,12 +46,12 @@ def eccentricity(cell_mask) -> float: def mean(cell_mask, trap_image) -> float: """ - Finds the mean of the pixels in the cell. + Find the mean of the pixels in the cell. Parameters ---------- cell_mask: 2d array - Segmentation mask for the cell + Segmentation mask for the cell. trap_image: 2d array """ return np.mean(trap_image[cell_mask]) @@ -57,12 +59,12 @@ def mean(cell_mask, trap_image) -> float: def median(cell_mask, trap_image) -> int: """ - Finds the median of the pixels in the cell. + Find the median of the pixels in the cell. Parameters ---------- cell_mask: 2d array - Segmentation mask for the cell + Segmentation mask for the cell. trap_image: 2d array """ return bn.median(trap_image[cell_mask]) @@ -70,12 +72,12 @@ def median(cell_mask, trap_image) -> int: def max2p5pc(cell_mask, trap_image) -> float: """ - Finds the mean of the brightest 2.5% of pixels in the cell. + Find the mean of the brightest 2.5% of pixels in the cell. Parameters ---------- cell_mask: 2d array - Segmentation mask for the cell + Segmentation mask for the cell. trap_image: 2d array """ # number of pixels in mask @@ -84,19 +86,18 @@ def max2p5pc(cell_mask, trap_image) -> float: # sort pixels in cell and find highest 2.5% pixels = trap_image[cell_mask] top_values = bn.partition(pixels, len(pixels) - n_top)[-n_top:] - # find mean of these highest pixels return np.mean(top_values) def max5px(cell_mask, trap_image) -> float: """ - Finds the mean of the five brightest pixels in the cell. + Find the mean of the five brightest pixels in the cell. Parameters ---------- cell_mask: 2d array - Segmentation mask for the cell + Segmentation mask for the cell. trap_image: 2d array """ # sort pixels in cell @@ -109,12 +110,12 @@ def max5px(cell_mask, trap_image) -> float: def std(cell_mask, trap_image): """ - Finds the standard deviation of the values of the pixels in the cell. + Find the standard deviation of the values of the pixels in the cell. Parameters ---------- cell_mask: 2d array - Segmentation mask for the cell + Segmentation mask for the cell. trap_image: 2d array """ return np.std(trap_image[cell_mask]) @@ -122,12 +123,15 @@ def std(cell_mask, trap_image): def volume(cell_mask) -> float: """ - Estimates the volume of the cell assuming it is an ellipsoid with the mask providing a cross-section through the median plane of the ellipsoid. + Estimate the volume of the cell. + + Assumes the cell is an ellipsoid with the mask providing + a cross-section through its median plane. Parameters ---------- cell_mask: 2d array - Segmentation mask for the cell + Segmentation mask for the cell. """ min_ax, maj_ax = min_maj_approximation(cell_mask) return (4 * np.pi * min_ax**2 * maj_ax) / 3 @@ -135,7 +139,7 @@ def volume(cell_mask) -> float: def conical_volume(cell_mask): """ - Estimates the volume of the cell + Estimate the volume of the cell. Parameters ---------- @@ -151,7 +155,10 @@ def conical_volume(cell_mask): def spherical_volume(cell_mask): """ - Estimates the volume of the cell assuming it is a sphere with the mask providing a cross-section through the median plane of the sphere. + Estimate the volume of the cell. + + Assumes the cell is a sphere with the mask providing + a cross-section through its median plane. Parameters ---------- @@ -165,7 +172,7 @@ def spherical_volume(cell_mask): def min_maj_approximation(cell_mask) -> t.Tuple[int]: """ - Finds the lengths of the minor and major axes of an ellipse from a cell mask. + Find the lengths of the minor and major axes of an ellipse from a cell mask. Parameters ---------- diff --git a/src/extraction/core/functions/defaults.py b/src/extraction/core/functions/defaults.py index 847986807b62bf52f308110497e62c169801d6b4..4fcb4094039f60d852f015a191e1af0aee759e5b 100644 --- a/src/extraction/core/functions/defaults.py +++ b/src/extraction/core/functions/defaults.py @@ -2,23 +2,26 @@ import re import typing as t from pathlib import PosixPath - import h5py +# should we move these functions here? +from aliby.tile.tiler import find_channel_name + def exparams_from_meta( meta: t.Union[dict, PosixPath, str], extras: t.Collection[str] = ["ph"] ): """ - Obtain parameters from metadata of hdf5 file. - It compares a list of candidate channels using case-inspecific REGEX to identify valid channels. + Obtain parameters from metadata of the h5 file. + + Compares a list of candidate channels using case-insensitive + REGEX to identify valid channels. """ - meta = meta if isinstance(meta, dict) else load_attributes(meta) + meta = meta if isinstance(meta, dict) else load_metadata(meta) base = { "tree": {"general": {"None": ["area", "volume", "eccentricity"]}}, "multichannel_ops": {}, } - candidate_channels = { "Citrine", "GFP", @@ -30,7 +33,6 @@ def exparams_from_meta( "Cy5", "mKO2", } - default_reductions = {"max"} default_metrics = { "mean", @@ -40,33 +42,26 @@ def exparams_from_meta( "max5px", # "nuc_est_conv", } - - # Defined ratiometric combinations that can be used as ratio - # key is numerator and value is denominator; add more to support additional channel names + # define ratiometric combinations + # key is numerator and value is denominator + # add more to support additional channel names ratiometric_combinations = {"phluorin405": ("phluorin488", "gfpfast")} - default_reduction_metrics = { r: default_metrics for r in default_reductions } # default_rm["None"] = ["nuc_conv_3d"] # Uncomment this to add nuc_conv_3d (slow) - - from aliby.tile.tiler import find_channel_name - extant_fluorescence_ch = [] for av_channel in candidate_channels: - # Find channels in metadata whose names match + # find matching channels in metadata found_channel = find_channel_name(meta.get("channels", []), av_channel) if found_channel is not None: extant_fluorescence_ch.append(found_channel) - for ch in extant_fluorescence_ch: base["tree"][ch] = default_reduction_metrics - base["sub_bg"] = extant_fluorescence_ch - - # Additional extraction defaults when channels available + # additional extraction defaults if the channels are available if "ph" in extras: - # SWAINLAB-specific names + # SWAINLAB specific names # find first valid combination of ratiometric fluorescence channels numerator_channel, denominator_channel = (None, None) for ch1, chs2 in ratiometric_combinations.items(): @@ -80,8 +75,7 @@ def exparams_from_meta( if found_channel2: denominator_channel = found_channel2 break - - # If two compatible ratiometric channels are available + # if two compatible ratiometric channels are available if numerator_channel is not None and denominator_channel is not None: sets = { b + a: (x, y) @@ -102,11 +96,11 @@ def exparams_from_meta( *v, default_reduction_metrics, ] - return base -def load_attributes(file: t.Union[str, PosixPath], group="/"): +def load_metadata(file: t.Union[str, PosixPath], group="/"): + """Get meta data from an h5 file.""" with h5py.File(file, "r") as f: meta = dict(f[group].attrs.items()) return meta diff --git a/src/extraction/core/functions/loaders.py b/src/extraction/core/functions/loaders.py index 9d2e9c479428db9db1189726a7ba8478b9923646..ff83b20c39fdb6db0f78b1d8e0474db52fae67c7 100644 --- a/src/extraction/core/functions/loaders.py +++ b/src/extraction/core/functions/loaders.py @@ -11,14 +11,13 @@ from extraction.core.functions.math_utils import div0 """ Load functions for analysing cells and their background. -Note that inspect.getmembers returns a list of function names and functions, and inspect.getfullargspec returns a function's arguments. +Note that inspect.getmembers returns a list of function names and functions, +and inspect.getfullargspec returns a function's arguments. """ def load_cellfuns_core(): - """ - Load functions from the cell module and return as a dict. - """ + """Load functions from the cell module and return as a dict.""" return { f[0]: f[1] for f in getmembers(cell) @@ -31,7 +30,10 @@ def load_custom_args() -> t.Tuple[ (t.Dict[str, t.Callable], t.Dict[str, t.List[str]]) ]: """ - Load custom functions from the localisation module and return the functions and any additional arguments, other than cell_mask and trap_image, as dictionaries. + Load custom functions from the localisation module. + + Return the functions and any additional arguments other + than cell_mask and trap_image as dictionaries. """ # load functions from module funs = { @@ -57,7 +59,8 @@ def load_custom_args() -> t.Tuple[ def load_cellfuns(): """ - Creates a dict of core functions that can be used on an array of cell_masks. + Create a dict of core functions for use on cell_masks. + The core functions only work on a single mask. """ # create dict of the core functions from cell.py - these functions apply to a single mask @@ -81,9 +84,7 @@ def load_cellfuns(): def load_trapfuns(): - """ - Load functions that are applied to an entire trap or tile or subsection of an image rather than to single cells. - """ + """Load functions that are applied to an entire tile.""" TRAPFUNS = { f[0]: f[1] for f in getmembers(trap) @@ -94,9 +95,7 @@ def load_trapfuns(): def load_funs(): - """ - Combine all automatically loaded functions - """ + """Combine all automatically loaded functions.""" CELLFUNS = load_cellfuns() TRAPFUNS = load_trapfuns() # return dict of cell funs, dict of trap funs, and dict of both @@ -111,7 +110,10 @@ def load_redfuns( """ Load functions to reduce a multidimensional image by one dimension. - It can take custom functions as arguments. + Parameters + ---------- + additional_reducers: function or a dict of functions (optional) + Functions to perform the reduction. """ RED_FUNS = { "max": bn.nanmax, @@ -121,12 +123,10 @@ def load_redfuns( "add": bn.nansum, "None": None, } - if additional_reducers is not None: if isinstance(additional_reducers, FunctionType): additional_reducers = [ (additional_reducers.__name__, additional_reducers) ] - RED_FUNS.update(name, fun) - + RED_FUNS.update(additional_reducers) return RED_FUNS diff --git a/src/extraction/core/functions/math_utils.py b/src/extraction/core/functions/math_utils.py index eeae8e0c432e698f27936cec738b17d906b2f59b..a6216ea9d26194009c1611ef3e788f530faa970f 100644 --- a/src/extraction/core/functions/math_utils.py +++ b/src/extraction/core/functions/math_utils.py @@ -20,7 +20,6 @@ def div0(array, fill=0, axis=-1): slices_0, slices_1 = [[slice(None)] * len(array.shape)] * 2 slices_0[axis] = 0 slices_1[axis] = 1 - with np.errstate(divide="ignore", invalid="ignore"): c = np.true_divide( array[tuple(slices_0)], diff --git a/src/extraction/core/functions/trap.py b/src/extraction/core/functions/trap.py index b3cd7d13c8be2942d92caa238216e628085937c6..f1f491e0940d7e994f17f52b8c0f7cfef098070c 100644 --- a/src/extraction/core/functions/trap.py +++ b/src/extraction/core/functions/trap.py @@ -5,14 +5,14 @@ import numpy as np def imBackground(cell_masks, trap_image): """ - Finds the median background (pixels not comprising cells) from trap_image + Find the median background (pixels not comprising cells) from trap_image. Parameters ---------- cell_masks: 3d array Segmentation masks for cells trap_image: - The image (all channels) for the tile containing the cell + The image (all channels) for the tile containing the cell. """ if not len(cell_masks): # create cell_masks if none are given @@ -25,14 +25,14 @@ def imBackground(cell_masks, trap_image): def background_max5(cell_masks, trap_image): """ - Finds the mean of the maximum five pixels of the background (pixels not comprising cells) from trap_image + Finds the mean of the maximum five pixels of the background. Parameters ---------- cell_masks: 3d array - Segmentation masks for cells + Segmentation masks for cells. trap_image: - The image (all channels) for the tile containing the cell + The image (all channels) for the tile containing the cell. """ if not len(cell_masks): # create cell_masks if none are given diff --git a/src/postprocessor/grouper.py b/src/postprocessor/grouper.py index 6b9fa9ee56a1d77fa24ff3b29299ca352db34466..92cd969acbeafc1519be92043b3060eeeac6debe 100644 --- a/src/postprocessor/grouper.py +++ b/src/postprocessor/grouper.py @@ -24,7 +24,7 @@ class Grouper(ABC): def __init__(self, dir: Union[str, PosixPath]): """Find h5 files and load a chain for each one.""" path = Path(dir) - assert path.exists(), "Dir does not exist" + assert path.exists(), f"{str(dir)} does not exist" self.name = path.name self.files = list(path.glob("*.h5")) assert len(self.files), "No valid h5 files in dir"