From 03111c1d358d540ba23ddc5adad8c47b44ca6018 Mon Sep 17 00:00:00 2001 From: mkirsz <s1351949@sms.ed.ac.uk> Date: Sun, 22 Dec 2024 10:57:23 +0000 Subject: [PATCH] Descriptors to manage which atoms to calculate --- include/tadah/models/descriptors/d2/d2_base.h | 1 + include/tadah/models/descriptors/d3/d3_base.h | 62 +++++++++---------- include/tadah/models/descriptors/d_base.h | 5 +- include/tadah/models/descriptors/dm/dm_base.h | 1 + .../tadah/models/descriptors/dm/dm_mEAD.hpp | 2 + src/d2_base.cpp | 7 +++ src/d2_blip.cpp | 3 + src/d2_bp.cpp | 4 +- src/d2_eam.cpp | 3 + src/d2_lj.cpp | 3 + src/d2_mie.cpp | 2 + src/d2_zbl.cpp | 2 + src/d3_base.cpp | 3 + src/d_base.cpp | 23 +++++-- src/d_mjoin.cpp | 31 +++++++++- src/dm_base.cpp | 7 +++ src/dm_blip.cpp | 2 + src/dm_ead.cpp | 2 + src/dm_eam.cpp | 2 + 19 files changed, 126 insertions(+), 39 deletions(-) diff --git a/include/tadah/models/descriptors/d2/d2_base.h b/include/tadah/models/descriptors/d2/d2_base.h index c70580c..8e7917c 100644 --- a/include/tadah/models/descriptors/d2/d2_base.h +++ b/include/tadah/models/descriptors/d2/d2_base.h @@ -22,6 +22,7 @@ public: } } virtual ~D2_Base() {}; + virtual std::vector<std::string> get_init_atoms(Config &c) override; /** \brief Calculate \ref AED * diff --git a/include/tadah/models/descriptors/d3/d3_base.h b/include/tadah/models/descriptors/d3/d3_base.h index ff4efec..c64a76b 100644 --- a/include/tadah/models/descriptors/d3/d3_base.h +++ b/include/tadah/models/descriptors/d3/d3_base.h @@ -4,37 +4,37 @@ #include <tadah/models/descriptors/d_base.h> class D3_Base: public D_Base { - public: - virtual ~D3_Base() {}; +public: + virtual ~D3_Base() {}; - virtual void calc_aed( - const size_t fidx, - const double rij, - const double rik, - const double fc_ij, - const double fc_ik, - aed_type& aed)=0; - virtual void calc_fd( - const size_t fidx, - const double rij, - const double rik, - const double fc_ij, - const double fc_ik, - const double fcp_ij, - const double fcp_ik, - fd_type &fd_ij)=0; - virtual void calc_all( - const size_t fidx, - const double rij, - const double rik, - const double fc_ij, - const double fc_ik, - const double fcp_ij, - const double fcp_ik, - aed_type& aed, - fd_type &fd_ij)=0; - virtual size_t size()=0; - virtual std::string label()=0; + virtual void calc_aed( + const size_t fidx, + const double rij, + const double rik, + const double fc_ij, + const double fc_ik, + aed_type& aed)=0; + virtual void calc_fd( + const size_t fidx, + const double rij, + const double rik, + const double fc_ij, + const double fc_ik, + const double fcp_ij, + const double fcp_ik, + fd_type &fd_ij)=0; + virtual void calc_all( + const size_t fidx, + const double rij, + const double rik, + const double fc_ij, + const double fc_ik, + const double fcp_ij, + const double fcp_ik, + aed_type& aed, + fd_type &fd_ij)=0; + virtual size_t size()override=0; + virtual std::string label()override=0; + virtual std::vector<std::string> get_init_atoms(Config &c) override; }; -//template<> inline Registry<D3_Base,Config&>::Map Registry<D3_Base,Config&>::registry{}; #endif diff --git a/include/tadah/models/descriptors/d_base.h b/include/tadah/models/descriptors/d_base.h index 77f215b..8685b31 100644 --- a/include/tadah/models/descriptors/d_base.h +++ b/include/tadah/models/descriptors/d_base.h @@ -67,10 +67,13 @@ public: virtual bool is_init_for_atom(int Z); virtual bool is_init_for_atoms(int Zi, int Zj); virtual void init_for_atom(int Zi); + virtual void init_for_atoms(const std::vector<std::string> &Zs); virtual void uninit_for_atom(int Zi); + static std::vector<std::string> get_init_atoms(Config &c, std::string type); + virtual std::vector<std::string> get_init_atoms(Config &c)=0; private: bool manage_memory=false; // who owns fcut - std::bitset<119> init_for_atoms; + std::bitset<119> init_for_atoms_map; }; #endif diff --git a/include/tadah/models/descriptors/dm/dm_base.h b/include/tadah/models/descriptors/dm/dm_base.h index 03ca072..825a1c5 100644 --- a/include/tadah/models/descriptors/dm/dm_base.h +++ b/include/tadah/models/descriptors/dm/dm_base.h @@ -26,6 +26,7 @@ public: } } virtual ~DM_Base() {}; + virtual std::vector<std::string> get_init_atoms(Config &c) override; /** \brief Calculate \ref AED * diff --git a/include/tadah/models/descriptors/dm/dm_mEAD.hpp b/include/tadah/models/descriptors/dm/dm_mEAD.hpp index c28018a..b549362 100644 --- a/include/tadah/models/descriptors/dm/dm_mEAD.hpp +++ b/include/tadah/models/descriptors/dm/dm_mEAD.hpp @@ -48,6 +48,8 @@ DM_mEAD<F>::DM_mEAD(Config &config): DM_Base(config), std::cout << std::endl; std::cout << "rhoisize: " << rhoisize << std::endl; } + auto init_atoms = get_init_atoms(config); + init_for_atoms(init_atoms); } template <typename F> diff --git a/src/d2_base.cpp b/src/d2_base.cpp index 00abcd0..aefca65 100644 --- a/src/d2_base.cpp +++ b/src/d2_base.cpp @@ -1 +1,8 @@ #include <tadah/models/descriptors/d2/d2_base.h> + +std::vector<std::string> D2_Base::get_init_atoms(Config &c) { + std::vector<std::string> init_atoms(c.size("TYPE2B")); + c.get("TYPE2B", init_atoms); + init_atoms.erase(init_atoms.begin(), init_atoms.begin() + nparams+1); + return init_atoms; +} diff --git a/src/d2_blip.cpp b/src/d2_blip.cpp index 7cf352f..3c16d5b 100644 --- a/src/d2_blip.cpp +++ b/src/d2_blip.cpp @@ -41,6 +41,9 @@ D2_Blip::D2_Blip(Config &c): D2_Base(c) } s=mius.size(); + + auto init_atoms = get_init_atoms(c); + init_for_atoms(init_atoms); } void D2_Blip::calc_aed( diff --git a/src/d2_bp.cpp b/src/d2_bp.cpp index 383995c..2b564f7 100644 --- a/src/d2_bp.cpp +++ b/src/d2_bp.cpp @@ -37,7 +37,9 @@ D2_BP::D2_BP(Config &c): D2_Base(c) { throw std::runtime_error("At least one of SGRID2B values is zero.\n"); } - s=mius.size(); + s = mius.size(); + auto init_atoms = get_init_atoms(c); + init_for_atoms(init_atoms); } void D2_BP::calc_aed( diff --git a/src/d2_eam.cpp b/src/d2_eam.cpp index c168631..90b8b79 100644 --- a/src/d2_eam.cpp +++ b/src/d2_eam.cpp @@ -31,6 +31,9 @@ D2_EAM::D2_EAM(Config &c): D2_Base(c) gen_splines(ef.nrho, ef.drho, ef.frho, frho_spline); gen_splines(ef.nr, ef.dr, ef.rhor, rhor_spline); gen_splines(ef.nr, ef.dr, ef.z2r, z2r_spline); + + auto init_atoms = get_init_atoms(c); + init_for_atoms(init_atoms); } void D2_EAM::calc_aed( diff --git a/src/d2_lj.cpp b/src/d2_lj.cpp index 64f33a3..3d41583 100644 --- a/src/d2_lj.cpp +++ b/src/d2_lj.cpp @@ -9,6 +9,9 @@ D2_LJ::D2_LJ(Config &c): D2_Base(c) if (!c.get<bool>("INIT2B")) return; init(); s=2; + + auto init_atoms = get_init_atoms(c); + init_for_atoms(init_atoms); } void D2_LJ::calc_aed( diff --git a/src/d2_mie.cpp b/src/d2_mie.cpp index 2d35258..450c2d9 100644 --- a/src/d2_mie.cpp +++ b/src/d2_mie.cpp @@ -14,6 +14,8 @@ D2_MIE::D2_MIE(Config &c): D2_Base(c) if (n<0 || m<0) { throw std::runtime_error("Both Mie exponents must by positive\n"); } + auto init_atoms = get_init_atoms(c); + init_for_atoms(init_atoms); } void D2_MIE::calc_aed( diff --git a/src/d2_zbl.cpp b/src/d2_zbl.cpp index e158b5e..90542c7 100644 --- a/src/d2_zbl.cpp +++ b/src/d2_zbl.cpp @@ -37,6 +37,8 @@ D2_ZBL::D2_ZBL(Config &c): D2_Base(c) a[i][j] = screening_length(Zi,Zj); } } + auto init_atoms = get_init_atoms(c); + init_for_atoms(init_atoms); } D2_ZBL::~D2_ZBL() { if (a) { diff --git a/src/d3_base.cpp b/src/d3_base.cpp index bf59bc8..fe5fa69 100644 --- a/src/d3_base.cpp +++ b/src/d3_base.cpp @@ -1,3 +1,6 @@ #include <tadah/models/descriptors/d3/d3_base.h> //template struct Registry<D3_Base, Config &>; +std::vector<std::string> D3_Base::get_init_atoms(Config &c) { + return {}; +} diff --git a/src/d_base.cpp b/src/d_base.cpp index 3999fb4..2443396 100644 --- a/src/d_base.cpp +++ b/src/d_base.cpp @@ -79,7 +79,6 @@ D_Base::D_Base(Config &c): double wi = c.get<double>("WATOMS",i); int Z = PeriodicTable::find_by_symbol(symbol).Z; weights[Z]=wi; - init_for_atom(Z); } } } @@ -104,14 +103,28 @@ double D_Base::get_rcut() { return fcut->get_rcut(); } bool D_Base::is_init_for_atom(int Z) { - return init_for_atoms[Z]; + return init_for_atoms_map[Z]; } bool D_Base::is_init_for_atoms(int Zi, int Zj) { - return init_for_atoms[Zi] && init_for_atoms[Zj]; + return init_for_atoms_map[Zi] && init_for_atoms_map[Zj]; } void D_Base::init_for_atom(int Z) { - init_for_atoms.set(Z); + init_for_atoms_map.set(Z); +} +void D_Base::init_for_atoms(const std::vector<std::string> &Zs) { + for (const auto &Z_str : Zs){ + int Z = PeriodicTable::find_by_symbol(Z_str).Z; + init_for_atom(Z); + } } void D_Base::uninit_for_atom(int Z) { - init_for_atoms.reset(Z); + init_for_atoms_map.reset(Z); +} +std::vector<std::string> D_Base::get_init_atoms(Config &c, std::string type) { + std::vector<std::string> init_atoms(c.size(type)); + c.get(type, init_atoms); + std::unique_ptr<D_Base> d(CONFIG::factory<D_Base>(init_atoms[0])); + size_t nparams = d->nparams; + init_atoms.erase(init_atoms.begin(), init_atoms.begin() + nparams); + return init_atoms; } diff --git a/src/d_mjoin.cpp b/src/d_mjoin.cpp index 1e67b93..9c1e368 100644 --- a/src/d_mjoin.cpp +++ b/src/d_mjoin.cpp @@ -147,11 +147,40 @@ std::vector<Config> D_mJoin::parse_config(Config &c, std::string type_str) { for (const auto& k : d->keys) { int n = std::stoi(*it++); + c1.add(type_str,n); for (int j = 0; j < n; ++j) { - c1.add(k,*(data_map[k]++)); /// Problem? + c1.add(k,*(data_map[k]++)); } } + // read init atoms + size_t found_init_atoms=0; + while (it != types.end()) { + if (it->find(prefix) == 0) { + break; + } + c1.add(type_str,*it++); + found_init_atoms++; + } + // if (it != types.end() && it != types.begin()) { + // --it; + // } + if (!found_init_atoms) { + throw std::runtime_error( + "Error: No element types specified. " + "Please provide element types for this calculator. " + "For example: '"+token+" Ti Ti Ti Nb' sets the calculator to compute Ti-Ti and Ti-Nb interactions. " + "Token: " + token); + } + + if (found_init_atoms % 2 != 0) { + throw std::runtime_error( + "Error: Element types must be provided in pairs. " + "Ensure each element is paired correctly. " + "For instance, '" + token + " Ti * Nb Zr' means that interactions of Ti with any atom, as well as Nb-Zr, will be computed. " + "Token: " + token); + } + configs.push_back(c1); delete d; } diff --git a/src/dm_base.cpp b/src/dm_base.cpp index 7fe63ac..c6247e5 100644 --- a/src/dm_base.cpp +++ b/src/dm_base.cpp @@ -1,3 +1,10 @@ #include <tadah/models/descriptors/dm/dm_base.h> void DM_Base::set_rfidx(size_t rfidx_) { rfidx=rfidx_; } size_t DM_Base::get_rfidx() { return rfidx; } + +std::vector<std::string> DM_Base::get_init_atoms(Config &c) { + std::vector<std::string> init_atoms(c.size("TYPEMB")); + c.get("TYPEMB", init_atoms); + init_atoms.erase(init_atoms.begin(), init_atoms.begin() + nparams+1); + return init_atoms; +} diff --git a/src/dm_blip.cpp b/src/dm_blip.cpp index c68b8ec..48408d2 100644 --- a/src/dm_blip.cpp +++ b/src/dm_blip.cpp @@ -40,6 +40,8 @@ DM_Blip::DM_Blip(Config &c): DM_Base(c), std::cout << std::endl; std::cout << "rhoisize: " << rhoisize << std::endl; } + auto init_atoms = get_init_atoms(c); + init_for_atoms(init_atoms); } void DM_Blip::calc_aed( diff --git a/src/dm_ead.cpp b/src/dm_ead.cpp index b7419b3..8cb83aa 100644 --- a/src/dm_ead.cpp +++ b/src/dm_ead.cpp @@ -47,6 +47,8 @@ DM_EAD::DM_EAD(Config &config): DM_Base(config), std::cout << std::endl; std::cout << "rhoisize: " << rhoisize << std::endl; } + auto init_atoms = get_init_atoms(config); + init_for_atoms(init_atoms); } void DM_EAD::calc_aed( diff --git a/src/dm_eam.cpp b/src/dm_eam.cpp index 8450987..3408b29 100644 --- a/src/dm_eam.cpp +++ b/src/dm_eam.cpp @@ -32,6 +32,8 @@ DM_EAM::DM_EAM(Config &c): DM_Base(c) gen_splines(ef.nr, ef.dr, ef.z2r, z2r_spline); rhoisize = 1; + auto init_atoms = get_init_atoms(c); + init_for_atoms(init_atoms); } void DM_EAM::calc_aed( -- GitLab