From ba9d3077bf38ce7d3ae863eded3c4c173224a155 Mon Sep 17 00:00:00 2001
From: Marcin Kirsz <mkirsz@ed.ac.uk>
Date: Mon, 7 Oct 2024 13:07:23 +0100
Subject: [PATCH] moved tags to MPI trainer

---
 trainer.h | 37 ++++++++++++++++++++++---------------
 1 file changed, 22 insertions(+), 15 deletions(-)

diff --git a/trainer.h b/trainer.h
index 1b26782..81a9be1 100644
--- a/trainer.h
+++ b/trainer.h
@@ -54,7 +54,9 @@ class Trainer {
 };
 
 #ifdef TADAH_BUILD_MPI
+
 #include <mpi.h>
+
 extern "C" void blacs_get_(int*, int*, int*);
 extern "C" void blacs_pinfo_(int*, int*);
 extern "C" void blacs_gridinit_(int*, char*, int*, int*);
@@ -73,8 +75,13 @@ extern "C" void pdgemv_(char* transa, int* m, int* n, double* alpha, double* a,
     int* ia, int* ja, int* desc_a, double* x, int* ix, int* jx, int* desc_x,
     int* incx, double* beta, double* y, int* iy, int* jy, int* desc_y, int* incy);	
 
+
 class MPI_Trainer: public Trainer {
   public:
+const static int WAIT_TAG = 3;
+const static int RELEASE_TAG = 2;
+const static int DATA_TAG = 1;
+const static int WORK_TAG = 0;
     MPI_Status status;
     int worker;
     int tag;
@@ -377,7 +384,7 @@ class MPI_Trainer_HOST {
         int w_rows_available;
         while (true) {
           MPI_Recv (&w_rows_available, 1, MPI_INT, MPI_ANY_SOURCE, 
-              TadahCLI::WORK_TAG, MPI_COMM_WORLD, &status2);
+              WORK_TAG, MPI_COMM_WORLD, &status2);
           worker2 = status2.MPI_SOURCE;
           if (tr.worker==worker2) {
             throw std::runtime_error("worker and worker2 are the same.");
@@ -386,25 +393,25 @@ class MPI_Trainer_HOST {
         }
         int rows_accepted = w_rows_available < rows_needed ? 
           w_rows_available : rows_needed;
-        MPI_Send (&worker2, 1, MPI_INT, tr.worker, TadahCLI::DATA_TAG, 
+        MPI_Send (&worker2, 1, MPI_INT, tr.worker, DATA_TAG, 
             MPI_COMM_WORLD);
         MPI_Send (&rows_accepted, 1, MPI_INT, tr.worker, 
-            TadahCLI::DATA_TAG, MPI_COMM_WORLD);
+            DATA_TAG, MPI_COMM_WORLD);
       }
     }
     void release_tag(int &count) {
       int rows_available;
       MPI_Recv (&rows_available, 1, MPI_INT, tr.worker, 
-          TadahCLI::WORK_TAG, MPI_COMM_WORLD, &tr.status);
+          WORK_TAG, MPI_COMM_WORLD, &tr.status);
       // there is no more work so release a worker if full
       if (rows_available==0) {
         MPI_Send (0, 0, MPI_INT, tr.worker, 
-            TadahCLI::RELEASE_TAG, MPI_COMM_WORLD);
+            RELEASE_TAG, MPI_COMM_WORLD);
         count++;
       }
       else {
         MPI_Send (0, 0, MPI_INT, tr.worker, 
-            TadahCLI::WAIT_TAG, MPI_COMM_WORLD);
+            WAIT_TAG, MPI_COMM_WORLD);
       }
     }
 
@@ -462,11 +469,11 @@ class MPI_Trainer_WORKER {
       MPI_Get_count(&tr.status, MPI_CHAR, &fn_length);
 
       char *fn  = (char *) malloc(fn_length+1);
-      MPI_Recv (fn, fn_length, MPI_CHAR, 0, TadahCLI::WORK_TAG,
+      MPI_Recv (fn, fn_length, MPI_CHAR, 0, WORK_TAG,
           MPI_COMM_WORLD, &tr.status);
-      MPI_Recv (&first, 1, MPI_INT, 0, TadahCLI::WORK_TAG, 
+      MPI_Recv (&first, 1, MPI_INT, 0, WORK_TAG, 
           MPI_COMM_WORLD, &tr.status);
-      MPI_Recv (&nstruc, 1, MPI_INT, 0, TadahCLI::WORK_TAG, 
+      MPI_Recv (&nstruc, 1, MPI_INT, 0, WORK_TAG, 
           MPI_COMM_WORLD, &tr.status);
 
       // do work
@@ -515,14 +522,14 @@ class MPI_Trainer_WORKER {
         // send remaining data to available processes
         while (rows_needed > 0) {
           // request host 
-          MPI_Send (&rows_needed, 1, MPI_INT, 0, TadahCLI::DATA_TAG, MPI_COMM_WORLD);
+          MPI_Send (&rows_needed, 1, MPI_INT, 0, DATA_TAG, MPI_COMM_WORLD);
           int rows_accepted; // number of accepted rows
           int dest; // receiving process
                     // host returns which dest can accept and how much
-          MPI_Recv (&dest, 1, MPI_INT, 0, TadahCLI::DATA_TAG, 
+          MPI_Recv (&dest, 1, MPI_INT, 0, DATA_TAG, 
               MPI_COMM_WORLD, &tr.status);
 
-          MPI_Recv (&rows_accepted, 1, MPI_INT, 0, TadahCLI::DATA_TAG, 
+          MPI_Recv (&rows_accepted, 1, MPI_INT, 0, DATA_TAG, 
               MPI_COMM_WORLD, &tr.status);
           // we send data to the host or a willing worker
           int start=temp_dm.Phi.rows()-rows_needed;
@@ -538,11 +545,11 @@ class MPI_Trainer_WORKER {
 
           // ready to send
           MPI_Send (&temp_dm.Phi.data()[start], rows_accepted, 
-              trowvecs, dest, TadahCLI::DATA_TAG, MPI_COMM_WORLD);
+              trowvecs, dest, DATA_TAG, MPI_COMM_WORLD);
           MPI_Send (&temp_dm.T.data()[start], rows_accepted, 
-              MPI_DOUBLE, dest, TadahCLI::DATA_TAG, MPI_COMM_WORLD);
+              MPI_DOUBLE, dest, DATA_TAG, MPI_COMM_WORLD);
           MPI_Send (&temp_dm.Tlabels.data()[start], rows_accepted, 
-              MPI_DOUBLE, dest, TadahCLI::DATA_TAG, MPI_COMM_WORLD);
+              MPI_DOUBLE, dest, DATA_TAG, MPI_COMM_WORLD);
           rows_needed -= rows_accepted;
           MPI_Type_free(&trowvec);
           MPI_Type_free(&trowvecs);
-- 
GitLab