From d5fede06ba8998b0f49cbe73a9a83bd9554dbcd5 Mon Sep 17 00:00:00 2001 From: mkirsz <s1351949@sms.ed.ac.uk> Date: Tue, 25 Feb 2025 15:26:07 +0000 Subject: [PATCH] Improved training weighting --- .../tadah/mlip/design_matrix/design_matrix.h | 39 +++------------- include/tadah/mlip/structure_db.h | 2 +- include/tadah/mlip/trainer.h | 1 - src/nn_finder.cpp | 4 +- src/structure_db.cpp | 44 ++++++++++++------- 5 files changed, 37 insertions(+), 53 deletions(-) diff --git a/include/tadah/mlip/design_matrix/design_matrix.h b/include/tadah/mlip/design_matrix/design_matrix.h index d78cc11..f4a2e71 100644 --- a/include/tadah/mlip/design_matrix/design_matrix.h +++ b/include/tadah/mlip/design_matrix/design_matrix.h @@ -1,6 +1,7 @@ #ifndef DESIGN_MATRIX_H #define DESIGN_MATRIX_H +#include <cmath> #include <tadah/mlip/st_descriptors_db.h> #include <tadah/mlip/structure_db.h> #include <tadah/mlip/descriptors_calc.h> @@ -178,9 +179,8 @@ public: compute_stdevs(stdb); fill_T(stdb); compute_wfactors(stdb); // call after fill_T - // for opm we need to find first rows for each structure - + // for opm we need to find first rows for each structure std::vector<size_t> rows(stdb.size()); size_t row=0; for (size_t s=0; s<stdb.size(); ++s) { @@ -216,19 +216,11 @@ public: } void build(size_t &row, const Structure &st, const StDescriptors &st_d) { - // double escale = 1; - // if (scale) escale = st.eweight*eweightglob/st.natoms(); // TODO f.calc_phi_energy_row(ws->Phi,row,st,st_d); if (force) { - // double fscale = 1; - // if (scale) fscale = st.fweight*fweightglob/st.natoms()/3.0; // TODO f.calc_phi_force_rows(ws->Phi,row,st,st_d); } if (stress) { - // double sscale_arr[6] {1,1,1,1,1,1}; - // if (scale) - // for(size_t xy=0;xy<6;++xy) - // sscale_arr[xy] = st.sweight*sweightglob/6.0; // TODO f.calc_phi_stress_rows(ws->Phi,row,st,st_d); } } @@ -236,13 +228,10 @@ public: size_t j=start; for (size_t s=0; s<stdb.size(); ++s) { - // double escale = 1; ws->Tlabels(j) = 0; ws->T(j++) = stdb(s).energy; if (force) { - // double fscale = 1; - // if (scale) fscale = stdb(s).fweight*fweightglob/stdb(s).natoms()/3.0; for (const Atom &a : stdb(s).atoms) { ws->Tlabels(j) = 1; ws->T(j++) = a.force(0); @@ -253,11 +242,6 @@ public: } } if (stress) { - // double sscale_arr[6] {1,1,1,1,1,1}; - // if (scale) - // for(size_t xy=0;xy<6;++xy) - // sscale_arr[xy] = stdb(s).sweight*sweightglob/6.0; - // size_t xy=0; for (size_t x=0; x<3; ++x) { for (size_t y=x; y<3; ++y) { ws->Tlabels(j) = 2; @@ -323,7 +307,6 @@ private: num_forces +=3*stdb(s).natoms(); } - //e_std_dev = std::sqrt((evec - evec.mean()).square().sum()/(evec.size()-1)); e_std_dev = evec.std_dev(evec.mean(), evec.size()-1); if (verbose) std::cout << "Energy standard deviation (eV/atom): " << e_std_dev << std::endl; @@ -334,7 +317,6 @@ private: } } - if (force) { size_t j=0; t_type fvec(num_forces); @@ -345,12 +327,8 @@ private: fvec(j++) = a.force(2); } } - //fvec /= e_std_dev; - //svec /= e_std_dev; - // e_std_dev has units of energy - // f_std_dev has units of inverse distance f_std_dev = fvec.std_dev(fvec.mean(),fvec.size()-1); - if (verbose) std::cout << "Force standard deviation (A^-1): " << f_std_dev << std::endl; + if (verbose) std::cout << "Force standard deviation (ev/A): " << f_std_dev << std::endl; } config.add("ESTDEV",e_std_dev); @@ -362,10 +340,9 @@ private: void compute_wfactors(const StructureDB &stdb) { size_t j=0; for (size_t s=0; s<stdb.size(); ++s) { - double escale = stdb(s).eweight*eweightglob; - ws->wfactors(j++) = escale; + ws->wfactors(j++) = std::sqrt(stdb(s).eweight*eweightglob); if (force) { - double fscale = stdb(s).fweight*fweightglob; + double fscale = std::sqrt(stdb(s).fweight*fweightglob); for (size_t a=0; a<stdb(s).natoms(); ++a) { ws->wfactors(j++) = fscale; ws->wfactors(j++) = fscale; @@ -373,13 +350,9 @@ private: } } if (stress) { - double sscale_arr[6] {1,1,1,1,1,1}; - for(size_t xy=0;xy<6;++xy) - sscale_arr[xy] = stdb(s).sweight*sweightglob/6.0; - int xy=0; for (size_t x=0; x<3; ++x) { for (size_t y=x; y<3; ++y) { - ws->wfactors(j++)=sscale_arr[xy++]; + ws->wfactors(j++) = std::sqrt(stdb(s).sweight*sweightglob); } } } diff --git a/include/tadah/mlip/structure_db.h b/include/tadah/mlip/structure_db.h index 0cab4a6..b96eac0 100644 --- a/include/tadah/mlip/structure_db.h +++ b/include/tadah/mlip/structure_db.h @@ -169,7 +169,7 @@ struct StructureDB { std::vector<Structure>::const_iterator end() const; /** Method to dump class content to a file */ - void dump_to_file(const std::string& filepath, size_t prec=12) const; + void dump_to_file(const std::string& filepath, bool append, size_t prec=12) const; // Public method that reads the file, counts blocks and line counts per block, // then prints the results to std::cout. diff --git a/include/tadah/mlip/trainer.h b/include/tadah/mlip/trainer.h index 9bbf598..e7a7776 100644 --- a/include/tadah/mlip/trainer.h +++ b/include/tadah/mlip/trainer.h @@ -2,7 +2,6 @@ #define MPI_TRAINER_H #include <tadah/mlip/descriptors_calc.h> #include <tadah/mlip/design_matrix/design_matrix.h> -#include <tadah/mlip/trainer.h> #include <tadah/mlip/design_matrix/functions/dm_function_base.h> #include <tadah/mlip/models/m_tadah_base.h> #include <tadah/mlip/nn_finder.h> diff --git a/src/nn_finder.cpp b/src/nn_finder.cpp index 81569cb..651077f 100644 --- a/src/nn_finder.cpp +++ b/src/nn_finder.cpp @@ -280,7 +280,7 @@ void NNFinder::calc(StructureDB &stdb) } auto t1 = std::chrono::steady_clock::now(); double seconds = std::chrono::duration<double>(t1 - t0).count(); - std::cout << "calc(StructureDB &stdb) for-loop took " - << seconds << " seconds\n"; + // std::cout << "calc(StructureDB &stdb) for-loop took " + // << seconds << " seconds\n"; } diff --git a/src/structure_db.cpp b/src/structure_db.cpp index f1aaa36..698ad42 100644 --- a/src/structure_db.cpp +++ b/src/structure_db.cpp @@ -256,16 +256,28 @@ std::vector<Structure>::const_iterator StructureDB::begin() const { std::vector<Structure>::const_iterator StructureDB::end() const { return structures.cend(); } -void StructureDB::dump_to_file(const std::string& filepath, size_t prec) const { - std::ofstream file(filepath, std::ios::app); // Open in append mode - if (!file.is_open()) { - std::cerr << "Error: Could not open file for writing: " << filepath << std::endl; - return; - } - for (const auto &s : structures) { - s.dump_to_file(file,prec); - } - file.close(); +void StructureDB::dump_to_file(const std::string& filepath, bool append, size_t prec) const +{ + // Determines the file open mode based on append flag + // - ios::out means file is opened in output mode + // - ios::app means all write operations append at the end + // - ios::trunc (the default if append is false) removes existing contents + std::ios_base::openmode mode = std::ios::out; + if (append) { + mode |= std::ios::app; + } else { + mode |= std::ios::trunc; + } + + std::ofstream file(filepath, mode); + if (!file.is_open()) { + std::cerr << "Error: Could not open file for writing: " << filepath << std::endl; + return; + } + for (const auto &s : structures) { + s.dump_to_file(file, prec); + } + file.close(); } std::string StructureDB::summary() const { @@ -335,12 +347,12 @@ void StructureDB::parseFile(const std::string& filename) delete[] buffer; // Print the results - std::cout << "Found " << blockLineCounts.size() << " blocks.\n"; - for (size_t i = 0; i < blockLineCounts.size(); i+=1000) - { - std::cout << "Block " << i << " has " - << blockLineCounts[i] << " atoms\n"; - } + // std::cout << "Found " << blockLineCounts.size() << " blocks.\n"; + // for (size_t i = 0; i < blockLineCounts.size(); i+=1000) + // { + // std::cout << "Block " << i << " has " + // << blockLineCounts[i] << " atoms\n"; + // } } bool StructureDB::isBlankLine(const std::string& line) const -- GitLab