diff --git a/trainer.h b/trainer.h index 43865b806146466cafe4db37863b4d813041b57d..7653707b02259e4b737b9c2eb3f77f990cda0ef8 100644 --- a/trainer.h +++ b/trainer.h @@ -283,174 +283,163 @@ class MPI_Trainer: public Trainer { }; class MPI_Trainer_HOST { - MPI_Trainer &tr; - std::vector<std::tuple<std::string,int,int>> wpckgs; + private: + MPI_Trainer &tr; + int &rank; + int &ncpu; + MPI_Status &status; public: - MPI_Trainer_HOST(MPI_Trainer &MPIT): - tr(MPIT) + std::vector<std::tuple<std::string,int,int>> wpckgs; + MPI_Trainer_HOST(MPI_Trainer &MPIT, MPI_Status &status, int &rank, int &ncpu): + tr(MPIT) {} - void prep_wpckgs() { - // HOST: prepare work packages - // filename, first structure index, number of structures to read - int nstruc = tr.config.get<int>("MPIWPCKG"); - for (const std::string &fn : tr.config("DBFILE")) { - // get number of structures - int dbsize = StructureDB::count(fn).first; - int first=0; - while(true) { - if (nstruc < dbsize) { - wpckgs.push_back(std::make_tuple(fn,first,nstruc)); - first += nstruc; - } else { - wpckgs.push_back(std::make_tuple(fn,first,dbsize)); - break; + + void prep_wpckgs() { + // HOST: prepare work packages + // filename, first structure index, number of structures to read + int nstruc = tr.config.get<int>("MPIWPCKG"); + for (const std::string &fn : tr.config("DBFILE")) { + // get number of structures + int dbsize = StructureDB::count(fn).first; + int first=0; + while(true) { + if (nstruc < dbsize) { + wpckgs.push_back(std::make_tuple(fn,first,nstruc)); + first += nstruc; + } else { + wpckgs.push_back(std::make_tuple(fn,first,dbsize)); + break; + } + dbsize-=nstruc; } - dbsize-=nstruc; } } - } - void run(MPI_Status &status, int &rank, int &ncpu) { - - while (true) { - - if (wpckgs.size()==0) { - // no more packages, just skip remaining workers - // we will collect remaining data and release them outside of this loop - break; - } + void work_tag() { - // probe ANY request from ANY worker - MPI_Probe(MPI_ANY_SOURCE, MPI_ANY_TAG, MPI_COMM_WORLD, &status); - int worker = status.MPI_SOURCE; - int tag = status.MPI_TAG; + int rows_available; + MPI_Recv (&rows_available, 1, MPI_INT, worker, tag, MPI_COMM_WORLD, &status); - if (tag==TadahCLI::WORK_TAG) { - int rows_available; - MPI_Recv (&rows_available, 1, MPI_INT, worker, tag, MPI_COMM_WORLD, &status); + std::tuple<std::string,int,int> wpckg = wpckgs.back(); + wpckgs.pop_back(); - std::tuple<std::string,int,int> wpckg = wpckgs.back(); - wpckgs.pop_back(); + // send dataset filename + const char *fn = std::get<0>(wpckg).c_str(); + int fn_length = std::get<0>(wpckg).length()+1; // +1 for char + MPI_Send (fn, fn_length, MPI_CHAR, worker, tag, MPI_COMM_WORLD); - // send dataset filename - const char *fn = std::get<0>(wpckg).c_str(); - int fn_length = std::get<0>(wpckg).length()+1; // +1 for char - MPI_Send (fn, fn_length, MPI_CHAR, worker, tag, MPI_COMM_WORLD); + // send index of the first structure to load + int first = std::get<1>(wpckg); + MPI_Send (&first, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); - // send index of the first structure to load - int first = std::get<1>(wpckg); - MPI_Send (&first, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); + // send number of structures to load + int nstruc = std::get<2>(wpckg); + MPI_Send (&nstruc, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); + } - // send number of structures to load - int nstruc = std::get<2>(wpckg); - MPI_Send (&nstruc, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); - } - else if (tag==TadahCLI::DATA_TAG) { - int rows_needed; - MPI_Recv (&rows_needed, 1, MPI_INT, worker, tag, MPI_COMM_WORLD, &status); - if (tr.rows_available>0) { - int rows_accepted = tr.rows_available < rows_needed ? tr.rows_available : rows_needed; - MPI_Send (&tr.b_rank, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); - MPI_Send (&rows_accepted, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); - MPI_Recv (&tr.dm.Phi.data()[tr.phi_row], rows_accepted, tr.rowvecs, worker, tag, MPI_COMM_WORLD, &status); - MPI_Recv (&tr.dm.T.data()[tr.phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); - MPI_Recv (&tr.dm.Tlabels.data()[tr.phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); - tr.rows_available -= rows_accepted; - tr.phi_row += rows_accepted; - if (tr.rows_available<0 ) { throw std::runtime_error(" HOST1: The number of rows in the local array is smaller than requested.");} - } - else { - // Host is unable to fit data we have to ask workers for their storage availability - // Locate a worker willing to accept at least some data. - int worker2; - MPI_Status status2; - // find a worker capable of accepting data - int w_rows_available; - while (true) { - MPI_Recv (&w_rows_available, 1, MPI_INT, MPI_ANY_SOURCE, TadahCLI::WORK_TAG, MPI_COMM_WORLD, &status2); - worker2 = status2.MPI_SOURCE; - if (worker==worker2) {throw std::runtime_error("worker and worker2 are the same."); } - if (w_rows_available==0 ) { - // give up on this worker - MPI_Send (&worker2, 1, MPI_INT, worker2, TadahCLI::WAIT_TAG, MPI_COMM_WORLD); - } - else { - // found a worker - break; + void data_tag() { + int rows_needed; + MPI_Recv (&rows_needed, 1, MPI_INT, worker, tag, MPI_COMM_WORLD, &status); + if (tr.rows_available>0) { + int rows_accepted = tr.rows_available < rows_needed ? tr.rows_available : rows_needed; + MPI_Send (&tr.b_rank, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); + MPI_Send (&rows_accepted, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); + MPI_Recv (&tr.dm.Phi.data()[tr.phi_row], rows_accepted, tr.rowvecs, worker, tag, MPI_COMM_WORLD, &status); + MPI_Recv (&tr.dm.T.data()[tr.phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); + MPI_Recv (&tr.dm.Tlabels.data()[tr.phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); + tr.rows_available -= rows_accepted; + tr.phi_row += rows_accepted; + if (tr.rows_available<0 ) { throw std::runtime_error(" HOST1: The number of rows in the local array is smaller than requested.");} + } + else { + // Host is unable to fit data we have to ask workers for their storage availability + // Locate a worker willing to accept at least some data. + int worker2; + MPI_Status status2; + // find a worker capable of accepting data + int w_rows_available; + while (true) { + MPI_Recv (&w_rows_available, 1, MPI_INT, MPI_ANY_SOURCE, TadahCLI::WORK_TAG, MPI_COMM_WORLD, &status2); + worker2 = status2.MPI_SOURCE; + if (worker==worker2) {throw std::runtime_error("worker and worker2 are the same."); } + if (w_rows_available==0 ) { + // give up on this worker + MPI_Send (&worker2, 1, MPI_INT, worker2, TadahCLI::WAIT_TAG, MPI_COMM_WORLD); + } + else { + // found a worker + break; + } } + int rows_accepted = w_rows_available < rows_needed ? w_rows_available : rows_needed; + // tell worker about worker2 + MPI_Send (&worker2, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); + MPI_Send (&rows_accepted, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); } - int rows_accepted = w_rows_available < rows_needed ? w_rows_available : rows_needed; - // tell worker about worker2 - MPI_Send (&worker2, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); - MPI_Send (&rows_accepted, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); } - } - else { - throw std::runtime_error("HOST1: Unexpected request from " + std::to_string(worker)); - } - } - // work finised, collect remaining data and release all workers - int count=1; // count number of release workers, skip host - while(true) { - MPI_Probe(MPI_ANY_SOURCE, MPI_ANY_TAG, MPI_COMM_WORLD, &status); - int worker = status.MPI_SOURCE; - int tag = status.MPI_TAG; - - if (tag==TadahCLI::DATA_TAG) { - int rows_needed; - MPI_Recv (&rows_needed, 1, MPI_INT, worker, tag, MPI_COMM_WORLD, &status); - if (tr.rows_available>0) { - int rows_accepted = tr.rows_available < rows_needed ? tr.rows_available : rows_needed; - MPI_Send (&tr.b_rank, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); - MPI_Send (&rows_accepted, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); - MPI_Recv (&tr.dm.Phi.data()[tr.phi_row], rows_accepted, tr.rowvecs, worker, tag, MPI_COMM_WORLD, &status); - MPI_Recv (&tr.dm.T.data()[tr.phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); - MPI_Recv (&tr.dm.Tlabels.data()[tr.phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); - tr.rows_available -= rows_accepted; - tr.phi_row += rows_accepted; - if (tr.rows_available<0 ) { throw std::runtime_error(" HOST2: The number of rows in the local array is smaller than requested.");} - } - else { - // host is unable to fit data we have to ask workers for their storage availability - // find a worker to accept at least some data - MPI_Status status2; - int worker2; - // find a worker capable of accepting data - int w_rows_available; - while (true) { - MPI_Recv (&w_rows_available, 1, MPI_INT, MPI_ANY_SOURCE, TadahCLI::WORK_TAG, MPI_COMM_WORLD, &status2); - worker2 = status2.MPI_SOURCE; - if (worker==worker2) {throw std::runtime_error("worker and worker2 are the same."); } - if (w_rows_available==0 ) { - // give up on this worker and release him as there is no more work to be done - MPI_Send (0, 0, MPI_INT, worker2, TadahCLI::RELEASE_TAG, MPI_COMM_WORLD); - count++; - } - else { - // found a worker - break; + void run() { + // work finised, collect remaining data and release all workers + int count=1; // count number of release workers, skip host + while(true) { + MPI_Probe(MPI_ANY_SOURCE, MPI_ANY_TAG, MPI_COMM_WORLD, &status); + int worker = status.MPI_SOURCE; + int tag = status.MPI_TAG; + + if (tag==TadahCLI::DATA_TAG) { + int rows_needed; + MPI_Recv (&rows_needed, 1, MPI_INT, worker, tag, MPI_COMM_WORLD, &status); + if (tr.rows_available>0) { + int rows_accepted = tr.rows_available < rows_needed ? tr.rows_available : rows_needed; + MPI_Send (&tr.b_rank, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); + MPI_Send (&rows_accepted, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); + MPI_Recv (&tr.dm.Phi.data()[tr.phi_row], rows_accepted, tr.rowvecs, worker, tag, MPI_COMM_WORLD, &status); + MPI_Recv (&tr.dm.T.data()[tr.phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); + MPI_Recv (&tr.dm.Tlabels.data()[tr.phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); + tr.rows_available -= rows_accepted; + tr.phi_row += rows_accepted; + if (tr.rows_available<0 ) { throw std::runtime_error(" HOST2: The number of rows in the local array is smaller than requested.");} + } + else { + // host is unable to fit data we have to ask workers for their storage availability + // find a worker to accept at least some data + MPI_Status status2; + int worker2; + // find a worker capable of accepting data + int w_rows_available; + while (true) { + MPI_Recv (&w_rows_available, 1, MPI_INT, MPI_ANY_SOURCE, TadahCLI::WORK_TAG, MPI_COMM_WORLD, &status2); + worker2 = status2.MPI_SOURCE; + if (worker==worker2) {throw std::runtime_error("worker and worker2 are the same."); } + if (w_rows_available==0 ) { + // give up on this worker and release him as there is no more work to be done + MPI_Send (0, 0, MPI_INT, worker2, TadahCLI::RELEASE_TAG, MPI_COMM_WORLD); + count++; + } + else { + // found a worker + break; + } } + int rows_accepted = w_rows_available < rows_needed ? w_rows_available : rows_needed; + MPI_Send (&worker2, 1, MPI_INT, worker, TadahCLI::DATA_TAG, MPI_COMM_WORLD); + MPI_Send (&rows_accepted, 1, MPI_INT, worker, TadahCLI::DATA_TAG, MPI_COMM_WORLD); } - int rows_accepted = w_rows_available < rows_needed ? w_rows_available : rows_needed; - MPI_Send (&worker2, 1, MPI_INT, worker, TadahCLI::DATA_TAG, MPI_COMM_WORLD); - MPI_Send (&rows_accepted, 1, MPI_INT, worker, TadahCLI::DATA_TAG, MPI_COMM_WORLD); - } - } - else { - int rows_available; - MPI_Recv (&rows_available, 1, MPI_INT, worker, TadahCLI::WORK_TAG, MPI_COMM_WORLD, &status); - // there is no more work so release a worker - if (rows_available==0) { - MPI_Send (0, 0, MPI_INT, worker, TadahCLI::RELEASE_TAG, MPI_COMM_WORLD); - count++; } else { - MPI_Send (0, 0, MPI_INT, worker, TadahCLI::WAIT_TAG, MPI_COMM_WORLD); + int rows_available; + MPI_Recv (&rows_available, 1, MPI_INT, worker, TadahCLI::WORK_TAG, MPI_COMM_WORLD, &status); + // there is no more work so release a worker + if (rows_available==0) { + MPI_Send (0, 0, MPI_INT, worker, TadahCLI::RELEASE_TAG, MPI_COMM_WORLD); + count++; + } + else { + MPI_Send (0, 0, MPI_INT, worker, TadahCLI::WAIT_TAG, MPI_COMM_WORLD); + } } + if (count==ncpu) { break; } // count starts from 1 } - if (count==ncpu) { break; } // count starts from 1 } - } }; #endif