diff --git a/trainer.h b/trainer.h index 20821372fece0fe6aa760ff19033911647e4211d..c8e2d06c086c36476cd6505af1066d2f1b595d3b 100644 --- a/trainer.h +++ b/trainer.h @@ -77,6 +77,7 @@ class MPI_Trainer: public Trainer { int phi_cols1; int phi_rows2; int phi_cols2; + size_t phi_row = 0; // next row to be filled in the local phi array MPI_Trainer(Config &c): Trainer(c) {} @@ -162,7 +163,6 @@ class MPI_Trainer: public Trainer { MPI_Type_commit(&rowvecs); // COUNTERS - size_t phi_row = 0; // next row to be filled in the local phi array rows_available=phi_rows1; // number of available rows in the local phi array // once we know the size of local phi, we can allocate memory to it @@ -344,11 +344,11 @@ class MPI_Trainer_HOST { int rows_accepted = tr.rows_available < rows_needed ? tr.rows_available : rows_needed; MPI_Send (&tr.b_rank, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); MPI_Send (&rows_accepted, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); - MPI_Recv (&tr.dm.Phi.data()[phi_row], rows_accepted, rowvecs, worker, tag, MPI_COMM_WORLD, &status); - MPI_Recv (&tr.dm.T.data()[phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); - MPI_Recv (&tr.dm.Tlabels.data()[phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); + MPI_Recv (&tr.dm.Phi.data()[tr.phi_row], rows_accepted, rowvecs, worker, tag, MPI_COMM_WORLD, &status); + MPI_Recv (&tr.dm.T.data()[tr.phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); + MPI_Recv (&tr.dm.Tlabels.data()[tr.phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); tr.rows_available -= rows_accepted; - phi_row += rows_accepted; + tr.phi_row += rows_accepted; if (tr.rows_available<0 ) { throw std::runtime_error(" HOST1: The number of rows in the local array is smaller than requested.");} } else { @@ -396,11 +396,11 @@ class MPI_Trainer_HOST { int rows_accepted = tr.rows_available < rows_needed ? tr.rows_available : rows_needed; MPI_Send (&tr.b_rank, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); MPI_Send (&rows_accepted, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); - MPI_Recv (&tr.dm.Phi.data()[phi_row], rows_accepted, rowvecs, worker, tag, MPI_COMM_WORLD, &status); - MPI_Recv (&tr.dm.T.data()[phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); - MPI_Recv (&tr.dm.Tlabels.data()[phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); + MPI_Recv (&tr.dm.Phi.data()[tr.phi_row], rows_accepted, rowvecs, worker, tag, MPI_COMM_WORLD, &status); + MPI_Recv (&tr.dm.T.data()[tr.phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); + MPI_Recv (&tr.dm.Tlabels.data()[tr.phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); tr.rows_available -= rows_accepted; - phi_row += rows_accepted; + tr.phi_row += rows_accepted; if (tr.rows_available<0 ) { throw std::runtime_error(" HOST2: The number of rows in the local array is smaller than requested.");} } else {