From 76d876f861652a5cd2be50dd50c634788d7a5f0e Mon Sep 17 00:00:00 2001 From: Richard Berger <richard.berger@temple.edu> Date: Thu, 15 Sep 2016 22:11:58 -0400 Subject: [PATCH] Allow detection of MPI_Abort condition in library call The return value of `lammps_get_last_error_message` now encodes if the last error was recoverable or should cause an `MPI_Abort`. The driving code is responsible of reacting to the error and calling `MPI_Abort` on the communicator it passed to the LAMMPS instance. --- python/lammps.py | 12 ++++++++++-- src/error.cpp | 18 ++++++++++++++---- src/error.h | 12 ++++++++++-- src/library.cpp | 20 +++++++++++++++++--- 4 files changed, 51 insertions(+), 11 deletions(-) diff --git a/python/lammps.py b/python/lammps.py index 7926fa5e4a..ef61ed1c96 100644 --- a/python/lammps.py +++ b/python/lammps.py @@ -43,6 +43,7 @@ class lammps(object): # create instance of LAMMPS def __init__(self,name="",cmdargs=None,ptr=None,comm=None): + self.comm = comm # determine module location @@ -152,8 +153,15 @@ class lammps(object): if self.lib.lammps_has_error(self.lmp): sb = create_string_buffer(100) - self.lib.lammps_get_last_error_message(self.lmp, sb, 100) - raise Exception(sb.value.decode().strip()) + error_type = self.lib.lammps_get_last_error_message(self.lmp, sb, 100) + error_msg = sb.value.decode().strip() + + if error_type == 2 and lammps.has_mpi4py_v2 and self.comm != None and self.comm.Get_size() > 1: + print(error_msg, file=sys.stderr) + print("Aborting...", file=sys.stderr) + sys.stderr.flush() + self.comm.Abort() + raise Exception(error_msg) def extract_global(self,name,type): if name: name = name.encode() diff --git a/src/error.cpp b/src/error.cpp index 614b62d5f0..237984bfaf 100644 --- a/src/error.cpp +++ b/src/error.cpp @@ -22,7 +22,7 @@ using namespace LAMMPS_NS; /* ---------------------------------------------------------------------- */ -Error::Error(LAMMPS *lmp) : Pointers(lmp), last_error_message(NULL) {} +Error::Error(LAMMPS *lmp) : Pointers(lmp), last_error_message(NULL), last_error_type(ERROR_NONE) {} /* ---------------------------------------------------------------------- called by all procs in universe @@ -208,13 +208,22 @@ char * Error::get_last_error() const return last_error_message; } +/* ---------------------------------------------------------------------- + return the type of the last error reported by LAMMPS (only used if + compiled with -DLAMMPS_EXCEPTIONS) +------------------------------------------------------------------------- */ + +ErrorType Error::get_last_error_type() const +{ + return last_error_type; +} /* ---------------------------------------------------------------------- - set the last error message (only used if compiled with - -DLAMMPS_EXCEPTIONS) + set the last error message and error type + (only used if compiled with -DLAMMPS_EXCEPTIONS) ------------------------------------------------------------------------- */ -void Error::set_last_error(const char * msg) +void Error::set_last_error(const char * msg, ErrorType type) { delete [] last_error_message; @@ -224,4 +233,5 @@ void Error::set_last_error(const char * msg) } else { last_error_message = NULL; } + last_error_type = type; } diff --git a/src/error.h b/src/error.h index 0ca9440572..0fa3475d1e 100644 --- a/src/error.h +++ b/src/error.h @@ -47,8 +47,15 @@ public: } }; +enum ErrorType { + ERROR_NONE = 0, + ERROR_NORMAL = 1, + ERROR_ABORT = 2 +}; + class Error : protected Pointers { char * last_error_message; + ErrorType last_error_type; public: Error(class LAMMPS *); @@ -63,8 +70,9 @@ class Error : protected Pointers { void message(const char *, int, const char *, int = 1); void done(int = 0); // 1 would be fully backwards compatible - char * get_last_error() const; - void set_last_error(const char * msg); + char * get_last_error() const; + ErrorType get_last_error_type() const; + void set_last_error(const char * msg, ErrorType type = ERROR_NORMAL); }; } diff --git a/src/library.cpp b/src/library.cpp index 8aa1d9978e..4d493c979b 100644 --- a/src/library.cpp +++ b/src/library.cpp @@ -113,8 +113,18 @@ char *lammps_command(void *ptr, char *str) try { return lmp->input->one(str); + } catch(LAMMPSAbortException & ae) { + int nprocs = 0; + MPI_Comm_size(ae.universe, &nprocs ); + + if (nprocs > 1) { + error->set_last_error(ae.message.c_str(), ERROR_ABORT); + } else { + error->set_last_error(ae.message.c_str(), ERROR_NORMAL); + } + return NULL; } catch(LAMMPSException & e) { - error->set_last_error(e.message.c_str()); + error->set_last_error(e.message.c_str(), ERROR_NORMAL); return NULL; } } @@ -613,6 +623,9 @@ int lammps_has_error(void *ptr) { /* ---------------------------------------------------------------------- Copy the last error message of LAMMPS into a character buffer + The return value encodes which type of error it is. + 1 = normal error (recoverable) + 2 = abort error (non-recoverable) ------------------------------------------------------------------------- */ int lammps_get_last_error_message(void *ptr, char * buffer, int buffer_size) { @@ -620,9 +633,10 @@ int lammps_get_last_error_message(void *ptr, char * buffer, int buffer_size) { Error * error = lmp->error; if(error->get_last_error()) { + int error_type = error->get_last_error_type(); strncpy(buffer, error->get_last_error(), buffer_size-1); - error->set_last_error(NULL); - return 1; + error->set_last_error(NULL, ERROR_NONE); + return error_type; } return 0; } -- GitLab