diff --git a/bin/tadah_cli.cpp b/bin/tadah_cli.cpp index ee7351d44686f1f339f8d68a4a26d676a3c9208d..1fe6ee6973fbf3b39fabeca31d54640b682dcb5e 100644 --- a/bin/tadah_cli.cpp +++ b/bin/tadah_cli.cpp @@ -23,6 +23,10 @@ void TadahCLI::subcommand_train() { /* MPI CODE: * 1. each rank reads config file * 2. rank 0 should calculate total number of structures + * 3. rank 0 should be able to calculate the size of the PHI matrix + * based on force and stress flags + * Who should be able to do this? We do not want to load all data + * 4. rank 0 should send the local matrix sizes to worker processes */ @@ -31,8 +35,6 @@ void TadahCLI::subcommand_train() { set_verbose(); Config config(config_file); config.check_for_training(); - if (is_verbose()) std::cout << "Training mode" << std::endl; - DC_Selector DCS(config); if(train->count("--Force")) { config.remove("FORCE"); @@ -43,10 +45,40 @@ void TadahCLI::subcommand_train() { config.add("STRESS", "true"); } + if (is_verbose()) std::cout << "Training mode" << std::endl; + DC_Selector DCS(config); + + // must set DSIZE key as early as possible, + // this will allow us to querry for the number of columns later on + DescriptorsCalc<> dc(config,*DCS.d2b,*DCS.d3b,*DCS.dmb, + *DCS.c2b,*DCS.c3b,*DCS.cmb); + + DM_Function_Base *fb = CONFIG::factory<DM_Function_Base,Config&>( + config.get<std::string>("MODEL",1),config); + + + M_Tadah_Base *model = CONFIG::factory<M_Tadah_Base,DM_Function_Base&,Config&> + (config.get<std::string>("MODEL",0),*fb,config); + + std::cout << "PHI COLS: " << fb->get_phi_cols(config) << std::endl; + //std::cout << "PHI ROWS: " << ??? << std::endl; + + if (is_verbose()) std::cout << "Loading structures..." << std::flush; StructureDB stdb(config); if (is_verbose()) std::cout << "Done!" << std::endl; - std::cout << "Test count: " << StructureDB::count_structures(config); + + std::cout << "Total structure count: " << StructureDB::count(config).first << std::endl; + std::cout << "Total natoms count: " << StructureDB::count(config).second << std::endl; + + StructureDB temp; + int nst = temp.add("db.train", 99, 3); + std::cout << "Number of structures loaded: " << temp.size() << std::endl; + std::cout << "Number of structures loaded: " << nst << std::endl; + std::cout << temp << std::endl; + std::cout << temp(0) << std::endl; + + if (is_verbose()) std::cout << "Finding nearest neighbours within: " << config.get<double>("RCUTMAX") << " cutoff distance..." << std::flush; @@ -55,13 +87,8 @@ void TadahCLI::subcommand_train() { if (is_verbose()) std::cout << "Done!" << std::endl; if (is_verbose()) std::cout << "Training start..." << std::flush; - DM_Function_Base *fb = CONFIG::factory<DM_Function_Base,Config&>( - config.get<std::string>("MODEL",1),config); - M_Tadah_Base *model = CONFIG::factory<M_Tadah_Base,DM_Function_Base&,Config&> - (config.get<std::string>("MODEL",0),*fb,config); - DescriptorsCalc<> dc(config,*DCS.d2b,*DCS.d3b,*DCS.dmb, - *DCS.c2b,*DCS.c3b,*DCS.cmb); + model->train(stdb,dc);