From 34c8751bef3c3a26506dc3c35073f0e5925cae77 Mon Sep 17 00:00:00 2001 From: Marcin Kirsz <mkirsz@ed.ac.uk> Date: Mon, 3 Mar 2025 16:53:23 +0000 Subject: [PATCH] WiP --- include/tadah/mlip/models/m_blr.h | 255 ++++++++++++----------- include/tadah/mlip/models/m_krr.h | 55 ++--- include/tadah/mlip/models/m_tadah_base.h | 3 +- 3 files changed, 168 insertions(+), 145 deletions(-) diff --git a/include/tadah/mlip/models/m_blr.h b/include/tadah/mlip/models/m_blr.h index db4874b..809f089 100644 --- a/include/tadah/mlip/models/m_blr.h +++ b/include/tadah/mlip/models/m_blr.h @@ -1,19 +1,19 @@ #ifndef M_BLR_H #define M_BLR_H -#include <tadah/mlip/models/m_tadah_base.h> +#include <tadah/core/config.h> #include <tadah/mlip/descriptors_calc.h> #include <tadah/mlip/design_matrix/design_matrix.h> #include <tadah/mlip/design_matrix/functions/dm_function_base.h> +#include <tadah/mlip/memory/IMLIPWorkspaceManager.h> +#include <tadah/mlip/models/m_tadah_base.h> #include <tadah/mlip/normaliser.h> #include <tadah/models/m_blr_train.h> -#include <tadah/core/config.h> -#include <tadah/mlip/memory/IMLIPWorkspaceManager.h> +#include <iostream> #include <limits> #include <stdexcept> #include <type_traits> -#include <iostream> namespace tadah { namespace mlip { @@ -21,34 +21,39 @@ namespace mlip { * @class M_BLR * @brief Bayesian Linear Regression (BLR). * - * This class implements Bayesian Linear Regression, a statistical method to make predictions using linear models with both linear and nonlinear features. + * This class implements Bayesian Linear Regression, a statistical method to + * make predictions using linear models with both linear and nonlinear features. * * **Model Supported Training Modes**: - * - **LINEAR**: Uses Ordinary Least Squares or Ridge Regression for linear relationships. - * - **NONLINEAR**: Utilizes basis functions to handle nonlinear input spaces, transforming input descriptors into higher-dimensional feature spaces. For example, polynomial transformations. + * - **LINEAR**: Uses Ordinary Least Squares or Ridge Regression for linear + * relationships. + * - **NONLINEAR**: Utilizes basis functions to handle nonlinear input spaces, + * transforming input descriptors into higher-dimensional feature spaces. For + * example, polynomial transformations. * * **Prediction**: - * - Computes predictions as a weighted sum of basis functions applied to input vectors. + * - Computes predictions as a weighted sum of basis functions applied to input + * vectors. * * **Training**: - * - Employs regularized least squares, allowing for optional regularization through the \f$\lambda\f$ parameter. + * - Employs regularized least squares, allowing for optional regularization + * through the \f$\lambda\f$ parameter. * - Ordinary Least Squares (OLS) is a special case when \f$\lambda = 0\f$. * * **Configuration Options**: - * - **LAMBDA**: Set to `0` for OLS, a positive value for specified regularization, or `-1` for automatic tuning using evidence approximation. + * - **LAMBDA**: Set to `0` for OLS, a positive value for specified + * regularization, or `-1` for automatic tuning using evidence approximation. * * @tparam BF DM_BF_Base child, Basis function */ -template -<class BF=DM_Function_Base&> -class M_BLR: public M_Tadah_Base, public tadah::models::M_BLR_Train<BF> { +template <class BF = DM_Function_Base &> +class M_BLR : public M_Tadah_Base, public tadah::models::M_BLR_Train<BF> { public: - using tadah::models::M_BLR_Train<BF>::config; using tadah::models::M_BLR_Train<BF>::bf; - /** + /** * @brief Initializes for training or prediction using a configuration. * * **Example**: @@ -56,63 +61,63 @@ public: * tadah::core::Config config("tadah::core::Config"); * M_BLR<tadah::models::BF_Linear> blr(config); * \endcode - * + * * @param c Configuration object. */ - M_BLR(tadah::core::Config &c): - tadah::models::M_BLR_Train<BF>(c), - desmat(tadah::models::M_BLR_Train<BF>::bf,c) - { + M_BLR(tadah::core::Config &c) + : tadah::models::M_BLR_Train<BF>(c), + desmat(tadah::models::M_BLR_Train<BF>::bf, c) { norm = Normaliser(c); } - /** - * @brief Initializes for training or prediction using a basis function and configuration. - * + /** + * @brief Initializes for training or prediction using a basis function and + * configuration. + * * @param bf Basis function. * @param c Configuration object. */ - M_BLR(BF &bf, tadah::core::Config &c): - tadah::models::M_BLR_Train<BF>(bf,c), - desmat(bf,c) - { + M_BLR(BF &bf, tadah::core::Config &c) + : tadah::models::M_BLR_Train<BF>(bf, c), desmat(bf, c) { norm = Normaliser(c); } - M_BLR(BF &bf, tadah::core::Config &c, tadah::mlip::memory::IMLIPWorkspaceManager& workspaceManager): - tadah::models::M_BLR_Train<BF>(bf,c,workspaceManager), - desmat(bf,c,workspaceManager) - { + M_BLR(BF &bf, tadah::core::Config &c, + tadah::mlip::memory::IMLIPWorkspaceManager &workspaceManager) + : tadah::models::M_BLR_Train<BF>(bf, c, workspaceManager), + desmat(bf, c, workspaceManager) { norm = Normaliser(c); } - double epredict(const tadah::core::aed_type &aed) const{ - return bf.epredict(weights,aed); + double epredict(const tadah::core::aed_type &aed) const { + return bf.epredict(weights, aed); }; - double fpredict(const tadah::core::fd_type &fdij, const tadah::core::aed_type &aedi, const size_t k) const{ - return bf.fpredict(weights,fdij,aedi,k); + double fpredict(const tadah::core::fd_type &fdij, + const tadah::core::aed_type &aedi, const size_t k) const { + return bf.fpredict(weights, fdij, aedi, k); } - tadah::core::force_type fpredict(const tadah::core::fd_type &fdij, const tadah::core::aed_type &aedi) const{ - return bf.fpredict(weights,fdij,aedi); + tadah::core::force_type fpredict(const tadah::core::fd_type &fdij, + const tadah::core::aed_type &aedi) const { + return bf.fpredict(weights, fdij, aedi); } void train(StDescriptorsDB &st_desc_db, const StructureDB &stdb) { - if(config.template get<bool>("NORM")) - norm = Normaliser(config,st_desc_db); + if (config.template get<bool>("NORM")) + norm = Normaliser(config, st_desc_db); - desmat.build(st_desc_db,stdb); + desmat.build(st_desc_db, stdb); train(desmat); } void train(const StructureDB &stdb, const NeighborListDB &nldb, DC_Base &dc) { - if(config.template get<bool>("NORM")) { + if (config.template get<bool>("NORM")) { - std::string force=config.template get<std::string>("FORCE"); - std::string stress=config.template get<std::string>("STRESS"); + std::string force = config.template get<std::string>("FORCE"); + std::string stress = config.template get<std::string>("STRESS"); config.remove("FORCE"); config.remove("STRESS"); @@ -121,7 +126,7 @@ public: StDescriptorsDB st_desc_db_temp = dc.calc(stdb, nldb); - if(config.template get<bool>("NORM")) { + if (config.template get<bool>("NORM")) { norm = Normaliser(config); norm.learn(st_desc_db_temp); // norm.normalise(st_desc_db_temp); @@ -133,24 +138,27 @@ public: config.add("STRESS", stress); } - desmat.build(stdb,nldb,norm,dc); + desmat.build(stdb, nldb, norm, dc); train(desmat); } - Structure predict(const tadah::core::Config &c, StDescriptors &std, const Structure &st, const StructureNeighborView &st_nb) { - if(config.template get<bool>("NORM") && !std.normalised && bf.get_label()!="tadah::models::BF_Linear") + Structure predict(const tadah::core::Config &c, StDescriptors &std, + const Structure &st, const StructureNeighborView &st_nb) { + if (config.template get<bool>("NORM") && !std.normalised && + bf.get_label() != "tadah::models::BF_Linear") norm.normalise(std); - return M_Tadah_Base::predict(c,std,st,st_nb); + return M_Tadah_Base::predict(c, std, st, st_nb); } - StructureDB predict(tadah::core::Config &c, const StructureDB &stdb, const NeighborListDB &nldb, DC_Base &dc) { - return M_Tadah_Base::predict(c,stdb,nldb,dc); + StructureDB predict(tadah::core::Config &c, const StructureDB &stdb, + const NeighborListDB &nldb, DC_Base &dc) { + return M_Tadah_Base::predict(c, stdb, nldb, dc); } tadah::core::Config get_param_file() { tadah::core::Config c = config; - //c.remove("ALPHA"); - //c.remove("BETA"); + // c.remove("ALPHA"); + // c.remove("BETA"); c.remove("DBFILE"); c.remove("FORCE"); c.remove("STRESS"); @@ -161,28 +169,28 @@ public: c.add("MODEL", label); c.add("MODEL", bf.get_label()); - for (size_t i=0;i<weights.size();++i) { + for (size_t i = 0; i < weights.size(); ++i) { c.add("WEIGHTS", weights(i)); } - if(config.template get<bool>("NORM")) { - for (size_t i=0;i<norm.mean.size();++i) { + if (config.template get<bool>("NORM")) { + for (size_t i = 0; i < norm.mean.size(); ++i) { c.add("NMEAN", norm.mean[i]); } - for (size_t i=0;i<norm.std_dev.size();++i) { + for (size_t i = 0; i < norm.std_dev.size(); ++i) { c.add("NSTDEV", norm.std_dev[i]); } } c.clear_internal_keys(); return c; } - StructureDB predict(tadah::core::Config config_pred, StructureDB &stdb, DC_Base &dc, - tadah::core::aed_type &predicted_error) { + StructureDB predict(tadah::core::Config config_pred, StructureDB &stdb, + DC_Base &dc, tadah::core::aed_type &predicted_error) { - tadah::models::LinearRegressor::read_sigma(config_pred,Sigma); - DesignMatrix<BF> dm(bf,config_pred); - dm.scale=false; // do not scale energy, forces and stresses - dm.build(stdb,norm,dc); + tadah::models::LinearRegressor::read_sigma(config_pred, Sigma); + DesignMatrix<BF> dm(bf, config_pred); + dm.scale = false; // do not scale energy, forces and stresses + dm.build(stdb, norm, dc); predicted_error = T_MDMT_diag(dm.getPhi(), Sigma); double pmean = sqrt(predicted_error.mean()); @@ -191,29 +199,35 @@ public: tadah::core::aed_type Tpred = T_dgemv(dm.getPhi(), weights); // Construct StructureDB object with predicted values - StructureDB stdb_; - stdb_.structures.resize(stdb.size()); - size_t i=0; - for (size_t s=0; s<stdb.size(); ++s) { - stdb_(s) = Structure(stdb(s)); - - predicted_error(i) = (sqrt(predicted_error(i))-pmean)/stdb(s).natoms(); + auto pinfo = stdb.exportPass1Info(); + StructureDB stdb_(pinfo, true); + size_t i = 0; + for (size_t s = 0; s < stdb.size(); ++s) { + /*stdb_(s) = Structure(stdb(s));*/ + + predicted_error(i) = + (sqrt(predicted_error(i)) - pmean) / stdb(s).natoms(); stdb_(s).energy = Tpred(i++); if (config_pred.get<bool>("FORCE")) { - for (size_t a=0; a<stdb(s).natoms(); ++a) { - for (size_t k=0; k<3; ++k) { - stdb_(s).atoms[a].force[k] = Tpred(i++); - predicted_error(i) = (sqrt(predicted_error(i))-pmean); - } + for (size_t a = 0; a < stdb(s).natoms(); ++a) { + stdb_(s).fx(a) = Tpred(i); + predicted_error(i) = (sqrt(predicted_error(i)) - pmean); + i++; + stdb_(s).fy(a) = Tpred(i); + predicted_error(i) = (sqrt(predicted_error(i)) - pmean); + i++; + stdb_(s).fz(a) = Tpred(i); + predicted_error(i) = (sqrt(predicted_error(i)) - pmean); + i++; } } if (config_pred.get<bool>("STRESS")) { - for (size_t x=0; x<3; ++x) { - for (size_t y=x; y<3; ++y) { - stdb_(s).stress(x,y) = Tpred(i++); - predicted_error(i) = (sqrt(predicted_error(i))-pmean); - if (x!=y) - stdb_(s).stress(y,x) = stdb_(s).stress(x,y); + for (size_t x = 0; x < 3; ++x) { + for (size_t y = x; y < 3; ++y) { + stdb_(s).stress(x, y) = Tpred(i++); + predicted_error(i) = (sqrt(predicted_error(i)) - pmean); + if (x != y) + stdb_(s).stress(y, x) = stdb_(s).stress(x, y); } } } @@ -221,7 +235,8 @@ public: return stdb_; } StructureDB predict(StructureDB &stdb) { - if(!trained) throw std::runtime_error("This object is not trained!\n\ + if (!trained) + throw std::runtime_error("This object is not trained!\n\ Hint: check different predict() methods."); tadah::core::phi_type &Phi = desmat.getPhi(); @@ -229,32 +244,35 @@ Hint: check different predict() methods."); // compute energy, forces and stresses tadah::core::aed_type Tpred = T_dgemv(Phi, weights); - double eweightglob=config.template get<double>("EWEIGHT"); - double fweightglob=config.template get<double>("FWEIGHT"); - double sweightglob=config.template get<double>("SWEIGHT"); + double eweightglob = config.template get<double>("EWEIGHT"); + double fweightglob = config.template get<double>("FWEIGHT"); + double sweightglob = config.template get<double>("SWEIGHT"); // Construct StructureDB object with predicted values - StructureDB stdb_; - stdb_.structures.resize(stdb.size()); - size_t s=0; - size_t i=0; - while (i<Phi.rows()) { - - stdb_(s).energy = Tpred(i++)*stdb(s).natoms()/eweightglob/stdb(s).eweight; + auto pinfo = stdb.exportPass1Info(); + StructureDB stdb_(pinfo, true); + size_t s = 0; + size_t i = 0; + while (i < Phi.rows()) { + + stdb_(s).energy = + Tpred(i++) * stdb(s).natoms() / eweightglob / stdb(s).eweight; if (config.template get<bool>("FORCE")) { - stdb_(s).atoms.resize(stdb(s).natoms()); - for (size_t a=0; a<stdb(s).natoms(); ++a) { - for (size_t k=0; k<3; ++k) { - stdb_(s).atoms[a].force[k] = Tpred(i++)/fweightglob/stdb(s).fweight; - } + for (size_t a = 0; a < stdb(s).natoms(); ++a) { + stdb_(s).fx(a) = Tpred(i); + i++; + stdb_(s).fy(a) = Tpred(i); + i++; + stdb_(s).fz(a) = Tpred(i); + i++; } } if (config.template get<bool>("STRESS")) { - for (size_t x=0; x<3; ++x) { - for (size_t y=x; y<3; ++y) { - stdb_(s).stress(x,y) = Tpred(i++)/sweightglob/stdb(s).sweight; - if (x!=y) - stdb_(s).stress(y,x) = stdb_(s).stress(x,y); + for (size_t x = 0; x < 3; ++x) { + for (size_t y = x; y < 3; ++y) { + stdb_(s).stress(x, y) = Tpred(i++) / sweightglob / stdb(s).sweight; + if (x != y) + stdb_(s).stress(y, x) = stdb_(s).stress(x, y); } } } @@ -264,64 +282,63 @@ Hint: check different predict() methods."); } private: - std::string label="M_BLR"; + std::string label = "M_BLR"; DesignMatrix<BF> desmat; // normalise weights such that when predict is called // we can supply it with a non-normalised descriptor - tadah::core::t_type convert_to_nweights(const tadah::core::t_type &weights) const { - if(bf.get_label()!="tadah::models::BF_Linear") { + tadah::core::t_type + convert_to_nweights(const tadah::core::t_type &weights) const { + if (bf.get_label() != "tadah::models::BF_Linear") { throw std::runtime_error("Cannot convert weights to nweights for\n\ non linear basis function\n"); } tadah::core::t_type nw(weights.rows()); nw(0) = weights(0); - for (size_t i=1; i<weights.size(); ++i) { + for (size_t i = 1; i < weights.size(); ++i) { if (norm.std_dev[i] > std::numeric_limits<double>::min()) nw(i) = weights(i) / norm.std_dev[i]; else nw(i) = weights(i); - nw(0) -= norm.mean[i]*nw(i); - + nw(0) -= norm.mean[i] * nw(i); } return nw; } // The opposite of convert_to_nweights() tadah::core::t_type convert_to_weights(const tadah::core::t_type &nw) const { - if(bf.get_label()!="tadah::models::BF_Linear") { + if (bf.get_label() != "tadah::models::BF_Linear") { throw std::runtime_error("Cannot convert nweights to weights for\n\ non linear basis function\n"); } // convert normalised weights back to "normal" tadah::core::t_type w(nw.rows()); w(0) = nw(0); - for (size_t i=1; i<nw.size(); ++i) { + for (size_t i = 1; i < nw.size(); ++i) { if (norm.std_dev[i] > std::numeric_limits<double>::min()) w(i) = nw(i) * norm.std_dev[i]; else w(i) = nw(i); - w(0) += nw(i)*norm.mean[i]; + w(0) += nw(i) * norm.mean[i]; } return w; } - template <typename D> - void train(D &desmat) { - // TODO + template <typename D> void train(D &desmat) { + // TODO // the train method destroys the Phi matrix // In consequence, we cannot use it for quick prediction // The simple solution, for now, is to make a copy of the Phi matrix - //tadah::core::phi_type &Phi = desmat.Phi; + // tadah::core::phi_type &Phi = desmat.Phi; tadah::core::phi_type Phi = desmat.getPhi(); tadah::core::t_type T = desmat.getT(); - //tadah::core::t_type &T = desmat.T; - tadah::models::M_BLR_Train<BF>::train(Phi,T); + // tadah::core::t_type &T = desmat.T; + tadah::models::M_BLR_Train<BF>::train(Phi, T); if (config.template get<bool>("NORM") && - bf.get_label()=="tadah::models::BF_Linear") { + bf.get_label() == "tadah::models::BF_Linear") { weights = convert_to_nweights(weights); } } @@ -334,6 +351,6 @@ non linear basis function\n"); using tadah::models::M_BLR_Train<BF>::weights; using tadah::models::M_BLR_Train<BF>::Sigma; }; -} -} +} // namespace mlip +} // namespace tadah #endif diff --git a/include/tadah/mlip/models/m_krr.h b/include/tadah/mlip/models/m_krr.h index 63f93b3..e5907c1 100644 --- a/include/tadah/mlip/models/m_krr.h +++ b/include/tadah/mlip/models/m_krr.h @@ -110,7 +110,7 @@ public: train(desmat); } - void train(StructureDB &stdb, DC_Base &dc) { + void train(StructureDB &stdb, const NeighborListDB &nldb, DC_Base &dc) { int modelN; try { modelN=config.template get<int>("MODEL",2); @@ -121,9 +121,9 @@ public: config.add("MODEL", modelN); } if (modelN==1) - train1(stdb,dc); + train1(stdb,nldb,dc); else if (modelN==2) - train2(stdb,dc); + train2(stdb,nldb,dc); else throw std::runtime_error( @@ -191,14 +191,17 @@ public: tadah::models::M_KRR_Train<K>::train2(basis.b, basis.T); } - Structure predict(const tadah::core::Config &c, StDescriptors &std, const Structure &st) { - if(config.template get<bool>("NORM") && !std.normalised && kernel.get_label()!="tadah::models::Kern_Linear") + Structure predict(const tadah::core::Config &c, StDescriptors &std, + const Structure &st, const StructureNeighborView &st_nb) { + if (config.template get<bool>("NORM") && !std.normalised && + kernel.get_label() != "tadah::models::Kern_Linear") norm.normalise(std); - return M_Tadah_Base::predict(c,std,st); + return M_Tadah_Base::predict(c, std, st, st_nb); } - StructureDB predict(tadah::core::Config &c, const StructureDB &stdb, DC_Base &dc) { - return M_Tadah_Base::predict(c,stdb,dc); + StructureDB predict(tadah::core::Config &c, const StructureDB &stdb, + const NeighborListDB &nldb, DC_Base &dc) { + return M_Tadah_Base::predict(c, stdb, nldb, dc); } tadah::core::Config get_param_file() { @@ -262,20 +265,23 @@ public: tadah::core::aed_type Tpred = T_dgemv(dm.getPhi(), weights); // Construct StructureDB object with predicted values - StructureDB stdb_; - stdb_.structures.resize(stdb.size()); + auto pinfo = stdb.exportPass1Info(); + StructureDB stdb_(pinfo, true); size_t i=0; for (size_t s=0; s<stdb.size(); ++s) { - stdb_(s) = Structure(stdb(s)); - stdb_(s).energy = Tpred(i++); predicted_error(i) = (sqrt(predicted_error(i))-pmean)/stdb(s).natoms(); if (config_pred.get<bool>("FORCE")) { for (size_t a=0; a<stdb(s).natoms(); ++a) { - for (size_t k=0; k<3; ++k) { - stdb_(s).atoms[a].force[k] = Tpred(i++); - predicted_error(i) = (sqrt(predicted_error(i))-pmean); - } + stdb_(s).fx(a) = Tpred(i); + predicted_error(i) = (sqrt(predicted_error(i)) - pmean); + i++; + stdb_(s).fy(a) = Tpred(i); + predicted_error(i) = (sqrt(predicted_error(i)) - pmean); + i++; + stdb_(s).fz(a) = Tpred(i); + predicted_error(i) = (sqrt(predicted_error(i)) - pmean); + i++; } } if (config_pred.get<bool>("STRESS")) { @@ -291,7 +297,7 @@ public: } return stdb_; } - StructureDB predict(StructureDB &stdb) { + StructureDB predict(StructureDB &stdb, const NeighborListDB &nldb) { if(!trained) throw std::runtime_error("This object is not trained!\n\ Hint: check different predict() methods."); @@ -305,19 +311,21 @@ Hint: check different predict() methods."); double sweightglob=config.template get<double>("SWEIGHT"); // Construct StructureDB object with predicted values - StructureDB stdb_; - stdb_.structures.resize(stdb.size()); + auto pinfo = stdb.exportPass1Info(); + StructureDB stdb_(pinfo, true); size_t s=0; size_t i=0; while (i<Phi.rows()) { stdb_(s).energy = Tpred(i++)*stdb(s).natoms()/eweightglob/stdb(s).eweight; if (config.template get<bool>("FORCE")) { - stdb_(s).atoms.resize(stdb(s).natoms()); for (size_t a=0; a<stdb(s).natoms(); ++a) { - for (size_t k=0; k<3; ++k) { - stdb_(s).atoms[a].force[k] = Tpred(i++)/fweightglob/stdb(s).fweight; - } + stdb_(s).fx(a) = Tpred(i); + i++; + stdb_(s).fy(a) = Tpred(i); + i++; + stdb_(s).fz(a) = Tpred(i); + i++; } } if (config.template get<bool>("STRESS")) { @@ -385,7 +393,6 @@ non linear kernel\n"); template <typename D> void train(D &desmat) { - // TODO see comments in M_BLR tadah::core::phi_type Phi = desmat.getPhi(); tadah::core::t_type T = desmat.getT(); tadah::models::M_KRR_Train<K>::train(Phi,T); diff --git a/include/tadah/mlip/models/m_tadah_base.h b/include/tadah/mlip/models/m_tadah_base.h index bad261c..fe7a15e 100644 --- a/include/tadah/mlip/models/m_tadah_base.h +++ b/include/tadah/mlip/models/m_tadah_base.h @@ -89,7 +89,7 @@ public: * * @param dc is a DescriptorCalc object */ - virtual void train(StructureDB &, DC_Base &) {}; + virtual void train(StructureDB &, const NeighborListDB &nldb, DC_Base &) {}; /** This will fit a model with precalculated StDescriptorsDB object. @@ -107,5 +107,4 @@ public: }; } } -//template<> inline Registry<M_Tadah_Base,DM_Function_Base&,tadah::core::Config&>::Map Registry<M_Tadah_Base,DM_Function_Base&,tadah::core::Config&>::registry{}; #endif -- GitLab