diff --git a/include/tadah/mlip/design_matrix/design_matrix.h b/include/tadah/mlip/design_matrix/design_matrix.h index 43efa09d70bfd298dc62c514e8b9f3b792c6f764..c9c02cb42c38f7da8e66b6da196243f98f3873e1 100644 --- a/include/tadah/mlip/design_matrix/design_matrix.h +++ b/include/tadah/mlip/design_matrix/design_matrix.h @@ -134,6 +134,9 @@ public: void build(StDescriptorsDB &st_desc_db, const StructureDB &stdb) { calc_mn(stdb); ws = workspaceManager_->getDesignMatrixWorkspace(rows,cols); + ws->Phi.set_zero(); + ws->T.set_zero(); + ws->Tlabels.set_zero(); compute_stdevs(stdb); fill_T(stdb); std::vector<size_t> rows(stdb.size()); @@ -166,6 +169,9 @@ public: //DescriptorsCalc<D2,D3,DM,C2,C3,CM> dc(config); calc_mn(stdb); ws = workspaceManager_->getDesignMatrixWorkspace(rows,cols); + ws->Phi.set_zero(); + ws->T.set_zero(); + ws->Tlabels.set_zero(); compute_stdevs(stdb); fill_T(stdb); // for opm we need to find first rows for each structure diff --git a/include/tadah/mlip/models/m_blr.h b/include/tadah/mlip/models/m_blr.h index 596730ad23bd098f6ce2189a2e8ba1350c97a112..7ba9d85a5fe8dc269bd193b9fabf44c4475f7c06 100644 --- a/include/tadah/mlip/models/m_blr.h +++ b/include/tadah/mlip/models/m_blr.h @@ -79,7 +79,7 @@ public: M_BLR(BF &bf, Config &c, tadah::mlip::memory::IMLIPWorkspaceManager& workspaceManager): M_BLR_Train<BF>(bf,c,workspaceManager), - desmat(bf,c) + desmat(bf,c,workspaceManager) { norm = Normaliser(c); } @@ -223,7 +223,6 @@ public: Hint: check different predict() methods."); phi_type &Phi = desmat.getPhi(); - //std::cout << Phi.row(0) << std::endl; // compute energy, forces and stresses aed_type Tpred = T_dgemv(Phi, weights); diff --git a/include/tadah/mlip/trainer.h b/include/tadah/mlip/trainer.h index 629e593c6328c4c3d945ad5587d020b43a2af6c6..9bbf598130b961ffbbf1ad5e82a3dceba24eb985 100644 --- a/include/tadah/mlip/trainer.h +++ b/include/tadah/mlip/trainer.h @@ -9,6 +9,7 @@ #include <tadah/core/config.h> #include <tadah/models/dc_selector.h> #include <tadah/mlip/memory/IMLIPWorkspaceManager.h> +#include <tadah/mlip/memory/MLIPWorkspaceManager.h> #include <iostream> @@ -38,27 +39,11 @@ class Trainer { config.get<std::string>("MODEL",1),config)), model(CONFIG::factory<M_Tadah_Base,DM_Function_Base&,Config&,tadah::mlip::memory::IMLIPWorkspaceManager&> (config.get<std::string>("MODEL",0),*fb,config,workspaceManager)), - //(config.get<std::string>("MODEL",0),*fb,config)), dm(*fb, config, workspaceManager) { config.postprocess(); config.check_for_training(); } - Trainer (Config &c): - config(c), - DCS(config), - dc(config,*DCS.d2b,*DCS.d3b,*DCS.dmb, - *DCS.c2b,*DCS.c3b,*DCS.cmb), - nnf(config), - fb(CONFIG::factory<DM_Function_Base,Config&>( - config.get<std::string>("MODEL",1),config)), - model(CONFIG::factory<M_Tadah_Base,DM_Function_Base&,Config&> - (config.get<std::string>("MODEL",0),*fb,config)), - dm(*fb, config) - { - config.postprocess(); - config.check_for_training(); - } void train(StructureDB &stdb) { nnf.calc(stdb);