From 132fb500a0f7c99b8a2301b575c3f6416faf488e Mon Sep 17 00:00:00 2001
From: Marcin Kirsz <mkirsz@ed.ac.uk>
Date: Tue, 22 Oct 2024 11:02:50 +0100
Subject: [PATCH] Update

---
 bin/tadah_cli.cpp         | 145 ++++++++++----------------------------
 examples/ex_0/targets_new |  34 ++++-----
 2 files changed, 54 insertions(+), 125 deletions(-)

diff --git a/bin/tadah_cli.cpp b/bin/tadah_cli.cpp
index 8a42413..f88d560 100644
--- a/bin/tadah_cli.cpp
+++ b/bin/tadah_cli.cpp
@@ -15,6 +15,7 @@
 #include "../MLIP/trainer.h"
 #ifdef TADAH_ENABLE_HPO
 #include "../HPO/hpo.h"
+#include "../HPO/hpo_worker.h"
 #endif
 
 #include <iostream>
@@ -68,43 +69,13 @@ void TadahCLI::subcommand_train() {
     CLI::Timer timer_tot{"Training", CLI::Timer::Big};
 
     // Initialize host MPI trainer
-    MPI_Trainer_HOST HOST(config, rank, ncpu);
-    HOST.prep_wpckgs();
-
-    // Process requests from workers
-    while (true) {
-      // Exit loop if there are no more packages
-      if (!HOST.has_packages()) break;
-
-      // Probe for incoming requests from any worker
-      HOST.probe();
-
-      // Handle requests based on their tags
-      if (HOST.tag == MPI_Trainer::WORK_TAG) {
-        HOST.work_tag();
-      } else if (HOST.tag == MPI_Trainer::DATA_TAG) {
-        HOST.data_tag();
-      } else {
-        throw std::runtime_error("HOST: Unexpected request from "
-            + std::to_string(HOST.worker));
-      }
-    }
-
-    // Collect remaining data and release all workers
-    int count = 1;
-    while (true) {
-      HOST.probe();
-      if (HOST.tag == MPI_Trainer::DATA_TAG) {
-        HOST.data_tag();
-      } else {
-        HOST.release_tag(count);
-      }
-      if (count == ncpu) break; // Exit when all workers are released
-    }
+    TrainerHost host(config, rank, ncpu);
+    host.run();
 
     // Perform the final computation and save parameters
-    HOST.solve();
-    Config param_file = HOST.model->get_param_file();
+    host.solve();
+
+    Config param_file = host.model->get_param_file();
     param_file.check_pot_file();
     std::ofstream outfile("pot.tadah");
     if (outfile.is_open()) {
@@ -113,29 +84,12 @@ void TadahCLI::subcommand_train() {
       std::cerr << "Error: Unable to open file pot.tadah" << std::endl;
     }
     if (is_verbose()) std::cout << timer_tot.to_string() << std::endl;
-  } else { // Worker code
-           // Initialize worker MPI trainer
-    MPI_Trainer_WORKER WORKER(config, rank, ncpu);
-    while (true) {
-      // Send a request for more work to the host
-      MPI_Send(&WORKER.rows_available, 1, MPI_INT, 0, MPI_Trainer::WORK_TAG, MPI_COMM_WORLD);
-
-      // Probe and respond to host or worker requests
-      WORKER.probe();
-
-      // Handle requests based on their tags
-      if (WORKER.tag == MPI_Trainer::RELEASE_TAG) {
-        WORKER.release_tag();
-        break;
-      } else if (WORKER.tag == MPI_Trainer::WAIT_TAG) {
-        WORKER.wait_tag();
-      } else if (WORKER.tag == MPI_Trainer::DATA_TAG) {
-        WORKER.data_tag();
-      } else if (WORKER.tag == MPI_Trainer::WORK_TAG) {
-        WORKER.work_tag();
-      }
-    }
-    WORKER.solve(); // Perform worker computations
+  }
+  else { // Worker code
+         // Initialize worker MPI trainer
+    TrainerWorker worker(config, rank, ncpu);
+    worker.run();
+    worker.solve(); // Perform worker computations
   }
 #else // Non-MPI version
       // Start training with a timer in the non-MPI mode
@@ -244,11 +198,9 @@ void TadahCLI::subcommand_predict() {
     stpred = modelp->predict(pot_config,stdb,dc,predicted_error);
   }
   else {
-
     stpred = modelp->predict(pot_config,stdb,dc);
   }
 
-
   if (is_verbose()) std::cout << "Done!" << std::endl;
 
   if (is_verbose()) std::cout << "Dumping output..." << std::flush;
@@ -289,7 +241,6 @@ void TadahCLI::subcommand_hpo(
     [[maybe_unused]]int argc,
     [[maybe_unused]]char**argv) {
 
-
 #ifdef TADAH_ENABLE_HPO
 
   if(hpo->count("--Verbose"))
@@ -310,34 +261,25 @@ void TadahCLI::subcommand_hpo(
   }
 
 #ifdef TADAH_BUILD_MPI
-  int rank=0;
-  int ncpu=1;
+
+  int rank;
+  int ncpu;
   MPI_Comm_size(MPI_COMM_WORLD, &ncpu);
   MPI_Comm_rank(MPI_COMM_WORLD, &rank);
 
-  // Create a communicator for LAMMPS
+  MPI_Status status;
   MPI_Group world_group;
   MPI_Comm_group(MPI_COMM_WORLD, &world_group);
 
-  std::vector<int> ranks;
-  ranks.push_back(0); // for now we use root to run LAMMPS
-                      // long term we want to distribute
-                      // work among available workers
-
-  MPI_Group lmp_group;
-  MPI_Group_incl(world_group, ranks.size(), ranks.data(), &lmp_group);
-
-  MPI_Comm lmp_comm;
-  MPI_Comm_create(MPI_COMM_WORLD, lmp_group, &lmp_comm);
-
   if (rank==0) {
     if (ncpu<2) {
-      std::cout << "Minimum number of cpus for an mpi version is 2" << std::endl;
+      std::cout << "Minimum number of cpus for an MPI version is 2" << std::endl;
       return;
     }
     if (is_verbose()) std::cout << "Optimising HPs..." << std::endl;
     CLI::Timer timer_tot {"HPO", CLI::Timer::Big};
-    hpo_run(config, target_file, validation_file, lmp_comm);
+    hpo_run(config, target_file, validation_file);
+    if (is_verbose()) std::cout << "Done" << std::endl;
 
     // shut all workers
     int cont=0;
@@ -345,54 +287,38 @@ void TadahCLI::subcommand_hpo(
     if (is_verbose()) std::cout << timer_tot.to_string() << std::endl;
   }
   else { // WORKER
+    HPO_Worker hpo_worker(rank, ncpu, world_group);
+
     while (true) {
       int cs_size;
       MPI_Bcast(&cs_size, 1, MPI_INT, 0, MPI_COMM_WORLD);
       if (!cs_size)
         break;
 
+      // Obtain updated config file from the HOST
       std::vector<char> serialized(cs_size);
       MPI_Bcast(serialized.data(), cs_size, MPI_CHAR, 0, MPI_COMM_WORLD);
       Config config;
       config.deserialize(serialized);
       config.postprocess();
 
-      MPI_Trainer_WORKER WORKER(config, rank, ncpu);
-      while (true) {
-        // ask for more work...
-        MPI_Send (&WORKER.rows_available, 1, MPI_INT, 0, MPI_Trainer::WORK_TAG, MPI_COMM_WORLD);
-
-        // request from root or from other workers
-        WORKER.probe();
-
-        // release a worker
-        if (WORKER.tag == MPI_Trainer::RELEASE_TAG) {
-          WORKER.release_tag();
-          break;
-        }
-        else if (WORKER.tag == MPI_Trainer::WAIT_TAG) {
-          // do nothing here; ask for more work in the next cycle
-          WORKER.wait_tag();
-        }
-        else if (WORKER.tag == MPI_Trainer::DATA_TAG) {
-          WORKER.data_tag();
-        }
-        else if (WORKER.tag == MPI_Trainer::WORK_TAG) {
-          WORKER.work_tag();
-        }
-      }
-      WORKER.solve();
-    }
+      TrainerWorker worker(config, rank, ncpu);
+      worker.run();
+      worker.solve();
+
+      // Run LAMMPS
+      if (hpo_worker.use_lammps)
+        hpo_worker.run();
+    } // end of WORKER external while loop
   }
   // Free the communicator and groups when done
   //MPI_Comm_free(&lmp_comm);
   MPI_Group_free(&world_group);
-  MPI_Group_free(&lmp_group);
 #else
   CLI::Timer timer_tot {"HPO", CLI::Timer::Big};
   if (is_verbose()) std::cout << "Optimising HPs..." << std::endl;
-  int dummy_comm=0;
-  hpo_run(config, target_file, validation_file, dummy_comm);
+  //int dummy_comm=0;
+  hpo_run(config, target_file, validation_file);
   if (is_verbose()) std::cout << timer_tot.to_string() << std::endl;
 #endif
 
@@ -510,7 +436,9 @@ TadahCLI::TadahCLI():
   /*     Hyperparameter Optimizer                                              */
   /*---------------------------------------------------------------------------*/
   ss.str(std::string());
-#ifdef TADAH_ENABLE_HPO
+#ifndef TADAH_ENABLE_HPO
+  ss << "(UNAVAILABLE)" << std::endl;
+#endif
   ss << "Optimize the model's architecture and determine\n";
   ss << "the best hyperparameters within predefined constraints.\n";
   ss << "This option uses hyperparameter optimization\n";
@@ -519,6 +447,7 @@ TadahCLI::TadahCLI():
 
   hpo = app.add_subcommand("hpo", ss.str());
 
+#ifdef TADAH_ENABLE_HPO
   ss.str(std::string());
   ss << "A config file containing model inital parameters\n";
   ss << "and training dataset(s).\n";
@@ -552,8 +481,6 @@ TadahCLI::TadahCLI():
   //hpo->add_flag("-u,--uncertainty",
   //        "Dump uncertainty on weights."); // TODO check this
 
-#else
-  ss << "(UNAVAILABLE)" << std::endl;
 #endif
 }
 
diff --git a/examples/ex_0/targets_new b/examples/ex_0/targets_new
index c4e2ed8..cce6296 100644
--- a/examples/ex_0/targets_new
+++ b/examples/ex_0/targets_new
@@ -1,19 +1,19 @@
 # Global settings for the hyperparameter optimiser
-MAXCALLS  100  # Maximum number of optimisation cycles
-EPS  1e-3      # Accuracy convergence threshold.
-               # Smaller values increase precision but slow global exploration.
+MAXCALLS  10  # Maximum number of optimisation cycles
+EPS  1e-3      # Accuracy convergence threshold. (default 1e-2)
+               # Smaller values increase precision but slow down global exploration.
 
 # Model parameters to be optimised.
 # Model is specified in a separate config file.
-OPTIM RCUT2B 5.0 6.0       # Optimise RCUT2B within these bounds
-OPTIM CGRID2B 1.1 6.0      # Optimise CGRID2B within these bounds
-OPTIM SGRID2B 0.1 4.11     # Optimise SGRID2B within these bounds
-OPTIM AGRIDMB 0 2          # Optimise AGRIDMB within these bounds
-OPTIM EWEIGHT 1e-2 1e0     # Optimise EWEIGHT within these bounds
+OPTIM RCUT2B 3.0 6.0       # Optimise RCUT2B within these bounds
+OPTIM CGRID2B 1.0 6.0      # Optimise CGRID2B within these bounds
+OPTIM SGRID2B 0.01 6.0     # Optimise SGRID2B within these bounds
+#OPTIM AGRIDMB 0 2          # Optimise AGRIDMB within these bounds
+#OPTIM EWEIGHT 1e-2 1e0     # Optimise EWEIGHT within these bounds
 
 # Basic optimisation targets with weighted inclusion in global loss
 ERMSE  0 100   # Energy RMSE with weight 100, printed to properties.dat
-SRMSE  0 20    # Stress RMSE with weight 20, printed to properties.dat
+#SRMSE  0 20    # Stress RMSE with weight 20, printed to properties.dat
 
 # Default settings
 LOSSFUNC 2     # Loss function type (default 1): 
@@ -22,8 +22,8 @@ LOSSFUNC 2     # Loss function type (default 1):
 	       # [3] -> log(tloss)
 
 # Default settings for LAMMPS simulations
-FAILSCORE 100  # Score assigned if simulation fails (default 100)
-MAXTIME 2 1    # Maximum of 2 seconds for LAMMPS simulations, checked per step (default 10 1)
+FAILSCORE 100  # Score assigned if simulation fails (default 1000)
+MAXTIME 2      # Maximum of 2 seconds for LAMMPS simulations, checked every step (default 10)
 
 # Advanced optimisation targets used with LAMMPS
 # Default settings may be overridden with specific flags
@@ -37,16 +37,18 @@ MAXTIME 2 1    # Maximum of 2 seconds for LAMMPS simulations, checked per step (
 # 	-o, --outvar    # Output variable produced by the script, printed to properties.dat
 # 	-w, --weight    # Weighting factor for tloss: w*tloss
 # 	-f, --failscore # Custom FAILSCORE for this simulation
+# 	-l, --lossfunc  # Custom LOSSFUNC for this simulation
 # 	-m, --maxtime   # Max time for simulation; returns failscore if exceeded
+# 	-n, --ncpu 	# Number of cores to be used by this LAMMPS simulation, MPI only
 
 # Simulation 1: Optimise lattice constant using FCC lattice.
 # Executes LAMMPS with 'in.lata', fully defining the simulation and 'tloss'.
-LAMMPS --script in.lata --varloss tloss
+#LAMMPS --script in.lata --varloss tloss
 
-# Simulation 2: Similar setup as simulation 1, at varying pressures.
+# Simulation 2: Similar setup as simulation 1, but at different pressures.
 # 'tloss' is defined in 'in.mod_lata' as:
 #   tloss = abs(lata_exp - lata) / lata_exp
-LAMMPS --script in.mod_lata --varloss tloss --invar P 1 --invar lata_exp 3.3 --outvar lata
+LAMMPS --script in.mod_lata --varloss tloss --invar P 1 --invar lata_exp 3.3 --outvar lata --ncpu 4
 LAMMPS --script in.mod_lata --varloss tloss --invar P 100 --invar lata_exp 3.2 --outvar lata
-LAMMPS --script in.mod_lata --varloss tloss --invar P 1000 --invar lata_exp 3.1 --outvar lata
-
+LAMMPS --script in.mod_lata --varloss tloss --invar P 1000 --invar lata_exp 3.1 --outvar lata --ncpu 1
+LAMMPS --script in.mod2_lata --varloss tloss --invar P 1000 --invar lata_exp 3.1 --outvar lata --ncpu 4
-- 
GitLab