Skip to content
Snippets Groups Projects
Commit 5ad667b2 authored by Marcin Kirsz's avatar Marcin Kirsz
Browse files

Merge branch 'develop' into 'main'

Update to random_shuffle

See merge request !10
parents f9d13ded dfa76e78
No related branches found
No related tags found
1 merge request!10Update to random_shuffle
Pipeline #48433 passed
...@@ -6,98 +6,103 @@ ...@@ -6,98 +6,103 @@
#include <tadah/core/core_types.h> #include <tadah/core/core_types.h>
#include <tadah/core/config.h> #include <tadah/core/config.h>
#include <algorithm>
#include <numeric> #include <numeric>
#include <stdexcept> #include <stdexcept>
#include <vector> #include <vector>
template <typename K> template <typename K>
class Basis { class Basis {
private: private:
Config &config; Config &config;
int verbose; int verbose;
public: public:
Matrix b; Matrix b;
t_type T; // Vectors corresponding to basis vectors t_type T; // Vectors corresponding to basis vectors
Basis(Config &c): Basis(Config &c):
config(c), config(c),
verbose(c.get<int>("VERBOSE")) verbose(c.get<int>("VERBOSE"))
{} {}
void set_basis(Matrix &b_) { void set_basis(Matrix &b_) {
b=b_; b=b_;
} }
void build_random_basis(size_t s, StDescriptorsDB &st_desc_db) { void build_random_basis(size_t s, StDescriptorsDB &st_desc_db) {
// generate indices // generate indices
std::vector<std::tuple<size_t,size_t>> indices; std::vector<std::tuple<size_t,size_t>> indices;
size_t counter=0; size_t counter=0;
for( size_t st = 0; st < st_desc_db.size(); st++ ) { for( size_t st = 0; st < st_desc_db.size(); st++ ) {
for( size_t a = 0; a < st_desc_db(st).naed() ; a++ ) { for( size_t a = 0; a < st_desc_db(st).naed() ; a++ ) {
indices.push_back(std::tuple<size_t,size_t>(st,a)); indices.push_back(std::tuple<size_t,size_t>(st,a));
counter++; counter++;
} }
} }
if (counter < s) { if (counter < s) {
throw std::runtime_error("The number of requestd basis vectors is\n \ throw std::runtime_error("The number of requestd basis vectors is\n \
larger than the amount of available AEDs\n"); larger than the amount of available AEDs\n");
} }
std::random_shuffle(indices.begin(), indices.end()); std::random_device rd;
std::default_random_engine rng(rd()); // Initialize random engine
std::shuffle(indices.begin(), indices.end(),rng);
b.resize(st_desc_db(0).dim(),s); b.resize(st_desc_db(0).dim(),s);
b.set_zero(); b.set_zero();
// set first basis function as "bias vector" // set first basis function as "bias vector"
b(0,0)=1; b(0,0)=1;
for (size_t i=1; i<s; ++i) { for (size_t i=1; i<s; ++i) {
size_t st = std::get<0>(indices[i]); size_t st = std::get<0>(indices[i]);
size_t a = std::get<1>(indices[i]); size_t a = std::get<1>(indices[i]);
const aed_type2 &aed = st_desc_db(st).get_aed(a); const aed_type2 &aed = st_desc_db(st).get_aed(a);
for (size_t j=0; j<aed.size(); ++j) { for (size_t j=0; j<aed.size(); ++j) {
b(j,i)=aed[j]; b(j,i)=aed[j];
} }
} }
} }
void prep_basis_for_krr(StDescriptorsDB &st_desc_db, void prep_basis_for_krr(StDescriptorsDB &st_desc_db,
StructureDB &stdb) { StructureDB &stdb) {
size_t s = config.get<size_t>("SBASIS"); size_t s = config.get<size_t>("SBASIS");
if (stdb.size() != st_desc_db.size()) { if (stdb.size() != st_desc_db.size()) {
throw std::runtime_error("The size of StructureDB is different \n \ throw std::runtime_error("The size of StructureDB is different \n \
from StDescriptorsDB. This is likely to be a bug.\n"); from StDescriptorsDB. This is likely to be a bug.\n");
} }
// generate indices // generate indices
std::vector<size_t> indices; std::vector<size_t> indices;
for( size_t st = 0; st < stdb.size(); st++ ) { for( size_t st = 0; st < stdb.size(); st++ ) {
indices.push_back(st); indices.push_back(st);
} }
if (indices.size() < s) { if (indices.size() < s) {
throw std::runtime_error("The number of requestd basis vectors is\n \ throw std::runtime_error("The number of requestd basis vectors is\n \
larger than the amount of available AEDs\n"); larger than the amount of available AEDs\n");
} }
std::random_shuffle(indices.begin(), indices.end()); std::random_device rd;
std::default_random_engine rng(rd()); // Initialize random engine
std::shuffle(indices.begin(), indices.end(),rng);
b.resize(st_desc_db(0).dim(),s); b.resize(st_desc_db(0).dim(),s);
T.resize(s); T.resize(s);
b.set_zero(); b.set_zero();
b(0,0)=1; b(0,0)=1;
// Here we add all atomic energy descriptors into single // Here we add all atomic energy descriptors into single
// descriptor which represents total energy of this configuration // descriptor which represents total energy of this configuration
for (size_t i=1; i<s; ++i) { for (size_t i=1; i<s; ++i) {
const size_t st = indices[i]; const size_t st = indices[i];
T(i)=stdb(st).energy/st_desc_db(st).naed(); T(i)=stdb(st).energy/st_desc_db(st).naed();
for( size_t a=0; a<st_desc_db(st).naed(); a++ ) { for( size_t a=0; a<st_desc_db(st).naed(); a++ ) {
const aed_type2 &aed = st_desc_db(st).get_aed(a); const aed_type2 &aed = st_desc_db(st).get_aed(a);
for (size_t j=0; j<aed.size(); ++j) { for (size_t j=0; j<aed.size(); ++j) {
b(j,i)+=aed[j]/st_desc_db(st).naed(); b(j,i)+=aed[j]/st_desc_db(st).naed();
}
}
}
} }
}
}
}
}; };
#endif #endif
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment