Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • swain-lab/aliby/aliby-mirror
  • swain-lab/aliby/alibylite
2 results
Show changes
Showing
with 600 additions and 132 deletions
#!/usr/bin/env python3
from abc import ABC
class BasePlotter(ABC):
"""Base class for plotting handler classes"""
def __init__(self, trace_name, unit_scaling, xlabel, plot_title):
"""Common attributes"""
self.trace_name = trace_name
self.unit_scaling = unit_scaling
self.xlabel = xlabel
self.ylabel = None
self.plot_title = plot_title
def plot(self, ax):
"""Template for drawing on provided Axes"""
ax.set_ylabel(self.ylabel)
ax.set_xlabel(self.xlabel)
ax.set_title(self.plot_title)
# Derived classes extends this with plotting functions
# TODO: something about the plotting functions at the end of the modules.
# Decorator?
#!/usr/bin/env python3
import matplotlib.pyplot as plt
from postprocessor.routines.single_plot import _SinglePlotter
class _SingleBirthPlotter(_SinglePlotter):
"""Draw a line plot of a single time series, but with buddings overlaid"""
def __init__(
self,
trace_timepoints,
trace_values,
trace_name,
birth_mask,
unit_scaling,
trace_color,
birth_color,
trace_linestyle,
birth_linestyle,
xlabel,
ylabel,
birth_label,
plot_title,
):
# Define attributes from arguments
super().__init__(
trace_timepoints,
trace_values,
trace_name,
unit_scaling,
trace_color,
trace_linestyle,
xlabel,
ylabel,
plot_title,
)
# Add some more attributes useful for buddings
self.birth_mask = birth_mask
self.birth_color = birth_color
self.birth_linestyle = birth_linestyle
self.birth_label = birth_label
def plot(self, ax):
"""Draw the line plots on the provided Axes."""
trace_time = self.trace_timepoints * self.unit_scaling
super().plot(ax)
birth_mask_bool = self.birth_mask.astype(bool)
for occurence, birth_time in enumerate(trace_time[birth_mask_bool]):
if occurence == 0:
label = self.birth_label
else:
label = None
ax.axvline(
birth_time,
color=self.birth_color,
linestyle=self.birth_linestyle,
label=label,
)
ax.legend()
def single_birth_plot(
trace_timepoints,
trace_values,
trace_name="flavin",
birth_mask=None,
unit_scaling=1,
trace_color="b",
birth_color="k",
trace_linestyle="-",
birth_linestyle="--",
xlabel="Time (min)",
ylabel="Normalised flavin fluorescence (AU)",
birth_label="budding event",
plot_title="",
ax=None,
):
"""Plot time series of trace, overlaid with buddings
Parameters
----------
trace_timepoints : array_like
Time points (as opposed to the actual times in time units)
trace_values : array_like
Trace to plot
trace_name : string
Name of trace being plotted, e.g. 'flavin'.
birth_mask : array_like
Mask to indicate where buddings are. Expect values of '0' and '1' or
'False' and 'True' in the elements.
unit_scaling : int or float
Unit scaling factor, e.g. 1/60 to convert minutes to hours.
trace_color : string
matplotlib colour string for the trace
birth_color : string
matplotlib colour string for the vertical lines indicating buddings
trace_linestyle : string
matplotlib linestyle argument for the trace
birth_linestyle : string
matplotlib linestyle argument for the vertical lines indicating buddings
xlabel : string
x axis label.
ylabel : string
y axis label.
birth_label : string
label for budding event, 'budding event' by default.
plot_title : string
Plot title.
ax : matplotlib Axes
Axes in which to draw the plot, otherwise use the currently active Axes.
Returns
-------
ax : matplotlib Axes
Axes object with the plot.
Examples
--------
FIXME: Add docs.
"""
plotter = _SingleBirthPlotter(
trace_timepoints,
trace_values,
trace_name,
birth_mask,
unit_scaling,
trace_color,
birth_color,
trace_linestyle,
birth_linestyle,
xlabel,
ylabel,
birth_label,
plot_title,
)
if ax is None:
ax = plt.gca()
plotter.plot(ax)
return ax
#!/usr/bin/env python3
import matplotlib.pyplot as plt
from postprocessor.routines.plottingabc import BasePlotter
class _SinglePlotter(BasePlotter):
"""Draw a line plot of a single time series."""
def __init__(
self,
trace_timepoints,
trace_values,
trace_name,
unit_scaling,
trace_color,
trace_linestyle,
xlabel,
ylabel,
plot_title,
):
super().__init__(trace_name, unit_scaling, xlabel, plot_title)
# Define attributes from arguments
self.trace_timepoints = trace_timepoints
self.trace_values = trace_values
self.trace_color = trace_color
self.trace_linestyle = trace_linestyle
# Define some labels
self.ylabel = ylabel
def plot(self, ax):
"""Draw the line plot on the provided Axes."""
super().plot(ax)
ax.plot(
self.trace_timepoints * self.unit_scaling,
self.trace_values,
color=self.trace_color,
linestyle=self.trace_linestyle,
label=self.trace_name,
)
def single_plot(
trace_timepoints,
trace_values,
trace_name="flavin",
unit_scaling=1,
trace_color="b",
trace_linestyle="-",
xlabel="Time (min)",
ylabel="Normalised flavin fluorescence (AU)",
plot_title="",
ax=None,
):
"""Plot time series of trace.
Parameters
----------
trace_timepoints : array_like
Time points (as opposed to the actual times in time units).
trace_values : array_like
Trace to plot.
trace_name : string
Name of trace being plotted, e.g. 'flavin'.
unit_scaling : int or float
Unit scaling factor, e.g. 1/60 to convert minutes to hours.
trace_color : string
matplotlib colour string, specifies colour of line plot.
trace_linestyle : string
matplotlib linestyle argument.
xlabel : string
x axis label.
ylabel : string
y axis label.
plot_title : string
Plot title.
ax : matplotlib Axes
Axes in which to draw the plot, otherwise use the currently active Axes.
Returns
-------
ax : matplotlib Axes
Axes object with the plot.
Examples
--------
FIXME: Add docs.
"""
plotter = _SinglePlotter(
trace_timepoints,
trace_values,
trace_name,
unit_scaling,
trace_color,
trace_linestyle,
xlabel,
ylabel,
plot_title,
)
if ax is None:
ax = plt.gca()
plotter.plot(ax)
return ax
"""
Basic ParametersIO tests
"""
import pytest
from agora.abc import ParametersABC
class DummyParameters(ParametersABC):
def __init__(self, **kwargs):
super().__init__(**kwargs)
@classmethod
def default(cls):
# Necessary empty builder
return cls.from_dict({})
def test_file_exists(yaml_file):
assert yaml_file.exists()
def test_from_yaml(yaml_file):
# From yaml
params = DummyParameters.from_yaml(yaml_file)
def test_from_stdin(yaml_file):
# From yaml
params = DummyParameters.from_yaml(yaml_file)
# To yaml
assert isinstance(params, ParametersABC)
def test_to_yaml(yaml_file):
with open(yaml_file, "r") as fd:
yaml_data = fd.read()
params = DummyParameters.from_yaml(yaml_file)
assert params.to_yaml() == yaml_data
def test_dict(example_dict):
params = DummyParameters(**example_dict)
assert params.to_dict() == example_dict
# Remove
params.to_yaml("outfile.yml")
def test_to_dict():
DummyParameters.default().to_dict()
#!/usr/bin/env jupyter
"""
Load data necessary to test agora.
"""
import typing as t
from pathlib import Path
import pytest
@pytest.fixture(scope="module")
def data_dir():
return Path(__file__).parent / "data"
@pytest.fixture(scope="module")
def yaml_file(data_dir: Path):
data = data_dir / "parameters.yaml"
if not data.exists():
pytest.fail(f"There is no file at {str( data_dir )}.")
return data
@pytest.fixture(scope="module", autouse=True)
def example_dict() -> t.Dict:
return dict(
string="abc",
number=1,
boolean=True,
dictionary=dict(
# empty_dict=dict(),
string="def",
number=2,
),
)
{}
# TODO turn into Unittest test case
import os
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
import json
import logging
from logging.handlers import RotatingFileHandler
from core.experiment import Experiment
logger = logging.getLogger('core')
logger.handlers = []
logger.setLevel(logging.DEBUG)
console = logging.StreamHandler()
console.setLevel(logging.WARNING)
formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s')
console.setFormatter(formatter)
logger.addHandler(console)
file_handler = RotatingFileHandler(filename='test.log',
maxBytes=1e5,
backupCount=1)
file_handler.setLevel(logging.DEBUG)
file_formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s '
'- %(message)s')
file_handler.setFormatter(file_formatter)
logger.addHandler(file_handler)
logger.debug('Set up the loggers as test.')
with open('config.json', 'r') as fd:
config = json.load(fd)
expt = Experiment.from_source(config['experiment'], config['user'],
config['password'], config['host'],
config['port'])
print(expt.metadata.channels)
print(expt.metadata.times)
print(expt.metadata.switch_params)
print(expt.metadata.zsections)
print(expt.metadata.positions)
# print(expt.get_hypercube(x=None, y=None, width=None, height=None,
# z_positions=[0], channels=[0], timepoints=[0]))
# expt.cache_locally(root_dir='/Users/s1893247/PhD/pipeline-core/data/',
# positions=['pos001', 'pos002', 'pos003'],
# channels=['Brightfield', 'GFP'],
# timepoints=range(3),
# z_positions=None)
expt.connection.seppuku()
# Example of argo experiment explorer
import pytest
from aliby.utils.argo import Argo
......
from itertools import product
import pytest
from extraction.core.extractor import Extractor, ExtractorParameters
from extraction.core.functions import cell
from extraction.core.functions.trap import imBackground
from extraction.core.functions.loaders import (
load_funs,
load_cellfuns,
load_trapfuns,
load_redfuns,
)
from extraction.examples import data
from extraction import local_data_loaders as data
dsets1z = data.load_1z()
dsets = data.load()
......@@ -32,12 +29,14 @@ def test_metrics_run(imgs, masks, f):
"""
for ch, img in imgs.items():
if ch is not "segoutlines":
assert tuple(masks.shape[:2]) == tuple(imgs[ch].shape)
if ch != "segoutlines":
assert tuple(masks.shape[-2:]) == tuple(imgs[ch].shape)
f(masks, img)
@pytest.mark.parametrize(["imgs", "masks", "tree"], product(dsets, masks, tree))
@pytest.mark.parametrize(
["imgs", "masks", "tree"], list(product(dsets, masks, tree))
)
def test_extractor(imgs, masks, tree):
"""
Test a tiler-less extractor using an instance built using default parameters.
......@@ -46,15 +45,14 @@ def test_extractor(imgs, masks, tree):
Tests reduce-extract
"""
extractor = Extractor(
ExtractorParameters.from_meta({"channels/channel": ["Brightfield", "GFP"]})
ExtractorParameters.from_meta({"channels": ["Brightfield", "GFP"]})
)
# Load all available functions
extractor._all_funs = load_funs()[2]
extractor._all_cell_funs = load_cellfuns()
extractor.load_funs()
extractor.tree = tree
traps = imgs["GFP"]
# Generate mock labels
labels = list(range(masks.shape[2]))
labels = list(range(len(masks)))
for ch_branches in extractor.params.tree.values():
print(
extractor.reduce_extract(
......
import numpy as np
from pathlib import Path
from extraction.core.extractor import Extractor, ExtractorParameters
params = ExtractorParameters.from_meta(
{"channels/channel": ["Brightfield", "GFPFast", "pHluorin405", "mCherry"]}
{"channels": ["Brightfield", "GFPFast", "pHluorin405", "mCherry"]}
)
ext = Extractor(params)
ext.load_funs()
......@@ -11,14 +11,14 @@ ext.load_funs()
def test_custom_output():
self = ext
mask = np.zeros((6, 6, 2), dtype=bool)
mask[2:4, 2:4, 0] = True
mask[3:5, 3:5, 1] = True
img = np.random.randint(1, 11, size=6 ** 2 * 5).reshape(6, 6, 5)
mask = np.zeros((2, 6, 6), dtype=bool)
mask[0, 2:4, 2:4] = True
mask[1, 3:5, 3:5] = True
img = np.random.randint(1, 11, size=6**2 * 5).reshape(5, 6, 6)
for i, f in self._custom_funs.items():
if "3d" in i:
res = f(mask, img)
else:
res = f(mask, np.maximum.reduce(img, axis=2))
assert len(res) == mask.shape[2], "Output doesn't match input"
res = f(mask, np.maximum.reduce(img, axis=0))
assert len(res) == mask.shape[0], "Output doesn't match input"
......@@ -3,29 +3,30 @@ This code requires a functional OMERO database on localhost at port 4064
See the README for instructions as to how to set these up with docker.
"""
# TODO remove and use unittest to run tests
import os
import sys
import os
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
USERNAME = 'root'
PASSWORD = 'omero-root-password'
HOST = 'localhost'
USERNAME = "root"
PASSWORD = "omero-root-password"
HOST = "localhost"
PORT = 4064
from omero.gateway import BlitzGateway
def print_obj(obj, indent=0):
"""
Helper method to display info about OMERO objects.
Not all objects will have a "name" or owner field.
"""
print("""%s%s:%s Name:"%s" """ % (
" " * indent,
obj.OMERO_CLASS,
obj.getId(),
obj.getName()))
print(
"""%s%s:%s Name:"%s" """
% (" " * indent, obj.OMERO_CLASS, obj.getId(), obj.getName())
)
if __name__ == '__main__':
if __name__ == "__main__":
# Connect to the Python Blitz Gateway
# ===================================
......@@ -38,9 +39,11 @@ if __name__ == '__main__':
# ==========================
if not connected:
import sys
sys.stderr.write(
"Error: Connection not available, please check your user name and"
" password.\n")
" password.\n"
)
sys.exit(1)
# Using secure connection
......@@ -58,16 +61,15 @@ if __name__ == '__main__':
# clients.
user = conn.getUser()
print( "Current user:")
print( " ID:", user.getId())
print( " Username:", user.getName())
print( " Full Name:", user.getFullName())
print("Current user:")
print(" ID:", user.getId())
print(" Username:", user.getName())
print(" Full Name:", user.getFullName())
# Check if you are an Administrator
print( " Is Admin:", conn.isAdmin())
print(" Is Admin:", conn.isAdmin())
# Close connection
# ================
# When you are done, close the session to free up server resources.
# Close connection
# ================
# When you are done, close the session to free up server resources.
conn.close()
import pytest
import logging
from logging.handlers import RotatingFileHandler
import unittest
from logging.handlers import RotatingFileHandler
from pathlib import Path
from aliby.experiment import Experiment
import pytest
# from aliby.experiment import Experiment
## LOGGING
logger = logging.getLogger("core")
......@@ -18,7 +18,9 @@ formatter = logging.Formatter("%(name)-12s: %(levelname)-8s %(message)s")
console.setFormatter(formatter)
logger.addHandler(console)
file_handler = RotatingFileHandler(filename="test.log", maxBytes=1e5, backupCount=1)
file_handler = RotatingFileHandler(
filename="test.log", maxBytes=1e5, backupCount=1
)
file_handler.setLevel(logging.DEBUG)
file_formatter = logging.Formatter(
......@@ -33,10 +35,13 @@ data_directory = Path(__file__).parent.parent / "data/"
root_directory = data_directory / "glclvl_0.1_mig1_msn2_maf1_sfp1_dot6_03"
@pytest.mark.skip(reason="No longer usable, requires local files. Kept until replaced.")
@pytest.mark.skip(
reason="No longer usable, requires local files. Kept until replaced."
)
class TestCase(unittest.TestCase):
def setUp(self):
self.expt = Experiment.from_source(root_directory, finished=True)
# self.expt = Experiment.from_source(root_directory, finished=True)
self.expt = None
def test_experiment_shape(self):
print("C: {}, T: {}, X: {}, Y: {}, Z: {}".format(*self.expt.shape))
......
import os
import unittest
from pathlib import Path
import pytest
import os
from pathlib import Path
from aliby.baby_client import BabyRunner
from aliby.experiment import ExperimentOMERO, ExperimentLocal
# from aliby.experiment import ExperimentOMERO, ExperimentLocal
from aliby.tile.tiler import Tiler
......
import unittest
import matplotlib.pyplot as plt
import numpy as np
import pytest
import skimage.morphology as morph
from scipy import ndimage
from skimage import draw
import unittest
from aliby.post_processing import (
conical,
ellipse_perimeter,
union_of_spheres,
volume_of_sphere,
circle_outline,
)
@pytest.mark.skip(
reason="No longer usable, post_processing unused inside aliby. Kept temporarily"
)
class VolumeEstimation(unittest.TestCase):
def test_conical(self):
radius = np.random.choice(range(60, 100))
......@@ -23,17 +20,23 @@ class VolumeEstimation(unittest.TestCase):
print(radius, con, b_sum)
self.assertAlmostEqual(abs(con - b_sum) / b_sum, 0, delta=0.10)
@pytest.mark.skip(
reason="No longer usable, post_processing unused inside aliby. Kept temporarily"
)
def test_conical_ellipse(self):
e = ellipse_perimeter(4, 5)
con = conical(e)
true = draw.ellipsoid_stats(4, 5, 4)[0]
print(con, true)
@pytest.mark.skip(
reason="No longer usable, post_processing unused inside aliby. Kept temporarily"
)
def test_sphere_error(self):
radii = range(3, 30)
con = [conical(circle_outline(radius)) for radius in radii]
spheres = [union_of_spheres(circle_outline(r)) for r in radii]
true = [4 * (r ** 3) * np.pi / 3 for r in radii]
true = [4 * (r**3) * np.pi / 3 for r in radii]
mVol = [
4 / 3 * np.pi * np.sqrt(morph.disk(radius).sum() / np.pi) ** 3
for radius in radii
......@@ -48,10 +51,15 @@ class VolumeEstimation(unittest.TestCase):
plt.legend()
# plt.show()
@pytest.mark.skip(
reason="No longer usable, post_processing unused inside aliby. Kept temporarily"
)
def test_ellipse_error(self):
x_radii = range(3, 30)
y_radii = [np.ceil(2.5 * r) for r in x_radii]
ellipses = [ellipse_perimeter(x_r, y_r) for x_r, y_r in zip(x_radii, y_radii)]
ellipses = [
ellipse_perimeter(x_r, y_r) for x_r, y_r in zip(x_radii, y_radii)
]
con = [conical(ellipse) for ellipse in ellipses]
spheres = [union_of_spheres(ellipse) for ellipse in ellipses]
mVol = np.array(
......@@ -59,12 +67,16 @@ class VolumeEstimation(unittest.TestCase):
4
/ 3
* np.pi
* np.sqrt(ndimage.binary_fill_holes(ellipse).sum() / np.pi) ** 3
* np.sqrt(ndimage.binary_fill_holes(ellipse).sum() / np.pi)
** 3
for ellipse in ellipses
]
)
true = np.array(
[4 * np.pi * x_r * y_r * x_r / 3 for x_r, y_r in zip(x_radii, y_radii)]
[
4 * np.pi * x_r * y_r * x_r / 3
for x_r, y_r in zip(x_radii, y_radii)
]
)
plt.scatter(true, con, label="Conical")
plt.scatter(true, spheres, label="Spheres")
......@@ -76,12 +88,17 @@ class VolumeEstimation(unittest.TestCase):
plt.legend()
# plt.show()
@pytest.mark.skip(
reason="No longer usable, post_processing unused inside aliby. Kept temporarily"
)
def test_minor_major_error(self):
r = np.random.choice(list(range(3, 30)))
x_radii = np.linspace(r / 3, r, 20)
y_radii = r ** 2 / x_radii
y_radii = r**2 / x_radii
ellipses = [ellipse_perimeter(x_r, y_r) for x_r, y_r in zip(x_radii, y_radii)]
ellipses = [
ellipse_perimeter(x_r, y_r) for x_r, y_r in zip(x_radii, y_radii)
]
con = np.array([conical(ellipse) for ellipse in ellipses])
spheres = np.array([union_of_spheres(ellipse) for ellipse in ellipses])
mVol = np.array(
......@@ -89,13 +106,17 @@ class VolumeEstimation(unittest.TestCase):
4
/ 3
* np.pi
* np.sqrt(ndimage.binary_fill_holes(ellipse).sum() / np.pi) ** 3
* np.sqrt(ndimage.binary_fill_holes(ellipse).sum() / np.pi)
** 3
for ellipse in ellipses
]
)
true = np.array(
[4 * np.pi * x_r * y_r * x_r / 3 for x_r, y_r in zip(x_radii, y_radii)]
[
4 * np.pi * x_r * y_r * x_r / 3
for x_r, y_r in zip(x_radii, y_radii)
]
)
ratio = y_radii / x_radii
......
import unittest
import pytest
import numpy as np
import pytest
from aliby.tile.traps import align_timelapse_images
from aliby.tile.tiler import Tiler, TilerParameters
@pytest.mark.skip(
......@@ -13,9 +14,10 @@ class TestCase(unittest.TestCase):
self.data = np.ones((1, 3, 5, 5, 5))
def test_align_timelapse_images(self):
drift, references = align_timelapse_images(self.data)
self.assertEqual(references, [0])
self.assertItemsEqual(drift.flatten(), np.zeros_like(drift.flatten()))
pass
# drift, references = align_timelapse_images(self.data)
# self.assertEqual(references, [0])
# self.assertItemsEqual(drift.flatten(), np.zeros_like(drift.flatten()))
if __name__ == "__main__":
......
import argparse
from aliby.experiment import ExperimentLocal
from aliby.tile.tiler import Tiler
from aliby.io.image import ImageLocalOME
# from aliby.experiment import ExperimentLocal
from aliby.tile.tiler import Tiler, TilerParameters
def define_parser():
......@@ -17,9 +19,15 @@ def define_parser():
return parser
def initialise_dummy():
tiler_parameters = TilerParameters.default().to_dict()
dummy_tiler = Tiler.dummy(tiler_parameters)
return dummy_tiler
def initialise_objects(data_path, template=None):
expt = ExperimentLocal(data_path, finished=True)
tiler = Tiler(expt, finished=True, template=template)
image = ImageLocalOME(data_path)
tiler = Tiler.from_image(image, TilerParameters.default())
return tiler
......@@ -28,7 +36,7 @@ def change_position(position, tiler):
def get_n_traps_timepoints(tiler):
return tiler.n_traps, tiler.n_timepoints
return tiler.n_traps, tiler.n
def trap_timelapse(tiler, trap_idx, channel, z):
......@@ -41,7 +49,7 @@ def trap_timelapse(tiler, trap_idx, channel, z):
def timepoint_traps(tiler, tp_idx, channel, z, tile_size):
channel_id = tiler.get_channel_index(channel)
traps = tiler.get_traps_timepoint(
traps = tiler.get_tiles_timepoint(
tp_idx, tile_size=tile_size, channels=[channel_id], z=list(range(z))
)
return traps
......@@ -51,6 +59,8 @@ if __name__ == "__main__":
parser = define_parser()
args = parser.parse_args()
dummy_tiler = initialise_dummy()
tiler = initialise_objects(args.root_dir, template=args.template)
if args.position is not None:
......@@ -58,7 +68,9 @@ if __name__ == "__main__":
n_traps, n_tps = get_n_traps_timepoints(tiler)
timelapse = trap_timelapse(tiler, args.trap, args.channel, args.z_positions)
timelapse = trap_timelapse(
tiler, args.trap, args.channel, args.z_positions
)
traps = timepoint_traps(
tiler, args.time, args.channel, args.z_positions, args.tile_size
)
import unittest
import numpy as np
from aliby.tile.traps import identify_trap_locations
......@@ -22,7 +23,10 @@ class TestCase(unittest.TestCase):
mode="constant",
)
self.expected_location = int(
(np.ceil((self.img_size - self.tile_size + self.trap_size) / 2) - 1)
(
np.ceil((self.img_size - self.tile_size + self.trap_size) / 2)
- 1
)
)
def test_identify_trap_locations(self):
......@@ -57,8 +61,10 @@ class TestMultipleCase(TestCase):
self.expected_locations = set(
[
(
self.expected_location + i * (self.img_size - self.trap_size),
self.expected_location + j * (self.img_size - self.trap_size),
self.expected_location
+ i * (self.img_size - self.trap_size),
self.expected_location
+ j * (self.img_size - self.trap_size),
)
for i in range(self.nrows)
for j in range(self.ncols)
......@@ -68,7 +74,11 @@ class TestMultipleCase(TestCase):
self.assertEqual(len(coords), ntraps)
self.assertEqual(
ntraps,
len(self.expected_locations.intersection([tuple(x) for x in coords])),
len(
self.expected_locations.intersection(
[tuple(x) for x in coords]
)
),
)
......
#!/usr/bin/env jupyter
def pytest_addoption(parser):
parser.addoption("--file", action="store", default="test_datasets")
def pytest_generate_tests(metafunc):
# This is called for every test. Only get/set command line arguments
# if the argument is specified in the list of test "fixturenames".
option_value = metafunc.config.option.file
if "file" in metafunc.fixturenames and option_value is not None:
metafunc.parametrize("file", [option_value])
#
#!/usr/bin/env python3
import numpy as np
import dask.array as da
import pytest
from aliby.io.image import ImageDummy
tiler_parameters = {"tile_size": 117, "ref_channel": "Brightfield", "ref_z": 0}
sample_da = da.from_array(np.array([[1, 2], [3, 4]]))
# Make it 5-dimensional
sample_da = da.reshape(
sample_da, (1, 1, 1, sample_da.shape[-2], sample_da.shape[-1])
)
@pytest.mark.parametrize("sample_da", [sample_da])
@pytest.mark.parametrize("dim", [2])
@pytest.mark.parametrize("n_empty_slices", [4])
@pytest.mark.parametrize("image_position", [1])
def test_pad_array(sample_da, dim, n_empty_slices, image_position):
"""Test ImageDummy.pad_array() method"""
# create object
imgdmy = ImageDummy(tiler_parameters)
# pads array
padded_da = imgdmy.pad_array(
sample_da,
dim=dim,
n_empty_slices=n_empty_slices,
image_position=image_position,
)
# select which dimension to index the multidimensional array
indices = {dim: image_position}
ix = [
indices.get(dim, slice(None))
for dim in range(padded_da.compute().ndim)
]
# Checks that original image array is there and is at the correct index
assert np.array_equal(padded_da.compute()[ix], sample_da.compute()[0])
# Checks that the additional axis is extended correctly
assert padded_da.compute().shape[dim] == n_empty_slices + 1
#!/usr/bin/env jupyter
from pathlib import Path
import pytest
from aliby.pipeline import Pipeline, PipelineParameters
def test_local_pipeline(file: str):
if Path(file).exists():
params = PipelineParameters.default(
general={
"expt_id": file,
"distributed": 0,
"directory": "test_output/",
"overwrite": True,
},
tiler={"ref_channel": 0},
)
p = Pipeline(params)
p.run()
else:
print("Test dataset not downloaded")