From dfa76e78c6a6289f0c6811c7c44e1745aa8e5f9a Mon Sep 17 00:00:00 2001 From: mkirsz <s1351949@sms.ed.ac.uk> Date: Fri, 22 Nov 2024 23:55:23 +0000 Subject: [PATCH] Update to random_shuffle --- include/tadah/mlip/models/basis.h | 153 +++++++++++++++--------------- 1 file changed, 79 insertions(+), 74 deletions(-) diff --git a/include/tadah/mlip/models/basis.h b/include/tadah/mlip/models/basis.h index 74a4071..80f2bbe 100644 --- a/include/tadah/mlip/models/basis.h +++ b/include/tadah/mlip/models/basis.h @@ -6,98 +6,103 @@ #include <tadah/core/core_types.h> #include <tadah/core/config.h> +#include <algorithm> #include <numeric> #include <stdexcept> #include <vector> template <typename K> class Basis { - private: - Config &config; - int verbose; - public: - Matrix b; - t_type T; // Vectors corresponding to basis vectors - Basis(Config &c): - config(c), - verbose(c.get<int>("VERBOSE")) +private: + Config &config; + int verbose; +public: + Matrix b; + t_type T; // Vectors corresponding to basis vectors + Basis(Config &c): + config(c), + verbose(c.get<int>("VERBOSE")) - {} + {} - void set_basis(Matrix &b_) { - b=b_; - } - void build_random_basis(size_t s, StDescriptorsDB &st_desc_db) { + void set_basis(Matrix &b_) { + b=b_; + } + void build_random_basis(size_t s, StDescriptorsDB &st_desc_db) { - // generate indices - std::vector<std::tuple<size_t,size_t>> indices; - size_t counter=0; - for( size_t st = 0; st < st_desc_db.size(); st++ ) { - for( size_t a = 0; a < st_desc_db(st).naed() ; a++ ) { - indices.push_back(std::tuple<size_t,size_t>(st,a)); - counter++; - } - } + // generate indices + std::vector<std::tuple<size_t,size_t>> indices; + size_t counter=0; + for( size_t st = 0; st < st_desc_db.size(); st++ ) { + for( size_t a = 0; a < st_desc_db(st).naed() ; a++ ) { + indices.push_back(std::tuple<size_t,size_t>(st,a)); + counter++; + } + } - if (counter < s) { - throw std::runtime_error("The number of requestd basis vectors is\n \ - larger than the amount of available AEDs\n"); - } + if (counter < s) { + throw std::runtime_error("The number of requestd basis vectors is\n \ +larger than the amount of available AEDs\n"); + } - std::random_shuffle(indices.begin(), indices.end()); + std::random_device rd; + std::default_random_engine rng(rd()); // Initialize random engine + std::shuffle(indices.begin(), indices.end(),rng); - b.resize(st_desc_db(0).dim(),s); - b.set_zero(); - // set first basis function as "bias vector" - b(0,0)=1; - for (size_t i=1; i<s; ++i) { - size_t st = std::get<0>(indices[i]); - size_t a = std::get<1>(indices[i]); - const aed_type2 &aed = st_desc_db(st).get_aed(a); - for (size_t j=0; j<aed.size(); ++j) { - b(j,i)=aed[j]; - } - } - } - void prep_basis_for_krr(StDescriptorsDB &st_desc_db, - StructureDB &stdb) { + b.resize(st_desc_db(0).dim(),s); + b.set_zero(); + // set first basis function as "bias vector" + b(0,0)=1; + for (size_t i=1; i<s; ++i) { + size_t st = std::get<0>(indices[i]); + size_t a = std::get<1>(indices[i]); + const aed_type2 &aed = st_desc_db(st).get_aed(a); + for (size_t j=0; j<aed.size(); ++j) { + b(j,i)=aed[j]; + } + } + } + void prep_basis_for_krr(StDescriptorsDB &st_desc_db, + StructureDB &stdb) { - size_t s = config.get<size_t>("SBASIS"); + size_t s = config.get<size_t>("SBASIS"); - if (stdb.size() != st_desc_db.size()) { - throw std::runtime_error("The size of StructureDB is different \n \ - from StDescriptorsDB. This is likely to be a bug.\n"); - } + if (stdb.size() != st_desc_db.size()) { + throw std::runtime_error("The size of StructureDB is different \n \ +from StDescriptorsDB. This is likely to be a bug.\n"); + } - // generate indices - std::vector<size_t> indices; - for( size_t st = 0; st < stdb.size(); st++ ) { - indices.push_back(st); - } + // generate indices + std::vector<size_t> indices; + for( size_t st = 0; st < stdb.size(); st++ ) { + indices.push_back(st); + } - if (indices.size() < s) { - throw std::runtime_error("The number of requestd basis vectors is\n \ - larger than the amount of available AEDs\n"); - } + if (indices.size() < s) { + throw std::runtime_error("The number of requestd basis vectors is\n \ +larger than the amount of available AEDs\n"); + } - std::random_shuffle(indices.begin(), indices.end()); + std::random_device rd; + std::default_random_engine rng(rd()); // Initialize random engine + std::shuffle(indices.begin(), indices.end(),rng); - b.resize(st_desc_db(0).dim(),s); - T.resize(s); - b.set_zero(); - b(0,0)=1; - // Here we add all atomic energy descriptors into single - // descriptor which represents total energy of this configuration - for (size_t i=1; i<s; ++i) { - const size_t st = indices[i]; - T(i)=stdb(st).energy/st_desc_db(st).naed(); - for( size_t a=0; a<st_desc_db(st).naed(); a++ ) { - const aed_type2 &aed = st_desc_db(st).get_aed(a); - for (size_t j=0; j<aed.size(); ++j) { - b(j,i)+=aed[j]/st_desc_db(st).naed(); - } - } - } + b.resize(st_desc_db(0).dim(),s); + T.resize(s); + b.set_zero(); + b(0,0)=1; + // Here we add all atomic energy descriptors into single + // descriptor which represents total energy of this configuration + for (size_t i=1; i<s; ++i) { + const size_t st = indices[i]; + T(i)=stdb(st).energy/st_desc_db(st).naed(); + for( size_t a=0; a<st_desc_db(st).naed(); a++ ) { + const aed_type2 &aed = st_desc_db(st).get_aed(a); + for (size_t j=0; j<aed.size(); ++j) { + b(j,i)+=aed[j]/st_desc_db(st).naed(); } + } + } + } }; #endif -- GitLab