From 898c5ed7c347a1c377990dd0f5671b10c1c91a54 Mon Sep 17 00:00:00 2001
From: Marcin Kirsz <mkirsz@ed.ac.uk>
Date: Mon, 7 Oct 2024 10:28:35 +0100
Subject: [PATCH] Added solve method to MPI trainer

---
 bin/tadah_cli.cpp | 72 ++---------------------------------------------
 1 file changed, 2 insertions(+), 70 deletions(-)

diff --git a/bin/tadah_cli.cpp b/bin/tadah_cli.cpp
index e0dd7cb..ad9c2e9 100644
--- a/bin/tadah_cli.cpp
+++ b/bin/tadah_cli.cpp
@@ -161,59 +161,17 @@ void TadahCLI::subcommand_train() {
   }
   // END HOST-WORKER
   // All local phi matrices are computed by this point
+  //
+  tr.solve();
 
   // Descriptors for scalaPACK
-  int descPHI[9],  descPHI2[9];
-  int descB[9],    descB2[9];
-  int info,        info2;
-
-  descinit_( descPHI,  &tr.PHI_rows, &tr.PHI_cols, &tr.rnb1, &tr.cnb1, &tr.izero, &tr.izero, &tr.context1, /*leading dimension*/&tr.phi_rows1, &info);
-  descinit_( descPHI2, &tr.PHI_rows, &tr.PHI_cols, &tr.rnb2, &tr.cnb2, &tr.izero, &tr.izero, &tr.context2, /*leading dimension*/&tr.phi_rows2, &info2);
-
-  if(info != 0) {
-    printf("Error in descinit 1a, info = %d\n", info);
-  }
-  if(info2 != 0) {
-    printf("Error in descinit 2a, info = %d\n", info2);
-    printf("HINT: Check these CONFIG parameters: MPIWPCKG, MBLOCK, NBLOCK\n");
-  }
-
-  descinit_( descB,   &tr.PHI_rows, &tr.ione, &tr.rnb1, &tr.cnb1, &tr.izero, &tr.izero, &tr.context1, /*leading dimension*/&tr.phi_rows1, &info);
-  descinit_( descB2,  &tr.PHI_rows, &tr.ione, &tr.rnb2, &tr.cnb2, &tr.izero, &tr.izero, &tr.context2, /*leading dimension*/&tr.phi_rows2, &info2);
-
-  if(info != 0) {
-    printf("Error in descinit 1b, info = %d\n", info);
-  }
-  if(info2 != 0) {
-    printf("Error in descinit 2b, info = %d\n", info2);
-    printf("HINT: Check these CONFIG parameters: MPIWPCKG, MBLOCK, NBLOCK\n");
-  }
-
-  //double MPIt1 = MPI_Wtime();
-  char trans='N';
-  int nrhs = 1;
   //double *b = dm.T.ptr();
   //std::cout << "BEFORE: ---b vec: rank: " << b_rank << " ";
   //for (int i=0;i<phi_rows1;++i) std::cout << b[i] << " ";
   //std::cout << std::endl;
   //std::cout << std::endl;
 
-  int ia = 1;
-  int ja = 1;
-  int ib = 1;
-  int jb = 1;
 
-  // Distribute data in 2D block cyclic 
-  DesignMatrix<DM_Function_Base&> dm2(*tr.fb, tr.config);
-  dm2.Phi.resize(tr.phi_rows2,tr.phi_cols2);
-  dm2.T.resize(tr.phi_rows2);
-  dm2.Tlabels.resize(tr.phi_rows2);
-
-  pdgemr2d_(&tr.PHI_rows, &tr.PHI_cols, tr.dm.Phi.ptr(), &tr.ione, &tr.ione, descPHI,
-      dm2.Phi.ptr(), &tr.ione, &tr.ione, descPHI2, &tr.context2);
-
-  pdgemr2d_(&tr.PHI_rows, &tr.ione, tr.dm.T.ptr(), &tr.ione, &tr.ione, descB,
-      dm2.T.ptr(), &tr.ione, &tr.ione, descB2, &tr.context2);
 
   // make a copy of dm.Phi and dm2.Phi
   //Matrix Phi_cpy = dm.Phi;
@@ -221,12 +179,9 @@ void TadahCLI::subcommand_train() {
   //t_type T = dm.T;
   //t_type T2 = dm2.T;
 
-  double *b2 = dm2.T.ptr();
 
   //double wkopt;
-  double wkopt2;
   //int lwork = -1;
-  int lwork2 = -1; // query -> get size of the work matrix
 
   //std::cout << "phi_cols1: " << phi_cols1 << std::endl;
   //pdgels_(&trans, &PHI_rows, &PHI_cols, &nrhs, dm.Phi.ptr(), &ia, &ja, descPHI, b, &ib, &jb, descB, &wkopt, &lwork, &info);
@@ -236,13 +191,8 @@ void TadahCLI::subcommand_train() {
 
   //std::cout << "phi_rows1, phi_cols1: " << phi_rows1 << ", " << phi_cols1<< std::endl;
   //std::cout << "phi_rows2, phi_cols2: " << phi_rows2 << ", " << phi_cols2<< std::endl;
-  pdgels_(&trans, &tr.PHI_rows, &tr.PHI_cols, &nrhs, dm2.Phi.ptr(), &ia, &ja, descPHI2, b2, &ib, &jb, descB2, &wkopt2, &lwork2, &info2);
-  if (info2 != 0) {
-    printf("Error in pdgels, info = %d\n", info);
-  }
 
   //lwork = (int)wkopt;
-  lwork2 = (int)wkopt2;
 
   //double *work = new double[lwork];
   //pdgels_(&trans, &PHI_rows, &PHI_cols, &nrhs, dm.Phi.ptr(), &ia, &ja, descPHI, b, &ib, &jb, descB, work, &lwork, &info);
@@ -251,18 +201,6 @@ void TadahCLI::subcommand_train() {
   //  for (int i=0;i<phi_cols1;++i) std::cout << b[i] << " ";
   //  std::cout << std::endl;
   //}
-  double *work2 = new double[lwork2];
-  pdgels_(&trans, &tr.PHI_rows, &tr.PHI_cols, &nrhs, dm2.Phi.ptr(), &ia, &ja, descPHI2, b2, &ib, &jb, descB2, work2, &lwork2, &info2);
-
-  // get weight vector, for context1 
-  pdgemr2d_(&tr.PHI_rows, &tr.ione, dm2.T.ptr(), &tr.ione, &tr.ione, descB2,
-      tr.dm.T.ptr(), &tr.ione, &tr.ione, descB, &tr.context1);
-
-  if (rank==0) {
-    t_type w(tr.dm.T.ptr(), tr.PHI_cols);
-    tr.model->set_weights(w);
-    tr.model->trained=true;
-  }
 
   //if (rbuf) delete rbuf;
 
@@ -302,12 +240,6 @@ void TadahCLI::subcommand_train() {
   // blacs_gridexit_(&context3);
 
   //delete[] work;
-  delete[] work2;
-  MPI_Type_free(&tr.rowvec);
-  MPI_Type_free(&tr.rowvecs);
-
-  blacs_gridexit_(&tr.context1);
-  blacs_gridexit_(&tr.context2);
 
 #else // NON MPI VERSION
   Trainer tr(config);
-- 
GitLab