From aabf70395ee4f9d06a90e768328a6436758aad12 Mon Sep 17 00:00:00 2001 From: mkirsz <s1351949@sms.ed.ac.uk> Date: Tue, 30 Jan 2024 10:50:05 +0000 Subject: [PATCH] Refactoring --- m_blr_core.h | 59 +++++++++++++++++------------ m_core.h | 16 +------- m_krr_core.h | 104 ++++++++++++++++++++++++++++++++++++++------------- 3 files changed, 115 insertions(+), 64 deletions(-) diff --git a/m_blr_core.h b/m_blr_core.h index bdd4980..433edbb 100644 --- a/m_blr_core.h +++ b/m_blr_core.h @@ -12,19 +12,14 @@ template <class BF=Function_Base&> class M_BLR_Core: - public virtual M_Core, - public virtual M_Predict, - public virtual M_Train { + public virtual M_Core +{ public: - //using M_Core<BF>::weights; - //using M_Core<BF>::trained; - //using M_Core<BF>::config; mat Sigma; Config &config; BF bf; M_BLR_Core(Config &c): config(c), - //M_Core<BF>(c), bf(c) { static_assert(std::is_base_of<BF_Base, BF>::value, @@ -38,21 +33,50 @@ class M_BLR_Core: M_BLR_Core(BF &bf, Config &c): config(c), - //M_Core<BF>(bf,c), bf(bf) { if (dynamic_cast<BF_Base*>(&bf) == nullptr) throw std::invalid_argument("Provided object is not of BF_Base type"); - //static_assert(std::is_same<Function_Base&, BF>::value, - // "This constructor requires BF=Function_Base&\n"); - init(); } double predict(const rvec &v) const{ return bf.epredict(get_weights(),v); }; + + private: + void init() { + if (config.exist("WEIGHTS")) { + weights.resize(config.size("WEIGHTS")); + config.get<t_type>("WEIGHTS",weights); + trained=true; + } + + verbose=(config.get<int>("VERBOSE")); + } +}; +template +<class BF=Function_Base&> +class M_BLR_Train: + public M_BLR_Core<BF>, // MUST NOT BE VIRUAL!! + public M_Train +{ + + public: + using M_BLR_Core<BF>::trained; + using M_BLR_Core<BF>::config; + using M_BLR_Core<BF>::weights; + using M_BLR_Core<BF>::Sigma; + + M_BLR_Train(Config &c): + M_BLR_Core<BF>(c) + {} + + M_BLR_Train(BF &bf, Config &c): + M_BLR_Core<BF>(bf, c) + {} + void train(phi_type &Phi, const t_type &T) { if(trained) { throw std::runtime_error("This object is already trained!"); @@ -61,21 +85,10 @@ class M_BLR_Core: trained=true; } t_type get_weights_uncertainty() const{ - double lambda=config.get<double>("LAMBDA"); + double lambda=config.template get<double>("LAMBDA"); if(lambda >= 0) throw std::runtime_error( "Sigma matrix is only computed for LAMBDA < 0"); return Sigma.diagonal(); } - - private: - void init() { - if (config.exist("WEIGHTS")) { - weights.resize(config.size("WEIGHTS")); - config.get<t_type>("WEIGHTS",weights); - trained=true; - } - - verbose=(config.get<int>("VERBOSE")); - } }; #endif diff --git a/m_core.h b/m_core.h index cced00f..56f827b 100644 --- a/m_core.h +++ b/m_core.h @@ -3,28 +3,14 @@ #include "../CORE/typedefs.h" -//template -//<class F=Function_Base&> class M_Core { public: -// F f; -// Config &config; -// -// M_Core(Config &c): -// config(c), -// f(c) -// {} -// M_Core(F &f, Config &c): -// config(c), -// f(f) -// {} int verbose; bool trained=false; t_type weights; virtual ~M_Core() {} - virtual double predict(const rvec &v)const=0; virtual t_type get_weights_uncertainty()const=0; bool is_trained() const { @@ -36,6 +22,7 @@ class M_Core { void set_weights(const t_type w) { weights=w; } + virtual double predict(const rvec &v)const=0; }; class M_Predict { public: @@ -63,7 +50,6 @@ class M_Predict { class M_Train { public: virtual ~M_Train() {} - virtual void train(phi_type &Phi, const t_type &T)=0; }; #endif diff --git a/m_krr_core.h b/m_krr_core.h index 5ebf7db..2bf5bf2 100644 --- a/m_krr_core.h +++ b/m_krr_core.h @@ -13,22 +13,16 @@ template <class K=Function_Base&> class M_KRR_Core: - public virtual M_Core, - public virtual M_Predict, - public virtual M_Train { + public virtual M_Core +{ public: - //using M_Core<K>::weights; - //using M_Core<K>::trained; - //using M_Core<K>::config; mat Sigma; Config &config; K kernel; - EKM<K> ekm; + M_KRR_Core(Config &c): config(c), - //M_Core<K>(c), - kernel(c), - ekm(c) + kernel(c) { static_assert(std::is_base_of<Kern_Base, K>::value, "\nThe provided Kernel K is not of a Kernel type.\n\ @@ -40,9 +34,7 @@ class M_KRR_Core: M_KRR_Core(K &kernel, Config &c): config(c), - //M_Core<K>(kernel, c), - kernel(kernel), - ekm(kernel) + kernel(kernel) { if (dynamic_cast<Kern_Base*>(&kernel) == nullptr) throw std::invalid_argument("Provided object is not of Kern_Base type"); @@ -50,9 +42,46 @@ class M_KRR_Core: init(); } double predict(const rvec &v) const{ - return kernel.epredict(get_weights(),v); + return kernel.epredict(weights,v); }; + + private: + void init() { + if (config.exist("WEIGHTS")) { + weights.resize(config.size("WEIGHTS")); + config.get<t_type>("WEIGHTS",weights); + trained=true; + } + + verbose=(config.get<int>("VERBOSE")); + } +}; +template +<class K=Function_Base&> +class M_KRR_Train: + public M_KRR_Core<K>, // MUST NOT BE VIRUAL!! + public M_Train +{ + + public: + EKM<K> ekm; + using M_KRR_Core<K>::trained; + using M_KRR_Core<K>::config; + using M_KRR_Core<K>::kernel; + using M_KRR_Core<K>::weights; + using M_KRR_Core<K>::Sigma; + + M_KRR_Train(Config &c): + M_KRR_Core<K>(c), + ekm(c) + {} + + M_KRR_Train(K &kernel, Config &c): + M_KRR_Core<K>(kernel, c), + ekm(kernel) + {} + void train(phi_type &Phi, const t_type &T) { if(trained) { throw std::runtime_error("This object is already trained!"); @@ -69,21 +98,44 @@ class M_KRR_Core: trained=true; } t_type get_weights_uncertainty() const{ - double lambda=config.get<double>("LAMBDA"); + double lambda=config.template get<double>("LAMBDA"); if(lambda >= 0) throw std::runtime_error( "Sigma matrix is only computed for LAMBDA < 0"); return Sigma.diagonal(); } - - private: - void init() { - if (config.exist("WEIGHTS")) { - weights.resize(config.size("WEIGHTS")); - config.get<t_type>("WEIGHTS",weights); - trained=true; - } - - verbose=(config.get<int>("VERBOSE")); - } }; + +//template +//<class K=Function_Base&> +//class M_KRR_Predict: +// public M_KRR_Core<K>, // MUST NOT BE VIRUAL!! +// public virtual M_Predict +//{ +// +// public: +// using M_KRR_Core<K>::config; +// using M_KRR_Core<K>::kernel; +// using M_KRR_Core<K>::weights; +// +// M_KRR_Predict(Config &c): +// M_KRR_Core<K>(c) +// {} +// +// M_KRR_Predict(K &kernel, Config &c): +// M_KRR_Core<K>(kernel, c) +// {} +// +// double epredict(const aed_rctype &aed) { +// return kernel.epredict(weights,aed); +// }; +// +// double fpredict(const fd_type &fdij, const aed_rctype &aedi, const size_t k) { +// return kernel.fpredict(weights,fdij,aedi,k); +// } +// +// force_type fpredict(const fd_type &fdij, const aed_rctype &aedi) { +// return kernel.fpredict(weights,fdij,aedi).array(); +// } +// +//}; #endif -- GitLab