diff --git a/bin/tadah_cli.cpp b/bin/tadah_cli.cpp index 0460b339ffb38f4b702e54db59c734dc100f3e05..402a8166c8a85bf44d40f8b11130b7a0a26423af 100644 --- a/bin/tadah_cli.cpp +++ b/bin/tadah_cli.cpp @@ -195,10 +195,10 @@ void TadahCLI::subcommand_train() { // once we know the size of local phi, we can allocate memory to it // including host as well. The host will collect excees computations from // workers. - DesignMatrix<DM_Function_Base&> dm(*tr.fb, tr.config); - dm.Phi.resize(phi_rows1,phi_cols1); - dm.T.resize(phi_rows1); - dm.Tlabels.resize(phi_rows1); + //DesignMatrix<DM_Function_Base&> dm(*tr.fb, tr.config); + tr.dm.Phi.resize(phi_rows1,phi_cols1); + tr.dm.T.resize(phi_rows1); + tr.dm.Tlabels.resize(phi_rows1); // BEGIN HOST-WORKER if (rank==0) { @@ -263,9 +263,9 @@ void TadahCLI::subcommand_train() { int rows_accepted = rows_available < rows_needed ? rows_available : rows_needed; MPI_Send (&b_rank, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); MPI_Send (&rows_accepted, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); - MPI_Recv (&dm.Phi.data()[phi_row], rows_accepted, rowvecs, worker, tag, MPI_COMM_WORLD, &status); - MPI_Recv (&dm.T.data()[phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); - MPI_Recv (&dm.Tlabels.data()[phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); + MPI_Recv (&tr.dm.Phi.data()[phi_row], rows_accepted, rowvecs, worker, tag, MPI_COMM_WORLD, &status); + MPI_Recv (&tr.dm.T.data()[phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); + MPI_Recv (&tr.dm.Tlabels.data()[phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); rows_available -= rows_accepted; phi_row += rows_accepted; if (rows_available<0 ) { throw std::runtime_error(" HOST1: The number of rows in the local array is smaller than requested.");} @@ -315,9 +315,9 @@ void TadahCLI::subcommand_train() { int rows_accepted = rows_available < rows_needed ? rows_available : rows_needed; MPI_Send (&b_rank, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); MPI_Send (&rows_accepted, 1, MPI_INT, worker, tag, MPI_COMM_WORLD); - MPI_Recv (&dm.Phi.data()[phi_row], rows_accepted, rowvecs, worker, tag, MPI_COMM_WORLD, &status); - MPI_Recv (&dm.T.data()[phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); - MPI_Recv (&dm.Tlabels.data()[phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); + MPI_Recv (&tr.dm.Phi.data()[phi_row], rows_accepted, rowvecs, worker, tag, MPI_COMM_WORLD, &status); + MPI_Recv (&tr.dm.T.data()[phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); + MPI_Recv (&tr.dm.Tlabels.data()[phi_row], rows_accepted, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); rows_available -= rows_accepted; phi_row += rows_accepted; if (rows_available<0 ) { throw std::runtime_error(" HOST2: The number of rows in the local array is smaller than requested.");} @@ -391,14 +391,14 @@ void TadahCLI::subcommand_train() { MPI_Get_count(&status, MPI_DOUBLE, &arr_size); int rows_accepted = arr_size/phi_cols1; if (rows_available<rows_accepted) { throw std::runtime_error("Number of rows available is smaller than number of provided rows");} - MPI_Recv (&dm.Phi.data()[phi_row], rows_available, rowvecs, worker, tag, MPI_COMM_WORLD, &status); - MPI_Recv (&dm.T.data()[phi_row], rows_available, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); - MPI_Recv (&dm.Tlabels.data()[phi_row], rows_available, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); + MPI_Recv (&tr.dm.Phi.data()[phi_row], rows_available, rowvecs, worker, tag, MPI_COMM_WORLD, &status); + MPI_Recv (&tr.dm.T.data()[phi_row], rows_available, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); + MPI_Recv (&tr.dm.Tlabels.data()[phi_row], rows_available, MPI_DOUBLE, worker, tag, MPI_COMM_WORLD, &status); rows_available -= rows_accepted; phi_row += rows_accepted; } else if (tag == TadahCLI::WORK_TAG) { - tr.train(status, rows_available, phi_row, phi_cols1, dm, rowvecs, worker); + tr.train(status, rows_available, phi_row, phi_cols1, rowvecs, worker); } } } @@ -452,10 +452,10 @@ void TadahCLI::subcommand_train() { dm2.T.resize(phi_rows2); dm2.Tlabels.resize(phi_rows2); - pdgemr2d_(&PHI_rows, &PHI_cols, dm.Phi.ptr(), &ione, &ione, descPHI, + pdgemr2d_(&PHI_rows, &PHI_cols, tr.dm.Phi.ptr(), &ione, &ione, descPHI, dm2.Phi.ptr(), &ione, &ione, descPHI2, &context2); - pdgemr2d_(&PHI_rows, &ione, dm.T.ptr(), &ione, &ione, descB, + pdgemr2d_(&PHI_rows, &ione, tr.dm.T.ptr(), &ione, &ione, descB, dm2.T.ptr(), &ione, &ione, descB2, &context2); // make a copy of dm.Phi and dm2.Phi