Skip to content
Snippets Groups Projects
Commit dd2dceeb authored by Marcin Kirsz's avatar Marcin Kirsz
Browse files

Update

parent 0be4c726
No related branches found
No related tags found
No related merge requests found
Pipeline #43069 failed
...@@ -77,6 +77,7 @@ class MPI_Trainer: public Trainer { ...@@ -77,6 +77,7 @@ class MPI_Trainer: public Trainer {
int phi_cols1; int phi_cols1;
int phi_rows2; int phi_rows2;
int phi_cols2; int phi_cols2;
size_t phi_row = 0; // next row to be filled in the local phi array
MPI_Trainer(Config &c): MPI_Trainer(Config &c):
Trainer(c) Trainer(c)
{} {}
...@@ -162,7 +163,6 @@ class MPI_Trainer: public Trainer { ...@@ -162,7 +163,6 @@ class MPI_Trainer: public Trainer {
MPI_Type_commit(&rowvecs); MPI_Type_commit(&rowvecs);
// COUNTERS // COUNTERS
size_t phi_row = 0; // next row to be filled in the local phi array
rows_available=phi_rows1; // number of available rows in the local phi array rows_available=phi_rows1; // number of available rows in the local phi array
// once we know the size of local phi, we can allocate memory to it // once we know the size of local phi, we can allocate memory to it
...@@ -344,11 +344,11 @@ class MPI_Trainer_HOST { ...@@ -344,11 +344,11 @@ class MPI_Trainer_HOST {
int rows_accepted = tr.rows_available < rows_needed ? tr.rows_available : rows_needed; int rows_accepted = tr.rows_available < rows_needed ? tr.rows_available : rows_needed;
MPI_Send (&tr.b_rank, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); MPI_Send (&tr.b_rank, 1, MPI_INT, worker, tag, MPI_COMM_WORLD);
MPI_Send (&rows_accepted, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); MPI_Send (&rows_accepted, 1, MPI_INT, worker, tag, MPI_COMM_WORLD);
MPI_Recv (&tr.dm.Phi.data()[phi_row], rows_accepted, rowvecs, worker, tag, MPI_COMM_WORLD, &status); MPI_Recv (&tr.dm.Phi.data()[tr.phi_row], rows_accepted, rowvecs, worker, tag, MPI_COMM_WORLD, &status);
MPI_Recv (&tr.dm.T.data()[phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); MPI_Recv (&tr.dm.T.data()[tr.phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status);
MPI_Recv (&tr.dm.Tlabels.data()[phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); MPI_Recv (&tr.dm.Tlabels.data()[tr.phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status);
tr.rows_available -= rows_accepted; tr.rows_available -= rows_accepted;
phi_row += rows_accepted; tr.phi_row += rows_accepted;
if (tr.rows_available<0 ) { throw std::runtime_error(" HOST1: The number of rows in the local array is smaller than requested.");} if (tr.rows_available<0 ) { throw std::runtime_error(" HOST1: The number of rows in the local array is smaller than requested.");}
} }
else { else {
...@@ -396,11 +396,11 @@ class MPI_Trainer_HOST { ...@@ -396,11 +396,11 @@ class MPI_Trainer_HOST {
int rows_accepted = tr.rows_available < rows_needed ? tr.rows_available : rows_needed; int rows_accepted = tr.rows_available < rows_needed ? tr.rows_available : rows_needed;
MPI_Send (&tr.b_rank, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); MPI_Send (&tr.b_rank, 1, MPI_INT, worker, tag, MPI_COMM_WORLD);
MPI_Send (&rows_accepted, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); MPI_Send (&rows_accepted, 1, MPI_INT, worker, tag, MPI_COMM_WORLD);
MPI_Recv (&tr.dm.Phi.data()[phi_row], rows_accepted, rowvecs, worker, tag, MPI_COMM_WORLD, &status); MPI_Recv (&tr.dm.Phi.data()[tr.phi_row], rows_accepted, rowvecs, worker, tag, MPI_COMM_WORLD, &status);
MPI_Recv (&tr.dm.T.data()[phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); MPI_Recv (&tr.dm.T.data()[tr.phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status);
MPI_Recv (&tr.dm.Tlabels.data()[phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); MPI_Recv (&tr.dm.Tlabels.data()[tr.phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status);
tr.rows_available -= rows_accepted; tr.rows_available -= rows_accepted;
phi_row += rows_accepted; tr.phi_row += rows_accepted;
if (tr.rows_available<0 ) { throw std::runtime_error(" HOST2: The number of rows in the local array is smaller than requested.");} if (tr.rows_available<0 ) { throw std::runtime_error(" HOST2: The number of rows in the local array is smaller than requested.");}
} }
else { else {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment