From 2efd979d910c9dd1fcf63b61cbc08a477a8e00de Mon Sep 17 00:00:00 2001 From: Marcin Kirsz <mkirsz@ed.ac.uk> Date: Fri, 4 Oct 2024 14:44:08 +0100 Subject: [PATCH] Fixed aggregation of weights --- bin/tadah_cli.cpp | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/bin/tadah_cli.cpp b/bin/tadah_cli.cpp index c43d0a3..0ac00be 100644 --- a/bin/tadah_cli.cpp +++ b/bin/tadah_cli.cpp @@ -246,18 +246,18 @@ void TadahCLI::subcommand_train() { double *work2 = new double[lwork2]; pdgels_(&trans, &tr.PHI_rows, &tr.PHI_cols, &nrhs, dm2.Phi.ptr(), &ia, &ja, descPHI2, b2, &ib, &jb, descB2, work2, &lwork2, &info2); + // get weight vector, for context1 + pdgemr2d_(&tr.PHI_rows, &tr.ione, dm2.T.ptr(), &tr.ione, &tr.ione, descB2, + tr.dm.T.ptr(), &tr.ione, &tr.ione, descB, &tr.context1); + if (rank==0) { - //std::cout << "---b2 vec: rank: " << rank << " "; - //for (int i=0;i<phi_cols2;++i) std::cout << b2[i] << " "; - //std::cout << std::endl; - t_type w(b2,tr.phi_cols2); - //w.resize(phi_cols2); - //for (int i=0;i<phi_cols2;++i) w[i] = b2[i]; + t_type w(tr.dm.T.ptr(), tr.PHI_cols); tr.model->set_weights(w); tr.model->trained=true; - //std::cout << w << std::endl; } + //if (rbuf) delete rbuf; + // // verify // int descY[9]; // descinit_( descY, &PHI_rows, &ione, &rnb1, &cnb1, &izero, &izero, &context1, /*leading dimension*/&phi_rows1, &info); @@ -337,7 +337,6 @@ void TadahCLI::subcommand_train() { // } - //MPI_Win_free(&window); } void TadahCLI::subcommand_predict() { -- GitLab