From 679c69f67308418bc5de2c7588cc2ff690ed8d53 Mon Sep 17 00:00:00 2001
From: mkirsz <s1351949@sms.ed.ac.uk>
Date: Mon, 29 Jan 2024 22:24:04 +0000
Subject: [PATCH] More refactoring

---
 m_base.h     | 36 ++++++++++++++++++------------------
 m_blr_core.h | 14 +++++++++++---
 m_core.h     | 43 ++++++++++++++++++++++++++++++++++++++++++-
 m_krr_core.h | 14 +++++++++++---
 4 files changed, 82 insertions(+), 25 deletions(-)

diff --git a/m_base.h b/m_base.h
index 4ecb44b..361a1d1 100644
--- a/m_base.h
+++ b/m_base.h
@@ -10,24 +10,24 @@ class M_Base {
     public:
         //Normaliser norm;    // TODO?
         virtual ~M_Base() {};
-
-        /** \brief Predict local energy of an atom or bond energy.
-         *
-         *  The result depends on how aed is computed.
-         *
-         *  If it is computed between a pair of atoms than the result is a bond energy.
-         *
-         *  If aed contains sum over all nearest neighbours than the result is
-         *  a local atomic energy \f$ E_i \f$.
-         * */
-        virtual double epredict(const aed_rctype &aed)=0;
-
-        /** \brief Predict force between a pair of atoms in a k-direction. */
-        virtual double fpredict(const fd_type &fdij, const aed_rctype &aedi, size_t k)=0;
-
-        /** \brief Predict force between a pair of atoms. */
-        virtual force_type fpredict(const fd_type &fdij,
-                const aed_rctype &aedi)=0;
+//
+//        /** \brief Predict local energy of an atom or bond energy.
+//         *
+//         *  The result depends on how aed is computed.
+//         *
+//         *  If it is computed between a pair of atoms than the result is a bond energy.
+//         *
+//         *  If aed contains sum over all nearest neighbours than the result is
+//         *  a local atomic energy \f$ E_i \f$.
+//         * */
+//        virtual double epredict(const aed_rctype &aed)=0;
+//
+//        /** \brief Predict force between a pair of atoms in a k-direction. */
+//        virtual double fpredict(const fd_type &fdij, const aed_rctype &aedi, size_t k)=0;
+//
+//        /** \brief Predict force between a pair of atoms. */
+//        virtual force_type fpredict(const fd_type &fdij,
+//                const aed_rctype &aedi)=0;
 
 
 };
diff --git a/m_blr_core.h b/m_blr_core.h
index dea3ea2..bdd4980 100644
--- a/m_blr_core.h
+++ b/m_blr_core.h
@@ -11,13 +11,20 @@
 
 template
 <class BF=Function_Base&>
-class M_BLR_Core: public virtual M_Core {
+class M_BLR_Core:
+    public virtual M_Core,
+    public virtual M_Predict,
+    public virtual M_Train {
     public:
-        Config &config;
+        //using M_Core<BF>::weights;
+        //using M_Core<BF>::trained;
+        //using M_Core<BF>::config;
         mat Sigma;
+        Config &config;
         BF bf;
         M_BLR_Core(Config &c):
             config(c),
+            //M_Core<BF>(c),
             bf(c)
     {
         static_assert(std::is_base_of<BF_Base, BF>::value,
@@ -31,6 +38,7 @@ class M_BLR_Core: public virtual M_Core {
 
         M_BLR_Core(BF &bf, Config &c):
             config(c),
+            //M_Core<BF>(bf,c),
             bf(bf)
     {
         if (dynamic_cast<BF_Base*>(&bf) == nullptr)
@@ -67,7 +75,7 @@ class M_BLR_Core: public virtual M_Core {
                 trained=true;
             }
 
-            M_Core::verbose=(config.get<int>("VERBOSE"));
+            verbose=(config.get<int>("VERBOSE"));
         }
 };
 #endif
diff --git a/m_core.h b/m_core.h
index 1caa4b0..cced00f 100644
--- a/m_core.h
+++ b/m_core.h
@@ -3,8 +3,21 @@
 
 #include "../CORE/typedefs.h"
 
+//template
+//<class F=Function_Base&>
 class M_Core {
     public:
+//        F f;
+//        Config &config;
+//
+//        M_Core(Config &c):
+//            config(c),
+//            f(c)
+//    {}
+//        M_Core(F &f, Config &c):
+//            config(c),
+//            f(f)
+//    {}
         int verbose;
         bool trained=false;
         t_type weights;
@@ -12,7 +25,6 @@ class M_Core {
         virtual ~M_Core() {}
 
         virtual double predict(const rvec &v)const=0;
-        virtual void train(phi_type &Phi, const t_type &T)=0;
         virtual t_type get_weights_uncertainty()const=0;
 
         bool is_trained() const {
@@ -25,4 +37,33 @@ class M_Core {
             weights=w;
         }
 };
+class M_Predict {
+    public:
+
+        virtual ~M_Predict() {}
+
+        /** \brief Predict local energy of an atom or bond energy.
+         *
+         *  The result depends on how aed is computed.
+         *
+         *  If it is computed between a pair of atoms than the result is a bond energy.
+         *
+         *  If aed contains sum over all nearest neighbours than the result is
+         *  a local atomic energy \f$ E_i \f$.
+         * */
+        virtual double epredict(const aed_rctype &aed)=0;
+
+        /** \brief Predict force between a pair of atoms in a k-direction. */
+        virtual double fpredict(const fd_type &fdij, const aed_rctype &aedi, size_t k)=0;
+
+        /** \brief Predict force between a pair of atoms. */
+        virtual force_type fpredict(const fd_type &fdij,
+                const aed_rctype &aedi)=0;
+};
+class M_Train {
+    public:
+        virtual ~M_Train() {}
+
+        virtual void train(phi_type &Phi, const t_type &T)=0;
+};
 #endif
diff --git a/m_krr_core.h b/m_krr_core.h
index 3499aad..5ebf7db 100644
--- a/m_krr_core.h
+++ b/m_krr_core.h
@@ -12,14 +12,21 @@
 
 template
 <class K=Function_Base&>
-class M_KRR_Core: public virtual M_Core {
+class M_KRR_Core:
+    public virtual M_Core,
+    public virtual M_Predict,
+    public virtual M_Train {
     public:
-        Config &config;
+        //using M_Core<K>::weights;
+        //using M_Core<K>::trained;
+        //using M_Core<K>::config;
         mat Sigma;
+        Config &config;
         K kernel;
         EKM<K> ekm;
         M_KRR_Core(Config &c):
             config(c),
+            //M_Core<K>(c),
             kernel(c),
             ekm(c)
     {
@@ -33,6 +40,7 @@ class M_KRR_Core: public virtual M_Core {
 
         M_KRR_Core(K &kernel, Config &c):
             config(c),
+            //M_Core<K>(kernel, c),
             kernel(kernel),
             ekm(kernel)
     {
@@ -75,7 +83,7 @@ class M_KRR_Core: public virtual M_Core {
                 trained=true;
             }
 
-            M_Core::verbose=(config.get<int>("VERBOSE"));
+            verbose=(config.get<int>("VERBOSE"));
         }
 };
 #endif
-- 
GitLab