diff --git a/include/tadah/mlip/models/m_blr.h b/include/tadah/mlip/models/m_blr.h index 943a7fea1a65056aa517bc03e1d9b3e548fa6479..c8e3eb0c9479d52de33c8d237af1efeb9618cee7 100644 --- a/include/tadah/mlip/models/m_blr.h +++ b/include/tadah/mlip/models/m_blr.h @@ -8,6 +8,7 @@ #include <tadah/mlip/normaliser.h> #include <tadah/models/m_blr_train.h> #include <tadah/core/config.h> +#include <tadah/models/memory/IModelsWorkspaceManager.h> #include <limits> #include <stdexcept> @@ -76,6 +77,13 @@ public: norm = Normaliser(c); } + M_BLR(BF &bf, Config &c, tadah::models::memory::IModelsWorkspaceManager& workspaceManager): + M_BLR_Train<BF>(bf,c,workspaceManager), + desmat(bf,c) + { + norm = Normaliser(c); + } + double epredict(const aed_type &aed) const{ return bf.epredict(weights,aed); }; diff --git a/include/tadah/mlip/models/m_krr.h b/include/tadah/mlip/models/m_krr.h index 92aee1d7c9d56324409000cd8c69710898bd44f3..ccd2c6c903949892a8b2e932700be975b285a7ca 100644 --- a/include/tadah/mlip/models/m_krr.h +++ b/include/tadah/mlip/models/m_krr.h @@ -9,6 +9,7 @@ #include <tadah/models/m_krr_train.h> #include <tadah/core/config.h> #include <tadah/mlip/models/basis.h> +#include <tadah/models/memory/IModelsWorkspaceManager.h> #include <limits> #include <stdexcept> @@ -79,6 +80,13 @@ public: norm = Normaliser(c); } + M_KRR(K &kernel, Config &c, tadah::models::memory::IModelsWorkspaceManager& workspaceManager): + M_KRR_Train<K>(kernel,c,workspaceManager), + basis(c), + desmat(kernel,c) + { + norm = Normaliser(c); + } double epredict(const aed_type &aed) const { return kernel.epredict(weights,aed); }; diff --git a/include/tadah/mlip/trainer.h b/include/tadah/mlip/trainer.h index e2069170a13d0ea6e704751f9430cf8d2d48749b..a212a9dcf14895d8111f71aaa531535d3c219d5a 100644 --- a/include/tadah/mlip/trainer.h +++ b/include/tadah/mlip/trainer.h @@ -36,7 +36,7 @@ class Trainer { nnf(config), fb(CONFIG::factory<DM_Function_Base,Config&>( config.get<std::string>("MODEL",1),config)), - model(CONFIG::factory<M_Tadah_Base,DM_Function_Base&,Config&> + model(CONFIG::factory<M_Tadah_Base,DM_Function_Base&,Config&,tadah::models::memory::IModelsWorkspaceManager&> (config.get<std::string>("MODEL",0),*fb,config,workspaceManager)), //(config.get<std::string>("MODEL",0),*fb,config)), dm(*fb, config) diff --git a/src/m_all.cpp b/src/m_all.cpp index 9c67972b5feb51734dcf20972aec98237cbfab37..fa99c8c93ceed85089bc90613e732d6a86d47a39 100644 --- a/src/m_all.cpp +++ b/src/m_all.cpp @@ -1,5 +1,16 @@ #include <tadah/mlip/models/m_all.h> +#include <tadah/models/memory/IModelsWorkspaceManager.h> -template<> CONFIG::Registry<M_Tadah_Base,DM_Function_Base&,Config&>::Map CONFIG::Registry<M_Tadah_Base,DM_Function_Base&,Config&>::registry{}; -CONFIG::Registry<M_Tadah_Base,DM_Function_Base&,Config&>::Register<M_KRR<>> M_KRR_1("M_KRR"); -CONFIG::Registry<M_Tadah_Base,DM_Function_Base&,Config&>::Register<M_BLR<>> M_BLR_1("M_BLR"); +template<> +CONFIG::Registry<M_Tadah_Base, DM_Function_Base&, Config&>::Map +CONFIG::Registry<M_Tadah_Base, DM_Function_Base&, Config&>::registry{}; + +CONFIG::Registry<M_Tadah_Base, DM_Function_Base&, Config&>::Register<M_KRR<>> M_KRR_1("M_KRR"); +CONFIG::Registry<M_Tadah_Base, DM_Function_Base&, Config&>::Register<M_BLR<>> M_BLR_1("M_BLR"); + +template<> +CONFIG::Registry<M_Tadah_Base, DM_Function_Base&, Config&, tadah::models::memory::IModelsWorkspaceManager&>::Map +CONFIG::Registry<M_Tadah_Base, DM_Function_Base&, Config&, tadah::models::memory::IModelsWorkspaceManager&>::registry{}; + +CONFIG::Registry<M_Tadah_Base, DM_Function_Base&, Config&, tadah::models::memory::IModelsWorkspaceManager&>::Register<M_BLR<>> M_BLR_2("M_BLR"); +CONFIG::Registry<M_Tadah_Base, DM_Function_Base&, Config&, tadah::models::memory::IModelsWorkspaceManager&>::Register<M_KRR<>> M_KRR_2("M_KRR");