From f8ef1f4d4169a10fa88dc9c3b6d546843f8eb031 Mon Sep 17 00:00:00 2001 From: Marcin Kirsz <mkirsz@ed.ac.uk> Date: Fri, 18 Oct 2024 12:41:34 +0100 Subject: [PATCH] Update --- trainer.h | 41 +++++++++++++++++++++++++++++++++++++---- 1 file changed, 37 insertions(+), 4 deletions(-) diff --git a/trainer.h b/trainer.h index a783760..3f72c8c 100644 --- a/trainer.h +++ b/trainer.h @@ -295,12 +295,12 @@ class MPI_Trainer: public Trainer { } }; -class MPI_Trainer_HOST: public MPI_Trainer { +class TrainerHost: public MPI_Trainer { private: std::vector<std::tuple<std::string,int,int>> wpckgs; public: - MPI_Trainer_HOST(Config &c, int &rank, int &ncpu): + TrainerHost(Config &c, int &rank, int &ncpu): MPI_Trainer(c, rank, ncpu) {} @@ -421,12 +421,45 @@ class MPI_Trainer_HOST: public MPI_Trainer { MPI_Send (&ready, 1, MPI_INT, worker, MPI_Trainer::CONFIG_TAG, MPI_COMM_WORLD); } + void run() { + prep_wpckgs(); + + // Process requests from workers + while (true) { + // Exit loop if there are no more packages + if (!has_packages()) break; + + // Probe for incoming requests from any worker + probe(); + + // Handle requests based on their tags + if (tag == MPI_Trainer::WORK_TAG) { + work_tag(); + } else if (tag == MPI_Trainer::DATA_TAG) { + data_tag(); + } else { + throw std::runtime_error("HOST: Unexpected request from " + + std::to_string(worker)); + } + } + // Collect remaining data and release all workers + int count = 1; + while (true) { + probe(); + if (tag == MPI_Trainer::DATA_TAG) { + data_tag(); + } else { + release_tag(count); + } + if (count == ncpu) break; // Exit when all workers are released + } + } }; -class Trainer_WORKER: public MPI_Trainer { +class TrainerWorker: public MPI_Trainer { public: - Trainer_WORKER(Config &c, int &rank, int &ncpu): + TrainerWorker(Config &c, int &rank, int &ncpu): MPI_Trainer(c, rank, ncpu) {} -- GitLab