diff --git a/include/tadah/models/descriptors/d_base.h b/include/tadah/models/descriptors/d_base.h index 183a58bc2d6f0c24860b89bbbfcf400ccfc4043d..77f215b56465fcc23dff3d23443296b578f04d52 100644 --- a/include/tadah/models/descriptors/d_base.h +++ b/include/tadah/models/descriptors/d_base.h @@ -8,6 +8,7 @@ #include <tadah/core/registry.h> #include <tadah/models/cut_all.h> #include <vector> +#include <bitset> /** \brief Base class for all descriptor types. * @@ -30,9 +31,6 @@ public: // if key is not used, returns -1 int get_arg_pos(const std::string &key) const; - /** \brief Return dimension of the descriptor. - */ - virtual size_t size() { return s; }; /** \brief Return label of this descriptor. */ @@ -60,10 +58,19 @@ public: virtual void set_fcut(Cut_Base* cut, bool manage_memory); - double get_rcut() { - return fcut->get_rcut(); } + double get_rcut(); + + /** \brief Return dimension of the descriptor. + */ + virtual size_t size(); + + virtual bool is_init_for_atom(int Z); + virtual bool is_init_for_atoms(int Zi, int Zj); + virtual void init_for_atom(int Zi); + virtual void uninit_for_atom(int Zi); private: bool manage_memory=false; // who owns fcut + std::bitset<119> init_for_atoms; }; #endif diff --git a/src/d2_mjoin.cpp b/src/d2_mjoin.cpp index cbd991047f117acb3a853ff2e4ed5ce337d5fe42..44edea5da3f32a9f9b45db37c6f11433e13331eb 100644 --- a/src/d2_mjoin.cpp +++ b/src/d2_mjoin.cpp @@ -22,7 +22,8 @@ D2_mJoin::D2_mJoin(Config &c) : D2_Base(c) { void D2_mJoin::calc_aed(const int Zi, const int Zj, const double rij, const double rij_sq, aed_type &aed, const double scale) { for (auto d : ds) { - if (rij > d->get_rcut()) return; + if (!d->is_init_for_atoms(Zi,Zj)) continue; + if (rij > d->get_rcut()) continue; d->calc_aed(Zi, Zj, rij, rij_sq, aed, scale); } } @@ -30,7 +31,8 @@ void D2_mJoin::calc_aed(const int Zi, const int Zj, const double rij, void D2_mJoin::calc_dXijdri(const int Zi, const int Zj, const double rij, const double rij_sq, fd_type &fd_ij, const double scale) { for (auto d : ds) { - if (rij > d->get_rcut()) return; + if (!d->is_init_for_atoms(Zi,Zj)) continue; + if (rij > d->get_rcut()) continue; d->calc_dXijdri(Zi, Zj, rij, rij_sq, fd_ij, scale); } } @@ -38,7 +40,8 @@ void D2_mJoin::calc_dXijdri(const int Zi, const int Zj, const double rij, void D2_mJoin::calc_all(const int Zi, const int Zj, const double rij, const double rij_sq, aed_type &aed, fd_type &fd_ij, const double scale) { for (auto d : ds) { - if (rij > d->get_rcut()) return; + if (!d->is_init_for_atoms(Zi,Zj)) continue; + if (rij > d->get_rcut()) continue; d->calc_all(Zi, Zj, rij, rij_sq, aed, fd_ij, scale); } } diff --git a/src/d_base.cpp b/src/d_base.cpp index 3f1ca91c4a5e69e0576d311e2d15ca1ea2f62faa..3999fb4f1630abc51795f0fc525fa6cb7bdca6d9 100644 --- a/src/d_base.cpp +++ b/src/d_base.cpp @@ -79,6 +79,7 @@ D_Base::D_Base(Config &c): double wi = c.get<double>("WATOMS",i); int Z = PeriodicTable::find_by_symbol(symbol).Z; weights[Z]=wi; + init_for_atom(Z); } } } @@ -98,3 +99,19 @@ int D_Base::get_arg_pos(const std::string &key) const { return std::distance(keys.begin(), it)+(nparams-keys.size()); } +size_t D_Base::size() { return s; }; +double D_Base::get_rcut() { + return fcut->get_rcut(); +} +bool D_Base::is_init_for_atom(int Z) { + return init_for_atoms[Z]; +} +bool D_Base::is_init_for_atoms(int Zi, int Zj) { + return init_for_atoms[Zi] && init_for_atoms[Zj]; +} +void D_Base::init_for_atom(int Z) { + init_for_atoms.set(Z); +} +void D_Base::uninit_for_atom(int Z) { + init_for_atoms.reset(Z); +} diff --git a/src/dm_mjoin.cpp b/src/dm_mjoin.cpp index 91bbfd53e98cb6b49411cb4f6b8fbc718eac2bb2..c9e3eaea382171b8a98e5a837683834b3b50ad13 100644 --- a/src/dm_mjoin.cpp +++ b/src/dm_mjoin.cpp @@ -41,11 +41,12 @@ void DM_mJoin::calc_dXijdri_dXjidri( const double scale) { size_t rho_fidx = 0; for (auto d : ds) { - if (rij > d->get_rcut()) return; size_t rho_size = 2 * d->rhoi_size(); - rho_type rhoi_ptr(&rhoi[rho_fidx],rho_size); - rho_type rhoj_ptr(&rhoj[rho_fidx],rho_size); - d->calc_dXijdri_dXjidri(Zi,Zj,rij,rij_sq,vec_ij,rhoi_ptr,rhoj_ptr,fd_ij,scale); + if (d->is_init_for_atoms(Zi,Zj) && rij < d->get_rcut()) { + rho_type rhoi_ptr(&rhoi[rho_fidx],rho_size); + rho_type rhoj_ptr(&rhoj[rho_fidx],rho_size); + d->calc_dXijdri_dXjidri(Zi,Zj,rij,rij_sq,vec_ij,rhoi_ptr,rhoj_ptr,fd_ij,scale); + } rho_fidx += rho_size; } } @@ -60,10 +61,11 @@ void DM_mJoin::calc_dXijdri( const double scale) { size_t rho_fidx = 0; for (auto d : ds) { - if (rij > d->get_rcut()) return; size_t rho_size = 2 * d->rhoi_size(); - rho_type rhoi_ptr(&rhoi[rho_fidx],rho_size); - d->calc_dXijdri(Zi,Zj,rij,rij_sq,vec_ij,rhoi_ptr,fd_ij,scale); + if (d->is_init_for_atoms(Zi,Zj) && rij < d->get_rcut()) { + rho_type rhoi_ptr(&rhoi[rho_fidx],rho_size); + d->calc_dXijdri(Zi,Zj,rij,rij_sq,vec_ij,rhoi_ptr,fd_ij,scale); + } rho_fidx += rho_size; } } @@ -81,10 +83,11 @@ void DM_mJoin::calc_rho( const double scale) { size_t rho_fidx = 0; for (auto d : ds) { - if (rij > d->get_rcut()) return; size_t rho_size = 2 * d->rhoi_size(); - rho_type rho_ptr(&rho[rho_fidx],rho_size); - d->calc_rho(Zi,Zj,rij,rij_sq,vec_ij,rho_ptr,scale); + if (d->is_init_for_atoms(Zi,Zj) && rij < d->get_rcut()) { + rho_type rho_ptr(&rho[rho_fidx],rho_size); + d->calc_rho(Zi,Zj,rij,rij_sq,vec_ij,rho_ptr,scale); + } rho_fidx += rho_size; } }