From 1ac0e0300ced7cc759cd917b0d357dad0c4573c0 Mon Sep 17 00:00:00 2001
From: mkirsz <s1351949@sms.ed.ac.uk>
Date: Sat, 14 Dec 2024 02:18:24 +0000
Subject: [PATCH 01/15] Major

---
 include/tadah/mlip/descriptors_calc.hpp | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/include/tadah/mlip/descriptors_calc.hpp b/include/tadah/mlip/descriptors_calc.hpp
index 7cf86c1..380401a 100644
--- a/include/tadah/mlip/descriptors_calc.hpp
+++ b/include/tadah/mlip/descriptors_calc.hpp
@@ -27,10 +27,10 @@ DescriptorsCalc<D2,D3,DM,C2,C3,CM>::DescriptorsCalc(Config &c, T1 &t1, T2 &t2, T
     if (config.get<bool>("BIAS"))
       bias++;
     if (config.get<bool>("INIT2B")) {
-      d2.fidx=bias;
+      d2.set_fidx(bias);
     }
     if (config.get<bool>("INITMB")) {
-      dm.fidx=bias+d2.size();
+      dm.set_fidx(bias+d2.size());
     }
   }
 }
@@ -112,7 +112,7 @@ void DescriptorsCalc<D2,D3,DM,C2,C3,CM>::common_constructor() {
     if (!config.exist("TYPE2B"))
       config.add("TYPE2B",d2.label());
     dsize+=d2.size();
-    d2.fidx=bias;
+    d2.set_fidx(bias);
   }
   else {
     config.add("SIZE2B",0);
@@ -133,7 +133,7 @@ void DescriptorsCalc<D2,D3,DM,C2,C3,CM>::common_constructor() {
     if (!config.exist("TYPEMB"))
       config.add("TYPEMB",dm.label());
     dsize+=dm.size();
-    dm.fidx=bias+d2.size();
+    dm.set_fidx(bias+d2.size());
   }
   else {
     config.add("SIZEMB",0);
-- 
GitLab


From 874c2dfb5fe9f1e07c302a3c2345c98d3a7161f9 Mon Sep 17 00:00:00 2001
From: mkirsz <s1351949@sms.ed.ac.uk>
Date: Sun, 15 Dec 2024 01:32:50 +0000
Subject: [PATCH 02/15] Update

---
 include/tadah/mlip/descriptors_calc.h   |  3 --
 include/tadah/mlip/descriptors_calc.hpp | 47 +++++++------------------
 2 files changed, 12 insertions(+), 38 deletions(-)

diff --git a/include/tadah/mlip/descriptors_calc.h b/include/tadah/mlip/descriptors_calc.h
index daaf78c..14d1c4b 100644
--- a/include/tadah/mlip/descriptors_calc.h
+++ b/include/tadah/mlip/descriptors_calc.h
@@ -105,7 +105,6 @@ class DescriptorsCalc: public DC_Base {
         /** Calculate all descriptors in a StructureDB st_db */
         StDescriptorsDB calc(const StructureDB &st_db);
 
-
         /** Calculate density vector for the structure */
         void calc_rho(const Structure &st, StDescriptors &std);
 
@@ -122,8 +121,6 @@ class DescriptorsCalc: public DC_Base {
         void calc(const Structure &st, StDescriptors &std);
         void calc_dimer(const Structure &st, StDescriptors &std);
         void common_constructor();
-        double weights[119];  // ignore zero index; w[1]->H
-
 };
 
 #include "descriptors_calc.hpp"
diff --git a/include/tadah/mlip/descriptors_calc.hpp b/include/tadah/mlip/descriptors_calc.hpp
index 380401a..0d3813f 100644
--- a/include/tadah/mlip/descriptors_calc.hpp
+++ b/include/tadah/mlip/descriptors_calc.hpp
@@ -14,12 +14,6 @@ DescriptorsCalc<D2,D3,DM,C2,C3,CM>::DescriptorsCalc(Config &c, T1 &t1, T2 &t2, T
   d3(t2),
   dm(t3)
 {
-  for (size_t i=0; i<c.size("ATOMS"); ++i) {
-    std::string symbol = c.get<std::string>("ATOMS",i);
-    double wi = c.get<double>("WATOMS",i);
-    int Z = PeriodicTable::find_by_symbol(symbol).Z;
-    weights[Z]=wi;
-  }
   if (!config.exist("DSIZE"))
     common_constructor();
   else {
@@ -75,12 +69,6 @@ DescriptorsCalc<D2,D3,DM,C2,C3,CM>::DescriptorsCalc(Config &c, T1 &d2, T2 &d3, T
   d3(d3),
   dm(dm)
 {
-  for (size_t i=0; i<c.size("ATOMS"); ++i) {
-    std::string symbol = c.get<std::string>("ATOMS",i);
-    double wi = c.get<double>("WATOMS",i);
-    int Z = PeriodicTable::find_by_symbol(symbol).Z;
-    weights[Z]=wi;
-  }
   if (c.get<bool>("INIT2B")) {
     if (!config.exist("RCTYPE2B"))
       config.add("RCTYPE2B",c2.label());
@@ -170,10 +158,8 @@ void DescriptorsCalc<D2,D3,DM,C2,C3,CM>::calc_rho(const Structure &st, StDescrip
       double rij_sq = delij[0]*delij[0] + delij[1]*delij[1] + delij[2]*delij[2];
       if (rij_sq > rcut_mb_sq) continue;
       int Zj = st.near_neigh_atoms[i][jj].Z;
-      double wj = weights[Zj];
       double rij = sqrt(rij_sq);
-      double fc_ij = wj*cm.calc(rij);
-      dm.calc_rho(rij,rij_sq,fc_ij,delij,st_d.get_rho(i));
+      dm.calc_rho(Zj,rij,rij_sq,delij,st_d.get_rho(i));
     }
   }
 }
@@ -204,8 +190,8 @@ void DescriptorsCalc<D2,D3,DM,C2,C3,CM>::calc(const Structure &st, StDescriptors
   double rcut_2b_sq = 0.0;
   double rcut_mb_sq = 0.0;
 
-  if (init2b) rcut_2b_sq = pow(config.get<double>("RCUT2B"),2);
-  if (initmb) rcut_mb_sq = pow(config.get<double>("RCUTMB"),2);
+  if (init2b) rcut_2b_sq = pow(d2.get_rcut(),2);
+  if (initmb) rcut_mb_sq = pow(dm.get_rcut(),2);
 
   // zero all aeds and set bias
   for (size_t i=0; i<st.natoms(); ++i) {
@@ -244,7 +230,6 @@ void DescriptorsCalc<D2,D3,DM,C2,C3,CM>::calc(const Structure &st, StDescriptors
 
       if (rij_sq > rcut_max_sq) continue;
       int Zj = st.near_neigh_atoms[i][jj].Z;
-      double wj = weights[Zj];
       double rij = sqrt(rij_sq);
       double rij_inv = 1.0/rij;
 
@@ -252,9 +237,7 @@ void DescriptorsCalc<D2,D3,DM,C2,C3,CM>::calc(const Structure &st, StDescriptors
       if (use_force || use_stress) {
         fd_type &fd_ij = st_d.fd[i][jj];
         if (rij_sq <= rcut_2b_sq && init2b) {
-          double fc_ij = wj*c2.calc(rij);
-          double fcp_ij = wj*c2.calc_prime(rij);
-          d2.calc_all(rij,rij_sq,fc_ij,fcp_ij,aed,fd_ij);
+          d2.calc_all(Zj,rij,rij_sq,aed,fd_ij);
           // Two-body descriptor calculates x-direction only - fd_ij(n,0)
           // so we have to copy x-dir to y- and z-dir
           // and scale them by the unit directional vector delij/rij.
@@ -267,11 +250,8 @@ void DescriptorsCalc<D2,D3,DM,C2,C3,CM>::calc(const Structure &st, StDescriptors
         }
         // CALCULATE MANY-BODY TERM
         if (rij_sq <= rcut_mb_sq && initmb) {
-          double fc_ij = wj*cm.calc(rij);
-          double fcp_ij = wj*cm.calc_prime(rij);
           rho_type& rhoi = st_d.get_rho(i);
-          int mode = dm.calc_dXijdri(rij,rij_sq,delij,
-              fc_ij,fcp_ij,rhoi,fd_ij);
+          int mode = dm.calc_dXijdri(Zj,rij,rij_sq,delij,rhoi,fd_ij);
           if (mode==0) {
             // some dm compute x-dir only, similarly to d2 above
             for (size_t n=size2b+bias; n<size2b+sizemb+bias; ++n) {
@@ -287,8 +267,7 @@ void DescriptorsCalc<D2,D3,DM,C2,C3,CM>::calc(const Structure &st, StDescriptors
       }
       else {
         if (rij_sq <= rcut_2b_sq && init2b) {
-          double fc_ij = wj*c2.calc(rij);
-          d2.calc_aed(rij,rij_sq,fc_ij,aed);
+          d2.calc_aed(Zj,rij,rij_sq,aed);
         }
       }
 
@@ -332,8 +311,8 @@ void DescriptorsCalc<D2,D3,DM,C2,C3,CM>::calc_dimer(const Structure &st, StDescr
   double rcut_2b_sq = 0.0;
   double rcut_mb_sq = 0.0;
 
-  if (init2b) rcut_2b_sq = pow(config.get<double>("RCUT2B"),2);
-  if (initmb) rcut_mb_sq = pow(config.get<double>("RCUTMB"),2);
+  if (init2b) rcut_2b_sq = pow(d2.get_rcut(),2);
+  if (initmb) rcut_mb_sq = pow(dm.get_rcut(),2);
 
   // Max distance between CoM of two interacting molecules
   double rcut_com_sq = pow(config.get<double>("RCUTMAX")-r_b,2);
@@ -399,15 +378,13 @@ void DescriptorsCalc<D2,D3,DM,C2,C3,CM>::calc_dimer(const Structure &st, StDescr
   size_t N = bond_bool ? 6 : 4;
   for (size_t n=0; n<N; ++n) {
     if (r_sq[n] <= rcut_mb_sq && initmb) {
-      double fcmb = Zj*cm.calc(r[n]);
-      dm.calc_rho(r[n],r_sq[n],fcmb,delM.row(n),st_d.get_rho(idx[n].first));
-      dm.calc_rho(r[n],r_sq[n],fcmb,-delM.row(n),st_d.get_rho(idx[n].second));
+      dm.calc_rho(Zj, r[n],r_sq[n],delM.row(n),st_d.get_rho(idx[n].first));
+      dm.calc_rho(Zj, r[n],r_sq[n],-delM.row(n),st_d.get_rho(idx[n].second));
     }
     // Do not compute 2b term between bonded atoms
     if (r_sq[n] <= rcut_2b_sq && init2b) {
-      double fc2b = Zj*c2.calc(r[n]);
-      d2.calc_aed(r[n],r_sq[n],fc2b,st_d.get_aed(idx[n].first));
-      d2.calc_aed(r[n],r_sq[n],fc2b,st_d.get_aed(idx[n].second));
+      d2.calc_aed(Zj,r[n],r_sq[n],st_d.get_aed(idx[n].first));
+      d2.calc_aed(Zj,r[n],r_sq[n],st_d.get_aed(idx[n].second));
     }
   }
 
-- 
GitLab


From 7fec2a9d2c758f6f8e2297e4db011d8a07f7ace7 Mon Sep 17 00:00:00 2001
From: mkirsz <s1351949@sms.ed.ac.uk>
Date: Sun, 15 Dec 2024 20:12:11 +0000
Subject: [PATCH 03/15] Fixes

---
 include/tadah/mlip/descriptors_calc.hpp | 33 ++++++++++++++++++++-----
 1 file changed, 27 insertions(+), 6 deletions(-)

diff --git a/include/tadah/mlip/descriptors_calc.hpp b/include/tadah/mlip/descriptors_calc.hpp
index 0d3813f..f0bbe1d 100644
--- a/include/tadah/mlip/descriptors_calc.hpp
+++ b/include/tadah/mlip/descriptors_calc.hpp
@@ -34,27 +34,42 @@ DescriptorsCalc<D2,D3,DM,C2,C3,CM>::DescriptorsCalc(Config &c):
 {
   if (c.get<bool>("INIT2B")) {
     c2 = C2(c.get<double>("RCUT2B"));
-    if (!config.exist("RCTYPE2B"))
+    if (!config.exist("RCTYPE2B")) {
       config.add("RCTYPE2B",c2.label());
+      d2.set_fcut(&c2,false);
+    }
   }
   else {
     c2 = C2(0);
+    if (!config.exist("RCTYPE2B")) {
+      d2.set_fcut(&c2,false);
+    }
   }
   if (c.get<bool>("INIT3B")) {
     c3 = C3(c.get<double>("RCUT3B"));
-    if (!config.exist("RCTYPE3B"))
+    if (!config.exist("RCTYPE3B")) {
       config.add("RCTYPE3B",c3.label());
+      d3.set_fcut(&c3,false);
+    }
   }
   else {
     c3 = C3(0);
+    if (!config.exist("RCTYPE3B")) {
+      d3.set_fcut(&c3,false);
+    }
   }
   if (c.get<bool>("INITMB")) {
     cm = CM(c.get<double>("RCUTMB"));
-    if (!config.exist("RCTYPEMB"))
+    if (!config.exist("RCTYPEMB")) {
       config.add("RCTYPEMB",cm.label());
+      dm.set_fcut(&cm,false);
+    }
   }
   else {
     cm = CM(0);
+    if (!config.exist("RCTYPEMB")) {
+      dm.set_fcut(&cm,false);
+    }
   }
 }
 
@@ -70,16 +85,22 @@ DescriptorsCalc<D2,D3,DM,C2,C3,CM>::DescriptorsCalc(Config &c, T1 &d2, T2 &d3, T
   dm(dm)
 {
   if (c.get<bool>("INIT2B")) {
-    if (!config.exist("RCTYPE2B"))
+    if (!config.exist("RCTYPE2B")) {
       config.add("RCTYPE2B",c2.label());
+      d2.set_fcut(&c2,false);
+    }
   }
   if (c.get<bool>("INIT3B")) {
-    if (!config.exist("RCTYPE3B"))
+    if (!config.exist("RCTYPE3B")) {
       config.add("RCTYPE3B",c3.label());
+      d3.set_fcut(&c3,false);
+    }
   }
   if (c.get<bool>("INITMB")) {
-    if (!config.exist("RCTYPEMB"))
+    if (!config.exist("RCTYPEMB")) {
       config.add("RCTYPEMB",cm.label());
+      dm.set_fcut(&cm,false);
+    }
   }
 
   if (!config.exist("DSIZE"))
-- 
GitLab


From 305976d0424a10d09e39987ac6292a86e3abe135 Mon Sep 17 00:00:00 2001
From: mkirsz <s1351949@sms.ed.ac.uk>
Date: Mon, 16 Dec 2024 02:25:42 +0000
Subject: [PATCH 04/15] cut fixes

---
 include/tadah/mlip/descriptors_calc.hpp | 8 ++++----
 include/tadah/mlip/models/m_blr.h       | 2 +-
 include/tadah/mlip/models/m_krr.h       | 2 +-
 3 files changed, 6 insertions(+), 6 deletions(-)

diff --git a/include/tadah/mlip/descriptors_calc.hpp b/include/tadah/mlip/descriptors_calc.hpp
index f0bbe1d..7f334d8 100644
--- a/include/tadah/mlip/descriptors_calc.hpp
+++ b/include/tadah/mlip/descriptors_calc.hpp
@@ -33,7 +33,7 @@ DescriptorsCalc<D2,D3,DM,C2,C3,CM>::DescriptorsCalc(Config &c):
   DescriptorsCalc(c,c,c,c)
 {
   if (c.get<bool>("INIT2B")) {
-    c2 = C2(c.get<double>("RCUT2B"));
+    c2 = C2(c.get<double>("RCUT2BMAX"));
     if (!config.exist("RCTYPE2B")) {
       config.add("RCTYPE2B",c2.label());
       d2.set_fcut(&c2,false);
@@ -46,7 +46,7 @@ DescriptorsCalc<D2,D3,DM,C2,C3,CM>::DescriptorsCalc(Config &c):
     }
   }
   if (c.get<bool>("INIT3B")) {
-    c3 = C3(c.get<double>("RCUT3B"));
+    c3 = C3(c.get<double>("RCUT3BMAX"));
     if (!config.exist("RCTYPE3B")) {
       config.add("RCTYPE3B",c3.label());
       d3.set_fcut(&c3,false);
@@ -59,7 +59,7 @@ DescriptorsCalc<D2,D3,DM,C2,C3,CM>::DescriptorsCalc(Config &c):
     }
   }
   if (c.get<bool>("INITMB")) {
-    cm = CM(c.get<double>("RCUTMB"));
+    cm = CM(c.get<double>("RCUTMBMAX"));
     if (!config.exist("RCTYPEMB")) {
       config.add("RCTYPEMB",cm.label());
       dm.set_fcut(&cm,false);
@@ -158,7 +158,7 @@ void DescriptorsCalc<D2,D3,DM,C2,C3,CM>::common_constructor() {
 }
 template <typename D2, typename D3, typename DM, typename C2, typename C3, typename CM>
 void DescriptorsCalc<D2,D3,DM,C2,C3,CM>::calc_rho(const Structure &st, StDescriptors &st_d) {
-  double rcut_mb_sq = pow(config.get<double>("RCUTMB"),2);
+  double rcut_mb_sq = pow(config.get<double>("RCUTMBMAX"),2);
   rhos_type &rhos = st_d.rhos;
   size_t s = dm.rhoi_size()+dm.rhoip_size();
   rhos.resize(s,st.natoms());
diff --git a/include/tadah/mlip/models/m_blr.h b/include/tadah/mlip/models/m_blr.h
index 043ec62..d687282 100644
--- a/include/tadah/mlip/models/m_blr.h
+++ b/include/tadah/mlip/models/m_blr.h
@@ -147,7 +147,6 @@ public:
     c.remove("VERBOSE");
     c.add("VERBOSE", 0);
 
-    c.clear_internal_keys();
     c.remove("MODEL");
     c.add("MODEL", label);
     c.add("MODEL", bf.get_label());
@@ -164,6 +163,7 @@ public:
         c.add("NSTDEV", norm.std_dev[i]);
       }
     }
+    c.clear_internal_keys();
     return c;
   }
   StructureDB predict(Config config_pred, StructureDB &stdb, DC_Base &dc,
diff --git a/include/tadah/mlip/models/m_krr.h b/include/tadah/mlip/models/m_krr.h
index c3f2057..610c1fc 100644
--- a/include/tadah/mlip/models/m_krr.h
+++ b/include/tadah/mlip/models/m_krr.h
@@ -201,7 +201,6 @@ public:
     c.remove("VERBOSE");
     c.add("VERBOSE", 0);
 
-    c.clear_internal_keys();
     c.remove("MODEL");
     c.add("MODEL", label);
     c.add("MODEL", kernel.get_label());
@@ -234,6 +233,7 @@ public:
         }
       }
     }
+    c.clear_internal_keys();
     return c;
   }
   StructureDB predict(Config config_pred, StructureDB &stdb, DC_Base &dc,
-- 
GitLab


From 4aa3839d51a12c51a5dd7802fc546d889778a454 Mon Sep 17 00:00:00 2001
From: mkirsz <s1351949@sms.ed.ac.uk>
Date: Mon, 16 Dec 2024 22:09:15 +0000
Subject: [PATCH 05/15] All DM descriptors to compute xyz components

---
 include/tadah/mlip/descriptors_calc.hpp | 13 +------------
 1 file changed, 1 insertion(+), 12 deletions(-)

diff --git a/include/tadah/mlip/descriptors_calc.hpp b/include/tadah/mlip/descriptors_calc.hpp
index 7f334d8..2f1ab7f 100644
--- a/include/tadah/mlip/descriptors_calc.hpp
+++ b/include/tadah/mlip/descriptors_calc.hpp
@@ -272,18 +272,7 @@ void DescriptorsCalc<D2,D3,DM,C2,C3,CM>::calc(const Structure &st, StDescriptors
         // CALCULATE MANY-BODY TERM
         if (rij_sq <= rcut_mb_sq && initmb) {
           rho_type& rhoi = st_d.get_rho(i);
-          int mode = dm.calc_dXijdri(Zj,rij,rij_sq,delij,rhoi,fd_ij);
-          if (mode==0) {
-            // some dm compute x-dir only, similarly to d2 above
-            for (size_t n=size2b+bias; n<size2b+sizemb+bias; ++n) {
-              fd_ij(n,0) *= rij_inv;
-              fd_ij(n,1) = fd_ij(n,0);
-              fd_ij(n,2) = fd_ij(n,0);
-              fd_ij(n,0) *= delij[0];
-              fd_ij(n,1) *= delij[1];
-              fd_ij(n,2) *= delij[2];
-            }
-          }
+          dm.calc_dXijdri(Zj,rij,rij_sq,delij,rhoi,fd_ij);
         }
       }
       else {
-- 
GitLab


From e72c34820462fdb05c47b85b179b78fa969675ac Mon Sep 17 00:00:00 2001
From: mkirsz <s1351949@sms.ed.ac.uk>
Date: Tue, 17 Dec 2024 09:11:46 +0000
Subject: [PATCH 06/15] units updated for printing

---
 include/tadah/mlip/design_matrix/design_matrix.h | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/include/tadah/mlip/design_matrix/design_matrix.h b/include/tadah/mlip/design_matrix/design_matrix.h
index 0289b7a..2edf859 100644
--- a/include/tadah/mlip/design_matrix/design_matrix.h
+++ b/include/tadah/mlip/design_matrix/design_matrix.h
@@ -294,7 +294,7 @@ private:
     if (stress) {
       for (size_t j=0; j<6; ++j) {
         s_std_dev[j] = svec.col(j).std_dev(svec.col(j).mean(), svec.col(j).size()-1);
-        if (verbose) std::cout << "Stress standard deviation (eV): " << s_std_dev[j] << std::endl;
+        if (verbose) std::cout << "Stress standard deviation (eV/A^3): " << s_std_dev[j] << std::endl;
       }
     }
 
@@ -314,7 +314,7 @@ private:
       // e_std_dev has units of energy
       // f_std_dev has units of inverse distance
       f_std_dev = fvec.std_dev(fvec.mean(),fvec.size()-1);
-      if (verbose) std::cout << "Force standard deviation (1/A): " << f_std_dev << std::endl;
+      if (verbose) std::cout << "Force standard deviation (A^-1): " << f_std_dev << std::endl;
     }
 
     config.add("ESTDEV",e_std_dev);
-- 
GitLab


From 23ac22b4cac2a85e99fedeef63c4cabff50c37b9 Mon Sep 17 00:00:00 2001
From: mkirsz <s1351949@sms.ed.ac.uk>
Date: Tue, 17 Dec 2024 13:07:58 +0000
Subject: [PATCH 07/15] Inheritance fixes

---
 .../functions/basis_functions/dm_bf_base.h    | 17 +++++++++-------
 .../functions/basis_functions/dm_bf_linear.h  |  8 ++++----
 .../basis_functions/dm_bf_polynomial2.h       |  8 ++++----
 .../mlip/design_matrix/functions/dm_f_all.h   |  2 +-
 .../functions/dm_function_base.h              | 20 ++++++++++---------
 .../functions/kernels/dm_kern_base.h          | 16 +++++++--------
 .../functions/kernels/dm_kern_linear.h        |  8 ++++----
 src/dm_bf_base.cpp                            |  4 ++++
 src/dm_bf_linear.cpp                          |  2 +-
 src/dm_bf_polynomial2.cpp                     |  3 ++-
 src/dm_function_base.cpp                      |  5 ++++-
 src/dm_kern_base.cpp                          |  5 +++++
 src/dm_kern_linear.cpp                        |  3 ++-
 src/dm_kern_lq.cpp                            |  5 +++--
 src/dm_kern_polynomial.cpp                    |  4 ++--
 src/dm_kern_quadratic.cpp                     |  4 ++--
 src/dm_kern_rbf.cpp                           |  4 ++--
 src/dm_kern_sigmoid.cpp                       |  4 ++--
 18 files changed, 71 insertions(+), 51 deletions(-)

diff --git a/include/tadah/mlip/design_matrix/functions/basis_functions/dm_bf_base.h b/include/tadah/mlip/design_matrix/functions/basis_functions/dm_bf_base.h
index 98b1c37..6f6187d 100644
--- a/include/tadah/mlip/design_matrix/functions/basis_functions/dm_bf_base.h
+++ b/include/tadah/mlip/design_matrix/functions/basis_functions/dm_bf_base.h
@@ -10,13 +10,16 @@
 #include <iostream>
 
 struct DM_BF_Base: public DM_Function_Base, public virtual BF_Base {
+   
+    DM_BF_Base();
+    DM_BF_Base(const Config &c);
     virtual ~DM_BF_Base();
-    virtual size_t get_phi_cols(const Config &config)=0;
-    virtual void calc_phi_energy_row(phi_type &Phi, size_t &row,
-            const double fac, const Structure &st, const StDescriptors &st_d)=0;
-    virtual void calc_phi_force_rows(phi_type &Phi, size_t &row,
-            const double fac, const Structure &st, const StDescriptors &st_d)=0;
-    virtual void calc_phi_stress_rows(phi_type &Phi, size_t &row,
-            const double fac[6], const Structure &st, const StDescriptors &st_d)=0;
+    // virtual size_t get_phi_cols(const Config &config)=0;
+    // virtual void calc_phi_energy_row(phi_type &Phi, size_t &row,
+    //         const double fac, const Structure &st, const StDescriptors &st_d)=0;
+    // virtual void calc_phi_force_rows(phi_type &Phi, size_t &row,
+    //         const double fac, const Structure &st, const StDescriptors &st_d)=0;
+    // virtual void calc_phi_stress_rows(phi_type &Phi, size_t &row,
+    //         const double fac[6], const Structure &st, const StDescriptors &st_d)=0;
 };
 #endif
diff --git a/include/tadah/mlip/design_matrix/functions/basis_functions/dm_bf_linear.h b/include/tadah/mlip/design_matrix/functions/basis_functions/dm_bf_linear.h
index 107ac26..e0052c3 100644
--- a/include/tadah/mlip/design_matrix/functions/basis_functions/dm_bf_linear.h
+++ b/include/tadah/mlip/design_matrix/functions/basis_functions/dm_bf_linear.h
@@ -8,12 +8,12 @@ struct DM_BF_Linear: public DM_BF_Base, public BF_Linear
 {
     DM_BF_Linear();
     DM_BF_Linear(const Config &c);
-    size_t get_phi_cols(const Config &config);
+    size_t get_phi_cols(const Config &config) override;
     void calc_phi_energy_row(phi_type &Phi, size_t &row,
-            const double fac, const Structure &st, const StDescriptors &st_d);
+            const double fac, const Structure &st, const StDescriptors &st_d) override;
     void calc_phi_force_rows(phi_type &Phi, size_t &row,
-            const double fac, const Structure &st, const StDescriptors &st_d);
+            const double fac, const Structure &st, const StDescriptors &st_d) override;
     void calc_phi_stress_rows(phi_type &Phi, size_t &row,
-            const double fac[6], const Structure &st, const StDescriptors &st_d);
+            const double fac[6], const Structure &st, const StDescriptors &st_d) override;
 };
 #endif
diff --git a/include/tadah/mlip/design_matrix/functions/basis_functions/dm_bf_polynomial2.h b/include/tadah/mlip/design_matrix/functions/basis_functions/dm_bf_polynomial2.h
index b0ce413..550f403 100644
--- a/include/tadah/mlip/design_matrix/functions/basis_functions/dm_bf_polynomial2.h
+++ b/include/tadah/mlip/design_matrix/functions/basis_functions/dm_bf_polynomial2.h
@@ -8,12 +8,12 @@ struct DM_BF_Polynomial2: public DM_BF_Base, public BF_Polynomial2
 {
     DM_BF_Polynomial2();
     DM_BF_Polynomial2(const Config &c);
-    size_t get_phi_cols(const Config &config);
+    size_t get_phi_cols(const Config &config) override;
     void  calc_phi_energy_row(phi_type &Phi, size_t &row,
-            const double fac, const Structure &st, const StDescriptors &st_d);
+            const double fac, const Structure &st, const StDescriptors &st_d) override;
     void  calc_phi_force_rows(phi_type &Phi, size_t &row,
-            const double fac, const Structure &st, const StDescriptors &st_d);
+            const double fac, const Structure &st, const StDescriptors &st_d) override;
     void  calc_phi_stress_rows(phi_type &Phi, size_t &row,
-            const double fac[6], const Structure &st, const StDescriptors &st_d);
+            const double fac[6], const Structure &st, const StDescriptors &st_d) override;
 };
 #endif
diff --git a/include/tadah/mlip/design_matrix/functions/dm_f_all.h b/include/tadah/mlip/design_matrix/functions/dm_f_all.h
index 201905a..56c4bae 100644
--- a/include/tadah/mlip/design_matrix/functions/dm_f_all.h
+++ b/include/tadah/mlip/design_matrix/functions/dm_f_all.h
@@ -1,2 +1,2 @@
 #include <tadah/mlip/design_matrix/functions/basis_functions/dm_bf_all.h>
-#include <tadah/mlip/design_matrix/functions/kernels/dm_kern_all.h>
+ #include <tadah/mlip/design_matrix/functions/kernels/dm_kern_all.h>
diff --git a/include/tadah/mlip/design_matrix/functions/dm_function_base.h b/include/tadah/mlip/design_matrix/functions/dm_function_base.h
index 6b7b5a9..06f5321 100644
--- a/include/tadah/mlip/design_matrix/functions/dm_function_base.h
+++ b/include/tadah/mlip/design_matrix/functions/dm_function_base.h
@@ -16,16 +16,18 @@
 /** Base class for Kernels and Basis Functions */
 struct DM_Function_Base: public virtual Function_Base {
 
-    // Derived classes must implement Derived() and Derived(Config)
-    virtual ~DM_Function_Base() {};
+  // Derived classes must implement Derived() and Derived(Config)
+  DM_Function_Base();
+  DM_Function_Base(const Config &c);
+  virtual ~DM_Function_Base();
 
-    virtual size_t get_phi_cols(const Config &)=0;
-    virtual void calc_phi_energy_row(phi_type &, size_t &,
-            const double , const Structure &, const StDescriptors &)=0;
-    virtual void calc_phi_force_rows(phi_type &, size_t &,
-            const double , const Structure &, const StDescriptors &)=0;
-    virtual void calc_phi_stress_rows(phi_type &, size_t &,
-            const double[6], const Structure &, const StDescriptors &)=0;
+  virtual size_t get_phi_cols(const Config &)=0;
+  virtual void calc_phi_energy_row(phi_type &, size_t &,
+                                   const double , const Structure &, const StDescriptors &)=0;
+  virtual void calc_phi_force_rows(phi_type &, size_t &,
+                                   const double , const Structure &, const StDescriptors &)=0;
+  virtual void calc_phi_stress_rows(phi_type &, size_t &,
+                                    const double[6], const Structure &, const StDescriptors &)=0;
 };
 //template<> inline CONFIG::Registry<DM_Function_Base>::Map CONFIG::Registry<DM_Function_Base>::registry{};
 //template<> inline CONFIG::Registry<DM_Function_Base,Config&>::Map CONFIG::Registry<DM_Function_Base,Config&>::registry{};
diff --git a/include/tadah/mlip/design_matrix/functions/kernels/dm_kern_base.h b/include/tadah/mlip/design_matrix/functions/kernels/dm_kern_base.h
index cf0c12f..8a12489 100644
--- a/include/tadah/mlip/design_matrix/functions/kernels/dm_kern_base.h
+++ b/include/tadah/mlip/design_matrix/functions/kernels/dm_kern_base.h
@@ -1,6 +1,7 @@
 #ifndef DM_KERN_BASE_H
 #define DM_KERN_BASE_H
 
+#include "tadah/core/config.h"
 #include <tadah/mlip/design_matrix/functions/dm_function_base.h>
 #include <tadah/mlip/structure.h>
 #include <tadah/mlip/st_descriptors.h>
@@ -16,17 +17,16 @@
  *  - ff = force descriptor
  *  - all derivatives are defined wrt to the second argument
  */
-class DM_Kern_Base: public DM_Function_Base, public virtual Kern_Base {
+class DM_Kern_Base: public DM_Function_Base, public virtual Kern_Base  {
     public:
 
+        DM_Kern_Base();
+        DM_Kern_Base(const Config&c);
         virtual ~DM_Kern_Base();
-        virtual size_t get_phi_cols(const Config &config);
-        virtual void  calc_phi_energy_row(phi_type &Phi, size_t &row, const double fac,
-                const Structure &st, const StDescriptors &st_d);
-        virtual void  calc_phi_force_rows(phi_type &Phi, size_t &row, const double fac,
-                const Structure &st, const StDescriptors &st_d);
-        virtual void  calc_phi_stress_rows(phi_type &Phi, size_t &row, const double fac[6],
-                const Structure &st, const StDescriptors &st_d);
+        virtual size_t get_phi_cols(const Config &config) override;
+        virtual void  calc_phi_energy_row(phi_type &Phi, size_t &row, const double fac, const Structure &st, const StDescriptors &st_d) override;
+        virtual void  calc_phi_force_rows(phi_type &Phi, size_t &row, const double fac, const Structure &st, const StDescriptors &st_d) override;
+        virtual void  calc_phi_stress_rows(phi_type &Phi, size_t &row, const double fac[6], const Structure &st, const StDescriptors &st_d) override;
 
 };
 #endif
diff --git a/include/tadah/mlip/design_matrix/functions/kernels/dm_kern_linear.h b/include/tadah/mlip/design_matrix/functions/kernels/dm_kern_linear.h
index 5c75dad..cf026d0 100644
--- a/include/tadah/mlip/design_matrix/functions/kernels/dm_kern_linear.h
+++ b/include/tadah/mlip/design_matrix/functions/kernels/dm_kern_linear.h
@@ -19,12 +19,12 @@ class DM_Kern_Linear :  public DM_Kern_Base, public Kern_Linear {
     DM_Kern_Linear ();
     DM_Kern_Linear (const Config &c);
 
-    size_t get_phi_cols(const Config &config);
+    size_t get_phi_cols(const Config &config) override;
     void calc_phi_energy_row(phi_type &Phi, size_t &row,
-            const double fac, const Structure &st, const StDescriptors &st_d);
+            const double fac, const Structure &st, const StDescriptors &st_d) override;
     void calc_phi_force_rows(phi_type &Phi, size_t &row,
-            const double fac, const Structure &st, const StDescriptors &st_d);
+            const double fac, const Structure &st, const StDescriptors &st_d) override;
     void calc_phi_stress_rows(phi_type &Phi, size_t &row,
-            const double fac[6], const Structure &st, const StDescriptors &st_d);
+            const double fac[6], const Structure &st, const StDescriptors &st_d) override;
 };
 #endif
diff --git a/src/dm_bf_base.cpp b/src/dm_bf_base.cpp
index e68de39..21543fa 100644
--- a/src/dm_bf_base.cpp
+++ b/src/dm_bf_base.cpp
@@ -1,2 +1,6 @@
+#include "tadah/mlip/design_matrix/functions/dm_function_base.h"
+#include "tadah/models/functions/basis_functions/bf_base.h"
 #include <tadah/mlip/design_matrix/functions/basis_functions/dm_bf_base.h>
+DM_BF_Base::DM_BF_Base() {}
+DM_BF_Base::DM_BF_Base(const Config &c): BF_Base(c), DM_Function_Base(c) {}
 DM_BF_Base::~DM_BF_Base() {}
diff --git a/src/dm_bf_linear.cpp b/src/dm_bf_linear.cpp
index ab467e8..9bb9846 100644
--- a/src/dm_bf_linear.cpp
+++ b/src/dm_bf_linear.cpp
@@ -4,7 +4,7 @@
 //CONFIG::Registry<DM_Function_Base,Config&>::Register<DM_BF_Linear> DM_BF_Linear_2( "BF_Linear" );
 
 DM_BF_Linear::DM_BF_Linear() {}
-DM_BF_Linear::DM_BF_Linear(const Config &c): BF_Linear(c)
+DM_BF_Linear::DM_BF_Linear(const Config &c): DM_BF_Base(c), BF_Linear(c)
 {}
 size_t DM_BF_Linear::get_phi_cols(const Config &config)
 {
diff --git a/src/dm_bf_polynomial2.cpp b/src/dm_bf_polynomial2.cpp
index 6638fdb..9ef36d3 100644
--- a/src/dm_bf_polynomial2.cpp
+++ b/src/dm_bf_polynomial2.cpp
@@ -1,10 +1,11 @@
+#include "tadah/mlip/design_matrix/functions/kernels/dm_kern_base.h"
 #include <tadah/mlip/design_matrix/functions/basis_functions/dm_bf_polynomial2.h>
 
 //CONFIG::Registry<DM_Function_Base>::Register<DM_BF_Polynomial2> DM_BF_Polynomial2_1( "BF_Polynomial2" );
 //CONFIG::Registry<DM_Function_Base,Config&>::Register<DM_BF_Polynomial2> DM_BF_Polynomial2_2( "BF_Polynomial2" );
 
 DM_BF_Polynomial2::DM_BF_Polynomial2() {}
-DM_BF_Polynomial2::DM_BF_Polynomial2(const Config &c): BF_Polynomial2(c)
+DM_BF_Polynomial2::DM_BF_Polynomial2(const Config &c): DM_BF_Base(c), BF_Polynomial2(c)
 {}
 size_t DM_BF_Polynomial2::get_phi_cols(const Config &config)
 {
diff --git a/src/dm_function_base.cpp b/src/dm_function_base.cpp
index 8c7af9b..945c434 100644
--- a/src/dm_function_base.cpp
+++ b/src/dm_function_base.cpp
@@ -1,2 +1,5 @@
+#include "tadah/models/functions/function_base.h"
 #include <tadah/mlip/design_matrix/functions/dm_function_base.h>
-
+DM_Function_Base::DM_Function_Base() {}
+DM_Function_Base::DM_Function_Base(const Config &c): Function_Base(c) {}
+DM_Function_Base::~DM_Function_Base() {}
diff --git a/src/dm_kern_base.cpp b/src/dm_kern_base.cpp
index ecf679c..773778a 100644
--- a/src/dm_kern_base.cpp
+++ b/src/dm_kern_base.cpp
@@ -1,6 +1,11 @@
+#include "tadah/core/config.h"
+#include "tadah/mlip/design_matrix/functions/dm_function_base.h"
+#include "tadah/models/functions/kernels/kern_base.h"
 #include <tadah/mlip/design_matrix/functions/kernels/dm_kern_base.h>
 
 DM_Kern_Base::~DM_Kern_Base() {}
+DM_Kern_Base::DM_Kern_Base() {}
+DM_Kern_Base::DM_Kern_Base(const Config &c): DM_Function_Base(c), Kern_Base(c) {}
 size_t DM_Kern_Base::get_phi_cols(const Config &)
 {
   return basis.cols();
diff --git a/src/dm_kern_linear.cpp b/src/dm_kern_linear.cpp
index 90f97a7..2788656 100644
--- a/src/dm_kern_linear.cpp
+++ b/src/dm_kern_linear.cpp
@@ -1,10 +1,11 @@
+#include "tadah/mlip/design_matrix/functions/kernels/dm_kern_base.h"
 #include <tadah/mlip/design_matrix/functions/kernels/dm_kern_linear.h>
 
 //CONFIG::Registry<DM_Function_Base>::Register<DM_Kern_Linear> DM_Kern_Linear_1( "Kern_Linear" );
 //CONFIG::Registry<DM_Function_Base,Config&>::Register<DM_Kern_Linear> DM_Kern_Linear_2( "Kern_Linear" );
 
 DM_Kern_Linear::DM_Kern_Linear() {}
-DM_Kern_Linear::DM_Kern_Linear (const Config &c): Kern_Linear(c)
+DM_Kern_Linear::DM_Kern_Linear (const Config &c): DM_Kern_Base(c), Kern_Linear(c)
 {}
 size_t DM_Kern_Linear::get_phi_cols(const Config &config)
 {
diff --git a/src/dm_kern_lq.cpp b/src/dm_kern_lq.cpp
index 3c14dfa..40fdb4d 100644
--- a/src/dm_kern_lq.cpp
+++ b/src/dm_kern_lq.cpp
@@ -1,11 +1,12 @@
+#include "tadah/mlip/design_matrix/functions/kernels/dm_kern_base.h"
 #include <tadah/mlip/design_matrix/functions/kernels/dm_kern_lq.h>
 
 //CONFIG::Registry<DM_Function_Base>::Register<DM_Kern_LQ> DM_Kern_LQ_1( "Kern_LQ" );
 //CONFIG::Registry<DM_Function_Base,Config&>::Register<DM_Kern_LQ> DM_Kern_LQ_2( "Kern_LQ" );
 
-DM_Kern_LQ::DM_Kern_LQ():
-  Kern_LQ()
+DM_Kern_LQ::DM_Kern_LQ()
 {}
 DM_Kern_LQ::DM_Kern_LQ(const Config &c):
+  DM_Kern_Base(c),
   Kern_LQ(c)
 {}
diff --git a/src/dm_kern_polynomial.cpp b/src/dm_kern_polynomial.cpp
index b60e9d9..5227ca7 100644
--- a/src/dm_kern_polynomial.cpp
+++ b/src/dm_kern_polynomial.cpp
@@ -3,9 +3,9 @@
 //CONFIG::Registry<DM_Function_Base>::Register<DM_Kern_Polynomial> DM_Kern_Polynomial_1( "Kern_Polynomial" );
 //CONFIG::Registry<DM_Function_Base,Config&>::Register<DM_Kern_Polynomial> DM_Kern_Polynomial_2( "Kern_Polynomial" );
 
-DM_Kern_Polynomial::DM_Kern_Polynomial():
-  Kern_Polynomial()
+DM_Kern_Polynomial::DM_Kern_Polynomial()
 {}
 DM_Kern_Polynomial::DM_Kern_Polynomial(const Config &c):
+  DM_Kern_Base(c),
   Kern_Polynomial(c)
 {}
diff --git a/src/dm_kern_quadratic.cpp b/src/dm_kern_quadratic.cpp
index 83a7198..9c4ecb2 100644
--- a/src/dm_kern_quadratic.cpp
+++ b/src/dm_kern_quadratic.cpp
@@ -3,9 +3,9 @@
 //CONFIG::Registry<DM_Function_Base>::Register<DM_Kern_Quadratic> DM_Kern_Quadratic_1( "Kern_Quadratic" );
 //CONFIG::Registry<DM_Function_Base,Config&>::Register<DM_Kern_Quadratic> DM_Kern_Quadratic_2( "Kern_Quadratic" );
 
-DM_Kern_Quadratic::DM_Kern_Quadratic():
-  Kern_Quadratic()
+DM_Kern_Quadratic::DM_Kern_Quadratic()
 {}
 DM_Kern_Quadratic::DM_Kern_Quadratic(const Config &c):
+  DM_Kern_Base(c),
   Kern_Quadratic(c)
 {}
diff --git a/src/dm_kern_rbf.cpp b/src/dm_kern_rbf.cpp
index 1870190..2c9b9e0 100644
--- a/src/dm_kern_rbf.cpp
+++ b/src/dm_kern_rbf.cpp
@@ -3,9 +3,9 @@
 //CONFIG::Registry<DM_Function_Base>::Register<DM_Kern_RBF> DM_Kern_RBF_1( "Kern_RBF" );
 //CONFIG::Registry<DM_Function_Base,Config&>::Register<DM_Kern_RBF> DM_Kern_RBF_2( "Kern_RBF" );
 
-DM_Kern_RBF::DM_Kern_RBF():
-  Kern_RBF()
+DM_Kern_RBF::DM_Kern_RBF()
 {}
 DM_Kern_RBF::DM_Kern_RBF(const Config &c):
+  DM_Kern_Base(c),
   Kern_RBF(c)
 {}
diff --git a/src/dm_kern_sigmoid.cpp b/src/dm_kern_sigmoid.cpp
index 2da37c9..7b547b0 100644
--- a/src/dm_kern_sigmoid.cpp
+++ b/src/dm_kern_sigmoid.cpp
@@ -3,9 +3,9 @@
 //CONFIG::Registry<DM_Function_Base>::Register<DM_Kern_Sigmoid> DM_Kern_Sigmoid_1( "Kern_Sigmoid" );
 //CONFIG::Registry<DM_Function_Base,Config&>::Register<DM_Kern_Sigmoid> DM_Kern_Sigmoid_2( "Kern_Sigmoid" );
 
-DM_Kern_Sigmoid::DM_Kern_Sigmoid():
-  Kern_Sigmoid()
+DM_Kern_Sigmoid::DM_Kern_Sigmoid()
 {}
 DM_Kern_Sigmoid::DM_Kern_Sigmoid(const Config &c):
+  DM_Kern_Base(c),
   Kern_Sigmoid(c)
 {}
-- 
GitLab


From 520c2a8567706e476946bd540ebfdde913d75c7a Mon Sep 17 00:00:00 2001
From: mkirsz <s1351949@sms.ed.ac.uk>
Date: Tue, 17 Dec 2024 13:12:10 +0000
Subject: [PATCH 08/15] Inheritance fixes update

---
 src/dm_kern_base.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/dm_kern_base.cpp b/src/dm_kern_base.cpp
index 773778a..e36fa6b 100644
--- a/src/dm_kern_base.cpp
+++ b/src/dm_kern_base.cpp
@@ -5,7 +5,7 @@
 
 DM_Kern_Base::~DM_Kern_Base() {}
 DM_Kern_Base::DM_Kern_Base() {}
-DM_Kern_Base::DM_Kern_Base(const Config &c): DM_Function_Base(c), Kern_Base(c) {}
+DM_Kern_Base::DM_Kern_Base(const Config &c): Kern_Base(c), DM_Function_Base(c) {}
 size_t DM_Kern_Base::get_phi_cols(const Config &)
 {
   return basis.cols();
-- 
GitLab


From cb47d330b822844be7c74abd43ce25da64ba7a8a Mon Sep 17 00:00:00 2001
From: mkirsz <s1351949@sms.ed.ac.uk>
Date: Tue, 17 Dec 2024 13:50:27 +0000
Subject: [PATCH 09/15] Proper init of FB

---
 src/dm_bf_base.cpp         | 5 ++++-
 src/dm_bf_linear.cpp       | 5 ++++-
 src/dm_bf_polynomial2.cpp  | 6 +++++-
 src/dm_kern_base.cpp       | 5 ++++-
 src/dm_kern_linear.cpp     | 5 ++++-
 src/dm_kern_lq.cpp         | 2 ++
 src/dm_kern_polynomial.cpp | 1 +
 src/dm_kern_quadratic.cpp  | 2 ++
 src/dm_kern_rbf.cpp        | 1 +
 src/dm_kern_sigmoid.cpp    | 1 +
 10 files changed, 28 insertions(+), 5 deletions(-)

diff --git a/src/dm_bf_base.cpp b/src/dm_bf_base.cpp
index 21543fa..946f455 100644
--- a/src/dm_bf_base.cpp
+++ b/src/dm_bf_base.cpp
@@ -2,5 +2,8 @@
 #include "tadah/models/functions/basis_functions/bf_base.h"
 #include <tadah/mlip/design_matrix/functions/basis_functions/dm_bf_base.h>
 DM_BF_Base::DM_BF_Base() {}
-DM_BF_Base::DM_BF_Base(const Config &c): BF_Base(c), DM_Function_Base(c) {}
+DM_BF_Base::DM_BF_Base(const Config &c):
+  Function_Base(c), 
+  BF_Base(c),
+  DM_Function_Base(c) {}
 DM_BF_Base::~DM_BF_Base() {}
diff --git a/src/dm_bf_linear.cpp b/src/dm_bf_linear.cpp
index 9bb9846..3aea01b 100644
--- a/src/dm_bf_linear.cpp
+++ b/src/dm_bf_linear.cpp
@@ -4,7 +4,10 @@
 //CONFIG::Registry<DM_Function_Base,Config&>::Register<DM_BF_Linear> DM_BF_Linear_2( "BF_Linear" );
 
 DM_BF_Linear::DM_BF_Linear() {}
-DM_BF_Linear::DM_BF_Linear(const Config &c): DM_BF_Base(c), BF_Linear(c)
+DM_BF_Linear::DM_BF_Linear(const Config &c): 
+  Function_Base(c), 
+  DM_BF_Base(c),
+  BF_Linear(c)
 {}
 size_t DM_BF_Linear::get_phi_cols(const Config &config)
 {
diff --git a/src/dm_bf_polynomial2.cpp b/src/dm_bf_polynomial2.cpp
index 9ef36d3..9c48624 100644
--- a/src/dm_bf_polynomial2.cpp
+++ b/src/dm_bf_polynomial2.cpp
@@ -1,11 +1,15 @@
 #include "tadah/mlip/design_matrix/functions/kernels/dm_kern_base.h"
+#include "tadah/models/functions/function_base.h"
 #include <tadah/mlip/design_matrix/functions/basis_functions/dm_bf_polynomial2.h>
 
 //CONFIG::Registry<DM_Function_Base>::Register<DM_BF_Polynomial2> DM_BF_Polynomial2_1( "BF_Polynomial2" );
 //CONFIG::Registry<DM_Function_Base,Config&>::Register<DM_BF_Polynomial2> DM_BF_Polynomial2_2( "BF_Polynomial2" );
 
 DM_BF_Polynomial2::DM_BF_Polynomial2() {}
-DM_BF_Polynomial2::DM_BF_Polynomial2(const Config &c): DM_BF_Base(c), BF_Polynomial2(c)
+DM_BF_Polynomial2::DM_BF_Polynomial2(const Config &c): 
+  Function_Base(c),
+  DM_BF_Base(c),
+  BF_Polynomial2(c)
 {}
 size_t DM_BF_Polynomial2::get_phi_cols(const Config &config)
 {
diff --git a/src/dm_kern_base.cpp b/src/dm_kern_base.cpp
index e36fa6b..9e569e0 100644
--- a/src/dm_kern_base.cpp
+++ b/src/dm_kern_base.cpp
@@ -5,7 +5,10 @@
 
 DM_Kern_Base::~DM_Kern_Base() {}
 DM_Kern_Base::DM_Kern_Base() {}
-DM_Kern_Base::DM_Kern_Base(const Config &c): Kern_Base(c), DM_Function_Base(c) {}
+DM_Kern_Base::DM_Kern_Base(const Config &c): 
+  Function_Base(c), 
+  Kern_Base(c), 
+  DM_Function_Base(c) {}
 size_t DM_Kern_Base::get_phi_cols(const Config &)
 {
   return basis.cols();
diff --git a/src/dm_kern_linear.cpp b/src/dm_kern_linear.cpp
index 2788656..df6126a 100644
--- a/src/dm_kern_linear.cpp
+++ b/src/dm_kern_linear.cpp
@@ -5,7 +5,10 @@
 //CONFIG::Registry<DM_Function_Base,Config&>::Register<DM_Kern_Linear> DM_Kern_Linear_2( "Kern_Linear" );
 
 DM_Kern_Linear::DM_Kern_Linear() {}
-DM_Kern_Linear::DM_Kern_Linear (const Config &c): DM_Kern_Base(c), Kern_Linear(c)
+DM_Kern_Linear::DM_Kern_Linear (const Config &c): 
+  Function_Base(c),
+  DM_Kern_Base(c),
+  Kern_Linear(c)
 {}
 size_t DM_Kern_Linear::get_phi_cols(const Config &config)
 {
diff --git a/src/dm_kern_lq.cpp b/src/dm_kern_lq.cpp
index 40fdb4d..dbef59e 100644
--- a/src/dm_kern_lq.cpp
+++ b/src/dm_kern_lq.cpp
@@ -1,4 +1,5 @@
 #include "tadah/mlip/design_matrix/functions/kernels/dm_kern_base.h"
+#include "tadah/models/functions/function_base.h"
 #include <tadah/mlip/design_matrix/functions/kernels/dm_kern_lq.h>
 
 //CONFIG::Registry<DM_Function_Base>::Register<DM_Kern_LQ> DM_Kern_LQ_1( "Kern_LQ" );
@@ -7,6 +8,7 @@
 DM_Kern_LQ::DM_Kern_LQ()
 {}
 DM_Kern_LQ::DM_Kern_LQ(const Config &c):
+  Function_Base(c), 
   DM_Kern_Base(c),
   Kern_LQ(c)
 {}
diff --git a/src/dm_kern_polynomial.cpp b/src/dm_kern_polynomial.cpp
index 5227ca7..e8637b1 100644
--- a/src/dm_kern_polynomial.cpp
+++ b/src/dm_kern_polynomial.cpp
@@ -6,6 +6,7 @@
 DM_Kern_Polynomial::DM_Kern_Polynomial()
 {}
 DM_Kern_Polynomial::DM_Kern_Polynomial(const Config &c):
+  Function_Base(c), 
   DM_Kern_Base(c),
   Kern_Polynomial(c)
 {}
diff --git a/src/dm_kern_quadratic.cpp b/src/dm_kern_quadratic.cpp
index 9c4ecb2..72ab0d4 100644
--- a/src/dm_kern_quadratic.cpp
+++ b/src/dm_kern_quadratic.cpp
@@ -1,3 +1,4 @@
+#include "tadah/models/functions/function_base.h"
 #include <tadah/mlip/design_matrix/functions/kernels/dm_kern_quadratic.h>
 
 //CONFIG::Registry<DM_Function_Base>::Register<DM_Kern_Quadratic> DM_Kern_Quadratic_1( "Kern_Quadratic" );
@@ -6,6 +7,7 @@
 DM_Kern_Quadratic::DM_Kern_Quadratic()
 {}
 DM_Kern_Quadratic::DM_Kern_Quadratic(const Config &c):
+  Function_Base(c),
   DM_Kern_Base(c),
   Kern_Quadratic(c)
 {}
diff --git a/src/dm_kern_rbf.cpp b/src/dm_kern_rbf.cpp
index 2c9b9e0..41d57be 100644
--- a/src/dm_kern_rbf.cpp
+++ b/src/dm_kern_rbf.cpp
@@ -6,6 +6,7 @@
 DM_Kern_RBF::DM_Kern_RBF()
 {}
 DM_Kern_RBF::DM_Kern_RBF(const Config &c):
+  Function_Base(c), 
   DM_Kern_Base(c),
   Kern_RBF(c)
 {}
diff --git a/src/dm_kern_sigmoid.cpp b/src/dm_kern_sigmoid.cpp
index 7b547b0..1c3d25b 100644
--- a/src/dm_kern_sigmoid.cpp
+++ b/src/dm_kern_sigmoid.cpp
@@ -6,6 +6,7 @@
 DM_Kern_Sigmoid::DM_Kern_Sigmoid()
 {}
 DM_Kern_Sigmoid::DM_Kern_Sigmoid(const Config &c):
+  Function_Base(c), 
   DM_Kern_Base(c),
   Kern_Sigmoid(c)
 {}
-- 
GitLab


From 172b434c8e060028e8bb9d8ca11c5e8ee079357f Mon Sep 17 00:00:00 2001
From: mkirsz <s1351949@sms.ed.ac.uk>
Date: Wed, 18 Dec 2024 00:08:10 +0000
Subject: [PATCH 10/15] moved scaling to individual descriptors

---
 include/tadah/mlip/descriptors_calc.hpp | 20 ++++++++++----------
 1 file changed, 10 insertions(+), 10 deletions(-)

diff --git a/include/tadah/mlip/descriptors_calc.hpp b/include/tadah/mlip/descriptors_calc.hpp
index 2f1ab7f..10eb0bf 100644
--- a/include/tadah/mlip/descriptors_calc.hpp
+++ b/include/tadah/mlip/descriptors_calc.hpp
@@ -258,12 +258,12 @@ void DescriptorsCalc<D2,D3,DM,C2,C3,CM>::calc(const Structure &st, StDescriptors
       if (use_force || use_stress) {
         fd_type &fd_ij = st_d.fd[i][jj];
         if (rij_sq <= rcut_2b_sq && init2b) {
-          d2.calc_all(Zj,rij,rij_sq,aed,fd_ij);
+          d2.calc_all(Zj,rij,rij_sq,aed,fd_ij,0.5);
           // Two-body descriptor calculates x-direction only - fd_ij(n,0)
           // so we have to copy x-dir to y- and z-dir
           // and scale them by the unit directional vector delij/rij.
           for (size_t n=bias; n<size2b+bias; ++n) {
-            fd_ij(n,0) *= 0.5*rij_inv;
+            fd_ij(n,0) *= rij_inv;
             fd_ij(n,1) = fd_ij(n,0)*delij[1];
             fd_ij(n,2) = fd_ij(n,0)*delij[2];
             fd_ij(n,0) *= delij[0];
@@ -277,19 +277,19 @@ void DescriptorsCalc<D2,D3,DM,C2,C3,CM>::calc(const Structure &st, StDescriptors
       }
       else {
         if (rij_sq <= rcut_2b_sq && init2b) {
-          d2.calc_aed(Zj,rij,rij_sq,aed);
+          d2.calc_aed(Zj,rij,rij_sq,aed,0.5);
         }
       }
 
     }
   }
-  if (init2b) {
-    for (size_t n=0; n<st.natoms(); ++n) {
-      for(size_t s=bias; s<bias+d2.size(); ++s) {
-        st_d.get_aed(n)(s) *= 0.5;
-      }
-    }
-  }
+  // if (init2b) {
+  //   for (size_t n=0; n<st.natoms(); ++n) {
+  //     for(size_t s=bias; s<bias+d2.size(); ++s) {
+  //       st_d.get_aed(n)(s) *= 0.5;
+  //     }
+  //   }
+  // }
 }
 template <typename D2, typename D3, typename DM, typename C2, typename C3, typename CM>
 void DescriptorsCalc<D2,D3,DM,C2,C3,CM>::calc_dimer(const Structure &st, StDescriptors &st_d) {
-- 
GitLab


From 40b996207e6c37b2f0919f53e3e32b626a513477 Mon Sep 17 00:00:00 2001
From: mkirsz <s1351949@sms.ed.ac.uk>
Date: Wed, 18 Dec 2024 12:25:52 +0000
Subject: [PATCH 11/15] rename aed

---
 include/tadah/mlip/analytics/statistics.h |  2 +-
 include/tadah/mlip/descriptors_calc.hpp   | 10 +++++-----
 include/tadah/mlip/models/basis.h         |  4 ++--
 include/tadah/mlip/models/m_blr.h         | 12 ++++++------
 include/tadah/mlip/models/m_krr.h         | 12 ++++++------
 include/tadah/mlip/models/m_tadah_base.h  |  6 +++---
 include/tadah/mlip/output/output.h        |  2 +-
 include/tadah/mlip/st_descriptors.h       |  4 ++--
 src/dm_bf_linear.cpp                      |  2 +-
 src/dm_bf_polynomial2.cpp                 | 10 +++++-----
 src/dm_kern_base.cpp                      |  8 ++++----
 src/dm_kern_linear.cpp                    |  6 +++---
 src/m_tadah_base.cpp                      | 12 ++++++------
 src/st_descriptors.cpp                    |  4 ++--
 14 files changed, 47 insertions(+), 47 deletions(-)

diff --git a/include/tadah/mlip/analytics/statistics.h b/include/tadah/mlip/analytics/statistics.h
index 0d86664..4bd75bf 100644
--- a/include/tadah/mlip/analytics/statistics.h
+++ b/include/tadah/mlip/analytics/statistics.h
@@ -5,7 +5,7 @@
 
 /** Some basis statistical tools */
 class Statistics {
-    using vec = aed_type2;
+    using vec = aed_type;
     public:
         /** Residual sum of squares. */
         static double res_sum_sq(const vec &obs, const vec &pred);
diff --git a/include/tadah/mlip/descriptors_calc.hpp b/include/tadah/mlip/descriptors_calc.hpp
index 10eb0bf..01ae99d 100644
--- a/include/tadah/mlip/descriptors_calc.hpp
+++ b/include/tadah/mlip/descriptors_calc.hpp
@@ -216,7 +216,7 @@ void DescriptorsCalc<D2,D3,DM,C2,C3,CM>::calc(const Structure &st, StDescriptors
 
   // zero all aeds and set bias
   for (size_t i=0; i<st.natoms(); ++i) {
-    aed_type2 &aed = st_d.get_aed(i);
+    aed_type &aed = st_d.get_aed(i);
     aed.set_zero();
     aed(0)=static_cast<double>(bias);   // set bias
   }
@@ -226,7 +226,7 @@ void DescriptorsCalc<D2,D3,DM,C2,C3,CM>::calc(const Structure &st, StDescriptors
   // can be calculated
   if (initmb) {
     for (size_t i=0; i<st.natoms(); ++i) {
-      aed_type2 &aed = st_d.get_aed(i);
+      aed_type &aed = st_d.get_aed(i);
       rho_type& rhoi = st_d.get_rho(i);
       dm.calc_aed(rhoi,aed);
     }
@@ -238,7 +238,7 @@ void DescriptorsCalc<D2,D3,DM,C2,C3,CM>::calc(const Structure &st, StDescriptors
   Vec3d delij;
   for (size_t i=0; i<st.natoms(); ++i) {
     const Atom &a1 = st(i);
-    aed_type2 &aed = st_d.get_aed(i);
+    aed_type &aed = st_d.get_aed(i);
 
     for (size_t jj=0; jj<st.nn_size(i); ++jj) {
       const Vec3d &a2pos = st.nn_pos(i,jj);
@@ -356,7 +356,7 @@ void DescriptorsCalc<D2,D3,DM,C2,C3,CM>::calc_dimer(const Structure &st, StDescr
 
   // zero all aeds+rho and set bias
   for (size_t i=0; i<4; ++i) {
-    aed_type2 &aed = st_d.get_aed(i);
+    aed_type &aed = st_d.get_aed(i);
     aed.set_zero();
     aed(0)=static_cast<double>(bias);   // set bias
   }
@@ -401,7 +401,7 @@ void DescriptorsCalc<D2,D3,DM,C2,C3,CM>::calc_dimer(const Structure &st, StDescr
   // calculate many-body aed
   if (initmb) {
     for (size_t i=0; i<4; ++i) {
-      aed_type2 &aed = st_d.get_aed(i);
+      aed_type &aed = st_d.get_aed(i);
       rho_type& rhoi = st_d.get_rho(i);
       dm.calc_aed(rhoi,aed);
     }
diff --git a/include/tadah/mlip/models/basis.h b/include/tadah/mlip/models/basis.h
index 80f2bbe..f00caf0 100644
--- a/include/tadah/mlip/models/basis.h
+++ b/include/tadah/mlip/models/basis.h
@@ -56,7 +56,7 @@ larger than the amount of available AEDs\n");
     for (size_t i=1; i<s; ++i) {
       size_t st = std::get<0>(indices[i]);
       size_t a = std::get<1>(indices[i]);
-      const aed_type2 &aed = st_desc_db(st).get_aed(a);
+      const aed_type &aed = st_desc_db(st).get_aed(a);
       for (size_t j=0; j<aed.size(); ++j) {
         b(j,i)=aed[j];
       }
@@ -97,7 +97,7 @@ larger than the amount of available AEDs\n");
       const size_t st = indices[i];
       T(i)=stdb(st).energy/st_desc_db(st).naed();
       for( size_t a=0; a<st_desc_db(st).naed(); a++ ) {
-        const aed_type2 &aed = st_desc_db(st).get_aed(a);
+        const aed_type &aed = st_desc_db(st).get_aed(a);
         for (size_t j=0; j<aed.size(); ++j) {
           b(j,i)+=aed[j]/st_desc_db(st).naed();
         }
diff --git a/include/tadah/mlip/models/m_blr.h b/include/tadah/mlip/models/m_blr.h
index d687282..943a7fe 100644
--- a/include/tadah/mlip/models/m_blr.h
+++ b/include/tadah/mlip/models/m_blr.h
@@ -76,15 +76,15 @@ public:
     norm = Normaliser(c);
   }
 
-  double epredict(const aed_type2 &aed) const{
+  double epredict(const aed_type &aed) const{
     return bf.epredict(weights,aed);
   };
 
-  double fpredict(const fd_type &fdij, const aed_type2 &aedi, const size_t k) const{
+  double fpredict(const fd_type &fdij, const aed_type &aedi, const size_t k) const{
     return bf.fpredict(weights,fdij,aedi,k);
   }
 
-  force_type fpredict(const fd_type &fdij, const aed_type2 &aedi) const{
+  force_type fpredict(const fd_type &fdij, const aed_type &aedi) const{
     return bf.fpredict(weights,fdij,aedi);
   }
 
@@ -167,7 +167,7 @@ public:
     return c;
   }
   StructureDB predict(Config config_pred, StructureDB &stdb, DC_Base &dc,
-                      aed_type2 &predicted_error) {
+                      aed_type &predicted_error) {
 
     LinearRegressor::read_sigma(config_pred,Sigma);
     DesignMatrix<BF> dm(bf,config_pred);
@@ -178,7 +178,7 @@ public:
     double pmean = sqrt(predicted_error.mean());
 
     // compute energy, forces and stresses
-    aed_type2 Tpred = T_dgemv(dm.Phi, weights);
+    aed_type Tpred = T_dgemv(dm.Phi, weights);
 
     // Construct StructureDB object with predicted values
     StructureDB stdb_;
@@ -218,7 +218,7 @@ Hint: check different predict() methods.");
     //std::cout << Phi.row(0) << std::endl;
 
     // compute energy, forces and stresses
-    aed_type2 Tpred = T_dgemv(Phi, weights);
+    aed_type Tpred = T_dgemv(Phi, weights);
 
     double eweightglob=config.template get<double>("EWEIGHT");
     double fweightglob=config.template get<double>("FWEIGHT");
diff --git a/include/tadah/mlip/models/m_krr.h b/include/tadah/mlip/models/m_krr.h
index 610c1fc..92aee1d 100644
--- a/include/tadah/mlip/models/m_krr.h
+++ b/include/tadah/mlip/models/m_krr.h
@@ -79,15 +79,15 @@ public:
     norm = Normaliser(c);
   }
 
-  double epredict(const aed_type2 &aed) const {
+  double epredict(const aed_type &aed) const {
     return kernel.epredict(weights,aed);
   };
 
-  double fpredict(const fd_type &fdij, const aed_type2 &aedi, const size_t k) const {
+  double fpredict(const fd_type &fdij, const aed_type &aedi, const size_t k) const {
     return kernel.fpredict(weights,fdij,aedi,k);
   }
 
-  force_type fpredict(const fd_type &fdij, const aed_type2 &aedi) const {
+  force_type fpredict(const fd_type &fdij, const aed_type &aedi) const {
     return kernel.fpredict(weights,fdij,aedi);
   }
 
@@ -237,7 +237,7 @@ public:
     return c;
   }
   StructureDB predict(Config config_pred, StructureDB &stdb, DC_Base &dc,
-                      aed_type2 &predicted_error) {
+                      aed_type &predicted_error) {
 
     LinearRegressor::read_sigma(config_pred,Sigma);
     DesignMatrix<K> dm(kernel,config_pred);
@@ -249,7 +249,7 @@ public:
     double pmean = sqrt(predicted_error.mean());
 
     // compute energy, forces and stresses
-    aed_type2 Tpred = T_dgemv(dm.Phi, weights);
+    aed_type Tpred = T_dgemv(dm.Phi, weights);
 
     // Construct StructureDB object with predicted values
     StructureDB stdb_;
@@ -288,7 +288,7 @@ Hint: check different predict() methods.");
     phi_type &Phi = desmat.Phi;
 
     // compute energy, forces and stresses
-    aed_type2 Tpred = T_dgemv(Phi, weights);
+    aed_type Tpred = T_dgemv(Phi, weights);
 
     double eweightglob=config.template get<double>("EWEIGHT");
     double fweightglob=config.template get<double>("FWEIGHT");
diff --git a/include/tadah/mlip/models/m_tadah_base.h b/include/tadah/mlip/models/m_tadah_base.h
index c98e4e6..d719829 100644
--- a/include/tadah/mlip/models/m_tadah_base.h
+++ b/include/tadah/mlip/models/m_tadah_base.h
@@ -29,11 +29,11 @@ public:
   double epredict(const StDescriptors &std);
 
   ///** \brief Predict force between a pair of atoms in a k-direction. */
-  //virtual double fpredict(const fd_type &fdij, const aed_type2 &aedi, size_t k)=0;
+  //virtual double fpredict(const fd_type &fdij, const aed_type &aedi, size_t k)=0;
 
   ///** \brief Predict force between a pair of atoms. */
   //virtual force_type fpredict(const fd_type &fdij,
-  //        const aed_type2 &aedi)=0;
+  //        const aed_type &aedi)=0;
 
   /** \brief Predict total force on an atom a. */
   virtual void fpredict(const size_t a, force_type &v,
@@ -96,7 +96,7 @@ public:
   virtual void train(StDescriptorsDB &, const StructureDB &) {};
 
   virtual StructureDB predict(Config config_pred, StructureDB &stdb, DC_Base &dc,
-                              aed_type2 &predicted_error)=0;
+                              aed_type &predicted_error)=0;
 
   virtual StructureDB predict(StructureDB &stdb)=0;
 
diff --git a/include/tadah/mlip/output/output.h b/include/tadah/mlip/output/output.h
index 0f0a7a3..1a08850 100644
--- a/include/tadah/mlip/output/output.h
+++ b/include/tadah/mlip/output/output.h
@@ -45,7 +45,7 @@ class Output {
 
             out_unc.close();
         }
-        void print_predict_all(StructureDB &stdb, StructureDB &stpred, aed_type2 & predicted_error) {
+        void print_predict_all(StructureDB &stdb, StructureDB &stpred, aed_type & predicted_error) {
             std::ofstream out_error("error.pred");
             std::ofstream out_energy("energy.pred");
             std::ofstream out_force("forces.pred");
diff --git a/include/tadah/mlip/st_descriptors.h b/include/tadah/mlip/st_descriptors.h
index 56618d5..f5cf2ad 100644
--- a/include/tadah/mlip/st_descriptors.h
+++ b/include/tadah/mlip/st_descriptors.h
@@ -69,8 +69,8 @@ struct StDescriptors {
 
     rhos_type rhos;
 
-    aed_type2 & get_aed(const size_t i);
-    const aed_type2 &get_aed(const size_t i) const;
+    aed_type & get_aed(const size_t i);
+    const aed_type &get_aed(const size_t i) const;
     rho_type& get_rho(const size_t i);
 
     size_t naed() const;
diff --git a/src/dm_bf_linear.cpp b/src/dm_bf_linear.cpp
index 3aea01b..4f5dac2 100644
--- a/src/dm_bf_linear.cpp
+++ b/src/dm_bf_linear.cpp
@@ -18,7 +18,7 @@ void DM_BF_Linear::calc_phi_energy_row(phi_type &Phi, size_t &row,
                                        const double fac, const Structure &, const StDescriptors &st_d)
 {
   for (size_t i=0; i<st_d.naed(); ++i) {
-    const aed_type2 &aed = st_d.get_aed(i);
+    const aed_type &aed = st_d.get_aed(i);
     for (size_t j=0; j<aed.size(); ++j) {
       Phi(row,j)+=aed[j]*fac;
     }
diff --git a/src/dm_bf_polynomial2.cpp b/src/dm_bf_polynomial2.cpp
index 9c48624..c7fb188 100644
--- a/src/dm_bf_polynomial2.cpp
+++ b/src/dm_bf_polynomial2.cpp
@@ -23,7 +23,7 @@ void DM_BF_Polynomial2::calc_phi_energy_row(phi_type &Phi,
                                             const StDescriptors &st_d)
 {
   for (size_t a=0; a<st_d.naed();++a) {
-    const aed_type2& aed = st_d.get_aed(a);
+    const aed_type& aed = st_d.get_aed(a);
     size_t b=0;
     for (size_t i=0; i<st_d.dim(); ++i) {
       for (size_t ii=i; ii<st_d.dim(); ++ii) {
@@ -40,13 +40,13 @@ void DM_BF_Polynomial2::calc_phi_force_rows(phi_type &Phi,
                                             const StDescriptors &st_d)
 {
   for (size_t a=0; a<st.natoms(); ++a) {
-    const aed_type2& aedi = st_d.get_aed(a);
+    const aed_type& aedi = st_d.get_aed(a);
     for (size_t jj=0; jj<st_d.fd[a].size(); ++jj) {
       const size_t j=st.near_neigh_idx[a][jj];
       size_t aa = st.get_nn_iindex(a,j,jj);
       const fd_type &fdji = st_d.fd[j][aa];
       const fd_type &fdij = st_d.fd[a][jj];
-      const aed_type2& aedj = st_d.get_aed(j);
+      const aed_type& aedj = st_d.get_aed(j);
 
       for (size_t k=0; k<3; ++k) {
         size_t b=0;
@@ -72,14 +72,14 @@ void DM_BF_Polynomial2::calc_phi_stress_rows(phi_type &Phi,
   double V_inv = 1/st.get_volume();
   for (size_t a=0; a<st.natoms(); ++a) {
     const Vec3d &ri = st(a).position;
-    const aed_type2& aedi = st_d.get_aed(a);
+    const aed_type& aedi = st_d.get_aed(a);
     for (size_t jj=0; jj<st_d.fd[a].size(); ++jj) {
       const size_t j=st.near_neigh_idx[a][jj];
       size_t aa = st.get_nn_iindex(a,j,jj);
       const fd_type &fdji = st_d.fd[j][aa];
       const fd_type &fdij = st_d.fd[a][jj];
       const Vec3d &rj = st.nn_pos(a,jj);
-      const aed_type2& aedj = st_d.get_aed(j);
+      const aed_type& aedj = st_d.get_aed(j);
       size_t mn=0;
       for (size_t x=0; x<3; ++x) {
         for (size_t y=x; y<3; ++y) {
diff --git a/src/dm_kern_base.cpp b/src/dm_kern_base.cpp
index 9e569e0..f78c71e 100644
--- a/src/dm_kern_base.cpp
+++ b/src/dm_kern_base.cpp
@@ -27,13 +27,13 @@ void  DM_Kern_Base::calc_phi_energy_row(phi_type &Phi, size_t &row, const double
 void  DM_Kern_Base::calc_phi_force_rows(phi_type &Phi, size_t &row, const double fac,
                                         const Structure &st, const StDescriptors &st_d) {
   for (size_t i=0; i<st.natoms(); ++i) {
-    const aed_type2& aedi = st_d.get_aed(i);
+    const aed_type& aedi = st_d.get_aed(i);
     for (size_t jj=0; jj<st_d.fd[i].size(); ++jj) {
       size_t j=st.near_neigh_idx[i][jj];
       size_t ii = st.get_nn_iindex(i,j,jj);
       const fd_type &fdji = st_d.fd[j][ii];
       const fd_type &fdij = st_d.fd[i][jj];
-      const aed_type2& aedj = st_d.get_aed(j);
+      const aed_type& aedj = st_d.get_aed(j);
       for (size_t b=0; b<basis.cols(); ++b) {
         for (size_t k=0; k<3; ++k) {
           Phi(row+k,b) -= fac*((*this).prime(basis.col(b), aedi,fdij(k)) -
@@ -50,13 +50,13 @@ void  DM_Kern_Base::calc_phi_stress_rows(phi_type &Phi, size_t &row, const doubl
   double V_inv = 1/st.get_volume();
   for (size_t i=0; i<st.natoms(); ++i) {
     const Vec3d &ri = st(i).position;
-    const aed_type2& aedi = st_d.get_aed(i);
+    const aed_type& aedi = st_d.get_aed(i);
     for (size_t jj=0; jj<st_d.fd[i].size(); ++jj) {
       size_t j=st.near_neigh_idx[i][jj];
       size_t ii = st.get_nn_iindex(i,j,jj);
       const fd_type &fdji = st_d.fd[j][ii];
       const fd_type &fdij = st_d.fd[i][jj];
-      const aed_type2& aedj = st_d.get_aed(j);
+      const aed_type& aedj = st_d.get_aed(j);
       const Vec3d &rj = st.nn_pos(i,jj);
       size_t mn=0;
       for (size_t x=0; x<3; ++x) {
diff --git a/src/dm_kern_linear.cpp b/src/dm_kern_linear.cpp
index df6126a..84c484b 100644
--- a/src/dm_kern_linear.cpp
+++ b/src/dm_kern_linear.cpp
@@ -19,7 +19,7 @@ void DM_Kern_Linear::calc_phi_energy_row(phi_type &Phi, size_t &row,
                                          const double fac, const Structure &, const StDescriptors &st_d)
 {
   for (size_t a=0; a<st_d.naed();++a) {
-    const aed_type2 &aed = st_d.get_aed(a);  // TODO
+    const aed_type &aed = st_d.get_aed(a);  // TODO
     for (size_t j=0; j<aed.size(); ++j) {
       Phi(row,j)+=aed[j]*fac;
     }
@@ -34,7 +34,7 @@ void DM_Kern_Linear::calc_phi_force_rows(phi_type &Phi, size_t &row,
       const size_t j=st.near_neigh_idx[a][jj];
       const size_t aa = st.get_nn_iindex(a,j,jj);
       for (size_t k=0; k<3; ++k) {
-        aed_type2 temp = (st_d.fd[a][jj](k)-
+        aed_type temp = (st_d.fd[a][jj](k)-
           st_d.fd[j][aa](k))*fac;
         for (size_t d=0; d<temp.size(); ++d) {
           Phi(row+k,d) -= temp[d];
@@ -60,7 +60,7 @@ void DM_Kern_Linear::calc_phi_stress_rows(phi_type &Phi, size_t &row,
       size_t mn=0;
       for (size_t x=0; x<3; ++x) {
         for (size_t y=x; y<3; ++y) {
-          aed_type2 temp = V_inv*(fdij(y)-fdji(y))*0.5*fac[mn]*(ri(x)-rj(x));
+          aed_type temp = V_inv*(fdij(y)-fdji(y))*0.5*fac[mn]*(ri(x)-rj(x));
           for (size_t d=0; d<temp.size(); ++d) {
             Phi(row+mn,d) += temp[d];
           }
diff --git a/src/m_tadah_base.cpp b/src/m_tadah_base.cpp
index ca18ccb..2efa8f1 100644
--- a/src/m_tadah_base.cpp
+++ b/src/m_tadah_base.cpp
@@ -9,8 +9,8 @@ fpredict(const size_t a, force_type &v,
     const size_t aa = st.get_nn_iindex(a,j,jj);
     const fd_type &fdji = std.fd[j][aa];
     const fd_type &fdij = std.fd[a][jj];
-    const aed_type2 &aedi = std.get_aed(a);
-    const aed_type2 &aedj = std.get_aed(j);
+    const aed_type &aedi = std.get_aed(a);
+    const aed_type &aedj = std.get_aed(j);
     v += fpredict(fdij,aedi);
     v -= fpredict(fdji,aedj);
   }
@@ -111,14 +111,14 @@ spredict(const size_t a, stress_type &s,
 {
   double V_inv = 1/st.get_volume();
   const Vec3d &ri = st.atoms[a].position;
-  const aed_type2 &aedi = std.get_aed(a);
+  const aed_type &aedi = std.get_aed(a);
   for (size_t jj=0; jj<st.nn_size(a); ++jj) {
     size_t j=st.near_neigh_idx[a][jj];
     const size_t aa = st.get_nn_iindex(a,j,jj);
     const Vec3d &rj = st.nn_pos(a,jj);
     const fd_type &fdij = std.fd[a][jj];
     const fd_type &fdji = std.fd[j][aa];
-    const aed_type2 &aedj = std.get_aed(j);
+    const aed_type &aedj = std.get_aed(j);
     const force_type fij = fpredict(fdij,aedi);
     const force_type fji = fpredict(fdji,aedj);
 
@@ -142,14 +142,14 @@ stress_force_predict(const StDescriptors &std, Structure &st_)
   for (size_t a=0; a<st_.natoms(); ++a) {
     force_type v;
     const Vec3d &ri = st_.atoms[a].position;
-    const aed_type2 &aedi = std.get_aed(a);
+    const aed_type &aedi = std.get_aed(a);
     for (size_t jj=0; jj<st_.nn_size(a); ++jj) {
       size_t j=st_.near_neigh_idx[a][jj];
       const size_t aa = st_.get_nn_iindex(a,j,jj);
       const Vec3d &rj = st_.nn_pos(a,jj);
       const fd_type &fdij = std.fd[a][jj];
       const fd_type &fdji = std.fd[j][aa];
-      const aed_type2 &aedj = std.get_aed(j);
+      const aed_type &aedj = std.get_aed(j);
       const force_type fij = fpredict(fdij,aedi);
       const force_type fji = fpredict(fdji,aedj);
 
diff --git a/src/st_descriptors.cpp b/src/st_descriptors.cpp
index e7a2840..a0407e8 100644
--- a/src/st_descriptors.cpp
+++ b/src/st_descriptors.cpp
@@ -31,11 +31,11 @@ StDescriptors::StDescriptors(const Structure &s, const Config &c):
 }
 StDescriptors::StDescriptors() {}
 
-aed_type2 & StDescriptors::get_aed(const size_t i) {
+aed_type & StDescriptors::get_aed(const size_t i) {
     return aeds.col(i);
 }
 
-const aed_type2 &StDescriptors::get_aed(const size_t i) const {
+const aed_type &StDescriptors::get_aed(const size_t i) const {
     return aeds.col(i);
 }
 rho_type& StDescriptors::get_rho(const size_t i) {
-- 
GitLab


From 0871db3a18a2a2b3b5c4a203be7e69056cb08cb7 Mon Sep 17 00:00:00 2001
From: mkirsz <s1351949@sms.ed.ac.uk>
Date: Wed, 18 Dec 2024 16:35:48 +0000
Subject: [PATCH 12/15] Zi Zj for all calcs

---
 include/tadah/mlip/descriptors_calc.hpp | 19 +++++++++++--------
 1 file changed, 11 insertions(+), 8 deletions(-)

diff --git a/include/tadah/mlip/descriptors_calc.hpp b/include/tadah/mlip/descriptors_calc.hpp
index 01ae99d..f9659b4 100644
--- a/include/tadah/mlip/descriptors_calc.hpp
+++ b/include/tadah/mlip/descriptors_calc.hpp
@@ -168,6 +168,7 @@ void DescriptorsCalc<D2,D3,DM,C2,C3,CM>::calc_rho(const Structure &st, StDescrip
   for (size_t i=0; i<st.natoms(); ++i) {
     const Atom &a1 = st(i);
 
+    int Zi = a1.Z;
     for (size_t jj=0; jj<st.nn_size(i); ++jj) {
       const Vec3d &a2pos = st.nn_pos(i,jj);
       //delij = a1.position - a2pos;
@@ -180,7 +181,7 @@ void DescriptorsCalc<D2,D3,DM,C2,C3,CM>::calc_rho(const Structure &st, StDescrip
       if (rij_sq > rcut_mb_sq) continue;
       int Zj = st.near_neigh_atoms[i][jj].Z;
       double rij = sqrt(rij_sq);
-      dm.calc_rho(Zj,rij,rij_sq,delij,st_d.get_rho(i));
+      dm.calc_rho(Zi,Zj,rij,rij_sq,delij,st_d.get_rho(i));
     }
   }
 }
@@ -240,6 +241,7 @@ void DescriptorsCalc<D2,D3,DM,C2,C3,CM>::calc(const Structure &st, StDescriptors
     const Atom &a1 = st(i);
     aed_type &aed = st_d.get_aed(i);
 
+    int Zi = a1.Z;
     for (size_t jj=0; jj<st.nn_size(i); ++jj) {
       const Vec3d &a2pos = st.nn_pos(i,jj);
 
@@ -258,7 +260,7 @@ void DescriptorsCalc<D2,D3,DM,C2,C3,CM>::calc(const Structure &st, StDescriptors
       if (use_force || use_stress) {
         fd_type &fd_ij = st_d.fd[i][jj];
         if (rij_sq <= rcut_2b_sq && init2b) {
-          d2.calc_all(Zj,rij,rij_sq,aed,fd_ij,0.5);
+          d2.calc_all(Zi,Zj,rij,rij_sq,aed,fd_ij,0.5);
           // Two-body descriptor calculates x-direction only - fd_ij(n,0)
           // so we have to copy x-dir to y- and z-dir
           // and scale them by the unit directional vector delij/rij.
@@ -272,12 +274,12 @@ void DescriptorsCalc<D2,D3,DM,C2,C3,CM>::calc(const Structure &st, StDescriptors
         // CALCULATE MANY-BODY TERM
         if (rij_sq <= rcut_mb_sq && initmb) {
           rho_type& rhoi = st_d.get_rho(i);
-          dm.calc_dXijdri(Zj,rij,rij_sq,delij,rhoi,fd_ij);
+          dm.calc_dXijdri(Zi,Zj,rij,rij_sq,delij,rhoi,fd_ij);
         }
       }
       else {
         if (rij_sq <= rcut_2b_sq && init2b) {
-          d2.calc_aed(Zj,rij,rij_sq,aed,0.5);
+          d2.calc_aed(Zi, Zj,rij,rij_sq,aed,0.5);
         }
       }
 
@@ -342,6 +344,7 @@ void DescriptorsCalc<D2,D3,DM,C2,C3,CM>::calc_dimer(const Structure &st, StDescr
   // TODO weighting factors
   // For now assume all are the same type
   //int Zj = st.near_neigh_atoms[0][0].Z;
+  int Zi = 1;
   int Zj = 1;
 
   // map of atom label and distances
@@ -388,13 +391,13 @@ void DescriptorsCalc<D2,D3,DM,C2,C3,CM>::calc_dimer(const Structure &st, StDescr
   size_t N = bond_bool ? 6 : 4;
   for (size_t n=0; n<N; ++n) {
     if (r_sq[n] <= rcut_mb_sq && initmb) {
-      dm.calc_rho(Zj, r[n],r_sq[n],delM.row(n),st_d.get_rho(idx[n].first));
-      dm.calc_rho(Zj, r[n],r_sq[n],-delM.row(n),st_d.get_rho(idx[n].second));
+      dm.calc_rho(Zi,Zj, r[n],r_sq[n],delM.row(n),st_d.get_rho(idx[n].first));
+      dm.calc_rho(Zi,Zj, r[n],r_sq[n],-delM.row(n),st_d.get_rho(idx[n].second));
     }
     // Do not compute 2b term between bonded atoms
     if (r_sq[n] <= rcut_2b_sq && init2b) {
-      d2.calc_aed(Zj,r[n],r_sq[n],st_d.get_aed(idx[n].first));
-      d2.calc_aed(Zj,r[n],r_sq[n],st_d.get_aed(idx[n].second));
+      d2.calc_aed(Zi, Zj,r[n],r_sq[n],st_d.get_aed(idx[n].first));
+      d2.calc_aed(Zi, Zj,r[n],r_sq[n],st_d.get_aed(idx[n].second));
     }
   }
 
-- 
GitLab


From adeb4277a79745ef9340127e173e1ceb40e973ab Mon Sep 17 00:00:00 2001
From: mkirsz <s1351949@sms.ed.ac.uk>
Date: Sat, 21 Dec 2024 17:01:30 +0000
Subject: [PATCH 13/15] one less computation

---
 src/m_tadah_base.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/src/m_tadah_base.cpp b/src/m_tadah_base.cpp
index 2efa8f1..bdba357 100644
--- a/src/m_tadah_base.cpp
+++ b/src/m_tadah_base.cpp
@@ -153,8 +153,8 @@ stress_force_predict(const StDescriptors &std, Structure &st_)
       const force_type fij = fpredict(fdij,aedi);
       const force_type fji = fpredict(fdji,aedj);
 
-      v += fpredict(fdij,aedi);
-      v -= fpredict(fdji,aedj);
+      v += fij;
+      v -= fji;
 
       for (size_t x=0; x<3; ++x) {
         for (size_t y=x; y<3; ++y) {
-- 
GitLab


From 8c978f4468425df8a456ecc9331ea8fdf5225d62 Mon Sep 17 00:00:00 2001
From: mkirsz <s1351949@sms.ed.ac.uk>
Date: Thu, 26 Dec 2024 22:58:58 +0000
Subject: [PATCH 14/15] Fix bug when reading stress with vaspreader. Added T to
 structure and adjusted readers/writers accordingly

---
 .../mlip/dataset_readers/vasp_outcar_reader.h |  1 +
 include/tadah/mlip/structure.h                |  6 +++++
 src/castep_castep_reader.cpp                  | 10 +++++++-
 src/castep_md_reader.cpp                      |  9 ++++----
 src/dataset_reader_selector.cpp               | 11 ++++-----
 src/structure.cpp                             | 15 +++++++++---
 src/vasp_outcar_reader.cpp                    | 23 +++++++++++++++++--
 src/vasp_vasprun_reader.cpp                   |  6 ++++-
 8 files changed, 64 insertions(+), 17 deletions(-)

diff --git a/include/tadah/mlip/dataset_readers/vasp_outcar_reader.h b/include/tadah/mlip/dataset_readers/vasp_outcar_reader.h
index b5bc064..34c861a 100644
--- a/include/tadah/mlip/dataset_readers/vasp_outcar_reader.h
+++ b/include/tadah/mlip/dataset_readers/vasp_outcar_reader.h
@@ -73,6 +73,7 @@ public:
 
 private:
   std::string raw_data_;  // Stores raw file data
+  std::string filename_;  // Stores raw file data
   double s_conv = 6.241509074e-4; // kbar -> eV/A^3
 };
 
diff --git a/include/tadah/mlip/structure.h b/include/tadah/mlip/structure.h
index 9648e65..306ca2c 100644
--- a/include/tadah/mlip/structure.h
+++ b/include/tadah/mlip/structure.h
@@ -62,6 +62,12 @@ struct Structure {
   /** Container for atoms which belong to this structure */
   std::vector<Atom> atoms;
 
+  /** Temperature of this structure.
+   *
+   * Default is 0.0
+   */
+  double T=0;
+
   /**
    * Container for nearest neighbour atoms for every atom in the structure.
    */
diff --git a/src/castep_castep_reader.cpp b/src/castep_castep_reader.cpp
index 89fed3a..b27d0bd 100644
--- a/src/castep_castep_reader.cpp
+++ b/src/castep_castep_reader.cpp
@@ -245,13 +245,21 @@ void CastepCastepReader::parse_data() {
     }
 
     else if (line.find("Potential Energy:") != std::string::npos) {
-      // MD: end of iteration
       std::istringstream iss(line);
       std::string tmp;
       if (!(iss >> tmp >> tmp >> tmp >> s.energy)) {
         std::cerr << "Warning, file" << filename << " line: " << counter << std::endl;
         std::cerr << "Warning: Unexpected end of data when reading total energy" << std::endl;
       }
+    }
+    else if (line.find("Temperature:") != std::string::npos) {
+      // MD: end of iteration
+      std::istringstream iss(line);
+      std::string tmp;
+      if (!(iss >> tmp >> tmp >> s.T)) {
+        std::cerr << "Warning, file" << filename << " line: " << counter << std::endl;
+        std::cerr << "Warning: Unexpected end of data when reading temperature" << std::endl;
+      }
 
       if (!label.size())label = "CASTEP MD, const. volume: false, step: 0"; // the last option
       s.label = label;
diff --git a/src/castep_md_reader.cpp b/src/castep_md_reader.cpp
index a592172..d75b260 100644
--- a/src/castep_md_reader.cpp
+++ b/src/castep_md_reader.cpp
@@ -46,7 +46,7 @@ void CastepMDReader::postproc_structure(Structure &s) {
   // finish conversion
   s.cell *= d_conv;
   s.stress *= s_conv;
-  // T /= k_b;
+  s.T = T/k_b;
 
   // add to database
   stdb.add(s);
@@ -101,17 +101,18 @@ void CastepMDReader::parse_data() {
   int S_flag=0;
   bool R_flag=false;
   bool F_flag=false;
+  bool T_flag=false;
   bool complete_structure=false;
 
   while (std::getline(stream, line)) {
     if (ends_with(line,"<-- T")) {
-      if (/* T_flag || */ complete_structure) {
+      if (T_flag || complete_structure) {
         error=true;
         continue;
       }
       std::istringstream iss(line);
       iss >> T;
-      //T_flag=true;
+      T_flag=true;
     }
     else if (ends_with(line,"<-- E")) {
       if (/* T_flag || */ E_flag || complete_structure) {
@@ -187,7 +188,7 @@ void CastepMDReader::parse_data() {
         postproc_structure(s);
 
       error=false;
-      //T_flag=false;
+      T_flag=false;
       E_flag=false;
       H_flag=0;
       S_flag=0;
diff --git a/src/dataset_reader_selector.cpp b/src/dataset_reader_selector.cpp
index bdc6d20..ef5d8b1 100644
--- a/src/dataset_reader_selector.cpp
+++ b/src/dataset_reader_selector.cpp
@@ -41,12 +41,11 @@ std::string DatasetReaderSelector::determine_file_type_by_content(const std::str
   std::string line;
   while (std::getline(file, line)) {
 
-    if (line.find("vasp") != std::string::npos) {
-      if (line.find("incar:") != std::string::npos || line.find("outcar") != std::string::npos) {
-        return "VASP.OUTCAR";
-      } else if (line.find("<modeling>") != std::string::npos || line.find("<calculation>") != std::string::npos) {
-        return "VASP.VASPRUN";
-      }
+    if (line.find("incar:") != std::string::npos ||  line.find("OUTCAR:") != std::string::npos ||
+      line.find("outcar") != std::string::npos || line.find("POTCAR:") != std::string::npos) {
+      return "VASP.OUTCAR";
+    } else if (line.find("<modeling>") != std::string::npos || line.find("<calculation>") != std::string::npos) {
+      return "VASP.VASPRUN";
     }
     else if (line.find("<-- c") != std::string::npos) {
       return "CASTEP.GEOM";
diff --git a/src/structure.cpp b/src/structure.cpp
index f5c6748..2d870bb 100644
--- a/src/structure.cpp
+++ b/src/structure.cpp
@@ -53,8 +53,17 @@ int Structure::read(std::ifstream &ifs) {
     stream.clear();
     stream.seekg(0, std::ios::beg);
     stream >> eweight >> fweight >> sweight;
-    // energy
-    ifs >> energy;
+    std::getline(ifs,line);
+    std::stringstream stream2(line);
+    // energy and T
+    stream2 >> energy;
+    if(stream2) stream2 >> T;
+  }
+  else if (count == 2) {
+    stream.clear();
+    stream.seekg(0, std::ios::beg);
+    // energy and T
+    stream >> energy >> T;
   }
   else {
     energy = std::stod(line);
@@ -262,7 +271,7 @@ void Structure::dump_to_file(std::ostream& file, size_t prec) const {
   file << label << std::endl;
   file << std::fixed << std::setprecision(prec);
   file << eweight << " " << fweight << " " << sweight << std::endl;
-  file << energy << std::endl;
+  file << energy << " " << T << std::endl;
 
   file
     << std::setw(prec+n) << cell(0,0) << " "
diff --git a/src/vasp_outcar_reader.cpp b/src/vasp_outcar_reader.cpp
index d5fccef..d1ead35 100644
--- a/src/vasp_outcar_reader.cpp
+++ b/src/vasp_outcar_reader.cpp
@@ -10,7 +10,7 @@
 VaspOutcarReader::VaspOutcarReader(StructureDB& db) : DatasetReader(db) {}
 
 VaspOutcarReader::VaspOutcarReader(StructureDB& db, const std::string& filename) 
-: DatasetReader(db, filename) {
+: DatasetReader(db, filename), filename_(filename) {
   read_data(filename);
 }
 
@@ -38,14 +38,21 @@ void VaspOutcarReader::parse_data() {
   size_t natoms;;
   bool stress_tensor_bool = false;
   bool complete_structure = false;
+  bool is_md = false;
 
   Structure s;
+  size_t counter=0;
 
   while (std::getline(stream, line)) {
-    if (line.find("VRHFIN") != std::string::npos) {
+    if (line.find("molecular dynamics") != std::string::npos) {
+      is_md = true;
+      s.label += "MD ";
+    }
+    else if (line.find("VRHFIN") != std::string::npos) {
       std::string type = line.substr(line.find("=") + 1);
       type = type.substr(0, type.find(":"));
       atom_types.push_back(type);
+      s.label += filename_ + " ";
     }
 
     else if (line.find("NIONS") != std::string::npos) {
@@ -164,8 +171,20 @@ void VaspOutcarReader::parse_data() {
       if (!(iss >> tmp >> tmp >> tmp >> s.energy)) {
         std::cerr << "Warning: Unexpected end of data when reading total energy" << std::endl;
       }
+      else if (!is_md) {
+        complete_structure = true;
+        s.label += "Structure " + std::to_string(++counter);
+      }
+    }
+    else if (line.find("EKIN_LAT=") != std::string::npos) {
+      std::istringstream iss(line);
+      std::string tmp;
+      if (!(iss >> tmp >> tmp >> tmp >> tmp >> tmp >> s.T)) {
+        std::cerr << "Warning: Unexpected end of data when reading temperature" << std::endl;
+      }
       else {
         complete_structure = true;
+        s.label += "Structure " + std::to_string(++counter);
       }
     }
 
diff --git a/src/vasp_vasprun_reader.cpp b/src/vasp_vasprun_reader.cpp
index cc8e662..c6917d3 100644
--- a/src/vasp_vasprun_reader.cpp
+++ b/src/vasp_vasprun_reader.cpp
@@ -25,6 +25,8 @@ void VaspVasprunReader::read_data(const std::string& filename) {
   } catch (std::exception &e) {
     std::cerr << "Error reading file: " << e.what() << std::endl;
   }
+
+  _s.label = filename + " ";
 }
 
 void VaspVasprunReader::parse_data() {
@@ -83,6 +85,7 @@ void VaspVasprunReader::extract_atom_types(rx::xml_node<> *root_node) {
 }
 
 void VaspVasprunReader::extract_calculations(rx::xml_node<> *root_node) {
+  size_t counter=0;
   for (auto calculation_node = root_node->first_node("calculation");
   calculation_node; calculation_node = calculation_node->next_sibling("calculation")) {
 
@@ -102,6 +105,7 @@ void VaspVasprunReader::extract_calculations(rx::xml_node<> *root_node) {
 
     extract_forces(calculation_node);
 
+    _s.label += "Structure " + std::to_string(++counter);
     stdb.add(_s);
     _s = Structure(); // reset
   }
@@ -140,12 +144,12 @@ void VaspVasprunReader::extract_stress_tensor(rx::xml_node<> *calculation_node)
           _s.stress(r, 0) = x;
           _s.stress(r, 1) = y;
           _s.stress(r, 2) = z;
-          _s.stress *= s_conv;
         } else {
           std::cerr << "Error parsing stress tensor components." << std::endl;
         }
         r++;
       }
+      _s.stress *= s_conv;
       break;
     }
     varray_node = varray_node->next_sibling("varray");
-- 
GitLab


From edf014d74ace834ce90f4a1aed26f0181f98b668 Mon Sep 17 00:00:00 2001
From: mkirsz <s1351949@sms.ed.ac.uk>
Date: Thu, 26 Dec 2024 23:26:45 +0000
Subject: [PATCH 15/15] Added density calculation

---
 include/tadah/mlip/structure.h | 3 +++
 src/structure.cpp              | 8 ++++++++
 2 files changed, 11 insertions(+)

diff --git a/include/tadah/mlip/structure.h b/include/tadah/mlip/structure.h
index 306ca2c..c36baf2 100644
--- a/include/tadah/mlip/structure.h
+++ b/include/tadah/mlip/structure.h
@@ -137,6 +137,9 @@ struct Structure {
   /** @return volume of this structure. */
   double get_volume() const;
 
+  /** @return density of this structure in g/cm^3 */
+  double get_density() const;
+
   /** @return virial pressure calculated from the stress tensor.
    *
    *  Units: energy/distance^3
diff --git a/src/structure.cpp b/src/structure.cpp
index 2d870bb..9db209c 100644
--- a/src/structure.cpp
+++ b/src/structure.cpp
@@ -173,6 +173,14 @@ size_t Structure::get_nn_iindex(const size_t i, const size_t j, const size_t jj)
 double Structure::get_volume() const {
   return cell.row(0)*(cell.row(1).cross(cell.row(2)));
 }
+double Structure::get_density() const {
+  double V = cell.row(0)*(cell.row(1).cross(cell.row(2)));
+  V*=1e-24; // convert to cm^3
+  double amu = 1.66053906660e-24; // g
+  double mass = 0;
+  for (const auto& a:atoms) mass += PeriodicTable::get_mass(a.Z);
+  return amu*mass/V;
+}
 
 double Structure::get_virial_pressure() const {
   return stress.trace()/get_volume()/3;
-- 
GitLab