Skip to content
Snippets Groups Projects
Commit af3b3450 authored by Marcin Kirsz's avatar Marcin Kirsz
Browse files

Update

parent 081eeae4
No related branches found
No related tags found
No related merge requests found
Pipeline #43295 passed
Pipeline: Tadah.MLIP

#43296

    ...@@ -193,15 +193,13 @@ class MPI_Trainer: public Trainer { ...@@ -193,15 +193,13 @@ class MPI_Trainer: public Trainer {
    class MPI_Trainer_HOST { class MPI_Trainer_HOST {
    private: private:
    MPI_Trainer &tr; MPI_Trainer &tr;
    MPI_Status &status;
    int &rank; int &rank;
    int &ncpu; int &ncpu;
    std::vector<std::tuple<std::string,int,int>> wpckgs; std::vector<std::tuple<std::string,int,int>> wpckgs;
    public: public:
    MPI_Trainer_HOST(MPI_Trainer &MPIT, MPI_Status &status, int &rank, int &ncpu): MPI_Trainer_HOST(MPI_Trainer &MPIT, int &rank, int &ncpu):
    tr(MPIT), tr(MPIT),
    status(status),
    rank(rank), rank(rank),
    ncpu(ncpu) ncpu(ncpu)
    {} {}
    ...@@ -233,7 +231,7 @@ class MPI_Trainer_HOST { ...@@ -233,7 +231,7 @@ class MPI_Trainer_HOST {
    void work_tag() { void work_tag() {
    int rows_available; int rows_available;
    MPI_Recv (&rows_available, 1, MPI_INT, tr.worker, tr.tag, MPI_COMM_WORLD, &status); MPI_Recv (&rows_available, 1, MPI_INT, tr.worker, tr.tag, MPI_COMM_WORLD, &tr.status);
    std::tuple<std::string,int,int> wpckg = wpckgs.back(); std::tuple<std::string,int,int> wpckg = wpckgs.back();
    wpckgs.pop_back(); wpckgs.pop_back();
    ...@@ -255,14 +253,14 @@ class MPI_Trainer_HOST { ...@@ -255,14 +253,14 @@ class MPI_Trainer_HOST {
    void data_tag(int &count) { void data_tag(int &count) {
    int rows_needed; int rows_needed;
    MPI_Recv (&rows_needed, 1, MPI_INT, tr.worker, tr.tag, MPI_COMM_WORLD, &status); MPI_Recv (&rows_needed, 1, MPI_INT, tr.worker, tr.tag, MPI_COMM_WORLD, &tr.status);
    if (tr.rows_available>0) { if (tr.rows_available>0) {
    int rows_accepted = tr.rows_available < rows_needed ? tr.rows_available : rows_needed; int rows_accepted = tr.rows_available < rows_needed ? tr.rows_available : rows_needed;
    MPI_Send (&tr.b_rank, 1, MPI_INT, tr.worker, tr.tag, MPI_COMM_WORLD); MPI_Send (&tr.b_rank, 1, MPI_INT, tr.worker, tr.tag, MPI_COMM_WORLD);
    MPI_Send (&rows_accepted, 1, MPI_INT, tr.worker, tr.tag, MPI_COMM_WORLD); MPI_Send (&rows_accepted, 1, MPI_INT, tr.worker, tr.tag, MPI_COMM_WORLD);
    MPI_Recv (&tr.dm.Phi.data()[tr.phi_row], rows_accepted, tr.rowvecs, tr.worker, tr.tag, MPI_COMM_WORLD, &status); MPI_Recv (&tr.dm.Phi.data()[tr.phi_row], rows_accepted, tr.rowvecs, tr.worker, tr.tag, MPI_COMM_WORLD, &tr.status);
    MPI_Recv (&tr.dm.T.data()[tr.phi_row], rows_accepted, MPI_DOUBLE, tr.worker, tr.tag, MPI_COMM_WORLD, &status); MPI_Recv (&tr.dm.T.data()[tr.phi_row], rows_accepted, MPI_DOUBLE, tr.worker, tr.tag, MPI_COMM_WORLD, &tr.status);
    MPI_Recv (&tr.dm.Tlabels.data()[tr.phi_row], rows_accepted, MPI_DOUBLE, tr.worker, tr.tag, MPI_COMM_WORLD, &status); MPI_Recv (&tr.dm.Tlabels.data()[tr.phi_row], rows_accepted, MPI_DOUBLE, tr.worker, tr.tag, MPI_COMM_WORLD, &tr.status);
    tr.rows_available -= rows_accepted; tr.rows_available -= rows_accepted;
    tr.phi_row += rows_accepted; tr.phi_row += rows_accepted;
    if (tr.rows_available<0 ) { throw std::runtime_error(" HOST2: The number of rows in the local array is smaller than requested.");} if (tr.rows_available<0 ) { throw std::runtime_error(" HOST2: The number of rows in the local array is smaller than requested.");}
    ...@@ -295,7 +293,7 @@ class MPI_Trainer_HOST { ...@@ -295,7 +293,7 @@ class MPI_Trainer_HOST {
    } }
    void release_tag(int &count) { void release_tag(int &count) {
    int rows_available; int rows_available;
    MPI_Recv (&rows_available, 1, MPI_INT, tr.worker, TadahCLI::WORK_TAG, MPI_COMM_WORLD, &status); MPI_Recv (&rows_available, 1, MPI_INT, tr.worker, TadahCLI::WORK_TAG, MPI_COMM_WORLD, &tr.status);
    // there is no more work so release a worker if full // there is no more work so release a worker if full
    if (rows_available==0) { if (rows_available==0) {
    MPI_Send (0, 0, MPI_INT, tr.worker, TadahCLI::RELEASE_TAG, MPI_COMM_WORLD); MPI_Send (0, 0, MPI_INT, tr.worker, TadahCLI::RELEASE_TAG, MPI_COMM_WORLD);
    ...@@ -310,37 +308,35 @@ class MPI_Trainer_HOST { ...@@ -310,37 +308,35 @@ class MPI_Trainer_HOST {
    class MPI_Trainer_WORKER { class MPI_Trainer_WORKER {
    private: private:
    MPI_Trainer &tr; MPI_Trainer &tr;
    MPI_Status &status;
    int &rank; int &rank;
    int &ncpu; int &ncpu;
    public: public:
    MPI_Trainer_WORKER(MPI_Trainer &MPIT, MPI_Status &status, int &rank, int &ncpu): MPI_Trainer_WORKER(MPI_Trainer &MPIT, int &rank, int &ncpu):
    tr(MPIT), tr(MPIT),
    status(status),
    rank(rank), rank(rank),
    ncpu(ncpu) ncpu(ncpu)
    {} {}
    bool release_tag() { bool release_tag() {
    int temp; int temp;
    MPI_Recv (&temp, 1, MPI_INT, tr.worker, tr.tag, MPI_COMM_WORLD, &status); MPI_Recv (&temp, 1, MPI_INT, tr.worker, tr.tag, MPI_COMM_WORLD, &tr.status);
    if (tr.rows_available!=0) { throw std::runtime_error("Attempting to release a worker... but the worker requires more data!!");} if (tr.rows_available!=0) { throw std::runtime_error("Attempting to release a worker... but the worker requires more data!!");}
    return true; return true;
    } }
    void wait_tag() { void wait_tag() {
    int temp; int temp;
    MPI_Recv (&temp, 1, MPI_INT, tr.worker, tr.tag, MPI_COMM_WORLD, &status); MPI_Recv (&temp, 1, MPI_INT, tr.worker, tr.tag, MPI_COMM_WORLD, &tr.status);
    } }
    void data_tag() { void data_tag() {
    // other worker is giving me some data // other worker is giving me some data
    int arr_size; int arr_size;
    MPI_Get_count(&status, MPI_DOUBLE, &arr_size); MPI_Get_count(&tr.status, MPI_DOUBLE, &arr_size);
    int rows_accepted = arr_size/tr.phi_cols1; int rows_accepted = arr_size/tr.phi_cols1;
    if (tr.rows_available<rows_accepted) { throw std::runtime_error("Number of rows available is smaller than number of provided rows");} if (tr.rows_available<rows_accepted) { throw std::runtime_error("Number of rows available is smaller than number of provided rows");}
    MPI_Recv (&tr.dm.Phi.data()[tr.phi_row], tr.rows_available, tr.rowvecs, tr.worker, tr.tag, MPI_COMM_WORLD, &status); MPI_Recv (&tr.dm.Phi.data()[tr.phi_row], tr.rows_available, tr.rowvecs, tr.worker, tr.tag, MPI_COMM_WORLD, &tr.status);
    MPI_Recv (&tr.dm.T.data()[tr.phi_row], tr.rows_available, MPI_DOUBLE, tr.worker, tr.tag, MPI_COMM_WORLD, &status); MPI_Recv (&tr.dm.T.data()[tr.phi_row], tr.rows_available, MPI_DOUBLE, tr.worker, tr.tag, MPI_COMM_WORLD, &tr.status);
    MPI_Recv (&tr.dm.Tlabels.data()[tr.phi_row], tr.rows_available, MPI_DOUBLE, tr.worker, tr.tag, MPI_COMM_WORLD, &status); MPI_Recv (&tr.dm.Tlabels.data()[tr.phi_row], tr.rows_available, MPI_DOUBLE, tr.worker, tr.tag, MPI_COMM_WORLD, &tr.status);
    tr.rows_available -= rows_accepted; tr.rows_available -= rows_accepted;
    tr.phi_row += rows_accepted; tr.phi_row += rows_accepted;
    } }
    ...@@ -350,12 +346,12 @@ class MPI_Trainer_WORKER { ...@@ -350,12 +346,12 @@ class MPI_Trainer_WORKER {
    int fn_length; // length of the filename char array int fn_length; // length of the filename char array
    int first; // index of the first structure to read from the file int first; // index of the first structure to read from the file
    int nstruc; // number of structures to be processed int nstruc; // number of structures to be processed
    MPI_Get_count(&status, MPI_CHAR, &fn_length); MPI_Get_count(&tr.status, MPI_CHAR, &fn_length);
    char *fn = (char *) malloc(fn_length+1); char *fn = (char *) malloc(fn_length+1);
    MPI_Recv (fn, fn_length, MPI_CHAR, 0, TadahCLI::WORK_TAG, MPI_COMM_WORLD, &status); MPI_Recv (fn, fn_length, MPI_CHAR, 0, TadahCLI::WORK_TAG, MPI_COMM_WORLD, &tr.status);
    MPI_Recv (&first, 1, MPI_INT, 0, TadahCLI::WORK_TAG, MPI_COMM_WORLD, &status); MPI_Recv (&first, 1, MPI_INT, 0, TadahCLI::WORK_TAG, MPI_COMM_WORLD, &tr.status);
    MPI_Recv (&nstruc, 1, MPI_INT, 0, TadahCLI::WORK_TAG, MPI_COMM_WORLD, &status); MPI_Recv (&nstruc, 1, MPI_INT, 0, TadahCLI::WORK_TAG, MPI_COMM_WORLD, &tr.status);
    // do work // do work
    StructureDB stdb; StructureDB stdb;
    ...@@ -407,9 +403,9 @@ class MPI_Trainer_WORKER { ...@@ -407,9 +403,9 @@ class MPI_Trainer_WORKER {
    int rows_accepted; // number of accepted rows int rows_accepted; // number of accepted rows
    int dest; // receiving process int dest; // receiving process
    // host returns which dest can accept and how much // host returns which dest can accept and how much
    MPI_Recv (&dest, 1, MPI_INT, 0, TadahCLI::DATA_TAG, MPI_COMM_WORLD, &status); MPI_Recv (&dest, 1, MPI_INT, 0, TadahCLI::DATA_TAG, MPI_COMM_WORLD, &tr.status);
    MPI_Recv (&rows_accepted, 1, MPI_INT, 0, TadahCLI::DATA_TAG, MPI_COMM_WORLD, &status); MPI_Recv (&rows_accepted, 1, MPI_INT, 0, TadahCLI::DATA_TAG, MPI_COMM_WORLD, &tr.status);
    // we send data to the host or a willing worker // we send data to the host or a willing worker
    int start=temp_dm.Phi.rows()-rows_needed; int start=temp_dm.Phi.rows()-rows_needed;
    ......
    0% Loading or .
    You are about to add 0 people to the discussion. Proceed with caution.
    Finish editing this message first!
    Please register or to comment