Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • swain-lab/aliby/aliby-mirror
  • swain-lab/aliby/alibylite
2 results
Show changes
Showing
with 692 additions and 233 deletions
......@@ -6,6 +6,8 @@ import numpy as np
import pandas as pd
from sklearn.cluster import KMeans
from agora.utils.indexing import validate_association
index_row = t.Tuple[str, str, int, int]
......@@ -120,7 +122,9 @@ def bidirectional_retainment_filter(
def melt_reset(df: pd.DataFrame, additional_ids: t.Dict[str, pd.Series] = {}):
new_df = add_index_levels(df, additional_ids)
return new_df.melt(ignore_index=False).reset_index()
return new_df.melt(
ignore_index=False, var_name="time (minutes)", value_name="signal"
).reset_index()
# Drop cells that if used would reduce info the most
......@@ -163,3 +167,79 @@ def slices_from_spans(spans: t.Tuple[int], df: pd.DataFrame) -> t.List[slice]:
for start, end in zip(cumsum[:-1], cumsum[1:])
]
return slices
def drop_mother_label(index: pd.MultiIndex) -> np.ndarray:
no_mother_label = index
if "mother_label" in index.names:
no_mother_label = index.droplevel("mother_label")
return np.array(no_mother_label.tolist())
def get_index_as_np(signal: pd.DataFrame):
# Get mother labels from multiindex dataframe
return np.array(signal.index.to_list())
def standard_filtering(
raw: pd.DataFrame,
lin: np.ndarray,
presence_high: float = 0.8,
presence_low: int = 7,
):
# Get all mothers
_, valid_indices = validate_association(
lin, np.array(raw.index.to_list()), match_column=0
)
in_lineage = raw.loc[valid_indices]
# Filter mothers by presence
present = in_lineage.loc[
in_lineage.notna().sum(axis=1) > (in_lineage.shape[1] * presence_high)
]
# Get indices
indices = np.array(present.index.to_list())
to_cast = np.stack((lin[:, :2], lin[:, [0, 2]]), axis=1)
ndin = to_cast[..., None] == indices.T[None, ...]
# use indices to fetch all daughters
valid_association = ndin.all(axis=2)[:, 0].any(axis=-1)
# Remove repeats
mothers, daughters = np.split(to_cast[valid_association], 2, axis=1)
mothers = mothers[:, 0]
daughters = daughters[:, 0]
d_m_dict = {tuple(d): m[-1] for m, d in zip(mothers, daughters)}
# assuming unique sorts
raw_mothers = raw.loc[_as_tuples(mothers)]
raw_mothers["mother_label"] = 0
raw_daughters = raw.loc[_as_tuples(daughters)]
raw_daughters["mother_label"] = d_m_dict.values()
concat = pd.concat((raw_mothers, raw_daughters)).sort_index()
concat.set_index("mother_label", append=True, inplace=True)
# Last filter to remove tracklets that are too short
removed_buds = concat.notna().sum(axis=1) <= presence_low
filt = concat.loc[~removed_buds]
# We check that no mothers are left child-less
m_d_dict = {tuple(m): [] for m in mothers}
for (trap, d), m in d_m_dict.items():
m_d_dict[(trap, m)].append(d)
for trap, daughter, mother in concat.index[removed_buds]:
idx_to_delete = m_d_dict[(trap, mother)].index(daughter)
del m_d_dict[(trap, mother)][idx_to_delete]
bud_free = []
for m, d in m_d_dict.items():
if not d:
bud_free.append(m)
final_result = filt.drop(bud_free)
# In the end, we get the mothers present for more than {presence_lineage1}% of the experiment
# and their tracklets present for more than {presence_lineage2} time-points
return final_result
......@@ -9,7 +9,7 @@ import numpy as np
import pandas as pd
from utils_find_1st import cmp_larger, find_1st
from agora.utils.association import validate_association
from agora.utils.indexing import compare_indices, validate_association
def apply_merges(data: pd.DataFrame, merges: np.ndarray):
......@@ -31,23 +31,29 @@ def apply_merges(data: pd.DataFrame, merges: np.ndarray):
"""
indices = data.index
if "mother_label" in indices.names:
indices = indices.droplevel("mother_label")
valid_merges, indices = validate_association(
merges, np.array(list(data.index))
merges, np.array(list(indices))
)
# Assign non-merged
merged = data.loc[~indices]
# Implement the merges and drop source rows.
# TODO Use matrices to perform merges in batch
# for ecficiency
if valid_merges.any():
to_merge = data.loc[indices]
for target, source in merges[valid_merges]:
target, source = tuple(target), tuple(source)
targets, sources = zip(*merges[valid_merges])
for source, target in zip(sources, targets):
target = tuple(target)
to_merge.loc[target] = join_tracks_pair(
to_merge.loc[target].values,
to_merge.loc[source].values,
to_merge.loc[tuple(source)].values,
)
to_merge.drop(source, inplace=True)
to_merge.drop(map(tuple, sources), inplace=True)
merged = pd.concat((merged, to_merge), names=data.index.names)
return merged
......@@ -56,9 +62,85 @@ def apply_merges(data: pd.DataFrame, merges: np.ndarray):
def join_tracks_pair(target: np.ndarray, source: np.ndarray) -> np.ndarray:
"""
Join two tracks and return the new value of the target.
TODO replace this with arrays only.
"""
target_copy = copy(target)
target_copy = target
end = find_1st(target_copy[::-1], 0, cmp_larger)
target_copy[-end:] = source[-end:]
return target_copy
def group_merges(merges: np.ndarray) -> t.List[t.Tuple]:
# Return a list where the cell is present as source and target
# (multimerges)
sources_targets = compare_indices(merges[:, 0, :], merges[:, 1, :])
is_multimerge = sources_targets.any(axis=0) | sources_targets.any(axis=1)
is_monomerge = ~is_multimerge
multimerge_subsets = union_find(zip(*np.where(sources_targets)))
merge_groups = [merges[np.array(tuple(x))] for x in multimerge_subsets]
sorted_merges = list(map(sort_association, merge_groups))
# Ensure that source and target are at the edges
return [
*sorted_merges,
*[[event] for event in merges[is_monomerge]],
]
def union_find(lsts):
sets = [set(lst) for lst in lsts if lst]
merged = True
while merged:
merged = False
results = []
while sets:
common, rest = sets[0], sets[1:]
sets = []
for x in rest:
if x.isdisjoint(common):
sets.append(x)
else:
merged = True
common |= x
results.append(common)
sets = results
return sets
def sort_association(array: np.ndarray):
# Sort the internal associations
order = np.where(
(array[:, 0, ..., None] == array[:, 1].T[None, ...]).all(axis=1)
)
res = []
[res.append(x) for x in np.flip(order).flatten() if x not in res]
sorted_array = array[np.array(res)]
return sorted_array
def merge_association(
association: np.ndarray, merges: np.ndarray
) -> np.ndarray:
grouped_merges = group_merges(merges)
flat_indices = association.reshape(-1, 2)
comparison_mat = compare_indices(merges[:, 0], flat_indices)
valid_indices = comparison_mat.any(axis=0)
if valid_indices.any(): # Where valid, perform transformation
replacement_d = {}
for dataset in grouped_merges:
for k in dataset:
replacement_d[tuple(k[0])] = dataset[-1][1]
flat_indices[valid_indices] = [
replacement_d[tuple(i)] for i in flat_indices[valid_indices]
]
merged_indices = flat_indices.reshape(-1, 2, 2)
return merged_indices
......@@ -6,7 +6,7 @@ import logging
import re
import time
import typing as t
from pathlib import Path, PosixPath
from pathlib import Path
from time import perf_counter
import baby.errors
......@@ -108,9 +108,7 @@ class BabyParameters(ParametersABC):
tf_version=2,
)
def update_baby_modelset(
self, path: t.Union[str, PosixPath, t.Dict[str, str]]
):
def update_baby_modelset(self, path: t.Union[str, Path, t.Dict[str, str]]):
"""
Replace default BABY model and flattener with another one from a folder outputted
by our standard retraining script.
......@@ -141,6 +139,14 @@ class BabyRunner(StepABC):
if parameters is None
else parameters.model_config
)
tiler_z = self.tiler.shape[-3]
model_name = self.model_config["flattener_file"]
if tiler_z != 5:
assert (
f"{tiler_z}z" in model_name
), f"Tiler z-stack ({tiler_z}) and Model shape ({model_name}) do not match "
self.brain = BabyBrain(**self.model_config)
self.crawler = BabyCrawler(self.brain)
self.bf_channel = self.tiler.ref_channel_index
......
#!/usr/bin/env jupyter
"""
Command Line Interface utilities.
"""
"""
Asynchronous annotation (in one thread). Used as a base to build threading-based annotation.
Currently only works on UNIX-like systems due to using "/" to split addresses.
Usage example
From python
$ python annotator.py --image_path path/to/folder/with/h5files --results_path path/to/folder/with/images/zarr --pos position_name --ncells max_n_to_annotate
As executable (installed via poetry)
$ annotator.py --image_path path/to/folder/with/h5files --results_path path/to/folder/with/images/zarr --pos position_name --ncells max_n_to_annotate
During annotation:
- Assign a (binary) label by typing '1' or '2'.
- Type 'u' to undo.
- Type 's' to skip.
- Type 'q' to quit.
File will be saved in: ./YYYY-MM-DD_annotation/annotation.csv, where YYYY-MM-DD is the current date.
"""
import argparse
import logging
import typing as t
from copy import copy
from datetime import datetime
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import readchar
import trio
from agora.utils.cast import _str_to_int
from aliby.utils.vis_tools import _sample_n_tiles_masks
from aliby.utils.plot import stretch
# Remove logging warnings
logging.getLogger("aliby").setLevel(40)
# Defaults
essential = {"image_path": "zarr", "results_path": "h5"}
param_values = dict(
out_dir=f"./{datetime.today().strftime('%Y_%m_%d')}_annotation/",
pos=None,
ncells=100,
min_tp=100,
max_tp=150,
seed=0,
)
annot_filename = "annotation.csv"
# Parsing
parser = argparse.ArgumentParser(
prog="aliby-annot-binary",
description="Annotate cells in a binary manner",
)
for i, arg in enumerate((*essential, *param_values)):
parser.add_argument(
f"--{arg}",
action="store",
default=param_values.get(arg),
required=i < len(essential),
)
args = parser.parse_args()
for i, k in enumerate((*essential, *param_values.keys())):
# Assign essential values as-is
if i < len(essential):
param_values[k] = getattr(args, k)
# Fill additional values
if passed_value := getattr(args, k):
param_values[k] = passed_value
try:
param_values[k] = _str_to_int(passed_value)
except Exception as exc:
pass
for k, suffix in essential.items(): # Autocomplete if fullpath not provided
if not str(param_values[k]).endswith(suffix):
param_values[k] = (
Path(param_values[k]) / f"{ param_values['pos'] }.{suffix}"
)
# Functions
async def generate_image(stack, skip: bool = False):
await trio.sleep(1)
result = np.random.randint(100, size=(10, 10))
stack.append(result)
async def draw(data, drawing):
if len(drawing) > 1:
for ax, img in zip(drawing, data):
if np.isnan(img).sum(): # Stretch masked channel
img = stretch(img)
ax.set_data(img)
else:
drawing.set_data(data)
plt.draw()
plt.pause(0.1)
def annotate_image(current_key=None, valid_values: t.Tuple[int] = (1, 2)):
# Show image to annotate
while current_key is None or current_key not in valid_values:
if current_key is not None:
print(
f"Invalid value. Please try with valid values {valid_values}"
)
if (current_key := readchar.readkey()) in "qsu":
# if (current_key := input()) in "qsu":
break
current_key = _parse_input(current_key, valid_values)
return current_key
async def generate_image(
generator,
location_stack: t.List[t.Tuple[np.ndarray, t.Tuple[int, int, int]]],
):
new_location_image = next(generator)
location_stack.append((new_location_image[0], new_location_image[1]))
def _parse_input(value: str, valid_values: t.Tuple[int]):
try:
return int(value)
except:
print(
f"Non-parsable value. Please try again with valid values {valid_values}"
)
return None
def write_annotation(
experiment_position: str,
out_dir: Path,
annotation: str,
location_stack: t.Tuple[t.Tuple[int, int, int], np.ndarray],
):
location, stack = location_stack
unique_location = list(map(str, (*experiment_position, *location)))
write_into_file(
out_dir / annot_filename,
",".join((*unique_location, str(annotation))) + "\n",
)
bg_zero = copy(stack[1])
bg_zero[np.isnan(bg_zero)] = 0
tosave = np.stack((stack[0], bg_zero.astype(int)))
# np.savez(out_dir / f"{'_'.join( unique_location )}.npz", tosave)
np.save(out_dir / f"{'.'.join( unique_location )}.npy", tosave)
def write_into_file(file_path: str, line: str):
with open(file_path, "a") as f:
f.write(str(line))
async def annotate_images(
image_path, results_path, out_dir, ncells, seed, interval
):
preemptive_cache = 3
location_stack = []
out_dir = Path(out_dir)
out_annot_file = str(out_dir / annot_filename)
generator = _sample_n_tiles_masks(
image_path, results_path, ncells, seed=seed, interval=interval
)
# Fetch a few positions preemtively
async with trio.open_nursery() as nursery:
for _ in range(preemptive_cache):
nursery.start_soon(generate_image, generator, location_stack)
print("parent: waiting for first annotations.")
_, ax = plt.subplots(figsize=(10, 8))
while not location_stack: # Wait until first image is loaded
await trio.sleep(0.1)
from aliby.utils.plot import plot_overlay
# drawing = ax.imshow(location_stack[0][1])
axes = plot_overlay(*location_stack[0][1], ax=ax.axes)
plt.show(block=False)
plt.draw()
plt.pause(0.5) # May be adjusted based on display speed
try:
out_dir.mkdir(parents=True)
except:
pass
if not Path(out_annot_file).exists():
write_into_file(
out_annot_file,
",".join(
(
"experiment",
"position",
"tile",
"cell_label",
"tp",
"annotation",
)
)
+ "\n",
)
# Loop until n_max or quit
for i in range(1, ncells - preemptive_cache + 1):
# Wait for input
print("Enter a key")
annotation = str(annotate_image())
if annotation == "q":
break
elif annotation == "s":
print("Skipping...")
# continue
elif annotation == "u":
i -= 1
elif isinstance(_str_to_int(annotation), int):
write_annotation(
str(results_path).split(".")[0].split("/")[-2:],
out_dir,
annotation,
location_stack[i],
)
print(location_stack[i][0])
# Append into annotations file
async with trio.open_nursery() as nursery:
nursery.start_soon(generate_image, generator, location_stack)
nursery.start_soon(draw, location_stack[i][1], axes)
print("Annotation done!")
# if __name__ == "__main__":
def annotate():
if any([param_values.get(k) is None for k in ("min_tp", "max_tp")]):
interval = None
else:
interval = (param_values["min_tp"], param_values["max_tp"])
print(param_values)
trio.run(
annotate_images,
param_values["image_path"],
param_values["results_path"],
param_values["out_dir"],
param_values["ncells"],
param_values["seed"],
interval,
)
#!/usr/bin/env jupyter
import argparse
from agora.utils.cast import _str_to_int
from aliby.pipeline import Pipeline, PipelineParameters
......@@ -19,7 +21,6 @@ def run():
Examples
--------
FIXME: Add docs.
FIXME: GTP-generated. Confirm manually.
"""
parser = argparse.ArgumentParser(
prog="aliby-run",
......@@ -37,23 +38,13 @@ def run():
"password": None,
}
def _cast_str(x: str or None):
"""
Cast string as int if possible. If Nonetype return None.
"""
if x:
try:
return int(x)
except:
return x
for k in param_values:
parser.add_argument(f"--{k}", action="store")
args = parser.parse_args()
for k in param_values:
if passed_value := _cast_str(getattr(args, k)):
if passed_value := _str_to_int(getattr(args, k)):
param_values[k] = passed_value
......
......@@ -10,7 +10,7 @@ import shutil
import time
import typing as t
from abc import ABC, abstractproperty, abstractmethod
from pathlib import Path, PosixPath
from pathlib import Path
from agora.io.bridge import BridgeH5
from aliby.io.image import ImageLocalOME
......@@ -54,10 +54,10 @@ class DatasetLocalABC(ABC):
Abstract Base class to find local files, either OME-XML or raw images.
"""
_valid_suffixes = ("tiff", "png")
_valid_suffixes = ("tiff", "png", "zarr")
_valid_meta_suffixes = ("txt", "log")
def __init__(self, dpath: t.Union[str, PosixPath], *args, **kwargs):
def __init__(self, dpath: t.Union[str, Path], *args, **kwargs):
self.path = Path(dpath)
def __enter__(self):
......@@ -110,7 +110,7 @@ class DatasetLocalABC(ABC):
class DatasetLocalDir(DatasetLocalABC):
"""Find paths to a data set, comprising multiple images in different folders."""
def __init__(self, dpath: t.Union[str, PosixPath], *args, **kwargs):
def __init__(self, dpath: t.Union[str, Path], *args, **kwargs):
super().__init__(dpath)
@property
......@@ -121,24 +121,33 @@ class DatasetLocalDir(DatasetLocalABC):
)
def get_images(self):
"""Return a dictionary of folder names and their paths."""
return {
folder.name: folder
for folder in self.path.glob("*/")
if any(
"""Return a dictionary of folder or file names and their paths.
FUTURE 3.12 use pathlib is_junction to pick Dir or File
"""
images = {
item.name: item
for item in self.path.glob("*/")
if item.is_dir()
and any(
path
for suffix in self._valid_suffixes
for path in folder.glob(f"*.{suffix}")
for path in item.glob(f"*.{suffix}")
)
or item.suffix[1:] in self._valid_suffixes
}
return images
class DatasetLocalOME(DatasetLocalABC):
"""Find names of images in a folder, assuming images in OME-TIFF format."""
def __init__(self, dpath: t.Union[str, PosixPath], *args, **kwargs):
def __init__(self, dpath: t.Union[str, Path], *args, **kwargs):
super().__init__(dpath)
assert len(self.get_images()), "No .tiff files found"
assert len(
self.get_images()
), f"No valid files found. Formats are {self._valid_suffixes}"
@property
def date(self):
......
......@@ -16,9 +16,10 @@ ImageDummy is a dummy class for silent failure testing.
import typing as t
from abc import ABC, abstractmethod, abstractproperty
from datetime import datetime
from pathlib import Path, PosixPath
from pathlib import Path
import dask.array as da
import numpy as np
import xmltodict
import zarr
from dask.array.image import imread
......@@ -33,34 +34,36 @@ def get_examples_dir():
return files("aliby").parent.parent / "examples" / "tiler"
def instatiate_image(source: t.Union[str, int, t.Dict[str, str], PosixPath]):
def instantiate_image(
source: t.Union[str, int, t.Dict[str, str], Path], **kwargs
):
"""Wrapper to instatiate the appropiate image
Parameters
----------
source : t.Union[str, int, t.Dict[str, str], PosixPath]
source : t.Union[str, int, t.Dict[str, str], Path]
Image identifier
Examples
--------
image_path = "path/to/image"]
with instatiate_image(image_path) as img:
with instantiate_image(image_path) as img:
print(imz.data, img.metadata)
"""
return dispatch_image(source)(source)
return dispatch_image(source)(source, **kwargs)
def dispatch_image(source: t.Union[str, int, t.Dict[str, str], PosixPath]):
def dispatch_image(source: t.Union[str, int, t.Dict[str, str], Path]):
"""
Wrapper to pick the appropiate Image class depending on the source of data.
"""
if isinstance(source, int):
if isinstance(source, (int, np.int64)):
from aliby.io.omero import Image
instatiator = Image
elif isinstance(source, dict) or (
isinstance(source, (str, PosixPath)) and Path(source).is_dir()
isinstance(source, (str, Path)) and Path(source).is_dir()
):
if Path(source).suffix == ".zarr":
instatiator = ImageZarr
......@@ -81,7 +84,7 @@ class BaseLocalImage(ABC):
_default_dimorder = "tczyx"
def __init__(self, path: t.Union[str, PosixPath]):
def __init__(self, path: t.Union[str, Path]):
# If directory, assume contents are naturally sorted
self.path = Path(path)
......@@ -249,7 +252,7 @@ class ImageDir(BaseLocalImage):
- Provides Dimorder as it is set in the filenames, or expects order during instatiation
"""
def __init__(self, path: t.Union[str, PosixPath], **kwargs):
def __init__(self, path: t.Union[str, Path], **kwargs):
super().__init__(path)
self.image_id = str(self.path.stem)
......@@ -305,7 +308,7 @@ class ImageZarr(BaseLocalImage):
skeletons/scripts/howto_omero/convert_clone_zarr_to_tiff.py
"""
def __init__(self, path: t.Union[str, PosixPath], **kwargs):
def __init__(self, path: t.Union[str, Path], **kwargs):
super().__init__(path)
self.set_meta()
try:
......
......@@ -5,7 +5,7 @@ Tools to manage I/O using a remote OMERO server.
import re
import typing as t
from abc import abstractmethod
from pathlib import PosixPath
from pathlib import Path
import dask.array as da
import numpy as np
......@@ -115,7 +115,7 @@ class BridgeOmero:
@classmethod
def server_info_from_h5(
cls,
filepath: t.Union[str, PosixPath],
filepath: t.Union[str, Path],
):
"""Return server info from hdf5 file.
......@@ -123,7 +123,7 @@ class BridgeOmero:
----------
cls : BridgeOmero
BridgeOmero class
filepath : t.Union[str, PosixPath]
filepath : t.Union[str, Path]
Location of hdf5 file.
Examples
......@@ -133,9 +133,8 @@ class BridgeOmero:
"""
# metadata = load_attributes(filepath)
bridge = BridgeH5(filepath)
server_info = safe_load(bridge.meta_h5["parameters"])["general"][
"server_info"
]
meta = safe_load(bridge.meta_h5["parameters"])["general"]
server_info = {k: meta[k] for k in ("host", "username", "password")}
return server_info
def set_id(self, ome_id: int):
......@@ -151,7 +150,7 @@ class BridgeOmero:
return valid_annotations
def add_file_as_annotation(
self, file_to_upload: t.Union[str, PosixPath], **kwargs
self, file_to_upload: t.Union[str, Path], **kwargs
):
"""Upload annotation to object on OMERO server. Only valid in subclasses.
......@@ -253,7 +252,7 @@ class Dataset(BridgeOmero):
@classmethod
def from_h5(
cls,
filepath: t.Union[str, PosixPath],
filepath: t.Union[str, Path],
):
"""Instatiate Dataset from a hdf5 file.
......@@ -261,7 +260,7 @@ class Dataset(BridgeOmero):
----------
cls : Image
Image class
filepath : t.Union[str, PosixPath]
filepath : t.Union[str, Path]
Location of hdf5 file.
Examples
......@@ -300,7 +299,7 @@ class Image(BridgeOmero):
@classmethod
def from_h5(
cls,
filepath: t.Union[str, PosixPath],
filepath: t.Union[str, Path],
):
"""Instatiate Image from a hdf5 file.
......@@ -308,7 +307,7 @@ class Image(BridgeOmero):
----------
cls : Image
Image class
filepath : t.Union[str, PosixPath]
filepath : t.Union[str, Path]
Location of hdf5 file.
Examples
......
#!/usr/bin/env jupyter
"""
Models that link regions of interest, such as mothers and buds.
"""
#!/usr/bin/env jupyter
"""
Extracted from the baby repository. Bud Tracker algorithm to link
cell outlines as mothers and buds.
"""
# /usr/bin/env jupyter
import pickle
from os.path import join
......
......@@ -6,7 +6,7 @@ import traceback
import typing as t
from copy import copy
from importlib.metadata import version
from pathlib import Path, PosixPath
from pathlib import Path
import h5py
import numpy as np
......@@ -76,14 +76,14 @@ class PipelineParameters(ParametersABC):
postprocessing: dict (optional)
Parameters for post-processing.
"""
# Alan: should 19993 be updated?
expt_id = general.get("expt_id", 19993)
if isinstance(expt_id, PosixPath):
if isinstance(expt_id, Path):
assert expt_id.exists()
expt_id = str(expt_id)
general["expt_id"] = expt_id
# Alan: an error message rather than a default might be better
directory = Path(general.get("directory", "../data"))
directory = Path(general["directory"])
# get log files, either locally or via OMERO
with dispatch_dataset(
......@@ -141,6 +141,18 @@ class PipelineParameters(ParametersABC):
# define defaults and update with any inputs
defaults["tiler"] = TilerParameters.default(**tiler).to_dict()
# Generate a backup channel, for when logfile meta is available
# but not image metadata.
backup_ref_channel = None
if "channels" in meta_d and isinstance(
defaults["tiler"]["ref_channel"], str
):
backup_ref_channel = meta_d["channels"].index(
defaults["tiler"]["ref_channel"]
)
defaults["tiler"]["backup_ref_channel"] = backup_ref_channel
defaults["baby"] = BabyParameters.default(**baby).to_dict()
defaults["extraction"] = (
exparams_from_meta(meta_d)
......@@ -174,8 +186,8 @@ class Pipeline(ProcessABC):
"extraction",
"postprocessing",
]
# Indicate step-writer groupings to perform special operations during step iteration
# Alan: replace with - specify the group in the h5 files written by each step (?)
# Specify the group in the h5 files written by each step
writer_groups = {
"tiler": ["trap_info"],
"baby": ["cell_info"],
......@@ -319,7 +331,7 @@ class Pipeline(ProcessABC):
config["general"]["directory"] = directory
self.setLogger(directory)
# pick particular images if desired
if pos_filter:
if pos_filter is not None:
if isinstance(pos_filter, list):
image_ids = {
k: v
......@@ -361,7 +373,7 @@ class Pipeline(ProcessABC):
def run_one_position(
self,
name_image_id: t.Tuple[str, str or PosixPath or int],
name_image_id: t.Tuple[str, str or Path or int],
index: t.Optional[int] = None,
):
"""Set up and run a pipeline for one position."""
......@@ -478,7 +490,7 @@ class Pipeline(ProcessABC):
f"Found {steps['tiler'].n_tiles} traps in {image.name}"
)
elif step == "baby":
# write state and pass info to ext (Alan: what's ext?)
# write state and pass info to Extractor
loaded_writers["state"].write(
data=steps[
step
......@@ -496,11 +508,12 @@ class Pipeline(ProcessABC):
frac_clogged_traps = self.check_earlystop(
filename, earlystop, steps["tiler"].tile_size
)
self._log(
f"{name}:Clogged_traps:{frac_clogged_traps}"
)
frac = np.round(frac_clogged_traps * 100)
pbar.set_postfix_str(f"{frac} Clogged")
if frac_clogged_traps > 0.3:
self._log(
f"{name}:Clogged_traps:{frac_clogged_traps}"
)
frac = np.round(frac_clogged_traps * 100)
pbar.set_postfix_str(f"{frac} Clogged")
else:
# stop if too many traps are clogged
self._log(
......@@ -555,7 +568,7 @@ class Pipeline(ProcessABC):
"""
# get the area of the cells organised by trap and cell number
s = Signal(filename)
df = s["/extraction/general/None/area"]
df = s.get_raw("/extraction/general/None/area")
# check the latest time points only
cells_used = df[
df.columns[-1 - es_parameters["ntps_to_eval"] : -1]
......@@ -573,10 +586,11 @@ class Pipeline(ProcessABC):
)
return (traps_above_nthresh & traps_above_athresh).mean()
# Alan: can both this method and the next be deleted?
# FIXME: Remove this functionality. It used to be for
# older hdf5 file formats.
def _load_config_from_file(
self,
filename: PosixPath,
filename: Path,
process_from: t.Dict[str, int],
trackers_state: t.List,
overwrite: t.Dict[str, bool],
......@@ -588,6 +602,8 @@ class Pipeline(ProcessABC):
process_from[k] += 1
return process_from, trackers_state, overwrite
# FIXME: Remove this functionality. It used to be for
# older hdf5 file formats.
@staticmethod
def legacy_get_last_tp(step: str) -> t.Callable:
"""Get last time-point in different ways depending
......@@ -608,7 +624,7 @@ class Pipeline(ProcessABC):
def _setup_pipeline(
self, image_id: int
) -> t.Tuple[
PosixPath,
Path,
MetaData,
t.Dict,
int,
......@@ -647,7 +663,7 @@ class Pipeline(ProcessABC):
States of any trackers from earlier runs.
"""
config = self.parameters.to_dict()
# Alan: session is never changed
# TODO Alan: Verify if session must be passed
session = None
earlystop = config["general"].get("earlystop", None)
process_from = {k: 0 for k in self.pipeline_steps}
......@@ -662,7 +678,7 @@ class Pipeline(ProcessABC):
}
# Set up
directory = general_config["directory"]
directory = config["general"]["directory"]
trackers_state: t.List[np.ndarray] = []
with dispatch_image(image_id)(image_id, **self.server_info) as image:
......@@ -700,8 +716,8 @@ class Pipeline(ProcessABC):
)
config["tiler"] = steps["tiler"].parameters.to_dict()
except Exception:
# Alan: a warning or log here?
pass
self._log(f"Overwriting tiling data")
if config["general"]["use_explog"]:
meta.run()
# add metadata not in the log file
......
......@@ -13,11 +13,12 @@ A peak-identifying algorithm recovers the x and y-axis location of traps in the
The experiment is stored as an array with a standard indexing order of (Time, Channels, Z-stack, X, Y).
"""
import logging
import re
import typing as t
import warnings
from functools import lru_cache
from pathlib import PosixPath
from pathlib import Path
import dask.array as da
import h5py
......@@ -26,7 +27,7 @@ from skimage.registration import phase_cross_correlation
from agora.abc import ParametersABC, StepABC
from agora.io.writer import BridgeH5
from aliby.io.image import ImageLocalOME, ImageDir, ImageDummy
from aliby.io.image import ImageDummy
from aliby.tile.traps import segment_traps
......@@ -182,12 +183,19 @@ class TileLocations:
class TilerParameters(ParametersABC):
"""Set default parameters for Tiler."""
"""
tile_size: int
ref_channel: str or int
ref_z: int
backup_ref_channel int or None, if int indicates the index for reference channel. Used when image does not include metadata, ref_channel is a string and channel names are included in parsed logfiles.
"""
_defaults = {
"tile_size": 117,
"ref_channel": "Brightfield",
"ref_z": 0,
"backup_ref_channel": None,
}
......@@ -223,9 +231,14 @@ class Tiler(StepABC):
self.image = image
self._metadata = metadata
self.channels = metadata.get(
"channels", list(range(metadata.get("size_c", 0)))
"channels",
list(range(metadata.get("size_c", 0))),
)
self.ref_channel = self.get_channel_index(parameters.ref_channel)
if self.ref_channel is None:
self.ref_channel = self.backup_ref_channel
self.ref_channel = self.get_channel_index(parameters.ref_channel)
self.tile_locs = tile_locs
try:
......@@ -290,7 +303,7 @@ class Tiler(StepABC):
def from_h5(
cls,
image,
filepath: t.Union[str, PosixPath],
filepath: t.Union[str, Path],
parameters: t.Optional[TilerParameters] = None,
):
"""
......@@ -578,23 +591,23 @@ class Tiler(StepABC):
"""Return index of reference channel."""
return self.get_channel_index(self.parameters.ref_channel)
def get_channel_index(self, channel: str or int):
def get_channel_index(self, channel: str or int) -> int or None:
"""
Find index for channel using regex.
Returns the first matched string.
Find index for channel using regex. Returns the first matched string.
If self.channels is integers (no image metadata) it returns None.
If channel is integer
Parameters
----------
channel: string or int
The channel or index to be used.
"""
if all(map(lambda x: isinstance(x, int), self.channels)):
channel = channel if isinstance(channel, int) else None
if isinstance(channel, str):
channel = find_channel_index(self.channels, channel)
if channel is None:
raise Warning(
f"Reference channel {channel} not in the available channels: {self.channels}"
)
return channel
@staticmethod
......@@ -639,7 +652,7 @@ class Tiler(StepABC):
return tile
# Alan: do we need these as well as get_channel_index and get_channel_name?
# FIXME: Refactor to support both channel or index
# self._log below is not defined
def find_channel_index(image_channels: t.List[str], channel: str):
"""
......@@ -649,7 +662,10 @@ def find_channel_index(image_channels: t.List[str], channel: str):
found = re.match(channel, ch, re.IGNORECASE)
if found:
if len(found.string) - (found.endpos - found.start()):
self._log(f"Channel {channel} matched {ch} using regex")
logging.getLogger("aliby").log(
logging.WARNING,
f"Channel {channel} matched {ch} using regex",
)
return i
......
......@@ -6,7 +6,7 @@ import pickle
import typing as t
from collections import Counter
from os.path import dirname, join
from pathlib import Path, PosixPath
from pathlib import Path
import numpy as np
from scipy.optimize import linear_sum_assignment
......@@ -70,11 +70,11 @@ class CellTracker(FeatureCalculator):
if extrafeats is None:
extrafeats = ()
if type(model) is str or type(model) is PosixPath:
if type(model) is str or type(model) is Path:
with open(Path(model), "rb") as f:
model = pickle.load(f)
if type(bak_model) is str or type(bak_model) is PosixPath:
if type(bak_model) is str or type(bak_model) is Path:
with open(Path(bak_model), "rb") as f:
bak_model = pickle.load(f)
......
......@@ -5,7 +5,7 @@ import re
import typing as t
from collections import Counter
from datetime import datetime
from pathlib import Path, PosixPath
from pathlib import Path
import numpy as np
from logfile_parser import Parser
......@@ -392,7 +392,7 @@ def get_chs(exptype):
def load_annot_from_cache(exp_id, cache_dir="cache/"):
# TODO Documentation
if type(cache_dir) is not PosixPath:
if type(cache_dir) is not Path:
cache_dir = Path(cache_dir)
annot_sets = {}
......
......@@ -6,12 +6,12 @@ Example of usage:
fpath = "/home/alan/Documents/dev/skeletons/scripts/data/16543_2019_07_16_aggregates_CTP_switch_2_0glu_0_0glu_URA7young_URA8young_URA8old_01/URA8_young018.h5"
trap_id = 9
trange = list(range(0, 30))
tile_id = 9
trange = list(range(0, 10))
ncols = 8
riv = remoteImageViewer(fpath)
riv.plot_labelled_trap(trap_id, trange, [0], ncols=ncols)
riv.plot_labelled_trap(tile_id, trange, [0], ncols=ncols)
"""
......@@ -27,9 +27,9 @@ from PIL import Image
from skimage.morphology import dilation
from agora.io.cells import Cells
from agora.io.writer import load_attributes
from aliby.io.image import dispatch_image
from agora.io.metadata import dispatch_metadata_parser
from aliby.tile.tiler import Tiler, TilerParameters
from aliby.utils.plot import stretch_clip
default_colours = {
"Brightfield": "Greys_r",
......@@ -55,37 +55,17 @@ def custom_imshow(a, norm=None, cmap=None, *args, **kwargs):
)
class localImageViewer:
"""
Fast access to Images segmented locally without tiling
from image.h5 objects.
"""
def __init__(self, h5file, data_source=None):
self._hdf = h5py.File(h5file)
self.positions = list(self._hdf.keys())
self.current_position = self.positions[0]
def plot_position(self, channel=0, tp=0, z=0, stretch=True):
pixvals = self._hdf[self.current_position][channel, tp, ..., z]
if stretch:
minval = np.percentile(pixvals, 0.5)
maxval = np.percentile(pixvals, 99.5)
pixvals = np.clip(pixvals, minval, maxval)
pixvals = ((pixvals - minval) / (maxval - minval)) * 255
Image.fromarray(pixvals.astype(np.uint8))
class BaseImageViewer(ABC):
def __init__(self, fpath):
self._fpath = fpath
attrs = load_attributes(fpath)
attrs = dispatch_metadata_parser(fpath.parent)
self._logfiles_meta = {}
self._logfiles_meta["channels"] = attrs["channels/channel"]
self.image_id = attrs.get("image_id")
if self.image_id is None:
with h5py.File(fpath, "r") as f:
self.image_id = f.attrs.get("image_id")
assert self.image_id is not None, "No valid image_id found in metadata"
......@@ -111,6 +91,7 @@ class LocalImageViewer(BaseImageViewer):
"""
Tool to generate figures from local files, either zarr or files organised
in directories.
TODO move common functionality from RemoteImageViewer to BaseImageViewer
"""
def __init__(self, results_path: str, data_path: str):
......@@ -132,30 +113,28 @@ class LocalImageViewer(BaseImageViewer):
self.cells = Cells.from_source(results_path)
class remoteImageViewer(BaseImageViewer):
class RemoteImageViewer(BaseImageViewer):
"""
This ImageViewer combines fetching remote images with tiling and outline display.
"""
_credentials = ("host", "username", "password")
def __init__(
self,
results_path: str,
server_info: t.Dict[str, str],
):
self.super().__init__(results_path)
super().__init__(results_path)
from aliby.io.omero import Image as OImage
from aliby.io.omero import UnsafeImage as OImage
self._image_class = OImage
self._server_info = server_info or {
k: attrs["parameters"]["general"][k] for k in self._credentials
}
self._server_info = (
server_info or attrs["parameters"]["general"]["server_info"]
)
with dispatch_image(self.image_id)(
self.image_id, **self.server_info
) as image:
self.tiler = Tiler.from_hdf5(image, results_path)
self._image_instance = OImage(self.image_id, **self._server_info)
self.tiler = Tiler.from_h5(self._image_instance, results_path)
self.cells = Cells.from_source(results_path)
......@@ -226,25 +205,26 @@ class remoteImageViewer(BaseImageViewer):
if z is None:
z = 0
server_info = server_info or self.server_info
server_info = server_info or self._server_info
channels = 0 or self._find_channels(channels)
z = z or self.tiler.ref_z
ch_tps = [(channels[0], tp) for tp in tps]
with self._image_class(self.image_id, **server_info) as image:
self.tiler.image = image.data
for ch, tp in ch_tps:
if (ch, tp) not in self.full:
self.full[(ch, tp)] = self.tiler.get_tiles_timepoint(
tp, channels=[ch], z=[z]
)[:, 0, 0, ..., z]
requested_trap = {tp: self.full[(ch, tp)] for ch, tp in ch_tps}
return requested_trap
image = self._image_instance
self.tiler.image = image.data
for ch, tp in ch_tps:
if (ch, tp) not in self.full:
self.full[(ch, tp)] = self.tiler.get_tiles_timepoint(
tp, channels=[ch], z=[z]
)[:, 0, 0, z, ...]
requested_trap = {tp: self.full[(ch, tp)] for ch, tp in ch_tps}
return requested_trap
def get_labelled_trap(
self,
trap_id: int,
tile_id: int,
tps: t.Union[range, t.Collection[int]],
channels=None,
concatenate=True,
......@@ -254,12 +234,12 @@ class remoteImageViewer(BaseImageViewer):
Core method to fetch traps and labels together
"""
imgs = self.get_pos_timepoints(tps, channels=channels, **kwargs)
imgs_list = [x[trap_id] for x in imgs.values()]
imgs_list = [x[tile_id] for x in imgs.values()]
outlines = [
self.cells.at_time(tp, kind="edgemask").get(trap_id, [])
self.cells.at_time(tp, kind="edgemask").get(tile_id, [])
for tp in tps
]
lbls = [self.cells.labels_at_time(tp).get(trap_id, []) for tp in tps]
lbls = [self.cells.labels_at_time(tp).get(tile_id, []) for tp in tps]
lbld_outlines = [
np.stack([mask * lbl for mask, lbl in zip(maskset, lblset)]).max(
axis=0
......@@ -273,7 +253,7 @@ class remoteImageViewer(BaseImageViewer):
imgs_list = np.concatenate(imgs_list, axis=1)
return lbld_outlines, imgs_list
def get_images(self, trap_id, trange, channels, **kwargs):
def get_images(self, tile_id, trange, channels, **kwargs):
"""
Wrapper to fetch images
"""
......@@ -282,13 +262,13 @@ class remoteImageViewer(BaseImageViewer):
for ch in self._find_channels(channels):
out, imgs[ch] = self.get_labelled_trap(
trap_id, trange, channels=[ch], **kwargs
tile_id, trange, channels=[ch], **kwargs
)
return out, imgs
def plot_labelled_trap(
self,
trap_id: int,
tile_id: int,
channels,
trange: t.Union[range, t.Collection[int]],
remove_axis: bool = False,
......@@ -308,7 +288,7 @@ class remoteImageViewer(BaseImageViewer):
Parameters
----------
trap_id : int
tile_id : int
Identifier of trap
channels : Union[str, int]
Channels to use
......@@ -345,7 +325,7 @@ class remoteImageViewer(BaseImageViewer):
nrows = int(np.ceil(len(trange) / ncols))
width = self.tiler.tile_size * ncols
out, images = self.get_images(trap_id, trange, channels, **kwargs)
out, images = self.get_images(tile_id, trange, channels, **kwargs)
# dilation makes outlines easier to see
out = dilation(out).astype(float)
......@@ -363,7 +343,7 @@ class remoteImageViewer(BaseImageViewer):
), "Invalid norm argument."
if norm and norm in ("l1", "l2", "max"):
images = {k: stretch_image(v) for k, v in images.items()}
images = {k: stretch_clip(v) for k, v in images.items()}
images = [concat_pad(img, width, nrows) for img in images.values()]
# TODO convert to RGB to draw fluorescence with colour
......@@ -443,35 +423,3 @@ def concat_pad(a: np.array, width, nrows):
axis=1,
)
)
def stretch_image(image):
"""
Performs contrast stretching on an input image.
This function takes an array-like input image and enhances its contrast by adjusting
the dynamic range of pixel values. It first scales the pixel values between 0 and 255,
then clips the values that are below the 2nd percentile or above the 98th percentile.
Finally, the pixel values are scaled to the range between 0 and 1.
Parameters
----------
image : array-like
Input image.
Returns
-------
stretched : ndarray
Contrast-stretched version of the input image.
Examples
--------
FIXME: Add docs.
FIXME: GTP-generated. Confirm manually.
"""
image = ((image - image.min()) / (image.max() - image.min())) * 255
minval = np.percentile(image, 2)
maxval = np.percentile(image, 98)
image = np.clip(image, minval, maxval)
image = (image - minval) / (maxval - minval)
return image
......@@ -11,15 +11,29 @@ from matplotlib import pyplot as plt
def plot_overlay(
bg: np.ndarray, fg: np.ndarray, alpha: float = 0.5, ax=plt
bg: np.ndarray, fg: np.ndarray, alpha: float = 1.0, ax=plt
) -> None:
"""
Plot two images, one on top of the other.
"""
ax.imshow(bg, cmap=plt.cm.gray, interpolation="none")
ax.imshow(fg, alpha=alpha, interpolation="none")
ax.axis("off")
ax1 = ax.imshow(
bg, cmap=plt.cm.gray, interpolation="none", interpolation_stage="rgba"
)
ax2 = ax.imshow(
stretch(fg),
alpha=alpha,
interpolation="none",
interpolation_stage="rgba",
)
plt.axis("off")
return ax1, ax2
def plot_overlay_in_square(data: t.Tuple[np.ndarray, np.ndarray]):
"""
Plot images in an automatically-arranged grid.
"""
specs = strategies.SquareStrategy("center").get_grid(len(data))
for i, (gs, (tile, mask)) in enumerate(zip(specs, data)):
ax = plt.subplot(gs)
......@@ -27,7 +41,52 @@ def plot_overlay_in_square(data: t.Tuple[np.ndarray, np.ndarray]):
def plot_in_square(data: t.Iterable):
"""
Plot images in an automatically-arranged grid. Only takes one mask
"""
specs = strategies.SquareStrategy("center").get_grid(len(data))
for i, (gs, datum) in enumerate(zip(specs, data)):
ax = plt.subplot(gs)
ax.imshow(datum)
def stretch_clip(image, clip=True):
"""
Performs contrast stretching on an input image.
This function takes an array-like input image and enhances its contrast by adjusting
the dynamic range of pixel values. It first scales the pixel values between 0 and 255,
then clips the values that are below the 2nd percentile or above the 98th percentile.
Finally, the pixel values are scaled to the range between 0 and 1.
Parameters
----------
image : array-like
Input image.
Returns
-------
stretched : ndarray
Contrast-stretched version of the input image.
Examples
--------
FIXME: Add docs.
"""
from copy import copy
image = image[~np.isnan(image)]
image = ((image - image.min()) / (image.max() - image.min())) * 255
if clip:
minval = np.percentile(image, 2)
maxval = np.percentile(image, 98)
image = np.clip(image, minval, maxval)
image = (image - minval) / (maxval - minval)
return image
def stretch(image):
nona = image[~np.isnan(image)]
return (image - nona.min()) / (nona.max() - nona.min())
......@@ -11,7 +11,7 @@ from copy import copy
import numpy as np
from agora.io.cells import Cells
from aliby.io.image import instatiate_image
from aliby.io.image import instantiate_image
from aliby.tile.tiler import Tiler, TilerParameters
......@@ -21,7 +21,7 @@ def fetch_tc(
"""
Return 3D ndarray with (Z,Y,X) for a given pair of time point and channel.
"""
with instatiate_image(image_path) as img:
with instantiate_image(image_path) as img:
tiler = Tiler.from_h5(img, results_path, TilerParameters.default())
tc = tiler.get_tp_data(t, c)
return tc
......@@ -55,7 +55,7 @@ def get_tiles_at_times(
"""
# Get the correct tile in space and time
with instatiate_image(image_path) as image:
with instantiate_image(image_path) as image:
tiler = Tiler.from_h5(image, results_path, TilerParameters.default())
tp_channel_stack = [
_dispatch_tile_reduction(tile_reduction)(
......
......@@ -61,23 +61,6 @@ class ExtractorParameters(ParametersABC):
self.sub_bg = sub_bg
self.multichannel_ops = multichannel_ops
@staticmethod
def guess_from_meta(store_name: str, suffix="fast"):
"""
Find the microscope name from the h5 metadata.
Parameters
----------
store_name : str or Path
For a h5 file
suffix : str
Added at the end of the predicted parameter set
"""
with h5py.File(store_name, "r") as f:
microscope = f["/"].attrs.get("microscope")
assert microscope, "No metadata found"
return "_".join((microscope, suffix))
@classmethod
def default(cls):
return cls({})
......@@ -100,7 +83,7 @@ class Extractor(StepABC):
Extraction follows a three-level tree structure. Channels, such as GFP, are the root level; the reduction algorithm, such as maximum projection, is the second level; the specific metric, or operation, to apply to the masks, such as mean, is the third level.
"""
# Alan: should this data be stored here or all such data in a separate file
# TODO Alan: Move this to a location with the SwainLab defaults
default_meta = {
"pixel_size": 0.236,
"z_size": 0.6,
......@@ -407,8 +390,6 @@ class Extractor(StepABC):
reduced = img
if method is not None:
reduced = reduce_z(img, method)
if reduced.shape[0] < 10:
print("ahoy")
return reduced
def extract_tp(
......@@ -482,14 +463,8 @@ class Extractor(StepABC):
# stored as an array arranged as (traps, channels, time points, X, Y, Z)
tiles = self.get_tiles(tp, tile_shape=tile_size, channels=tree_chs)
# generate boolean masks for background as a list with one mask per trap
bgs = []
bgs = np.array([])
if self.params.sub_bg:
# bgs = [
# ~np.sum(m, axis=0).astype(bool)
# if np.any(m)
# else np.zeros((tile_size, tile_size)).astype(bool)
# for m in masks
# ]
bgs = ~np.array(
list(
map(
......
......@@ -129,7 +129,6 @@ def nuc_est_conv(
def nuc_conv_3d(cell_mask, trap_image, pixel_size=0.23, spacing=0.6):
print(cell_mask.shape, trap_image.shape)
cell_mask = np.stack([cell_mask] * trap_image.shape[0])
ratio = spacing / pixel_size
cell_fluo = trap_image[cell_mask]
......