Skip to content
Snippets Groups Projects

MPI version of Tadah

Closed Marcin Kirsz requested to merge develop into main
1 file
+ 91
101
Compare changes
  • Side-by-side
  • Inline
+ 91
101
@@ -23,27 +23,6 @@
void TadahCLI::subcommand_train() {
int rank = 0;
#ifdef TADAH_BUILD_MPI
int ncpu = 1;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &ncpu);
if (ncpu<2) {
std::cout << "Minimum number of cpus for an mpi version is 2" << std::endl;
return;
}
if (train->count("--uncertainty")) {
if (rank==0) {
std::cout << "-----------------------------------------------------" << std::endl;
std::cout << "The --uncertainty flag is not supported by MPI build." << std::endl;
std::cout << "-----------------------------------------------" << std::endl;
}
return;
}
#endif
CLI::Timer timer_tot {"Training", CLI::Timer::Big};
if(train->count("--verbose"))
set_verbose();
@@ -59,10 +38,12 @@ void TadahCLI::subcommand_train() {
config.add("STRESS", "true");
}
if (rank==0)
if (is_verbose()) std::cout << "Training mode" << std::endl;
#ifdef TADAH_BUILD_MPI
int rank = 0;
int ncpu = 1;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &ncpu);
/* MPI CODE:
* The PHI matrix is divided into local phi matrices.
*
@@ -89,13 +70,25 @@ void TadahCLI::subcommand_train() {
* 6. Each worker allocates memory for a local matrix phi
*/
MPI_Trainer tr(config);
tr.init(rank, ncpu);
// BEGIN HOST-WORKER
if (rank==0) {
if (ncpu<2) {
std::cout << "Minimum number of cpus for an mpi version is 2" << std::endl;
return;
}
if (train->count("--uncertainty")) {
if (rank==0) {
std::cout << "-----------------------------------------------------" << std::endl;
std::cout << "The --uncertainty flag is not supported by MPI build." << std::endl;
std::cout << "-----------------------------------------------" << std::endl;
}
return;
}
if (is_verbose()) std::cout << "Training..." << std::endl;
CLI::Timer timer_tot {"Training", CLI::Timer::Big};
// HOST is waiting for workers requests
MPI_Trainer_HOST HOST(tr, rank, ncpu);
MPI_Trainer_HOST HOST(config, rank, ncpu);
HOST.init();
HOST.prep_wpckgs();
while (true) {
if (!HOST.has_packages()) {
@@ -105,25 +98,25 @@ void TadahCLI::subcommand_train() {
}
// probe ANY request from ANY worker
tr.probe();
HOST.probe();
if (tr.tag==MPI_Trainer::WORK_TAG) {
if (HOST.tag==MPI_Trainer::WORK_TAG) {
HOST.work_tag();
}
else if (tr.tag==MPI_Trainer::DATA_TAG) {
else if (HOST.tag==MPI_Trainer::DATA_TAG) {
HOST.data_tag();
}
else {
throw std::runtime_error("HOST1: Unexpected request from "
+ std::to_string(tr.worker));
+ std::to_string(HOST.worker));
}
}
// work finised, collect remaining data and release all workers
int count=1; // count number of release workers, skip host
while(true) {
tr.probe();
if (tr.tag==MPI_Trainer::DATA_TAG) {
HOST.probe();
if (HOST.tag==MPI_Trainer::DATA_TAG) {
HOST.data_tag();
}
else {
@@ -131,40 +124,45 @@ void TadahCLI::subcommand_train() {
}
if (count==ncpu) { break; } // count starts from 1
}
HOST.solve();
Config param_file = HOST.model->get_param_file();
param_file.check_pot_file();
std::ofstream outfile;
outfile.open ("pot.tadah");
outfile << param_file << std::endl;;
if (is_verbose()) std::cout << timer_tot.to_string() << std::endl;
}
else { // WORKER
MPI_Trainer_WORKER WORKER(tr, rank, ncpu);
MPI_Trainer_WORKER WORKER(config, rank, ncpu);
WORKER.init();
while (true) {
// ask for more work...
MPI_Send (&tr.rows_available, 1, MPI_INT, 0, MPI_Trainer::WORK_TAG, MPI_COMM_WORLD);
MPI_Send (&WORKER.rows_available, 1, MPI_INT, 0, MPI_Trainer::WORK_TAG, MPI_COMM_WORLD);
// request from root or from other workers
tr.probe();
WORKER.probe();
// release a worker
if (tr.tag == MPI_Trainer::RELEASE_TAG) {
if (WORKER.tag == MPI_Trainer::RELEASE_TAG) {
WORKER.release_tag();
break;
}
else if (tr.tag == MPI_Trainer::WAIT_TAG) {
else if (WORKER.tag == MPI_Trainer::WAIT_TAG) {
// do nothing here; ask for more work in the next cycle
WORKER.wait_tag();
}
else if (tr.tag == MPI_Trainer::DATA_TAG) {
else if (WORKER.tag == MPI_Trainer::DATA_TAG) {
WORKER.data_tag();
}
else if (tr.tag == MPI_Trainer::WORK_TAG) {
else if (WORKER.tag == MPI_Trainer::WORK_TAG) {
WORKER.work_tag();
}
}
WORKER.solve();
}
// END HOST-WORKER
// All local phi matrices are computed by this point
//
tr.solve();
#else // NON MPI VERSION
CLI::Timer timer_tot {"Training", CLI::Timer::Big};
Trainer tr(config);
if (is_verbose()) std::cout << "Loading structures..." << std::flush;
StructureDB stdb(tr.config);
@@ -182,33 +180,27 @@ void TadahCLI::subcommand_train() {
if (is_verbose()) std::cout << "Done!" << std::endl;
#endif
if (rank==0) {
Config param_file = tr.model->get_param_file();
param_file.check_pot_file();
std::ofstream outfile;
outfile.open ("pot.tadah");
outfile << param_file << std::endl;;
#ifndef TADAH_BUILD_MPI
if(train->count("--uncertainty")) {
t_type weights = tr.model->get_weights();
t_type unc = tr.model->get_weights_uncertainty();
Output(param_file,false).print_train_unc(weights, unc);
}
#endif
Config param_file = tr.model->get_param_file();
param_file.check_pot_file();
std::ofstream outfile;
outfile.open ("pot.tadah");
outfile << param_file << std::endl;;
if (is_verbose()) std::cout << timer_tot.to_string() << std::endl;
if(train->count("--uncertainty")) {
t_type weights = tr.model->get_weights();
t_type unc = tr.model->get_weights_uncertainty();
Output(param_file,false).print_train_unc(weights, unc);
}
if (is_verbose()) std::cout << timer_tot.to_string() << std::endl;
#endif
}
void TadahCLI::subcommand_predict() {
CLI::Timer timer_tot {"Prediction", CLI::Timer::Big};
if(predict->count("--verbose"))
set_verbose();
if (is_verbose()) std::cout << "Prediction mode" << std::endl;
// if (is_verbose()) std::cout << "Prediction..." << std::endl;
Config pot_config(pot_file);
pot_config.check_for_predict();
pot_config.remove("VERBOSE");
@@ -256,7 +248,7 @@ void TadahCLI::subcommand_predict() {
DescriptorsCalc<> dc(pot_config,*DCS.d2b,*DCS.d3b,*DCS.dmb,
*DCS.c2b,*DCS.c3b,*DCS.cmb);
if (is_verbose()) std::cout << "Prediction start..." << std::flush;
if (is_verbose()) std::cout << "Prediction..." << std::flush;
DM_Function_Base *fb = CONFIG::factory<DM_Function_Base,Config&>(
pot_config.get<std::string>("MODEL",1),pot_config);
M_Tadah_Base *modelp = CONFIG::factory<M_Tadah_Base,DM_Function_Base&,Config&>(
@@ -328,25 +320,25 @@ void TadahCLI::subcommand_hpo(
#endif
if (rank==0)
if (is_verbose()) std::cout << "HPO mode" << std::endl;
if (is_verbose()) std::cout << "Optimising HPs..." << std::endl;
CLI::Timer timer_tot {"HPO", CLI::Timer::Big};
CLI::Timer timer_tot {"HPO", CLI::Timer::Big};
if(hpo->count("--verbose"))
set_verbose();
Config config(config_file);
config.remove("CHECKPRESS");
config.add("CHECKPRESS", "true");
config.check_for_training();
if(hpo->count("--verbose"))
set_verbose();
Config config(config_file);
config.remove("CHECKPRESS");
config.add("CHECKPRESS", "true");
config.check_for_training();
if(hpo->count("--Force")) {
config.remove("FORCE");
config.add("FORCE", "true");
}
if(hpo->count("--Stress")) {
config.remove("STRESS");
config.add("STRESS", "true");
}
if(hpo->count("--Force")) {
config.remove("FORCE");
config.add("FORCE", "true");
}
if(hpo->count("--Stress")) {
config.remove("STRESS");
config.add("STRESS", "true");
}
#ifdef TADAH_BUILD_MPI
if (rank==0) {
@@ -364,48 +356,46 @@ void TadahCLI::subcommand_hpo(
// TODO broadcase config to all workers instead of dumping it on the disk
Config config("config.temp");
MPI_Trainer tr(config);
tr.init(rank, ncpu);
MPI_Trainer_WORKER WORKER(tr, rank, ncpu);
MPI_Trainer_WORKER WORKER(config, rank, ncpu);
WORKER.init();
while (true) {
// ask for more work...
MPI_Send (&tr.rows_available, 1, MPI_INT, 0, MPI_Trainer::WORK_TAG, MPI_COMM_WORLD);
MPI_Send (&WORKER.rows_available, 1, MPI_INT, 0, MPI_Trainer::WORK_TAG, MPI_COMM_WORLD);
// request from root or from other workers
tr.probe();
WORKER.probe();
// release a worker
if (tr.tag == MPI_Trainer::RELEASE_TAG) {
if (WORKER.tag == MPI_Trainer::RELEASE_TAG) {
WORKER.release_tag();
break;
}
else if (tr.tag == MPI_Trainer::WAIT_TAG) {
else if (WORKER.tag == MPI_Trainer::WAIT_TAG) {
// do nothing here; ask for more work in the next cycle
WORKER.wait_tag();
}
else if (tr.tag == MPI_Trainer::DATA_TAG) {
else if (WORKER.tag == MPI_Trainer::DATA_TAG) {
WORKER.data_tag();
}
else if (tr.tag == MPI_Trainer::WORK_TAG) {
else if (WORKER.tag == MPI_Trainer::WORK_TAG) {
WORKER.work_tag();
}
}
tr.solve();
WORKER.solve();
}
}
#else
hpo_run(config, target_file);
hpo_run(config, target_file);
#endif
if (rank==0)
if (is_verbose()) std::cout << timer_tot.to_string() << std::endl;
if (rank==0)
if (is_verbose()) std::cout << timer_tot.to_string() << std::endl;
#else
std::cout << "-----------------------------------------------" << std::endl;
std::cout << "This subcommand is not supported by this build." << std::endl;
std::cout << "Tadah! Must by compiled with HPO support." << std::endl;
std::cout << "See documentation for details." << std::endl;
std::cout << "-----------------------------------------------" << std::endl;
std::cout << "-----------------------------------------------" << std::endl;
std::cout << "This subcommand is not supported by this build." << std::endl;
std::cout << "Tadah! Must by compiled with HPO support." << std::endl;
std::cout << "See documentation for details." << std::endl;
std::cout << "-----------------------------------------------" << std::endl;
#endif
}
Loading