Skip to content
Snippets Groups Projects
Commit 36b0cfdb authored by Marcin Kirsz's avatar Marcin Kirsz
Browse files

Merge branch 'develop' into 'main'

Subtracting mean for error prediction

See merge request !5
parents 1414cad1 f7587204
No related branches found
No related tags found
1 merge request!5Subtracting mean for error prediction
Pipeline #48054 passed
......@@ -210,10 +210,8 @@ class M_BLR: public M_Tadah_Base, public M_BLR_Train<BF> {
dm.scale=false; // do not scale energy, forces and stresses
dm.build(stdb,norm,dc);
predicted_error = T_MDMT_diag(dm.Phi, Sigma);
double beta = config.template get<double>("BETA");
predicted_error += 1.0/beta;
double pmean = sqrt(predicted_error.mean());
// compute energy, forces and stresses
aed_type2 Tpred = T_dgemv(dm.Phi, weights);
......@@ -225,22 +223,21 @@ class M_BLR: public M_Tadah_Base, public M_BLR_Train<BF> {
for (size_t s=0; s<stdb.size(); ++s) {
stdb_(s) = Structure(stdb(s));
predicted_error(i) = sqrt(predicted_error(i));
predicted_error(i) = (sqrt(predicted_error(i))-pmean)/stdb(s).natoms();
stdb_(s).energy = Tpred(i++);
if (config_pred.get<bool>("FORCE")) {
for (size_t a=0; a<stdb(s).natoms(); ++a) {
for (size_t k=0; k<3; ++k) {
predicted_error(i) = sqrt(predicted_error(i));
stdb_(s).atoms[a].force[k] = Tpred(i++);
predicted_error(i) = (sqrt(predicted_error(i))-pmean);
}
}
}
if (config_pred.get<bool>("STRESS")) {
for (size_t x=0; x<3; ++x) {
for (size_t y=x; y<3; ++y) {
predicted_error(i) = sqrt(predicted_error(i));
stdb_(s).stress(x,y) = Tpred(i++);
predicted_error(i) = (sqrt(predicted_error(i))-pmean);
if (x!=y)
stdb_(s).stress(y,x) = stdb_(s).stress(x,y);
}
......
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment