From 679c69f67308418bc5de2c7588cc2ff690ed8d53 Mon Sep 17 00:00:00 2001 From: mkirsz <s1351949@sms.ed.ac.uk> Date: Mon, 29 Jan 2024 22:24:04 +0000 Subject: [PATCH] More refactoring --- m_base.h | 36 ++++++++++++++++++------------------ m_blr_core.h | 14 +++++++++++--- m_core.h | 43 ++++++++++++++++++++++++++++++++++++++++++- m_krr_core.h | 14 +++++++++++--- 4 files changed, 82 insertions(+), 25 deletions(-) diff --git a/m_base.h b/m_base.h index 4ecb44b..361a1d1 100644 --- a/m_base.h +++ b/m_base.h @@ -10,24 +10,24 @@ class M_Base { public: //Normaliser norm; // TODO? virtual ~M_Base() {}; - - /** \brief Predict local energy of an atom or bond energy. - * - * The result depends on how aed is computed. - * - * If it is computed between a pair of atoms than the result is a bond energy. - * - * If aed contains sum over all nearest neighbours than the result is - * a local atomic energy \f$ E_i \f$. - * */ - virtual double epredict(const aed_rctype &aed)=0; - - /** \brief Predict force between a pair of atoms in a k-direction. */ - virtual double fpredict(const fd_type &fdij, const aed_rctype &aedi, size_t k)=0; - - /** \brief Predict force between a pair of atoms. */ - virtual force_type fpredict(const fd_type &fdij, - const aed_rctype &aedi)=0; +// +// /** \brief Predict local energy of an atom or bond energy. +// * +// * The result depends on how aed is computed. +// * +// * If it is computed between a pair of atoms than the result is a bond energy. +// * +// * If aed contains sum over all nearest neighbours than the result is +// * a local atomic energy \f$ E_i \f$. +// * */ +// virtual double epredict(const aed_rctype &aed)=0; +// +// /** \brief Predict force between a pair of atoms in a k-direction. */ +// virtual double fpredict(const fd_type &fdij, const aed_rctype &aedi, size_t k)=0; +// +// /** \brief Predict force between a pair of atoms. */ +// virtual force_type fpredict(const fd_type &fdij, +// const aed_rctype &aedi)=0; }; diff --git a/m_blr_core.h b/m_blr_core.h index dea3ea2..bdd4980 100644 --- a/m_blr_core.h +++ b/m_blr_core.h @@ -11,13 +11,20 @@ template <class BF=Function_Base&> -class M_BLR_Core: public virtual M_Core { +class M_BLR_Core: + public virtual M_Core, + public virtual M_Predict, + public virtual M_Train { public: - Config &config; + //using M_Core<BF>::weights; + //using M_Core<BF>::trained; + //using M_Core<BF>::config; mat Sigma; + Config &config; BF bf; M_BLR_Core(Config &c): config(c), + //M_Core<BF>(c), bf(c) { static_assert(std::is_base_of<BF_Base, BF>::value, @@ -31,6 +38,7 @@ class M_BLR_Core: public virtual M_Core { M_BLR_Core(BF &bf, Config &c): config(c), + //M_Core<BF>(bf,c), bf(bf) { if (dynamic_cast<BF_Base*>(&bf) == nullptr) @@ -67,7 +75,7 @@ class M_BLR_Core: public virtual M_Core { trained=true; } - M_Core::verbose=(config.get<int>("VERBOSE")); + verbose=(config.get<int>("VERBOSE")); } }; #endif diff --git a/m_core.h b/m_core.h index 1caa4b0..cced00f 100644 --- a/m_core.h +++ b/m_core.h @@ -3,8 +3,21 @@ #include "../CORE/typedefs.h" +//template +//<class F=Function_Base&> class M_Core { public: +// F f; +// Config &config; +// +// M_Core(Config &c): +// config(c), +// f(c) +// {} +// M_Core(F &f, Config &c): +// config(c), +// f(f) +// {} int verbose; bool trained=false; t_type weights; @@ -12,7 +25,6 @@ class M_Core { virtual ~M_Core() {} virtual double predict(const rvec &v)const=0; - virtual void train(phi_type &Phi, const t_type &T)=0; virtual t_type get_weights_uncertainty()const=0; bool is_trained() const { @@ -25,4 +37,33 @@ class M_Core { weights=w; } }; +class M_Predict { + public: + + virtual ~M_Predict() {} + + /** \brief Predict local energy of an atom or bond energy. + * + * The result depends on how aed is computed. + * + * If it is computed between a pair of atoms than the result is a bond energy. + * + * If aed contains sum over all nearest neighbours than the result is + * a local atomic energy \f$ E_i \f$. + * */ + virtual double epredict(const aed_rctype &aed)=0; + + /** \brief Predict force between a pair of atoms in a k-direction. */ + virtual double fpredict(const fd_type &fdij, const aed_rctype &aedi, size_t k)=0; + + /** \brief Predict force between a pair of atoms. */ + virtual force_type fpredict(const fd_type &fdij, + const aed_rctype &aedi)=0; +}; +class M_Train { + public: + virtual ~M_Train() {} + + virtual void train(phi_type &Phi, const t_type &T)=0; +}; #endif diff --git a/m_krr_core.h b/m_krr_core.h index 3499aad..5ebf7db 100644 --- a/m_krr_core.h +++ b/m_krr_core.h @@ -12,14 +12,21 @@ template <class K=Function_Base&> -class M_KRR_Core: public virtual M_Core { +class M_KRR_Core: + public virtual M_Core, + public virtual M_Predict, + public virtual M_Train { public: - Config &config; + //using M_Core<K>::weights; + //using M_Core<K>::trained; + //using M_Core<K>::config; mat Sigma; + Config &config; K kernel; EKM<K> ekm; M_KRR_Core(Config &c): config(c), + //M_Core<K>(c), kernel(c), ekm(c) { @@ -33,6 +40,7 @@ class M_KRR_Core: public virtual M_Core { M_KRR_Core(K &kernel, Config &c): config(c), + //M_Core<K>(kernel, c), kernel(kernel), ekm(kernel) { @@ -75,7 +83,7 @@ class M_KRR_Core: public virtual M_Core { trained=true; } - M_Core::verbose=(config.get<int>("VERBOSE")); + verbose=(config.get<int>("VERBOSE")); } }; #endif -- GitLab