From ba9d3077bf38ce7d3ae863eded3c4c173224a155 Mon Sep 17 00:00:00 2001 From: Marcin Kirsz <mkirsz@ed.ac.uk> Date: Mon, 7 Oct 2024 13:07:23 +0100 Subject: [PATCH] moved tags to MPI trainer --- trainer.h | 37 ++++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/trainer.h b/trainer.h index 1b26782..81a9be1 100644 --- a/trainer.h +++ b/trainer.h @@ -54,7 +54,9 @@ class Trainer { }; #ifdef TADAH_BUILD_MPI + #include <mpi.h> + extern "C" void blacs_get_(int*, int*, int*); extern "C" void blacs_pinfo_(int*, int*); extern "C" void blacs_gridinit_(int*, char*, int*, int*); @@ -73,8 +75,13 @@ extern "C" void pdgemv_(char* transa, int* m, int* n, double* alpha, double* a, int* ia, int* ja, int* desc_a, double* x, int* ix, int* jx, int* desc_x, int* incx, double* beta, double* y, int* iy, int* jy, int* desc_y, int* incy); + class MPI_Trainer: public Trainer { public: +const static int WAIT_TAG = 3; +const static int RELEASE_TAG = 2; +const static int DATA_TAG = 1; +const static int WORK_TAG = 0; MPI_Status status; int worker; int tag; @@ -377,7 +384,7 @@ class MPI_Trainer_HOST { int w_rows_available; while (true) { MPI_Recv (&w_rows_available, 1, MPI_INT, MPI_ANY_SOURCE, - TadahCLI::WORK_TAG, MPI_COMM_WORLD, &status2); + WORK_TAG, MPI_COMM_WORLD, &status2); worker2 = status2.MPI_SOURCE; if (tr.worker==worker2) { throw std::runtime_error("worker and worker2 are the same."); @@ -386,25 +393,25 @@ class MPI_Trainer_HOST { } int rows_accepted = w_rows_available < rows_needed ? w_rows_available : rows_needed; - MPI_Send (&worker2, 1, MPI_INT, tr.worker, TadahCLI::DATA_TAG, + MPI_Send (&worker2, 1, MPI_INT, tr.worker, DATA_TAG, MPI_COMM_WORLD); MPI_Send (&rows_accepted, 1, MPI_INT, tr.worker, - TadahCLI::DATA_TAG, MPI_COMM_WORLD); + DATA_TAG, MPI_COMM_WORLD); } } void release_tag(int &count) { int rows_available; MPI_Recv (&rows_available, 1, MPI_INT, tr.worker, - TadahCLI::WORK_TAG, MPI_COMM_WORLD, &tr.status); + WORK_TAG, MPI_COMM_WORLD, &tr.status); // there is no more work so release a worker if full if (rows_available==0) { MPI_Send (0, 0, MPI_INT, tr.worker, - TadahCLI::RELEASE_TAG, MPI_COMM_WORLD); + RELEASE_TAG, MPI_COMM_WORLD); count++; } else { MPI_Send (0, 0, MPI_INT, tr.worker, - TadahCLI::WAIT_TAG, MPI_COMM_WORLD); + WAIT_TAG, MPI_COMM_WORLD); } } @@ -462,11 +469,11 @@ class MPI_Trainer_WORKER { MPI_Get_count(&tr.status, MPI_CHAR, &fn_length); char *fn = (char *) malloc(fn_length+1); - MPI_Recv (fn, fn_length, MPI_CHAR, 0, TadahCLI::WORK_TAG, + MPI_Recv (fn, fn_length, MPI_CHAR, 0, WORK_TAG, MPI_COMM_WORLD, &tr.status); - MPI_Recv (&first, 1, MPI_INT, 0, TadahCLI::WORK_TAG, + MPI_Recv (&first, 1, MPI_INT, 0, WORK_TAG, MPI_COMM_WORLD, &tr.status); - MPI_Recv (&nstruc, 1, MPI_INT, 0, TadahCLI::WORK_TAG, + MPI_Recv (&nstruc, 1, MPI_INT, 0, WORK_TAG, MPI_COMM_WORLD, &tr.status); // do work @@ -515,14 +522,14 @@ class MPI_Trainer_WORKER { // send remaining data to available processes while (rows_needed > 0) { // request host - MPI_Send (&rows_needed, 1, MPI_INT, 0, TadahCLI::DATA_TAG, MPI_COMM_WORLD); + MPI_Send (&rows_needed, 1, MPI_INT, 0, DATA_TAG, MPI_COMM_WORLD); int rows_accepted; // number of accepted rows int dest; // receiving process // host returns which dest can accept and how much - MPI_Recv (&dest, 1, MPI_INT, 0, TadahCLI::DATA_TAG, + MPI_Recv (&dest, 1, MPI_INT, 0, DATA_TAG, MPI_COMM_WORLD, &tr.status); - MPI_Recv (&rows_accepted, 1, MPI_INT, 0, TadahCLI::DATA_TAG, + MPI_Recv (&rows_accepted, 1, MPI_INT, 0, DATA_TAG, MPI_COMM_WORLD, &tr.status); // we send data to the host or a willing worker int start=temp_dm.Phi.rows()-rows_needed; @@ -538,11 +545,11 @@ class MPI_Trainer_WORKER { // ready to send MPI_Send (&temp_dm.Phi.data()[start], rows_accepted, - trowvecs, dest, TadahCLI::DATA_TAG, MPI_COMM_WORLD); + trowvecs, dest, DATA_TAG, MPI_COMM_WORLD); MPI_Send (&temp_dm.T.data()[start], rows_accepted, - MPI_DOUBLE, dest, TadahCLI::DATA_TAG, MPI_COMM_WORLD); + MPI_DOUBLE, dest, DATA_TAG, MPI_COMM_WORLD); MPI_Send (&temp_dm.Tlabels.data()[start], rows_accepted, - MPI_DOUBLE, dest, TadahCLI::DATA_TAG, MPI_COMM_WORLD); + MPI_DOUBLE, dest, DATA_TAG, MPI_COMM_WORLD); rows_needed -= rows_accepted; MPI_Type_free(&trowvec); MPI_Type_free(&trowvecs); -- GitLab