diff --git a/trainer.h b/trainer.h index 69b65ba2eb19eb3435d54847f8a90a1449747724..06c43a4cbe3f913e168d9cbdc0f9427292b7861d 100644 --- a/trainer.h +++ b/trainer.h @@ -231,7 +231,7 @@ class MPI_Trainer_HOST { void work_tag() { int rows_available; - MPI_Recv (&rows_available, 1, MPI_INT, tr.worker, tag, MPI_COMM_WORLD, &status); + MPI_Recv (&rows_available, 1, MPI_INT, tr.worker, tr.tag, MPI_COMM_WORLD, &status); std::tuple<std::string,int,int> wpckg = wpckgs.back(); wpckgs.pop_back(); @@ -239,28 +239,28 @@ class MPI_Trainer_HOST { // send dataset filename const char *fn = std::get<0>(wpckg).c_str(); int fn_length = std::get<0>(wpckg).length()+1; // +1 for char - MPI_Send (fn, fn_length, MPI_CHAR, tr.worker, tag, MPI_COMM_WORLD); + MPI_Send (fn, fn_length, MPI_CHAR, tr.worker, tr.tag, MPI_COMM_WORLD); // send index of the first structure to load int first = std::get<1>(wpckg); - MPI_Send (&first, 1, MPI_INT, tr.worker, tag, MPI_COMM_WORLD); + MPI_Send (&first, 1, MPI_INT, tr.worker, tr.tag, MPI_COMM_WORLD); // send number of structures to load int nstruc = std::get<2>(wpckg); - MPI_Send (&nstruc, 1, MPI_INT, tr.worker, tag, MPI_COMM_WORLD); + MPI_Send (&nstruc, 1, MPI_INT, tr.worker, tr.tag, MPI_COMM_WORLD); } void data_tag(int &count) { int rows_needed; - MPI_Recv (&rows_needed, 1, MPI_INT, tr.worker, tag, MPI_COMM_WORLD, &status); + MPI_Recv (&rows_needed, 1, MPI_INT, tr.worker, tr.tag, MPI_COMM_WORLD, &status); if (tr.rows_available>0) { int rows_accepted = tr.rows_available < rows_needed ? tr.rows_available : rows_needed; - MPI_Send (&tr.b_rank, 1, MPI_INT, tr.worker, tag, MPI_COMM_WORLD); - MPI_Send (&rows_accepted, 1, MPI_INT, tr.worker, tag, MPI_COMM_WORLD); - MPI_Recv (&tr.dm.Phi.data()[tr.phi_row], rows_accepted, tr.rowvecs, tr.worker, tag, MPI_COMM_WORLD, &status); - MPI_Recv (&tr.dm.T.data()[tr.phi_row], rows_accepted, MPI_DOUBLE, tr.worker, tag, MPI_COMM_WORLD, &status); - MPI_Recv (&tr.dm.Tlabels.data()[tr.phi_row], rows_accepted, MPI_DOUBLE, tr.worker, tag, MPI_COMM_WORLD, &status); + MPI_Send (&tr.b_rank, 1, MPI_INT, tr.worker, tr.tag, MPI_COMM_WORLD); + MPI_Send (&rows_accepted, 1, MPI_INT, tr.worker, tr.tag, MPI_COMM_WORLD); + MPI_Recv (&tr.dm.Phi.data()[tr.phi_row], rows_accepted, tr.rowvecs, tr.worker, tr.tag, MPI_COMM_WORLD, &status); + MPI_Recv (&tr.dm.T.data()[tr.phi_row], rows_accepted, MPI_DOUBLE, tr.worker, tr.tag, MPI_COMM_WORLD, &status); + MPI_Recv (&tr.dm.Tlabels.data()[tr.phi_row], rows_accepted, MPI_DOUBLE, tr.worker, tr.tag, MPI_COMM_WORLD, &status); tr.rows_available -= rows_accepted; tr.phi_row += rows_accepted; if (tr.rows_available<0 ) { throw std::runtime_error(" HOST2: The number of rows in the local array is smaller than requested.");} @@ -322,13 +322,13 @@ class MPI_Trainer_WORKER { bool release_tag() { int temp; - MPI_Recv (&temp, 1, MPI_INT, tr.worker, tag, MPI_COMM_WORLD, &status); + MPI_Recv (&temp, 1, MPI_INT, tr.worker, tr.tag, MPI_COMM_WORLD, &status); if (tr.rows_available!=0) { throw std::runtime_error("Attempting to release a worker... but the worker requires more data!!");} return true; } bool wait_tag() { int temp; - MPI_Recv (&temp, 1, MPI_INT, tr.worker, tag, MPI_COMM_WORLD, &status); + MPI_Recv (&temp, 1, MPI_INT, tr.worker, tr.tag, MPI_COMM_WORLD, &status); } bool data_tag() { // other worker is giving me some data @@ -336,9 +336,9 @@ class MPI_Trainer_WORKER { MPI_Get_count(&status, MPI_DOUBLE, &arr_size); int rows_accepted = arr_size/tr.phi_cols1; if (tr.rows_available<rows_accepted) { throw std::runtime_error("Number of rows available is smaller than number of provided rows");} - MPI_Recv (&tr.dm.Phi.data()[tr.phi_row], tr.rows_available, tr.rowvecs, tr.worker, tag, MPI_COMM_WORLD, &status); - MPI_Recv (&tr.dm.T.data()[tr.phi_row], tr.rows_available, MPI_DOUBLE, tr.worker, tag, MPI_COMM_WORLD, &status); - MPI_Recv (&tr.dm.Tlabels.data()[tr.phi_row], tr.rows_available, MPI_DOUBLE, tr.worker, tag, MPI_COMM_WORLD, &status); + MPI_Recv (&tr.dm.Phi.data()[tr.phi_row], tr.rows_available, tr.rowvecs, tr.worker, tr.tag, MPI_COMM_WORLD, &status); + MPI_Recv (&tr.dm.T.data()[tr.phi_row], tr.rows_available, MPI_DOUBLE, tr.worker, tr.tag, MPI_COMM_WORLD, &status); + MPI_Recv (&tr.dm.Tlabels.data()[tr.phi_row], tr.rows_available, MPI_DOUBLE, tr.worker, tr.tag, MPI_COMM_WORLD, &status); tr.rows_available -= rows_accepted; tr.phi_row += rows_accepted; }