diff --git a/include/tadah/mlip/trainer.h b/include/tadah/mlip/trainer.h index aa56caeb096263cd1160af647d3b3a3e50e833bb..e2069170a13d0ea6e704751f9430cf8d2d48749b 100644 --- a/include/tadah/mlip/trainer.h +++ b/include/tadah/mlip/trainer.h @@ -8,6 +8,7 @@ #include <tadah/mlip/nn_finder.h> #include <tadah/core/config.h> #include <tadah/models/dc_selector.h> +#include <tadah/models/memory/IModelsWorkspaceManager.h> #include <iostream> @@ -27,6 +28,22 @@ class Trainer { if(fb) delete fb; } + Trainer (Config &c, tadah::models::memory::IModelsWorkspaceManager& workspaceManager): + config(c), + DCS(config), + dc(config,*DCS.d2b,*DCS.d3b,*DCS.dmb, + *DCS.c2b,*DCS.c3b,*DCS.cmb), + nnf(config), + fb(CONFIG::factory<DM_Function_Base,Config&>( + config.get<std::string>("MODEL",1),config)), + model(CONFIG::factory<M_Tadah_Base,DM_Function_Base&,Config&> + (config.get<std::string>("MODEL",0),*fb,config,workspaceManager)), + //(config.get<std::string>("MODEL",0),*fb,config)), + dm(*fb, config) + { + config.postprocess(); + config.check_for_training(); + } Trainer (Config &c): config(c), DCS(config),