From 0871db3a18a2a2b3b5c4a203be7e69056cb08cb7 Mon Sep 17 00:00:00 2001
From: mkirsz <s1351949@sms.ed.ac.uk>
Date: Wed, 18 Dec 2024 16:35:48 +0000
Subject: [PATCH] Zi Zj for all calcs

---
 include/tadah/mlip/descriptors_calc.hpp | 19 +++++++++++--------
 1 file changed, 11 insertions(+), 8 deletions(-)

diff --git a/include/tadah/mlip/descriptors_calc.hpp b/include/tadah/mlip/descriptors_calc.hpp
index 01ae99d..f9659b4 100644
--- a/include/tadah/mlip/descriptors_calc.hpp
+++ b/include/tadah/mlip/descriptors_calc.hpp
@@ -168,6 +168,7 @@ void DescriptorsCalc<D2,D3,DM,C2,C3,CM>::calc_rho(const Structure &st, StDescrip
   for (size_t i=0; i<st.natoms(); ++i) {
     const Atom &a1 = st(i);
 
+    int Zi = a1.Z;
     for (size_t jj=0; jj<st.nn_size(i); ++jj) {
       const Vec3d &a2pos = st.nn_pos(i,jj);
       //delij = a1.position - a2pos;
@@ -180,7 +181,7 @@ void DescriptorsCalc<D2,D3,DM,C2,C3,CM>::calc_rho(const Structure &st, StDescrip
       if (rij_sq > rcut_mb_sq) continue;
       int Zj = st.near_neigh_atoms[i][jj].Z;
       double rij = sqrt(rij_sq);
-      dm.calc_rho(Zj,rij,rij_sq,delij,st_d.get_rho(i));
+      dm.calc_rho(Zi,Zj,rij,rij_sq,delij,st_d.get_rho(i));
     }
   }
 }
@@ -240,6 +241,7 @@ void DescriptorsCalc<D2,D3,DM,C2,C3,CM>::calc(const Structure &st, StDescriptors
     const Atom &a1 = st(i);
     aed_type &aed = st_d.get_aed(i);
 
+    int Zi = a1.Z;
     for (size_t jj=0; jj<st.nn_size(i); ++jj) {
       const Vec3d &a2pos = st.nn_pos(i,jj);
 
@@ -258,7 +260,7 @@ void DescriptorsCalc<D2,D3,DM,C2,C3,CM>::calc(const Structure &st, StDescriptors
       if (use_force || use_stress) {
         fd_type &fd_ij = st_d.fd[i][jj];
         if (rij_sq <= rcut_2b_sq && init2b) {
-          d2.calc_all(Zj,rij,rij_sq,aed,fd_ij,0.5);
+          d2.calc_all(Zi,Zj,rij,rij_sq,aed,fd_ij,0.5);
           // Two-body descriptor calculates x-direction only - fd_ij(n,0)
           // so we have to copy x-dir to y- and z-dir
           // and scale them by the unit directional vector delij/rij.
@@ -272,12 +274,12 @@ void DescriptorsCalc<D2,D3,DM,C2,C3,CM>::calc(const Structure &st, StDescriptors
         // CALCULATE MANY-BODY TERM
         if (rij_sq <= rcut_mb_sq && initmb) {
           rho_type& rhoi = st_d.get_rho(i);
-          dm.calc_dXijdri(Zj,rij,rij_sq,delij,rhoi,fd_ij);
+          dm.calc_dXijdri(Zi,Zj,rij,rij_sq,delij,rhoi,fd_ij);
         }
       }
       else {
         if (rij_sq <= rcut_2b_sq && init2b) {
-          d2.calc_aed(Zj,rij,rij_sq,aed,0.5);
+          d2.calc_aed(Zi, Zj,rij,rij_sq,aed,0.5);
         }
       }
 
@@ -342,6 +344,7 @@ void DescriptorsCalc<D2,D3,DM,C2,C3,CM>::calc_dimer(const Structure &st, StDescr
   // TODO weighting factors
   // For now assume all are the same type
   //int Zj = st.near_neigh_atoms[0][0].Z;
+  int Zi = 1;
   int Zj = 1;
 
   // map of atom label and distances
@@ -388,13 +391,13 @@ void DescriptorsCalc<D2,D3,DM,C2,C3,CM>::calc_dimer(const Structure &st, StDescr
   size_t N = bond_bool ? 6 : 4;
   for (size_t n=0; n<N; ++n) {
     if (r_sq[n] <= rcut_mb_sq && initmb) {
-      dm.calc_rho(Zj, r[n],r_sq[n],delM.row(n),st_d.get_rho(idx[n].first));
-      dm.calc_rho(Zj, r[n],r_sq[n],-delM.row(n),st_d.get_rho(idx[n].second));
+      dm.calc_rho(Zi,Zj, r[n],r_sq[n],delM.row(n),st_d.get_rho(idx[n].first));
+      dm.calc_rho(Zi,Zj, r[n],r_sq[n],-delM.row(n),st_d.get_rho(idx[n].second));
     }
     // Do not compute 2b term between bonded atoms
     if (r_sq[n] <= rcut_2b_sq && init2b) {
-      d2.calc_aed(Zj,r[n],r_sq[n],st_d.get_aed(idx[n].first));
-      d2.calc_aed(Zj,r[n],r_sq[n],st_d.get_aed(idx[n].second));
+      d2.calc_aed(Zi, Zj,r[n],r_sq[n],st_d.get_aed(idx[n].first));
+      d2.calc_aed(Zi, Zj,r[n],r_sq[n],st_d.get_aed(idx[n].second));
     }
   }
 
-- 
GitLab