Skip to content
Snippets Groups Projects
Commit af3b3450 authored by Marcin Kirsz's avatar Marcin Kirsz
Browse files

Update

parent 081eeae4
No related branches found
No related tags found
No related merge requests found
Pipeline #43295 passed
......@@ -193,15 +193,13 @@ class MPI_Trainer: public Trainer {
class MPI_Trainer_HOST {
private:
MPI_Trainer &tr;
MPI_Status &status;
int &rank;
int &ncpu;
std::vector<std::tuple<std::string,int,int>> wpckgs;
public:
MPI_Trainer_HOST(MPI_Trainer &MPIT, MPI_Status &status, int &rank, int &ncpu):
MPI_Trainer_HOST(MPI_Trainer &MPIT, int &rank, int &ncpu):
tr(MPIT),
status(status),
rank(rank),
ncpu(ncpu)
{}
......@@ -233,7 +231,7 @@ class MPI_Trainer_HOST {
void work_tag() {
int rows_available;
MPI_Recv (&rows_available, 1, MPI_INT, tr.worker, tr.tag, MPI_COMM_WORLD, &status);
MPI_Recv (&rows_available, 1, MPI_INT, tr.worker, tr.tag, MPI_COMM_WORLD, &tr.status);
std::tuple<std::string,int,int> wpckg = wpckgs.back();
wpckgs.pop_back();
......@@ -255,14 +253,14 @@ class MPI_Trainer_HOST {
void data_tag(int &count) {
int rows_needed;
MPI_Recv (&rows_needed, 1, MPI_INT, tr.worker, tr.tag, MPI_COMM_WORLD, &status);
MPI_Recv (&rows_needed, 1, MPI_INT, tr.worker, tr.tag, MPI_COMM_WORLD, &tr.status);
if (tr.rows_available>0) {
int rows_accepted = tr.rows_available < rows_needed ? tr.rows_available : rows_needed;
MPI_Send (&tr.b_rank, 1, MPI_INT, tr.worker, tr.tag, MPI_COMM_WORLD);
MPI_Send (&rows_accepted, 1, MPI_INT, tr.worker, tr.tag, MPI_COMM_WORLD);
MPI_Recv (&tr.dm.Phi.data()[tr.phi_row], rows_accepted, tr.rowvecs, tr.worker, tr.tag, MPI_COMM_WORLD, &status);
MPI_Recv (&tr.dm.T.data()[tr.phi_row], rows_accepted, MPI_DOUBLE, tr.worker, tr.tag, MPI_COMM_WORLD, &status);
MPI_Recv (&tr.dm.Tlabels.data()[tr.phi_row], rows_accepted, MPI_DOUBLE, tr.worker, tr.tag, MPI_COMM_WORLD, &status);
MPI_Recv (&tr.dm.Phi.data()[tr.phi_row], rows_accepted, tr.rowvecs, tr.worker, tr.tag, MPI_COMM_WORLD, &tr.status);
MPI_Recv (&tr.dm.T.data()[tr.phi_row], rows_accepted, MPI_DOUBLE, tr.worker, tr.tag, MPI_COMM_WORLD, &tr.status);
MPI_Recv (&tr.dm.Tlabels.data()[tr.phi_row], rows_accepted, MPI_DOUBLE, tr.worker, tr.tag, MPI_COMM_WORLD, &tr.status);
tr.rows_available -= rows_accepted;
tr.phi_row += rows_accepted;
if (tr.rows_available<0 ) { throw std::runtime_error(" HOST2: The number of rows in the local array is smaller than requested.");}
......@@ -295,7 +293,7 @@ class MPI_Trainer_HOST {
}
void release_tag(int &count) {
int rows_available;
MPI_Recv (&rows_available, 1, MPI_INT, tr.worker, TadahCLI::WORK_TAG, MPI_COMM_WORLD, &status);
MPI_Recv (&rows_available, 1, MPI_INT, tr.worker, TadahCLI::WORK_TAG, MPI_COMM_WORLD, &tr.status);
// there is no more work so release a worker if full
if (rows_available==0) {
MPI_Send (0, 0, MPI_INT, tr.worker, TadahCLI::RELEASE_TAG, MPI_COMM_WORLD);
......@@ -310,37 +308,35 @@ class MPI_Trainer_HOST {
class MPI_Trainer_WORKER {
private:
MPI_Trainer &tr;
MPI_Status &status;
int &rank;
int &ncpu;
public:
MPI_Trainer_WORKER(MPI_Trainer &MPIT, MPI_Status &status, int &rank, int &ncpu):
MPI_Trainer_WORKER(MPI_Trainer &MPIT, int &rank, int &ncpu):
tr(MPIT),
status(status),
rank(rank),
ncpu(ncpu)
{}
bool release_tag() {
int temp;
MPI_Recv (&temp, 1, MPI_INT, tr.worker, tr.tag, MPI_COMM_WORLD, &status);
MPI_Recv (&temp, 1, MPI_INT, tr.worker, tr.tag, MPI_COMM_WORLD, &tr.status);
if (tr.rows_available!=0) { throw std::runtime_error("Attempting to release a worker... but the worker requires more data!!");}
return true;
}
void wait_tag() {
int temp;
MPI_Recv (&temp, 1, MPI_INT, tr.worker, tr.tag, MPI_COMM_WORLD, &status);
MPI_Recv (&temp, 1, MPI_INT, tr.worker, tr.tag, MPI_COMM_WORLD, &tr.status);
}
void data_tag() {
// other worker is giving me some data
int arr_size;
MPI_Get_count(&status, MPI_DOUBLE, &arr_size);
MPI_Get_count(&tr.status, MPI_DOUBLE, &arr_size);
int rows_accepted = arr_size/tr.phi_cols1;
if (tr.rows_available<rows_accepted) { throw std::runtime_error("Number of rows available is smaller than number of provided rows");}
MPI_Recv (&tr.dm.Phi.data()[tr.phi_row], tr.rows_available, tr.rowvecs, tr.worker, tr.tag, MPI_COMM_WORLD, &status);
MPI_Recv (&tr.dm.T.data()[tr.phi_row], tr.rows_available, MPI_DOUBLE, tr.worker, tr.tag, MPI_COMM_WORLD, &status);
MPI_Recv (&tr.dm.Tlabels.data()[tr.phi_row], tr.rows_available, MPI_DOUBLE, tr.worker, tr.tag, MPI_COMM_WORLD, &status);
MPI_Recv (&tr.dm.Phi.data()[tr.phi_row], tr.rows_available, tr.rowvecs, tr.worker, tr.tag, MPI_COMM_WORLD, &tr.status);
MPI_Recv (&tr.dm.T.data()[tr.phi_row], tr.rows_available, MPI_DOUBLE, tr.worker, tr.tag, MPI_COMM_WORLD, &tr.status);
MPI_Recv (&tr.dm.Tlabels.data()[tr.phi_row], tr.rows_available, MPI_DOUBLE, tr.worker, tr.tag, MPI_COMM_WORLD, &tr.status);
tr.rows_available -= rows_accepted;
tr.phi_row += rows_accepted;
}
......@@ -350,12 +346,12 @@ class MPI_Trainer_WORKER {
int fn_length; // length of the filename char array
int first; // index of the first structure to read from the file
int nstruc; // number of structures to be processed
MPI_Get_count(&status, MPI_CHAR, &fn_length);
MPI_Get_count(&tr.status, MPI_CHAR, &fn_length);
char *fn = (char *) malloc(fn_length+1);
MPI_Recv (fn, fn_length, MPI_CHAR, 0, TadahCLI::WORK_TAG, MPI_COMM_WORLD, &status);
MPI_Recv (&first, 1, MPI_INT, 0, TadahCLI::WORK_TAG, MPI_COMM_WORLD, &status);
MPI_Recv (&nstruc, 1, MPI_INT, 0, TadahCLI::WORK_TAG, MPI_COMM_WORLD, &status);
MPI_Recv (fn, fn_length, MPI_CHAR, 0, TadahCLI::WORK_TAG, MPI_COMM_WORLD, &tr.status);
MPI_Recv (&first, 1, MPI_INT, 0, TadahCLI::WORK_TAG, MPI_COMM_WORLD, &tr.status);
MPI_Recv (&nstruc, 1, MPI_INT, 0, TadahCLI::WORK_TAG, MPI_COMM_WORLD, &tr.status);
// do work
StructureDB stdb;
......@@ -407,9 +403,9 @@ class MPI_Trainer_WORKER {
int rows_accepted; // number of accepted rows
int dest; // receiving process
// host returns which dest can accept and how much
MPI_Recv (&dest, 1, MPI_INT, 0, TadahCLI::DATA_TAG, MPI_COMM_WORLD, &status);
MPI_Recv (&dest, 1, MPI_INT, 0, TadahCLI::DATA_TAG, MPI_COMM_WORLD, &tr.status);
MPI_Recv (&rows_accepted, 1, MPI_INT, 0, TadahCLI::DATA_TAG, MPI_COMM_WORLD, &status);
MPI_Recv (&rows_accepted, 1, MPI_INT, 0, TadahCLI::DATA_TAG, MPI_COMM_WORLD, &tr.status);
// we send data to the host or a willing worker
int start=temp_dm.Phi.rows()-rows_needed;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment