diff --git a/trainer.h b/trainer.h index 9218404cce5d24ed3ff97e932d324075ed3c607c..dfc477d8c0cf4892e77857ec95324f6534a9f4cb 100644 --- a/trainer.h +++ b/trainer.h @@ -193,15 +193,13 @@ class MPI_Trainer: public Trainer { class MPI_Trainer_HOST { private: MPI_Trainer &tr; - MPI_Status &status; int &rank; int &ncpu; std::vector<std::tuple<std::string,int,int>> wpckgs; public: - MPI_Trainer_HOST(MPI_Trainer &MPIT, MPI_Status &status, int &rank, int &ncpu): + MPI_Trainer_HOST(MPI_Trainer &MPIT, int &rank, int &ncpu): tr(MPIT), - status(status), rank(rank), ncpu(ncpu) {} @@ -233,7 +231,7 @@ class MPI_Trainer_HOST { void work_tag() { int rows_available; - MPI_Recv (&rows_available, 1, MPI_INT, tr.worker, tr.tag, MPI_COMM_WORLD, &status); + MPI_Recv (&rows_available, 1, MPI_INT, tr.worker, tr.tag, MPI_COMM_WORLD, &tr.status); std::tuple<std::string,int,int> wpckg = wpckgs.back(); wpckgs.pop_back(); @@ -255,14 +253,14 @@ class MPI_Trainer_HOST { void data_tag(int &count) { int rows_needed; - MPI_Recv (&rows_needed, 1, MPI_INT, tr.worker, tr.tag, MPI_COMM_WORLD, &status); + MPI_Recv (&rows_needed, 1, MPI_INT, tr.worker, tr.tag, MPI_COMM_WORLD, &tr.status); if (tr.rows_available>0) { int rows_accepted = tr.rows_available < rows_needed ? tr.rows_available : rows_needed; MPI_Send (&tr.b_rank, 1, MPI_INT, tr.worker, tr.tag, MPI_COMM_WORLD); MPI_Send (&rows_accepted, 1, MPI_INT, tr.worker, tr.tag, MPI_COMM_WORLD); - MPI_Recv (&tr.dm.Phi.data()[tr.phi_row], rows_accepted, tr.rowvecs, tr.worker, tr.tag, MPI_COMM_WORLD, &status); - MPI_Recv (&tr.dm.T.data()[tr.phi_row], rows_accepted, MPI_DOUBLE, tr.worker, tr.tag, MPI_COMM_WORLD, &status); - MPI_Recv (&tr.dm.Tlabels.data()[tr.phi_row], rows_accepted, MPI_DOUBLE, tr.worker, tr.tag, MPI_COMM_WORLD, &status); + MPI_Recv (&tr.dm.Phi.data()[tr.phi_row], rows_accepted, tr.rowvecs, tr.worker, tr.tag, MPI_COMM_WORLD, &tr.status); + MPI_Recv (&tr.dm.T.data()[tr.phi_row], rows_accepted, MPI_DOUBLE, tr.worker, tr.tag, MPI_COMM_WORLD, &tr.status); + MPI_Recv (&tr.dm.Tlabels.data()[tr.phi_row], rows_accepted, MPI_DOUBLE, tr.worker, tr.tag, MPI_COMM_WORLD, &tr.status); tr.rows_available -= rows_accepted; tr.phi_row += rows_accepted; if (tr.rows_available<0 ) { throw std::runtime_error(" HOST2: The number of rows in the local array is smaller than requested.");} @@ -295,7 +293,7 @@ class MPI_Trainer_HOST { } void release_tag(int &count) { int rows_available; - MPI_Recv (&rows_available, 1, MPI_INT, tr.worker, TadahCLI::WORK_TAG, MPI_COMM_WORLD, &status); + MPI_Recv (&rows_available, 1, MPI_INT, tr.worker, TadahCLI::WORK_TAG, MPI_COMM_WORLD, &tr.status); // there is no more work so release a worker if full if (rows_available==0) { MPI_Send (0, 0, MPI_INT, tr.worker, TadahCLI::RELEASE_TAG, MPI_COMM_WORLD); @@ -310,37 +308,35 @@ class MPI_Trainer_HOST { class MPI_Trainer_WORKER { private: MPI_Trainer &tr; - MPI_Status &status; int &rank; int &ncpu; public: - MPI_Trainer_WORKER(MPI_Trainer &MPIT, MPI_Status &status, int &rank, int &ncpu): + MPI_Trainer_WORKER(MPI_Trainer &MPIT, int &rank, int &ncpu): tr(MPIT), - status(status), rank(rank), ncpu(ncpu) {} bool release_tag() { int temp; - MPI_Recv (&temp, 1, MPI_INT, tr.worker, tr.tag, MPI_COMM_WORLD, &status); + MPI_Recv (&temp, 1, MPI_INT, tr.worker, tr.tag, MPI_COMM_WORLD, &tr.status); if (tr.rows_available!=0) { throw std::runtime_error("Attempting to release a worker... but the worker requires more data!!");} return true; } void wait_tag() { int temp; - MPI_Recv (&temp, 1, MPI_INT, tr.worker, tr.tag, MPI_COMM_WORLD, &status); + MPI_Recv (&temp, 1, MPI_INT, tr.worker, tr.tag, MPI_COMM_WORLD, &tr.status); } void data_tag() { // other worker is giving me some data int arr_size; - MPI_Get_count(&status, MPI_DOUBLE, &arr_size); + MPI_Get_count(&tr.status, MPI_DOUBLE, &arr_size); int rows_accepted = arr_size/tr.phi_cols1; if (tr.rows_available<rows_accepted) { throw std::runtime_error("Number of rows available is smaller than number of provided rows");} - MPI_Recv (&tr.dm.Phi.data()[tr.phi_row], tr.rows_available, tr.rowvecs, tr.worker, tr.tag, MPI_COMM_WORLD, &status); - MPI_Recv (&tr.dm.T.data()[tr.phi_row], tr.rows_available, MPI_DOUBLE, tr.worker, tr.tag, MPI_COMM_WORLD, &status); - MPI_Recv (&tr.dm.Tlabels.data()[tr.phi_row], tr.rows_available, MPI_DOUBLE, tr.worker, tr.tag, MPI_COMM_WORLD, &status); + MPI_Recv (&tr.dm.Phi.data()[tr.phi_row], tr.rows_available, tr.rowvecs, tr.worker, tr.tag, MPI_COMM_WORLD, &tr.status); + MPI_Recv (&tr.dm.T.data()[tr.phi_row], tr.rows_available, MPI_DOUBLE, tr.worker, tr.tag, MPI_COMM_WORLD, &tr.status); + MPI_Recv (&tr.dm.Tlabels.data()[tr.phi_row], tr.rows_available, MPI_DOUBLE, tr.worker, tr.tag, MPI_COMM_WORLD, &tr.status); tr.rows_available -= rows_accepted; tr.phi_row += rows_accepted; } @@ -350,12 +346,12 @@ class MPI_Trainer_WORKER { int fn_length; // length of the filename char array int first; // index of the first structure to read from the file int nstruc; // number of structures to be processed - MPI_Get_count(&status, MPI_CHAR, &fn_length); + MPI_Get_count(&tr.status, MPI_CHAR, &fn_length); char *fn = (char *) malloc(fn_length+1); - MPI_Recv (fn, fn_length, MPI_CHAR, 0, TadahCLI::WORK_TAG, MPI_COMM_WORLD, &status); - MPI_Recv (&first, 1, MPI_INT, 0, TadahCLI::WORK_TAG, MPI_COMM_WORLD, &status); - MPI_Recv (&nstruc, 1, MPI_INT, 0, TadahCLI::WORK_TAG, MPI_COMM_WORLD, &status); + MPI_Recv (fn, fn_length, MPI_CHAR, 0, TadahCLI::WORK_TAG, MPI_COMM_WORLD, &tr.status); + MPI_Recv (&first, 1, MPI_INT, 0, TadahCLI::WORK_TAG, MPI_COMM_WORLD, &tr.status); + MPI_Recv (&nstruc, 1, MPI_INT, 0, TadahCLI::WORK_TAG, MPI_COMM_WORLD, &tr.status); // do work StructureDB stdb; @@ -407,9 +403,9 @@ class MPI_Trainer_WORKER { int rows_accepted; // number of accepted rows int dest; // receiving process // host returns which dest can accept and how much - MPI_Recv (&dest, 1, MPI_INT, 0, TadahCLI::DATA_TAG, MPI_COMM_WORLD, &status); + MPI_Recv (&dest, 1, MPI_INT, 0, TadahCLI::DATA_TAG, MPI_COMM_WORLD, &tr.status); - MPI_Recv (&rows_accepted, 1, MPI_INT, 0, TadahCLI::DATA_TAG, MPI_COMM_WORLD, &status); + MPI_Recv (&rows_accepted, 1, MPI_INT, 0, TadahCLI::DATA_TAG, MPI_COMM_WORLD, &tr.status); // we send data to the host or a willing worker int start=temp_dm.Phi.rows()-rows_needed;