diff --git a/trainer.h b/trainer.h index 1813f343c0343992d5df6d0bc89ffecaa4566fd5..20821372fece0fe6aa760ff19033911647e4211d 100644 --- a/trainer.h +++ b/trainer.h @@ -72,7 +72,11 @@ class MPI_Trainer: public Trainer { int PHI_cols; int PHI_rows; int rows_available; - + int b_rank; + int phi_rows1; + int phi_cols1; + int phi_rows2; + int phi_cols2; MPI_Trainer(Config &c): Trainer(c) {} @@ -91,7 +95,6 @@ class MPI_Trainer: public Trainer { // We create two contexts. // context1 is used for the computation of phi matrices // context2 is used for distribution of local phi to "block cyclic phi" - int b_rank; blacs_pinfo_(&b_rank, &ncpu) ; // BLACS rank and world size int context1, context2; @@ -145,10 +148,10 @@ class MPI_Trainer: public Trainer { blacs_gridinfo_(&context2, &b_nrows2, &b_ncols2, &b_row2, &b_col2 ); // Compute the size of the local phi matrices - int phi_rows1 = numroc_( &PHI_rows, &rnb1, &b_row1, &izero, &b_nrows1 ); - int phi_cols1 = numroc_( &PHI_cols, &cnb1, &b_col1, &izero, &b_ncols1 ); - int phi_rows2 = numroc_( &PHI_rows, &rnb2, &b_row2, &izero, &b_nrows2 ); - int phi_cols2 = numroc_( &PHI_cols, &cnb2, &b_col2, &izero, &b_ncols2 ); + phi_rows1 = numroc_( &PHI_rows, &rnb1, &b_row1, &izero, &b_nrows1 ); + phi_cols1 = numroc_( &PHI_cols, &cnb1, &b_col1, &izero, &b_ncols1 ); + phi_rows2 = numroc_( &PHI_rows, &rnb2, &b_row2, &izero, &b_nrows2 ); + phi_cols2 = numroc_( &PHI_cols, &cnb2, &b_col2, &izero, &b_ncols2 ); // Define MPI datatype to send rows from the PHI matrix with column-major order // used only in context1 @@ -339,7 +342,7 @@ class MPI_Trainer_HOST { MPI_Recv (&rows_needed, 1, MPI_INT, worker, tag, MPI_COMM_WORLD, &status); if (tr.rows_available>0) { int rows_accepted = tr.rows_available < rows_needed ? tr.rows_available : rows_needed; - MPI_Send (&b_rank, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); + MPI_Send (&tr.b_rank, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); MPI_Send (&rows_accepted, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); MPI_Recv (&tr.dm.Phi.data()[phi_row], rows_accepted, rowvecs, worker, tag, MPI_COMM_WORLD, &status); MPI_Recv (&tr.dm.T.data()[phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); @@ -391,7 +394,7 @@ class MPI_Trainer_HOST { MPI_Recv (&rows_needed, 1, MPI_INT, worker, tag, MPI_COMM_WORLD, &status); if (tr.rows_available>0) { int rows_accepted = tr.rows_available < rows_needed ? tr.rows_available : rows_needed; - MPI_Send (&b_rank, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); + MPI_Send (&tr.b_rank, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); MPI_Send (&rows_accepted, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); MPI_Recv (&tr.dm.Phi.data()[phi_row], rows_accepted, rowvecs, worker, tag, MPI_COMM_WORLD, &status); MPI_Recv (&tr.dm.T.data()[phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status);