From d5fede06ba8998b0f49cbe73a9a83bd9554dbcd5 Mon Sep 17 00:00:00 2001
From: mkirsz <s1351949@sms.ed.ac.uk>
Date: Tue, 25 Feb 2025 15:26:07 +0000
Subject: [PATCH] Improved training weighting

---
 .../tadah/mlip/design_matrix/design_matrix.h  | 39 +++-------------
 include/tadah/mlip/structure_db.h             |  2 +-
 include/tadah/mlip/trainer.h                  |  1 -
 src/nn_finder.cpp                             |  4 +-
 src/structure_db.cpp                          | 44 ++++++++++++-------
 5 files changed, 37 insertions(+), 53 deletions(-)

diff --git a/include/tadah/mlip/design_matrix/design_matrix.h b/include/tadah/mlip/design_matrix/design_matrix.h
index d78cc11..f4a2e71 100644
--- a/include/tadah/mlip/design_matrix/design_matrix.h
+++ b/include/tadah/mlip/design_matrix/design_matrix.h
@@ -1,6 +1,7 @@
 #ifndef DESIGN_MATRIX_H
 #define DESIGN_MATRIX_H
 
+#include <cmath>
 #include <tadah/mlip/st_descriptors_db.h>
 #include <tadah/mlip/structure_db.h>
 #include <tadah/mlip/descriptors_calc.h>
@@ -178,9 +179,8 @@ public:
     compute_stdevs(stdb);
     fill_T(stdb);
     compute_wfactors(stdb); // call after fill_T
-    // for opm we need to find first rows for each structure
-
 
+    // for opm we need to find first rows for each structure
     std::vector<size_t> rows(stdb.size());
     size_t row=0;
     for (size_t s=0; s<stdb.size(); ++s) {
@@ -216,19 +216,11 @@ public:
   }
   void build(size_t &row, const Structure &st, const StDescriptors &st_d) {
 
-    // double escale = 1;
-    // if (scale) escale = st.eweight*eweightglob/st.natoms(); // TODO
     f.calc_phi_energy_row(ws->Phi,row,st,st_d);
     if (force) {
-      // double fscale = 1;
-      // if (scale) fscale = st.fweight*fweightglob/st.natoms()/3.0; // TODO
       f.calc_phi_force_rows(ws->Phi,row,st,st_d);
     }
     if (stress) {
-      // double sscale_arr[6] {1,1,1,1,1,1};
-      // if (scale)
-      // for(size_t xy=0;xy<6;++xy)
-      // sscale_arr[xy] = st.sweight*sweightglob/6.0; // TODO
       f.calc_phi_stress_rows(ws->Phi,row,st,st_d);
     }
   }
@@ -236,13 +228,10 @@ public:
     size_t j=start;
     for (size_t s=0; s<stdb.size(); ++s) {
 
-      // double escale = 1;
       ws->Tlabels(j) = 0;
       ws->T(j++) = stdb(s).energy;
 
       if (force) {
-        // double fscale = 1;
-        // if (scale) fscale = stdb(s).fweight*fweightglob/stdb(s).natoms()/3.0;
         for (const Atom &a : stdb(s).atoms) {
           ws->Tlabels(j) = 1;
           ws->T(j++) = a.force(0);
@@ -253,11 +242,6 @@ public:
         }
       }
       if (stress) {
-        // double sscale_arr[6] {1,1,1,1,1,1};
-        // if (scale)
-        //   for(size_t xy=0;xy<6;++xy)
-        //     sscale_arr[xy] = stdb(s).sweight*sweightglob/6.0;
-        // size_t xy=0;
         for (size_t x=0; x<3; ++x) {
           for (size_t y=x; y<3; ++y) {
             ws->Tlabels(j) = 2;
@@ -323,7 +307,6 @@ private:
       num_forces +=3*stdb(s).natoms();
     }
 
-    //e_std_dev = std::sqrt((evec - evec.mean()).square().sum()/(evec.size()-1));
     e_std_dev = evec.std_dev(evec.mean(), evec.size()-1);
     if (verbose) std::cout << "Energy standard deviation (eV/atom): " << e_std_dev << std::endl;
 
@@ -334,7 +317,6 @@ private:
       }
     }
 
-
     if (force) {
       size_t j=0;
       t_type fvec(num_forces);
@@ -345,12 +327,8 @@ private:
           fvec(j++) = a.force(2);
         }
       }
-      //fvec /= e_std_dev;
-      //svec /= e_std_dev;
-      // e_std_dev has units of energy
-      // f_std_dev has units of inverse distance
       f_std_dev = fvec.std_dev(fvec.mean(),fvec.size()-1);
-      if (verbose) std::cout << "Force standard deviation (A^-1): " << f_std_dev << std::endl;
+      if (verbose) std::cout << "Force standard deviation (ev/A): " << f_std_dev << std::endl;
     }
 
     config.add("ESTDEV",e_std_dev);
@@ -362,10 +340,9 @@ private:
   void compute_wfactors(const StructureDB &stdb) {
     size_t j=0;
     for (size_t s=0; s<stdb.size(); ++s) {
-      double escale = stdb(s).eweight*eweightglob;
-      ws->wfactors(j++) = escale;
+      ws->wfactors(j++) = std::sqrt(stdb(s).eweight*eweightglob);
       if (force) {
-        double fscale = stdb(s).fweight*fweightglob;
+        double fscale = std::sqrt(stdb(s).fweight*fweightglob);
         for (size_t a=0; a<stdb(s).natoms(); ++a) {
           ws->wfactors(j++) = fscale;
           ws->wfactors(j++) = fscale;
@@ -373,13 +350,9 @@ private:
         }
       }
       if (stress) {
-        double sscale_arr[6] {1,1,1,1,1,1};
-        for(size_t xy=0;xy<6;++xy)
-          sscale_arr[xy] = stdb(s).sweight*sweightglob/6.0;
-        int xy=0;
         for (size_t x=0; x<3; ++x) {
           for (size_t y=x; y<3; ++y) {
-            ws->wfactors(j++)=sscale_arr[xy++];
+            ws->wfactors(j++) = std::sqrt(stdb(s).sweight*sweightglob);
           }
         }
       }
diff --git a/include/tadah/mlip/structure_db.h b/include/tadah/mlip/structure_db.h
index 0cab4a6..b96eac0 100644
--- a/include/tadah/mlip/structure_db.h
+++ b/include/tadah/mlip/structure_db.h
@@ -169,7 +169,7 @@ struct StructureDB {
   std::vector<Structure>::const_iterator end() const;
 
   /** Method to dump class content to a file */
-  void dump_to_file(const std::string& filepath, size_t prec=12) const;
+  void dump_to_file(const std::string& filepath, bool append, size_t prec=12) const;
 
   // Public method that reads the file, counts blocks and line counts per block,
   // then prints the results to std::cout.
diff --git a/include/tadah/mlip/trainer.h b/include/tadah/mlip/trainer.h
index 9bbf598..e7a7776 100644
--- a/include/tadah/mlip/trainer.h
+++ b/include/tadah/mlip/trainer.h
@@ -2,7 +2,6 @@
 #define MPI_TRAINER_H
 #include <tadah/mlip/descriptors_calc.h>
 #include <tadah/mlip/design_matrix/design_matrix.h>
-#include <tadah/mlip/trainer.h>
 #include <tadah/mlip/design_matrix/functions/dm_function_base.h>
 #include <tadah/mlip/models/m_tadah_base.h>
 #include <tadah/mlip/nn_finder.h>
diff --git a/src/nn_finder.cpp b/src/nn_finder.cpp
index 81569cb..651077f 100644
--- a/src/nn_finder.cpp
+++ b/src/nn_finder.cpp
@@ -280,7 +280,7 @@ void NNFinder::calc(StructureDB &stdb)
   }
   auto t1 = std::chrono::steady_clock::now();
   double seconds = std::chrono::duration<double>(t1 - t0).count();
-  std::cout << "calc(StructureDB &stdb) for-loop took "
-            << seconds << " seconds\n";
+  // std::cout << "calc(StructureDB &stdb) for-loop took "
+  //           << seconds << " seconds\n";
 }
 
diff --git a/src/structure_db.cpp b/src/structure_db.cpp
index f1aaa36..698ad42 100644
--- a/src/structure_db.cpp
+++ b/src/structure_db.cpp
@@ -256,16 +256,28 @@ std::vector<Structure>::const_iterator StructureDB::begin() const {
 std::vector<Structure>::const_iterator StructureDB::end() const { 
     return structures.cend(); 
 }
-void StructureDB::dump_to_file(const std::string& filepath, size_t prec) const {
-  std::ofstream file(filepath, std::ios::app);  // Open in append mode
-  if (!file.is_open()) {
-    std::cerr << "Error: Could not open file for writing: " << filepath << std::endl;
-    return;
-  }
-  for (const auto &s : structures) {
-    s.dump_to_file(file,prec);
-  }
-  file.close();
+void StructureDB::dump_to_file(const std::string& filepath, bool append, size_t prec) const
+{
+    // Determines the file open mode based on append flag
+    //   - ios::out means file is opened in output mode
+    //   - ios::app means all write operations append at the end
+    //   - ios::trunc (the default if append is false) removes existing contents
+    std::ios_base::openmode mode = std::ios::out;
+    if (append) {
+        mode |= std::ios::app;
+    } else {
+        mode |= std::ios::trunc;
+    }
+
+    std::ofstream file(filepath, mode);
+    if (!file.is_open()) {
+        std::cerr << "Error: Could not open file for writing: " << filepath << std::endl;
+        return;
+    }
+    for (const auto &s : structures) {
+        s.dump_to_file(file, prec);  
+    }
+    file.close();
 }
 
 std::string StructureDB::summary() const {
@@ -335,12 +347,12 @@ void StructureDB::parseFile(const std::string& filename)
     delete[] buffer;
 
     // Print the results
-    std::cout << "Found " << blockLineCounts.size() << " blocks.\n";
-    for (size_t i = 0; i < blockLineCounts.size(); i+=1000)
-    {
-        std::cout << "Block " << i << " has "
-                  << blockLineCounts[i] << " atoms\n";
-    }
+    // std::cout << "Found " << blockLineCounts.size() << " blocks.\n";
+    // for (size_t i = 0; i < blockLineCounts.size(); i+=1000)
+    // {
+    //     std::cout << "Block " << i << " has "
+    //               << blockLineCounts[i] << " atoms\n";
+    // }
 }
 
 bool StructureDB::isBlankLine(const std::string& line) const
-- 
GitLab