diff --git a/bin/tadah_cli.cpp b/bin/tadah_cli.cpp index ad9c2e9f936c38aa33bdb11d37c79b97dd0d5dce..9399254a50f6ec83dba5cb1b6f0736366d8409a3 100644 --- a/bin/tadah_cli.cpp +++ b/bin/tadah_cli.cpp @@ -98,7 +98,6 @@ void TadahCLI::subcommand_train() { MPI_Trainer_HOST HOST(tr, rank, ncpu); HOST.prep_wpckgs(); while (true) { - if (!HOST.has_packages()) { // no more packages, just skip remaining workers // we will collect remaining data and release them outside of this loop @@ -164,82 +163,6 @@ void TadahCLI::subcommand_train() { // tr.solve(); - // Descriptors for scalaPACK - //double *b = dm.T.ptr(); - //std::cout << "BEFORE: ---b vec: rank: " << b_rank << " "; - //for (int i=0;i<phi_rows1;++i) std::cout << b[i] << " "; - //std::cout << std::endl; - //std::cout << std::endl; - - - - // make a copy of dm.Phi and dm2.Phi - //Matrix Phi_cpy = dm.Phi; - //Matrix Phi2_cpy = dm2.Phi; - //t_type T = dm.T; - //t_type T2 = dm2.T; - - - //double wkopt; - //int lwork = -1; - - //std::cout << "phi_cols1: " << phi_cols1 << std::endl; - //pdgels_(&trans, &PHI_rows, &PHI_cols, &nrhs, dm.Phi.ptr(), &ia, &ja, descPHI, b, &ib, &jb, descB, &wkopt, &lwork, &info); - //if (info != 0) { - // printf("Error in pdgels, info = %d\n", info); - //} - - //std::cout << "phi_rows1, phi_cols1: " << phi_rows1 << ", " << phi_cols1<< std::endl; - //std::cout << "phi_rows2, phi_cols2: " << phi_rows2 << ", " << phi_cols2<< std::endl; - - //lwork = (int)wkopt; - - //double *work = new double[lwork]; - //pdgels_(&trans, &PHI_rows, &PHI_cols, &nrhs, dm.Phi.ptr(), &ia, &ja, descPHI, b, &ib, &jb, descB, work, &lwork, &info); - //if (rank==0) { - // std::cout << "---b vec: rank: " << rank << " "; - // for (int i=0;i<phi_cols1;++i) std::cout << b[i] << " "; - // std::cout << std::endl; - //} - - //if (rbuf) delete rbuf; - - // // verify - // int descY[9]; - // descinit_( descY, &PHI_rows, &ione, &rnb1, &cnb1, &izero, &izero, &context1, /*leading dimension*/&phi_rows1, &info); - // int descY2[9]; - // descinit_( descY2, &PHI_rows, &ione, &rnb2, &cnb2, &izero, &izero, &context2, /*leading dimension*/&phi_rows2, &info2); - - // double alpha=1; - // double beta=0; - // double *y = new double[phi_rows1]; - // double *y2 = new double[phi_rows2]; - // pdgemv_(&trans, &PHI_rows, &PHI_cols, &alpha, Phi_cpy.ptr(), &ia, &ja, descPHI, b, &ia, &ja, - // descB, &ione, &beta, y, &ia, &ja, descY, &ione); - // std::cout << "---y vec: rank: " << b_rank << " ::: "; - // for (int i=0;i<phi_rows1;++i) std::cout << y[i] << " " << T[i] << " " << y[i]-T[i] << std::endl; - // std::cout << std::endl; - // delete[] y; - // pdgemv_(&trans, &PHI_rows, &PHI_cols, &alpha, Phi2_cpy.ptr(), &ia, &ja, descPHI2, b2, &ia, &ja, - // descB2, &ione, &beta, y2, &ia, &ja, descY2, &ione); - - // if (b_rank%b_ncols2==0) { - // std::cout << "---y2 vec: rank: " << b_rank << " ::: "; - // for (int i=0;i<phi_rows2;++i) std::cout << y2[i] << " " << T2[i] << " "<< y2[i]-T2[i] << std::endl; - // std::cout << std::endl; - // } - // delete[] y2; - - // // Create third context - // int context3; - // int b_row3, b_col3; - // blacs_get_(&izero,&izero, &context3 ); // -> Create context2 - // blacs_gridinit_(&context3, &layout, &ncpu, &ione ); // context1 -> Initialize the grid - // blacs_gridinfo_(&context3, &ncpu, &ione, &b_row3, &b_col3 ); - // // do work on context3 - // blacs_gridexit_(&context3); - - //delete[] work; #else // NON MPI VERSION Trainer tr(config); @@ -433,19 +356,39 @@ void TadahCLI::subcommand_hpo( } #ifdef TADAH_BUILD_MPI // WORKER - else { - // 0. Obtain updated config from the host - // a) serialize config as a binary and broadcast with MPI(?) - // toml can serialise - // 1. calculate portion of descriptors - // 2. solve with scalaPACK - // 3. run lammps simulations requested be the host - while (true) { - // request work - // if training - // train_worker() - // if simulations - // sim_worker() + if (rank==0) { + hpo_run(config, target_file); + } + else { // WORKER + int TEMP=1; + while (TEMP--) { + MPI_Trainer tr(config); + tr.init(rank, ncpu); + MPI_Trainer_WORKER WORKER(tr, rank, ncpu); + while (true) { + // ask for more work... + MPI_Send (&tr.rows_available, 1, MPI_INT, 0, WORK_TAG, MPI_COMM_WORLD); + + // request from root or from other workers + tr.probe(); + + // release a worker + if (tr.tag == TadahCLI::RELEASE_TAG) { + WORKER.release_tag(); + break; + } + else if (tr.tag == TadahCLI::WAIT_TAG) { + // do nothing here; ask for more work in the next cycle + WORKER.wait_tag(); + } + else if (tr.tag == TadahCLI::DATA_TAG) { + WORKER.data_tag(); + } + else if (tr.tag == TadahCLI::WORK_TAG) { + WORKER.work_tag(); + } + } + tr.solve(); } } #endif