Skip to content
Snippets Groups Projects
Commit df414e25 authored by Marcin Kirsz's avatar Marcin Kirsz
Browse files

Update

parent 7e03c320
No related branches found
No related tags found
No related merge requests found
Pipeline #43057 passed
Pipeline: Tadah.MLIP

#43060

    ...@@ -29,9 +29,9 @@ class Trainer { ...@@ -29,9 +29,9 @@ class Trainer {
    *DCS.c2b,*DCS.c3b,*DCS.cmb), *DCS.c2b,*DCS.c3b,*DCS.cmb),
    nnf(config), nnf(config),
    fb(CONFIG::factory<DM_Function_Base,Config&>( fb(CONFIG::factory<DM_Function_Base,Config&>(
    config.get<std::string>("MODEL",1),config)), config.get<std::string>("MODEL",1),config)),
    model(CONFIG::factory<M_Tadah_Base,DM_Function_Base&,Config&> model(CONFIG::factory<M_Tadah_Base,DM_Function_Base&,Config&>
    (config.get<std::string>("MODEL",0),*fb,config)), (config.get<std::string>("MODEL",0),*fb,config)),
    dm(*fb, config) dm(*fb, config)
    { {
    config.postprocess(); config.postprocess();
    ...@@ -39,8 +39,8 @@ class Trainer { ...@@ -39,8 +39,8 @@ class Trainer {
    } }
    void train(StructureDB &stdb) { void train(StructureDB &stdb) {
    nnf.calc(stdb); nnf.calc(stdb);
    model->train(stdb,dc); model->train(stdb,dc);
    } }
    Config get_param_file() { Config get_param_file() {
    ...@@ -280,5 +280,147 @@ class MPI_Trainer_HOST { ...@@ -280,5 +280,147 @@ class MPI_Trainer_HOST {
    MPI_Trainer_HOST(MPI_Trainer &MPIT): MPI_Trainer_HOST(MPI_Trainer &MPIT):
    MPIT(MPIT) MPIT(MPIT)
    {} {}
    void run() {
    while (true) {
    if (wpckgs.size()==0) {
    // no more packages, just skip remaining workers
    // we will collect remaining data and release them outside of this loop
    break;
    }
    // probe ANY request from ANY worker
    MPI_Probe(MPI_ANY_SOURCE, MPI_ANY_TAG, MPI_COMM_WORLD, &status);
    int worker = status.MPI_SOURCE;
    int tag = status.MPI_TAG;
    if (tag==WORK_TAG) {
    int rows_available;
    MPI_Recv (&rows_available, 1, MPI_INT, worker, tag, MPI_COMM_WORLD, &status);
    std::tuple<std::string,int,int> wpckg = wpckgs.back();
    wpckgs.pop_back();
    // send dataset filename
    const char *fn = std::get<0>(wpckg).c_str();
    int fn_length = std::get<0>(wpckg).length()+1; // +1 for char
    MPI_Send (fn, fn_length, MPI_CHAR, worker, tag, MPI_COMM_WORLD);
    // send index of the first structure to load
    int first = std::get<1>(wpckg);
    MPI_Send (&first, 1, MPI_INT, worker, tag, MPI_COMM_WORLD);
    // send number of structures to load
    int nstruc = std::get<2>(wpckg);
    MPI_Send (&nstruc, 1, MPI_INT, worker, tag, MPI_COMM_WORLD);
    }
    else if (tag==DATA_TAG) {
    int rows_needed;
    MPI_Recv (&rows_needed, 1, MPI_INT, worker, tag, MPI_COMM_WORLD, &status);
    if (rows_available>0) {
    int rows_accepted = rows_available < rows_needed ? rows_available : rows_needed;
    MPI_Send (&b_rank, 1, MPI_INT, worker, tag, MPI_COMM_WORLD);
    MPI_Send (&rows_accepted, 1, MPI_INT, worker, tag, MPI_COMM_WORLD);
    MPI_Recv (&tr.dm.Phi.data()[phi_row], rows_accepted, rowvecs, worker, tag, MPI_COMM_WORLD, &status);
    MPI_Recv (&tr.dm.T.data()[phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status);
    MPI_Recv (&tr.dm.Tlabels.data()[phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status);
    rows_available -= rows_accepted;
    phi_row += rows_accepted;
    if (rows_available<0 ) { throw std::runtime_error(" HOST1: The number of rows in the local array is smaller than requested.");}
    }
    else {
    // Host is unable to fit data we have to ask workers for their storage availability
    // Locate a worker willing to accept at least some data.
    int worker2;
    MPI_Status status2;
    // find a worker capable of accepting data
    int w_rows_available;
    while (true) {
    MPI_Recv (&w_rows_available, 1, MPI_INT, MPI_ANY_SOURCE, WORK_TAG, MPI_COMM_WORLD, &status2);
    worker2 = status2.MPI_SOURCE;
    if (worker==worker2) {throw std::runtime_error("worker and worker2 are the same."); }
    if (w_rows_available==0 ) {
    // give up on this worker
    MPI_Send (&worker2, 1, MPI_INT, worker2, WAIT_TAG, MPI_COMM_WORLD);
    }
    else {
    // found a worker
    break;
    }
    }
    int rows_accepted = w_rows_available < rows_needed ? w_rows_available : rows_needed;
    // tell worker about worker2
    MPI_Send (&worker2, 1, MPI_INT, worker, tag, MPI_COMM_WORLD);
    MPI_Send (&rows_accepted, 1, MPI_INT, worker, tag, MPI_COMM_WORLD);
    }
    }
    else {
    throw std::runtime_error("HOST1: Unexpected request from " + std::to_string(worker));
    }
    }
    // work finised, collect remaining data and release all workers
    int count=1; // count number of release workers, skip host
    while(true) {
    MPI_Probe(MPI_ANY_SOURCE, MPI_ANY_TAG, MPI_COMM_WORLD, &status);
    int worker = status.MPI_SOURCE;
    int tag = status.MPI_TAG;
    if (tag==DATA_TAG) {
    int rows_needed;
    MPI_Recv (&rows_needed, 1, MPI_INT, worker, tag, MPI_COMM_WORLD, &status);
    if (rows_available>0) {
    int rows_accepted = rows_available < rows_needed ? rows_available : rows_needed;
    MPI_Send (&b_rank, 1, MPI_INT, worker, tag, MPI_COMM_WORLD);
    MPI_Send (&rows_accepted, 1, MPI_INT, worker, tag, MPI_COMM_WORLD);
    MPI_Recv (&tr.dm.Phi.data()[phi_row], rows_accepted, rowvecs, worker, tag, MPI_COMM_WORLD, &status);
    MPI_Recv (&tr.dm.T.data()[phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status);
    MPI_Recv (&tr.dm.Tlabels.data()[phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status);
    rows_available -= rows_accepted;
    phi_row += rows_accepted;
    if (rows_available<0 ) { throw std::runtime_error(" HOST2: The number of rows in the local array is smaller than requested.");}
    }
    else {
    // host is unable to fit data we have to ask workers for their storage availability
    // find a worker to accept at least some data
    MPI_Status status2;
    int worker2;
    // find a worker capable of accepting data
    int w_rows_available;
    while (true) {
    MPI_Recv (&w_rows_available, 1, MPI_INT, MPI_ANY_SOURCE, WORK_TAG, MPI_COMM_WORLD, &status2);
    worker2 = status2.MPI_SOURCE;
    if (worker==worker2) {throw std::runtime_error("worker and worker2 are the same."); }
    if (w_rows_available==0 ) {
    // give up on this worker and release him as there is no more work to be done
    MPI_Send (0, 0, MPI_INT, worker2, RELEASE_TAG, MPI_COMM_WORLD);
    count++;
    }
    else {
    // found a worker
    break;
    }
    }
    int rows_accepted = w_rows_available < rows_needed ? w_rows_available : rows_needed;
    MPI_Send (&worker2, 1, MPI_INT, worker, DATA_TAG, MPI_COMM_WORLD);
    MPI_Send (&rows_accepted, 1, MPI_INT, worker, DATA_TAG, MPI_COMM_WORLD);
    }
    }
    else {
    int rows_available;
    MPI_Recv (&rows_available, 1, MPI_INT, worker, WORK_TAG, MPI_COMM_WORLD, &status);
    // there is no more work so release a worker
    if (rows_available==0) {
    MPI_Send (0, 0, MPI_INT, worker, RELEASE_TAG, MPI_COMM_WORLD);
    count++;
    }
    else {
    MPI_Send (0, 0, MPI_INT, worker, WAIT_TAG, MPI_COMM_WORLD);
    }
    }
    if (count==ncpu) { break; } // count starts from 1
    }
    }
    }; };
    #endif #endif
    0% Loading or .
    You are about to add 0 people to the discussion. Proceed with caution.
    Finish editing this message first!
    Please register or to comment