From 2b2454d6c28da845788e797c32612e7ffa79a93f Mon Sep 17 00:00:00 2001
From: Marcin Kirsz <mkirsz@ed.ac.uk>
Date: Wed, 30 Oct 2024 18:33:37 +0000
Subject: [PATCH] StructureDB to check ATOMS and WATOMS consistency and set
 defaults if needed

---
 include/tadah/mlip/structure.h    |   4 +-
 include/tadah/mlip/structure_db.h |  15 +++-
 src/structure.cpp                 |   3 +-
 src/structure_db.cpp              | 136 +++++++++++++++++++++---------
 4 files changed, 108 insertions(+), 50 deletions(-)

diff --git a/include/tadah/mlip/structure.h b/include/tadah/mlip/structure.h
index f42580e..47ef531 100644
--- a/include/tadah/mlip/structure.h
+++ b/include/tadah/mlip/structure.h
@@ -9,7 +9,7 @@
 #include <fstream>
 #include <iomanip>
 #include <vector>
-#include <unordered_set>
+#include <set>
 
 /**
  * Container for a collection of Atom(s).
@@ -109,7 +109,7 @@ struct Structure {
   double sweight=1.0;
 
   /** List of chemical elements in this structure. */
-  std::unordered_set<Element> unique_elements;
+  std::set<Element> unique_elements;
 
   /** @return a reference to the i-th Atom.
    *
diff --git a/include/tadah/mlip/structure_db.h b/include/tadah/mlip/structure_db.h
index 678a5e3..833c3cb 100644
--- a/include/tadah/mlip/structure_db.h
+++ b/include/tadah/mlip/structure_db.h
@@ -1,14 +1,15 @@
 #ifndef STRUCTURE_DB_h
 #define STRUCTURE_DB_h
 
-#include <tadah/mlip/structure.h>
+#include <tadah/core/element.h>
 #include <tadah/core/config.h>
+#include <tadah/mlip/structure.h>
 
 #include <string>
 #include <iostream>
 #include <fstream>
 #include <vector>
-#include <unordered_set>
+#include <set>
 
 /**
  * Container for a collection of Structure(s).
@@ -49,7 +50,7 @@ struct StructureDB {
      * \note
      *      Required Config key: \ref DBFILE
      */
-    StructureDB(const Config &config);
+    StructureDB(Config &config);
 
     /** Add structures listed in the config file
      *
@@ -134,7 +135,7 @@ struct StructureDB {
     size_t calc_natoms(size_t n) const;
 
     /** Return unique elements for all Structures. */
-    std::unordered_set<Element> get_unique_elements();
+    std::set<Element> get_unique_elements() const;
 
     /** Count number of structures and atoms in all datasets from the Config file. */
     static std::pair<int,int> count(const Config &c);
@@ -143,5 +144,11 @@ struct StructureDB {
     static std::pair<int,int> count(const std::string fn);
 
     void clear_nn();
+
+    /** Check consistency of the ATOMS key. */
+    void check_atoms_key(Config &config) const;
+
+    /** Check consistency of the WATOMS key. Add if missing*/
+    void check_watoms_key(Config &config) const;
 };
 #endif
diff --git a/src/structure.cpp b/src/structure.cpp
index d6e0b8d..acb4c1b 100644
--- a/src/structure.cpp
+++ b/src/structure.cpp
@@ -77,7 +77,6 @@ int Structure::read(std::ifstream &ifs) {
 
   char symbol[3];
   double px,py,pz,fx,fy,fz;
-  PeriodicTable pd;
   while(std::getline(ifs,line)) {
     if(line.empty()) break;
     //if(line == " ") break;
@@ -85,7 +84,7 @@ int Structure::read(std::ifstream &ifs) {
 
     std::istringstream tmp(line);
     tmp >> symbol >> px >> py >> pz >> fx >> fy >> fz;
-    Element e = pd.find_by_symbol(symbol);
+    Element e = PeriodicTable::find_by_symbol(symbol);
     unique_elements.insert(e);
     atoms.push_back(Atom(e,px,py,pz,fx,fy,fz));
     //Atom &a = atoms.back();
diff --git a/src/structure_db.cpp b/src/structure_db.cpp
index c6e9733..cdaf502 100644
--- a/src/structure_db.cpp
+++ b/src/structure_db.cpp
@@ -1,26 +1,28 @@
 #include <tadah/mlip/structure_db.h>
 
 StructureDB::StructureDB() {}
-StructureDB::StructureDB(const Config &config) {
-    add(config);
+StructureDB::StructureDB(Config &config) {
+  add(config);
+  check_atoms_key(config);
+  check_watoms_key(config);
 }
 
 void StructureDB::add(const std::string fn) {
-    std::ifstream ifs(fn);
-    if (!ifs.is_open()) {
-        throw std::runtime_error("DBFILE does not exist: "+fn);
-    }
-    while (true) {
-        structures.push_back(Structure());
-        int t = structures.back().read(ifs);
+  std::ifstream ifs(fn);
+  if (!ifs.is_open()) {
+    throw std::runtime_error("DBFILE does not exist: "+fn);
+  }
+  while (true) {
+    structures.push_back(Structure());
+    int t = structures.back().read(ifs);
 
-        // did we read structure succesfully?
-        // if not remove last object from the list
-        if (t==1) structures.pop_back();
+    // did we read structure succesfully?
+    // if not remove last object from the list
+    if (t==1) structures.pop_back();
 
-        if (ifs.eof()) break;
-    }
-    ifs.close();
+    if (ifs.eof()) break;
+  }
+  ifs.close();
 }
 int StructureDB::add(const std::string fn, int first, int N) {
   std::ifstream ifs(fn);
@@ -50,59 +52,59 @@ int StructureDB::add(const std::string fn, int first, int N) {
 }
 
 void StructureDB::add(const Structure &s) {
-    structures.push_back(s);
+  structures.push_back(s);
 }
 
 void StructureDB::remove(size_t i) {
-    structures.erase(structures.begin()+i);
+  structures.erase(structures.begin()+i);
 }
 
 void StructureDB::add(const Config &config) {
-    for (const std::string &s : config("DBFILE")) {
-        dbidx.push_back(size());
-        add(s);
-    }
+  for (const std::string &s : config("DBFILE")) {
     dbidx.push_back(size());
+    add(s);
+  }
+  dbidx.push_back(size());
 }
 
 size_t StructureDB::size() const {
-    return structures.size();
+  return structures.size();
 }
 
 size_t StructureDB::size(size_t n) const {
-    return dbidx[n+1]-dbidx[n];
+  return dbidx[n+1]-dbidx[n];
 }
 
 Structure &StructureDB::operator()(size_t s) {
-    return structures[s];
+  return structures[s];
 }
 
 const Structure &StructureDB::operator()(size_t s) const {
-    return structures[s];
+  return structures[s];
 }
 
 Atom &StructureDB::operator()(size_t s, size_t a) {
-    return structures[s].atoms[a];
+  return structures[s].atoms[a];
 }
 size_t StructureDB::calc_natoms() const {
-    size_t natoms=0;
-    for (auto struc: structures) natoms += struc.natoms();
-    return natoms;
+  size_t natoms=0;
+  for (auto struc: structures) natoms += struc.natoms();
+  return natoms;
 }
 size_t StructureDB::calc_natoms(size_t n) const {
-    size_t start = dbidx[n];
-    size_t stop = dbidx[n+1];
-    size_t natoms=0;
-    for (size_t i=start; i<stop; ++i) {
-        natoms += (*this)(i).natoms();
-    }
-    return natoms;
+  size_t start = dbidx[n];
+  size_t stop = dbidx[n+1];
+  size_t natoms=0;
+  for (size_t i=start; i<stop; ++i) {
+    natoms += (*this)(i).natoms();
+  }
+  return natoms;
 }
-std::unordered_set<Element> StructureDB::get_unique_elements() {
-    std::unordered_set<Element> s;
-    for (const auto & st: structures) s.insert(
-            st.unique_elements.begin(),st.unique_elements.end());
-    return s;
+std::set<Element> StructureDB::get_unique_elements() const {
+  std::set<Element> s;
+  for (const auto & st: structures) s.insert(
+      st.unique_elements.begin(),st.unique_elements.end());
+  return s;
 }
 
 template <typename T,typename U>                                                   
@@ -136,5 +138,55 @@ std::pair<int,int> StructureDB::count(const std::string fn){
   return res;
 }
 void StructureDB::clear_nn() {
-    for (auto &struc: structures) struc.clear_nn();
+  for (auto &struc: structures) struc.clear_nn();
+}
+void StructureDB::check_atoms_key(Config &config) const {
+  std::set<Element> unique_elements = get_unique_elements();
+  bool error=false;
+
+  if (config.exist("ATOMS")) {
+    // user set this key so here we check does it correspond to unique_elements
+    if (unique_elements.size()!=config.size("ATOMS")) {
+      error=true;
+    }
+
+    auto set_it = unique_elements.begin();
+    auto atoms_it = config("ATOMS").begin();
+    while (set_it != unique_elements.end() && atoms_it != config("ATOMS").end()) {
+      if (set_it->symbol != *atoms_it)  {
+        error=true;
+      }
+      ++set_it;
+      ++atoms_it;
+    }
+    if (error) {
+      throw std::runtime_error("\n"
+          "Mismatch between elements in datasets and ATOMS in the config file.\n"
+          "Please either update the ATOMS in the config file or remove ATOMS\n"
+          "key completely. Tadah! will automatically configure this key.\n"
+          );
+    }
+  } else {
+    for (const auto &s : unique_elements) config.add("ATOMS", s.symbol);
+  }
+}
+void StructureDB::check_watoms_key(Config &config) const {
+  std::set<Element> unique_elements = get_unique_elements();
+  bool error=false;
+
+  if (config.exist("WATOMS")) {
+    // user set this key so here we check does it correspond to unique_elements
+    if (unique_elements.size()!=config.size("WATOMS")) {
+      error=true;
+    }
+    if (error) {
+      throw std::runtime_error("\n"
+          "Mismatch between elements in datasets and WATOMS in the config file.\n"
+          "Please either update the WATOMS in the config file or remove WATOMS\n"
+          "key completely. In the latter case Tadah! will use default values.\n"
+          );
+    }
+  } else {
+    for (const auto &s : unique_elements) config.add("WATOMS", s.Z);
+  }
 }
-- 
GitLab