From d236804bde5d6fa602397a20ab3f41de90228381 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <amuoz@ed.ac.uk>
Date: Sat, 22 Jan 2022 16:32:35 +0000
Subject: [PATCH] fix tests after threshold modification

---
 tests/aliby/test_traps.py | 63 ++++++++++++++++++++++++++++++++++++---
 1 file changed, 59 insertions(+), 4 deletions(-)

diff --git a/tests/aliby/test_traps.py b/tests/aliby/test_traps.py
index c74d26fa..4daa2422 100644
--- a/tests/aliby/test_traps.py
+++ b/tests/aliby/test_traps.py
@@ -6,15 +6,70 @@ from aliby.tile.traps import identify_trap_locations
 
 class TestCase(unittest.TestCase):
     def setUp(self):
-        self.data = np.pad(np.ones((5, 5)), 10, mode="constant")
-        self.template = np.pad(np.ones((5, 5)), 2, mode="constant")
+        self.trap_size = 5
+        self.tile_size = 9
+        assert self.trap_size % 2
+        assert self.tile_size % 2
+        self.img_size = 16
+        self.data = np.pad(
+            np.ones((self.trap_size, self.trap_size)),
+            (self.img_size - self.tile_size) // 2,
+            mode="constant",
+        )
+        self.template = np.pad(
+            np.ones((self.trap_size, self.trap_size)),
+            (self.tile_size - self.trap_size) // 2,
+            mode="constant",
+        )
+        self.expected_location = int(
+            (np.ceil((self.img_size - self.tile_size + self.trap_size) / 2) - 1)
+        )
 
     def test_identify_trap_locations(self):
         coords = identify_trap_locations(
-            self.data, self.template, optimize_scale=False, downscale=1
+            self.data,
+            self.template,
+            optimize_scale=False,
+            downscale=1,
         )
         self.assertEqual(len(coords), 1)
-        self.assertEqual(coords[0].tolist(), [12, 12])
+        self.assertEqual(
+            coords[0].tolist(),
+            [self.expected_location, self.expected_location],
+        )
+
+
+class TestMultipleCase(TestCase):
+    def setUp(self):
+        self.nrows = 4
+        self.ncols = 4
+        super().setUp()
+        row = np.concatenate([self.data for i in range(self.ncols)])
+        self.data = np.concatenate([row for i in range(self.nrows)], axis=1)
+
+    def test_identify_trap_locations(self):
+        coords = identify_trap_locations(
+            self.data,
+            self.template,
+            optimize_scale=False,
+            downscale=1,
+        )
+        self.expected_locations = set(
+            [
+                (
+                    self.expected_location + i * (self.img_size - self.trap_size),
+                    self.expected_location + j * (self.img_size - self.trap_size),
+                )
+                for i in range(self.nrows)
+                for j in range(self.ncols)
+            ]
+        )
+        ntraps = self.nrows * self.ncols
+        self.assertEqual(len(coords), ntraps)
+        self.assertEqual(
+            ntraps,
+            len(self.expected_locations.intersection([tuple(x) for x in coords])),
+        )
 
 
 if __name__ == "__main__":
-- 
GitLab