From 874c2dfb5fe9f1e07c302a3c2345c98d3a7161f9 Mon Sep 17 00:00:00 2001
From: mkirsz <s1351949@sms.ed.ac.uk>
Date: Sun, 15 Dec 2024 01:32:50 +0000
Subject: [PATCH] Update

---
 include/tadah/mlip/descriptors_calc.h   |  3 --
 include/tadah/mlip/descriptors_calc.hpp | 47 +++++++------------------
 2 files changed, 12 insertions(+), 38 deletions(-)

diff --git a/include/tadah/mlip/descriptors_calc.h b/include/tadah/mlip/descriptors_calc.h
index daaf78c..14d1c4b 100644
--- a/include/tadah/mlip/descriptors_calc.h
+++ b/include/tadah/mlip/descriptors_calc.h
@@ -105,7 +105,6 @@ class DescriptorsCalc: public DC_Base {
         /** Calculate all descriptors in a StructureDB st_db */
         StDescriptorsDB calc(const StructureDB &st_db);
 
-
         /** Calculate density vector for the structure */
         void calc_rho(const Structure &st, StDescriptors &std);
 
@@ -122,8 +121,6 @@ class DescriptorsCalc: public DC_Base {
         void calc(const Structure &st, StDescriptors &std);
         void calc_dimer(const Structure &st, StDescriptors &std);
         void common_constructor();
-        double weights[119];  // ignore zero index; w[1]->H
-
 };
 
 #include "descriptors_calc.hpp"
diff --git a/include/tadah/mlip/descriptors_calc.hpp b/include/tadah/mlip/descriptors_calc.hpp
index 380401a..0d3813f 100644
--- a/include/tadah/mlip/descriptors_calc.hpp
+++ b/include/tadah/mlip/descriptors_calc.hpp
@@ -14,12 +14,6 @@ DescriptorsCalc<D2,D3,DM,C2,C3,CM>::DescriptorsCalc(Config &c, T1 &t1, T2 &t2, T
   d3(t2),
   dm(t3)
 {
-  for (size_t i=0; i<c.size("ATOMS"); ++i) {
-    std::string symbol = c.get<std::string>("ATOMS",i);
-    double wi = c.get<double>("WATOMS",i);
-    int Z = PeriodicTable::find_by_symbol(symbol).Z;
-    weights[Z]=wi;
-  }
   if (!config.exist("DSIZE"))
     common_constructor();
   else {
@@ -75,12 +69,6 @@ DescriptorsCalc<D2,D3,DM,C2,C3,CM>::DescriptorsCalc(Config &c, T1 &d2, T2 &d3, T
   d3(d3),
   dm(dm)
 {
-  for (size_t i=0; i<c.size("ATOMS"); ++i) {
-    std::string symbol = c.get<std::string>("ATOMS",i);
-    double wi = c.get<double>("WATOMS",i);
-    int Z = PeriodicTable::find_by_symbol(symbol).Z;
-    weights[Z]=wi;
-  }
   if (c.get<bool>("INIT2B")) {
     if (!config.exist("RCTYPE2B"))
       config.add("RCTYPE2B",c2.label());
@@ -170,10 +158,8 @@ void DescriptorsCalc<D2,D3,DM,C2,C3,CM>::calc_rho(const Structure &st, StDescrip
       double rij_sq = delij[0]*delij[0] + delij[1]*delij[1] + delij[2]*delij[2];
       if (rij_sq > rcut_mb_sq) continue;
       int Zj = st.near_neigh_atoms[i][jj].Z;
-      double wj = weights[Zj];
       double rij = sqrt(rij_sq);
-      double fc_ij = wj*cm.calc(rij);
-      dm.calc_rho(rij,rij_sq,fc_ij,delij,st_d.get_rho(i));
+      dm.calc_rho(Zj,rij,rij_sq,delij,st_d.get_rho(i));
     }
   }
 }
@@ -204,8 +190,8 @@ void DescriptorsCalc<D2,D3,DM,C2,C3,CM>::calc(const Structure &st, StDescriptors
   double rcut_2b_sq = 0.0;
   double rcut_mb_sq = 0.0;
 
-  if (init2b) rcut_2b_sq = pow(config.get<double>("RCUT2B"),2);
-  if (initmb) rcut_mb_sq = pow(config.get<double>("RCUTMB"),2);
+  if (init2b) rcut_2b_sq = pow(d2.get_rcut(),2);
+  if (initmb) rcut_mb_sq = pow(dm.get_rcut(),2);
 
   // zero all aeds and set bias
   for (size_t i=0; i<st.natoms(); ++i) {
@@ -244,7 +230,6 @@ void DescriptorsCalc<D2,D3,DM,C2,C3,CM>::calc(const Structure &st, StDescriptors
 
       if (rij_sq > rcut_max_sq) continue;
       int Zj = st.near_neigh_atoms[i][jj].Z;
-      double wj = weights[Zj];
       double rij = sqrt(rij_sq);
       double rij_inv = 1.0/rij;
 
@@ -252,9 +237,7 @@ void DescriptorsCalc<D2,D3,DM,C2,C3,CM>::calc(const Structure &st, StDescriptors
       if (use_force || use_stress) {
         fd_type &fd_ij = st_d.fd[i][jj];
         if (rij_sq <= rcut_2b_sq && init2b) {
-          double fc_ij = wj*c2.calc(rij);
-          double fcp_ij = wj*c2.calc_prime(rij);
-          d2.calc_all(rij,rij_sq,fc_ij,fcp_ij,aed,fd_ij);
+          d2.calc_all(Zj,rij,rij_sq,aed,fd_ij);
           // Two-body descriptor calculates x-direction only - fd_ij(n,0)
           // so we have to copy x-dir to y- and z-dir
           // and scale them by the unit directional vector delij/rij.
@@ -267,11 +250,8 @@ void DescriptorsCalc<D2,D3,DM,C2,C3,CM>::calc(const Structure &st, StDescriptors
         }
         // CALCULATE MANY-BODY TERM
         if (rij_sq <= rcut_mb_sq && initmb) {
-          double fc_ij = wj*cm.calc(rij);
-          double fcp_ij = wj*cm.calc_prime(rij);
           rho_type& rhoi = st_d.get_rho(i);
-          int mode = dm.calc_dXijdri(rij,rij_sq,delij,
-              fc_ij,fcp_ij,rhoi,fd_ij);
+          int mode = dm.calc_dXijdri(Zj,rij,rij_sq,delij,rhoi,fd_ij);
           if (mode==0) {
             // some dm compute x-dir only, similarly to d2 above
             for (size_t n=size2b+bias; n<size2b+sizemb+bias; ++n) {
@@ -287,8 +267,7 @@ void DescriptorsCalc<D2,D3,DM,C2,C3,CM>::calc(const Structure &st, StDescriptors
       }
       else {
         if (rij_sq <= rcut_2b_sq && init2b) {
-          double fc_ij = wj*c2.calc(rij);
-          d2.calc_aed(rij,rij_sq,fc_ij,aed);
+          d2.calc_aed(Zj,rij,rij_sq,aed);
         }
       }
 
@@ -332,8 +311,8 @@ void DescriptorsCalc<D2,D3,DM,C2,C3,CM>::calc_dimer(const Structure &st, StDescr
   double rcut_2b_sq = 0.0;
   double rcut_mb_sq = 0.0;
 
-  if (init2b) rcut_2b_sq = pow(config.get<double>("RCUT2B"),2);
-  if (initmb) rcut_mb_sq = pow(config.get<double>("RCUTMB"),2);
+  if (init2b) rcut_2b_sq = pow(d2.get_rcut(),2);
+  if (initmb) rcut_mb_sq = pow(dm.get_rcut(),2);
 
   // Max distance between CoM of two interacting molecules
   double rcut_com_sq = pow(config.get<double>("RCUTMAX")-r_b,2);
@@ -399,15 +378,13 @@ void DescriptorsCalc<D2,D3,DM,C2,C3,CM>::calc_dimer(const Structure &st, StDescr
   size_t N = bond_bool ? 6 : 4;
   for (size_t n=0; n<N; ++n) {
     if (r_sq[n] <= rcut_mb_sq && initmb) {
-      double fcmb = Zj*cm.calc(r[n]);
-      dm.calc_rho(r[n],r_sq[n],fcmb,delM.row(n),st_d.get_rho(idx[n].first));
-      dm.calc_rho(r[n],r_sq[n],fcmb,-delM.row(n),st_d.get_rho(idx[n].second));
+      dm.calc_rho(Zj, r[n],r_sq[n],delM.row(n),st_d.get_rho(idx[n].first));
+      dm.calc_rho(Zj, r[n],r_sq[n],-delM.row(n),st_d.get_rho(idx[n].second));
     }
     // Do not compute 2b term between bonded atoms
     if (r_sq[n] <= rcut_2b_sq && init2b) {
-      double fc2b = Zj*c2.calc(r[n]);
-      d2.calc_aed(r[n],r_sq[n],fc2b,st_d.get_aed(idx[n].first));
-      d2.calc_aed(r[n],r_sq[n],fc2b,st_d.get_aed(idx[n].second));
+      d2.calc_aed(Zj,r[n],r_sq[n],st_d.get_aed(idx[n].first));
+      d2.calc_aed(Zj,r[n],r_sq[n],st_d.get_aed(idx[n].second));
     }
   }
 
-- 
GitLab