diff --git a/include/tadah/models/descriptors/d2/d2_base.h b/include/tadah/models/descriptors/d2/d2_base.h index c70580c38d896f11e62a80f3d636da7606a78bce..8e7917c74c17ba471468ed0b9f2bff27e72ce343 100644 --- a/include/tadah/models/descriptors/d2/d2_base.h +++ b/include/tadah/models/descriptors/d2/d2_base.h @@ -22,6 +22,7 @@ public: } } virtual ~D2_Base() {}; + virtual std::vector<std::string> get_init_atoms(Config &c) override; /** \brief Calculate \ref AED * diff --git a/include/tadah/models/descriptors/d3/d3_base.h b/include/tadah/models/descriptors/d3/d3_base.h index ff4efecfe93062cc8b0941379848e60733c0d193..c64a76bfd218d1ff385f95fe2935d2358e3e8ef2 100644 --- a/include/tadah/models/descriptors/d3/d3_base.h +++ b/include/tadah/models/descriptors/d3/d3_base.h @@ -4,37 +4,37 @@ #include <tadah/models/descriptors/d_base.h> class D3_Base: public D_Base { - public: - virtual ~D3_Base() {}; +public: + virtual ~D3_Base() {}; - virtual void calc_aed( - const size_t fidx, - const double rij, - const double rik, - const double fc_ij, - const double fc_ik, - aed_type& aed)=0; - virtual void calc_fd( - const size_t fidx, - const double rij, - const double rik, - const double fc_ij, - const double fc_ik, - const double fcp_ij, - const double fcp_ik, - fd_type &fd_ij)=0; - virtual void calc_all( - const size_t fidx, - const double rij, - const double rik, - const double fc_ij, - const double fc_ik, - const double fcp_ij, - const double fcp_ik, - aed_type& aed, - fd_type &fd_ij)=0; - virtual size_t size()=0; - virtual std::string label()=0; + virtual void calc_aed( + const size_t fidx, + const double rij, + const double rik, + const double fc_ij, + const double fc_ik, + aed_type& aed)=0; + virtual void calc_fd( + const size_t fidx, + const double rij, + const double rik, + const double fc_ij, + const double fc_ik, + const double fcp_ij, + const double fcp_ik, + fd_type &fd_ij)=0; + virtual void calc_all( + const size_t fidx, + const double rij, + const double rik, + const double fc_ij, + const double fc_ik, + const double fcp_ij, + const double fcp_ik, + aed_type& aed, + fd_type &fd_ij)=0; + virtual size_t size()override=0; + virtual std::string label()override=0; + virtual std::vector<std::string> get_init_atoms(Config &c) override; }; -//template<> inline Registry<D3_Base,Config&>::Map Registry<D3_Base,Config&>::registry{}; #endif diff --git a/include/tadah/models/descriptors/d_base.h b/include/tadah/models/descriptors/d_base.h index 77f215b56465fcc23dff3d23443296b578f04d52..8685b31b3c22cd130ef9d7b06ad68e91faa50831 100644 --- a/include/tadah/models/descriptors/d_base.h +++ b/include/tadah/models/descriptors/d_base.h @@ -67,10 +67,13 @@ public: virtual bool is_init_for_atom(int Z); virtual bool is_init_for_atoms(int Zi, int Zj); virtual void init_for_atom(int Zi); + virtual void init_for_atoms(const std::vector<std::string> &Zs); virtual void uninit_for_atom(int Zi); + static std::vector<std::string> get_init_atoms(Config &c, std::string type); + virtual std::vector<std::string> get_init_atoms(Config &c)=0; private: bool manage_memory=false; // who owns fcut - std::bitset<119> init_for_atoms; + std::bitset<119> init_for_atoms_map; }; #endif diff --git a/include/tadah/models/descriptors/dm/dm_base.h b/include/tadah/models/descriptors/dm/dm_base.h index 03ca07240c126db60eaee5062f157ec712518287..825a1c52061f7763b196cc4855a41526c8fee7b0 100644 --- a/include/tadah/models/descriptors/dm/dm_base.h +++ b/include/tadah/models/descriptors/dm/dm_base.h @@ -26,6 +26,7 @@ public: } } virtual ~DM_Base() {}; + virtual std::vector<std::string> get_init_atoms(Config &c) override; /** \brief Calculate \ref AED * diff --git a/include/tadah/models/descriptors/dm/dm_mEAD.hpp b/include/tadah/models/descriptors/dm/dm_mEAD.hpp index c28018a6fdc4a33bbd4b2c7dbe1378fbef44e4cb..b5493621e7d4de3725b8d4f913b32ac75592af48 100644 --- a/include/tadah/models/descriptors/dm/dm_mEAD.hpp +++ b/include/tadah/models/descriptors/dm/dm_mEAD.hpp @@ -48,6 +48,8 @@ DM_mEAD<F>::DM_mEAD(Config &config): DM_Base(config), std::cout << std::endl; std::cout << "rhoisize: " << rhoisize << std::endl; } + auto init_atoms = get_init_atoms(config); + init_for_atoms(init_atoms); } template <typename F> diff --git a/src/d2_base.cpp b/src/d2_base.cpp index 00abcd05957115e8b4026f6b4ad322c8ae77e57b..aefca65f63901cb38dbc8058769f3fba31b97809 100644 --- a/src/d2_base.cpp +++ b/src/d2_base.cpp @@ -1 +1,8 @@ #include <tadah/models/descriptors/d2/d2_base.h> + +std::vector<std::string> D2_Base::get_init_atoms(Config &c) { + std::vector<std::string> init_atoms(c.size("TYPE2B")); + c.get("TYPE2B", init_atoms); + init_atoms.erase(init_atoms.begin(), init_atoms.begin() + nparams+1); + return init_atoms; +} diff --git a/src/d2_blip.cpp b/src/d2_blip.cpp index 7cf352fa0d617bdffc63265c8ea45c67221a8713..3c16d5b275c25ba4b9c0092c316138f8a181836c 100644 --- a/src/d2_blip.cpp +++ b/src/d2_blip.cpp @@ -41,6 +41,9 @@ D2_Blip::D2_Blip(Config &c): D2_Base(c) } s=mius.size(); + + auto init_atoms = get_init_atoms(c); + init_for_atoms(init_atoms); } void D2_Blip::calc_aed( diff --git a/src/d2_bp.cpp b/src/d2_bp.cpp index 383995c8ad9d0656436b05814f0825e715aee4ad..2b564f7f404ff8cdb184124bf0057a8a47c5f8e4 100644 --- a/src/d2_bp.cpp +++ b/src/d2_bp.cpp @@ -37,7 +37,9 @@ D2_BP::D2_BP(Config &c): D2_Base(c) { throw std::runtime_error("At least one of SGRID2B values is zero.\n"); } - s=mius.size(); + s = mius.size(); + auto init_atoms = get_init_atoms(c); + init_for_atoms(init_atoms); } void D2_BP::calc_aed( diff --git a/src/d2_eam.cpp b/src/d2_eam.cpp index c168631ed917a39aaaec2956fb34c519e67a7f96..90b8b7969ddf2e1b8a7f325982d26a62b621c11c 100644 --- a/src/d2_eam.cpp +++ b/src/d2_eam.cpp @@ -31,6 +31,9 @@ D2_EAM::D2_EAM(Config &c): D2_Base(c) gen_splines(ef.nrho, ef.drho, ef.frho, frho_spline); gen_splines(ef.nr, ef.dr, ef.rhor, rhor_spline); gen_splines(ef.nr, ef.dr, ef.z2r, z2r_spline); + + auto init_atoms = get_init_atoms(c); + init_for_atoms(init_atoms); } void D2_EAM::calc_aed( diff --git a/src/d2_lj.cpp b/src/d2_lj.cpp index 64f33a3039a89c5e30e8dae9c01133df2fd5f156..3d4158322588b40d8d411a81a25ed85d4a1c380e 100644 --- a/src/d2_lj.cpp +++ b/src/d2_lj.cpp @@ -9,6 +9,9 @@ D2_LJ::D2_LJ(Config &c): D2_Base(c) if (!c.get<bool>("INIT2B")) return; init(); s=2; + + auto init_atoms = get_init_atoms(c); + init_for_atoms(init_atoms); } void D2_LJ::calc_aed( diff --git a/src/d2_mie.cpp b/src/d2_mie.cpp index 2d35258e8356343af9d67012e8f74e73561382de..450c2d9f7f5e14818c7498112969cd5ae0a63114 100644 --- a/src/d2_mie.cpp +++ b/src/d2_mie.cpp @@ -14,6 +14,8 @@ D2_MIE::D2_MIE(Config &c): D2_Base(c) if (n<0 || m<0) { throw std::runtime_error("Both Mie exponents must by positive\n"); } + auto init_atoms = get_init_atoms(c); + init_for_atoms(init_atoms); } void D2_MIE::calc_aed( diff --git a/src/d2_zbl.cpp b/src/d2_zbl.cpp index e158b5e63d63124eb54e40c6bcbd9b74d756996d..90542c77445c885ee858139fe391348f1fe3ddd8 100644 --- a/src/d2_zbl.cpp +++ b/src/d2_zbl.cpp @@ -37,6 +37,8 @@ D2_ZBL::D2_ZBL(Config &c): D2_Base(c) a[i][j] = screening_length(Zi,Zj); } } + auto init_atoms = get_init_atoms(c); + init_for_atoms(init_atoms); } D2_ZBL::~D2_ZBL() { if (a) { diff --git a/src/d3_base.cpp b/src/d3_base.cpp index bf59bc896569e2267f1bc1deb8cde5d32d56e769..fe5fa6932ef38b78dac3fd4d1bdf9f5f990629df 100644 --- a/src/d3_base.cpp +++ b/src/d3_base.cpp @@ -1,3 +1,6 @@ #include <tadah/models/descriptors/d3/d3_base.h> //template struct Registry<D3_Base, Config &>; +std::vector<std::string> D3_Base::get_init_atoms(Config &c) { + return {}; +} diff --git a/src/d_base.cpp b/src/d_base.cpp index 3999fb4f1630abc51795f0fc525fa6cb7bdca6d9..2443396f4d988c57d8fcfab1448c5e8c3fefe514 100644 --- a/src/d_base.cpp +++ b/src/d_base.cpp @@ -79,7 +79,6 @@ D_Base::D_Base(Config &c): double wi = c.get<double>("WATOMS",i); int Z = PeriodicTable::find_by_symbol(symbol).Z; weights[Z]=wi; - init_for_atom(Z); } } } @@ -104,14 +103,28 @@ double D_Base::get_rcut() { return fcut->get_rcut(); } bool D_Base::is_init_for_atom(int Z) { - return init_for_atoms[Z]; + return init_for_atoms_map[Z]; } bool D_Base::is_init_for_atoms(int Zi, int Zj) { - return init_for_atoms[Zi] && init_for_atoms[Zj]; + return init_for_atoms_map[Zi] && init_for_atoms_map[Zj]; } void D_Base::init_for_atom(int Z) { - init_for_atoms.set(Z); + init_for_atoms_map.set(Z); +} +void D_Base::init_for_atoms(const std::vector<std::string> &Zs) { + for (const auto &Z_str : Zs){ + int Z = PeriodicTable::find_by_symbol(Z_str).Z; + init_for_atom(Z); + } } void D_Base::uninit_for_atom(int Z) { - init_for_atoms.reset(Z); + init_for_atoms_map.reset(Z); +} +std::vector<std::string> D_Base::get_init_atoms(Config &c, std::string type) { + std::vector<std::string> init_atoms(c.size(type)); + c.get(type, init_atoms); + std::unique_ptr<D_Base> d(CONFIG::factory<D_Base>(init_atoms[0])); + size_t nparams = d->nparams; + init_atoms.erase(init_atoms.begin(), init_atoms.begin() + nparams); + return init_atoms; } diff --git a/src/d_mjoin.cpp b/src/d_mjoin.cpp index 1e67b93ef368568d88cee564821ff37f137d3efc..9c1e3682847470463bac4b166858ce768017ac11 100644 --- a/src/d_mjoin.cpp +++ b/src/d_mjoin.cpp @@ -147,11 +147,40 @@ std::vector<Config> D_mJoin::parse_config(Config &c, std::string type_str) { for (const auto& k : d->keys) { int n = std::stoi(*it++); + c1.add(type_str,n); for (int j = 0; j < n; ++j) { - c1.add(k,*(data_map[k]++)); /// Problem? + c1.add(k,*(data_map[k]++)); } } + // read init atoms + size_t found_init_atoms=0; + while (it != types.end()) { + if (it->find(prefix) == 0) { + break; + } + c1.add(type_str,*it++); + found_init_atoms++; + } + // if (it != types.end() && it != types.begin()) { + // --it; + // } + if (!found_init_atoms) { + throw std::runtime_error( + "Error: No element types specified. " + "Please provide element types for this calculator. " + "For example: '"+token+" Ti Ti Ti Nb' sets the calculator to compute Ti-Ti and Ti-Nb interactions. " + "Token: " + token); + } + + if (found_init_atoms % 2 != 0) { + throw std::runtime_error( + "Error: Element types must be provided in pairs. " + "Ensure each element is paired correctly. " + "For instance, '" + token + " Ti * Nb Zr' means that interactions of Ti with any atom, as well as Nb-Zr, will be computed. " + "Token: " + token); + } + configs.push_back(c1); delete d; } diff --git a/src/dm_base.cpp b/src/dm_base.cpp index 7fe63ac8afe5d8129e9ba4d86a7cdc63fff294a1..c6247e5c42c297c4f63de0fd61b358ca073ffc99 100644 --- a/src/dm_base.cpp +++ b/src/dm_base.cpp @@ -1,3 +1,10 @@ #include <tadah/models/descriptors/dm/dm_base.h> void DM_Base::set_rfidx(size_t rfidx_) { rfidx=rfidx_; } size_t DM_Base::get_rfidx() { return rfidx; } + +std::vector<std::string> DM_Base::get_init_atoms(Config &c) { + std::vector<std::string> init_atoms(c.size("TYPEMB")); + c.get("TYPEMB", init_atoms); + init_atoms.erase(init_atoms.begin(), init_atoms.begin() + nparams+1); + return init_atoms; +} diff --git a/src/dm_blip.cpp b/src/dm_blip.cpp index c68b8ec96c721f00e126b5cdc6914c5e66c0c7ad..48408d27b4d4d1cd18eaaa0d2451773f165d97f0 100644 --- a/src/dm_blip.cpp +++ b/src/dm_blip.cpp @@ -40,6 +40,8 @@ DM_Blip::DM_Blip(Config &c): DM_Base(c), std::cout << std::endl; std::cout << "rhoisize: " << rhoisize << std::endl; } + auto init_atoms = get_init_atoms(c); + init_for_atoms(init_atoms); } void DM_Blip::calc_aed( diff --git a/src/dm_ead.cpp b/src/dm_ead.cpp index b7419b34846a3015495060dcbb5027e0481669d1..8cb83aaa524773dd2052e86b6829eed84d8fbb88 100644 --- a/src/dm_ead.cpp +++ b/src/dm_ead.cpp @@ -47,6 +47,8 @@ DM_EAD::DM_EAD(Config &config): DM_Base(config), std::cout << std::endl; std::cout << "rhoisize: " << rhoisize << std::endl; } + auto init_atoms = get_init_atoms(config); + init_for_atoms(init_atoms); } void DM_EAD::calc_aed( diff --git a/src/dm_eam.cpp b/src/dm_eam.cpp index 84509872c72448fdf8cdc9c45ad4b768193ff614..3408b29ec065c682f7d3e17ac63758d7b4621365 100644 --- a/src/dm_eam.cpp +++ b/src/dm_eam.cpp @@ -32,6 +32,8 @@ DM_EAM::DM_EAM(Config &c): DM_Base(c) gen_splines(ef.nr, ef.dr, ef.z2r, z2r_spline); rhoisize = 1; + auto init_atoms = get_init_atoms(c); + init_for_atoms(init_atoms); } void DM_EAM::calc_aed(