From 05ea80205bf321ffd569b5990c829a0ac37efdc5 Mon Sep 17 00:00:00 2001
From: Marcin Kirsz <mkirsz@ed.ac.uk>
Date: Thu, 31 Oct 2024 16:41:18 +0000
Subject: [PATCH] find_unique... to accomodate for additional line of structure
 weights

---
 include/tadah/mlip/structure.h |   1 +
 src/structure_db.cpp           |  45 +++++-----
 tests/test_structure.cpp       | 160 +++++++++++++++++----------------
 3 files changed, 104 insertions(+), 102 deletions(-)

diff --git a/include/tadah/mlip/structure.h b/include/tadah/mlip/structure.h
index 47ef531..439bb35 100644
--- a/include/tadah/mlip/structure.h
+++ b/include/tadah/mlip/structure.h
@@ -3,6 +3,7 @@
 
 #include <tadah/mlip/atom.h>
 #include <tadah/core/core_types.h>
+#include <tadah/core/periodic_table.h>
 
 #include <string>
 #include <iostream>
diff --git a/src/structure_db.cpp b/src/structure_db.cpp
index 9f31fb6..f47b5e5 100644
--- a/src/structure_db.cpp
+++ b/src/structure_db.cpp
@@ -108,37 +108,34 @@ std::set<Element> StructureDB::get_unique_elements() const {
 }
 std::set<Element> StructureDB::find_unique_elements(const std::string &fn) {
   std::set<Element> s;
-  std::ifstream datafile(fn);
+  std::ifstream ifs(fn);
   std::string line;
 
-  if (!datafile.is_open()) {
+  if (!ifs.is_open()) {
     std::cerr << "Could not open the file." << std::endl;
   }
 
   char symbol[3];
-
-  // move to first non empty line
-  while (std::getline(datafile, line)) {
-    if (!line.empty()) {
-      break;
-    }
-  }
-  // skip 
-  for (int i = 0; i < 7; ++i) {
-    std::getline(datafile, line);
-  }
-  while (std::getline(datafile, line)) {
-    if(line.empty()) {
-      for (int i = 0; i < 8; ++i) {
-        if (!std::getline(datafile, line)) {
-          // Handle the case where there are not enough lines
-          break; // or handle the error
-        }
-      }
-      continue;
+  while (std::getline(ifs, line)) {
+    // the second line could be energy or
+    // a scalling factors eweight fweight sweight
+    std::getline(ifs,line);  
+    std::stringstream stream(line);
+    size_t count = std::distance(std::istream_iterator<std::string>(stream),
+        std::istream_iterator<std::string>());
+
+    if (count == 3)
+      std::getline(ifs,line);
+
+    for (size_t i=0; i<6; ++i)
+      std::getline(ifs,line);
+
+    while (std::getline(ifs, line)) {
+      if(line.empty()) break;
+      sscanf(line.c_str(), "%2s", symbol);
+      s.insert(PeriodicTable::find_by_symbol(symbol));
     }
-    sscanf(line.c_str(), "%2s", symbol);
-    s.insert(PeriodicTable::find_by_symbol(symbol));
+
   }
   return s;
 }
diff --git a/tests/test_structure.cpp b/tests/test_structure.cpp
index 39c3489..1a0292c 100644
--- a/tests/test_structure.cpp
+++ b/tests/test_structure.cpp
@@ -7,112 +7,116 @@
 // Conversion factor from eV/A^3 to kbar;
 double fac = 1602.1766208;
 
+
 TEST_CASE( "Testing Structure class volume", "[structure_volume]" ) {
-    //using vec=Eigen::Vector3d;
+  //using vec=Eigen::Vector3d;
 
-    //std::string symbol="Ti";
-    //std::string name="Titanium";
-    //int Z = 22;
-    //vec pos(1.0,2.0,3.0);
-    //vec force(4.0,5.0,6.0);
-    //Element element(symbol,name,Z);
+  //std::string symbol="Ti";
+  //std::string name="Titanium";
+  //int Z = 22;
+  //vec pos(1.0,2.0,3.0);
+  //vec force(4.0,5.0,6.0);
+  //Element element(symbol,name,Z);
 
-    //Atom a(element, 1.0, 2.0, 3.0, 
-    //        4.0, 5.0, 6.0);
+  //Atom a(element, 1.0, 2.0, 3.0, 
+  //        4.0, 5.0, 6.0);
 
-    Structure st;
-    // Integer volume
-    st.cell.load(2,-1,3,
-               3,2,-4,
-              -2,0,1);
+  Structure st;
+  // Integer volume
+  st.cell.load(2,-1,3,
+      3,2,-4,
+      -2,0,1);
 
-    REQUIRE(st.get_volume() == 11 );
+  REQUIRE(st.get_volume() == 11 );
 
-    // double volume
-    st.cell.load(5.0147,5.0104,-5.0018,
-              -9.9924,-3.3395,-9.9662,
-              -25.6396,38.4594,12.8198);
+  // double volume
+  st.cell.load(5.0147,5.0104,-5.0018,
+      -9.9924,-3.3395,-9.9662,
+      -25.6396,38.4594,12.8198);
 
-    REQUIRE(st.get_volume() == Approx(5980.0279772134).epsilon(1e-10));
+  REQUIRE(st.get_volume() == Approx(5980.0279772134).epsilon(1e-10));
 
 
 }
 TEST_CASE( "Testing Structure virial pressure calculations", "[structure_virial_pressures]" ) {
 
-    Structure st;
-    st.cell.load(9.374769,0.0,0.0,
-               0.0,9.374769,0.0,
-               0.0,0.0,9.374769);
+  Structure st;
+  st.cell.load(9.374769,0.0,0.0,
+      0.0,9.374769,0.0,
+      0.0,0.0,9.374769);
 
-    REQUIRE(st.get_volume() == Approx(823.91));
+  REQUIRE(st.get_volume() == Approx(823.91));
 
-    st.stress.load(257.0893807653,0.0,0.0,
-                 0.0,257.0893807653,0.0,
-                 0.0,0.0,257.0893807653);
-    REQUIRE(fac*st.get_virial_pressure() == Approx(499.93));
+  st.stress.load(257.0893807653,0.0,0.0,
+      0.0,257.0893807653,0.0,
+      0.0,0.0,257.0893807653);
+  REQUIRE(fac*st.get_virial_pressure() == Approx(499.93));
 
 }
 TEST_CASE( "Testing Structure read and write", "[structure_read_write]" ) {
-    Structure st;
-    st.read("tests_data/structure_1.dat");
+  PeriodicTable::initialize();
+  Structure st;
+  st.read("tests_data/structure_1.dat");
 
-    REQUIRE(st.get_volume() == Approx(989.521812));
+  REQUIRE(st.get_volume() == Approx(989.521812));
 
-    REQUIRE(fac*st.get_virial_pressure() == Approx(26.705578));
-    REQUIRE(fac*st.get_pressure(300) == Approx(28.965914));
-    REQUIRE(fac*st.get_pressure(0) == Approx(26.705578));
+  REQUIRE(fac*st.get_virial_pressure() == Approx(26.705578));
+  REQUIRE(fac*st.get_pressure(300) == Approx(28.965914));
+  REQUIRE(fac*st.get_pressure(0) == Approx(26.705578));
 
-    std::string tempfile = std::tmpnam(nullptr);
-    st.save(tempfile);
+  std::string tempfile = std::tmpnam(nullptr);
+  st.save(tempfile);
 
-    Structure st_temp;
-    st_temp.read(tempfile);
+  Structure st_temp;
+  st_temp.read(tempfile);
 
-    REQUIRE(st_temp.get_volume() == Approx(989.521812));
+  REQUIRE(st_temp.get_volume() == Approx(989.521812));
 
-    REQUIRE(fac*st_temp.get_virial_pressure() == Approx(26.705578));
-    REQUIRE(fac*st_temp.get_pressure(300) == Approx(28.965914));
-    REQUIRE(fac*st_temp.get_pressure(0) == Approx(26.705578));
+  REQUIRE(fac*st_temp.get_virial_pressure() == Approx(26.705578));
+  REQUIRE(fac*st_temp.get_pressure(300) == Approx(28.965914));
+  REQUIRE(fac*st_temp.get_pressure(0) == Approx(26.705578));
 
-    REQUIRE(st==st_temp);
+  REQUIRE(st==st_temp);
 
-    std::remove(tempfile.c_str());
+  std::remove(tempfile.c_str());
 
 }
 TEST_CASE( "Testing Structure compare", "[structure_compare]" ) {
-    Structure st;
-    st.read("tests_data/structure_1.dat");
-    std::string tempfile = std::tmpnam(nullptr);
-    st.save(tempfile);
-
-    Structure st_temp;
-    st_temp.read(tempfile);
-
-    SECTION("Compare unchanged") {
-        REQUIRE(st==st_temp);
-    }
-    SECTION("Compare symbols") {
-        REQUIRE(st==st_temp);
-        st_temp.atoms[0].symbol[0]='X';
-        st_temp.atoms[0].symbol[1]='X';
-        REQUIRE(!(st==st_temp));
-    }
-    SECTION("Compare position") {
-        REQUIRE(st==st_temp);
-        st_temp.atoms[0].position(0.12,0.13,10.14);
-        REQUIRE(!(st==st_temp));
-    }
-    SECTION("Compare force") {
-        REQUIRE(st==st_temp);
-        st_temp.atoms[0].force(1.12, 0.13, 0.134);
-        REQUIRE(!(st==st_temp));
-    }
-
-    std::remove(tempfile.c_str());
+  PeriodicTable::initialize();
+  Structure st;
+  st.read("tests_data/structure_1.dat");
+  std::string tempfile = std::tmpnam(nullptr);
+  st.save(tempfile);
+
+  Structure st_temp;
+  st_temp.read(tempfile);
+
+  SECTION("Compare unchanged") {
+    REQUIRE(st==st_temp);
+  }
+  SECTION("Compare symbols") {
+    REQUIRE(st==st_temp);
+    st_temp.atoms[0].symbol[0]='X';
+    st_temp.atoms[0].symbol[1]='X';
+    REQUIRE(!(st==st_temp));
+  }
+  SECTION("Compare position") {
+    REQUIRE(st==st_temp);
+    st_temp.atoms[0].position(0.12,0.13,10.14);
+    REQUIRE(!(st==st_temp));
+  }
+  SECTION("Compare force") {
+    REQUIRE(st==st_temp);
+    st_temp.atoms[0].force(1.12, 0.13, 0.134);
+    REQUIRE(!(st==st_temp));
+  }
+
+  std::remove(tempfile.c_str());
 }
 TEST_CASE( "Testing Structure copy", "[structure_copy]" ) {
-    Structure st;
-    st.read("tests_data/structure_1.dat");
-    Structure st2=st;
-        REQUIRE(st==st2);
+  PeriodicTable::initialize();
+  Structure st;
+  st.read("tests_data/structure_1.dat");
+  Structure st2=st;
+  REQUIRE(st==st2);
 }
-- 
GitLab