Skip to content
Snippets Groups Projects
Commit 757a3637 authored by Marcin Kirsz's avatar Marcin Kirsz
Browse files

Update

parent 2abf2b31
No related branches found
No related tags found
No related merge requests found
Pipeline #42848 passed
Pipeline: Tadah.MLIP

#42849

    ......@@ -62,4 +62,136 @@ class Trainer {
    Config get_param_file() {
    return model->get_param_file();
    }
    #ifdef TADAH_BUILD_MPI
    int train_worker(int &rows_available, MPI_Status &status, size_t &phi_row,
    int &phi_cols1, DesignMatrix<DM_Function_Base&> &dm, MPI_Datatype &rowvecs,
    int &worker, int &tag) {
    // release a worker
    if (status.MPI_TAG == TadahCLI::RELEASE_TAG) {
    int temp;
    MPI_Recv (&temp, 1, MPI_INT, worker, tag, MPI_COMM_WORLD, &status);
    if (rows_available!=0) { throw std::runtime_error("Attempting to release a worker... but the worker requires more data!!");}
    return 0;
    }
    else if (status.MPI_TAG == TadahCLI::WAIT_TAG) {
    int temp;
    MPI_Recv (&temp, 1, MPI_INT, worker, tag, MPI_COMM_WORLD, &status);
    // do nothing here; ask for more work in the next cycle
    }
    else if (status.MPI_TAG == TadahCLI::DATA_TAG) {
    // other worker is giving me some data
    int arr_size;
    MPI_Get_count(&status, MPI_DOUBLE, &arr_size);
    int rows_accepted = arr_size/phi_cols1;
    if (rows_available<rows_accepted) { throw std::runtime_error("Number of rows available is smaller than number of provided rows");}
    MPI_Recv (&dm.Phi.data()[phi_row], rows_available, rowvecs, worker, tag, MPI_COMM_WORLD, &status);
    MPI_Recv (&dm.T.data()[phi_row], rows_available, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status);
    MPI_Recv (&dm.Tlabels.data()[phi_row], rows_available, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status);
    rows_available -= rows_accepted;
    phi_row += rows_accepted;
    }
    else if (status.MPI_TAG == TadahCLI::WORK_TAG) {
    // otherwise get work package
    int fn_length; // length of the filename char array
    int first; // index of the first structure to read from the file
    int nstruc; // number of structures to be processed
    MPI_Get_count(&status, MPI_CHAR, &fn_length);
    char *fn = (char *) malloc(fn_length+1);
    MPI_Recv (fn, fn_length, MPI_CHAR, 0, TadahCLI::WORK_TAG, MPI_COMM_WORLD, &status);
    MPI_Recv (&first, 1, MPI_INT, 0, TadahCLI::WORK_TAG, MPI_COMM_WORLD, &status);
    MPI_Recv (&nstruc, 1, MPI_INT, 0, TadahCLI::WORK_TAG, MPI_COMM_WORLD, &status);
    // do work
    StructureDB stdb;
    stdb.add(std::string(fn,fn_length),first,nstruc);
    nnf.calc(stdb);
    // compute number of rows needed for a given StructureDB
    int rows_needed = 0;
    for (size_t s=0; s<stdb.size(); ++s) {
    int natoms = stdb(s).natoms();
    rows_needed += DesignMatrixBase::phi_rows_num(config, 1, natoms);
    }
    if (rows_available<rows_needed) {
    // we do not have enough rows in the local phi matrix
    // so we create temp DM of required size
    DesignMatrix<DM_Function_Base&> temp_dm(*fb, config);
    temp_dm.Phi.resize(rows_needed,phi_cols1);
    temp_dm.T.resize(rows_needed);
    temp_dm.Tlabels.resize(rows_needed);
    // and compute all rows
    size_t temp_phi_row=0;
    temp_dm.fill_T(stdb);
    for (size_t s=0; s<stdb.size(); ++s) {
    StDescriptors st_d = dc.calc(stdb(s));
    temp_dm.build(temp_phi_row,stdb(s),st_d); // phi_row++
    }
    // first we try to fill remaining rows in the local phi matrix
    // copy top of temp_dm.Phi to the bottom of dm. Phi in reverse order
    if (rows_available>0) {
    for (rows_available; rows_available>0; rows_available--) {
    for (int c=0; c<phi_cols1; c++) {
    dm.Phi(phi_row,c) = temp_dm.Phi(rows_available-1,c);
    dm.T(phi_row) = temp_dm.T(rows_available-1);
    dm.Tlabels(phi_row) = temp_dm.Tlabels(rows_available-1);
    }
    phi_row++;
    rows_needed--;
    }
    }
    // there are no more available rows
    // send remaining data to available processes
    while (rows_needed > 0) {
    // request host
    MPI_Send (&rows_needed, 1, MPI_INT, 0, TadahCLI::DATA_TAG, MPI_COMM_WORLD);
    int rows_accepted; // number of accepted rows
    int dest; // receiving process
    // host returns which dest can accept and how much
    MPI_Recv (&dest, 1, MPI_INT, 0, TadahCLI::DATA_TAG, MPI_COMM_WORLD, &status);
    MPI_Recv (&rows_accepted, 1, MPI_INT, 0, TadahCLI::DATA_TAG, MPI_COMM_WORLD, &status);
    // we send data to the host or a willing worker
    int start=temp_dm.Phi.rows()-rows_needed;
    // Define temp data type for temp Phi matrix
    // Phi is stored in a column-major order
    MPI_Datatype trowvec, trowvecs;
    MPI_Type_vector( temp_dm.Phi.cols(), 1, temp_dm.Phi.rows(), MPI_DOUBLE, &trowvec);
    MPI_Type_commit(&trowvec);
    MPI_Type_create_resized(trowvec, 0, 1*sizeof(double), &trowvecs);
    MPI_Type_commit(&trowvecs);
    // ready to send
    MPI_Send (&temp_dm.Phi.data()[start], rows_accepted, trowvecs, dest, TadahCLI::DATA_TAG, MPI_COMM_WORLD);
    MPI_Send (&temp_dm.T.data()[start], rows_accepted, MPI_DOUBLE, dest, TadahCLI::DATA_TAG, MPI_COMM_WORLD);
    MPI_Send (&temp_dm.Tlabels.data()[start], rows_accepted, MPI_DOUBLE, dest, TadahCLI::DATA_TAG, MPI_COMM_WORLD);
    rows_needed -= rows_accepted;
    MPI_Type_free(&trowvec);
    MPI_Type_free(&trowvecs);
    }
    }
    else {
    // just fill local phi array as it is large enough
    // fill_T must be called before phi_row is incremented
    dm.fill_T(stdb,phi_row); // phi_row is not incremented by this method
    for (size_t s=0; s<stdb.size(); ++s) {
    StDescriptors st_d = dc.calc(stdb(s));
    dm.build(phi_row,stdb(s),st_d); // build() increments phi_row++
    }
    rows_available-=rows_needed;
    }
    if (fn)
    delete fn;
    }
    return 1;
    }
    #endif
    };
    0% Loading or .
    You are about to add 0 people to the discussion. Proceed with caution.
    Finish editing this message first!
    Please register or to comment