diff --git a/include/tadah/mlip/design_matrix/design_matrix.h b/include/tadah/mlip/design_matrix/design_matrix.h index b98239c66056dbd8a30e9594fbb8afe12a658a39..168b94b3a9409140dd408c150489b323ec192ab7 100644 --- a/include/tadah/mlip/design_matrix/design_matrix.h +++ b/include/tadah/mlip/design_matrix/design_matrix.h @@ -10,6 +10,8 @@ #include <tadah/mlip/memory/MLIPWorkspaceManager.h> #include <tadah/mlip/memory/DesignMatrixWorkspace.h> +#include <tadah/mlip/neighbor_list_db.h> + #include <stdexcept> namespace tadah { @@ -73,7 +75,7 @@ class DesignMatrix : public DesignMatrixBase { public: F f; - tadah::mlip::memory::DesignMatrixWorkspace *ws; + memory::DesignMatrixWorkspace *ws; bool scale=true; // Control escale,fscale,sscale double e_std_dev=1; @@ -92,7 +94,7 @@ public: * * \endcode */ - DesignMatrix(F &f, Config &c, tadah::mlip::memory::IMLIPWorkspaceManager& workspaceManager) + DesignMatrix(F &f, Config &c, memory::IMLIPWorkspaceManager& workspaceManager) : f(f), workspaceManager_(&workspaceManager), ownWorkspaceManager(false), @@ -114,7 +116,7 @@ public: // Constructor without workspaceManager_ parameter (delegating constructor) DesignMatrix(F &f, Config &c) - : DesignMatrix(f, c, *new tadah::mlip::memory::MLIPWorkspaceManager()) + : DesignMatrix(f, c, *new memory::MLIPWorkspaceManager()) { ownWorkspaceManager = true; // Set ownership flag } @@ -166,8 +168,10 @@ public: /** \brief Calculate descriptors and build design matrix. */ template <typename DC> - void build(const StructureDB &stdb, Normaliser &norm, + void build(const StructureDB &stdb, NeighbourListDB &nldb, + Normaliser &norm, DC &dc) { + //DescriptorsCalc<D2,D3,DM,C2,C3,CM> dc(config); calc_mn(stdb); ws = workspaceManager_->getDesignMatrixWorkspace(rows,cols); @@ -197,12 +201,12 @@ public: #pragma omp parallel for #endif for (size_t s=0; s<stdb.size(); ++s) { - StDescriptors st_d = dc.calc(stdb(s)); + StDescriptors st_d = dc.calc(stdb(s), nldb); if(config.get<bool>("NORM")) norm.normalise(st_d); - build(rows[s], stdb(s), st_d); + build(rows[s], stdb(s), nldb, st_d); } } @@ -266,7 +270,7 @@ t_type &getT() { return ws->T; } t_type &getTlabels() { return ws->Tlabels; } private: - tadah::mlip::memory::IMLIPWorkspaceManager* workspaceManager_ = nullptr; + memory::IMLIPWorkspaceManager* workspaceManager_ = nullptr; bool ownWorkspaceManager = false; Config & config; size_t rows = 0; diff --git a/include/tadah/mlip/design_matrix/functions/basis_functions/dm_bf_base.h b/include/tadah/mlip/design_matrix/functions/basis_functions/dm_bf_base.h index 8dc9cc4f46469270a69e71dfbc9ad0555364deb1..0cd2f6daccdf1c8ffb95dba6b240ae1647aa8e49 100644 --- a/include/tadah/mlip/design_matrix/functions/basis_functions/dm_bf_base.h +++ b/include/tadah/mlip/design_matrix/functions/basis_functions/dm_bf_base.h @@ -2,6 +2,7 @@ #define DM_BASIS_FUNCTIONS_H #include <tadah/mlip/design_matrix/functions/dm_function_base.h> +#include <tadah/mlip/atom.h> #include <tadah/mlip/structure.h> #include <tadah/mlip/st_descriptors.h> #include <tadah/core/core_types.h> diff --git a/include/tadah/mlip/models/basis.h b/include/tadah/mlip/models/basis.h index f00caf0953e911cc133987bca99070a44e24334e..d5b699b122276647e6a4fc5a686e158f9a72d1c8 100644 --- a/include/tadah/mlip/models/basis.h +++ b/include/tadah/mlip/models/basis.h @@ -11,6 +11,9 @@ #include <stdexcept> #include <vector> +namespace tadah { +namespace mlip { + template <typename K> class Basis { private: @@ -105,4 +108,6 @@ larger than the amount of available AEDs\n"); } } }; +} +} #endif diff --git a/include/tadah/mlip/models/m_all.h b/include/tadah/mlip/models/m_all.h index d155789207e9ecc83fe25187a50bd95dd15d526b..5a4e798ecfbb142ea2b76579caa4b4f32a6ec0fe 100644 --- a/include/tadah/mlip/models/m_all.h +++ b/include/tadah/mlip/models/m_all.h @@ -1,3 +1,7 @@ //#include "m_tadah_base.h" #include <tadah/mlip/models/m_blr.h> #include <tadah/mlip/models/m_krr.h> +namespace tadah { +namespace mlip { +} +} diff --git a/include/tadah/mlip/models/m_blr.h b/include/tadah/mlip/models/m_blr.h index 7ba9d85a5fe8dc269bd193b9fabf44c4475f7c06..ceca9024b920f987d0e6c9211107e7576aa2df91 100644 --- a/include/tadah/mlip/models/m_blr.h +++ b/include/tadah/mlip/models/m_blr.h @@ -15,6 +15,8 @@ #include <type_traits> #include <iostream> +namespace tadah { +namespace mlip { /** * @class M_BLR * @brief Bayesian Linear Regression (BLR). @@ -105,7 +107,7 @@ public: train(desmat); } - void train(StructureDB &stdb, DC_Base &dc) { + void train(const StructureDB &stdb, const NeighborListDB &nldb, DC_Base &dc) { if(config.template get<bool>("NORM")) { @@ -131,7 +133,7 @@ public: config.add("STRESS", stress); } - desmat.build(stdb,norm,dc); + desmat.build(stdb,nldb,norm,dc); train(desmat); } @@ -332,4 +334,6 @@ non linear basis function\n"); using M_BLR_Train<BF>::weights; using M_BLR_Train<BF>::Sigma; }; +} +} #endif diff --git a/include/tadah/mlip/models/m_krr.h b/include/tadah/mlip/models/m_krr.h index 662e200b73e97b45269638fee12edf3d8592ba36..609c3803a35613cb35c2f288ea6a30e6cb4c11ba 100644 --- a/include/tadah/mlip/models/m_krr.h +++ b/include/tadah/mlip/models/m_krr.h @@ -16,6 +16,8 @@ #include <type_traits> #include <iostream> +namespace tadah { +namespace mlip { /** * @class M_KRR * @brief Kernel Ridge Regression (KRR) with Empirical Kernel Map (EKM). @@ -405,4 +407,6 @@ non linear kernel\n"); using M_KRR_Train<K>::kernel; using M_KRR_Train<K>::ekm; }; +} +} #endif diff --git a/include/tadah/mlip/models/m_tadah_base.h b/include/tadah/mlip/models/m_tadah_base.h index 3285cf0bf5572a1551c60ffaa0d78e74a66b2304..4522eefdd94a1cc54ffe39b68cd3adb6a5699cf7 100644 --- a/include/tadah/mlip/models/m_tadah_base.h +++ b/include/tadah/mlip/models/m_tadah_base.h @@ -11,6 +11,8 @@ #include <tadah/models/m_core.h> #include <tadah/models/m_predict.h> +namespace tadah { +namespace mlip { /** This interface provides functionality required from all models. */ class M_Tadah_Base: @@ -102,5 +104,7 @@ public: virtual StructureDB predict(StructureDB &stdb)=0; }; +} +} //template<> inline Registry<M_Tadah_Base,DM_Function_Base&,Config&>::Map Registry<M_Tadah_Base,DM_Function_Base&,Config&>::registry{}; #endif diff --git a/include/tadah/mlip/neighbor_calc.h b/include/tadah/mlip/neighbor_calc.h index e2e7e5962e78594d4155296f2a875c00e8048762..a26b9a7d48df5b2c3213c5cf8293f77f39171e8f 100644 --- a/include/tadah/mlip/neighbor_calc.h +++ b/include/tadah/mlip/neighbor_calc.h @@ -1,6 +1,6 @@ #pragma once -#include "neighbor_list.h" +#include <tadah/mlip/neighbor_list.h> #include <vector> #include <cmath> @@ -103,6 +103,8 @@ public: } } } + void build (const StructureDB &stdb, NeighborList &nList, double cutoff) { + } }; } // end namespace mlip diff --git a/include/tadah/mlip/neighbor_list_db.h b/include/tadah/mlip/neighbor_list_db.h index 1143607341788e9031f0e3666711eed406ce3495..f7160700cef8a7f4a4b42e44a3a278ce2477f009 100644 --- a/include/tadah/mlip/neighbor_list_db.h +++ b/include/tadah/mlip/neighbor_list_db.h @@ -3,8 +3,9 @@ #include <cstddef> #include <stdexcept> -#include "neighbor_list.h" -#include "structure_db.h" +#include <tadah/mlip/neighbor_list.h> +#include <tadah/mlip/structure_db.h> +#include <tadah/mlip/structure_neighbor_view.h> namespace tadah { namespace mlip { @@ -101,6 +102,22 @@ public: return nlist_.getShiftXPtr(gIdx); } + /** + * @brief Creates a StructureNeighborView for a given Structure, enabling local-atom queries. + * + * This method constructs a view object that references nlist_ (the global HPC arrays) + * and the selected structure. The code can then call numNeighbors(localIdx) or getNeighborsPtr(localIdx) + * without manual offset logic. + * + * @param s A reference to the chosen Structure from the same DB. + * @return A lightweight view object referencing the HPC data subrange for s. + */ + inline StructureNeighborView createStructureView(const Structure &s) const + { + // Optionally check that s.db_ == &db_ if you want to ensure they belong together. + return StructureNeighborView(nlist_, s); + } + private: /** * @brief A read-only reference to the HPC structure database. diff --git a/include/tadah/mlip/structure_neighbor_view.h b/include/tadah/mlip/structure_neighbor_view.h new file mode 100644 index 0000000000000000000000000000000000000000..78dc5c72e56efc88a6a2e9be34ac1657961a21a6 --- /dev/null +++ b/include/tadah/mlip/structure_neighbor_view.h @@ -0,0 +1,122 @@ +#pragma once + +#include <cstddef> +#include <stdexcept> +#include "neighbor_list.h" +#include "structure_db.h" + +namespace tadah { +namespace mlip { + +/** + * @class StructureNeighborView + * @brief Creates a read-only slice of the global NeighborList for a single Structure. + * + * This design references a specific "Structure" (with known offset_ and natoms()) + * and a global "NeighborList" that covers all atoms in the HPC arrays. + * Methods compute each atom's global index on the fly (structure.offset_ + localIndex) + * and retrieve the neighbor data from the HPC neighbor list. + * + * No duplication of neighbor arrays occurs. + * This allows queries on "per-structure" neighbor data while still leveraging + * the single global HPC neighbor list in memory. + */ +class StructureNeighborView +{ +public: + /** + * @brief Constructs the view with references to a single Structure and the global HPC list. + * + * The code checks that the Structure's database pointer matches the DB + * that the "NeighborList" belongs to, if desired. That step is optional + * and can be enforced by additional runtime checks if multiple DBs exist. + * + * @param globalList Reference to the global HPC neighbor list. + * @param structure Reference to the specific Structure for local atom offsets. + */ + StructureNeighborView(const NeighborList &globalList, const Structure &structure) + : globalList_(globalList), str_(structure) + { + // Optional: check that str_.db_ matches the DB used by the globalList if needed. + } + + /** + * @brief Returns how many atoms this structure possesses in its subrange. + */ + inline std::size_t numAtoms() const + { + return str_.natoms(); + } + + /** + * @brief Returns how many neighbors atom (localAtomIndex) has within this structure's sub-slice. + * + * Internally, the HPC neighbor list is indexed by a global atom index. + * The method computes: globalIdx = str_.offset_ + localAtomIndex. + */ + inline std::size_t numNeighbors(std::size_t localAtomIndex) const + { + if (localAtomIndex >= str_.natoms()) { + throw std::out_of_range("StructureNeighborView::numNeighbors: local atom index out of range"); + } + std::size_t gIdx = str_.offset_ + localAtomIndex; + return globalList_.numNeighbors(gIdx); + } + + /** + * @brief Returns a pointer to the neighbor indices for atom (localAtomIndex) in this structure. + * + * The caller typically pairs this with numNeighbors() to iterate over neighbor indices. + */ + inline const std::size_t* getNeighborsPtr(std::size_t localAtomIndex) const + { + if (localAtomIndex >= str_.natoms()) { + throw std::out_of_range("StructureNeighborView::getNeighborsPtr: local atom index out of range"); + } + std::size_t gIdx = str_.offset_ + localAtomIndex; + return globalList_.getNeighborsPtr(gIdx); + } + + /** + * @brief Returns a pointer to the mirror index array for symmetrical pairs, if needed. + */ + inline const std::size_t* getMirrorIndexPtr(std::size_t localAtomIndex) const + { + if (localAtomIndex >= str_.natoms()) { + throw std::out_of_range("StructureNeighborView::getMirrorIndexPtr: local atom index out of range"); + } + std::size_t gIdx = str_.offset_ + localAtomIndex; + return globalList_.getMirrorIndexPtr(gIdx); + } + + /** + * @brief Retrieves X shift array pointer for atom (localAtomIndex). + * Similar methods can be added for Y, Z, etc. + */ + inline const int* getShiftXPtr(std::size_t localAtomIndex) const + { + if (localAtomIndex >= str_.natoms()) { + throw std::out_of_range("StructureNeighborView::getShiftXPtr: local atom index out of range"); + } + std::size_t gIdx = str_.offset_ + localAtomIndex; + return globalList_.getShiftXPtr(gIdx); + } + + // Additional shiftYPtr, shiftZPtr, etc., can be provided if your code needs them. + +private: + /** + * @brief Reference to the global HPC neighbor list that covers all atoms. + */ + const NeighborList &globalList_; + + /** + * @brief Reference to the specific structure. + * offset_ + localIndex => global HPC neighbor index. + */ + const Structure &str_; +}; + +} // namespace mlip +} // namespace tadah + diff --git a/include/tadah/mlip/trainer.h b/include/tadah/mlip/trainer.h index ef20679ba3a99c4c9bf1a8897659379a80d9ea8a..8ebc323b3484b1ef55684cc0f5c9d8696a126a19 100644 --- a/include/tadah/mlip/trainer.h +++ b/include/tadah/mlip/trainer.h @@ -11,6 +11,9 @@ #include <tadah/mlip/memory/IMLIPWorkspaceManager.h> #include <tadah/mlip/memory/MLIPWorkspaceManager.h> +#include <tadah/mlip/neighbor_calc.h> +#include <tadah/mlip/neighbor_list_db.h> + #include <iostream> namespace tadah { @@ -20,10 +23,11 @@ class Trainer { Config config; DC_Selector DCS; DescriptorsCalc<> dc; - NNFinder nnf; + // NNFinder nnf; DM_Function_Base *fb; M_Tadah_Base *model; DesignMatrix<DM_Function_Base&> dm; + StructureDB stdb; ~Trainer() { if(model) @@ -36,7 +40,7 @@ class Trainer { DCS(config), dc(config,*DCS.d2b,*DCS.d3b,*DCS.dmb, *DCS.c2b,*DCS.c3b,*DCS.cmb), - nnf(config), + // nnf(config), fb(CONFIG::factory<DM_Function_Base,Config&>( config.get<std::string>("MODEL",1),config)), model(CONFIG::factory<M_Tadah_Base,DM_Function_Base&,Config&,tadah::mlip::memory::IMLIPWorkspaceManager&> @@ -48,7 +52,7 @@ class Trainer { } void train(StructureDB &stdb) { - nnf.calc(stdb); + // nnf.calc(stdb); model->train(stdb,dc); }