Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • swain-lab/aliby/aliby-mirror
  • swain-lab/aliby/alibylite
2 results
Show changes
Showing
with 1691 additions and 2125 deletions
"""
Neural network initialisation.
"""
from pathlib import Path
from time import perf_counter
import numpy as np
import tensorflow as tf
from agora.io.writer import DynamicWriter
def initialise_tf(version):
# Initialise tensorflow
if version == 1:
core_config = tf.ConfigProto()
core_config.gpu_options.allow_growth = True
session = tf.Session(config=core_config)
return session
# TODO this only works for TF2
if version == 2:
gpus = tf.config.experimental.list_physical_devices("GPU")
if gpus:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
logical_gpus = tf.config.experimental.list_logical_devices("GPU")
print(
len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs"
)
return None
def timer(func, *args, **kwargs):
start = perf_counter()
result = func(*args, **kwargs)
print(f"Function {func.__name__}: {perf_counter() - start}s")
return result
################## CUSTOM OBJECTS ##################################
class ModelPredictor:
"""Generic object that takes a NN and returns the prediction.
Use for predicting fluorescence/other from bright field.
This does not do instance segmentations of anything.
"""
def __init__(self, tiler, model, name):
self.tiler = tiler
self.model = model
self.name = name
def get_data(self, tp):
# Change axes to X,Y,Z rather than Z,Y,X
return (
self.tiler.get_tp_data(tp, self.bf_channel)
.swapaxes(1, 3)
.swapaxes(1, 2)
)
def format_result(self, result, tp):
return {self.name: result, "timepoints": [tp] * len(result)}
def run_tp(self, tp):
"""Simulating processing time with sleep"""
# Access the image
segmentation = self.model.predict(self.get_data(tp))
return self._format_result(segmentation, tp)
class ModelPredictorWriter(DynamicWriter):
def __init__(self, file, name, shape, dtype):
super.__init__(file)
self.datatypes = {
name: (shape, dtype),
"timepoint": ((None,), np.uint16),
}
self.group = f"{self.name}_info"
......@@ -54,7 +54,7 @@ class DatasetLocalABC(ABC):
Abstract Base class to find local files, either OME-XML or raw images.
"""
_valid_suffixes = ("tiff", "png", "zarr")
_valid_suffixes = ("tiff", "png", "zarr", "tif")
_valid_meta_suffixes = ("txt", "log")
def __init__(self, dpath: t.Union[str, Path], *args, **kwargs):
......
......@@ -29,15 +29,11 @@ from tifffile import TiffFile
from agora.io.metadata import dir_to_meta, dispatch_metadata_parser
def get_examples_dir():
"""Get examples directory which stores dummy image for tiler"""
return files("aliby").parent.parent / "examples" / "tiler"
def instantiate_image(
source: t.Union[str, int, t.Dict[str, str], Path], **kwargs
):
"""Wrapper to instatiate the appropiate image
"""
Instantiate the image.
Parameters
----------
......@@ -46,42 +42,38 @@ def instantiate_image(
Examples
--------
image_path = "path/to/image"]
image_path = "path/to/image"
with instantiate_image(image_path) as img:
print(imz.data, img.metadata)
"""
return dispatch_image(source)(source, **kwargs)
def dispatch_image(source: t.Union[str, int, t.Dict[str, str], Path]):
"""
Wrapper to pick the appropiate Image class depending on the source of data.
"""
"""Pick the appropriate Image class depending on the source of data."""
if isinstance(source, (int, np.int64)):
from aliby.io.omero import Image
instatiator = Image
instantiator = Image
elif isinstance(source, dict) or (
isinstance(source, (str, Path)) and Path(source).is_dir()
):
# zarr files are considered directories
if Path(source).suffix == ".zarr":
instatiator = ImageZarr
instantiator = ImageZarr
else:
instatiator = ImageDir
elif isinstance(source, str) and Path(source).is_file():
instatiator = ImageLocalOME
instantiator = ImageDir
elif isinstance(source, (str, Path)) and Path(source).is_file():
instantiator = ImageLocalOME
else:
raise Exception(f"Invalid data source at {source}")
return instatiator
return instantiator
class BaseLocalImage(ABC):
"""
Base Image class to set path and provide context management method.
"""
"""Set path and provide method for context management."""
# default image order
_default_dimorder = "tczyx"
def __init__(self, path: t.Union[str, Path]):
......@@ -98,8 +90,7 @@ class BaseLocalImage(ABC):
return False
def rechunk_data(self, img):
# Format image using x and y size from metadata.
"""Format image using x and y size from metadata."""
self._rechunked_img = da.rechunk(
img,
chunks=(
......@@ -112,30 +103,35 @@ class BaseLocalImage(ABC):
)
return self._rechunked_img
@property
def data(self):
"""Get data."""
return self.get_data_lazy()
@property
def metadata(self):
"""Get metadata."""
return self._meta
def set_meta(self):
"""Load metadata using parser dispatch."""
self._meta = dispatch_metadata_parser(self.path)
@abstractmethod
def get_data_lazy(self) -> da.Array:
"""Define in child class."""
pass
@abstractproperty
def name(self):
"""Define in child class."""
pass
@abstractproperty
def dimorder(self):
"""Define in child class."""
pass
@property
def data(self):
return self.get_data_lazy()
@property
def metadata(self):
return self._meta
def set_meta(self):
"""Load metadata using parser dispatch"""
self._meta = dispatch_metadata_parser(self.path)
class ImageLocalOME(BaseLocalImage):
"""
......@@ -145,16 +141,18 @@ class ImageLocalOME(BaseLocalImage):
in which a multidimensional tiff image contains the metadata.
"""
def __init__(self, path: str, dimorder=None):
def __init__(self, path: str, dimorder=None, **kwargs):
"""Initialise using file name."""
super().__init__(path)
self._id = str(path)
self.set_meta(str(path))
def set_meta(self):
def set_meta(self, path):
"""Get metadata from the associated tiff file."""
meta = dict()
try:
with TiffFile(path) as f:
self._meta = xmltodict.parse(f.ome_metadata)["OME"]
for dim in self.dimorder:
meta["size_" + dim.lower()] = int(
self._meta["Image"]["Pixels"]["@Size" + dim]
......@@ -165,21 +163,19 @@ class ImageLocalOME(BaseLocalImage):
]
meta["name"] = self._meta["Image"]["@Name"]
meta["type"] = self._meta["Image"]["Pixels"]["@Type"]
except Exception as e: # Images not in OMEXML
except Exception as e:
# images not in OMEXML
print("Warning:Metadata not found: {}".format(e))
print(
f"Warning: No dimensional info provided. Assuming {self._default_dimorder}"
"Warning: No dimensional info provided. "
f"Assuming {self._default_dimorder}"
)
# Mark non-existent dimensions for padding
# mark non-existent dimensions for padding
self.base = self._default_dimorder
# self.ids = [self.index(i) for i in dimorder]
self._dimorder = base
self._dimorder = self.base
self._meta = meta
# self._meta["name"] = Path(path).name.split(".")[0]
@property
def name(self):
......@@ -196,7 +192,7 @@ class ImageLocalOME(BaseLocalImage):
@property
def dimorder(self):
"""Order of dimensions in image"""
"""Return order of dimensions in the image."""
if not hasattr(self, "_dimorder"):
self._dimorder = self._meta["Image"]["Pixels"]["@DimensionOrder"]
return self._dimorder
......@@ -207,16 +203,16 @@ class ImageLocalOME(BaseLocalImage):
return self._dimorder
def get_data_lazy(self) -> da.Array:
"""Return 5D dask array. For lazy-loading multidimensional tiff files"""
"""Return 5D dask array via lazy-loading of tiff files."""
if not hasattr(self, "formatted_img"):
if not hasattr(self, "ids"): # Standard dimension order
if not hasattr(self, "ids"):
# standard order of image dimensions
img = (imread(str(self.path))[0],)
else: # Custom dimension order, we rearrange the axes for compatibility
else:
# bespoke order, so rearrange axes for compatibility
img = imread(str(self.path))[0]
for i, d in enumerate(self._dimorder):
self._meta["size_" + d.lower()] = img.shape[i]
target_order = (
*self.ids,
*[
......@@ -235,42 +231,38 @@ class ImageLocalOME(BaseLocalImage):
img = da.moveaxis(
reshaped, range(len(reshaped.shape)), target_order
)
return self.rechunk_data(img)
class ImageDir(BaseLocalImage):
"""
Standard image class for tiff files.
Image class for the case in which all images are split in one or
multiple folders with time-points and channels as independent files.
It inherits from BaseLocalImage so we only override methods that are critical.
Assumptions:
- One folders per position.
- One folder per position.
- Images are flat.
- Channel, Time, z-stack and the others are determined by filenames.
- Provides Dimorder as it is set in the filenames, or expects order during instatiation
- Provides Dimorder as it is set in the filenames, or expects order
"""
def __init__(self, path: t.Union[str, Path], **kwargs):
"""Initialise using file name."""
super().__init__(path)
self.image_id = str(self.path.stem)
self._meta = dir_to_meta(self.path)
def get_data_lazy(self) -> da.Array:
"""Return 5D dask array. For lazy-loading local multidimensional tiff files"""
"""Return 5D dask array."""
img = imread(str(self.path / "*.tiff"))
# If extra channels, pick the first stack of the last dimensions
while len(img.shape) > 3:
img = img[..., 0]
if self._meta:
self._meta["size_x"], self._meta["size_y"] = img.shape[-2:]
# Reshape using metadata
# img = da.reshape(img, (*self._meta, *img.shape[1:]))
img = da.reshape(img, self._meta.values())
......@@ -291,6 +283,7 @@ class ImageDir(BaseLocalImage):
@property
def name(self):
"""Return name of image directory."""
return self.path.stem
@property
......@@ -304,24 +297,27 @@ class ImageDir(BaseLocalImage):
class ImageZarr(BaseLocalImage):
"""
Read zarr compressed files.
These are outputed by the script
These files are generated by the script
skeletons/scripts/howto_omero/convert_clone_zarr_to_tiff.py
"""
def __init__(self, path: t.Union[str, Path], **kwargs):
"""Initialise using file name."""
super().__init__(path)
self.set_meta()
try:
self._img = zarr.open(self.path)
self.add_size_to_meta()
except Exception as e:
print(f"Could not add size info to metadata: {e}")
print(f"ImageZarr: Could not add size info to metadata: {e}.")
def get_data_lazy(self) -> da.Array:
"""Return 5D dask array. For lazy-loading local multidimensional zarr files"""
"""Return 5D dask array for lazy-loading local multidimensional zarr files."""
return self._img
def add_size_to_meta(self):
"""Add shape of image array to metadata."""
self._meta.update(
{
f"size_{dim}": shape
......@@ -331,16 +327,13 @@ class ImageZarr(BaseLocalImage):
@property
def name(self):
"""Return name of zarr directory."""
return self.path.stem
@property
def dimorder(self):
# FIXME hardcoded order based on zarr compression/cloning script
"""Impose a hard-coded order of dimensions based on the zarr compression script."""
return "TCZYX"
# Assumes only dimensions start with "size"
# return [
# k.split("_")[-1] for k in self._meta.keys() if k.startswith("size")
# ]
class ImageDummy(BaseLocalImage):
......
......@@ -131,7 +131,6 @@ class BridgeOmero:
FIXME: Add docs.
"""
# metadata = load_attributes(filepath)
bridge = BridgeH5(filepath)
meta = safe_load(bridge.meta_h5["parameters"])["general"]
server_info = {k: meta[k] for k in ("host", "username", "password")}
......@@ -208,6 +207,13 @@ class Dataset(BridgeOmero):
im.getName(): im.getId() for im in self.ome_class.listChildren()
}
def get_channels(self):
"""Get channels from OMERO."""
for im in self.ome_class.listChildren():
channels = [ch.getLabel() for ch in im.getChannels()]
break
return channels
@property
def files(self):
if not hasattr(self, "_files"):
......@@ -254,7 +260,8 @@ class Dataset(BridgeOmero):
cls,
filepath: t.Union[str, Path],
):
"""Instatiate Dataset from a hdf5 file.
"""
Instantiate data set from a h5 file.
Parameters
----------
......@@ -268,7 +275,6 @@ class Dataset(BridgeOmero):
FIXME: Add docs.
"""
# metadata = load_attributes(filepath)
bridge = BridgeH5(filepath)
dataset_keys = ("omero_id", "omero_id,", "dataset_id")
for k in dataset_keys:
......@@ -301,21 +307,21 @@ class Image(BridgeOmero):
cls,
filepath: t.Union[str, Path],
):
"""Instatiate Image from a hdf5 file.
"""
Instantiate Image from a h5 file.
Parameters
----------
cls : Image
Image class
filepath : t.Union[str, Path]
Location of hdf5 file.
Location of h5 file.
Examples
--------
FIXME: Add docs.
"""
# metadata = load_attributes(filepath)
bridge = BridgeH5(filepath)
image_id = bridge.meta_h5["image_id"]
return cls(image_id, **cls.server_info_from_h5(filepath))
......
"""Set up and run pipelines: tiling, segmentation, extraction, and then post-processing."""
import logging
import multiprocessing
import os
import re
import traceback
import typing as t
from copy import copy
from importlib.metadata import version
from pathlib import Path
from pprint import pprint
import h5py
import baby
import baby.errors
import numpy as np
import pandas as pd
import tensorflow as tf
from pathos.multiprocessing import Pool
from tqdm import tqdm
try:
if baby.__version__:
from aliby.baby_sitter import BabyParameters, BabyRunner
except AttributeError:
from aliby.baby_client import BabyParameters, BabyRunner
import aliby.global_parameters as global_parameters
from agora.abc import ParametersABC, ProcessABC
from agora.io.metadata import MetaData, parse_logfiles
from agora.io.reader import StateReader
from agora.io.metadata import MetaData
from agora.io.signal import Signal
from agora.io.writer import (
LinearBabyWriter,
StateWriter,
TilerWriter,
)
from aliby.baby_client import BabyParameters, BabyRunner
from aliby.haystack import initialise_tf
from agora.io.writer import LinearBabyWriter, StateWriter, TilerWriter
from aliby.io.dataset import dispatch_dataset
from aliby.io.image import dispatch_image
from aliby.tile.tiler import Tiler, TilerParameters
from extraction.core.extractor import Extractor, ExtractorParameters
from extraction.core.functions.defaults import exparams_from_meta
from postprocessor.core.processor import PostProcessor, PostProcessorParameters
from extraction.core.extractor import (
Extractor,
ExtractorParameters,
extraction_params_from_meta,
)
from postprocessor.core.postprocessing import (
PostProcessor,
PostProcessorParameters,
)
# stop warnings from TensorFlow
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
logging.getLogger("tensorflow").setLevel(logging.ERROR)
class PipelineParameters(ParametersABC):
"""Define parameters for the steps of the pipeline."""
_pool_index = None
def __init__(
self, general, tiler, baby, extraction, postprocessing, reporting
self,
general,
tiler,
baby,
extraction,
postprocessing,
):
"""Initialise, but called by a class method - not directly."""
"""Initialise parameter sets using passed dictionaries."""
self.general = general
self.tiler = tiler
self.baby = baby
self.extraction = extraction
self.postprocessing = postprocessing
self.reporting = reporting
@classmethod
def default(
......@@ -76,16 +89,15 @@ class PipelineParameters(ParametersABC):
postprocessing: dict (optional)
Parameters for post-processing.
"""
expt_id = general.get("expt_id", 19993)
if isinstance(expt_id, Path):
assert expt_id.exists()
expt_id = str(expt_id)
general["expt_id"] = expt_id
if (
isinstance(general["expt_id"], Path)
and general["expt_id"].exists()
):
expt_id = str(general["expt_id"])
else:
expt_id = general["expt_id"]
directory = Path(general["directory"])
# get log files, either locally or via OMERO
# get metadata from log files either locally or via OMERO
with dispatch_dataset(
expt_id,
**{k: general.get(k) for k in ("host", "username", "password")},
......@@ -107,7 +119,6 @@ class PipelineParameters(ParametersABC):
}
# set minimal metadata
meta_d = minimal_default_meta
# define default values for general parameters
tps = meta_d.get("ntps", 2000)
defaults = {
......@@ -117,19 +128,12 @@ class PipelineParameters(ParametersABC):
tps=tps,
directory=str(directory.parent),
filter="",
earlystop=dict(
min_tp=100,
thresh_pos_clogged=0.4,
thresh_trap_ncells=8,
thresh_trap_area=0.9,
ntps_to_eval=5,
),
earlystop=global_parameters.earlystop,
logfile_level="INFO",
use_explog=True,
)
}
# update default values using inputs
# update default values for general using inputs
for k, v in general.items():
if k not in defaults["general"]:
defaults["general"][k] = v
......@@ -138,11 +142,9 @@ class PipelineParameters(ParametersABC):
defaults["general"][k][k2] = v2
else:
defaults["general"][k] = v
# define defaults and update with any inputs
# default Tiler parameters
defaults["tiler"] = TilerParameters.default(**tiler).to_dict()
# Generate a backup channel, for when logfile meta is available
# generate a backup channel for when logfile meta is available
# but not image metadata.
backup_ref_channel = None
if "channels" in meta_d and isinstance(
......@@ -152,69 +154,42 @@ class PipelineParameters(ParametersABC):
defaults["tiler"]["ref_channel"]
)
defaults["tiler"]["backup_ref_channel"] = backup_ref_channel
# default parameters
defaults["baby"] = BabyParameters.default(**baby).to_dict()
defaults["extraction"] = (
exparams_from_meta(meta_d)
or BabyParameters.default(**extraction).to_dict()
)
defaults["extraction"] = extraction_params_from_meta(meta_d)
defaults["postprocessing"] = PostProcessorParameters.default(
**postprocessing
).to_dict()
defaults["reporting"] = {}
return cls(**{k: v for k, v in defaults.items()})
def load_logs(self):
parsed_flattened = parse_logfiles(self.log_dir)
return parsed_flattened
class Pipeline(ProcessABC):
"""
Initialise and run tiling, segmentation, extraction and post-processing.
Each step feeds the next one.
To customise parameters for any step use the PipelineParameters class.stem
"""
pipeline_steps = ["tiler", "baby", "extraction"]
step_sequence = [
"tiler",
"baby",
"extraction",
"postprocessing",
]
# Specify the group in the h5 files written by each step
writer_groups = {
"tiler": ["trap_info"],
"baby": ["cell_info"],
"extraction": ["extraction"],
"postprocessing": ["postprocessing", "modifiers"],
}
writers = { # TODO integrate Extractor and PostProcessing in here
"tiler": [("tiler", TilerWriter)],
"baby": [("baby", LinearBabyWriter), ("state", StateWriter)],
}
"""Initialise and run tiling, segmentation, extraction and post-processing."""
def __init__(self, parameters: PipelineParameters, store=None):
"""Initialise - not usually called directly."""
"""Initialise using Pipeline parameters."""
super().__init__(parameters)
if store is not None:
store = Path(store)
# h5 file
self.store = store
config = self.parameters.to_dict()
self.server_info = {
k: config["general"].get(k)
for k in ("host", "username", "password")
}
self.expt_id = config["general"]["id"]
self.setLogger(config["general"]["directory"])
@staticmethod
def setLogger(
folder, file_level: str = "INFO", stream_level: str = "WARNING"
folder, file_level: str = "INFO", stream_level: str = "INFO"
):
"""Initialise and format logger."""
logger = logging.getLogger("aliby")
logger.setLevel(getattr(logging, file_level))
formatter = logging.Formatter(
"%(asctime)s - %(levelname)s:%(message)s",
"%(asctime)s - %(levelname)s: %(message)s",
datefmt="%Y-%m-%dT%H:%M:%S%z",
)
# for streams - stdout, files, etc.
......@@ -228,526 +203,288 @@ class Pipeline(ProcessABC):
fh.setFormatter(formatter)
logger.addHandler(fh)
@classmethod
def from_yaml(cls, fpath):
# This is just a convenience function, think before implementing
# for other processes
return cls(parameters=PipelineParameters.from_yaml(fpath))
@classmethod
def from_folder(cls, dir_path):
"""
Re-process all h5 files in a given folder.
All files must share the same parameters, even if they have different channels.
Parameters
---------
dir_path : str or Pathlib
Folder containing the files.
"""
# find h5 files
dir_path = Path(dir_path)
files = list(dir_path.rglob("*.h5"))
assert len(files), "No valid files found in folder"
fpath = files[0]
# TODO add support for non-standard unique folder names
with h5py.File(fpath, "r") as f:
pipeline_parameters = PipelineParameters.from_yaml(
f.attrs["parameters"]
)
pipeline_parameters.general["directory"] = dir_path.parent
pipeline_parameters.general["filter"] = [fpath.stem for fpath in files]
# fix legacy post-processing parameters
post_process_params = pipeline_parameters.postprocessing.get(
"parameters", None
)
if post_process_params:
pipeline_parameters.postprocessing["param_sets"] = copy(
post_process_params
)
del pipeline_parameters.postprocessing["parameters"]
return cls(pipeline_parameters)
@classmethod
def from_existing_h5(cls, fpath):
"""
Re-process an existing h5 file.
Not suitable for more than one file.
Parameters
---------
fpath: str
Name of file.
"""
with h5py.File(fpath, "r") as f:
pipeline_parameters = PipelineParameters.from_yaml(
f.attrs["parameters"]
)
directory = Path(fpath).parent
pipeline_parameters.general["directory"] = directory
pipeline_parameters.general["filter"] = Path(fpath).stem
post_process_params = pipeline_parameters.postprocessing.get(
"parameters", None
)
if post_process_params:
pipeline_parameters.postprocessing["param_sets"] = copy(
post_process_params
)
del pipeline_parameters.postprocessing["parameters"]
return cls(pipeline_parameters, store=directory)
@property
def _logger(self):
return logging.getLogger("aliby")
def run(self):
"""Run separate pipelines for all positions in an experiment."""
# general information in config
def setup(self):
"""Get meta data and identify each position."""
config = self.parameters.to_dict()
expt_id = config["general"]["id"]
distributed = config["general"]["distributed"]
pos_filter = config["general"]["filter"]
# print configuration
self.log("Using alibylite.", "info")
try:
self.log(f"Using Baby {baby.__version__}.", "info")
except AttributeError:
self.log("Using original Baby.", "info")
# extract from configuration
root_dir = Path(config["general"]["directory"])
self.server_info = {
k: config["general"].get(k)
for k in ("host", "username", "password")
}
dispatcher = dispatch_dataset(expt_id, **self.server_info)
logging.getLogger("aliby").info(
f"Fetching data using {dispatcher.__class__.__name__}"
dispatcher = dispatch_dataset(self.expt_id, **self.server_info)
self.log(
f"Fetching data using {dispatcher.__class__.__name__}.", "info"
)
# get log files, either locally or via OMERO
with dispatcher as conn:
image_ids = conn.get_images()
position_ids = conn.get_images()
directory = self.store or root_dir / conn.unique_name
if not directory.exists():
directory.mkdir(parents=True)
# download logs to use for metadata
# get logs to use for metadata
conn.cache_logs(directory)
# update configuration
print("Positions available:")
for i, pos in enumerate(position_ids.keys()):
print("\t" + f"{i}: " + pos.split(".")[0])
# add directory to configuration
self.parameters.general["directory"] = str(directory)
config["general"]["directory"] = directory
self.setLogger(directory)
# pick particular images if desired
if pos_filter is not None:
if isinstance(pos_filter, list):
image_ids = {
k: v
for filt in pos_filter
for k, v in self.apply_filter(image_ids, filt).items()
}
return position_ids
def channels_from_OMERO(self):
"""Get a definitive list of channels from OMERO."""
dispatcher = dispatch_dataset(self.expt_id, **self.server_info)
with dispatcher as conn:
if hasattr(conn, "get_channels"):
channels = conn.get_channels()
else:
image_ids = self.apply_filter(image_ids, pos_filter)
assert len(image_ids), "No images to segment"
# create pipelines
channels = None
return channels
def filter_positions(self, position_filter, position_ids):
"""Select particular positions."""
if isinstance(position_filter, list):
selected_ids = {
k: v
for filt in position_filter
for k, v in self.apply_filter(position_ids, filt).items()
}
else:
selected_ids = self.apply_filter(position_ids, position_filter)
return selected_ids
def apply_filter(self, position_ids: dict, position_filter: int or str):
"""
Select positions.
Either pick a particular position or use a regular expression
to parse their file names.
"""
if isinstance(position_filter, str):
# pick positions using a regular expression
position_ids = {
k: v
for k, v in position_ids.items()
if re.search(position_filter, k)
}
elif isinstance(position_filter, int):
# pick a particular position
position_ids = {
k: v
for i, (k, v) in enumerate(position_ids.items())
if i == position_filter
}
return position_ids
def run(self):
"""Run separate pipelines for all positions in an experiment."""
self.OMERO_channels = self.channels_from_OMERO()
config = self.parameters.to_dict()
position_ids = self.setup()
# pick particular positions if desired
position_filter = config["general"]["filter"]
if position_filter is not None:
position_ids = self.filter_positions(position_filter, position_ids)
if not len(position_ids):
raise Exception("No images to segment.")
else:
print("Positions selected:")
for pos in position_ids:
print("\t" + pos.split(".")[0])
print(f"Number of CPU cores available: {multiprocessing.cpu_count()}")
# create and run pipelines
distributed = config["general"]["distributed"]
if distributed != 0:
# multiple cores
with Pool(distributed) as p:
results = p.map(
lambda x: self.run_one_position(*x),
[(k, i) for i, k in enumerate(image_ids.items())],
self.run_one_position,
[position_id for position_id in position_ids.items()],
)
else:
# single core
results = []
for k, v in tqdm(image_ids.items()):
r = self.run_one_position((k, v), 1)
results.append(r)
results = [
self.run_one_position(position_id)
for position_id in position_ids.items()
]
return results
def apply_filter(self, image_ids: dict, filt: int or str):
"""Select images by picking a particular one or by using a regular expression to parse their file names."""
if isinstance(filt, str):
# pick images using a regular expression
image_ids = {
k: v for k, v in image_ids.items() if re.search(filt, k)
}
elif isinstance(filt, int):
# pick the filt'th image
image_ids = {
k: v for i, (k, v) in enumerate(image_ids.items()) if i == filt
}
return image_ids
def generate_h5file(self, image_id):
"""Delete any existing and then create h5file for one position."""
config = self.parameters.to_dict()
out_dir = config["general"]["directory"]
with dispatch_image(image_id)(image_id, **self.server_info) as image:
out_file = Path(f"{out_dir}/{image.name}.h5")
# remove existing h5 file
if out_file.exists():
os.remove(out_file)
meta = MetaData(out_dir, out_file)
# generate h5 file using meta data from logs
if config["general"]["use_explog"]:
meta.run()
return out_file
def run_one_position(
self,
name_image_id: t.Tuple[str, str or Path or int],
index: t.Optional[int] = None,
self, name_image_id: t.Tuple[str, str or Path or int]
):
"""Set up and run a pipeline for one position."""
self._pool_index = index
"""Run a pipeline for one position."""
name, image_id = name_image_id
# session and filename are defined by calling setup_pipeline.
# can they be deleted here?
session = None
filename = None
#
run_kwargs = {"extraction": {"labels": None, "masks": None}}
try:
(
filename,
meta,
config,
process_from,
tps,
steps,
earlystop,
session,
trackers_state,
) = self._setup_pipeline(image_id)
loaded_writers = {
name: writer(filename)
for k in self.step_sequence
if k in self.writers
for name, writer in self.writers[k]
}
writer_ow_kwargs = {
"state": loaded_writers["state"].datatypes.keys(),
"baby": ["mother_assign"],
}
# START PIPELINE
frac_clogged_traps = 0.0
min_process_from = min(process_from.values())
with dispatch_image(image_id)(
image_id, **self.server_info
) as image:
# initialise steps
if "tiler" not in steps:
steps["tiler"] = Tiler.from_image(
image, TilerParameters.from_dict(config["tiler"])
)
if process_from["baby"] < tps:
session = initialise_tf(2)
steps["baby"] = BabyRunner.from_tiler(
BabyParameters.from_dict(config["baby"]),
steps["tiler"],
)
if trackers_state:
steps["baby"].crawler.tracker_states = trackers_state
# limit extraction parameters using the available channels in tiler
if process_from["extraction"] < tps:
# TODO Move this parameter validation into Extractor
av_channels = set((*steps["tiler"].channels, "general"))
config["extraction"]["tree"] = {
k: v
for k, v in config["extraction"]["tree"].items()
if k in av_channels
}
config["extraction"]["sub_bg"] = av_channels.intersection(
config["extraction"]["sub_bg"]
)
av_channels_wsub = av_channels.union(
[c + "_bgsub" for c in config["extraction"]["sub_bg"]]
config = self.parameters.to_dict()
config["tiler"]["position_name"] = name.split(".")[0]
earlystop = config["general"].get("earlystop", None)
out_file = self.generate_h5file(image_id)
# instantiate writers
tiler_writer = TilerWriter(out_file)
baby_writer = LinearBabyWriter(out_file)
babystate_writer = StateWriter(out_file)
# start pipeline
initialise_tensorflow()
frac_clogged_traps = 0.0
with dispatch_image(image_id)(image_id, **self.server_info) as image:
# initialise tiler; load local meta data from image
tiler = Tiler.from_image(
image,
TilerParameters.from_dict(config["tiler"]),
OMERO_channels=self.OMERO_channels,
)
# initialise Baby
babyrunner = BabyRunner.from_tiler(
BabyParameters.from_dict(config["baby"]), tiler=tiler
)
# initialise extraction
extraction = Extractor.from_tiler(
ExtractorParameters.from_dict(config["extraction"]),
store=out_file,
tiler=tiler,
)
# initiate progress bar
tps = min(config["general"]["tps"], image.data.shape[0])
progress_bar = tqdm(range(tps), desc=image.name)
# run through time points
for i in progress_bar:
if (
frac_clogged_traps < earlystop["thresh_pos_clogged"]
or i < earlystop["min_tp"]
):
# run tiler
result = tiler.run_tp(i)
tiler_writer.write(
data=result,
overwrite=[],
tp=i,
meta={"last_processed:": i},
)
tmp = copy(config["extraction"]["multichannel_ops"])
for op, (input_ch, _, _) in tmp.items():
if not set(input_ch).issubset(av_channels_wsub):
del config["extraction"]["multichannel_ops"][op]
exparams = ExtractorParameters.from_dict(
config["extraction"]
if i == 0:
self.log(
f"Found {tiler.no_tiles} traps in {image.name}.",
"info",
)
# run Baby
try:
result = babyrunner.run_tp(i)
except baby.errors.Clogging:
self.log(
"WARNING:Clogging threshold exceeded in BABY."
)
baby_writer.write(
data=result,
tp=i,
overwrite=["mother_assign"],
meta={"last_processed": i},
)
steps["extraction"] = Extractor.from_tiler(
exparams, store=filename, tiler=steps["tiler"]
babystate_writer.write(
data=babyrunner.crawler.tracker_states,
overwrite=babystate_writer.datatypes.keys(),
tp=i,
)
# set up progress meter
pbar = tqdm(
range(min_process_from, tps),
desc=image.name,
initial=min_process_from,
total=tps,
# run extraction
result = extraction.run_tp(i, cell_labels=None, masks=None)
# check and report clogging
frac_clogged_traps = check_earlystop(
out_file,
earlystop,
tiler.tile_size,
)
for i in pbar:
if (
frac_clogged_traps
< earlystop["thresh_pos_clogged"]
or i < earlystop["min_tp"]
):
# run through steps
for step in self.pipeline_steps:
if i >= process_from[step]:
result = steps[step].run_tp(
i, **run_kwargs.get(step, {})
)
if step in loaded_writers:
loaded_writers[step].write(
data=result,
overwrite=writer_ow_kwargs.get(
step, []
),
tp=i,
meta={"last_processed": i},
)
# perform step
if (
step == "tiler"
and i == min_process_from
):
logging.getLogger("aliby").info(
f"Found {steps['tiler'].n_tiles} traps in {image.name}"
)
elif step == "baby":
# write state and pass info to Extractor
loaded_writers["state"].write(
data=steps[
step
].crawler.tracker_states,
overwrite=loaded_writers[
"state"
].datatypes.keys(),
tp=i,
)
elif step == "extraction":
# remove mask/label after extraction
for k in ["masks", "labels"]:
run_kwargs[step][k] = None
# check and report clogging
frac_clogged_traps = self.check_earlystop(
filename, earlystop, steps["tiler"].tile_size
)
if frac_clogged_traps > 0.3:
self._log(
f"{name}:Clogged_traps:{frac_clogged_traps}"
)
frac = np.round(frac_clogged_traps * 100)
pbar.set_postfix_str(f"{frac} Clogged")
else:
# stop if too many traps are clogged
self._log(
f"{name}:Stopped early at time {i} with {frac_clogged_traps} clogged traps"
)
meta.add_fields({"end_status": "Clogged"})
break
meta.add_fields({"last_processed": i})
# run post-processing
meta.add_fields({"end_status": "Success"})
post_proc_params = PostProcessorParameters.from_dict(
config["postprocessing"]
if frac_clogged_traps > 0.3:
self.log(f"{name}:Clogged_traps:{frac_clogged_traps}")
frac = np.round(frac_clogged_traps * 100)
progress_bar.set_postfix_str(f"{frac} Clogged")
else:
# stop if too many clogged traps
self.log(
f"{name}:Stopped early at time {i} with {frac_clogged_traps} clogged traps"
)
PostProcessor(filename, post_proc_params).run()
self._log("Analysis finished successfully.", "info")
return 1
break
# run post-processing
PostProcessor(
out_file,
PostProcessorParameters.from_dict(config["postprocessing"]),
).run()
self.log("Analysis finished successfully.", "info")
return 1
except Exception as e:
# catch bugs during setup or run time
logging.exception(
f"{name}: Exception caught.",
exc_info=True,
)
# print the type, value, and stack trace of the exception
traceback.print_exc()
raise e
finally:
_close_session(session)
@staticmethod
def check_earlystop(filename: str, es_parameters: dict, tile_size: int):
"""
Check recent time points for tiles with too many cells.
Returns the fraction of clogged tiles, where clogged tiles have
too many cells or too much of their area covered by cells.
Parameters
----------
filename: str
Name of h5 file.
es_parameters: dict
Parameters defining when early stopping should happen.
For example:
{'min_tp': 100,
'thresh_pos_clogged': 0.4,
'thresh_trap_ncells': 8,
'thresh_trap_area': 0.9,
'ntps_to_eval': 5}
tile_size: int
Size of tile.
"""
# get the area of the cells organised by trap and cell number
s = Signal(filename)
df = s.get_raw("/extraction/general/None/area")
# check the latest time points only
cells_used = df[
df.columns[-1 - es_parameters["ntps_to_eval"] : -1]
].dropna(how="all")
# find tiles with too many cells
traps_above_nthresh = (
cells_used.groupby("trap").count().apply(np.mean, axis=1)
> es_parameters["thresh_trap_ncells"]
)
# find tiles with cells covering too great a fraction of the tiles' area
traps_above_athresh = (
cells_used.groupby("trap").sum().apply(np.mean, axis=1)
/ tile_size**2
> es_parameters["thresh_trap_area"]
)
return (traps_above_nthresh & traps_above_athresh).mean()
# FIXME: Remove this functionality. It used to be for
# older hdf5 file formats.
def _load_config_from_file(
self,
filename: Path,
process_from: t.Dict[str, int],
trackers_state: t.List,
overwrite: t.Dict[str, bool],
):
with h5py.File(filename, "r") as f:
for k in process_from.keys():
if not overwrite[k]:
process_from[k] = self.legacy_get_last_tp[k](f)
process_from[k] += 1
return process_from, trackers_state, overwrite
# FIXME: Remove this functionality. It used to be for
# older hdf5 file formats.
@staticmethod
def legacy_get_last_tp(step: str) -> t.Callable:
"""Get last time-point in different ways depending
on which step we are using
To support segmentation in aliby < v0.24
TODO Deprecate and replace with State method
"""
switch_case = {
"tiler": lambda f: f["trap_info/drifts"].shape[0] - 1,
"baby": lambda f: f["cell_info/timepoint"][-1],
"extraction": lambda f: f[
"extraction/general/None/area/timepoint"
][-1],
}
return switch_case[step]
def _setup_pipeline(
self, image_id: int
) -> t.Tuple[
Path,
MetaData,
t.Dict,
int,
t.Dict,
t.Dict,
t.Optional[int],
t.List[np.ndarray],
]:
"""
Initialise steps in a pipeline.
If necessary use a file to re-start experiments already partly run.
Parameters
----------
image_id : int or str
Identifier of a data set in an OMERO server or a filename.
Returns
-------
filename: str
Path to a h5 file to write to.
meta: object
agora.io.metadata.MetaData object
config: dict
Configuration parameters.
process_from: dict
Gives from which time point each step of the pipeline should start.
tps: int
Number of time points.
steps: dict
earlystop: dict
Parameters to check whether the pipeline should be stopped.
session: None
trackers_state: list
States of any trackers from earlier runs.
"""
@property
def display_config(self):
"""Show all parameters for each step of the pipeline."""
config = self.parameters.to_dict()
# TODO Alan: Verify if session must be passed
session = None
earlystop = config["general"].get("earlystop", None)
process_from = {k: 0 for k in self.pipeline_steps}
steps = {}
# check overwriting
ow_id = config["general"].get("overwrite", 0)
ow = {step: True for step in self.step_sequence}
if ow_id and ow_id is not True:
ow = {
step: self.step_sequence.index(ow_id) < i
for i, step in enumerate(self.step_sequence, 1)
}
# Set up
directory = config["general"]["directory"]
for step in config:
print("\n---\n" + step + "\n---")
pprint(config[step])
print()
trackers_state: t.List[np.ndarray] = []
with dispatch_image(image_id)(image_id, **self.server_info) as image:
filename = Path(f"{directory}/{image.name}.h5")
meta = MetaData(directory, filename)
from_start = True if np.any(ow.values()) else False
# remove existing file if overwriting
if (
from_start
and (
config["general"].get("overwrite", False)
or np.all(list(ow.values()))
)
and filename.exists()
):
os.remove(filename)
# if the file exists with no previous segmentation use its tiler
if filename.exists():
self._log("Result file exists.", "info")
if not ow["tiler"]:
steps["tiler"] = Tiler.from_hdf5(image, filename)
try:
(
process_from,
trackers_state,
ow,
) = self._load_config_from_file(
filename, process_from, trackers_state, ow
)
# get state array
trackers_state = (
[]
if ow["baby"]
else StateReader(filename).get_formatted_states()
)
config["tiler"] = steps["tiler"].parameters.to_dict()
except Exception:
self._log(f"Overwriting tiling data")
if config["general"]["use_explog"]:
meta.run()
# add metadata not in the log file
meta.add_fields(
{
"aliby_version": version("aliby"),
"baby_version": version("aliby-baby"),
"omero_id": config["general"]["id"],
"image_id": image_id
if isinstance(image_id, int)
else str(image_id),
"parameters": PipelineParameters.from_dict(
config
).to_yaml(),
}
)
tps = min(config["general"]["tps"], image.data.shape[0])
return (
filename,
meta,
config,
process_from,
tps,
steps,
earlystop,
session,
trackers_state,
def check_earlystop(filename: str, es_parameters: dict, tile_size: int):
"""
Check recent time points for tiles with too many cells.
Returns the fraction of clogged tiles, where clogged tiles have
too many cells or too much of their area covered by cells.
Parameters
----------
filename: str
Name of h5 file.
es_parameters: dict
Parameters defining when early stopping should happen.
For example:
{'min_tp': 100,
'thresh_pos_clogged': 0.4,
'thresh_trap_ncells': 8,
'thresh_trap_area': 0.9,
'ntps_to_eval': 5}
tile_size: int
Size of tile.
"""
# get the area of the cells organised by trap and cell number
s = Signal(filename)
df = s.get_raw("/extraction/general/None/area")
# check the latest time points only
cells_used = df[
df.columns[-1 - es_parameters["ntps_to_eval"] : -1]
].dropna(how="all")
# find tiles with too many cells
traps_above_nthresh = (
cells_used.groupby("trap").count().apply(np.mean, axis=1)
> es_parameters["thresh_trap_ncells"]
)
# find tiles with cells covering too great a fraction of the tiles' area
traps_above_athresh = (
cells_used.groupby("trap").sum().apply(np.mean, axis=1) / tile_size**2
> es_parameters["thresh_trap_area"]
)
return (traps_above_nthresh & traps_above_athresh).mean()
def initialise_tensorflow(version=2):
"""Initialise tensorflow."""
if version == 2:
gpus = tf.config.experimental.list_physical_devices("GPU")
if gpus:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
logical_gpus = tf.config.experimental.list_logical_devices("GPU")
print(
len(gpus), "physical GPUs,", len(logical_gpus), "logical GPUs"
)
def _close_session(session):
if session:
session.close()
"""
Tiler: Divides images into smaller tiles.
The tasks of the Tiler are selecting regions of interest, or tiles, of images - with one trap per tile, correcting for the drift of the microscope stage over time, and handling errors and bridging between the image data and Aliby’s image-processing steps.
The tasks of the Tiler are selecting regions of interest, or tiles, of
images - with one trap per tile, correcting for the drift of the microscope
stage over time, and handling errors and bridging between the image data
and Aliby’s image-processing steps.
Tiler subclasses deal with either network connections or local files.
To find tiles, we use a two-step process: we analyse the bright-field image to produce the template of a trap, and we fit this template to the image to find the tiles' centres.
To find tiles, we use a two-step process: we analyse the bright-field image
to produce the template of a trap, and we fit this template to the image to
find the tiles' centres.
We use texture-based segmentation (entropy) to split the image into foreground -- cells and traps -- and background, which we then identify with an Otsu filter. Two methods are used to produce a template trap from these regions: pick the trap with the smallest minor axis length and average over all validated traps.
We use texture-based segmentation (entropy) to split the image into
foreground -- cells and traps -- and background, which we then identify with
an Otsu filter. Two methods are used to produce a template trap from these
regions: pick the trap with the smallest minor axis length and average over
all validated traps.
A peak-identifying algorithm recovers the x and y-axis location of traps in the original image, and we choose the approach to template that identifies the most tiles.
A peak-identifying algorithm recovers the x and y-axis location of traps in
the original image, and we choose the approach to template that identifies
the most tiles.
The experiment is stored as an array with a standard indexing order of (Time, Channels, Z-stack, X, Y).
The experiment is stored as an array with a standard indexing order of
(Time, Channels, Z-stack, X, Y).
"""
import logging
import re
......@@ -27,26 +39,22 @@ from skimage.registration import phase_cross_correlation
from agora.abc import ParametersABC, StepABC
from agora.io.writer import BridgeH5
from aliby.io.image import ImageDummy
from agora.io.metadata import find_channels_by_position
from aliby.tile.traps import segment_traps
class Tile:
"""
Store a tile's location and size.
Checks to see if the tile should be padded.
Can export the tile either in OMERO or numpy formats.
"""
"""Store a tile's location and size."""
def __init__(self, centre, parent, size, max_size):
def __init__(self, centre, parent_class, size, max_size):
"""Initialise using a parent class."""
self.centre = centre
self.parent = parent # used to access drifts
self.parent_class = parent_class # used to access drifts
self.size = size
self.half_size = size // 2
self.max_size = max_size
def at_time(self, tp: int) -> t.List[int]:
def centre_at_time(self, tp: int) -> t.List[int]:
"""
Return tile's centre by applying drifts.
......@@ -55,7 +63,7 @@ class Tile:
tp: integer
Index for the time point of interest.
"""
drifts = self.parent.drifts
drifts = self.parent_class.drifts
tile_centre = self.centre - np.sum(drifts[: tp + 1], axis=0)
return list(tile_centre.astype(int))
......@@ -74,15 +82,15 @@ class Tile:
Returns
-------
x: int
x-coordinate of bottom left corner of tile
x-coordinate of bottom left corner of tile.
y: int
y-coordinate of bottom left corner of tile
y-coordinate of bottom left corner of tile.
w: int
Width of tile
Width of tile.
h: int
Height of tile
Height of tile.
"""
x, y = self.at_time(tp)
x, y = self.centre_at_time(tp)
# tile bottom corner
x = int(x - self.half_size)
y = int(y - self.half_size)
......@@ -90,8 +98,7 @@ class Tile:
def as_range(self, tp: int):
"""
Return tile in a range format: two slice objects that can
be used in arrays.
Return a horizontal and a vertical slice of a tile.
Parameters
----------
......@@ -117,6 +124,20 @@ class TileLocations:
max_size: int = 1200,
drifts: np.array = None,
):
"""
Initialise tiles as an array of Tile objects.
Parameters
----------
initial_location: array
An array of tile centres.
tile_size: int
Length of one side of a square tile.
max_size: int, optional
Default is 1200.
drifts: array
An array of translations to correct drift of the microscope.
"""
if drifts is None:
drifts = []
self.tile_size = tile_size
......@@ -129,20 +150,21 @@ class TileLocations:
self.drifts = drifts
def __len__(self):
"""Find number of tiles."""
return len(self.tiles)
def __iter__(self):
"""Return the next tile from the list of tiles."""
yield from self.tiles
@property
def shape(self):
"""Return numbers of tiles and drifts."""
"""Return the number of tiles and the number of drifts."""
return len(self.tiles), len(self.drifts)
def to_dict(self, tp: int):
"""
Export initial locations, tile_size, max_size, and drifts
as a dictionary.
Export initial locations, tile_size, max_size, and drifts as a dict.
Parameters
----------
......@@ -157,19 +179,22 @@ class TileLocations:
res["drifts"] = np.expand_dims(self.drifts[tp], axis=0)
return res
def at_time(self, tp: int) -> np.ndarray:
def centres_at_time(self, tp: int) -> np.ndarray:
"""Return an array of tile centres (x- and y-coords)."""
return np.array([tile.at_time(tp) for tile in self.tiles])
return np.array([tile.centre_at_time(tp) for tile in self.tiles])
@classmethod
def from_tiler_init(
cls, initial_location, tile_size: int = None, max_size: int = 1200
cls,
initial_location,
tile_size: int = None,
max_size: int = 1200,
):
"""Instantiate from a Tiler."""
return cls(initial_location, tile_size, max_size, drifts=[])
@classmethod
def read_hdf5(cls, file):
def read_h5(cls, file):
"""Instantiate from a h5 file."""
with h5py.File(file, "r") as hfile:
tile_info = hfile["trap_info"]
......@@ -183,19 +208,13 @@ class TileLocations:
class TilerParameters(ParametersABC):
"""
tile_size: int
ref_channel: str or int
ref_z: int
backup_ref_channel int or None, if int indicates the index for reference channel. Used when image does not include metadata, ref_channel is a string and channel names are included in parsed logfiles.
"""
"""Define default values for tile size and the reference channels."""
_defaults = {
"tile_size": 117,
"ref_channel": "Brightfield",
"ref_z": 0,
"backup_ref_channel": None,
"position_name": None,
}
......@@ -203,10 +222,10 @@ class Tiler(StepABC):
"""
Divide images into smaller tiles for faster processing.
Finds tiles and re-registers images if they drift.
Find tiles and re-register images if they drift.
Fetch images from an OMERO server if necessary.
Uses an Image instance, which lazily provides the data on pixels,
Uses an Image instance, which lazily provides the pixel data,
and, as an independent argument, metadata.
"""
......@@ -215,7 +234,8 @@ class Tiler(StepABC):
image: da.core.Array,
metadata: dict,
parameters: TilerParameters,
tile_locs=None,
tile_locations=None,
OMERO_channels=None,
):
"""
Initialise.
......@@ -226,69 +246,55 @@ class Tiler(StepABC):
metadata: dictionary
parameters: an instance of TilerParameters
tile_locs: (optional)
OMERO_channels: list of str
A definitive list of channels from OMERO to order channels in tiler.
"""
super().__init__(parameters)
self.image = image
self._metadata = metadata
self.channels = metadata.get(
"channels",
list(range(metadata.get("size_c", 0))),
)
self.ref_channel = self.get_channel_index(parameters.ref_channel)
if self.ref_channel is None:
self.ref_channel = self.backup_ref_channel
self.ref_channel = self.get_channel_index(parameters.ref_channel)
self.tile_locs = tile_locs
try:
self.position_name = parameters.to_dict()["position_name"]
# get channels for this position
if "channels_by_group" in metadata:
channel_dict = metadata["channels_by_group"]
elif "positions/posname" in metadata:
# old meta data from image
channel_dict = find_channels_by_position(
metadata["positions/posname"]
)
else:
channel_dict = {}
if channel_dict:
channels = channel_dict.get(
self.position_name,
list(range(metadata.get("size_c", 0))),
)
else:
# new image meta data contains channels for that image
channels = metadata.get(
"channels", list(range(metadata.get("size_c", 0)))
)
# sort channels based on OMERO's channel order
if OMERO_channels is not None:
channels = [
ch for och in OMERO_channels for ch in channels if ch == och
]
self.channels = channels
# get reference channel - used for segmentation
self.ref_channel_index = self.channels.index(parameters.ref_channel)
self.tile_locs = tile_locations
if "zsections" in metadata:
self.z_perchannel = {
ch: zsect
for ch, zsect in zip(self.channels, metadata["zsections"])
}
except Exception as e:
self._log(f"No z_perchannel data: {e}")
self.tile_size = self.tile_size or min(self.image.shape[-2:])
@classmethod
def dummy(cls, parameters: dict):
"""
Instantiate dummy Tiler from dummy image.
If image.dimorder exists dimensions are saved in that order.
Otherwise default to "tczyx".
Parameters
----------
parameters: dict
An instance of TilerParameters converted to a dict.
"""
imgdmy_obj = ImageDummy(parameters)
dummy_image = imgdmy_obj.get_data_lazy()
# default to "tczyx" if image.dimorder is None
dummy_omero_metadata = {
f"size_{dim}": dim_size
for dim, dim_size in zip(
imgdmy_obj.dimorder or "tczyx", dummy_image.shape
)
}
dummy_omero_metadata.update(
{
"channels": [
parameters["ref_channel"],
*(["nil"] * (dummy_omero_metadata["size_c"] - 1)),
],
"name": "",
}
)
return cls(
imgdmy_obj.data,
dummy_omero_metadata,
TilerParameters.from_dict(parameters),
)
@classmethod
def from_image(cls, image, parameters: TilerParameters):
def from_image(
cls,
image,
parameters: TilerParameters,
OMERO_channels: t.List[str] = None,
):
"""
Instantiate from an Image instance.
......@@ -297,7 +303,12 @@ class Tiler(StepABC):
image: an instance of Image
parameters: an instance of TilerPameters
"""
return cls(image.data, image.metadata, parameters)
return cls(
image.data,
image.metadata,
parameters,
OMERO_channels=OMERO_channels,
)
@classmethod
def from_h5(
......@@ -305,18 +316,19 @@ class Tiler(StepABC):
image,
filepath: t.Union[str, Path],
parameters: t.Optional[TilerParameters] = None,
OMERO_channels: t.List[str] = None,
):
"""
Instantiate from h5 files.
Instantiate from an h5 file.
Parameters
----------
image: an instance of Image
filepath: Path instance
Path to a directory of h5 files
Path to an h5 file.
parameters: an instance of TileParameters (optional)
"""
tile_locs = TileLocations.read_hdf5(filepath)
tile_locs = TileLocations.read_h5(filepath)
metadata = BridgeH5(filepath).meta_h5
metadata["channels"] = image.metadata["channels"]
if parameters is None:
......@@ -325,14 +337,15 @@ class Tiler(StepABC):
image.data,
metadata,
parameters,
tile_locs=tile_locs,
tile_locations=tile_locs,
OMERO_channels=OMERO_channels,
)
if hasattr(tile_locs, "drifts"):
tiler.n_processed = len(tile_locs.drifts)
tiler.no_processed = len(tile_locs.drifts)
return tiler
@lru_cache(maxsize=2)
def get_tc(self, t: int, c: int) -> np.ndarray:
def load_image(self, tp: int, c: int) -> np.ndarray:
"""
Load image using dask.
......@@ -345,7 +358,7 @@ class Tiler(StepABC):
Parameters
----------
t: integer
tp: integer
An index for a time point
c: integer
An index for a channel
......@@ -354,32 +367,35 @@ class Tiler(StepABC):
-------
full: an array of images
"""
full = self.image[t, c]
if hasattr(full, "compute"): # If using dask fetch images here
full = self.image[tp, c]
if hasattr(full, "compute"):
# if using dask fetch images
full = full.compute(scheduler="synchronous")
return full
@property
def shape(self):
"""
Return properties of the time-lapse as shown by self.image.shape
Return the shape of the image array.
The image array is arranged as number of images, number of channels,
number of z sections, and size of the image in y and x.
"""
return self.image.shape
@property
def n_processed(self):
def no_processed(self):
"""Return the number of processed images."""
if not hasattr(self, "_n_processed"):
self._n_processed = 0
return self._n_processed
if not hasattr(self, "_no_processed"):
self._no_processed = 0
return self._no_processed
@n_processed.setter
def n_processed(self, value):
self._n_processed = value
@no_processed.setter
def no_processed(self, value):
self._no_processed = value
@property
def n_tiles(self):
def no_tiles(self):
"""Return number of tiles."""
return len(self.tile_locs)
......@@ -395,12 +411,11 @@ class Tiler(StepABC):
tile_size: integer
The size of a tile.
"""
initial_image = self.image[0, self.ref_channel, self.ref_z]
initial_image = self.image[0, self.ref_channel_index, self.ref_z]
if tile_size:
half_tile = tile_size // 2
# max_size is the minimal number of x or y pixels
# max_size is the minimum of the numbers of x and y pixels
max_size = min(self.image.shape[-2:])
# first time point, reference channel, reference z-position
# find the tiles
tile_locs = segment_traps(initial_image, tile_size)
# keep only tiles that are not near an edge
......@@ -415,6 +430,7 @@ class Tiler(StepABC):
tile_locs, tile_size
)
else:
# one tile with its centre at the image's centre
yx_shape = self.image.shape[-2:]
tile_locs = [[x // 2 for x in yx_shape]]
self.tile_locs = TileLocations.from_tiler_init(
......@@ -423,8 +439,9 @@ class Tiler(StepABC):
def find_drift(self, tp: int):
"""
Find any translational drift between two images at consecutive
time points using cross correlation.
Find any translational drift between two images.
Use cross correlation between two consecutive images.
Arguments
---------
......@@ -434,8 +451,8 @@ class Tiler(StepABC):
prev_tp = max(0, tp - 1)
# cross-correlate
drift, _, _ = phase_cross_correlation(
self.image[prev_tp, self.ref_channel, self.ref_z],
self.image[tp, self.ref_channel, self.ref_z],
self.image[prev_tp, self.ref_channel_index, self.ref_z],
self.image[tp, self.ref_channel_index, self.ref_z],
)
# store drift
if 0 < tp < len(self.tile_locs.drifts):
......@@ -445,7 +462,7 @@ class Tiler(StepABC):
def get_tp_data(self, tp, c) -> np.ndarray:
"""
Returns all tiles corrected for drift.
Return all tiles corrected for drift.
Parameters
----------
......@@ -456,25 +473,24 @@ class Tiler(StepABC):
Returns
----------
Numpy ndarray of tiles with shape (tile, z, y, x)
Numpy ndarray of tiles with shape (no tiles, z-sections, y, x)
"""
tiles = []
# get OMERO image
full = self.get_tc(tp, c)
full = self.load_image(tp, c)
for tile in self.tile_locs:
# pad tile if necessary
ndtile = self.ifoob_pad(full, tile.as_range(tp))
ndtile = Tiler.if_out_of_bounds_pad(full, tile.as_range(tp))
tiles.append(ndtile)
return np.stack(tiles)
def get_tile_data(self, tile_id: int, tp: int, c: int):
"""
Return a particular tile corrected for drift and padding.
Return a tile corrected for drift and padding.
Parameters
----------
tile_id: integer
Number of tile.
Index of tile.
tp: integer
Index of time points.
c: integer
......@@ -485,14 +501,14 @@ class Tiler(StepABC):
ndtile: array
An array of (x, y) arrays, one for each z stack
"""
full = self.get_tc(tp, c)
full = self.load_image(tp, c)
tile = self.tile_locs.tiles[tile_id]
ndtile = self.ifoob_pad(full, tile.as_range(tp))
ndtile = self.if_out_of_bounds_pad(full, tile.as_range(tp))
return ndtile
def _run_tp(self, tp: int):
"""
Find tiles if they have not yet been found.
Find tiles for a given time point.
Determine any translational drift of the current image from the
previous one.
......@@ -502,41 +518,33 @@ class Tiler(StepABC):
tp: integer
The time point to tile.
"""
# assert tp >= self.n_processed, "Time point already processed"
# TODO check contiguity?
if self.n_processed == 0 or not hasattr(self.tile_locs, "drifts"):
if self.no_processed == 0 or not hasattr(self.tile_locs, "drifts"):
self.initialise_tiles(self.tile_size)
if hasattr(self.tile_locs, "drifts"):
drift_len = len(self.tile_locs.drifts)
if self.n_processed != drift_len:
warnings.warn("Tiler:n_processed and ndrifts don't match")
self.n_processed = drift_len
# determine drift
if self.no_processed != drift_len:
warnings.warn(
"Tiler: the number of processed tiles and the number of drifts"
" calculated do not match."
)
self.no_processed = drift_len
# determine drift for this time point and update tile_locs.drifts
self.find_drift(tp)
# update n_processed
self.n_processed = tp + 1
# update no_processed
self.no_processed = tp + 1
# return result for writer
return self.tile_locs.to_dict(tp)
def run(self, time_dim=None):
"""
Tile all time points in an experiment at once.
"""
"""Tile all time points in an experiment at once."""
if time_dim is None:
time_dim = 0
for frame in range(self.image.shape[time_dim]):
self.run_tp(frame)
return None
def get_traps_timepoint(self, *args, **kwargs):
self._log(
"get_traps_timepoint is deprecated; get_tiles_timepoint instead."
)
return self.get_tiles_timepoint(*args, **kwargs)
# The next set of functions are necessary for the extraction object
def get_tiles_timepoint(
self, tp: int, tile_shape=None, channels=None, z: int = 0
self, tp: int, channels=None, z: int = 0
) -> np.ndarray:
"""
Get a multidimensional array with all tiles for a set of channels
......@@ -558,67 +566,62 @@ class Tiler(StepABC):
Returns
-------
res: array
Data arranged as (tiles, channels, time points, X, Y, Z)
Data arranged as (tiles, channels, Z, X, Y)
"""
# FIXME add support for sub-tiling a tile
# FIXME can we ignore z
if channels is None:
channels = [0]
elif isinstance(channels, str):
channels = [channels]
# get the data
# convert to indices
channels = [
self.channels.index(channel)
if isinstance(channel, str)
else channel
for channel in channels
]
# get the data as a list of length of the number of channels
res = []
for c in channels:
# only return requested z
val = self.get_tp_data(tp, c)[:, z]
# starts with the order: tiles, z, y, x
# returns the order: tiles, C, T, Z, X, Y
val = np.expand_dims(val, axis=1)
res.append(val)
if tile_shape is not None:
if isinstance(tile_shape, int):
tile_shape = (tile_shape, tile_shape)
assert np.all(
[
(tile_size - ax) > -1
for tile_size, ax in zip(tile_shape, res[0].shape[-3:-2])
]
)
return np.stack(res, axis=1)
@property
def ref_channel_index(self):
"""Return index of reference channel."""
return self.get_channel_index(self.parameters.ref_channel)
tiles = self.get_tp_data(tp, c)[:, z]
# insert new axis at index 1 for missing time point
tiles = np.expand_dims(tiles, axis=1)
res.append(tiles)
# stack at time-point axis if more than one channel
final = np.stack(res, axis=1)
return final
def get_channel_index(self, channel: str or int) -> int or None:
"""
Find index for channel using regex. Returns the first matched string.
If self.channels is integers (no image metadata) it returns None.
If channel is integer
Find index for channel using regex.
If channels are strings, return the first matched string.
If channels are integers, return channel unchanged if it is
an integer.
Parameters
----------
channel: string or int
The channel or index to be used.
"""
if all(map(lambda x: isinstance(x, int), self.channels)):
channel = channel if isinstance(channel, int) else None
if isinstance(channel, str):
channel = find_channel_index(self.channels, channel)
return channel
if isinstance(channel, int) and all(
map(lambda x: isinstance(x, int), self.channels)
):
return channel
elif isinstance(channel, str):
return find_channel_index(self.channels, channel)
else:
return None
@staticmethod
def ifoob_pad(full, slices):
def if_out_of_bounds_pad(image_array, slices):
"""
Return the slices padded if out of bounds.
Pad slices if out of bounds.
Parameters
----------
full: array
Slice of OMERO image (zstacks, x, y) - the entire position
Slice of image (zstacks, x, y) - the entire position
with zstacks as first axis
slices: tuple of two slices
Delineates indices for the x- and y- ranges of the tile.
......@@ -631,11 +634,11 @@ class Tiler(StepABC):
If much padding is needed, a tile of NaN is returned.
"""
# number of pixels in the y direction
max_size = full.shape[-1]
max_size = image_array.shape[-1]
# ignore parts of the tile outside of the image
y, x = [slice(max(0, s.start), min(max_size, s.stop)) for s in slices]
# get the tile including all z stacks
tile = full[:, y, x]
tile = image_array[:, y, x]
# find extent of padding needed in x and y
padding = np.array(
[(-min(0, s.start), -min(0, max_size - s.stop)) for s in slices]
......@@ -643,43 +646,31 @@ class Tiler(StepABC):
if padding.any():
tile_size = slices[0].stop - slices[0].start
if (padding > tile_size / 4).any():
# too much of the tile is outside of the image
# fill with NaN
tile = np.full((full.shape[0], tile_size, tile_size), np.nan)
# fill with NaN because too much of the tile is outside of the image
tile = np.full(
(image_array.shape[0], tile_size, tile_size), np.nan
)
else:
# pad tile with median value of the tile
tile = np.pad(tile, [[0, 0]] + padding.tolist(), "median")
return tile
# FIXME: Refactor to support both channel or index
# self._log below is not defined
def find_channel_index(image_channels: t.List[str], channel: str):
"""
Access
"""
for i, ch in enumerate(image_channels):
found = re.match(channel, ch, re.IGNORECASE)
def find_channel_index(image_channels: t.List[str], channel_regex: str):
"""Use a regex to find the index of a channel."""
for index, ch in enumerate(image_channels):
found = re.match(channel_regex, ch, re.IGNORECASE)
if found:
if len(found.string) - (found.endpos - found.start()):
logging.getLogger("aliby").log(
logging.WARNING,
f"Channel {channel} matched {ch} using regex",
f"Channel {channel_regex} matched {ch} using regex",
)
return i
return index
def find_channel_name(image_channels: t.List[str], channel: str):
"""
Find the name of the channel using regex.
Parameters
----------
image_channels: list of str
Channels.
channel: str
A regular expression.
"""
index = find_channel_index(image_channels, channel)
def find_channel_name(image_channels: t.List[str], channel_regex: str):
"""Find the name of the channel using regex."""
index = find_channel_index(image_channels, channel_regex)
if index is not None:
return image_channels[index]
"""
ImageViewer class, used to look at individual or multiple traps over time.
Example of usage:
fpath = "/home/alan/Documents/dev/skeletons/scripts/data/16543_2019_07_16_aggregates_CTP_switch_2_0glu_0_0glu_URA7young_URA8young_URA8old_01/URA8_young018.h5"
tile_id = 9
trange = list(range(0, 10))
ncols = 8
riv = remoteImageViewer(fpath)
riv.plot_labelled_trap(tile_id, trange, [0], ncols=ncols)
"""
import re
import typing as t
import h5py
from abc import ABC
from pathlib import Path
import h5py
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from PIL import Image
from skimage.morphology import dilation
from agora.io.cells import Cells
from agora.io.metadata import dispatch_metadata_parser
from aliby.tile.tiler import Tiler, TilerParameters
from aliby.io.image import dispatch_image
from aliby.tile.tiler import Tiler
from aliby.utils.plot import stretch_clip
default_colours = {
......@@ -40,9 +24,7 @@ default_colours = {
def custom_imshow(a, norm=None, cmap=None, *args, **kwargs):
"""
Wrapper on plt.imshow function.
"""
"""Wrap plt.imshow."""
if cmap is None:
cmap = "Greys_r"
return plt.imshow(
......@@ -56,232 +38,145 @@ def custom_imshow(a, norm=None, cmap=None, *args, **kwargs):
class BaseImageViewer(ABC):
def __init__(self, fpath):
"""Base class with routines common to all ImageViewers."""
self._fpath = fpath
attrs = dispatch_metadata_parser(fpath.parent)
self._logfiles_meta = {}
self.image_id = attrs.get("image_id")
def __init__(self, h5file_path):
"""Initialise from a Path to a h5 file."""
self.h5file_path = h5file_path
self.logfiles_meta = dispatch_metadata_parser(h5file_path.parent)
self.image_id = self.logfiles_meta.get("image_id")
if self.image_id is None:
with h5py.File(fpath, "r") as f:
with h5py.File(h5file_path, "r") as f:
self.image_id = f.attrs.get("image_id")
assert self.image_id is not None, "No valid image_id found in metadata"
if self.image_id is None:
raise ("No valid image_id found in metadata.")
self.full = {}
@property
def shape(self):
"""Return shape of image array."""
return self.tiler.image.shape
@property
def ntraps(self):
"""Find the number of traps available."""
return self.cells.ntraps
@property
def max_labels(self):
# Print max cell label in whole experiment
"""Find maximum cell label in whole experiment."""
return [max(x) for x in self.cells.labels]
def labels_at_time(self, tp: int):
# Print cell label at a given time-point
"""Find cell label at a given time point."""
return self.cells.labels_at_time(tp)
class LocalImageViewer(BaseImageViewer):
"""
Tool to generate figures from local files, either zarr or files organised
in directories.
TODO move common functionality from RemoteImageViewer to BaseImageViewer
"""
def __init__(self, results_path: str, data_path: str):
super().__init__(results_path)
from aliby.io.image import ImageDir, ImageZarr
self._image_class = (
ImageZarr if data_path.endswith(".zar") else ImageDir
)
with dispatch_image(data_path)(data_path) as image:
self.tiler = Tiler(
image.data,
self._meta if hasattr(self, "_meta") else self._logfiles_meta,
TilerParameters.default(),
)
self.cells = Cells.from_source(results_path)
class RemoteImageViewer(BaseImageViewer):
"""
This ImageViewer combines fetching remote images with tiling and outline display.
"""
_credentials = ("host", "username", "password")
def __init__(
self,
results_path: str,
server_info: t.Dict[str, str],
):
super().__init__(results_path)
from aliby.io.omero import UnsafeImage as OImage
self._server_info = server_info or {
k: attrs["parameters"]["general"][k] for k in self._credentials
}
self._image_instance = OImage(self.image_id, **self._server_info)
self.tiler = Tiler.from_h5(self._image_instance, results_path)
self.cells = Cells.from_source(results_path)
def random_valid_trap_tp(
self,
min_ncells: int = None,
min_consecutive_tps: int = None,
label_modulo: int = None,
def find_channel_indices(
self, channels: t.Union[str, t.Collection[str]], guess=True
):
# Call Cells convenience function to pick a random trap and tp
# containing cells for x cells for y
return self.cells.random_valid_trap_tp(
min_ncells=min_ncells,
min_consecutive_tps=min_consecutive_tps,
)
def get_entire_position(self):
raise (NotImplementedError)
def get_position_timelapse(self):
raise (NotImplementedError)
@property
def full(self):
if not hasattr(self, "_full"):
self._full = {}
return self._full
def get_tc(self, tp, channel=None, server_info=None):
server_info = server_info or self._server_info
channel = channel or self.tiler.ref_channel
with self._image_class(self.image_id, **server_info) as image:
self.tiler.image = image.data
return self.tiler.get_tc(tp, channel)
def _find_channels(self, channels: str, guess: bool = True):
"""Find index for particular channels."""
channels = channels or self.tiler.ref_channel
if isinstance(channels, (int, str)):
channels = [channels]
if isinstance(channels[0], str):
if guess:
channels = [self.tiler.channels.index(ch) for ch in channels]
indices = [self.tiler.channels.index(ch) for ch in channels]
else:
channels = [
indices = [
re.search(ch, tiler_channels)
for ch in channels
for tiler_channels in self.tiler.channels
]
return indices
else:
return channels
def get_outlines_tiles_dict(self, tile_id, trange, channels):
"""Get outlines and dict of tiles with channel indices as keys."""
outlines = None
tile_dict = {}
for ch in self.find_channel_indices(channels):
outlines, tile_dict[ch] = self.get_outlines_tiles(
tile_id, trange, channels=[ch]
)
return outlines, tile_dict
def get_outlines_tiles(
self,
tile_id: int,
tps: t.Union[range, t.Collection[int]],
channels=None,
concatenate=True,
**kwargs,
) -> t.Tuple[np.array]:
"""
Get masks uniquely labelled for each cell with the corresponding tiles.
return channels
Returns a list of masks, each an array with distinct masks for each cell,
and an array of tiles for the given channel.
"""
tile_dict = self.get_tiles(tps, channels=channels, **kwargs)
# get tiles of interest
tiles = [x[tile_id] for x in tile_dict.values()]
# get outlines for each time point
outlines = [
self.cells.at_time(tp, kind="edgemask").get(tile_id, []) for tp in tps
]
# get cell labels for each time point
cell_labels = [self.cells.labels_at_time(tp).get(tile_id, []) for tp in tps]
# generate one image with all cell outlines uniquely labelled per tile
labelled_outlines = [
np.stack(
[outline * label for outline, label in zip(outlines_tp, labels_tp)]
).max(axis=0)
if len(labels_tp)
else np.zeros_like(tiles[0]).astype(bool)
for outlines_tp, labels_tp in zip(outlines, cell_labels)
]
if concatenate:
# concatenate to allow potential image processing
labelled_outlines = np.concatenate(labelled_outlines, axis=1)
tiles = np.concatenate(tiles, axis=1)
return labelled_outlines, tiles
def get_pos_timepoints(
def get_tiles(
self,
tps: t.Union[int, t.Collection[int]],
channels: t.Union[str, t.Collection[str]] = None,
channels: None,
z: int = None,
server_info=None,
):
"""Get dict with time points as keys and all available tiles as values."""
if tps and not isinstance(tps, t.Collection):
tps = range(tps)
# TODO add support for multiple channels or refactor
if channels and not isinstance(channels, t.Collection):
channels = [channels]
if z is None:
z = 0
server_info = server_info or self._server_info
channels = 0 or self._find_channels(channels)
z = z or self.tiler.ref_z
ch_tps = [(channels[0], tp) for tp in tps]
image = self._image_instance
self.tiler.image = image.data
channel_indices = self.find_channel_indices(channels)
ch_tps = [(channel_indices[0], tp) for tp in tps]
for ch, tp in ch_tps:
if (ch, tp) not in self.full:
self.full[(ch, tp)] = self.tiler.get_tiles_timepoint(
tp, channels=[ch], z=[z]
)[:, 0, 0, z, ...]
requested_trap = {tp: self.full[(ch, tp)] for ch, tp in ch_tps}
return requested_trap
def get_labelled_trap(
self,
tile_id: int,
tps: t.Union[range, t.Collection[int]],
channels=None,
concatenate=True,
**kwargs,
) -> t.Tuple[np.array]:
"""
Core method to fetch traps and labels together
"""
imgs = self.get_pos_timepoints(tps, channels=channels, **kwargs)
imgs_list = [x[tile_id] for x in imgs.values()]
outlines = [
self.cells.at_time(tp, kind="edgemask").get(tile_id, [])
for tp in tps
]
lbls = [self.cells.labels_at_time(tp).get(tile_id, []) for tp in tps]
lbld_outlines = [
np.stack([mask * lbl for mask, lbl in zip(maskset, lblset)]).max(
axis=0
)
if len(lblset)
else np.zeros_like(imgs_list[0]).astype(bool)
for maskset, lblset in zip(outlines, lbls)
]
if concatenate:
lbld_outlines = np.concatenate(lbld_outlines, axis=1)
imgs_list = np.concatenate(imgs_list, axis=1)
return lbld_outlines, imgs_list
def get_images(self, tile_id, trange, channels, **kwargs):
"""
Wrapper to fetch images
"""
out = None
imgs = {}
for ch in self._find_channels(channels):
out, imgs[ch] = self.get_labelled_trap(
tile_id, trange, channels=[ch], **kwargs
)
return out, imgs
tile_dict = {tp: self.full[(ch, tp)] for ch, tp in ch_tps}
return tile_dict
def plot_labelled_trap(
self,
tile_id: int,
tile_id,
channels,
trange: t.Union[range, t.Collection[int]],
remove_axis: bool = False,
savefile: str = None,
skip_outlines: bool = False,
norm: str = None,
norm=True,
ncols: int = None,
local_colours: bool = True,
img_plot_kwargs: dict = {},
lbl_plot_kwargs: dict = {"alpha": 0.8},
**kwargs,
):
"""Wrapper to plot time-lapses of individual traps
"""
Plot time-lapses of individual tiles.
Use Cells and Tiler to generate images of cells with their resulting
outlines.
......@@ -290,15 +185,13 @@ class RemoteImageViewer(BaseImageViewer):
----------
tile_id : int
Identifier of trap
channels : Union[str, int]
channel : Union[str, int]
Channels to use
trange : t.Union[range, t.Collection[int]]
Range or collection indicating the time-points to use.
remove_axis : bool
None, "off", or "x". Determines whether to remove the x-axis, both
axes or none.
savefile : str
Saves file to a location.
skip_outlines : bool
Do not add overlay with outlines
norm : str
......@@ -312,68 +205,40 @@ class RemoteImageViewer(BaseImageViewer):
Arguments to pass to plt.imshow used for images.
lbl_plot_kwargs : dict
Keyword arguments to pass to label plots.
**kwargs : dict
Additional keyword arguments passed to ImageViewer.get_images.
Examples
--------
FIXME: Add docs.
"""
# set up for plotting
if ncols is None:
ncols = len(trange)
nrows = int(np.ceil(len(trange) / ncols))
width = self.tiler.tile_size * ncols
out, images = self.get_images(tile_id, trange, channels, **kwargs)
# dilation makes outlines easier to see
out = dilation(out).astype(float)
out[out == 0] = np.nan
outlines, tiles_dict = self.get_outlines_tiles_dict(tile_id, trange, channels)
channel_labels = [
self.tiler.channels[ch] if isinstance(ch, int) else ch
for ch in channels
self.tiler.channels[ch] if isinstance(ch, int) else ch for ch in channels
]
# dilate to make outlines easier to see
outlines = dilation(outlines).astype(float)
outlines[outlines == 0] = np.nan
# split concatenated tiles into one tile per time point in a row
tiles = [
into_image_time_series(tile, width, nrows) for tile in tiles_dict.values()
]
assert not norm or norm in (
"l1",
"l2",
"max",
), "Invalid norm argument."
if norm and norm in ("l1", "l2", "max"):
images = {k: stretch_clip(v) for k, v in images.items()}
images = [concat_pad(img, width, nrows) for img in images.values()]
# TODO convert to RGB to draw fluorescence with colour
tiled_imgs = {}
tiled_imgs["img"] = np.concatenate(images, axis=0)
tiled_imgs["cell_labels"] = np.concatenate(
[concat_pad(out, width, nrows) for _ in images], axis=0
)
custom_imshow(
tiled_imgs["img"],
**img_plot_kwargs,
res = {}
# concatenate different channels vertically for display
res["tiles"] = np.concatenate(tiles, axis=0)
res["cell_labels"] = np.concatenate(
[into_image_time_series(outlines, width, nrows) for _ in tiles], axis=0
)
custom_imshow(res["tiles"], **img_plot_kwargs)
custom_imshow(
tiled_imgs["cell_labels"],
cmap=sns.color_palette("Paired", as_cmap=True),
**lbl_plot_kwargs,
res["cell_labels"], cmap=default_colours["cell_label"], **lbl_plot_kwargs
)
if remove_axis is True:
plt.axis("off")
elif remove_axis == "x":
plt.tick_params(
axis="x",
which="both",
bottom=False,
top=False,
labelbottom=False,
axis="x", which="both", bottom=False, top=False, labelbottom=False
)
if remove_axis != "True":
plt.yticks(
ticks=[
......@@ -383,39 +248,63 @@ class RemoteImageViewer(BaseImageViewer):
],
labels=channel_labels,
)
if not remove_axis:
xlabels = (
["+ {} ".format(i) for i in range(ncols)]
if nrows > 1
else list(trange)
["+ {} ".format(i) for i in range(ncols)] if nrows > 1 else list(trange)
)
plt.xlabel("Time-point")
plt.xticks(
ticks=[self.tiler.tile_size * (i + 0.5) for i in range(ncols)],
labels=xlabels,
)
if not np.any(outlines):
print("ImageViewer:Warning: No cell outlines found.")
plt.tight_layout()
plt.show(block=False)
if not np.any(out):
print("ImageViewer:Warning:No cell outlines found")
if savefile:
plt.savefig(savefile, bbox_inches="tight", dpi=300)
plt.close()
else:
plt.show()
def concat_pad(a: np.array, width, nrows):
class LocalImageViewer(BaseImageViewer):
"""
Melt an array into having multiple blocks as rows
View images from local files.
File are either zarr or organised in directories.
"""
def __init__(self, h5file: str, image_direc: str):
"""Initialise using a h5file and a local directory of images."""
h5file_path = Path(h5file)
image_direc_path = Path(image_direc)
super().__init__(h5file_path)
with dispatch_image(image_direc_path)(image_direc_path) as image:
self.tiler = Tiler.from_h5(image, h5file_path)
self.cells = Cells.from_source(h5file_path)
class RemoteImageViewer(BaseImageViewer):
"""Fetching remote images with tiling and outline display."""
credentials = ("host", "username", "password")
def __init__(self, h5file: str, server_info: t.Dict[str, str]):
"""Initialise using a h5file and importing aliby.io.omero."""
from aliby.io.omero import UnsafeImage as OImage
h5file_path = Path(h5file)
super().__init__(h5file_path)
self.server_info = server_info or {
k: self.attrs["parameters"]["general"][k] for k in self.credentials
}
image = OImage(self.image_id, **self._server_info)
self.tiler = Tiler.from_h5(image, h5file_path)
self.cells = Cells.from_source(h5file_path)
def into_image_time_series(a: np.array, width, nrows):
"""Split into sub-arrays and then concatenate into one."""
return np.concatenate(
np.array_split(
np.pad(
a,
# ((0, 0), (0, width - (a.shape[1] % width))),
((0, 0), (0, a.shape[1] % width)),
constant_values=np.nan,
),
......
......@@ -52,7 +52,7 @@ def plot_in_square(data: t.Iterable):
def stretch_clip(image, clip=True):
"""
Performs contrast stretching on an input image.
Perform contrast stretching on an input image.
This function takes an array-like input image and enhances its contrast by adjusting
the dynamic range of pixel values. It first scales the pixel values between 0 and 255,
......
__version__ = "0.1.64 lite"
import typing as t
from pathlib import Path
import bottleneck as bn
import h5py
import numpy as np
import pandas as pd
import aliby.global_parameters as global_parameters
from agora.abc import ParametersABC, StepABC
from agora.io.cells import Cells
from agora.io.writer import Writer, load_attributes
from aliby.tile.tiler import Tiler
from extraction.core.functions.defaults import exparams_from_meta
from agora.io.writer import Writer, load_meta
from aliby.tile.tiler import Tiler, find_channel_name
from extraction.core.functions.distributors import reduce_z, trap_apply
from extraction.core.functions.loaders import (
load_custom_args,
......@@ -26,13 +27,44 @@ extraction_result = t.Dict[
str, t.Dict[reduction_method, t.Dict[str, t.Dict[str, pd.Series]]]
]
# Global variables used to load functions that either analyse cells or their background. These global variables both allow the functions to be stored in a dictionary for access only on demand and to be defined simply in extraction/core/functions.
CELL_FUNS, TRAPFUNS, FUNS = load_funs()
# Global variables used to load functions that either analyse cells
# or their background. These global variables both allow the functions
# to be stored in a dictionary for access only on demand and to be
# defined simply in extraction/core/functions.
CELL_FUNS, TRAP_FUNS, ALL_FUNS = load_funs()
CUSTOM_FUNS, CUSTOM_ARGS = load_custom_args()
RED_FUNS = load_redfuns()
# Assign datatype depending on the metric used
# m2type = {"mean": np.float32, "median": np.ubyte, "imBackground": np.ubyte}
REDUCTION_FUNS = load_redfuns()
def extraction_params_from_meta(
meta: t.Union[dict, Path, str], extras: t.Collection[str] = ["ph"]
):
"""Obtain parameters for extraction from meta data."""
if not isinstance(meta, dict):
# load meta data
with h5py.File(meta, "r") as f:
meta = dict(f["/"].attrs.items())
base = {
"tree": {"general": {"None": ["area", "volume", "eccentricity"]}},
"multichannel_ops": {},
}
candidate_channels = set(global_parameters.possible_imaging_channels)
default_reductions = {"max"}
default_metrics = set(global_parameters.fluorescence_functions)
default_reduction_metrics = {
r: default_metrics for r in default_reductions
}
# default_rm["None"] = ["nuc_conv_3d"] # Uncomment this to add nuc_conv_3d (slow)
extant_fluorescence_ch = []
for av_channel in candidate_channels:
# find matching channels in metadata
found_channel = find_channel_name(meta.get("channels", []), av_channel)
if found_channel is not None:
extant_fluorescence_ch.append(found_channel)
for ch in extant_fluorescence_ch:
base["tree"][ch] = default_reduction_metrics
base["sub_bg"] = extant_fluorescence_ch
return base
class ExtractorParameters(ParametersABC):
......@@ -67,28 +99,30 @@ class ExtractorParameters(ParametersABC):
@classmethod
def from_meta(cls, meta):
return cls(**exparams_from_meta(meta))
"""Instantiate from the meta data; used by Pipeline."""
return cls(**extraction_params_from_meta(meta))
class Extractor(StepABC):
"""
Apply a metric to cells identified in the tiles.
Using the cell masks, the Extractor applies a metric, such as area or median, to cells identified in the image tiles.
Using the cell masks, the Extractor applies a metric, such as
area or median, to cells identified in the image tiles.
Its methods require both tile images and masks.
Usually the metric is applied to only a tile's masked area, but some metrics depend on the whole tile.
Usually the metric is applied to only a tile's masked area, but
some metrics depend on the whole tile.
Extraction follows a three-level tree structure. Channels, such as GFP, are the root level; the reduction algorithm, such as maximum projection, is the second level; the specific metric, or operation, to apply to the masks, such as mean, is the third level.
Extraction follows a three-level tree structure. Channels, such
as GFP, are the root level; the reduction algorithm, such as
maximum projection, is the second level; the specific metric,
or operation, to apply to the masks, such as mean, is the third
or leaf level.
"""
# TODO Alan: Move this to a location with the SwainLab defaults
default_meta = {
"pixel_size": 0.236,
"z_size": 0.6,
"spacing": 0.6,
}
default_meta = global_parameters.imaging_specifications
def __init__(
self,
......@@ -107,17 +141,36 @@ class Extractor(StepABC):
store: str
Path to the h5 file containing the cell masks.
tiler: pipeline-core.core.segmentation tiler
Class that contains or fetches the images used for segmentation.
Class that contains or fetches the images used for
segmentation.
"""
self.params = parameters
if store:
self.local = store
self.load_meta()
self.h5path = store
self.meta = load_meta(self.h5path)
else:
# if no h5 file, use the parameters directly
self.meta = {"channel": parameters.to_dict()["tree"].keys()}
if tiler:
self.tiler = tiler
available_channels = set((*tiler.channels, "general"))
# only extract for channels available
self.params.tree = {
k: v
for k, v in self.params.tree.items()
if k in available_channels
}
self.params.sub_bg = available_channels.intersection(
self.params.sub_bg
)
# add background subtracted channels to those available
available_channels_bgsub = available_channels.union(
[c + "_bgsub" for c in self.params.sub_bg]
)
# remove any multichannel operations requiring a missing channel
for op, (input_ch, _, _) in self.params.multichannel_ops.items():
if not set(input_ch).issubset(available_channels_bgsub):
self.params.multichannel_ops.pop(op)
self.load_funs()
@classmethod
......@@ -150,43 +203,53 @@ class Extractor(StepABC):
@property
def current_position(self):
return str(self.local).split("/")[-1][:-3]
"""Return position being analysed."""
return str(self.h5path).split("/")[-1][:-3]
@property
def group(self):
"""Return path within the h5 file."""
"""Return out path to write in the h5 file."""
if not hasattr(self, "_out_path"):
self._group = "/extraction/"
return self._group
def load_funs(self):
"""Define all functions, including custom ones."""
self.load_custom_funs()
self.all_cell_funs = set(self.custom_funs.keys()).union(CELL_FUNS)
# merge the two dicts
self.all_funs = {**self.custom_funs, **ALL_FUNS}
def load_custom_funs(self):
"""
Incorporate the extra arguments of custom functions into their definitions.
Incorporate extra arguments of custom functions into their definitions.
Normal functions only have cell_masks and trap_image as their
arguments, and here custom functions are made the same by
setting the values of their extra arguments.
Any other parameters are taken from the experiment's metadata and automatically applied. These parameters therefore must be loaded within an Extractor instance.
Any other parameters are taken from the experiment's metadata
and automatically applied. These parameters therefore must be
loaded within an Extractor instance.
"""
# find functions specified in params.tree
funs = set(
[
fun
for ch in self.params.tree.values()
for red in ch.values()
for fun in red
for channel in self.params.tree.values()
for reduction in channel.values()
for fun in reduction
]
)
# consider only those already loaded from CUSTOM_FUNS
funs = funs.intersection(CUSTOM_FUNS.keys())
# find their arguments
self._custom_arg_vals = {
self.custom_arg_vals = {
k: {k2: self.get_meta(k2) for k2 in v}
for k, v in CUSTOM_ARGS.items()
}
# define custom functions
self._custom_funs = {}
self.custom_funs = {}
for k, f in CUSTOM_FUNS.items():
def tmp(f):
......@@ -196,33 +259,22 @@ class Extractor(StepABC):
f,
cell_masks,
trap_image,
**self._custom_arg_vals.get(k, {}),
**self.custom_arg_vals.get(k, {}),
)
self._custom_funs[k] = tmp(f)
def load_funs(self):
"""Define all functions, including custum ones."""
self.load_custom_funs()
self._all_cell_funs = set(self._custom_funs.keys()).union(CELL_FUNS)
# merge the two dicts
self._all_funs = {**self._custom_funs, **FUNS}
def load_meta(self):
"""Load metadata from h5 file."""
self.meta = load_attributes(self.local)
self.custom_funs[k] = tmp(f)
def get_tiles(
self,
tp: int,
channels: t.Optional[t.List[t.Union[str, int]]] = None,
z: t.Optional[t.List[str]] = None,
**kwargs,
) -> t.Optional[np.ndarray]:
"""
Find tiles for a given time point, channels, and z-stacks.
Any additional keyword arguments are passed to tiler.get_tiles_timepoint
Any additional keyword arguments are passed to
tiler.get_tiles_timepoint
Parameters
----------
......@@ -243,27 +295,26 @@ class Extractor(StepABC):
# a list of the indices of the z stacks
channel_ids = None
if z is None:
# gets the tiles data via tiler
# include all Z channels
z = list(range(self.tiler.shape[-3]))
res = (
self.tiler.get_tiles_timepoint(
tp, channels=channel_ids, z=z, **kwargs
)
# get the image data via tiler
tiles = (
self.tiler.get_tiles_timepoint(tp, channels=channel_ids, z=z)
if channel_ids
else None
)
# data arranged as (tiles, channels, time points, X, Y, Z)
return res
# tiles has dimensions (tiles, channels, 1, Z, X, Y)
return tiles
def extract_traps(
def apply_cell_function(
self,
traps: t.List[np.ndarray],
masks: t.List[np.ndarray],
metric: str,
labels: t.Dict[int, t.List[int]],
cell_function: str,
cell_labels: t.Dict[int, t.List[int]],
) -> t.Tuple[t.Union[t.Tuple[float], t.Tuple[t.Tuple[int]]]]:
"""
Apply a function to a whole position.
Apply a cell function to all cells at all traps for one time point.
Parameters
----------
......@@ -271,35 +322,35 @@ class Extractor(StepABC):
t.List of images.
masks: list of arrays
t.List of masks.
metric: str
Metric to extract.
labels: dict
A dict of cell labels with trap_ids as keys and a list of cell labels as values.
pos_info: bool
Whether to add the position as an index or not.
cell_function: str
Function to apply.
cell_labels: dict
A dict with trap_ids as keys and a list of cell labels as
values.
Returns
-------
res_idx: a tuple of tuples
A two-tuple comprising a tuple of results and a tuple of the tile_id and cell labels
A two-tuple comprising a tuple of results and a tuple of
the tile_id and cell labels
"""
if labels is None:
self._log("No labels given. Sorting cells using index.")
cell_fun = True if metric in self._all_cell_funs else False
if cell_labels is None:
self.log("No cell labels given. Sorting cells using index.")
cell_fun = True if cell_function in self.all_cell_funs else False
idx = []
results = []
for trap_id, (mask_set, trap, lbl_set) in enumerate(
zip(masks, traps, labels.values())
for trap_id, (mask_set, trap, local_cell_labels) in enumerate(
zip(masks, traps, cell_labels.values())
):
# ignore empty traps
if len(mask_set):
# apply metric either a cell function or otherwise
result = self._all_funs[metric](mask_set, trap)
# find property from the tile
result = self.all_funs[cell_function](mask_set, trap)
if cell_fun:
# store results for each cell separately
for lbl, val in zip(lbl_set, result):
for cell_label, val in zip(local_cell_labels, result):
results.append(val)
idx.append((trap_id, lbl))
idx.append((trap_id, cell_label))
else:
# background (trap) function
results.append(result)
......@@ -307,68 +358,74 @@ class Extractor(StepABC):
res_idx = (tuple(results), tuple(idx))
return res_idx
def extract_funs(
def apply_cell_funs(
self,
traps: t.List[np.array],
tiles: t.List[np.array],
masks: t.List[np.array],
metrics: t.List[str],
cell_funs: t.List[str],
**kwargs,
) -> t.Dict[str, pd.Series]:
"""
Return dict with metrics as key and metrics applied to data as values.
Return dict with cell_funs as keys and the corresponding results as values.
Data from one time point is used.
"""
d = {
metric: self.extract_traps(
traps=traps, masks=masks, metric=metric, **kwargs
cell_fun: self.apply_cell_function(
traps=tiles, masks=masks, cell_function=cell_fun, **kwargs
)
for metric in metrics
for cell_fun in cell_funs
}
return d
def reduce_extract(
self,
traps: np.ndarray,
tiles: np.ndarray,
masks: t.List[np.ndarray],
red_metrics: t.Dict[reduction_method, t.Collection[str]],
reduction_cell_funs: t.Dict[reduction_method, t.Collection[str]],
**kwargs,
) -> t.Dict[str, t.Dict[reduction_method, t.Dict[str, pd.Series]]]:
"""
Wrapper to apply reduction and then extraction.
Reduce to a 2D image and then extract.
Parameters
----------
tiles_data: array
tiles: array
An array of image data arranged as (tiles, X, Y, Z)
masks: list of arrays
An array of masks for each trap: one per cell at the trap
red_metrics: dict
dict for which keys are reduction functions and values are either a list or a set of strings giving the metric functions.
reduction_cell_funs: dict
An upper branch of the extraction tree: a dict for which
keys are reduction functions and values are either a list
or a set of strings giving the cell functions to apply.
For example: {'np_max': {'max5px', 'mean', 'median'}}
**kwargs: dict
All other arguments passed to Extractor.extract_funs.
All other arguments passed to Extractor.apply_cell_funs.
Returns
------
Dict of dataframes with the corresponding reductions and metrics nested.
"""
# create dict with keys naming the reduction in the z-direction and the reduced data as values
reduced_tiles_data = {}
if traps is not None:
for red_fun in red_metrics.keys():
reduced_tiles_data[red_fun] = [
self.reduce_dims(tile_data, method=RED_FUNS[red_fun])
for tile_data in traps
# create dict with keys naming the reduction in the z-direction
# and the reduced data as values
reduced_tiles = {}
if tiles is not None:
for reduction in reduction_cell_funs.keys():
reduced_tiles[reduction] = [
self.reduce_dims(
tile_data, method=REDUCTION_FUNS[reduction]
)
for tile_data in tiles
]
# calculate cell and tile properties
d = {
red_fun: self.extract_funs(
metrics=metrics,
traps=reduced_tiles_data.get(red_fun, [None for _ in masks]),
reduction: self.apply_cell_funs(
tiles=reduced_tiles.get(reduction, [None for _ in masks]),
masks=masks,
cell_funs=cell_funs,
**kwargs,
)
for red_fun, metrics in red_metrics.items()
for reduction, cell_funs in reduction_cell_funs.items()
}
return d
......@@ -392,13 +449,141 @@ class Extractor(StepABC):
reduced = reduce_z(img, method)
return reduced
def make_tree_dict(self, tree: extraction_tree):
"""Put extraction tree into a dict."""
if tree is None:
# use default
tree = self.params.tree
tree_dict = {
# the whole extraction tree
"tree": tree,
# the extraction tree for fluorescence channels
"channels_tree": {
ch: v for ch, v in tree.items() if ch != "general"
},
}
# tuple of the fluorescence channels
tree_dict["channels"] = (*tree_dict["channels_tree"],)
return tree_dict
def get_masks(self, tp, masks, cells):
"""Get the masks as a list with an array of masks for each trap."""
# find the cell masks for a given trap as a dict with trap_ids as keys
if masks is None:
raw_masks = cells.at_time(tp, kind="mask")
masks = {trap_id: [] for trap_id in range(cells.ntraps)}
for trap_id, cells in raw_masks.items():
if len(cells):
masks[trap_id] = np.stack(np.array(cells)).astype(bool)
# convert to a list of masks
# one array of size (no cells, tile_size, tile_size) per trap
masks = [np.array(v) for v in masks.values()]
return masks
def get_cell_labels(self, tp, cell_labels, cells):
"""Get the cell labels per trap as a dict with trap_ids as keys."""
if cell_labels is None:
raw_cell_labels = cells.labels_at_time(tp)
cell_labels = {
trap_id: raw_cell_labels.get(trap_id, [])
for trap_id in range(cells.ntraps)
}
return cell_labels
def get_background_masks(self, masks, tile_size):
"""
Generate boolean background masks.
Combine masks per trap and then take the logical inverse.
"""
if self.params.sub_bg:
bgs = ~np.array(
list(
map(
# sum over masks for each cell
lambda x: (
np.sum(x, axis=0)
if np.any(x)
else np.zeros((tile_size, tile_size))
),
masks,
)
)
).astype(bool)
else:
bgs = np.array([])
return bgs
def extract_one_channel(
self, tree_dict, cell_labels, img, img_bgsub, masks, **kwargs
):
"""Extract as dict all metrics requiring only a single channel."""
d = {}
for ch, reduction_cell_funs in tree_dict["tree"].items():
# extract from all images including bright field
d[ch] = self.reduce_extract(
# use None for "general"; no fluorescence image
tiles=img.get(ch, None),
masks=masks,
reduction_cell_funs=reduction_cell_funs,
cell_labels=cell_labels,
**kwargs,
)
if ch != "general":
# extract from background-corrected fluorescence images
d[ch + "_bgsub"] = self.reduce_extract(
tiles=img_bgsub[ch + "_bgsub"],
masks=masks,
reduction_cell_funs=reduction_cell_funs,
cell_labels=cell_labels,
**kwargs,
)
return d
def extract_multiple_channels(self, cell_labels, img, img_bgsub, masks):
"""Extract as a dict all metrics requiring multiple channels."""
# NB multichannel functions do not use tree_dict
available_channels = set(list(img.keys()) + list(img_bgsub.keys()))
d = {}
for multichannel_fun_name, (
channels,
reduction,
multichannel_function,
) in self.params.multichannel_ops.items():
common_channels = set(channels).intersection(available_channels)
# all required channels should be available
if len(common_channels) == len(channels):
for images, suffix in zip([img, img_bgsub], ["", "_bgsub"]):
# channels
channels_stack = np.stack(
[images[ch + suffix] for ch in channels],
axis=-1,
)
# reduce in Z
tiles = REDUCTION_FUNS[reduction](channels_stack, axis=1)
# set up dict
if multichannel_fun_name not in d:
d[multichannel_fun_name] = {}
if reduction not in d[multichannel_fun_name]:
d[multichannel_fun_name][reduction] = {}
# apply multichannel function
d[multichannel_fun_name][reduction][
multichannel_function + suffix
] = self.apply_cell_function(
tiles,
masks,
multichannel_function,
cell_labels,
)
return d
def extract_tp(
self,
tp: int,
tree: t.Optional[extraction_tree] = None,
tile_size: int = 117,
masks: t.Optional[t.List[np.ndarray]] = None,
labels: t.Optional[t.List[int]] = None,
cell_labels: t.Optional[t.List[int]] = None,
**kwargs,
) -> t.Dict[str, t.Dict[str, t.Dict[str, tuple]]]:
"""
......@@ -409,15 +594,15 @@ class Extractor(StepABC):
tp : int
Time point being analysed.
tree : dict
Nested dictionary indicating channels, reduction functions and
metrics to be used.
Nested dictionary indicating channels, reduction functions
and metrics to be used.
For example: {'general': {'None': ['area', 'volume', 'eccentricity']}}
tile_size : int
Size of the tile to be extracted.
masks : list of arrays
A list of masks per trap with each mask having dimensions (ncells, tile_size,
tile_size).
labels : dict
A list of masks per trap with each mask having dimensions
(ncells, tile_size, tile_size) and with one mask per cell.
cell_labels : dict
A dictionary with trap_ids as keys and cell_labels as values.
**kwargs : keyword arguments
Passed to extractor.reduce_extract.
......@@ -428,125 +613,86 @@ class Extractor(StepABC):
Dictionary of the results with three levels of dictionaries.
The first level has channels as keys.
The second level has reduction metrics as keys.
The third level has cell or background metrics as keys and a two-tuple as values.
The first tuple is the result of applying the metrics to a particular cell or trap; the second tuple is either (trap_id, cell_label) for a metric applied to a cell or a trap_id for a metric applied to a trap.
An example is d["GFP"]["np_max"]["mean"][0], which gives a tuple of the calculated mean GFP fluorescence for all cells.
The third level has cell or background metrics as keys and a
two-tuple as values.
The first tuple is the result of applying the metrics to a
particular cell or trap; the second tuple is either
(trap_id, cell_label) for a metric applied to a cell or a
trap_id for a metric applied to a trap.
An example is d["GFP"]["np_max"]["mean"][0], which gives a tuple
of the calculated mean GFP fluorescence for all cells.
"""
# TODO Can we split the different extraction types into sub-methods to make this easier to read?
if tree is None:
# use default
tree: extraction_tree = self.params.tree
# dictionary with channel: {reduction algorithm : metric}
ch_tree = {ch: v for ch, v in tree.items() if ch != "general"}
# tuple of the channels
tree_chs = (*ch_tree,)
# dict of information from extraction tree
tree_dict = self.make_tree_dict(tree)
# create a Cells object to extract information from the h5 file
cells = Cells(self.local)
# find the cell labels and store as dict with trap_ids as keys
if labels is None:
raw_labels = cells.labels_at_time(tp)
labels = {
trap_id: raw_labels.get(trap_id, [])
for trap_id in range(cells.ntraps)
}
# find the cell masks for a given trap as a dict with trap_ids as keys
if masks is None:
raw_masks = cells.at_time(tp, kind="mask")
masks = {trap_id: [] for trap_id in range(cells.ntraps)}
for trap_id, cells in raw_masks.items():
if len(cells):
masks[trap_id] = np.stack(np.array(cells)).astype(bool)
# convert to a list of masks
masks = [np.array(v) for v in masks.values()]
# find image data at the time point
# stored as an array arranged as (traps, channels, time points, X, Y, Z)
tiles = self.get_tiles(tp, tile_shape=tile_size, channels=tree_chs)
# generate boolean masks for background as a list with one mask per trap
bgs = np.array([])
if self.params.sub_bg:
bgs = ~np.array(
list(
map(
lambda x: np.sum(x, axis=0)
if np.any(x)
else np.zeros((tile_size, tile_size)),
masks,
)
)
).astype(bool)
# perform extraction by applying metrics
d = {}
self.img_bgsub = {}
for ch, red_metrics in tree.items():
cells = Cells(self.h5path)
# find the cell labels as dict with trap_ids as keys
cell_labels = self.get_cell_labels(tp, cell_labels, cells)
# get masks one per cell per trap
masks = self.get_masks(tp, masks, cells)
# find image data for all traps at the time point
# stored as an array arranged as (traps, channels, 1, Z, X, Y)
tiles = self.get_tiles(tp, channels=tree_dict["channels"])
# generate boolean masks for background for each trap
bgs = self.get_background_masks(masks, tile_size)
# get images and background corrected images as dicts
# with fluorescnce channels as keys
img, img_bgsub = self.get_imgs_background_subtract(
tree_dict, tiles, bgs
)
# perform extraction
res_one = self.extract_one_channel(
tree_dict, cell_labels, img, img_bgsub, masks, **kwargs
)
res_multiple = self.extract_multiple_channels(
cell_labels, img, img_bgsub, masks
)
res = {**res_one, **res_multiple}
return res
def get_imgs_background_subtract(self, tree_dict, tiles, bgs):
"""
Get two dicts of fluorescence images.
Return images and background subtracted image for all traps
for one time point.
"""
img = {}
img_bgsub = {}
for ch, _ in tree_dict["channels_tree"].items():
# NB ch != is necessary for threading
if ch != "general" and tiles is not None and len(tiles):
# image data for all traps and z sections for a particular channel
# as an array arranged as (tiles, Z, X, Y, )
img = tiles[:, tree_chs.index(ch), 0]
if tiles is not None and len(tiles):
# image data for all traps for a particular channel and
# time point arranged as (traps, Z, X, Y)
# we use 0 here to access the single time point available
img[ch] = tiles[:, tree_dict["channels"].index(ch), 0]
if (
bgs.any()
and ch in self.params.sub_bg
and img[ch] is not None
):
# subtract median background
bgsub_mapping = map(
# move Z to last column to allow subtraction
lambda img, bgs: np.moveaxis(img, 0, -1)
# median of background over all pixels for each Z section
- bn.median(img[:, bgs], axis=1),
img[ch],
bgs,
)
# apply map and convert to array
mapping_result = np.stack(list(bgsub_mapping))
# move Z axis back to the second column
img_bgsub[ch + "_bgsub"] = np.moveaxis(
mapping_result, -1, 1
)
else:
img = None
# apply metrics to image data
d[ch] = self.reduce_extract(
traps=img,
masks=masks,
red_metrics=red_metrics,
labels=labels,
**kwargs,
)
# apply metrics to image data with the background subtracted
if bgs.any() and ch in self.params.sub_bg and img is not None:
# calculate metrics with subtracted bg
ch_bs = ch + "_bgsub"
# subtract median background
self.img_bgsub[ch_bs] = np.moveaxis(
np.stack(
list(
map(
lambda tile, mask: np.moveaxis(tile, 0, -1)
- bn.median(tile[:, mask], axis=1),
img,
bgs,
)
)
),
-1,
1,
) # End with tiles, z, y, x
# apply metrics to background-corrected data
d[ch_bs] = self.reduce_extract(
red_metrics=ch_tree[ch],
traps=self.img_bgsub[ch_bs],
masks=masks,
labels=labels,
**kwargs,
)
# apply any metrics using multiple channels, such as pH calculations
for name, (
chs,
merge_fun,
red_metrics,
) in self.params.multichannel_ops.items():
if len(
set(chs).intersection(
set(self.img_bgsub.keys()).union(tree_chs)
)
) == len(chs):
channels_stack = np.stack(
[self.get_imgs(ch, tiles, tree_chs) for ch in chs], axis=-1
)
merged = RED_FUNS[merge_fun](channels_stack, axis=-1)
d[name] = self.reduce_extract(
red_metrics=red_metrics,
traps=merged,
masks=masks,
labels=labels,
**kwargs,
)
return d
img[ch] = None
img_bgsub[ch] = None
return img, img_bgsub
def get_imgs(self, channel: t.Optional[str], tiles, channels=None):
def get_imgs_old(self, channel: t.Optional[str], tiles, channels=None):
"""
Return image from a correct source, either raw or bgsub.
......@@ -555,14 +701,16 @@ class Extractor(StepABC):
channel: str
Name of channel to get.
tiles: ndarray
An array of the image data having dimensions of (tile_id, channel, tp, tile_size, tile_size, n_zstacks).
An array of the image data having dimensions of
(tile_id, channel, tp, tile_size, tile_size, n_zstacks).
channels: list of str (optional)
t.List of available channels.
Returns
-------
img: ndarray
An array of image data with dimensions (no tiles, X, Y, no Z channels)
An array of image data with dimensions
(no tiles, X, Y, no Z channels)
"""
if channels is None:
channels = (*self.params.tree,)
......@@ -579,7 +727,9 @@ class Extractor(StepABC):
**kwargs,
) -> dict:
"""
Wrapper to add compatibility with other steps of the pipeline.
Run extraction for one position and for the specified time points.
Save the results to a h5 file.
Parameters
----------
......@@ -597,7 +747,9 @@ class Extractor(StepABC):
Returns
-------
d: dict
A dict of the extracted data with a concatenated string of channel, reduction metric, and cell metric as keys and pd.Series of the extracted data as values.
A dict of the extracted data for one position with a concatenated
string of channel, reduction metric, and cell metric as keys and
pd.DataFrame of the extracted data for all time points as values.
"""
if tree is None:
tree = self.params.tree
......@@ -614,7 +766,7 @@ class Extractor(StepABC):
to="series",
tp=tp,
)
# concatenate with data extracted from early time points
# concatenate with data extracted from earlier time points
for k in new.keys():
d[k] = pd.concat((d.get(k, None), new[k]), axis=1)
# add indices to pd.Series containing the extracted data
......@@ -628,12 +780,12 @@ class Extractor(StepABC):
d[k].index.names = idx
# save
if save:
self.save_to_hdf(d)
self.save_to_h5(d)
return d
def save_to_hdf(self, dict_series, path=None):
def save_to_h5(self, dict_series, path=None):
"""
Save the extracted data to the h5 file.
Save the extracted data for one position to the h5 file.
Parameters
----------
......@@ -643,7 +795,7 @@ class Extractor(StepABC):
To the h5 file.
"""
if path is None:
path = self.local
path = self.h5path
self.writer = Writer(path)
for extract_name, series in dict_series.items():
dset_path = "/extraction/" + extract_name
......@@ -660,7 +812,6 @@ class Extractor(StepABC):
}
### Helpers
def flatten_nesteddict(
nest: dict, to="series", tp: int = None
) -> t.Dict[str, pd.Series]:
......@@ -672,14 +823,17 @@ def flatten_nesteddict(
nest: dict of dicts
Contains the nested results of extraction.
to: str (optional)
Specifies the format of the output, either pd.Series (default) or a list
Specifies the format of the output, either pd.Series (default)
or a list
tp: int
Time point used to name the pd.Series
Returns
-------
d: dict
A dict with a concatenated string of channel, reduction metric, and cell metric as keys and either a pd.Series or a list of the corresponding extracted data as values.
A dict with a concatenated string of channel, reduction metric,
and cell metric as keys and either a pd.Series or a list of the
corresponding extracted data as values.
"""
d = {}
for k0, v0 in nest.items():
......@@ -689,14 +843,3 @@ def flatten_nesteddict(
pd.Series(*v2, name=tp) if to == "series" else v2
)
return d
class hollowExtractor(Extractor):
"""
Extractor that only cares about receiving images and masks.
Used for testing.
"""
def __init__(self, parameters):
self.params = parameters
......@@ -90,9 +90,12 @@ def max2p5pc(cell_mask, trap_image) -> float:
return np.mean(top_values)
def max5px(cell_mask, trap_image) -> float:
def max5px_median(cell_mask, trap_image) -> float:
"""
Find the mean of the five brightest pixels in the cell.
Estimate the degree of localisation.
Find the mean of the five brightest pixels in the cell divided by the
median of all pixels.
Parameters
----------
......@@ -102,10 +105,17 @@ def max5px(cell_mask, trap_image) -> float:
"""
# sort pixels in cell
pixels = trap_image[cell_mask]
top_values = bn.partition(pixels, len(pixels) - 5)[-5:]
# find mean of five brightest pixels
max5px = np.mean(top_values)
return max5px
if len(pixels) > 5:
top_values = bn.partition(pixels, len(pixels) - 5)[-5:]
# find mean of five brightest pixels
max5px = np.mean(top_values)
med = np.median(pixels)
if med == 0:
return np.nan
else:
return max5px / np.median(pixels)
else:
return np.nan
def std(cell_mask, trap_image):
......@@ -193,3 +203,53 @@ def min_maj_approximation(cell_mask) -> t.Tuple[int]:
# + distance from the center of cone top to edge of cone top
maj_ax = np.round(np.max(dn) + np.sum(cone_top) / 2)
return min_ax, maj_ax
def moment_of_inertia(cell_mask, trap_image):
"""
Find moment of inertia - a measure of homogeneity.
From iopscience.iop.org/article/10.1088/1742-6596/1962/1/012028
which cites ieeexplore.ieee.org/document/1057692.
"""
# set pixels not in cell to zero
trap_image[~cell_mask] = 0
x = trap_image
if np.any(x):
# x-axis : column=x-axis
columnvec = np.arange(1, x.shape[1] + 1, 1)[:, None].T
# y-axis : row=y-axis
rowvec = np.arange(1, x.shape[0] + 1, 1)[:, None]
# find raw moments
M00 = np.sum(x)
M10 = np.sum(np.multiply(x, columnvec))
M01 = np.sum(np.multiply(x, rowvec))
# find centroid
Xm = M10 / M00
Ym = M01 / M00
# find central moments
Mu00 = M00
Mu20 = np.sum(np.multiply(x, (columnvec - Xm) ** 2))
Mu02 = np.sum(np.multiply(x, (rowvec - Ym) ** 2))
# find invariants
Eta20 = Mu20 / Mu00 ** (1 + (2 + 0) / 2)
Eta02 = Mu02 / Mu00 ** (1 + (0 + 2) / 2)
# find moments of inertia
moi = Eta20 + Eta02
return moi
else:
return np.nan
def ratio(cell_mask, trap_image):
"""Find the median ratio between two fluorescence channels."""
if trap_image.ndim == 3 and trap_image.shape[-1] == 2:
fl_0 = trap_image[..., 0][cell_mask]
fl_1 = trap_image[..., 1][cell_mask]
if np.any(fl_1 == 0):
div = np.nan
else:
div = np.median(fl_0 / fl_1)
else:
div = np.nan
return div
""" How to do the nuc Est Conv from MATLAB
"""
How to do the nuc Est Conv from MATLAB
Based on the code in MattSegCode/Matt Seg
GUI/@timelapseTraps/extractCellDataStacksParfor.m
Especially lines 342 to 399.
Especially lines 342 to 399.
This part only replicates the method to get the nuc_est_conv values
"""
import typing as t
......
# File with defaults for ease of use
import re
import typing as t
from pathlib import Path
import h5py
# should we move these functions here?
from aliby.tile.tiler import find_channel_name
def exparams_from_meta(
meta: t.Union[dict, Path, str], extras: t.Collection[str] = ["ph"]
):
"""
Obtain parameters from metadata of the h5 file.
Compares a list of candidate channels using case-insensitive
REGEX to identify valid channels.
"""
meta = meta if isinstance(meta, dict) else load_metadata(meta)
base = {
"tree": {"general": {"None": ["area", "volume", "eccentricity"]}},
"multichannel_ops": {},
}
candidate_channels = {
"Citrine",
"GFP",
"GFPFast",
"mCherry",
"pHluorin405",
"pHluorin488",
"Flavin",
"Cy5",
"mKO2",
}
default_reductions = {"max"}
default_metrics = {
"mean",
"median",
"std",
"imBackground",
"max5px",
# "nuc_est_conv",
}
# define ratiometric combinations
# key is numerator and value is denominator
# add more to support additional channel names
ratiometric_combinations = {"phluorin405": ("phluorin488", "gfpfast")}
default_reduction_metrics = {
r: default_metrics for r in default_reductions
}
# default_rm["None"] = ["nuc_conv_3d"] # Uncomment this to add nuc_conv_3d (slow)
extant_fluorescence_ch = []
for av_channel in candidate_channels:
# find matching channels in metadata
found_channel = find_channel_name(meta.get("channels", []), av_channel)
if found_channel is not None:
extant_fluorescence_ch.append(found_channel)
for ch in extant_fluorescence_ch:
base["tree"][ch] = default_reduction_metrics
base["sub_bg"] = extant_fluorescence_ch
# additional extraction defaults if the channels are available
if "ph" in extras:
# SWAINLAB specific names
# find first valid combination of ratiometric fluorescence channels
numerator_channel, denominator_channel = (None, None)
for ch1, chs2 in ratiometric_combinations.items():
found_channel1 = find_channel_name(extant_fluorescence_ch, ch1)
if found_channel1 is not None:
numerator_channel = found_channel1
for ch2 in chs2:
found_channel2 = find_channel_name(
extant_fluorescence_ch, ch2
)
if found_channel2:
denominator_channel = found_channel2
break
# if two compatible ratiometric channels are available
if numerator_channel is not None and denominator_channel is not None:
sets = {
b + a: (x, y)
for a, x in zip(
["", "_bgsub"],
(
[numerator_channel, denominator_channel],
[
f"{numerator_channel}_bgsub",
f"{denominator_channel}_bgsub",
],
),
)
for b, y in zip(["em_ratio", "gsum"], ["div0", "add"])
}
for i, v in sets.items():
base["multichannel_ops"][i] = [
*v,
default_reduction_metrics,
]
return base
def load_metadata(file: t.Union[str, Path], group="/"):
"""Get meta data from an h5 file."""
with h5py.File(file, "r") as f:
meta = dict(f[group].attrs.items())
return meta
......@@ -44,5 +44,6 @@ def reduce_z(trap_image: np.ndarray, fun: t.Callable, axis: int = 0):
elif isinstance(fun, np.ufunc):
# optimise the reduction function if possible
return fun.reduce(trap_image, axis=axis)
else: # WARNING: Very slow, only use when no alternatives exist
else:
# WARNING: Very slow, only use when no alternatives exist
return np.apply_along_axis(fun, axis, trap_image)
......@@ -11,8 +11,10 @@ from extraction.core.functions.math_utils import div0
"""
Load functions for analysing cells and their background.
Note that inspect.getmembers returns a list of function names and functions,
and inspect.getfullargspec returns a function's arguments.
Note that inspect.getmembers returns a list of function names
and functions, and inspect.getfullargspec returns a
function's arguments.
"""
......@@ -66,7 +68,7 @@ def load_cellfuns():
# create dict of the core functions from cell.py - these functions apply to a single mask
cell_funs = load_cellfuns_core()
# create a dict of functions that apply the core functions to an array of cell_masks
CELLFUNS = {}
CELL_FUNS = {}
for f_name, f in cell_funs.items():
if isfunction(f):
......@@ -79,27 +81,27 @@ def load_cellfuns():
# function that applies f to m and img, the trap_image
return lambda m, img: trap_apply(f, m, img)
CELLFUNS[f_name] = tmp(f)
return CELLFUNS
CELL_FUNS[f_name] = tmp(f)
return CELL_FUNS
def load_trapfuns():
"""Load functions that are applied to an entire tile."""
TRAPFUNS = {
TRAP_FUNS = {
f[0]: f[1]
for f in getmembers(trap)
if isfunction(f[1])
and f[1].__module__.startswith("extraction.core.functions")
}
return TRAPFUNS
return TRAP_FUNS
def load_funs():
"""Combine all automatically loaded functions."""
CELLFUNS = load_cellfuns()
TRAPFUNS = load_trapfuns()
CELL_FUNS = load_cellfuns()
TRAP_FUNS = load_trapfuns()
# return dict of cell funs, dict of trap funs, and dict of both
return CELLFUNS, TRAPFUNS, {**TRAPFUNS, **CELLFUNS}
return CELL_FUNS, TRAP_FUNS, {**TRAP_FUNS, **CELL_FUNS}
def load_redfuns(
......
......@@ -53,17 +53,58 @@ from pyparsing import (
atomic = t.Union[str, int, float, bool]
# specify grammar for the Swain lab
sl_grammar = {
"general": {
"start_trigger": Literal("Swain Lab microscope experiment log file"),
"data_type": "fields",
"end_trigger": "-----Acquisition settings-----",
},
"image_config": {
"start_trigger": "Image Configs:",
"data_type": "table",
},
"device_properties": {
"start_trigger": "Device properties:",
"data_type": "table",
},
"group": {
"position": {
"start_trigger": Group(
Group(Literal("group:") + Word(printables))
+ Group(Literal("field:") + "position")
),
"data_type": "table",
},
**{
key: {
"start_trigger": Group(
Group(Literal("group:") + Word(printables))
+ Group(Literal("field:") + key)
),
"data_type": "fields",
}
for key in ("time", "config")
},
},
}
ACQ_START = "-----Acquisition settings-----"
HEADER_END = "-----Experiment started-----"
MAX_NLINES = 2000 # In case of malformed logfile
ParserElement.setDefaultWhitespaceChars(" \t")
class HeaderEndNotFound(Exception):
def __init__(self, message, errors):
super().__init__(message)
self.errors = errors
def extract_header(filepath: Path):
# header_contents = ""
with open(filepath, "r") as f:
"""Extract content of log file before the experiment starts."""
with open(filepath, "r", errors="ignore") as f:
try:
header = ""
for _ in range(MAX_NLINES):
......@@ -72,16 +113,50 @@ def extract_header(filepath: Path):
if HEADER_END in line:
break
except HeaderEndNotFound as e:
print(f"{MAX_NLINES} checked and no header found")
print(f"{MAX_NLINES} checked and no header found.")
raise (e)
return header
def parse_from_swainlab_grammar(filepath: t.Union[str, Path]):
"""Parse using a grammar for the Swain lab."""
return parse_from_grammar(filepath, sl_grammar)
def parse_from_grammar(filepath: str, grammar: t.Dict):
"""Parse a file using the specified grammar."""
header = extract_header(filepath)
d = {}
for key, values in grammar.items():
try:
if "data_type" in values:
# data_type for parse_x defined in values
d[key] = parse_x(header, **values)
else:
# use sub keys to parse groups
for subkey, subvalues in values.items():
subkey = "_".join((key, subkey))
d[subkey] = parse_x(header, **subvalues)
except Exception as e:
logging.getLogger("aliby").critical(
f"Parsing failed for key {key} and values {values}."
)
raise (e)
return d
def parse_x(string, data_type, **kwargs):
"""Parse a string for data of a specified type."""
res_dict = eval(f"parse_{data_type}(string, **kwargs)")
return res_dict
def parse_table(
string: str,
start_trigger: t.Union[str, Keyword],
) -> pd.DataFrame:
"""Parse csv-like table
"""
Parse csv-like table.
Parameters
----------
......@@ -98,12 +173,9 @@ def parse_table(
Examples
--------
>>> table = parse_table()
"""
if isinstance(start_trigger, str):
start_trigger: Keyword = Keyword(start_trigger)
EOL = LineEnd().suppress()
field = OneOrMore(CharsNotIn(":,\n"))
line = LineStart() + Group(
......@@ -116,11 +188,9 @@ def parse_table(
+ EOL # end_trigger.suppress()
)
parser_result = parser.search_string(string)
assert all(
[len(row) == len(parser_result[0]) for row in parser_result]
), f"Table {start_trigger} has unequal number of columns"
assert len(parser_result), f"Parsing is empty. {parser}"
return table_to_df(parser_result.as_list())
......@@ -139,16 +209,12 @@ def parse_fields(
start: 0
interval: 300
frames: 180
"""
EOL = LineEnd().suppress()
if end_trigger is None:
end_trigger = EOL
elif isinstance(end_trigger, str):
end_trigger = Literal(end_trigger)
field = OneOrMore(CharsNotIn(":\n"))
line = (
LineStart()
......@@ -164,79 +230,6 @@ def parse_fields(
return fields_to_dict_or_table(results)
# Grammar specification
grammar = {
"general": {
"start_trigger": Literal("Swain Lab microscope experiment log file"),
"type": "fields",
"end_trigger": "-----Acquisition settings-----",
},
"image_config": {
"start_trigger": "Image Configs:",
"type": "table",
},
"device_properties": {
"start_trigger": "Device properties:",
"type": "table",
},
"group": {
"position": {
"start_trigger": Group(
Group(Literal("group:") + Word(printables))
+ Group(Literal("field:") + "position")
),
"type": "table",
},
**{
key: {
"start_trigger": Group(
Group(Literal("group:") + Word(printables))
+ Group(Literal("field:") + key)
),
"type": "fields",
}
for key in ("time", "config")
},
},
}
ACQ_START = "-----Acquisition settings-----"
HEADER_END = "-----Experiment started-----"
MAX_NLINES = 2000 # In case of malformed logfile
# test_file = "/home/alan/Downloads/pH_med_to_low.log"
# test_file = "/home/alan/Documents/dev/skeletons/scripts/dev/C1_60x.log"
ParserElement.setDefaultWhitespaceChars(" \t")
# time_fields = parse_field(acq, start_trigger=grammar["group"]["time"]["start_trigger"])
# config_fields = parse_fields(
# acq, start_trigger=grammar["group"]["config"]["start_trigger"]
# )
# general_fields = parse_fields(basic, start_trigger=grammar["general"]["start_trigger"])
def parse_from_grammar(filepath: str, grammar: t.Dict):
header = extract_header(filepath)
d = {}
for key, values in grammar.items():
try:
if "type" in values:
d[key] = parse_x(header, **values)
else: # Use subkeys to parse groups
for subkey, subvalues in values.items():
subkey = "_".join((key, subkey))
d[subkey] = parse_x(header, **subvalues)
except Exception as e:
logging.getLogger("aliby").critical(
f"Parsing failed for key {key} and values {values}"
)
raise (e)
return d
def table_to_df(result: t.List[t.List]):
if len(result) > 1: # Multiple tables with ids to append
# Generate multiindex from "Name column"
......@@ -292,12 +285,3 @@ def _cast_type(x: str) -> t.Union[str, int, float, bool]:
except:
pass
return x
def parse_x(string: str, type: str, **kwargs):
# return eval(f"parse_{type}({string}, **{kwargs})")
return eval(f"parse_{type}(string, **kwargs)")
def parse_from_swainlab_grammar(filepath: t.Union[str, Path]):
return parse_from_grammar(filepath, grammar)
#!/usr/bin/env jupyter
import re
import typing as t
from copy import copy
import pandas as pd
from agora.io.signal import Signal
from agora.utils.kymograph import bidirectional_retainment_filter
from postprocessor.core.abc import get_process
class Chainer(Signal):
"""
Extend Signal by applying post-processes and allowing composite signals that combine basic signals.
It "chains" multiple processes upon fetching a dataset to produce the desired datasets.
Instead of reading processes previously applied, it executes
them when called.
"""
_synonyms = {
"m5m": ("extraction/GFP/max/max5px", "extraction/GFP/max/median")
}
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def replace_path(path: str, bgsub: bool = ""):
# function to add bgsub to paths
channel = path.split("/")[1]
suffix = "_bgsub" if bgsub else ""
path = re.sub(channel, f"{channel}{suffix}", path)
return path
# Add chain with and without bgsub for composite statistics
self.common_chains = {
alias
+ bgsub: lambda **kwargs: self.get(
replace_path(denominator, alias + bgsub), **kwargs
)
/ self.get(replace_path(numerator, alias + bgsub), **kwargs)
for alias, (denominator, numerator) in self._synonyms.items()
for bgsub in ("", "_bgsub")
}
def get(
self,
dataset: str,
chain: t.Collection[str] = ("standard", "interpolate", "savgol"),
in_minutes: bool = True,
stages: bool = True,
retain: t.Optional[float] = None,
**kwargs,
):
"""Load data from an h5 file."""
if dataset in self.common_chains:
# get dataset for composite chains
data = self.common_chains[dataset](**kwargs)
else:
# use Signal's get_raw
data = self.get_raw(dataset, in_minutes=in_minutes, lineage=True)
if chain:
data = self.apply_chain(data, chain, **kwargs)
if retain:
# keep data only from early time points
data = self.get_retained(data, retain)
if stages and "stage" not in data.columns.names:
# return stages as additional column level
stages_index = [
x
for i, (name, span) in enumerate(self.stages_span_tp)
for x in (f"{i} { name }",) * span
]
data.columns = pd.MultiIndex.from_tuples(
zip(stages_index, data.columns),
names=("stage", "time"),
)
return data
def apply_chain(
self, input_data: pd.DataFrame, chain: t.Tuple[str, ...], **kwargs
):
"""
Apply a series of processes to a data set.
Like postprocessing, Chainer consecutively applies processes.
Parameters can be passed as kwargs.
Chainer does not support applying the same process multiple times with different parameters.
Parameters
----------
input_data : pd.DataFrame
Input data to process.
chain : t.Tuple[str, ...]
Tuple of strings with the names of the processes
**kwargs : kwargs
Arguments passed on to Process.as_function() method to modify the parameters.
Examples
--------
FIXME: Add docs.
"""
result = copy(input_data)
self._intermediate_steps = []
for process in chain:
if process == "standard":
result = bidirectional_retainment_filter(result)
else:
params = kwargs.get(process, {})
process_cls = get_process(process)
result = process_cls.as_function(result, **params)
process_type = process_cls.__module__.split(".")[-2]
if process_type == "reshapers":
if process == "merger":
raise (NotImplementedError)
self._intermediate_steps.append(result)
return result
......@@ -46,18 +46,16 @@ def get_process(process, suffix="") -> PostProcessABC or ParametersABC or None:
_to_snake_case(process),
_to_pascal_case(_to_snake_case(process)),
)
found = None
for possible_location, process_syntax in product(
possible_locations, valid_syntaxes
):
location = f"{base_location}.{possible_location}.{_to_snake_case(process)}.{process_syntax}{suffix}"
# instantiate class but not a class object
found = locate(location)
if found is not None:
break
else:
raise Exception(
f"{process} not found in locations {possible_locations} at {base_location}"
)
......
"""
Functions to process, filter and merge tracks.
"""
Functions to process, filter, and merge tracks.
We call two tracks contiguous if they are adjacent in time: the
maximal time point of one is one time point less than the
minimal time point of the other.
# from collections import Counter
A right track can have multiple potential left tracks. We pick the best.
"""
import typing as t
from copy import copy
......@@ -17,8 +21,385 @@ from utils_find_1st import cmp_larger, find_1st
from postprocessor.core.processes.savgol import non_uniform_savgol
def get_merges(tracks, smooth=False, tol=0.2, window=5, degree=3) -> dict:
"""
Find all pairs of tracks that should be joined.
Each track is defined by (trap_id, cell_id).
If there are multiple choices of which, say, left tracks to join to a
right track, pick the best using the Signal values to do so.
To score two tracks, we predict the future value of a left track and
compare with the mean initial values of a right track.
Parameters
----------
tracks: pd.DataFrame
A Signal, usually area, where rows are cell tracks and columns are
time points.
smooth: boolean
If True, smooth tracks with a savgol_filter.
tol: float < 1 or int
If int, compare the absolute distance between predicted values
for the left and right end points of two contiguous tracks.
If float, compare the distance relative to the magnitude of the
end point of the left track.
window: int
Length of window used for predictions and for any savgol_filter.
degree: int
The degree of the polynomial used by the savgol_filter.
"""
# only consider time series with more than two non-NaN data points
tracks = tracks.loc[tracks.notna().sum(axis=1) > 2]
# get contiguous tracks
if smooth:
# specialise to tracks with growing cells and of long duration
clean = clean_tracks(tracks, min_duration=window + 1, min_gr=0.9)
contigs = clean.groupby(["trap"]).apply(get_contiguous_pairs)
else:
contigs = tracks.groupby(["trap"]).apply(get_contiguous_pairs)
# remove traps with no contiguous tracks
contigs = contigs.loc[contigs.apply(len) > 0]
# flatten to (trap, cell_id) pairs
flat = set([k for v in contigs.values for i in v for j in i for k in j])
# make a data frame of contiguous tracks with the tracks as arrays
if smooth:
smoothed_tracks = clean.loc[flat].apply(
lambda x: non_uniform_savgol(x.index, x.values, window, degree),
axis=1,
)
else:
smoothed_tracks = tracks.loc[flat].apply(
lambda x: np.array(x.values), axis=1
)
# get the Signal values for neighbouring end points of contiguous tracks
actual_edge_values = contigs.apply(
lambda x: get_edge_values(x, smoothed_tracks)
)
# get the predicted values
predicted_edge_values = contigs.apply(
lambda x: get_predicted_edge_values(x, smoothed_tracks, window)
)
# score predicted edge values: low values are best
prediction_scores = predicted_edge_values.apply(get_dMetric_wrap)
# find contiguous tracks to join for each trap
trap_contigs_to_join = []
for idx in contigs.index:
local_contigs = contigs.loc[idx]
# find indices of best left and right tracks to join
best_indices = find_best_from_scores_wrap(
prediction_scores.loc[idx], actual_edge_values.loc[idx], tol=tol
)
# find tracks from the indices
trap_contigs_to_join.append(
[
(contig[0][left], contig[1][right])
for best_index, contig in zip(best_indices, local_contigs)
for (left, right) in best_index
if best_index
]
)
# return only the pairs of contiguous tracks
contigs_to_join = [
contigs
for trap_tracks in trap_contigs_to_join
for contigs in trap_tracks
]
merges = np.array(contigs_to_join, dtype=int)
return merges
def clean_tracks(
tracks, min_duration: int = 15, min_gr: float = 1.0
) -> pd.DataFrame:
"""Remove small non-growing tracks and return the reduced data frame."""
ntps = tracks.apply(max_ntps, axis=1)
grs = tracks.apply(get_avg_gr, axis=1)
growing_long_tracks = tracks.loc[(ntps >= min_duration) & (grs > min_gr)]
return growing_long_tracks
def get_contiguous_pairs(tracks: pd.DataFrame) -> list:
"""
Get all pair of contiguous track ids from a tracks data frame.
For two tracks to be contiguous, they must be exactly adjacent.
Parameters
----------
tracks: pd.Dataframe
A dataframe where rows are cell tracks and columns are time
points.
"""
# TODO add support for skipping time points
# find time points bounding tracks of non-NaN values
mins, maxs = [
tracks.notna().apply(np.where, axis=1).apply(fn)
for fn in (np.min, np.max)
]
# flip so that time points become the index
mins_d = mins.groupby(mins).apply(lambda x: x.index.tolist())
maxs_d = maxs.groupby(maxs).apply(lambda x: x.index.tolist())
# reduce minimal time point to make a right track overlap with a left track
mins_d.index = mins_d.index - 1
# find common end points
common = sorted(set(mins_d.index).intersection(maxs_d.index), reverse=True)
contigs = [(maxs_d[t], mins_d[t]) for t in common]
return contigs
def get_edge_values(contigs_ids, smoothed_tracks):
"""Get Signal values for adjacent end points for each contiguous track."""
values = [
(
[get_value(smoothed_tracks.loc[pre_id], -1) for pre_id in pre_ids],
[
get_value(smoothed_tracks.loc[post_id], 0)
for post_id in post_ids
],
)
for pre_ids, post_ids in contigs_ids
]
return values
def get_predicted_edge_values(contigs_ids, smoothed_tracks, window):
"""
Find neighbouring values of two contiguous tracks.
Predict the next value for the leftmost track using window values
and find the mean of the initial window values of the rightmost
track.
"""
result = []
for pre_ids, post_ids in contigs_ids:
pre_res = []
# left contiguous tracks
for pre_id in pre_ids:
# get last window values of a track
y = get_values_i(smoothed_tracks.loc[pre_id], -window)
# predict next value
pre_res.append(
np.poly1d(np.polyfit(range(len(y)), y, 1))(len(y) + 1),
)
# right contiguous tracks
pos_res = [
# mean value of initial window values of a track
get_mean_value_i(smoothed_tracks.loc[post_id], window)
for post_id in post_ids
]
result.append([pre_res, pos_res])
return result
def get_dMetric_wrap(lst: List, **kwargs):
"""Calculate dMetric on a list."""
return [get_dMetric(*sublist, **kwargs) for sublist in lst]
def get_dMetric(pre_values: List[float], post_values: List[float]):
"""
Calculate a scoring matrix based on the difference between two Signal
values.
We generate one score per pair of contiguous tracks.
Lower scores are better.
Parameters
----------
pre_values: list of floats
Values of the Signal for left contiguous tracks.
post_values: list of floats
Values of the Signal for right contiguous tracks.
"""
if len(pre_values) > len(post_values):
dMetric = np.abs(np.subtract.outer(post_values, pre_values))
else:
dMetric = np.abs(np.subtract.outer(pre_values, post_values))
# replace NaNs with maximal values
dMetric[np.isnan(dMetric)] = 1 + np.nanmax(dMetric)
return dMetric
def find_best_from_scores_wrap(dMetric: List, edges: List, **kwargs):
"""Calculate solve_matrices on a list."""
return [
find_best_from_scores(mat, edgeset, **kwargs)
for mat, edgeset in zip(dMetric, edges)
]
def find_best_from_scores(
scores: np.ndarray, actual_edge_values: List, tol: Union[float, int] = 1
):
"""Find indices for left and right contiguous tracks with scores below a tolerance."""
ids = find_best_indices(scores)
if len(ids[0]):
pre_value, post_value = actual_edge_values
# score with relative or absolute distance
norm = (
np.array(pre_value)[ids[len(pre_value) > len(post_value)]]
if tol < 1
else 1
)
best_scores = scores[ids] / norm
ids = ids if len(pre_value) < len(post_value) else ids[::-1]
# keep only indices with best_score less than the tolerance
indices = [
idx for idx, score in zip(zip(*ids), best_scores) if score <= tol
]
return indices
else:
return []
def find_best_indices(dMetric):
"""Find indices for left and right contiguous tracks with minimal scores."""
glob_is = []
glob_js = []
if (~np.isnan(dMetric)).any():
lMetric = copy(dMetric)
sortedMetric = sorted(lMetric[~np.isnan(lMetric)])
while (~np.isnan(sortedMetric)).any():
# indices of point with the lowest score
i_s, j_s = np.where(lMetric == sortedMetric[0])
i = i_s[0]
j = j_s[0]
# store this point
glob_is.append(i)
glob_js.append(j)
# remove from lMetric
lMetric[i, :] += np.nan
lMetric[:, j] += np.nan
sortedMetric = sorted(lMetric[~np.isnan(lMetric)])
indices = (np.array(glob_is), np.array(glob_js))
return indices
def get_value(x, n):
"""Get value from an array ignoring NaN."""
val = x[~np.isnan(x)][n] if len(x[~np.isnan(x)]) else np.nan
return val
def get_mean_value_i(x, i):
"""Get track's mean Signal value from values either from or up to an index."""
if not len(x[~np.isnan(x)]):
return np.nan
else:
if i > 0:
v = x[~np.isnan(x)][:i]
else:
v = x[~np.isnan(x)][i:]
return np.nanmean(v)
def get_values_i(x, i):
"""Get track's Signal values either from or up to an index."""
if not len(x[~np.isnan(x)]):
return np.nan
else:
if i > 0:
v = x[~np.isnan(x)][:i]
else:
v = x[~np.isnan(x)][i:]
return v
def get_avg_gr(track: pd.Series) -> float:
"""Get average growth rate for a track."""
ntps = max_ntps(track)
vals = track.dropna().values
gr = (vals[-1] - vals[0]) / ntps
return gr
######################################################################
def get_joinable_original(
tracks, smooth=False, tol=0.1, window=5, degree=3
) -> dict:
"""
Get the pair of track (without repeats) that have a smaller error than the
tolerance. If there is a track that can be assigned to two or more other
ones, choose the one with lowest error.
Parameters
----------
tracks: (m x n) Signal
A Signal, usually area, dataframe where rows are cell tracks and
columns are time points.
tol: float or int
threshold of average (prediction error/std) necessary
to consider two tracks the same. If float is fraction of first track,
if int it is absolute units.
window: int
value of window used for savgol_filter
degree: int
value of polynomial degree passed to savgol_filter
"""
# only consider time series with more than two non-NaN data points
tracks = tracks.loc[tracks.notna().sum(axis=1) > 2]
# get contiguous tracks
if smooth:
# specialise to tracks with growing cells and of long duration
clean = clean_tracks(tracks, min_duration=window + 1, min_gr=0.9)
contigs = clean.groupby(["trap"]).apply(get_contiguous_pairs)
else:
contigs = tracks.groupby(["trap"]).apply(get_contiguous_pairs)
# remove traps with no contiguous tracks
contigs = contigs.loc[contigs.apply(len) > 0]
# flatten to (trap, cell_id) pairs
flat = set([k for v in contigs.values for i in v for j in i for k in j])
# make a data frame of contiguous tracks with the tracks as arrays
if smooth:
smoothed_tracks = clean.loc[flat].apply(
lambda x: non_uniform_savgol(x.index, x.values, window, degree),
axis=1,
)
else:
smoothed_tracks = tracks.loc[flat].apply(
lambda x: np.array(x.values), axis=1
)
# get the Signal values for neighbouring end points of contiguous tracks
actual_edge_values = contigs.apply(
lambda x: get_edge_values(x, smoothed_tracks)
)
# get the predicted values
predicted_edge_values = contigs.apply(
lambda x: get_predicted_edge_values(x, smoothed_tracks, window)
)
# score predicted edge values
prediction_scores = predicted_edge_values.apply(get_dMetric_wrap, tol=tol)
solutions = [
# for all sets of contigs at a trap
find_best_from_scores_wrap(cost, edge_values, tol=tol)
for (trap_id, cost), edge_values in zip(
prediction_scores.items(), actual_edge_values
)
]
closest_pairs = pd.Series(
solutions,
index=prediction_scores.index,
)
# match local with global ids
joinable_ids = [
localid_to_idx(closest_pairs.loc[i], contigs.loc[i])
for i in closest_pairs.index
]
contigs_to_join = [
contigs for trap_tracks in joinable_ids for contigs in trap_tracks
]
return contigs_to_join
def load_test_dset():
"""Load development dataset to test functions."""
"""Load test data set."""
return pd.DataFrame(
{
("a", 1, 1): [2, 5, np.nan, 6, 8] + [np.nan] * 5,
......@@ -45,49 +426,6 @@ def max_nonstop_ntps(track: pd.Series) -> int:
return max(consecutive_nonas_grouped)
def get_tracks_ntps(tracks: pd.DataFrame) -> pd.Series:
return tracks.apply(max_ntps, axis=1)
def get_avg_gr(track: pd.Series) -> int:
"""
Get average growth rate for a track.
:param tracks: Series with volume and timepoints as indices
"""
ntps = max_ntps(track)
vals = track.dropna().values
gr = (vals[-1] - vals[0]) / ntps
return gr
def get_avg_grs(tracks: pd.DataFrame) -> pd.DataFrame:
"""
Get average growth rate for a group of tracks
:param tracks: (m x n) dataframe where rows are cell tracks and
columns are timepoints
"""
return tracks.apply(get_avg_gr, axis=1)
def clean_tracks(
tracks, min_len: int = 15, min_gr: float = 1.0
) -> pd.DataFrame:
"""
Clean small non-growing tracks and return the reduced dataframe
:param tracks: (m x n) dataframe where rows are cell tracks and
columns are timepoints
:param min_len: int number of timepoints cells must have not to be removed
:param min_gr: float Minimum mean growth rate to assume an outline is growing
"""
ntps = get_tracks_ntps(tracks)
grs = get_avg_grs(tracks)
growing_long_tracks = tracks.loc[(ntps >= min_len) & (grs > min_gr)]
return growing_long_tracks
def merge_tracks(
tracks, drop=False, **kwargs
) -> t.Tuple[pd.DataFrame, t.Collection]:
......@@ -191,141 +529,6 @@ def join_track_pair(target, source):
return tgt_copy
def get_joinable(tracks, smooth=False, tol=0.1, window=5, degree=3) -> dict:
"""
Get the pair of track (without repeats) that have a smaller error than the
tolerance. If there is a track that can be assigned to two or more other
ones, choose the one with lowest error.
:param tracks: (m x n) dataframe where rows are cell tracks and
columns are timepoints
:param tol: float or int threshold of average (prediction error/std) necessary
to consider two tracks the same. If float is fraction of first track,
if int it is absolute units.
:param window: int value of window used for savgol_filter
:param degree: int value of polynomial degree passed to savgol_filter
"""
tracks = tracks.loc[tracks.notna().sum(axis=1) > 2]
# Commented because we are not smoothing in this step yet
# candict = {k:v for d in contig.values for k,v in d.items()}
# smooth all relevant tracks
if smooth: # Apply savgol filter TODO fix nans affecting edge placing
clean = clean_tracks(
tracks, min_len=window + 1, min_gr=0.9
) # get useful tracks
def savgol_on_srs(x):
return non_uniform_savgol(x.index, x.values, window, degree)
contig = clean.groupby(["trap"]).apply(get_contiguous_pairs)
contig = contig.loc[contig.apply(len) > 0]
flat = set([k for v in contig.values for i in v for j in i for k in j])
smoothed_tracks = clean.loc[flat].apply(savgol_on_srs, 1)
else:
contig = tracks.groupby(["trap"]).apply(get_contiguous_pairs)
contig = contig.loc[contig.apply(len) > 0]
flat = set([k for v in contig.values for i in v for j in i for k in j])
smoothed_tracks = tracks.loc[flat].apply(
lambda x: np.array(x.values), axis=1
)
# fetch edges from ids TODO (IF necessary, here we can compare growth rates)
def idx_to_edge(preposts):
return [
(
[get_val(smoothed_tracks.loc[pre], -1) for pre in pres],
[get_val(smoothed_tracks.loc[post], 0) for post in posts],
)
for pres, posts in preposts
]
# idx_to_means = lambda preposts: [
# (
# [get_means(smoothed_tracks.loc[pre], -window) for pre in pres],
# [get_means(smoothed_tracks.loc[post], window) for post in posts],
# )
# for pres, posts in preposts
# ]
def idx_to_pred(preposts):
result = []
for pres, posts in preposts:
pre_res = []
for pre in pres:
y = get_last_i(smoothed_tracks.loc[pre], -window)
pre_res.append(
np.poly1d(np.polyfit(range(len(y)), y, 1))(len(y) + 1),
)
pos_res = [
get_means(smoothed_tracks.loc[post], window) for post in posts
]
result.append([pre_res, pos_res])
return result
edges = contig.apply(idx_to_edge) # Raw edges
# edges_mean = contig.apply(idx_to_means) # Mean of both
pre_pred = contig.apply(idx_to_pred) # Prediction of pre and mean of post
# edges_dMetric = edges.apply(get_dMetric_wrap, tol=tol)
# edges_dMetric_mean = edges_mean.apply(get_dMetric_wrap, tol=tol)
edges_dMetric_pred = pre_pred.apply(get_dMetric_wrap, tol=tol)
# combined_dMetric = pd.Series(
# [
# [np.nanmin((a, b), axis=0) for a, b in zip(x, y)]
# for x, y in zip(edges_dMetric, edges_dMetric_mean)
# ],
# index=edges_dMetric.index,
# )
# closest_pairs = combined_dMetric.apply(get_vec_closest_pairs, tol=tol)
solutions = []
# for (i, dMetrics), edgeset in zip(combined_dMetric.items(), edges):
for (i, dMetrics), edgeset in zip(edges_dMetric_pred.items(), edges):
solutions.append(solve_matrices_wrap(dMetrics, edgeset, tol=tol))
closest_pairs = pd.Series(
solutions,
index=edges_dMetric_pred.index,
)
# match local with global ids
joinable_ids = [
localid_to_idx(closest_pairs.loc[i], contig.loc[i])
for i in closest_pairs.index
]
return [pair for pairset in joinable_ids for pair in pairset]
def get_val(x, n):
return x[~np.isnan(x)][n] if len(x[~np.isnan(x)]) else np.nan
def get_means(x, i):
if not len(x[~np.isnan(x)]):
return np.nan
if i > 0:
v = x[~np.isnan(x)][:i]
else:
v = x[~np.isnan(x)][i:]
return np.nanmean(v)
def get_last_i(x, i):
if not len(x[~np.isnan(x)]):
return np.nan
if i > 0:
v = x[~np.isnan(x)][:i]
else:
v = x[~np.isnan(x)][i:]
return v
def localid_to_idx(local_ids, contig_trap):
"""
Fetch the original ids from a nested list with joinable local_ids.
......@@ -351,60 +554,6 @@ def get_vec_closest_pairs(lst: List, **kwargs):
return [get_closest_pairs(*sublist, **kwargs) for sublist in lst]
def get_dMetric_wrap(lst: List, **kwargs):
return [get_dMetric(*sublist, **kwargs) for sublist in lst]
def solve_matrices_wrap(dMetric: List, edges: List, **kwargs):
return [
solve_matrices(mat, edgeset, **kwargs)
for mat, edgeset in zip(dMetric, edges)
]
def get_dMetric(
pre: List[float], post: List[float], tol: Union[float, int] = 1
):
"""Calculate a cost matrix
input
:param pre: list of floats with edges on left
:param post: list of floats with edges on right
:param tol: int or float if int metrics of tolerance, if float fraction
returns
:: list of indices corresponding to the best solutions for matrices
"""
if len(pre) > len(post):
dMetric = np.abs(np.subtract.outer(post, pre))
else:
dMetric = np.abs(np.subtract.outer(pre, post))
dMetric[np.isnan(dMetric)] = (
tol + 1 + np.nanmax(dMetric)
) # nans will be filtered
return dMetric
def solve_matrices(
dMetric: np.ndarray, prepost: List, tol: Union[float, int] = 1
):
"""
Solve the distance matrices obtained in get_dMetric and/or merged from
independent dMetric matrices.
"""
ids = solve_matrix(dMetric)
if not len(ids[0]):
return []
pre, post = prepost
norm = (
np.array(pre)[ids[len(pre) > len(post)]] if tol < 1 else 1
) # relative or absolute tol
result = dMetric[ids] / norm
ids = ids if len(pre) < len(post) else ids[::-1]
return [idx for idx, res in zip(zip(*ids), result) if res <= tol]
def get_closest_pairs(
pre: List[float], post: List[float], tol: Union[float, int] = 1
):
......@@ -422,41 +571,11 @@ def get_closest_pairs(
"""
dMetric = get_dMetric(pre, post, tol)
return solve_matrices(dMetric, pre, post, tol)
def solve_matrix(dMetric):
"""
Solve cost matrix focusing on getting the smallest cost at each iteration.
input
:param dMetric: np.array cost matrix
returns
tuple of np.arrays indicating picks with lowest individual value
"""
glob_is = []
glob_js = []
if (~np.isnan(dMetric)).any():
tmp = copy(dMetric)
std = sorted(tmp[~np.isnan(tmp)])
while (~np.isnan(std)).any():
v = std[0]
i_s, j_s = np.where(tmp == v)
i = i_s[0]
j = j_s[0]
tmp[i, :] += np.nan
tmp[:, j] += np.nan
glob_is.append(i)
glob_js.append(j)
std = sorted(tmp[~np.isnan(tmp)])
return (np.array(glob_is), np.array(glob_js))
return find_best_from_scores(dMetric, pre, post, tol)
def plot_joinable(tracks, joinable_pairs):
"""
Convenience plotting function for debugging and data vis
"""
"""Convenience plotting function for debugging."""
nx = 8
ny = 8
_, axes = plt.subplots(nx, ny)
......@@ -475,59 +594,3 @@ def plot_joinable(tracks, joinable_pairs):
# pass
ax.plot(post_srs.index, post_srs.values, "g")
plt.show()
def get_contiguous_pairs(tracks: pd.DataFrame) -> list:
"""
Get all pair of contiguous track ids from a tracks dataframe.
:param tracks: (m x n) dataframe where rows are cell tracks and
columns are timepoints
:param min_dgr: float minimum difference in growth rate from
the interpolation
"""
mins, maxes = [
tracks.notna().apply(np.where, axis=1).apply(fn)
for fn in (np.min, np.max)
]
mins_d = mins.groupby(mins).apply(lambda x: x.index.tolist())
mins_d.index = mins_d.index - 1 # make indices equal
# TODO add support for skipping time points
maxes_d = maxes.groupby(maxes).apply(lambda x: x.index.tolist())
common = sorted(
set(mins_d.index).intersection(maxes_d.index), reverse=True
)
return [(maxes_d[t], mins_d[t]) for t in common]
# def fit_track(track: pd.Series, obj=None):
# if obj is None:
# obj = objective
# x = track.dropna().index
# y = track.dropna().values
# popt, _ = curve_fit(obj, x, y)
# return popt
# def interpolate(track, xs) -> list:
# '''
# Interpolate next timepoint from a track
# :param track: pd.Series of volume growth over a time period
# :param t: int timepoint to interpolate
# '''
# popt = fit_track(track)
# # perr = np.sqrt(np.diag(pcov))
# return objective(np.array(xs), *popt)
# def objective(x,a,b,c,d) -> float:
# # return (a)/(1+b*np.exp(c*x))+d
# return (((x+d)*a)/((x+d)+b))+c
# def cand_pairs_to_dict(candidates):
# d={x:[] for x,_ in candidates}
# for x,y in candidates:
# d[x].append(y)
# return d
......@@ -7,24 +7,24 @@ import numpy as np
import pandas as pd
from agora.abc import ParametersABC
from agora.utils.kymograph import get_index_as_np
from postprocessor.core.abc import PostProcessABC
class LineageProcessParameters(ParametersABC):
"""
Parameters
"""
"""Parameters - none are necessary."""
_defaults = {}
class LineageProcess(PostProcessABC):
"""
Lineage process that must be passed a (N,3) lineage matrix (where the columns are trap, mother, daughter respectively)
To analyse lineage data.
Currently bare bones, but extracts lineage information from a Signal or Cells object.
"""
def __init__(self, parameters: LineageProcessParameters):
"""Initialise using PostProcessABC."""
super().__init__(parameters)
@abstractmethod
......@@ -34,6 +34,7 @@ class LineageProcess(PostProcessABC):
lineage: np.ndarray,
*args,
):
"""Implement method required by PostProcessABC - undefined."""
pass
@classmethod
......@@ -45,8 +46,9 @@ class LineageProcess(PostProcessABC):
**kwargs,
):
"""
Overrides PostProcess.as_function classmethod.
Lineage functions require lineage information to be passed if run as function.
Override PostProcesABC.as_function method.
Lineage functions require lineage information to be run as functions.
"""
parameters = cls.default_parameters(**kwargs)
return cls(parameters=parameters).run(
......@@ -54,9 +56,10 @@ class LineageProcess(PostProcessABC):
)
def get_lineage_information(self, signal=None, merged=True):
"""Get lineage as an array with tile IDs, mother and bud labels."""
if signal is not None and "mother_label" in signal.index.names:
lineage = get_index_as_np(signal)
# from kymograph
lineage = np.array(signal.index.to_list())
elif hasattr(self, "lineage"):
lineage = self.lineage
elif hasattr(self, "cells"):
......@@ -68,5 +71,5 @@ class LineageProcess(PostProcessABC):
elif self.cells is not None:
lineage = self.cells.mothers_daughters
else:
raise Exception("No linage information found")
raise Exception("No lineage information found")
return lineage