From dd2dceeb978e79416972f710d252a03c7ce2c98f Mon Sep 17 00:00:00 2001
From: Marcin Kirsz <mkirsz@ed.ac.uk>
Date: Thu, 3 Oct 2024 14:26:03 +0100
Subject: [PATCH] Update

---
 trainer.h | 18 +++++++++---------
 1 file changed, 9 insertions(+), 9 deletions(-)

diff --git a/trainer.h b/trainer.h
index 2082137..c8e2d06 100644
--- a/trainer.h
+++ b/trainer.h
@@ -77,6 +77,7 @@ class MPI_Trainer: public Trainer {
     int phi_cols1;
     int phi_rows2;
     int phi_cols2;
+    size_t phi_row = 0; // next row to be filled in the local phi array
     MPI_Trainer(Config &c):
       Trainer(c)
   {}
@@ -162,7 +163,6 @@ class MPI_Trainer: public Trainer {
       MPI_Type_commit(&rowvecs);
 
       // COUNTERS
-      size_t phi_row = 0; // next row to be filled in the local phi array
       rows_available=phi_rows1;  // number of available rows in the local phi array
 
       // once we know the size of local phi, we can allocate memory to it
@@ -344,11 +344,11 @@ class MPI_Trainer_HOST {
           int rows_accepted = tr.rows_available < rows_needed ? tr.rows_available : rows_needed;
           MPI_Send (&tr.b_rank, 1, MPI_INT, worker, tag, MPI_COMM_WORLD);
           MPI_Send (&rows_accepted, 1, MPI_INT, worker, tag, MPI_COMM_WORLD);
-          MPI_Recv (&tr.dm.Phi.data()[phi_row], rows_accepted, rowvecs, worker, tag, MPI_COMM_WORLD, &status);
-          MPI_Recv (&tr.dm.T.data()[phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status);
-          MPI_Recv (&tr.dm.Tlabels.data()[phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status);
+          MPI_Recv (&tr.dm.Phi.data()[tr.phi_row], rows_accepted, rowvecs, worker, tag, MPI_COMM_WORLD, &status);
+          MPI_Recv (&tr.dm.T.data()[tr.phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status);
+          MPI_Recv (&tr.dm.Tlabels.data()[tr.phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status);
           tr.rows_available -= rows_accepted;
-          phi_row += rows_accepted;
+          tr.phi_row += rows_accepted;
           if (tr.rows_available<0 ) { throw std::runtime_error(" HOST1: The number of rows in the local array is smaller than requested.");}
         }
         else {
@@ -396,11 +396,11 @@ class MPI_Trainer_HOST {
           int rows_accepted = tr.rows_available < rows_needed ? tr.rows_available : rows_needed;
           MPI_Send (&tr.b_rank, 1, MPI_INT, worker, tag, MPI_COMM_WORLD);
           MPI_Send (&rows_accepted, 1, MPI_INT, worker, tag, MPI_COMM_WORLD);
-          MPI_Recv (&tr.dm.Phi.data()[phi_row], rows_accepted, rowvecs, worker, tag, MPI_COMM_WORLD, &status);
-          MPI_Recv (&tr.dm.T.data()[phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status);
-          MPI_Recv (&tr.dm.Tlabels.data()[phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status);
+          MPI_Recv (&tr.dm.Phi.data()[tr.phi_row], rows_accepted, rowvecs, worker, tag, MPI_COMM_WORLD, &status);
+          MPI_Recv (&tr.dm.T.data()[tr.phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status);
+          MPI_Recv (&tr.dm.Tlabels.data()[tr.phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status);
           tr.rows_available -= rows_accepted;
-          phi_row += rows_accepted;
+          tr.phi_row += rows_accepted;
           if (tr.rows_available<0 ) { throw std::runtime_error(" HOST2: The number of rows in the local array is smaller than requested.");}
         }
         else {
-- 
GitLab