Skip to content
Snippets Groups Projects
Commit fe48ac41 authored by Marcin Kirsz's avatar Marcin Kirsz
Browse files

Update

parent 50847c10
No related branches found
No related tags found
No related merge requests found
Pipeline #43194 passed
Pipeline: Tadah.MLIP

#43196

    ......@@ -179,107 +179,6 @@ class MPI_Trainer: public Trainer {
    dm.Tlabels.resize(phi_rows1);
    }
    int calc_desc(MPI_Status &status, int &rows_available, size_t &phi_row) {
    // get work package
    int fn_length; // length of the filename char array
    int first; // index of the first structure to read from the file
    int nstruc; // number of structures to be processed
    MPI_Get_count(&status, MPI_CHAR, &fn_length);
    char *fn = (char *) malloc(fn_length+1);
    MPI_Recv (fn, fn_length, MPI_CHAR, 0, TadahCLI::WORK_TAG, MPI_COMM_WORLD, &status);
    MPI_Recv (&first, 1, MPI_INT, 0, TadahCLI::WORK_TAG, MPI_COMM_WORLD, &status);
    MPI_Recv (&nstruc, 1, MPI_INT, 0, TadahCLI::WORK_TAG, MPI_COMM_WORLD, &status);
    // do work
    StructureDB stdb;
    stdb.add(std::string(fn,fn_length),first,nstruc);
    nnf.calc(stdb);
    // compute number of rows needed for a given StructureDB
    int rows_needed = 0;
    for (size_t s=0; s<stdb.size(); ++s) {
    int natoms = stdb(s).natoms();
    rows_needed += DesignMatrixBase::phi_rows_num(config, 1, natoms);
    }
    if (rows_available<rows_needed) {
    // we do not have enough rows in the local phi matrix
    // so we create temp DM of required size
    DesignMatrix<DM_Function_Base&> temp_dm(*fb, config);
    temp_dm.Phi.resize(rows_needed,dm.Phi.cols());
    temp_dm.T.resize(rows_needed);
    temp_dm.Tlabels.resize(rows_needed);
    // and compute all rows
    size_t temp_phi_row=0;
    temp_dm.fill_T(stdb);
    for (size_t s=0; s<stdb.size(); ++s) {
    StDescriptors st_d = dc.calc(stdb(s));
    temp_dm.build(temp_phi_row,stdb(s),st_d); // phi_row++
    }
    // first we try to fill remaining rows in the local phi matrix
    // copy top of temp_dm.Phi to the bottom of dm. Phi in reverse order
    if (rows_available>0) {
    for (rows_available; rows_available>0; rows_available--) {
    for (int c=0; c<dm.Phi.cols(); c++) {
    dm.Phi(phi_row,c) = temp_dm.Phi(rows_available-1,c);
    dm.T(phi_row) = temp_dm.T(rows_available-1);
    dm.Tlabels(phi_row) = temp_dm.Tlabels(rows_available-1);
    }
    phi_row++;
    rows_needed--;
    }
    }
    // there are no more available rows
    // send remaining data to available processes
    while (rows_needed > 0) {
    // request host
    MPI_Send (&rows_needed, 1, MPI_INT, 0, TadahCLI::DATA_TAG, MPI_COMM_WORLD);
    int rows_accepted; // number of accepted rows
    int dest; // receiving process
    // host returns which dest can accept and how much
    MPI_Recv (&dest, 1, MPI_INT, 0, TadahCLI::DATA_TAG, MPI_COMM_WORLD, &status);
    MPI_Recv (&rows_accepted, 1, MPI_INT, 0, TadahCLI::DATA_TAG, MPI_COMM_WORLD, &status);
    // we send data to the host or a willing worker
    int start=temp_dm.Phi.rows()-rows_needed;
    // Define temp data type for temp Phi matrix
    // Phi is stored in a column-major order
    MPI_Datatype trowvec, trowvecs;
    MPI_Type_vector( temp_dm.Phi.cols(), 1, temp_dm.Phi.rows(), MPI_DOUBLE, &trowvec);
    MPI_Type_commit(&trowvec);
    MPI_Type_create_resized(trowvec, 0, 1*sizeof(double), &trowvecs);
    MPI_Type_commit(&trowvecs);
    // ready to send
    MPI_Send (&temp_dm.Phi.data()[start], rows_accepted, trowvecs, dest, TadahCLI::DATA_TAG, MPI_COMM_WORLD);
    MPI_Send (&temp_dm.T.data()[start], rows_accepted, MPI_DOUBLE, dest, TadahCLI::DATA_TAG, MPI_COMM_WORLD);
    MPI_Send (&temp_dm.Tlabels.data()[start], rows_accepted, MPI_DOUBLE, dest, TadahCLI::DATA_TAG, MPI_COMM_WORLD);
    rows_needed -= rows_accepted;
    MPI_Type_free(&trowvec);
    MPI_Type_free(&trowvecs);
    }
    }
    else {
    // just fill local phi array as it is large enough
    // fill_T must be called before phi_row is incremented
    dm.fill_T(stdb,phi_row); // phi_row is not incremented by this method
    for (size_t s=0; s<stdb.size(); ++s) {
    StDescriptors st_d = dc.calc(stdb(s));
    dm.build(phi_row,stdb(s),st_d); // build() increments phi_row++
    }
    rows_available-=rows_needed;
    }
    if (fn)
    delete fn;
    return 0;
    }
    };
    class MPI_Trainer_HOST {
    ......@@ -436,5 +335,106 @@ class MPI_Trainer_WORKER {
    tr.rows_available -= rows_accepted;
    tr.phi_row += rows_accepted;
    }
    int work_tag(int &worker, int &tag) {
    // get work package
    int fn_length; // length of the filename char array
    int first; // index of the first structure to read from the file
    int nstruc; // number of structures to be processed
    MPI_Get_count(&status, MPI_CHAR, &fn_length);
    char *fn = (char *) malloc(fn_length+1);
    MPI_Recv (fn, fn_length, MPI_CHAR, 0, TadahCLI::WORK_TAG, MPI_COMM_WORLD, &status);
    MPI_Recv (&first, 1, MPI_INT, 0, TadahCLI::WORK_TAG, MPI_COMM_WORLD, &status);
    MPI_Recv (&nstruc, 1, MPI_INT, 0, TadahCLI::WORK_TAG, MPI_COMM_WORLD, &status);
    // do work
    StructureDB stdb;
    stdb.add(std::string(fn,fn_length),first,nstruc);
    nnf.calc(stdb);
    // compute number of rows needed for a given StructureDB
    int rows_needed = 0;
    for (size_t s=0; s<stdb.size(); ++s) {
    int natoms = stdb(s).natoms();
    rows_needed += DesignMatrixBase::phi_rows_num(config, 1, natoms);
    }
    if (rows_available<rows_needed) {
    // we do not have enough rows in the local phi matrix
    // so we create temp DM of required size
    DesignMatrix<DM_Function_Base&> temp_dm(*fb, config);
    temp_dm.Phi.resize(rows_needed,dm.Phi.cols());
    temp_dm.T.resize(rows_needed);
    temp_dm.Tlabels.resize(rows_needed);
    // and compute all rows
    size_t temp_phi_row=0;
    temp_dm.fill_T(stdb);
    for (size_t s=0; s<stdb.size(); ++s) {
    StDescriptors st_d = dc.calc(stdb(s));
    temp_dm.build(temp_phi_row,stdb(s),st_d); // phi_row++
    }
    // first we try to fill remaining rows in the local phi matrix
    // copy top of temp_dm.Phi to the bottom of dm. Phi in reverse order
    if (rows_available>0) {
    for (rows_available; rows_available>0; rows_available--) {
    for (int c=0; c<dm.Phi.cols(); c++) {
    dm.Phi(phi_row,c) = temp_dm.Phi(rows_available-1,c);
    dm.T(phi_row) = temp_dm.T(rows_available-1);
    dm.Tlabels(phi_row) = temp_dm.Tlabels(rows_available-1);
    }
    phi_row++;
    rows_needed--;
    }
    }
    // there are no more available rows
    // send remaining data to available processes
    while (rows_needed > 0) {
    // request host
    MPI_Send (&rows_needed, 1, MPI_INT, 0, TadahCLI::DATA_TAG, MPI_COMM_WORLD);
    int rows_accepted; // number of accepted rows
    int dest; // receiving process
    // host returns which dest can accept and how much
    MPI_Recv (&dest, 1, MPI_INT, 0, TadahCLI::DATA_TAG, MPI_COMM_WORLD, &status);
    MPI_Recv (&rows_accepted, 1, MPI_INT, 0, TadahCLI::DATA_TAG, MPI_COMM_WORLD, &status);
    // we send data to the host or a willing worker
    int start=temp_dm.Phi.rows()-rows_needed;
    // Define temp data type for temp Phi matrix
    // Phi is stored in a column-major order
    MPI_Datatype trowvec, trowvecs;
    MPI_Type_vector( temp_dm.Phi.cols(), 1, temp_dm.Phi.rows(), MPI_DOUBLE, &trowvec);
    MPI_Type_commit(&trowvec);
    MPI_Type_create_resized(trowvec, 0, 1*sizeof(double), &trowvecs);
    MPI_Type_commit(&trowvecs);
    // ready to send
    MPI_Send (&temp_dm.Phi.data()[start], rows_accepted, trowvecs, dest, TadahCLI::DATA_TAG, MPI_COMM_WORLD);
    MPI_Send (&temp_dm.T.data()[start], rows_accepted, MPI_DOUBLE, dest, TadahCLI::DATA_TAG, MPI_COMM_WORLD);
    MPI_Send (&temp_dm.Tlabels.data()[start], rows_accepted, MPI_DOUBLE, dest, TadahCLI::DATA_TAG, MPI_COMM_WORLD);
    rows_needed -= rows_accepted;
    MPI_Type_free(&trowvec);
    MPI_Type_free(&trowvecs);
    }
    }
    else {
    // just fill local phi array as it is large enough
    // fill_T must be called before phi_row is incremented
    dm.fill_T(stdb,phi_row); // phi_row is not incremented by this method
    for (size_t s=0; s<stdb.size(); ++s) {
    StDescriptors st_d = dc.calc(stdb(s));
    dm.build(phi_row,stdb(s),st_d); // build() increments phi_row++
    }
    rows_available-=rows_needed;
    }
    if (fn)
    delete fn;
    return 0;
    }
    };
    #endif
    0% Loading or .
    You are about to add 0 people to the discussion. Proceed with caution.
    Finish editing this message first!
    Please register or to comment