Skip to content
Snippets Groups Projects
Commit ba9d3077 authored by Marcin Kirsz's avatar Marcin Kirsz
Browse files

moved tags to MPI trainer

parent d7f054e2
No related branches found
No related tags found
No related merge requests found
Pipeline #43395 passed
Pipeline: Tadah.MLIP

#43396

    ...@@ -54,7 +54,9 @@ class Trainer { ...@@ -54,7 +54,9 @@ class Trainer {
    }; };
    #ifdef TADAH_BUILD_MPI #ifdef TADAH_BUILD_MPI
    #include <mpi.h> #include <mpi.h>
    extern "C" void blacs_get_(int*, int*, int*); extern "C" void blacs_get_(int*, int*, int*);
    extern "C" void blacs_pinfo_(int*, int*); extern "C" void blacs_pinfo_(int*, int*);
    extern "C" void blacs_gridinit_(int*, char*, int*, int*); extern "C" void blacs_gridinit_(int*, char*, int*, int*);
    ...@@ -73,8 +75,13 @@ extern "C" void pdgemv_(char* transa, int* m, int* n, double* alpha, double* a, ...@@ -73,8 +75,13 @@ extern "C" void pdgemv_(char* transa, int* m, int* n, double* alpha, double* a,
    int* ia, int* ja, int* desc_a, double* x, int* ix, int* jx, int* desc_x, int* ia, int* ja, int* desc_a, double* x, int* ix, int* jx, int* desc_x,
    int* incx, double* beta, double* y, int* iy, int* jy, int* desc_y, int* incy); int* incx, double* beta, double* y, int* iy, int* jy, int* desc_y, int* incy);
    class MPI_Trainer: public Trainer { class MPI_Trainer: public Trainer {
    public: public:
    const static int WAIT_TAG = 3;
    const static int RELEASE_TAG = 2;
    const static int DATA_TAG = 1;
    const static int WORK_TAG = 0;
    MPI_Status status; MPI_Status status;
    int worker; int worker;
    int tag; int tag;
    ...@@ -377,7 +384,7 @@ class MPI_Trainer_HOST { ...@@ -377,7 +384,7 @@ class MPI_Trainer_HOST {
    int w_rows_available; int w_rows_available;
    while (true) { while (true) {
    MPI_Recv (&w_rows_available, 1, MPI_INT, MPI_ANY_SOURCE, MPI_Recv (&w_rows_available, 1, MPI_INT, MPI_ANY_SOURCE,
    TadahCLI::WORK_TAG, MPI_COMM_WORLD, &status2); WORK_TAG, MPI_COMM_WORLD, &status2);
    worker2 = status2.MPI_SOURCE; worker2 = status2.MPI_SOURCE;
    if (tr.worker==worker2) { if (tr.worker==worker2) {
    throw std::runtime_error("worker and worker2 are the same."); throw std::runtime_error("worker and worker2 are the same.");
    ...@@ -386,25 +393,25 @@ class MPI_Trainer_HOST { ...@@ -386,25 +393,25 @@ class MPI_Trainer_HOST {
    } }
    int rows_accepted = w_rows_available < rows_needed ? int rows_accepted = w_rows_available < rows_needed ?
    w_rows_available : rows_needed; w_rows_available : rows_needed;
    MPI_Send (&worker2, 1, MPI_INT, tr.worker, TadahCLI::DATA_TAG, MPI_Send (&worker2, 1, MPI_INT, tr.worker, DATA_TAG,
    MPI_COMM_WORLD); MPI_COMM_WORLD);
    MPI_Send (&rows_accepted, 1, MPI_INT, tr.worker, MPI_Send (&rows_accepted, 1, MPI_INT, tr.worker,
    TadahCLI::DATA_TAG, MPI_COMM_WORLD); DATA_TAG, MPI_COMM_WORLD);
    } }
    } }
    void release_tag(int &count) { void release_tag(int &count) {
    int rows_available; int rows_available;
    MPI_Recv (&rows_available, 1, MPI_INT, tr.worker, MPI_Recv (&rows_available, 1, MPI_INT, tr.worker,
    TadahCLI::WORK_TAG, MPI_COMM_WORLD, &tr.status); WORK_TAG, MPI_COMM_WORLD, &tr.status);
    // there is no more work so release a worker if full // there is no more work so release a worker if full
    if (rows_available==0) { if (rows_available==0) {
    MPI_Send (0, 0, MPI_INT, tr.worker, MPI_Send (0, 0, MPI_INT, tr.worker,
    TadahCLI::RELEASE_TAG, MPI_COMM_WORLD); RELEASE_TAG, MPI_COMM_WORLD);
    count++; count++;
    } }
    else { else {
    MPI_Send (0, 0, MPI_INT, tr.worker, MPI_Send (0, 0, MPI_INT, tr.worker,
    TadahCLI::WAIT_TAG, MPI_COMM_WORLD); WAIT_TAG, MPI_COMM_WORLD);
    } }
    } }
    ...@@ -462,11 +469,11 @@ class MPI_Trainer_WORKER { ...@@ -462,11 +469,11 @@ class MPI_Trainer_WORKER {
    MPI_Get_count(&tr.status, MPI_CHAR, &fn_length); MPI_Get_count(&tr.status, MPI_CHAR, &fn_length);
    char *fn = (char *) malloc(fn_length+1); char *fn = (char *) malloc(fn_length+1);
    MPI_Recv (fn, fn_length, MPI_CHAR, 0, TadahCLI::WORK_TAG, MPI_Recv (fn, fn_length, MPI_CHAR, 0, WORK_TAG,
    MPI_COMM_WORLD, &tr.status); MPI_COMM_WORLD, &tr.status);
    MPI_Recv (&first, 1, MPI_INT, 0, TadahCLI::WORK_TAG, MPI_Recv (&first, 1, MPI_INT, 0, WORK_TAG,
    MPI_COMM_WORLD, &tr.status); MPI_COMM_WORLD, &tr.status);
    MPI_Recv (&nstruc, 1, MPI_INT, 0, TadahCLI::WORK_TAG, MPI_Recv (&nstruc, 1, MPI_INT, 0, WORK_TAG,
    MPI_COMM_WORLD, &tr.status); MPI_COMM_WORLD, &tr.status);
    // do work // do work
    ...@@ -515,14 +522,14 @@ class MPI_Trainer_WORKER { ...@@ -515,14 +522,14 @@ class MPI_Trainer_WORKER {
    // send remaining data to available processes // send remaining data to available processes
    while (rows_needed > 0) { while (rows_needed > 0) {
    // request host // request host
    MPI_Send (&rows_needed, 1, MPI_INT, 0, TadahCLI::DATA_TAG, MPI_COMM_WORLD); MPI_Send (&rows_needed, 1, MPI_INT, 0, DATA_TAG, MPI_COMM_WORLD);
    int rows_accepted; // number of accepted rows int rows_accepted; // number of accepted rows
    int dest; // receiving process int dest; // receiving process
    // host returns which dest can accept and how much // host returns which dest can accept and how much
    MPI_Recv (&dest, 1, MPI_INT, 0, TadahCLI::DATA_TAG, MPI_Recv (&dest, 1, MPI_INT, 0, DATA_TAG,
    MPI_COMM_WORLD, &tr.status); MPI_COMM_WORLD, &tr.status);
    MPI_Recv (&rows_accepted, 1, MPI_INT, 0, TadahCLI::DATA_TAG, MPI_Recv (&rows_accepted, 1, MPI_INT, 0, DATA_TAG,
    MPI_COMM_WORLD, &tr.status); MPI_COMM_WORLD, &tr.status);
    // we send data to the host or a willing worker // we send data to the host or a willing worker
    int start=temp_dm.Phi.rows()-rows_needed; int start=temp_dm.Phi.rows()-rows_needed;
    ...@@ -538,11 +545,11 @@ class MPI_Trainer_WORKER { ...@@ -538,11 +545,11 @@ class MPI_Trainer_WORKER {
    // ready to send // ready to send
    MPI_Send (&temp_dm.Phi.data()[start], rows_accepted, MPI_Send (&temp_dm.Phi.data()[start], rows_accepted,
    trowvecs, dest, TadahCLI::DATA_TAG, MPI_COMM_WORLD); trowvecs, dest, DATA_TAG, MPI_COMM_WORLD);
    MPI_Send (&temp_dm.T.data()[start], rows_accepted, MPI_Send (&temp_dm.T.data()[start], rows_accepted,
    MPI_DOUBLE, dest, TadahCLI::DATA_TAG, MPI_COMM_WORLD); MPI_DOUBLE, dest, DATA_TAG, MPI_COMM_WORLD);
    MPI_Send (&temp_dm.Tlabels.data()[start], rows_accepted, MPI_Send (&temp_dm.Tlabels.data()[start], rows_accepted,
    MPI_DOUBLE, dest, TadahCLI::DATA_TAG, MPI_COMM_WORLD); MPI_DOUBLE, dest, DATA_TAG, MPI_COMM_WORLD);
    rows_needed -= rows_accepted; rows_needed -= rows_accepted;
    MPI_Type_free(&trowvec); MPI_Type_free(&trowvec);
    MPI_Type_free(&trowvecs); MPI_Type_free(&trowvecs);
    ......
    0% Loading or .
    You are about to add 0 people to the discussion. Proceed with caution.
    Finish editing this message first!
    Please register or to comment