From e5df1fcb958411676abffa8c5bc3ff2a6f8ba55a Mon Sep 17 00:00:00 2001
From: mkirsz <s1351949@sms.ed.ac.uk>
Date: Sun, 22 Dec 2024 00:46:46 +0000
Subject: [PATCH] added bitmap for checking is calc initalised for atoms, fixed
 return to continue as it was a clear bug

---
 include/tadah/models/descriptors/d_base.h | 17 ++++++++++++-----
 src/d2_mjoin.cpp                          |  9 ++++++---
 src/d_base.cpp                            | 17 +++++++++++++++++
 src/dm_mjoin.cpp                          | 23 +++++++++++++----------
 4 files changed, 48 insertions(+), 18 deletions(-)

diff --git a/include/tadah/models/descriptors/d_base.h b/include/tadah/models/descriptors/d_base.h
index 183a58b..77f215b 100644
--- a/include/tadah/models/descriptors/d_base.h
+++ b/include/tadah/models/descriptors/d_base.h
@@ -8,6 +8,7 @@
 #include <tadah/core/registry.h>
 #include <tadah/models/cut_all.h>
 #include <vector>
+#include <bitset>
 
 /** \brief Base class for all descriptor types.
  *
@@ -30,9 +31,6 @@ public:
   // if key is not used, returns -1
   int get_arg_pos(const std::string &key) const;
 
-  /** \brief Return dimension of the descriptor.
-  */
-  virtual size_t size() { return s; };
 
   /** \brief Return label of this descriptor.
    */
@@ -60,10 +58,19 @@ public:
 
   virtual void set_fcut(Cut_Base* cut, bool manage_memory);
 
-  double get_rcut() { 
-    return fcut->get_rcut(); }
+  double get_rcut();
+
+  /** \brief Return dimension of the descriptor.
+  */
+  virtual size_t size();
+
+  virtual bool is_init_for_atom(int Z);
+  virtual bool is_init_for_atoms(int Zi, int Zj);
+  virtual void init_for_atom(int Zi);
+  virtual void uninit_for_atom(int Zi);
 
 private:
   bool manage_memory=false; // who owns fcut
+  std::bitset<119> init_for_atoms;
 };
 #endif
diff --git a/src/d2_mjoin.cpp b/src/d2_mjoin.cpp
index cbd9910..44edea5 100644
--- a/src/d2_mjoin.cpp
+++ b/src/d2_mjoin.cpp
@@ -22,7 +22,8 @@ D2_mJoin::D2_mJoin(Config &c) : D2_Base(c) {
 void D2_mJoin::calc_aed(const int Zi, const int Zj, const double rij,
                         const double rij_sq, aed_type &aed, const double scale) {
   for (auto d : ds) {
-    if (rij > d->get_rcut()) return;
+    if (!d->is_init_for_atoms(Zi,Zj)) continue;
+    if (rij > d->get_rcut()) continue;
     d->calc_aed(Zi, Zj, rij, rij_sq, aed, scale);
   }
 }
@@ -30,7 +31,8 @@ void D2_mJoin::calc_aed(const int Zi, const int Zj, const double rij,
 void D2_mJoin::calc_dXijdri(const int Zi, const int Zj, const double rij,
                             const double rij_sq, fd_type &fd_ij, const double scale) {
   for (auto d : ds) {
-    if (rij > d->get_rcut()) return;
+    if (!d->is_init_for_atoms(Zi,Zj)) continue;
+    if (rij > d->get_rcut()) continue;
     d->calc_dXijdri(Zi, Zj, rij, rij_sq, fd_ij, scale);
   }
 }
@@ -38,7 +40,8 @@ void D2_mJoin::calc_dXijdri(const int Zi, const int Zj, const double rij,
 void D2_mJoin::calc_all(const int Zi, const int Zj, const double rij,
                         const double rij_sq, aed_type &aed, fd_type &fd_ij, const double scale) {
   for (auto d : ds) {
-    if (rij > d->get_rcut()) return;
+    if (!d->is_init_for_atoms(Zi,Zj)) continue;
+    if (rij > d->get_rcut()) continue;
     d->calc_all(Zi, Zj, rij, rij_sq, aed, fd_ij, scale);
   }
 }
diff --git a/src/d_base.cpp b/src/d_base.cpp
index 3f1ca91..3999fb4 100644
--- a/src/d_base.cpp
+++ b/src/d_base.cpp
@@ -79,6 +79,7 @@ D_Base::D_Base(Config &c):
       double wi = c.get<double>("WATOMS",i);
       int Z = PeriodicTable::find_by_symbol(symbol).Z;
       weights[Z]=wi;
+      init_for_atom(Z);
     }
   }
 }
@@ -98,3 +99,19 @@ int D_Base::get_arg_pos(const std::string &key) const {
 
   return std::distance(keys.begin(), it)+(nparams-keys.size());
 }
+size_t D_Base::size() { return s; };
+double D_Base::get_rcut() { 
+  return fcut->get_rcut(); 
+}
+bool D_Base::is_init_for_atom(int Z) {
+  return init_for_atoms[Z];
+}
+bool D_Base::is_init_for_atoms(int Zi, int Zj) {
+  return init_for_atoms[Zi] && init_for_atoms[Zj];
+}
+void D_Base::init_for_atom(int Z) {
+  init_for_atoms.set(Z);
+}
+void D_Base::uninit_for_atom(int Z) {
+  init_for_atoms.reset(Z);
+}
diff --git a/src/dm_mjoin.cpp b/src/dm_mjoin.cpp
index 91bbfd5..c9e3eae 100644
--- a/src/dm_mjoin.cpp
+++ b/src/dm_mjoin.cpp
@@ -41,11 +41,12 @@ void DM_mJoin::calc_dXijdri_dXjidri(
   const double scale) {
   size_t rho_fidx = 0;
   for (auto d : ds) {
-    if (rij > d->get_rcut()) return;
     size_t rho_size = 2 * d->rhoi_size();
-    rho_type rhoi_ptr(&rhoi[rho_fidx],rho_size);
-    rho_type rhoj_ptr(&rhoj[rho_fidx],rho_size);
-    d->calc_dXijdri_dXjidri(Zi,Zj,rij,rij_sq,vec_ij,rhoi_ptr,rhoj_ptr,fd_ij,scale); 
+    if (d->is_init_for_atoms(Zi,Zj) && rij < d->get_rcut()) {
+      rho_type rhoi_ptr(&rhoi[rho_fidx],rho_size);
+      rho_type rhoj_ptr(&rhoj[rho_fidx],rho_size);
+      d->calc_dXijdri_dXjidri(Zi,Zj,rij,rij_sq,vec_ij,rhoi_ptr,rhoj_ptr,fd_ij,scale); 
+    }
     rho_fidx += rho_size;
   }
 }
@@ -60,10 +61,11 @@ void DM_mJoin::calc_dXijdri(
   const double scale) {
   size_t rho_fidx = 0;
   for (auto d : ds) { 
-    if (rij > d->get_rcut()) return;
     size_t rho_size = 2 * d->rhoi_size();
-    rho_type rhoi_ptr(&rhoi[rho_fidx],rho_size);
-    d->calc_dXijdri(Zi,Zj,rij,rij_sq,vec_ij,rhoi_ptr,fd_ij,scale);
+    if (d->is_init_for_atoms(Zi,Zj) && rij < d->get_rcut()) {
+      rho_type rhoi_ptr(&rhoi[rho_fidx],rho_size);
+      d->calc_dXijdri(Zi,Zj,rij,rij_sq,vec_ij,rhoi_ptr,fd_ij,scale);
+    }
     rho_fidx += rho_size;
   }
 }
@@ -81,10 +83,11 @@ void DM_mJoin::calc_rho(
   const double scale) {
   size_t rho_fidx = 0;
   for (auto d : ds) {
-    if (rij > d->get_rcut()) return;
     size_t rho_size = 2 * d->rhoi_size();
-    rho_type rho_ptr(&rho[rho_fidx],rho_size);
-    d->calc_rho(Zi,Zj,rij,rij_sq,vec_ij,rho_ptr,scale);
+    if (d->is_init_for_atoms(Zi,Zj) && rij < d->get_rcut()) {
+      rho_type rho_ptr(&rho[rho_fidx],rho_size);
+      d->calc_rho(Zi,Zj,rij,rij_sq,vec_ij,rho_ptr,scale);
+    }
     rho_fidx += rho_size;
   }
 }
-- 
GitLab