diff --git a/trainer.h b/trainer.h index 7a1eece225c1813d49635040f641a45b32dee149..501a97f8b87a472c367d0855213de714166c40ba 100644 --- a/trainer.h +++ b/trainer.h @@ -29,9 +29,9 @@ class Trainer { *DCS.c2b,*DCS.c3b,*DCS.cmb), nnf(config), fb(CONFIG::factory<DM_Function_Base,Config&>( - config.get<std::string>("MODEL",1),config)), + config.get<std::string>("MODEL",1),config)), model(CONFIG::factory<M_Tadah_Base,DM_Function_Base&,Config&> - (config.get<std::string>("MODEL",0),*fb,config)), + (config.get<std::string>("MODEL",0),*fb,config)), dm(*fb, config) { config.postprocess(); @@ -39,8 +39,8 @@ class Trainer { } void train(StructureDB &stdb) { - nnf.calc(stdb); - model->train(stdb,dc); + nnf.calc(stdb); + model->train(stdb,dc); } Config get_param_file() { @@ -280,5 +280,147 @@ class MPI_Trainer_HOST { MPI_Trainer_HOST(MPI_Trainer &MPIT): MPIT(MPIT) {} + void run() { + + while (true) { + + if (wpckgs.size()==0) { + // no more packages, just skip remaining workers + // we will collect remaining data and release them outside of this loop + break; + } + + // probe ANY request from ANY worker + MPI_Probe(MPI_ANY_SOURCE, MPI_ANY_TAG, MPI_COMM_WORLD, &status); + int worker = status.MPI_SOURCE; + int tag = status.MPI_TAG; + + if (tag==WORK_TAG) { + int rows_available; + MPI_Recv (&rows_available, 1, MPI_INT, worker, tag, MPI_COMM_WORLD, &status); + + std::tuple<std::string,int,int> wpckg = wpckgs.back(); + wpckgs.pop_back(); + + // send dataset filename + const char *fn = std::get<0>(wpckg).c_str(); + int fn_length = std::get<0>(wpckg).length()+1; // +1 for char + MPI_Send (fn, fn_length, MPI_CHAR, worker, tag, MPI_COMM_WORLD); + + // send index of the first structure to load + int first = std::get<1>(wpckg); + MPI_Send (&first, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); + + // send number of structures to load + int nstruc = std::get<2>(wpckg); + MPI_Send (&nstruc, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); + } + else if (tag==DATA_TAG) { + int rows_needed; + MPI_Recv (&rows_needed, 1, MPI_INT, worker, tag, MPI_COMM_WORLD, &status); + if (rows_available>0) { + int rows_accepted = rows_available < rows_needed ? rows_available : rows_needed; + MPI_Send (&b_rank, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); + MPI_Send (&rows_accepted, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); + MPI_Recv (&tr.dm.Phi.data()[phi_row], rows_accepted, rowvecs, worker, tag, MPI_COMM_WORLD, &status); + MPI_Recv (&tr.dm.T.data()[phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); + MPI_Recv (&tr.dm.Tlabels.data()[phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); + rows_available -= rows_accepted; + phi_row += rows_accepted; + if (rows_available<0 ) { throw std::runtime_error(" HOST1: The number of rows in the local array is smaller than requested.");} + } + else { + // Host is unable to fit data we have to ask workers for their storage availability + // Locate a worker willing to accept at least some data. + int worker2; + MPI_Status status2; + // find a worker capable of accepting data + int w_rows_available; + while (true) { + MPI_Recv (&w_rows_available, 1, MPI_INT, MPI_ANY_SOURCE, WORK_TAG, MPI_COMM_WORLD, &status2); + worker2 = status2.MPI_SOURCE; + if (worker==worker2) {throw std::runtime_error("worker and worker2 are the same."); } + if (w_rows_available==0 ) { + // give up on this worker + MPI_Send (&worker2, 1, MPI_INT, worker2, WAIT_TAG, MPI_COMM_WORLD); + } + else { + // found a worker + break; + } + } + int rows_accepted = w_rows_available < rows_needed ? w_rows_available : rows_needed; + // tell worker about worker2 + MPI_Send (&worker2, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); + MPI_Send (&rows_accepted, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); + } + } + else { + throw std::runtime_error("HOST1: Unexpected request from " + std::to_string(worker)); + } + } + + // work finised, collect remaining data and release all workers + int count=1; // count number of release workers, skip host + while(true) { + MPI_Probe(MPI_ANY_SOURCE, MPI_ANY_TAG, MPI_COMM_WORLD, &status); + int worker = status.MPI_SOURCE; + int tag = status.MPI_TAG; + + if (tag==DATA_TAG) { + int rows_needed; + MPI_Recv (&rows_needed, 1, MPI_INT, worker, tag, MPI_COMM_WORLD, &status); + if (rows_available>0) { + int rows_accepted = rows_available < rows_needed ? rows_available : rows_needed; + MPI_Send (&b_rank, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); + MPI_Send (&rows_accepted, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); + MPI_Recv (&tr.dm.Phi.data()[phi_row], rows_accepted, rowvecs, worker, tag, MPI_COMM_WORLD, &status); + MPI_Recv (&tr.dm.T.data()[phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); + MPI_Recv (&tr.dm.Tlabels.data()[phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); + rows_available -= rows_accepted; + phi_row += rows_accepted; + if (rows_available<0 ) { throw std::runtime_error(" HOST2: The number of rows in the local array is smaller than requested.");} + } + else { + // host is unable to fit data we have to ask workers for their storage availability + // find a worker to accept at least some data + MPI_Status status2; + int worker2; + // find a worker capable of accepting data + int w_rows_available; + while (true) { + MPI_Recv (&w_rows_available, 1, MPI_INT, MPI_ANY_SOURCE, WORK_TAG, MPI_COMM_WORLD, &status2); + worker2 = status2.MPI_SOURCE; + if (worker==worker2) {throw std::runtime_error("worker and worker2 are the same."); } + if (w_rows_available==0 ) { + // give up on this worker and release him as there is no more work to be done + MPI_Send (0, 0, MPI_INT, worker2, RELEASE_TAG, MPI_COMM_WORLD); + count++; + } + else { + // found a worker + break; + } + } + int rows_accepted = w_rows_available < rows_needed ? w_rows_available : rows_needed; + MPI_Send (&worker2, 1, MPI_INT, worker, DATA_TAG, MPI_COMM_WORLD); + MPI_Send (&rows_accepted, 1, MPI_INT, worker, DATA_TAG, MPI_COMM_WORLD); + } + } + else { + int rows_available; + MPI_Recv (&rows_available, 1, MPI_INT, worker, WORK_TAG, MPI_COMM_WORLD, &status); + // there is no more work so release a worker + if (rows_available==0) { + MPI_Send (0, 0, MPI_INT, worker, RELEASE_TAG, MPI_COMM_WORLD); + count++; + } + else { + MPI_Send (0, 0, MPI_INT, worker, WAIT_TAG, MPI_COMM_WORLD); + } + } + if (count==ncpu) { break; } // count starts from 1 + } + } }; #endif