From f8ef1f4d4169a10fa88dc9c3b6d546843f8eb031 Mon Sep 17 00:00:00 2001
From: Marcin Kirsz <mkirsz@ed.ac.uk>
Date: Fri, 18 Oct 2024 12:41:34 +0100
Subject: [PATCH] Update

---
 trainer.h | 41 +++++++++++++++++++++++++++++++++++++----
 1 file changed, 37 insertions(+), 4 deletions(-)

diff --git a/trainer.h b/trainer.h
index a783760..3f72c8c 100644
--- a/trainer.h
+++ b/trainer.h
@@ -295,12 +295,12 @@ class MPI_Trainer: public Trainer {
     }
 };
 
-class MPI_Trainer_HOST: public MPI_Trainer {
+class TrainerHost: public MPI_Trainer {
   private:
     std::vector<std::tuple<std::string,int,int>> wpckgs;
 
   public:
-    MPI_Trainer_HOST(Config &c, int &rank, int &ncpu):
+    TrainerHost(Config &c, int &rank, int &ncpu):
       MPI_Trainer(c, rank, ncpu)
   {}
 
@@ -421,12 +421,45 @@ class MPI_Trainer_HOST: public MPI_Trainer {
       MPI_Send (&ready, 1, MPI_INT, worker, 
           MPI_Trainer::CONFIG_TAG, MPI_COMM_WORLD);
     }
+    void run() {
+      prep_wpckgs();
+
+      // Process requests from workers
+      while (true) {
+        // Exit loop if there are no more packages
+        if (!has_packages()) break;
+
+        // Probe for incoming requests from any worker
+        probe();
+
+        // Handle requests based on their tags
+        if (tag == MPI_Trainer::WORK_TAG) {
+          work_tag();
+        } else if (tag == MPI_Trainer::DATA_TAG) {
+          data_tag();
+        } else {
+          throw std::runtime_error("HOST: Unexpected request from "
+              + std::to_string(worker));
+        }
+      }
 
+      // Collect remaining data and release all workers
+      int count = 1;
+      while (true) {
+        probe();
+        if (tag == MPI_Trainer::DATA_TAG) {
+          data_tag();
+        } else {
+          release_tag(count);
+        }
+        if (count == ncpu) break; // Exit when all workers are released
+      }
+    }
 };
-class Trainer_WORKER: public MPI_Trainer {
+class TrainerWorker: public MPI_Trainer {
 
   public:
-    Trainer_WORKER(Config &c, int &rank, int &ncpu):
+    TrainerWorker(Config &c, int &rank, int &ncpu):
       MPI_Trainer(c, rank, ncpu)
   {}
 
-- 
GitLab