From aabf70395ee4f9d06a90e768328a6436758aad12 Mon Sep 17 00:00:00 2001
From: mkirsz <s1351949@sms.ed.ac.uk>
Date: Tue, 30 Jan 2024 10:50:05 +0000
Subject: [PATCH] Refactoring

---
 m_blr_core.h |  59 +++++++++++++++++------------
 m_core.h     |  16 +-------
 m_krr_core.h | 104 ++++++++++++++++++++++++++++++++++++++-------------
 3 files changed, 115 insertions(+), 64 deletions(-)

diff --git a/m_blr_core.h b/m_blr_core.h
index bdd4980..433edbb 100644
--- a/m_blr_core.h
+++ b/m_blr_core.h
@@ -12,19 +12,14 @@
 template
 <class BF=Function_Base&>
 class M_BLR_Core:
-    public virtual M_Core,
-    public virtual M_Predict,
-    public virtual M_Train {
+    public virtual M_Core
+{
     public:
-        //using M_Core<BF>::weights;
-        //using M_Core<BF>::trained;
-        //using M_Core<BF>::config;
         mat Sigma;
         Config &config;
         BF bf;
         M_BLR_Core(Config &c):
             config(c),
-            //M_Core<BF>(c),
             bf(c)
     {
         static_assert(std::is_base_of<BF_Base, BF>::value,
@@ -38,21 +33,50 @@ class M_BLR_Core:
 
         M_BLR_Core(BF &bf, Config &c):
             config(c),
-            //M_Core<BF>(bf,c),
             bf(bf)
     {
         if (dynamic_cast<BF_Base*>(&bf) == nullptr)
             throw std::invalid_argument("Provided object is not of BF_Base type");
 
-        //static_assert(std::is_same<Function_Base&, BF>::value,
-        //        "This constructor requires BF=Function_Base&\n");
-
         init();
     }
         double predict(const rvec &v) const{
             return bf.epredict(get_weights(),v);
         };
 
+
+    private:
+        void init() {
+            if (config.exist("WEIGHTS")) {
+                weights.resize(config.size("WEIGHTS"));
+                config.get<t_type>("WEIGHTS",weights);
+                trained=true;
+            }
+
+            verbose=(config.get<int>("VERBOSE"));
+        }
+};
+template
+<class BF=Function_Base&>
+class M_BLR_Train:
+    public M_BLR_Core<BF>,    // MUST NOT BE VIRUAL!!
+    public M_Train
+{
+
+    public:
+        using M_BLR_Core<BF>::trained;
+        using M_BLR_Core<BF>::config;
+        using M_BLR_Core<BF>::weights;
+        using M_BLR_Core<BF>::Sigma;
+
+        M_BLR_Train(Config &c):
+            M_BLR_Core<BF>(c)
+    {}
+
+        M_BLR_Train(BF &bf, Config &c):
+            M_BLR_Core<BF>(bf, c)
+    {}
+
         void train(phi_type &Phi, const t_type &T) {
             if(trained) {
                 throw std::runtime_error("This object is already trained!");
@@ -61,21 +85,10 @@ class M_BLR_Core:
             trained=true;
         }
         t_type get_weights_uncertainty() const{
-            double lambda=config.get<double>("LAMBDA");
+            double lambda=config.template get<double>("LAMBDA");
             if(lambda >= 0) throw std::runtime_error(
                     "Sigma matrix is only computed for LAMBDA < 0");
             return Sigma.diagonal();
         }
-
-    private:
-        void init() {
-            if (config.exist("WEIGHTS")) {
-                weights.resize(config.size("WEIGHTS"));
-                config.get<t_type>("WEIGHTS",weights);
-                trained=true;
-            }
-
-            verbose=(config.get<int>("VERBOSE"));
-        }
 };
 #endif
diff --git a/m_core.h b/m_core.h
index cced00f..56f827b 100644
--- a/m_core.h
+++ b/m_core.h
@@ -3,28 +3,14 @@
 
 #include "../CORE/typedefs.h"
 
-//template
-//<class F=Function_Base&>
 class M_Core {
     public:
-//        F f;
-//        Config &config;
-//
-//        M_Core(Config &c):
-//            config(c),
-//            f(c)
-//    {}
-//        M_Core(F &f, Config &c):
-//            config(c),
-//            f(f)
-//    {}
         int verbose;
         bool trained=false;
         t_type weights;
 
         virtual ~M_Core() {}
 
-        virtual double predict(const rvec &v)const=0;
         virtual t_type get_weights_uncertainty()const=0;
 
         bool is_trained() const {
@@ -36,6 +22,7 @@ class M_Core {
         void set_weights(const t_type w) {
             weights=w;
         }
+        virtual double predict(const rvec &v)const=0;
 };
 class M_Predict {
     public:
@@ -63,7 +50,6 @@ class M_Predict {
 class M_Train {
     public:
         virtual ~M_Train() {}
-
         virtual void train(phi_type &Phi, const t_type &T)=0;
 };
 #endif
diff --git a/m_krr_core.h b/m_krr_core.h
index 5ebf7db..2bf5bf2 100644
--- a/m_krr_core.h
+++ b/m_krr_core.h
@@ -13,22 +13,16 @@
 template
 <class K=Function_Base&>
 class M_KRR_Core:
-    public virtual M_Core,
-    public virtual M_Predict,
-    public virtual M_Train {
+    public virtual M_Core
+{
     public:
-        //using M_Core<K>::weights;
-        //using M_Core<K>::trained;
-        //using M_Core<K>::config;
         mat Sigma;
         Config &config;
         K kernel;
-        EKM<K> ekm;
+
         M_KRR_Core(Config &c):
             config(c),
-            //M_Core<K>(c),
-            kernel(c),
-            ekm(c)
+            kernel(c)
     {
         static_assert(std::is_base_of<Kern_Base, K>::value,
                 "\nThe provided Kernel K is not of a Kernel type.\n\
@@ -40,9 +34,7 @@ class M_KRR_Core:
 
         M_KRR_Core(K &kernel, Config &c):
             config(c),
-            //M_Core<K>(kernel, c),
-            kernel(kernel),
-            ekm(kernel)
+            kernel(kernel)
     {
         if (dynamic_cast<Kern_Base*>(&kernel) == nullptr)
             throw std::invalid_argument("Provided object is not of Kern_Base type");
@@ -50,9 +42,46 @@ class M_KRR_Core:
         init();
     }
         double predict(const rvec &v) const{
-            return kernel.epredict(get_weights(),v);
+            return kernel.epredict(weights,v);
         };
 
+
+    private:
+        void init() {
+            if (config.exist("WEIGHTS")) {
+                weights.resize(config.size("WEIGHTS"));
+                config.get<t_type>("WEIGHTS",weights);
+                trained=true;
+            }
+
+            verbose=(config.get<int>("VERBOSE"));
+        }
+};
+template
+<class K=Function_Base&>
+class M_KRR_Train:
+    public M_KRR_Core<K>,    // MUST NOT BE VIRUAL!!
+    public M_Train
+{
+
+    public:
+        EKM<K> ekm;
+        using M_KRR_Core<K>::trained;
+        using M_KRR_Core<K>::config;
+        using M_KRR_Core<K>::kernel;
+        using M_KRR_Core<K>::weights;
+        using M_KRR_Core<K>::Sigma;
+
+        M_KRR_Train(Config &c):
+            M_KRR_Core<K>(c),
+            ekm(c)
+    {}
+
+        M_KRR_Train(K &kernel, Config &c):
+            M_KRR_Core<K>(kernel, c),
+            ekm(kernel)
+    {}
+
         void train(phi_type &Phi, const t_type &T) {
             if(trained) {
                 throw std::runtime_error("This object is already trained!");
@@ -69,21 +98,44 @@ class M_KRR_Core:
             trained=true;
         }
         t_type get_weights_uncertainty() const{
-            double lambda=config.get<double>("LAMBDA");
+            double lambda=config.template get<double>("LAMBDA");
             if(lambda >= 0) throw std::runtime_error(
                     "Sigma matrix is only computed for LAMBDA < 0");
             return Sigma.diagonal();
         }
-
-    private:
-        void init() {
-            if (config.exist("WEIGHTS")) {
-                weights.resize(config.size("WEIGHTS"));
-                config.get<t_type>("WEIGHTS",weights);
-                trained=true;
-            }
-
-            verbose=(config.get<int>("VERBOSE"));
-        }
 };
+
+//template
+//<class K=Function_Base&>
+//class M_KRR_Predict:
+//    public M_KRR_Core<K>,    // MUST NOT BE VIRUAL!!
+//    public virtual M_Predict
+//{
+//
+//    public:
+//        using M_KRR_Core<K>::config;
+//        using M_KRR_Core<K>::kernel;
+//        using M_KRR_Core<K>::weights;
+//
+//        M_KRR_Predict(Config &c):
+//            M_KRR_Core<K>(c)
+//    {}
+//
+//        M_KRR_Predict(K &kernel, Config &c):
+//            M_KRR_Core<K>(kernel, c)
+//    {}
+//
+//        double epredict(const aed_rctype &aed) {
+//            return kernel.epredict(weights,aed);
+//        };
+//
+//        double fpredict(const fd_type &fdij, const aed_rctype &aedi, const size_t k) {
+//            return kernel.fpredict(weights,fdij,aedi,k);
+//        }
+//
+//        force_type fpredict(const fd_type &fdij, const aed_rctype &aedi) {
+//            return kernel.fpredict(weights,fdij,aedi).array();
+//        }
+//
+//};
 #endif
-- 
GitLab