From bcec556de7e91cbf3bbb294e6d7613078182a093 Mon Sep 17 00:00:00 2001
From: Marcin Kirsz <mkirsz@ed.ac.uk>
Date: Wed, 25 Sep 2024 17:07:00 +0100
Subject: [PATCH] Fixed MPI handle leak. Added CONFIG keys to control block and
 work package sizes for MPI training

---
 bin/tadah_cli.cpp | 10 ++++++----
 1 file changed, 6 insertions(+), 4 deletions(-)

diff --git a/bin/tadah_cli.cpp b/bin/tadah_cli.cpp
index 0d658c8..c0864fc 100644
--- a/bin/tadah_cli.cpp
+++ b/bin/tadah_cli.cpp
@@ -150,9 +150,8 @@ void TadahCLI::subcommand_train() {
   int b_col1,                        b_col2;
   int b_nrows1,                      b_nrows2;    // Number of row procs
   int b_ncols1,                      b_ncols2;    // Number of column procs
-  int rnb1 = ceil(PHI_rows/ncpu),    rnb2 = 64;   // Row block size
-  int cnb1 = PHI_cols,               cnb2 = 64;   // Column block size
-  // TODO block size is system dependent, allow user to overwrite default settings
+  int rnb1 = ceil(PHI_rows/ncpu),    rnb2 = config.get<int>("MBLOCK");   // Row block size
+  int cnb1 = PHI_cols,               cnb2 = config.get<int>("NBLOCK");   // Column block size
 
   b_ncols1 = 1;        b_ncols2 = 2;
   b_nrows1 = ncpu;     b_nrows2 = ncpu/b_ncols2;
@@ -232,7 +231,7 @@ void TadahCLI::subcommand_train() {
     // HOST: prepare work packages
     // filename, first structure index, number of structures to read
     std::vector<std::tuple<std::string,int,int>> wpckgs;
-    int nstruc = 10;  // TODO: read from Config, the number of structures in a single work package
+    int nstruc = config.get<int>("MPIWPCKG");  // TODO: read from Config, the number of structures in a single work package
     for (const std::string &fn : config("DBFILE")) {
       // get number of structures
       int dbsize = StructureDB::count(fn).first;
@@ -507,6 +506,7 @@ void TadahCLI::subcommand_train() {
             MPI_Send (&temp_dm.Tlabels.data()[start], rows_accepted, MPI_DOUBLE, dest, DATA_TAG, MPI_COMM_WORLD);
             rows_needed -= rows_accepted;
             MPI_Type_free(&trowvec);
+            MPI_Type_free(&trowvecs);
           }
         }
         else {
@@ -543,6 +543,7 @@ void TadahCLI::subcommand_train() {
   }
   if(info2 != 0) {
     printf("Error in descinit 2a, info = %d\n", info2);
+    printf("HINT: Check these CONFIG parameters: MPIWPCKG, MBLOCK, NBLOCK\n");
   }
 
   descinit_( descB,   &PHI_rows, &ione, &rnb1, &cnb1, &izero, &izero, &context1, /*leading dimension*/&phi_rows1, &info);
@@ -553,6 +554,7 @@ void TadahCLI::subcommand_train() {
   }
   if(info2 != 0) {
     printf("Error in descinit 2b, info = %d\n", info2);
+    printf("HINT: Check these CONFIG parameters: MPIWPCKG, MBLOCK, NBLOCK\n");
   }
 
   //double MPIt1 = MPI_Wtime();
-- 
GitLab