From 5fc6d58b28d659d31b683106f470e2957ef7b59d Mon Sep 17 00:00:00 2001
From: mkirsz <s1351949@sms.ed.ac.uk>
Date: Wed, 18 Dec 2024 16:36:19 +0000
Subject: [PATCH] Zi Zj for all calcs

---
 ML-TADAH/pair_tadah.cpp | 31 ++++++++++++++++---------------
 1 file changed, 16 insertions(+), 15 deletions(-)

diff --git a/ML-TADAH/pair_tadah.cpp b/ML-TADAH/pair_tadah.cpp
index ec3e1ec..db13f42 100644
--- a/ML-TADAH/pair_tadah.cpp
+++ b/ML-TADAH/pair_tadah.cpp
@@ -196,7 +196,7 @@ void PairTadah::compute_2b_mb_half(int eflag, int vflag)
       int Zj = lt->Z[type[j]-1];
 
       if (lt->S.d2b->get_rcut() > rij) {
-        lt->S.d2b->calc_aed(Zj,rij,rij_sq,lt->aeds.col(i),0.5);
+        lt->S.d2b->calc_aed(Zi,Zj,rij,rij_sq,lt->aeds.col(i),0.5);
 
         if (newton_pair || j < nlocal) {
           lt->S.d2b->calc_aed(Zi,rij,rij_sq,lt->aeds.col(j),0.5);
@@ -204,10 +204,10 @@ void PairTadah::compute_2b_mb_half(int eflag, int vflag)
       }
 
       if (lt->S.dmb->get_rcut() > rij) {
-        lt->S.dmb->calc_rho(Zj,rij,rij_sq,delij,lt->rhos.col(i));
+        lt->S.dmb->calc_rho(Zi,Zj,rij,rij_sq,delij,lt->rhos.col(i));
 
         if (newton_pair || j < nlocal) {
-          lt->S.dmb->calc_rho(Zi,rij,rij_sq,-delij,lt->rhos.col(j));
+          lt->S.dmb->calc_rho(Zi,Zj,rij,rij_sq,-delij,lt->rhos.col(j));
         }
       }
     }
@@ -272,7 +272,7 @@ void PairTadah::compute_2b_mb_half(int eflag, int vflag)
       int Zj = lt->Z[type[j]-1];
 
       if (lt->S.d2b->get_rcut() > rij) {
-        lt->S.d2b->calc_dXijdri(Zj,rij,rij_sq,fd);
+        lt->S.d2b->calc_dXijdri(Zi,Zj,rij,rij_sq,fd);
       }
 
       if (lt->S.dmb->get_rcut() > rij) {
@@ -373,11 +373,11 @@ void PairTadah::compute_2b_mb_full(int eflag, int vflag)
       double Zj = lt->Z[type[j]-1];
 
       if (lt->S.d2b->get_rcut() > rij) {
-        lt->S.d2b->calc_aed(Zj,rij,rij_sq,lt->aeds.col(i),0.5);
+        lt->S.d2b->calc_aed(Zi,Zj,rij,rij_sq,lt->aeds.col(i),0.5);
       }
 
       if (lt->S.dmb->get_rcut() > rij) {
-        lt->S.dmb->calc_rho(Zj,rij,rij_sq,delij,lt->rhos.col(i));
+        lt->S.dmb->calc_rho(Zi,Zj,rij,rij_sq,delij,lt->rhos.col(i));
       }
     }
   }
@@ -419,6 +419,7 @@ void PairTadah::compute_2b_mb_full(int eflag, int vflag)
 
     jlist = firstneigh[i];
     jnum = numneigh[i];
+    int Zi = lt->Z[type[i]-1];
     for (jj = 0; jj < jnum; jj++) {
       fd.set_zero();
       j = jlist[jj];
@@ -432,11 +433,11 @@ void PairTadah::compute_2b_mb_full(int eflag, int vflag)
       int Zj = lt->Z[type[j]-1];
 
       if (lt->S.d2b->get_rcut() > rij) {
-        lt->S.d2b->calc_dXijdri(Zj,rij,rij_sq,fd,0.5);
+        lt->S.d2b->calc_dXijdri(Zi,Zj,rij,rij_sq,fd,0.5);
       }
 
       if (lt->S.dmb->get_rcut() > rij) {
-        lt->S.dmb->calc_dXijdri(Zj,rij,rij_sq,delij,lt->rhos.col(i),fd);
+        lt->S.dmb->calc_dXijdri(Zi,Zj,rij,rij_sq,delij,lt->rhos.col(i),fd);
       }
 
       // copy 2b descriptors first and
@@ -679,12 +680,12 @@ void PairTadah::compute_dimers(int eflag, int vflag)
       // Compute densities for every atom.
       for (int n=Nstart; n<6; ++n) {
         if (lt->initmb && lt->S.dmb->get_rcut() >= r[n]) {
-          lt->S.dmb->calc_rho(1,r[n],r_sq[n],delM.row(n),lt->rhos.col(midx[n].first));
-          lt->S.dmb->calc_rho(1,r[n],r_sq[n],-delM.row(n),lt->rhos.col(midx[n].second));
+          lt->S.dmb->calc_rho(1,1,r[n],r_sq[n],delM.row(n),lt->rhos.col(midx[n].first));
+          lt->S.dmb->calc_rho(1,1,r[n],r_sq[n],-delM.row(n),lt->rhos.col(midx[n].second));
         }
         if (lt->init2b && lt->S.d2b->get_rcut() >= r[n]) {
-          lt->S.d2b->calc_aed(1,r[n],r_sq[n],lt->aeds.col(midx[n].first));
-          lt->S.d2b->calc_aed(1,r[n],r_sq[n],lt->aeds.col(midx[n].second));
+          lt->S.d2b->calc_aed(1,1,r[n],r_sq[n],lt->aeds.col(midx[n].first));
+          lt->S.d2b->calc_aed(1,1,r[n],r_sq[n],lt->aeds.col(midx[n].second));
         }
       }
 
@@ -712,13 +713,13 @@ void PairTadah::compute_dimers(int eflag, int vflag)
         int I=IJ[midx[n].first];
         int J=IJ[midx[n].second];
         if (lt->init2b && lt->S.d2b->get_rcut() > r[n]) {
-          lt->S.d2b->calc_dXijdri(1,r[n],r_sq[n],fdIJ);
+          lt->S.d2b->calc_dXijdri(1,1,r[n],r_sq[n],fdIJ);
         }
 
         if (lt->initmb && lt->S.dmb->get_rcut() > r[n]) {
-          lt->S.dmb->calc_dXijdri(1,r[n],r_sq[n],delM.row(n),
+          lt->S.dmb->calc_dXijdri(1,1,r[n],r_sq[n],delM.row(n),
                                   lt->rhos.col(midx[n].first), fdIJ);
-          lt->S.dmb->calc_dXijdri(1,r[n],r_sq[n],-delM.row(n),
+          lt->S.dmb->calc_dXijdri(1,1,r[n],r_sq[n],-delM.row(n),
                                   lt->rhos.col(midx[n].second), fdJI);
         }
 
-- 
GitLab