diff --git a/bin/tadah_cli.cpp b/bin/tadah_cli.cpp index ce76643fd8db0deb37b459305fb51284e4667d83..8a42413c8974ffcd27388bb74bac25ec0aed01e4 100644 --- a/bin/tadah_cli.cpp +++ b/bin/tadah_cli.cpp @@ -314,6 +314,22 @@ void TadahCLI::subcommand_hpo( int ncpu=1; MPI_Comm_size(MPI_COMM_WORLD, &ncpu); MPI_Comm_rank(MPI_COMM_WORLD, &rank); + + // Create a communicator for LAMMPS + MPI_Group world_group; + MPI_Comm_group(MPI_COMM_WORLD, &world_group); + + std::vector<int> ranks; + ranks.push_back(0); // for now we use root to run LAMMPS + // long term we want to distribute + // work among available workers + + MPI_Group lmp_group; + MPI_Group_incl(world_group, ranks.size(), ranks.data(), &lmp_group); + + MPI_Comm lmp_comm; + MPI_Comm_create(MPI_COMM_WORLD, lmp_group, &lmp_comm); + if (rank==0) { if (ncpu<2) { std::cout << "Minimum number of cpus for an mpi version is 2" << std::endl; @@ -321,7 +337,11 @@ void TadahCLI::subcommand_hpo( } if (is_verbose()) std::cout << "Optimising HPs..." << std::endl; CLI::Timer timer_tot {"HPO", CLI::Timer::Big}; - hpo_run(config, target_file, validation_file); + hpo_run(config, target_file, validation_file, lmp_comm); + + // shut all workers + int cont=0; + MPI_Bcast(&cont, 1, MPI_INT, 0, MPI_COMM_WORLD); if (is_verbose()) std::cout << timer_tot.to_string() << std::endl; } else { // WORKER @@ -364,10 +384,15 @@ void TadahCLI::subcommand_hpo( WORKER.solve(); } } + // Free the communicator and groups when done + //MPI_Comm_free(&lmp_comm); + MPI_Group_free(&world_group); + MPI_Group_free(&lmp_group); #else CLI::Timer timer_tot {"HPO", CLI::Timer::Big}; if (is_verbose()) std::cout << "Optimising HPs..." << std::endl; - hpo_run(config, target_file, validation_file); + int dummy_comm=0; + hpo_run(config, target_file, validation_file, dummy_comm); if (is_verbose()) std::cout << timer_tot.to_string() << std::endl; #endif