diff --git a/trainer.h b/trainer.h index dec5b520f87997409cf0de00055d9085407d7425..6c3a95522977b04adeaf09c5d112ce0d9d72474b 100644 --- a/trainer.h +++ b/trainer.h @@ -385,68 +385,59 @@ class MPI_Trainer_HOST { } } - void run() { - // work finised, collect remaining data and release all workers - int count=1; // count number of release workers, skip host - while(true) { - MPI_Probe(MPI_ANY_SOURCE, MPI_ANY_TAG, MPI_COMM_WORLD, &status); - int worker = status.MPI_SOURCE; - int tag = status.MPI_TAG; - - if (tag==TadahCLI::DATA_TAG) { - int rows_needed; - MPI_Recv (&rows_needed, 1, MPI_INT, worker, tag, MPI_COMM_WORLD, &status); - if (tr.rows_available>0) { - int rows_accepted = tr.rows_available < rows_needed ? tr.rows_available : rows_needed; - MPI_Send (&tr.b_rank, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); - MPI_Send (&rows_accepted, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); - MPI_Recv (&tr.dm.Phi.data()[tr.phi_row], rows_accepted, tr.rowvecs, worker, tag, MPI_COMM_WORLD, &status); - MPI_Recv (&tr.dm.T.data()[tr.phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); - MPI_Recv (&tr.dm.Tlabels.data()[tr.phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); - tr.rows_available -= rows_accepted; - tr.phi_row += rows_accepted; - if (tr.rows_available<0 ) { throw std::runtime_error(" HOST2: The number of rows in the local array is smaller than requested.");} - } - else { - // host is unable to fit data we have to ask workers for their storage availability - // find a worker to accept at least some data - MPI_Status status2; - int worker2; - // find a worker capable of accepting data - int w_rows_available; - while (true) { - MPI_Recv (&w_rows_available, 1, MPI_INT, MPI_ANY_SOURCE, TadahCLI::WORK_TAG, MPI_COMM_WORLD, &status2); - worker2 = status2.MPI_SOURCE; - if (worker==worker2) {throw std::runtime_error("worker and worker2 are the same."); } - if (w_rows_available==0 ) { - // give up on this worker and release him as there is no more work to be done - MPI_Send (0, 0, MPI_INT, worker2, TadahCLI::RELEASE_TAG, MPI_COMM_WORLD); - count++; - } - else { - // found a worker - break; - } - } - int rows_accepted = w_rows_available < rows_needed ? w_rows_available : rows_needed; - MPI_Send (&worker2, 1, MPI_INT, worker, TadahCLI::DATA_TAG, MPI_COMM_WORLD); - MPI_Send (&rows_accepted, 1, MPI_INT, worker, TadahCLI::DATA_TAG, MPI_COMM_WORLD); - } - } - else { - int rows_available; - MPI_Recv (&rows_available, 1, MPI_INT, worker, TadahCLI::WORK_TAG, MPI_COMM_WORLD, &status); - // there is no more work so release a worker - if (rows_available==0) { - MPI_Send (0, 0, MPI_INT, worker, TadahCLI::RELEASE_TAG, MPI_COMM_WORLD); + void a(int &worker, int &tag, int &count) { + + int rows_needed; + MPI_Recv (&rows_needed, 1, MPI_INT, worker, tag, MPI_COMM_WORLD, &status); + if (tr.rows_available>0) { + int rows_accepted = tr.rows_available < rows_needed ? tr.rows_available : rows_needed; + MPI_Send (&tr.b_rank, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); + MPI_Send (&rows_accepted, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); + MPI_Recv (&tr.dm.Phi.data()[tr.phi_row], rows_accepted, tr.rowvecs, worker, tag, MPI_COMM_WORLD, &status); + MPI_Recv (&tr.dm.T.data()[tr.phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); + MPI_Recv (&tr.dm.Tlabels.data()[tr.phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); + tr.rows_available -= rows_accepted; + tr.phi_row += rows_accepted; + if (tr.rows_available<0 ) { throw std::runtime_error(" HOST2: The number of rows in the local array is smaller than requested.");} + } + else { + // host is unable to fit data we have to ask workers for their storage availability + // find a worker to accept at least some data + MPI_Status status2; + int worker2; + // find a worker capable of accepting data + int w_rows_available; + while (true) { + MPI_Recv (&w_rows_available, 1, MPI_INT, MPI_ANY_SOURCE, TadahCLI::WORK_TAG, MPI_COMM_WORLD, &status2); + worker2 = status2.MPI_SOURCE; + if (worker==worker2) {throw std::runtime_error("worker and worker2 are the same."); } + if (w_rows_available==0 ) { + // give up on this worker and release him as there is no more work to be done + MPI_Send (0, 0, MPI_INT, worker2, TadahCLI::RELEASE_TAG, MPI_COMM_WORLD); count++; - } + } else { - MPI_Send (0, 0, MPI_INT, worker, TadahCLI::WAIT_TAG, MPI_COMM_WORLD); + // found a worker + break; } } - if (count==ncpu) { break; } // count starts from 1 + int rows_accepted = w_rows_available < rows_needed ? w_rows_available : rows_needed; + MPI_Send (&worker2, 1, MPI_INT, worker, TadahCLI::DATA_TAG, MPI_COMM_WORLD); + MPI_Send (&rows_accepted, 1, MPI_INT, worker, TadahCLI::DATA_TAG, MPI_COMM_WORLD); } } + void b(int &worker, int &tag, int &count) { + int rows_available; + MPI_Recv (&rows_available, 1, MPI_INT, worker, TadahCLI::WORK_TAG, MPI_COMM_WORLD, &status); + // there is no more work so release a worker + if (rows_available==0) { + MPI_Send (0, 0, MPI_INT, worker, TadahCLI::RELEASE_TAG, MPI_COMM_WORLD); + count++; + } + else { + MPI_Send (0, 0, MPI_INT, worker, TadahCLI::WAIT_TAG, MPI_COMM_WORLD); + } + } + }; #endif