Skip to content
Snippets Groups Projects
Commit 5ad667b2 authored by Marcin Kirsz's avatar Marcin Kirsz
Browse files

Merge branch 'develop' into 'main'

Update to random_shuffle

See merge request !10
parents f9d13ded dfa76e78
No related branches found
No related tags found
1 merge request!10Update to random_shuffle
Pipeline #48433 passed
......@@ -6,98 +6,103 @@
#include <tadah/core/core_types.h>
#include <tadah/core/config.h>
#include <algorithm>
#include <numeric>
#include <stdexcept>
#include <vector>
template <typename K>
class Basis {
private:
Config &config;
int verbose;
public:
Matrix b;
t_type T; // Vectors corresponding to basis vectors
Basis(Config &c):
config(c),
verbose(c.get<int>("VERBOSE"))
private:
Config &config;
int verbose;
public:
Matrix b;
t_type T; // Vectors corresponding to basis vectors
Basis(Config &c):
config(c),
verbose(c.get<int>("VERBOSE"))
{}
{}
void set_basis(Matrix &b_) {
b=b_;
}
void build_random_basis(size_t s, StDescriptorsDB &st_desc_db) {
void set_basis(Matrix &b_) {
b=b_;
}
void build_random_basis(size_t s, StDescriptorsDB &st_desc_db) {
// generate indices
std::vector<std::tuple<size_t,size_t>> indices;
size_t counter=0;
for( size_t st = 0; st < st_desc_db.size(); st++ ) {
for( size_t a = 0; a < st_desc_db(st).naed() ; a++ ) {
indices.push_back(std::tuple<size_t,size_t>(st,a));
counter++;
}
}
// generate indices
std::vector<std::tuple<size_t,size_t>> indices;
size_t counter=0;
for( size_t st = 0; st < st_desc_db.size(); st++ ) {
for( size_t a = 0; a < st_desc_db(st).naed() ; a++ ) {
indices.push_back(std::tuple<size_t,size_t>(st,a));
counter++;
}
}
if (counter < s) {
throw std::runtime_error("The number of requestd basis vectors is\n \
larger than the amount of available AEDs\n");
}
if (counter < s) {
throw std::runtime_error("The number of requestd basis vectors is\n \
larger than the amount of available AEDs\n");
}
std::random_shuffle(indices.begin(), indices.end());
std::random_device rd;
std::default_random_engine rng(rd()); // Initialize random engine
std::shuffle(indices.begin(), indices.end(),rng);
b.resize(st_desc_db(0).dim(),s);
b.set_zero();
// set first basis function as "bias vector"
b(0,0)=1;
for (size_t i=1; i<s; ++i) {
size_t st = std::get<0>(indices[i]);
size_t a = std::get<1>(indices[i]);
const aed_type2 &aed = st_desc_db(st).get_aed(a);
for (size_t j=0; j<aed.size(); ++j) {
b(j,i)=aed[j];
}
}
}
void prep_basis_for_krr(StDescriptorsDB &st_desc_db,
StructureDB &stdb) {
b.resize(st_desc_db(0).dim(),s);
b.set_zero();
// set first basis function as "bias vector"
b(0,0)=1;
for (size_t i=1; i<s; ++i) {
size_t st = std::get<0>(indices[i]);
size_t a = std::get<1>(indices[i]);
const aed_type2 &aed = st_desc_db(st).get_aed(a);
for (size_t j=0; j<aed.size(); ++j) {
b(j,i)=aed[j];
}
}
}
void prep_basis_for_krr(StDescriptorsDB &st_desc_db,
StructureDB &stdb) {
size_t s = config.get<size_t>("SBASIS");
size_t s = config.get<size_t>("SBASIS");
if (stdb.size() != st_desc_db.size()) {
throw std::runtime_error("The size of StructureDB is different \n \
from StDescriptorsDB. This is likely to be a bug.\n");
}
if (stdb.size() != st_desc_db.size()) {
throw std::runtime_error("The size of StructureDB is different \n \
from StDescriptorsDB. This is likely to be a bug.\n");
}
// generate indices
std::vector<size_t> indices;
for( size_t st = 0; st < stdb.size(); st++ ) {
indices.push_back(st);
}
// generate indices
std::vector<size_t> indices;
for( size_t st = 0; st < stdb.size(); st++ ) {
indices.push_back(st);
}
if (indices.size() < s) {
throw std::runtime_error("The number of requestd basis vectors is\n \
larger than the amount of available AEDs\n");
}
if (indices.size() < s) {
throw std::runtime_error("The number of requestd basis vectors is\n \
larger than the amount of available AEDs\n");
}
std::random_shuffle(indices.begin(), indices.end());
std::random_device rd;
std::default_random_engine rng(rd()); // Initialize random engine
std::shuffle(indices.begin(), indices.end(),rng);
b.resize(st_desc_db(0).dim(),s);
T.resize(s);
b.set_zero();
b(0,0)=1;
// Here we add all atomic energy descriptors into single
// descriptor which represents total energy of this configuration
for (size_t i=1; i<s; ++i) {
const size_t st = indices[i];
T(i)=stdb(st).energy/st_desc_db(st).naed();
for( size_t a=0; a<st_desc_db(st).naed(); a++ ) {
const aed_type2 &aed = st_desc_db(st).get_aed(a);
for (size_t j=0; j<aed.size(); ++j) {
b(j,i)+=aed[j]/st_desc_db(st).naed();
}
}
}
b.resize(st_desc_db(0).dim(),s);
T.resize(s);
b.set_zero();
b(0,0)=1;
// Here we add all atomic energy descriptors into single
// descriptor which represents total energy of this configuration
for (size_t i=1; i<s; ++i) {
const size_t st = indices[i];
T(i)=stdb(st).energy/st_desc_db(st).naed();
for( size_t a=0; a<st_desc_db(st).naed(); a++ ) {
const aed_type2 &aed = st_desc_db(st).get_aed(a);
for (size_t j=0; j<aed.size(); ++j) {
b(j,i)+=aed[j]/st_desc_db(st).naed();
}
}
}
}
};
#endif
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment