From eb171bae929370a27b6b461ce03bb1660e4a6382 Mon Sep 17 00:00:00 2001
From: Marcin Kirsz <mkirsz@ed.ac.uk>
Date: Tue, 8 Oct 2024 14:17:06 +0100
Subject: [PATCH] fix for solver when N > M

---
 trainer.h | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/trainer.h b/trainer.h
index 70956fd..035d152 100644
--- a/trainer.h
+++ b/trainer.h
@@ -246,6 +246,8 @@ class MPI_Trainer: public Trainer {
 
       descinit_( sol,  &PHI_cols, &ione, &rnb1, &cnb1, &izero, 
           &izero, &context1, /*leading dimension*/&phi_cols1, &info3);
+      descinit_( sol2,  &PHI_cols, &ione, &rnb2, &cnb2, &izero, 
+          &izero, &context2, /*leading dimension*/&phi_cols2, &info3);
 
       if(info != 0) {
         printf("Error in descinit 1b, info = %d\n", info);
@@ -293,7 +295,7 @@ class MPI_Trainer: public Trainer {
           descPHI2, dm2.T.ptr(), &ib, &jb, descB2, work2, &lwork2, &info2);
 
       // get weight vector, for context1 
-      pdgemr2d_(&PHI_cols, &ione, dm2.T.ptr(), &ione, &ione, descB2,
+      pdgemr2d_(&PHI_cols, &ione, dm2.T.ptr(), &ione, &ione, sol2,
           dm.T.ptr(), &ione, &ione, sol, &context1);
 
       if (rank==0) {
-- 
GitLab