diff --git a/bin/tadah_cli.cpp b/bin/tadah_cli.cpp index aabbced45922e0f4a65e2a7a8e61be5eb1cc1eb1..9f3451ee810569a222b7eb789bfb488dd25781c0 100644 --- a/bin/tadah_cli.cpp +++ b/bin/tadah_cli.cpp @@ -23,27 +23,6 @@ void TadahCLI::subcommand_train() { - int rank = 0; - -#ifdef TADAH_BUILD_MPI - int ncpu = 1; - MPI_Comm_rank(MPI_COMM_WORLD, &rank); - MPI_Comm_size(MPI_COMM_WORLD, &ncpu); - if (ncpu<2) { - std::cout << "Minimum number of cpus for an mpi version is 2" << std::endl; - return; - } - if (train->count("--uncertainty")) { - if (rank==0) { - std::cout << "-----------------------------------------------------" << std::endl; - std::cout << "The --uncertainty flag is not supported by MPI build." << std::endl; - std::cout << "-----------------------------------------------" << std::endl; - } - return; - } -#endif - - CLI::Timer timer_tot {"Training", CLI::Timer::Big}; if(train->count("--verbose")) set_verbose(); @@ -59,10 +38,12 @@ void TadahCLI::subcommand_train() { config.add("STRESS", "true"); } - if (rank==0) - if (is_verbose()) std::cout << "Training mode" << std::endl; - #ifdef TADAH_BUILD_MPI + int rank = 0; + int ncpu = 1; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + MPI_Comm_size(MPI_COMM_WORLD, &ncpu); + /* MPI CODE: * The PHI matrix is divided into local phi matrices. * @@ -89,13 +70,25 @@ void TadahCLI::subcommand_train() { * 6. Each worker allocates memory for a local matrix phi */ - MPI_Trainer tr(config); - tr.init(rank, ncpu); - // BEGIN HOST-WORKER if (rank==0) { + if (ncpu<2) { + std::cout << "Minimum number of cpus for an mpi version is 2" << std::endl; + return; + } + if (train->count("--uncertainty")) { + if (rank==0) { + std::cout << "-----------------------------------------------------" << std::endl; + std::cout << "The --uncertainty flag is not supported by MPI build." << std::endl; + std::cout << "-----------------------------------------------" << std::endl; + } + return; + } + if (is_verbose()) std::cout << "Training..." << std::endl; + CLI::Timer timer_tot {"Training", CLI::Timer::Big}; // HOST is waiting for workers requests - MPI_Trainer_HOST HOST(tr, rank, ncpu); + MPI_Trainer_HOST HOST(config, rank, ncpu); + HOST.init(); HOST.prep_wpckgs(); while (true) { if (!HOST.has_packages()) { @@ -105,25 +98,25 @@ void TadahCLI::subcommand_train() { } // probe ANY request from ANY worker - tr.probe(); + HOST.probe(); - if (tr.tag==MPI_Trainer::WORK_TAG) { + if (HOST.tag==MPI_Trainer::WORK_TAG) { HOST.work_tag(); } - else if (tr.tag==MPI_Trainer::DATA_TAG) { + else if (HOST.tag==MPI_Trainer::DATA_TAG) { HOST.data_tag(); } else { throw std::runtime_error("HOST1: Unexpected request from " - + std::to_string(tr.worker)); + + std::to_string(HOST.worker)); } } // work finised, collect remaining data and release all workers int count=1; // count number of release workers, skip host while(true) { - tr.probe(); - if (tr.tag==MPI_Trainer::DATA_TAG) { + HOST.probe(); + if (HOST.tag==MPI_Trainer::DATA_TAG) { HOST.data_tag(); } else { @@ -131,40 +124,45 @@ void TadahCLI::subcommand_train() { } if (count==ncpu) { break; } // count starts from 1 } + + HOST.solve(); + Config param_file = HOST.model->get_param_file(); + param_file.check_pot_file(); + std::ofstream outfile; + outfile.open ("pot.tadah"); + outfile << param_file << std::endl;; + if (is_verbose()) std::cout << timer_tot.to_string() << std::endl; } else { // WORKER - MPI_Trainer_WORKER WORKER(tr, rank, ncpu); + MPI_Trainer_WORKER WORKER(config, rank, ncpu); + WORKER.init(); while (true) { // ask for more work... - MPI_Send (&tr.rows_available, 1, MPI_INT, 0, MPI_Trainer::WORK_TAG, MPI_COMM_WORLD); + MPI_Send (&WORKER.rows_available, 1, MPI_INT, 0, MPI_Trainer::WORK_TAG, MPI_COMM_WORLD); // request from root or from other workers - tr.probe(); + WORKER.probe(); // release a worker - if (tr.tag == MPI_Trainer::RELEASE_TAG) { + if (WORKER.tag == MPI_Trainer::RELEASE_TAG) { WORKER.release_tag(); break; } - else if (tr.tag == MPI_Trainer::WAIT_TAG) { + else if (WORKER.tag == MPI_Trainer::WAIT_TAG) { // do nothing here; ask for more work in the next cycle WORKER.wait_tag(); } - else if (tr.tag == MPI_Trainer::DATA_TAG) { + else if (WORKER.tag == MPI_Trainer::DATA_TAG) { WORKER.data_tag(); } - else if (tr.tag == MPI_Trainer::WORK_TAG) { + else if (WORKER.tag == MPI_Trainer::WORK_TAG) { WORKER.work_tag(); } } + WORKER.solve(); } - // END HOST-WORKER - // All local phi matrices are computed by this point - // - tr.solve(); - - #else // NON MPI VERSION + CLI::Timer timer_tot {"Training", CLI::Timer::Big}; Trainer tr(config); if (is_verbose()) std::cout << "Loading structures..." << std::flush; StructureDB stdb(tr.config); @@ -182,33 +180,27 @@ void TadahCLI::subcommand_train() { if (is_verbose()) std::cout << "Done!" << std::endl; -#endif - - if (rank==0) { - Config param_file = tr.model->get_param_file(); - param_file.check_pot_file(); - std::ofstream outfile; - outfile.open ("pot.tadah"); - outfile << param_file << std::endl;; - -#ifndef TADAH_BUILD_MPI - if(train->count("--uncertainty")) { - t_type weights = tr.model->get_weights(); - t_type unc = tr.model->get_weights_uncertainty(); - Output(param_file,false).print_train_unc(weights, unc); - } -#endif + Config param_file = tr.model->get_param_file(); + param_file.check_pot_file(); + std::ofstream outfile; + outfile.open ("pot.tadah"); + outfile << param_file << std::endl;; - if (is_verbose()) std::cout << timer_tot.to_string() << std::endl; + if(train->count("--uncertainty")) { + t_type weights = tr.model->get_weights(); + t_type unc = tr.model->get_weights_uncertainty(); + Output(param_file,false).print_train_unc(weights, unc); } + if (is_verbose()) std::cout << timer_tot.to_string() << std::endl; +#endif } void TadahCLI::subcommand_predict() { CLI::Timer timer_tot {"Prediction", CLI::Timer::Big}; if(predict->count("--verbose")) set_verbose(); - if (is_verbose()) std::cout << "Prediction mode" << std::endl; + // if (is_verbose()) std::cout << "Prediction..." << std::endl; Config pot_config(pot_file); pot_config.check_for_predict(); pot_config.remove("VERBOSE"); @@ -256,7 +248,7 @@ void TadahCLI::subcommand_predict() { DescriptorsCalc<> dc(pot_config,*DCS.d2b,*DCS.d3b,*DCS.dmb, *DCS.c2b,*DCS.c3b,*DCS.cmb); - if (is_verbose()) std::cout << "Prediction start..." << std::flush; + if (is_verbose()) std::cout << "Prediction..." << std::flush; DM_Function_Base *fb = CONFIG::factory<DM_Function_Base,Config&>( pot_config.get<std::string>("MODEL",1),pot_config); M_Tadah_Base *modelp = CONFIG::factory<M_Tadah_Base,DM_Function_Base&,Config&>( @@ -328,25 +320,25 @@ void TadahCLI::subcommand_hpo( #endif if (rank==0) - if (is_verbose()) std::cout << "HPO mode" << std::endl; + if (is_verbose()) std::cout << "Optimising HPs..." << std::endl; - CLI::Timer timer_tot {"HPO", CLI::Timer::Big}; + CLI::Timer timer_tot {"HPO", CLI::Timer::Big}; - if(hpo->count("--verbose")) - set_verbose(); - Config config(config_file); - config.remove("CHECKPRESS"); - config.add("CHECKPRESS", "true"); - config.check_for_training(); + if(hpo->count("--verbose")) + set_verbose(); + Config config(config_file); + config.remove("CHECKPRESS"); + config.add("CHECKPRESS", "true"); + config.check_for_training(); - if(hpo->count("--Force")) { - config.remove("FORCE"); - config.add("FORCE", "true"); - } - if(hpo->count("--Stress")) { - config.remove("STRESS"); - config.add("STRESS", "true"); - } + if(hpo->count("--Force")) { + config.remove("FORCE"); + config.add("FORCE", "true"); + } + if(hpo->count("--Stress")) { + config.remove("STRESS"); + config.add("STRESS", "true"); + } #ifdef TADAH_BUILD_MPI if (rank==0) { @@ -364,48 +356,46 @@ void TadahCLI::subcommand_hpo( // TODO broadcase config to all workers instead of dumping it on the disk Config config("config.temp"); - MPI_Trainer tr(config); - tr.init(rank, ncpu); - MPI_Trainer_WORKER WORKER(tr, rank, ncpu); + MPI_Trainer_WORKER WORKER(config, rank, ncpu); + WORKER.init(); while (true) { // ask for more work... - MPI_Send (&tr.rows_available, 1, MPI_INT, 0, MPI_Trainer::WORK_TAG, MPI_COMM_WORLD); + MPI_Send (&WORKER.rows_available, 1, MPI_INT, 0, MPI_Trainer::WORK_TAG, MPI_COMM_WORLD); // request from root or from other workers - tr.probe(); + WORKER.probe(); // release a worker - if (tr.tag == MPI_Trainer::RELEASE_TAG) { + if (WORKER.tag == MPI_Trainer::RELEASE_TAG) { WORKER.release_tag(); break; } - else if (tr.tag == MPI_Trainer::WAIT_TAG) { + else if (WORKER.tag == MPI_Trainer::WAIT_TAG) { // do nothing here; ask for more work in the next cycle WORKER.wait_tag(); } - else if (tr.tag == MPI_Trainer::DATA_TAG) { + else if (WORKER.tag == MPI_Trainer::DATA_TAG) { WORKER.data_tag(); } - else if (tr.tag == MPI_Trainer::WORK_TAG) { + else if (WORKER.tag == MPI_Trainer::WORK_TAG) { WORKER.work_tag(); } } - - tr.solve(); + WORKER.solve(); } } #else - hpo_run(config, target_file); + hpo_run(config, target_file); #endif - if (rank==0) - if (is_verbose()) std::cout << timer_tot.to_string() << std::endl; + if (rank==0) + if (is_verbose()) std::cout << timer_tot.to_string() << std::endl; #else - std::cout << "-----------------------------------------------" << std::endl; - std::cout << "This subcommand is not supported by this build." << std::endl; - std::cout << "Tadah! Must by compiled with HPO support." << std::endl; - std::cout << "See documentation for details." << std::endl; - std::cout << "-----------------------------------------------" << std::endl; + std::cout << "-----------------------------------------------" << std::endl; + std::cout << "This subcommand is not supported by this build." << std::endl; + std::cout << "Tadah! Must by compiled with HPO support." << std::endl; + std::cout << "See documentation for details." << std::endl; + std::cout << "-----------------------------------------------" << std::endl; #endif }