From 3d3eb49a753451bc1056c42e26d938e0c992303e Mon Sep 17 00:00:00 2001 From: Marcin Kirsz <mkirsz@ed.ac.uk> Date: Tue, 8 Oct 2024 14:08:45 +0100 Subject: [PATCH] fix for solver when N > M --- trainer.h | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/trainer.h b/trainer.h index 0c005bb..91de420 100644 --- a/trainer.h +++ b/trainer.h @@ -198,8 +198,8 @@ class MPI_Trainer: public Trainer { // workers. //DesignMatrix<DM_Function_Base&> dm(*fb, config); dm.Phi.resize(phi_rows1,phi_cols1); - int b_lda1 = phi_rows1 > phi_cols1 ? phi_rows1 : phi_cols1; - dm.T.resize(b_lda1); + int lda1 = phi_rows1 > phi_cols1 ? phi_rows1 : phi_cols1; + dm.T.resize(lda1); dm.Tlabels.resize(phi_rows1); } void probe() { @@ -218,6 +218,9 @@ class MPI_Trainer: public Trainer { int descB[9], descB2[9]; int info, info2; + int sol[9], sol2[9]; + int info3; + int lda1 = phi_rows1 > phi_cols1 ? phi_rows1 : phi_cols1; int lda2 = phi_rows2 > phi_cols2 ? phi_rows2 : phi_cols2; @@ -236,11 +239,14 @@ class MPI_Trainer: public Trainer { int temp = PHI_rows > PHI_cols ? PHI_rows : PHI_cols; - descinit_( descB, &temp, &ione, &rnb1, &cnb1, &izero, + descinit_( descB, &PHI_rows, &ione, &rnb1, &cnb1, &izero, &izero, &context1, /*leading dimension*/&lda1, &info); - descinit_( descB2, &temp, &ione, &rnb2, &cnb2, &izero, + descinit_( descB2, &PHI_rows, &ione, &rnb2, &cnb2, &izero, &izero, &context2, /*leading dimension*/&lda2, &info2); + descinit_( sol, &PHI_cols, &ione, &rnb1, &cnb1, &izero, + &izero, &context1, /*leading dimension*/&lda1, &info3); + if(info != 0) { printf("Error in descinit 1b, info = %d\n", info); } @@ -248,8 +254,10 @@ class MPI_Trainer: public Trainer { printf("Error in descinit 2b, info = %d\n", info2); printf("HINT: Check these CONFIG parameters: MPIWPCKG, MBLOCK, NBLOCK\n"); } + if(info3 != 0) { + printf("Error in descinit 1b, info = %d\n", info); + } - //double MPIt1 = MPI_Wtime(); char trans='N'; int nrhs = 1; @@ -286,7 +294,7 @@ class MPI_Trainer: public Trainer { // get weight vector, for context1 pdgemr2d_(&PHI_cols, &ione, dm2.T.ptr(), &ione, &ione, descB2, - dm.T.ptr(), &ione, &ione, descB, &context1); + dm.T.ptr(), &ione, &ione, sol, &context1); if (rank==0) { t_type w(dm.T.ptr(), PHI_cols); -- GitLab