From bb249ee9f3e93cc427edac8ab2b609b1b4b3441f Mon Sep 17 00:00:00 2001
From: mkirsz <s1351949@sms.ed.ac.uk>
Date: Tue, 30 Jan 2024 11:42:39 +0000
Subject: [PATCH] Refactoring

---
 m_blr_core.h  | 49 +++++++++----------------------------------
 m_blr_train.h | 42 +++++++++++++++++++++++++++++++++++++
 m_core.h      |  2 +-
 m_krr_core.h  | 58 +++++++--------------------------------------------
 m_krr_train.h | 55 ++++++++++++++++++++++++++++++++++++++++++++++++
 5 files changed, 116 insertions(+), 90 deletions(-)
 create mode 100644 m_blr_train.h
 create mode 100644 m_krr_train.h

diff --git a/m_blr_core.h b/m_blr_core.h
index 433edbb..a53bc24 100644
--- a/m_blr_core.h
+++ b/m_blr_core.h
@@ -1,11 +1,10 @@
 #ifndef M_BLR_Core_H
 #define M_BLR_Core_H
 
-#include "linear_regressor.h"
-#include "../CORE/config/config.h"
+#include "m_core.h"
 #include "functions/function_base.h"
 #include "functions/basis_functions/bf_base.h"
-#include "m_core.h"
+#include "../CORE/config/config.h"
 
 #include <iostream>
 
@@ -42,7 +41,14 @@ class M_BLR_Core:
     }
         double predict(const rvec &v) const{
             return bf.epredict(get_weights(),v);
-        };
+        }
+
+        t_type get_weights_uncertainty() const{
+            double lambda=config.template get<double>("LAMBDA");
+            if(lambda >= 0) throw std::runtime_error(
+                    "Sigma matrix is only computed for LAMBDA < 0");
+            return Sigma.diagonal();
+        }
 
 
     private:
@@ -56,39 +62,4 @@ class M_BLR_Core:
             verbose=(config.get<int>("VERBOSE"));
         }
 };
-template
-<class BF=Function_Base&>
-class M_BLR_Train:
-    public M_BLR_Core<BF>,    // MUST NOT BE VIRUAL!!
-    public M_Train
-{
-
-    public:
-        using M_BLR_Core<BF>::trained;
-        using M_BLR_Core<BF>::config;
-        using M_BLR_Core<BF>::weights;
-        using M_BLR_Core<BF>::Sigma;
-
-        M_BLR_Train(Config &c):
-            M_BLR_Core<BF>(c)
-    {}
-
-        M_BLR_Train(BF &bf, Config &c):
-            M_BLR_Core<BF>(bf, c)
-    {}
-
-        void train(phi_type &Phi, const t_type &T) {
-            if(trained) {
-                throw std::runtime_error("This object is already trained!");
-            }
-            LinearRegressor::train(config,Phi, T,weights,Sigma);
-            trained=true;
-        }
-        t_type get_weights_uncertainty() const{
-            double lambda=config.template get<double>("LAMBDA");
-            if(lambda >= 0) throw std::runtime_error(
-                    "Sigma matrix is only computed for LAMBDA < 0");
-            return Sigma.diagonal();
-        }
-};
 #endif
diff --git a/m_blr_train.h b/m_blr_train.h
new file mode 100644
index 0000000..60618df
--- /dev/null
+++ b/m_blr_train.h
@@ -0,0 +1,42 @@
+#ifndef M_BLR_TRAIN_H
+#define M_BLR_TRAIN_H
+
+#include "m_core.h"
+#include "m_blr_core.h"
+#include "linear_regressor.h"
+#include "functions/function_base.h"
+#include "functions/basis_functions/bf_base.h"
+#include "../CORE/config/config.h"
+
+#include <iostream>
+
+template
+<class BF=Function_Base&>
+class M_BLR_Train:
+    public M_BLR_Core<BF>,    // MUST NOT BE VIRUAL!!
+    public M_Train
+{
+
+    public:
+        using M_BLR_Core<BF>::trained;
+        using M_BLR_Core<BF>::config;
+        using M_BLR_Core<BF>::weights;
+        using M_BLR_Core<BF>::Sigma;
+
+        M_BLR_Train(Config &c):
+            M_BLR_Core<BF>(c)
+    {}
+
+        M_BLR_Train(BF &bf, Config &c):
+            M_BLR_Core<BF>(bf, c)
+    {}
+
+        void train(phi_type &Phi, const t_type &T) {
+            if(trained) {
+                throw std::runtime_error("This object is already trained!");
+            }
+            LinearRegressor::train(config,Phi, T,weights,Sigma);
+            trained=true;
+        }
+};
+#endif
diff --git a/m_core.h b/m_core.h
index aaec72b..c2783bd 100644
--- a/m_core.h
+++ b/m_core.h
@@ -22,6 +22,7 @@ class M_Core {
             weights=w;
         }
         virtual double predict(const rvec &v)const=0;
+        virtual t_type get_weights_uncertainty()const=0;
 };
 class M_Predict {
     public:
@@ -50,6 +51,5 @@ class M_Train {
     public:
         virtual ~M_Train() {}
         virtual void train(phi_type &Phi, const t_type &T)=0;
-        virtual t_type get_weights_uncertainty()const=0;
 };
 #endif
diff --git a/m_krr_core.h b/m_krr_core.h
index 2bf5bf2..ff7b370 100644
--- a/m_krr_core.h
+++ b/m_krr_core.h
@@ -1,12 +1,10 @@
 #ifndef M_KRR_Core_H
 #define M_KRR_Core_H
 
-#include "linear_regressor.h"
 #include "../CORE/config/config.h"
 #include "functions/function_base.h"
 #include "functions/kernels/kern_base.h"
 #include "m_core.h"
-#include "ekm.h"
 
 #include <iostream>
 
@@ -43,7 +41,14 @@ class M_KRR_Core:
     }
         double predict(const rvec &v) const{
             return kernel.epredict(weights,v);
-        };
+        }
+
+        t_type get_weights_uncertainty() const{
+            double lambda=config.template get<double>("LAMBDA");
+            if(lambda >= 0) throw std::runtime_error(
+                    "Sigma matrix is only computed for LAMBDA < 0");
+            return Sigma.diagonal();
+        }
 
 
     private:
@@ -57,53 +62,6 @@ class M_KRR_Core:
             verbose=(config.get<int>("VERBOSE"));
         }
 };
-template
-<class K=Function_Base&>
-class M_KRR_Train:
-    public M_KRR_Core<K>,    // MUST NOT BE VIRUAL!!
-    public M_Train
-{
-
-    public:
-        EKM<K> ekm;
-        using M_KRR_Core<K>::trained;
-        using M_KRR_Core<K>::config;
-        using M_KRR_Core<K>::kernel;
-        using M_KRR_Core<K>::weights;
-        using M_KRR_Core<K>::Sigma;
-
-        M_KRR_Train(Config &c):
-            M_KRR_Core<K>(c),
-            ekm(c)
-    {}
-
-        M_KRR_Train(K &kernel, Config &c):
-            M_KRR_Core<K>(kernel, c),
-            ekm(kernel)
-    {}
-
-        void train(phi_type &Phi, const t_type &T) {
-            if(trained) {
-                throw std::runtime_error("This object is already trained!");
-            }
-            if (kernel.get_label()!="Kern_Linear") {
-                ekm.project(Phi);
-            }
-            LinearRegressor::train(config,Phi, T,weights,Sigma);
-
-            if (kernel.get_label()!="Kern_Linear") {
-                //kernalize weights
-                weights = ekm.KK.transpose()*weights;
-            }
-            trained=true;
-        }
-        t_type get_weights_uncertainty() const{
-            double lambda=config.template get<double>("LAMBDA");
-            if(lambda >= 0) throw std::runtime_error(
-                    "Sigma matrix is only computed for LAMBDA < 0");
-            return Sigma.diagonal();
-        }
-};
 
 //template
 //<class K=Function_Base&>
diff --git a/m_krr_train.h b/m_krr_train.h
new file mode 100644
index 0000000..4cb7da3
--- /dev/null
+++ b/m_krr_train.h
@@ -0,0 +1,55 @@
+#ifndef M_KRR_TRAIN_H
+#define M_KRR_TRAIN_H
+
+#include "linear_regressor.h"
+#include "../CORE/config/config.h"
+#include "functions/function_base.h"
+#include "functions/kernels/kern_base.h"
+#include "m_core.h"
+#include "m_krr_core.h"
+#include "ekm.h"
+
+#include <iostream>
+
+template
+<class K=Function_Base&>
+class M_KRR_Train:
+    public M_KRR_Core<K>,    // MUST NOT BE VIRUAL!!
+    public M_Train
+{
+
+    public:
+        EKM<K> ekm;
+        using M_KRR_Core<K>::trained;
+        using M_KRR_Core<K>::config;
+        using M_KRR_Core<K>::kernel;
+        using M_KRR_Core<K>::weights;
+        using M_KRR_Core<K>::Sigma;
+
+        M_KRR_Train(Config &c):
+            M_KRR_Core<K>(c),
+            ekm(c)
+    {}
+
+        M_KRR_Train(K &kernel, Config &c):
+            M_KRR_Core<K>(kernel, c),
+            ekm(kernel)
+    {}
+
+        void train(phi_type &Phi, const t_type &T) {
+            if(trained) {
+                throw std::runtime_error("This object is already trained!");
+            }
+            if (kernel.get_label()!="Kern_Linear") {
+                ekm.project(Phi);
+            }
+            LinearRegressor::train(config,Phi, T,weights,Sigma);
+
+            if (kernel.get_label()!="Kern_Linear") {
+                //kernalize weights
+                weights = ekm.KK.transpose()*weights;
+            }
+            trained=true;
+        }
+};
+#endif
-- 
GitLab