From d236804bde5d6fa602397a20ab3f41de90228381 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <amuoz@ed.ac.uk> Date: Sat, 22 Jan 2022 16:32:35 +0000 Subject: [PATCH] fix tests after threshold modification --- tests/aliby/test_traps.py | 63 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 59 insertions(+), 4 deletions(-) diff --git a/tests/aliby/test_traps.py b/tests/aliby/test_traps.py index c74d26fa..4daa2422 100644 --- a/tests/aliby/test_traps.py +++ b/tests/aliby/test_traps.py @@ -6,15 +6,70 @@ from aliby.tile.traps import identify_trap_locations class TestCase(unittest.TestCase): def setUp(self): - self.data = np.pad(np.ones((5, 5)), 10, mode="constant") - self.template = np.pad(np.ones((5, 5)), 2, mode="constant") + self.trap_size = 5 + self.tile_size = 9 + assert self.trap_size % 2 + assert self.tile_size % 2 + self.img_size = 16 + self.data = np.pad( + np.ones((self.trap_size, self.trap_size)), + (self.img_size - self.tile_size) // 2, + mode="constant", + ) + self.template = np.pad( + np.ones((self.trap_size, self.trap_size)), + (self.tile_size - self.trap_size) // 2, + mode="constant", + ) + self.expected_location = int( + (np.ceil((self.img_size - self.tile_size + self.trap_size) / 2) - 1) + ) def test_identify_trap_locations(self): coords = identify_trap_locations( - self.data, self.template, optimize_scale=False, downscale=1 + self.data, + self.template, + optimize_scale=False, + downscale=1, ) self.assertEqual(len(coords), 1) - self.assertEqual(coords[0].tolist(), [12, 12]) + self.assertEqual( + coords[0].tolist(), + [self.expected_location, self.expected_location], + ) + + +class TestMultipleCase(TestCase): + def setUp(self): + self.nrows = 4 + self.ncols = 4 + super().setUp() + row = np.concatenate([self.data for i in range(self.ncols)]) + self.data = np.concatenate([row for i in range(self.nrows)], axis=1) + + def test_identify_trap_locations(self): + coords = identify_trap_locations( + self.data, + self.template, + optimize_scale=False, + downscale=1, + ) + self.expected_locations = set( + [ + ( + self.expected_location + i * (self.img_size - self.trap_size), + self.expected_location + j * (self.img_size - self.trap_size), + ) + for i in range(self.nrows) + for j in range(self.ncols) + ] + ) + ntraps = self.nrows * self.ncols + self.assertEqual(len(coords), ntraps) + self.assertEqual( + ntraps, + len(self.expected_locations.intersection([tuple(x) for x in coords])), + ) if __name__ == "__main__": -- GitLab