Skip to content
Snippets Groups Projects
Commit bb249ee9 authored by mkirsz's avatar mkirsz
Browse files

Refactoring

parent 52095af5
No related branches found
No related tags found
No related merge requests found
#ifndef M_BLR_Core_H
#define M_BLR_Core_H
#include "linear_regressor.h"
#include "../CORE/config/config.h"
#include "m_core.h"
#include "functions/function_base.h"
#include "functions/basis_functions/bf_base.h"
#include "m_core.h"
#include "../CORE/config/config.h"
#include <iostream>
......@@ -42,7 +41,14 @@ class M_BLR_Core:
}
double predict(const rvec &v) const{
return bf.epredict(get_weights(),v);
};
}
t_type get_weights_uncertainty() const{
double lambda=config.template get<double>("LAMBDA");
if(lambda >= 0) throw std::runtime_error(
"Sigma matrix is only computed for LAMBDA < 0");
return Sigma.diagonal();
}
private:
......@@ -56,39 +62,4 @@ class M_BLR_Core:
verbose=(config.get<int>("VERBOSE"));
}
};
template
<class BF=Function_Base&>
class M_BLR_Train:
public M_BLR_Core<BF>, // MUST NOT BE VIRUAL!!
public M_Train
{
public:
using M_BLR_Core<BF>::trained;
using M_BLR_Core<BF>::config;
using M_BLR_Core<BF>::weights;
using M_BLR_Core<BF>::Sigma;
M_BLR_Train(Config &c):
M_BLR_Core<BF>(c)
{}
M_BLR_Train(BF &bf, Config &c):
M_BLR_Core<BF>(bf, c)
{}
void train(phi_type &Phi, const t_type &T) {
if(trained) {
throw std::runtime_error("This object is already trained!");
}
LinearRegressor::train(config,Phi, T,weights,Sigma);
trained=true;
}
t_type get_weights_uncertainty() const{
double lambda=config.template get<double>("LAMBDA");
if(lambda >= 0) throw std::runtime_error(
"Sigma matrix is only computed for LAMBDA < 0");
return Sigma.diagonal();
}
};
#endif
#ifndef M_BLR_TRAIN_H
#define M_BLR_TRAIN_H
#include "m_core.h"
#include "m_blr_core.h"
#include "linear_regressor.h"
#include "functions/function_base.h"
#include "functions/basis_functions/bf_base.h"
#include "../CORE/config/config.h"
#include <iostream>
template
<class BF=Function_Base&>
class M_BLR_Train:
public M_BLR_Core<BF>, // MUST NOT BE VIRUAL!!
public M_Train
{
public:
using M_BLR_Core<BF>::trained;
using M_BLR_Core<BF>::config;
using M_BLR_Core<BF>::weights;
using M_BLR_Core<BF>::Sigma;
M_BLR_Train(Config &c):
M_BLR_Core<BF>(c)
{}
M_BLR_Train(BF &bf, Config &c):
M_BLR_Core<BF>(bf, c)
{}
void train(phi_type &Phi, const t_type &T) {
if(trained) {
throw std::runtime_error("This object is already trained!");
}
LinearRegressor::train(config,Phi, T,weights,Sigma);
trained=true;
}
};
#endif
......@@ -22,6 +22,7 @@ class M_Core {
weights=w;
}
virtual double predict(const rvec &v)const=0;
virtual t_type get_weights_uncertainty()const=0;
};
class M_Predict {
public:
......@@ -50,6 +51,5 @@ class M_Train {
public:
virtual ~M_Train() {}
virtual void train(phi_type &Phi, const t_type &T)=0;
virtual t_type get_weights_uncertainty()const=0;
};
#endif
#ifndef M_KRR_Core_H
#define M_KRR_Core_H
#include "linear_regressor.h"
#include "../CORE/config/config.h"
#include "functions/function_base.h"
#include "functions/kernels/kern_base.h"
#include "m_core.h"
#include "ekm.h"
#include <iostream>
......@@ -43,7 +41,14 @@ class M_KRR_Core:
}
double predict(const rvec &v) const{
return kernel.epredict(weights,v);
};
}
t_type get_weights_uncertainty() const{
double lambda=config.template get<double>("LAMBDA");
if(lambda >= 0) throw std::runtime_error(
"Sigma matrix is only computed for LAMBDA < 0");
return Sigma.diagonal();
}
private:
......@@ -57,53 +62,6 @@ class M_KRR_Core:
verbose=(config.get<int>("VERBOSE"));
}
};
template
<class K=Function_Base&>
class M_KRR_Train:
public M_KRR_Core<K>, // MUST NOT BE VIRUAL!!
public M_Train
{
public:
EKM<K> ekm;
using M_KRR_Core<K>::trained;
using M_KRR_Core<K>::config;
using M_KRR_Core<K>::kernel;
using M_KRR_Core<K>::weights;
using M_KRR_Core<K>::Sigma;
M_KRR_Train(Config &c):
M_KRR_Core<K>(c),
ekm(c)
{}
M_KRR_Train(K &kernel, Config &c):
M_KRR_Core<K>(kernel, c),
ekm(kernel)
{}
void train(phi_type &Phi, const t_type &T) {
if(trained) {
throw std::runtime_error("This object is already trained!");
}
if (kernel.get_label()!="Kern_Linear") {
ekm.project(Phi);
}
LinearRegressor::train(config,Phi, T,weights,Sigma);
if (kernel.get_label()!="Kern_Linear") {
//kernalize weights
weights = ekm.KK.transpose()*weights;
}
trained=true;
}
t_type get_weights_uncertainty() const{
double lambda=config.template get<double>("LAMBDA");
if(lambda >= 0) throw std::runtime_error(
"Sigma matrix is only computed for LAMBDA < 0");
return Sigma.diagonal();
}
};
//template
//<class K=Function_Base&>
......
#ifndef M_KRR_TRAIN_H
#define M_KRR_TRAIN_H
#include "linear_regressor.h"
#include "../CORE/config/config.h"
#include "functions/function_base.h"
#include "functions/kernels/kern_base.h"
#include "m_core.h"
#include "m_krr_core.h"
#include "ekm.h"
#include <iostream>
template
<class K=Function_Base&>
class M_KRR_Train:
public M_KRR_Core<K>, // MUST NOT BE VIRUAL!!
public M_Train
{
public:
EKM<K> ekm;
using M_KRR_Core<K>::trained;
using M_KRR_Core<K>::config;
using M_KRR_Core<K>::kernel;
using M_KRR_Core<K>::weights;
using M_KRR_Core<K>::Sigma;
M_KRR_Train(Config &c):
M_KRR_Core<K>(c),
ekm(c)
{}
M_KRR_Train(K &kernel, Config &c):
M_KRR_Core<K>(kernel, c),
ekm(kernel)
{}
void train(phi_type &Phi, const t_type &T) {
if(trained) {
throw std::runtime_error("This object is already trained!");
}
if (kernel.get_label()!="Kern_Linear") {
ekm.project(Phi);
}
LinearRegressor::train(config,Phi, T,weights,Sigma);
if (kernel.get_label()!="Kern_Linear") {
//kernalize weights
weights = ekm.KK.transpose()*weights;
}
trained=true;
}
};
#endif
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment