From dd2dceeb978e79416972f710d252a03c7ce2c98f Mon Sep 17 00:00:00 2001 From: Marcin Kirsz <mkirsz@ed.ac.uk> Date: Thu, 3 Oct 2024 14:26:03 +0100 Subject: [PATCH] Update --- trainer.h | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/trainer.h b/trainer.h index 2082137..c8e2d06 100644 --- a/trainer.h +++ b/trainer.h @@ -77,6 +77,7 @@ class MPI_Trainer: public Trainer { int phi_cols1; int phi_rows2; int phi_cols2; + size_t phi_row = 0; // next row to be filled in the local phi array MPI_Trainer(Config &c): Trainer(c) {} @@ -162,7 +163,6 @@ class MPI_Trainer: public Trainer { MPI_Type_commit(&rowvecs); // COUNTERS - size_t phi_row = 0; // next row to be filled in the local phi array rows_available=phi_rows1; // number of available rows in the local phi array // once we know the size of local phi, we can allocate memory to it @@ -344,11 +344,11 @@ class MPI_Trainer_HOST { int rows_accepted = tr.rows_available < rows_needed ? tr.rows_available : rows_needed; MPI_Send (&tr.b_rank, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); MPI_Send (&rows_accepted, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); - MPI_Recv (&tr.dm.Phi.data()[phi_row], rows_accepted, rowvecs, worker, tag, MPI_COMM_WORLD, &status); - MPI_Recv (&tr.dm.T.data()[phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); - MPI_Recv (&tr.dm.Tlabels.data()[phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); + MPI_Recv (&tr.dm.Phi.data()[tr.phi_row], rows_accepted, rowvecs, worker, tag, MPI_COMM_WORLD, &status); + MPI_Recv (&tr.dm.T.data()[tr.phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); + MPI_Recv (&tr.dm.Tlabels.data()[tr.phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); tr.rows_available -= rows_accepted; - phi_row += rows_accepted; + tr.phi_row += rows_accepted; if (tr.rows_available<0 ) { throw std::runtime_error(" HOST1: The number of rows in the local array is smaller than requested.");} } else { @@ -396,11 +396,11 @@ class MPI_Trainer_HOST { int rows_accepted = tr.rows_available < rows_needed ? tr.rows_available : rows_needed; MPI_Send (&tr.b_rank, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); MPI_Send (&rows_accepted, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); - MPI_Recv (&tr.dm.Phi.data()[phi_row], rows_accepted, rowvecs, worker, tag, MPI_COMM_WORLD, &status); - MPI_Recv (&tr.dm.T.data()[phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); - MPI_Recv (&tr.dm.Tlabels.data()[phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); + MPI_Recv (&tr.dm.Phi.data()[tr.phi_row], rows_accepted, rowvecs, worker, tag, MPI_COMM_WORLD, &status); + MPI_Recv (&tr.dm.T.data()[tr.phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); + MPI_Recv (&tr.dm.Tlabels.data()[tr.phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); tr.rows_available -= rows_accepted; - phi_row += rows_accepted; + tr.phi_row += rows_accepted; if (tr.rows_available<0 ) { throw std::runtime_error(" HOST2: The number of rows in the local array is smaller than requested.");} } else { -- GitLab