From 898c5ed7c347a1c377990dd0f5671b10c1c91a54 Mon Sep 17 00:00:00 2001 From: Marcin Kirsz <mkirsz@ed.ac.uk> Date: Mon, 7 Oct 2024 10:28:35 +0100 Subject: [PATCH] Added solve method to MPI trainer --- bin/tadah_cli.cpp | 72 ++--------------------------------------------- 1 file changed, 2 insertions(+), 70 deletions(-) diff --git a/bin/tadah_cli.cpp b/bin/tadah_cli.cpp index e0dd7cb..ad9c2e9 100644 --- a/bin/tadah_cli.cpp +++ b/bin/tadah_cli.cpp @@ -161,59 +161,17 @@ void TadahCLI::subcommand_train() { } // END HOST-WORKER // All local phi matrices are computed by this point + // + tr.solve(); // Descriptors for scalaPACK - int descPHI[9], descPHI2[9]; - int descB[9], descB2[9]; - int info, info2; - - descinit_( descPHI, &tr.PHI_rows, &tr.PHI_cols, &tr.rnb1, &tr.cnb1, &tr.izero, &tr.izero, &tr.context1, /*leading dimension*/&tr.phi_rows1, &info); - descinit_( descPHI2, &tr.PHI_rows, &tr.PHI_cols, &tr.rnb2, &tr.cnb2, &tr.izero, &tr.izero, &tr.context2, /*leading dimension*/&tr.phi_rows2, &info2); - - if(info != 0) { - printf("Error in descinit 1a, info = %d\n", info); - } - if(info2 != 0) { - printf("Error in descinit 2a, info = %d\n", info2); - printf("HINT: Check these CONFIG parameters: MPIWPCKG, MBLOCK, NBLOCK\n"); - } - - descinit_( descB, &tr.PHI_rows, &tr.ione, &tr.rnb1, &tr.cnb1, &tr.izero, &tr.izero, &tr.context1, /*leading dimension*/&tr.phi_rows1, &info); - descinit_( descB2, &tr.PHI_rows, &tr.ione, &tr.rnb2, &tr.cnb2, &tr.izero, &tr.izero, &tr.context2, /*leading dimension*/&tr.phi_rows2, &info2); - - if(info != 0) { - printf("Error in descinit 1b, info = %d\n", info); - } - if(info2 != 0) { - printf("Error in descinit 2b, info = %d\n", info2); - printf("HINT: Check these CONFIG parameters: MPIWPCKG, MBLOCK, NBLOCK\n"); - } - - //double MPIt1 = MPI_Wtime(); - char trans='N'; - int nrhs = 1; //double *b = dm.T.ptr(); //std::cout << "BEFORE: ---b vec: rank: " << b_rank << " "; //for (int i=0;i<phi_rows1;++i) std::cout << b[i] << " "; //std::cout << std::endl; //std::cout << std::endl; - int ia = 1; - int ja = 1; - int ib = 1; - int jb = 1; - // Distribute data in 2D block cyclic - DesignMatrix<DM_Function_Base&> dm2(*tr.fb, tr.config); - dm2.Phi.resize(tr.phi_rows2,tr.phi_cols2); - dm2.T.resize(tr.phi_rows2); - dm2.Tlabels.resize(tr.phi_rows2); - - pdgemr2d_(&tr.PHI_rows, &tr.PHI_cols, tr.dm.Phi.ptr(), &tr.ione, &tr.ione, descPHI, - dm2.Phi.ptr(), &tr.ione, &tr.ione, descPHI2, &tr.context2); - - pdgemr2d_(&tr.PHI_rows, &tr.ione, tr.dm.T.ptr(), &tr.ione, &tr.ione, descB, - dm2.T.ptr(), &tr.ione, &tr.ione, descB2, &tr.context2); // make a copy of dm.Phi and dm2.Phi //Matrix Phi_cpy = dm.Phi; @@ -221,12 +179,9 @@ void TadahCLI::subcommand_train() { //t_type T = dm.T; //t_type T2 = dm2.T; - double *b2 = dm2.T.ptr(); //double wkopt; - double wkopt2; //int lwork = -1; - int lwork2 = -1; // query -> get size of the work matrix //std::cout << "phi_cols1: " << phi_cols1 << std::endl; //pdgels_(&trans, &PHI_rows, &PHI_cols, &nrhs, dm.Phi.ptr(), &ia, &ja, descPHI, b, &ib, &jb, descB, &wkopt, &lwork, &info); @@ -236,13 +191,8 @@ void TadahCLI::subcommand_train() { //std::cout << "phi_rows1, phi_cols1: " << phi_rows1 << ", " << phi_cols1<< std::endl; //std::cout << "phi_rows2, phi_cols2: " << phi_rows2 << ", " << phi_cols2<< std::endl; - pdgels_(&trans, &tr.PHI_rows, &tr.PHI_cols, &nrhs, dm2.Phi.ptr(), &ia, &ja, descPHI2, b2, &ib, &jb, descB2, &wkopt2, &lwork2, &info2); - if (info2 != 0) { - printf("Error in pdgels, info = %d\n", info); - } //lwork = (int)wkopt; - lwork2 = (int)wkopt2; //double *work = new double[lwork]; //pdgels_(&trans, &PHI_rows, &PHI_cols, &nrhs, dm.Phi.ptr(), &ia, &ja, descPHI, b, &ib, &jb, descB, work, &lwork, &info); @@ -251,18 +201,6 @@ void TadahCLI::subcommand_train() { // for (int i=0;i<phi_cols1;++i) std::cout << b[i] << " "; // std::cout << std::endl; //} - double *work2 = new double[lwork2]; - pdgels_(&trans, &tr.PHI_rows, &tr.PHI_cols, &nrhs, dm2.Phi.ptr(), &ia, &ja, descPHI2, b2, &ib, &jb, descB2, work2, &lwork2, &info2); - - // get weight vector, for context1 - pdgemr2d_(&tr.PHI_rows, &tr.ione, dm2.T.ptr(), &tr.ione, &tr.ione, descB2, - tr.dm.T.ptr(), &tr.ione, &tr.ione, descB, &tr.context1); - - if (rank==0) { - t_type w(tr.dm.T.ptr(), tr.PHI_cols); - tr.model->set_weights(w); - tr.model->trained=true; - } //if (rbuf) delete rbuf; @@ -302,12 +240,6 @@ void TadahCLI::subcommand_train() { // blacs_gridexit_(&context3); //delete[] work; - delete[] work2; - MPI_Type_free(&tr.rowvec); - MPI_Type_free(&tr.rowvecs); - - blacs_gridexit_(&tr.context1); - blacs_gridexit_(&tr.context2); #else // NON MPI VERSION Trainer tr(config); -- GitLab