From a4f0b5cc42becef2fd55a12ae4784e5421a32139 Mon Sep 17 00:00:00 2001 From: mkirsz <s1351949@sms.ed.ac.uk> Date: Wed, 19 Feb 2025 00:21:55 +0000 Subject: [PATCH] Integration of memoryManager --- include/tadah/mlip/models/m_blr.h | 8 ++++++++ include/tadah/mlip/models/m_krr.h | 8 ++++++++ include/tadah/mlip/trainer.h | 2 +- src/m_all.cpp | 17 ++++++++++++++--- 4 files changed, 31 insertions(+), 4 deletions(-) diff --git a/include/tadah/mlip/models/m_blr.h b/include/tadah/mlip/models/m_blr.h index 943a7fe..c8e3eb0 100644 --- a/include/tadah/mlip/models/m_blr.h +++ b/include/tadah/mlip/models/m_blr.h @@ -8,6 +8,7 @@ #include <tadah/mlip/normaliser.h> #include <tadah/models/m_blr_train.h> #include <tadah/core/config.h> +#include <tadah/models/memory/IModelsWorkspaceManager.h> #include <limits> #include <stdexcept> @@ -76,6 +77,13 @@ public: norm = Normaliser(c); } + M_BLR(BF &bf, Config &c, tadah::models::memory::IModelsWorkspaceManager& workspaceManager): + M_BLR_Train<BF>(bf,c,workspaceManager), + desmat(bf,c) + { + norm = Normaliser(c); + } + double epredict(const aed_type &aed) const{ return bf.epredict(weights,aed); }; diff --git a/include/tadah/mlip/models/m_krr.h b/include/tadah/mlip/models/m_krr.h index 92aee1d..ccd2c6c 100644 --- a/include/tadah/mlip/models/m_krr.h +++ b/include/tadah/mlip/models/m_krr.h @@ -9,6 +9,7 @@ #include <tadah/models/m_krr_train.h> #include <tadah/core/config.h> #include <tadah/mlip/models/basis.h> +#include <tadah/models/memory/IModelsWorkspaceManager.h> #include <limits> #include <stdexcept> @@ -79,6 +80,13 @@ public: norm = Normaliser(c); } + M_KRR(K &kernel, Config &c, tadah::models::memory::IModelsWorkspaceManager& workspaceManager): + M_KRR_Train<K>(kernel,c,workspaceManager), + basis(c), + desmat(kernel,c) + { + norm = Normaliser(c); + } double epredict(const aed_type &aed) const { return kernel.epredict(weights,aed); }; diff --git a/include/tadah/mlip/trainer.h b/include/tadah/mlip/trainer.h index e206917..a212a9d 100644 --- a/include/tadah/mlip/trainer.h +++ b/include/tadah/mlip/trainer.h @@ -36,7 +36,7 @@ class Trainer { nnf(config), fb(CONFIG::factory<DM_Function_Base,Config&>( config.get<std::string>("MODEL",1),config)), - model(CONFIG::factory<M_Tadah_Base,DM_Function_Base&,Config&> + model(CONFIG::factory<M_Tadah_Base,DM_Function_Base&,Config&,tadah::models::memory::IModelsWorkspaceManager&> (config.get<std::string>("MODEL",0),*fb,config,workspaceManager)), //(config.get<std::string>("MODEL",0),*fb,config)), dm(*fb, config) diff --git a/src/m_all.cpp b/src/m_all.cpp index 9c67972..fa99c8c 100644 --- a/src/m_all.cpp +++ b/src/m_all.cpp @@ -1,5 +1,16 @@ #include <tadah/mlip/models/m_all.h> +#include <tadah/models/memory/IModelsWorkspaceManager.h> -template<> CONFIG::Registry<M_Tadah_Base,DM_Function_Base&,Config&>::Map CONFIG::Registry<M_Tadah_Base,DM_Function_Base&,Config&>::registry{}; -CONFIG::Registry<M_Tadah_Base,DM_Function_Base&,Config&>::Register<M_KRR<>> M_KRR_1("M_KRR"); -CONFIG::Registry<M_Tadah_Base,DM_Function_Base&,Config&>::Register<M_BLR<>> M_BLR_1("M_BLR"); +template<> +CONFIG::Registry<M_Tadah_Base, DM_Function_Base&, Config&>::Map +CONFIG::Registry<M_Tadah_Base, DM_Function_Base&, Config&>::registry{}; + +CONFIG::Registry<M_Tadah_Base, DM_Function_Base&, Config&>::Register<M_KRR<>> M_KRR_1("M_KRR"); +CONFIG::Registry<M_Tadah_Base, DM_Function_Base&, Config&>::Register<M_BLR<>> M_BLR_1("M_BLR"); + +template<> +CONFIG::Registry<M_Tadah_Base, DM_Function_Base&, Config&, tadah::models::memory::IModelsWorkspaceManager&>::Map +CONFIG::Registry<M_Tadah_Base, DM_Function_Base&, Config&, tadah::models::memory::IModelsWorkspaceManager&>::registry{}; + +CONFIG::Registry<M_Tadah_Base, DM_Function_Base&, Config&, tadah::models::memory::IModelsWorkspaceManager&>::Register<M_BLR<>> M_BLR_2("M_BLR"); +CONFIG::Registry<M_Tadah_Base, DM_Function_Base&, Config&, tadah::models::memory::IModelsWorkspaceManager&>::Register<M_KRR<>> M_KRR_2("M_KRR"); -- GitLab