From 2efd979d910c9dd1fcf63b61cbc08a477a8e00de Mon Sep 17 00:00:00 2001
From: Marcin Kirsz <mkirsz@ed.ac.uk>
Date: Fri, 4 Oct 2024 14:44:08 +0100
Subject: [PATCH] Fixed aggregation of weights

---
 bin/tadah_cli.cpp | 15 +++++++--------
 1 file changed, 7 insertions(+), 8 deletions(-)

diff --git a/bin/tadah_cli.cpp b/bin/tadah_cli.cpp
index c43d0a3..0ac00be 100644
--- a/bin/tadah_cli.cpp
+++ b/bin/tadah_cli.cpp
@@ -246,18 +246,18 @@ void TadahCLI::subcommand_train() {
   double *work2 = new double[lwork2];
   pdgels_(&trans, &tr.PHI_rows, &tr.PHI_cols, &nrhs, dm2.Phi.ptr(), &ia, &ja, descPHI2, b2, &ib, &jb, descB2, work2, &lwork2, &info2);
 
+  // get weight vector, for context1 
+  pdgemr2d_(&tr.PHI_rows, &tr.ione, dm2.T.ptr(), &tr.ione, &tr.ione, descB2,
+      tr.dm.T.ptr(), &tr.ione, &tr.ione, descB, &tr.context1);
+
   if (rank==0) {
-    //std::cout << "---b2 vec: rank: " << rank << " ";
-    //for (int i=0;i<phi_cols2;++i) std::cout << b2[i] << " ";
-    //std::cout << std::endl;
-    t_type w(b2,tr.phi_cols2);
-    //w.resize(phi_cols2);
-    //for (int i=0;i<phi_cols2;++i) w[i] = b2[i];
+    t_type w(tr.dm.T.ptr(), tr.PHI_cols);
     tr.model->set_weights(w);
     tr.model->trained=true;
-    //std::cout << w << std::endl;
   }
 
+  //if (rbuf) delete rbuf;
+
   // // verify
   // int descY[9];
   // descinit_( descY,  &PHI_rows, &ione, &rnb1, &cnb1, &izero, &izero, &context1, /*leading dimension*/&phi_rows1, &info);
@@ -337,7 +337,6 @@ void TadahCLI::subcommand_train() {
     //
   }
 
-  //MPI_Win_free(&window);
 }
 
 void TadahCLI::subcommand_predict() {
-- 
GitLab