From bb249ee9f3e93cc427edac8ab2b609b1b4b3441f Mon Sep 17 00:00:00 2001 From: mkirsz <s1351949@sms.ed.ac.uk> Date: Tue, 30 Jan 2024 11:42:39 +0000 Subject: [PATCH] Refactoring --- m_blr_core.h | 49 +++++++++---------------------------------- m_blr_train.h | 42 +++++++++++++++++++++++++++++++++++++ m_core.h | 2 +- m_krr_core.h | 58 +++++++-------------------------------------------- m_krr_train.h | 55 ++++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 116 insertions(+), 90 deletions(-) create mode 100644 m_blr_train.h create mode 100644 m_krr_train.h diff --git a/m_blr_core.h b/m_blr_core.h index 433edbb..a53bc24 100644 --- a/m_blr_core.h +++ b/m_blr_core.h @@ -1,11 +1,10 @@ #ifndef M_BLR_Core_H #define M_BLR_Core_H -#include "linear_regressor.h" -#include "../CORE/config/config.h" +#include "m_core.h" #include "functions/function_base.h" #include "functions/basis_functions/bf_base.h" -#include "m_core.h" +#include "../CORE/config/config.h" #include <iostream> @@ -42,7 +41,14 @@ class M_BLR_Core: } double predict(const rvec &v) const{ return bf.epredict(get_weights(),v); - }; + } + + t_type get_weights_uncertainty() const{ + double lambda=config.template get<double>("LAMBDA"); + if(lambda >= 0) throw std::runtime_error( + "Sigma matrix is only computed for LAMBDA < 0"); + return Sigma.diagonal(); + } private: @@ -56,39 +62,4 @@ class M_BLR_Core: verbose=(config.get<int>("VERBOSE")); } }; -template -<class BF=Function_Base&> -class M_BLR_Train: - public M_BLR_Core<BF>, // MUST NOT BE VIRUAL!! - public M_Train -{ - - public: - using M_BLR_Core<BF>::trained; - using M_BLR_Core<BF>::config; - using M_BLR_Core<BF>::weights; - using M_BLR_Core<BF>::Sigma; - - M_BLR_Train(Config &c): - M_BLR_Core<BF>(c) - {} - - M_BLR_Train(BF &bf, Config &c): - M_BLR_Core<BF>(bf, c) - {} - - void train(phi_type &Phi, const t_type &T) { - if(trained) { - throw std::runtime_error("This object is already trained!"); - } - LinearRegressor::train(config,Phi, T,weights,Sigma); - trained=true; - } - t_type get_weights_uncertainty() const{ - double lambda=config.template get<double>("LAMBDA"); - if(lambda >= 0) throw std::runtime_error( - "Sigma matrix is only computed for LAMBDA < 0"); - return Sigma.diagonal(); - } -}; #endif diff --git a/m_blr_train.h b/m_blr_train.h new file mode 100644 index 0000000..60618df --- /dev/null +++ b/m_blr_train.h @@ -0,0 +1,42 @@ +#ifndef M_BLR_TRAIN_H +#define M_BLR_TRAIN_H + +#include "m_core.h" +#include "m_blr_core.h" +#include "linear_regressor.h" +#include "functions/function_base.h" +#include "functions/basis_functions/bf_base.h" +#include "../CORE/config/config.h" + +#include <iostream> + +template +<class BF=Function_Base&> +class M_BLR_Train: + public M_BLR_Core<BF>, // MUST NOT BE VIRUAL!! + public M_Train +{ + + public: + using M_BLR_Core<BF>::trained; + using M_BLR_Core<BF>::config; + using M_BLR_Core<BF>::weights; + using M_BLR_Core<BF>::Sigma; + + M_BLR_Train(Config &c): + M_BLR_Core<BF>(c) + {} + + M_BLR_Train(BF &bf, Config &c): + M_BLR_Core<BF>(bf, c) + {} + + void train(phi_type &Phi, const t_type &T) { + if(trained) { + throw std::runtime_error("This object is already trained!"); + } + LinearRegressor::train(config,Phi, T,weights,Sigma); + trained=true; + } +}; +#endif diff --git a/m_core.h b/m_core.h index aaec72b..c2783bd 100644 --- a/m_core.h +++ b/m_core.h @@ -22,6 +22,7 @@ class M_Core { weights=w; } virtual double predict(const rvec &v)const=0; + virtual t_type get_weights_uncertainty()const=0; }; class M_Predict { public: @@ -50,6 +51,5 @@ class M_Train { public: virtual ~M_Train() {} virtual void train(phi_type &Phi, const t_type &T)=0; - virtual t_type get_weights_uncertainty()const=0; }; #endif diff --git a/m_krr_core.h b/m_krr_core.h index 2bf5bf2..ff7b370 100644 --- a/m_krr_core.h +++ b/m_krr_core.h @@ -1,12 +1,10 @@ #ifndef M_KRR_Core_H #define M_KRR_Core_H -#include "linear_regressor.h" #include "../CORE/config/config.h" #include "functions/function_base.h" #include "functions/kernels/kern_base.h" #include "m_core.h" -#include "ekm.h" #include <iostream> @@ -43,7 +41,14 @@ class M_KRR_Core: } double predict(const rvec &v) const{ return kernel.epredict(weights,v); - }; + } + + t_type get_weights_uncertainty() const{ + double lambda=config.template get<double>("LAMBDA"); + if(lambda >= 0) throw std::runtime_error( + "Sigma matrix is only computed for LAMBDA < 0"); + return Sigma.diagonal(); + } private: @@ -57,53 +62,6 @@ class M_KRR_Core: verbose=(config.get<int>("VERBOSE")); } }; -template -<class K=Function_Base&> -class M_KRR_Train: - public M_KRR_Core<K>, // MUST NOT BE VIRUAL!! - public M_Train -{ - - public: - EKM<K> ekm; - using M_KRR_Core<K>::trained; - using M_KRR_Core<K>::config; - using M_KRR_Core<K>::kernel; - using M_KRR_Core<K>::weights; - using M_KRR_Core<K>::Sigma; - - M_KRR_Train(Config &c): - M_KRR_Core<K>(c), - ekm(c) - {} - - M_KRR_Train(K &kernel, Config &c): - M_KRR_Core<K>(kernel, c), - ekm(kernel) - {} - - void train(phi_type &Phi, const t_type &T) { - if(trained) { - throw std::runtime_error("This object is already trained!"); - } - if (kernel.get_label()!="Kern_Linear") { - ekm.project(Phi); - } - LinearRegressor::train(config,Phi, T,weights,Sigma); - - if (kernel.get_label()!="Kern_Linear") { - //kernalize weights - weights = ekm.KK.transpose()*weights; - } - trained=true; - } - t_type get_weights_uncertainty() const{ - double lambda=config.template get<double>("LAMBDA"); - if(lambda >= 0) throw std::runtime_error( - "Sigma matrix is only computed for LAMBDA < 0"); - return Sigma.diagonal(); - } -}; //template //<class K=Function_Base&> diff --git a/m_krr_train.h b/m_krr_train.h new file mode 100644 index 0000000..4cb7da3 --- /dev/null +++ b/m_krr_train.h @@ -0,0 +1,55 @@ +#ifndef M_KRR_TRAIN_H +#define M_KRR_TRAIN_H + +#include "linear_regressor.h" +#include "../CORE/config/config.h" +#include "functions/function_base.h" +#include "functions/kernels/kern_base.h" +#include "m_core.h" +#include "m_krr_core.h" +#include "ekm.h" + +#include <iostream> + +template +<class K=Function_Base&> +class M_KRR_Train: + public M_KRR_Core<K>, // MUST NOT BE VIRUAL!! + public M_Train +{ + + public: + EKM<K> ekm; + using M_KRR_Core<K>::trained; + using M_KRR_Core<K>::config; + using M_KRR_Core<K>::kernel; + using M_KRR_Core<K>::weights; + using M_KRR_Core<K>::Sigma; + + M_KRR_Train(Config &c): + M_KRR_Core<K>(c), + ekm(c) + {} + + M_KRR_Train(K &kernel, Config &c): + M_KRR_Core<K>(kernel, c), + ekm(kernel) + {} + + void train(phi_type &Phi, const t_type &T) { + if(trained) { + throw std::runtime_error("This object is already trained!"); + } + if (kernel.get_label()!="Kern_Linear") { + ekm.project(Phi); + } + LinearRegressor::train(config,Phi, T,weights,Sigma); + + if (kernel.get_label()!="Kern_Linear") { + //kernalize weights + weights = ekm.KK.transpose()*weights; + } + trained=true; + } +}; +#endif -- GitLab