From e5df1fcb958411676abffa8c5bc3ff2a6f8ba55a Mon Sep 17 00:00:00 2001 From: mkirsz <s1351949@sms.ed.ac.uk> Date: Sun, 22 Dec 2024 00:46:46 +0000 Subject: [PATCH] added bitmap for checking is calc initalised for atoms, fixed return to continue as it was a clear bug --- include/tadah/models/descriptors/d_base.h | 17 ++++++++++++----- src/d2_mjoin.cpp | 9 ++++++--- src/d_base.cpp | 17 +++++++++++++++++ src/dm_mjoin.cpp | 23 +++++++++++++---------- 4 files changed, 48 insertions(+), 18 deletions(-) diff --git a/include/tadah/models/descriptors/d_base.h b/include/tadah/models/descriptors/d_base.h index 183a58b..77f215b 100644 --- a/include/tadah/models/descriptors/d_base.h +++ b/include/tadah/models/descriptors/d_base.h @@ -8,6 +8,7 @@ #include <tadah/core/registry.h> #include <tadah/models/cut_all.h> #include <vector> +#include <bitset> /** \brief Base class for all descriptor types. * @@ -30,9 +31,6 @@ public: // if key is not used, returns -1 int get_arg_pos(const std::string &key) const; - /** \brief Return dimension of the descriptor. - */ - virtual size_t size() { return s; }; /** \brief Return label of this descriptor. */ @@ -60,10 +58,19 @@ public: virtual void set_fcut(Cut_Base* cut, bool manage_memory); - double get_rcut() { - return fcut->get_rcut(); } + double get_rcut(); + + /** \brief Return dimension of the descriptor. + */ + virtual size_t size(); + + virtual bool is_init_for_atom(int Z); + virtual bool is_init_for_atoms(int Zi, int Zj); + virtual void init_for_atom(int Zi); + virtual void uninit_for_atom(int Zi); private: bool manage_memory=false; // who owns fcut + std::bitset<119> init_for_atoms; }; #endif diff --git a/src/d2_mjoin.cpp b/src/d2_mjoin.cpp index cbd9910..44edea5 100644 --- a/src/d2_mjoin.cpp +++ b/src/d2_mjoin.cpp @@ -22,7 +22,8 @@ D2_mJoin::D2_mJoin(Config &c) : D2_Base(c) { void D2_mJoin::calc_aed(const int Zi, const int Zj, const double rij, const double rij_sq, aed_type &aed, const double scale) { for (auto d : ds) { - if (rij > d->get_rcut()) return; + if (!d->is_init_for_atoms(Zi,Zj)) continue; + if (rij > d->get_rcut()) continue; d->calc_aed(Zi, Zj, rij, rij_sq, aed, scale); } } @@ -30,7 +31,8 @@ void D2_mJoin::calc_aed(const int Zi, const int Zj, const double rij, void D2_mJoin::calc_dXijdri(const int Zi, const int Zj, const double rij, const double rij_sq, fd_type &fd_ij, const double scale) { for (auto d : ds) { - if (rij > d->get_rcut()) return; + if (!d->is_init_for_atoms(Zi,Zj)) continue; + if (rij > d->get_rcut()) continue; d->calc_dXijdri(Zi, Zj, rij, rij_sq, fd_ij, scale); } } @@ -38,7 +40,8 @@ void D2_mJoin::calc_dXijdri(const int Zi, const int Zj, const double rij, void D2_mJoin::calc_all(const int Zi, const int Zj, const double rij, const double rij_sq, aed_type &aed, fd_type &fd_ij, const double scale) { for (auto d : ds) { - if (rij > d->get_rcut()) return; + if (!d->is_init_for_atoms(Zi,Zj)) continue; + if (rij > d->get_rcut()) continue; d->calc_all(Zi, Zj, rij, rij_sq, aed, fd_ij, scale); } } diff --git a/src/d_base.cpp b/src/d_base.cpp index 3f1ca91..3999fb4 100644 --- a/src/d_base.cpp +++ b/src/d_base.cpp @@ -79,6 +79,7 @@ D_Base::D_Base(Config &c): double wi = c.get<double>("WATOMS",i); int Z = PeriodicTable::find_by_symbol(symbol).Z; weights[Z]=wi; + init_for_atom(Z); } } } @@ -98,3 +99,19 @@ int D_Base::get_arg_pos(const std::string &key) const { return std::distance(keys.begin(), it)+(nparams-keys.size()); } +size_t D_Base::size() { return s; }; +double D_Base::get_rcut() { + return fcut->get_rcut(); +} +bool D_Base::is_init_for_atom(int Z) { + return init_for_atoms[Z]; +} +bool D_Base::is_init_for_atoms(int Zi, int Zj) { + return init_for_atoms[Zi] && init_for_atoms[Zj]; +} +void D_Base::init_for_atom(int Z) { + init_for_atoms.set(Z); +} +void D_Base::uninit_for_atom(int Z) { + init_for_atoms.reset(Z); +} diff --git a/src/dm_mjoin.cpp b/src/dm_mjoin.cpp index 91bbfd5..c9e3eae 100644 --- a/src/dm_mjoin.cpp +++ b/src/dm_mjoin.cpp @@ -41,11 +41,12 @@ void DM_mJoin::calc_dXijdri_dXjidri( const double scale) { size_t rho_fidx = 0; for (auto d : ds) { - if (rij > d->get_rcut()) return; size_t rho_size = 2 * d->rhoi_size(); - rho_type rhoi_ptr(&rhoi[rho_fidx],rho_size); - rho_type rhoj_ptr(&rhoj[rho_fidx],rho_size); - d->calc_dXijdri_dXjidri(Zi,Zj,rij,rij_sq,vec_ij,rhoi_ptr,rhoj_ptr,fd_ij,scale); + if (d->is_init_for_atoms(Zi,Zj) && rij < d->get_rcut()) { + rho_type rhoi_ptr(&rhoi[rho_fidx],rho_size); + rho_type rhoj_ptr(&rhoj[rho_fidx],rho_size); + d->calc_dXijdri_dXjidri(Zi,Zj,rij,rij_sq,vec_ij,rhoi_ptr,rhoj_ptr,fd_ij,scale); + } rho_fidx += rho_size; } } @@ -60,10 +61,11 @@ void DM_mJoin::calc_dXijdri( const double scale) { size_t rho_fidx = 0; for (auto d : ds) { - if (rij > d->get_rcut()) return; size_t rho_size = 2 * d->rhoi_size(); - rho_type rhoi_ptr(&rhoi[rho_fidx],rho_size); - d->calc_dXijdri(Zi,Zj,rij,rij_sq,vec_ij,rhoi_ptr,fd_ij,scale); + if (d->is_init_for_atoms(Zi,Zj) && rij < d->get_rcut()) { + rho_type rhoi_ptr(&rhoi[rho_fidx],rho_size); + d->calc_dXijdri(Zi,Zj,rij,rij_sq,vec_ij,rhoi_ptr,fd_ij,scale); + } rho_fidx += rho_size; } } @@ -81,10 +83,11 @@ void DM_mJoin::calc_rho( const double scale) { size_t rho_fidx = 0; for (auto d : ds) { - if (rij > d->get_rcut()) return; size_t rho_size = 2 * d->rhoi_size(); - rho_type rho_ptr(&rho[rho_fidx],rho_size); - d->calc_rho(Zi,Zj,rij,rij_sq,vec_ij,rho_ptr,scale); + if (d->is_init_for_atoms(Zi,Zj) && rij < d->get_rcut()) { + rho_type rho_ptr(&rho[rho_fidx],rho_size); + d->calc_rho(Zi,Zj,rij,rij_sq,vec_ij,rho_ptr,scale); + } rho_fidx += rho_size; } } -- GitLab