From 76d876f861652a5cd2be50dd50c634788d7a5f0e Mon Sep 17 00:00:00 2001
From: Richard Berger <richard.berger@temple.edu>
Date: Thu, 15 Sep 2016 22:11:58 -0400
Subject: [PATCH] Allow detection of MPI_Abort condition in library call

The return value of `lammps_get_last_error_message` now encodes if the last
error was recoverable or should cause an `MPI_Abort`. The driving code is
responsible of reacting to the error and calling `MPI_Abort` on the
communicator it passed to the LAMMPS instance.
---
 python/lammps.py | 12 ++++++++++--
 src/error.cpp    | 18 ++++++++++++++----
 src/error.h      | 12 ++++++++++--
 src/library.cpp  | 20 +++++++++++++++++---
 4 files changed, 51 insertions(+), 11 deletions(-)

diff --git a/python/lammps.py b/python/lammps.py
index 7926fa5e4a..ef61ed1c96 100644
--- a/python/lammps.py
+++ b/python/lammps.py
@@ -43,6 +43,7 @@ class lammps(object):
   # create instance of LAMMPS
 
   def __init__(self,name="",cmdargs=None,ptr=None,comm=None):
+    self.comm = comm
 
     # determine module location
 
@@ -152,8 +153,15 @@ class lammps(object):
 
     if self.lib.lammps_has_error(self.lmp):
       sb = create_string_buffer(100)
-      self.lib.lammps_get_last_error_message(self.lmp, sb, 100)
-      raise Exception(sb.value.decode().strip())
+      error_type = self.lib.lammps_get_last_error_message(self.lmp, sb, 100)
+      error_msg = sb.value.decode().strip()
+
+      if error_type == 2 and lammps.has_mpi4py_v2 and self.comm != None and self.comm.Get_size() > 1:
+        print(error_msg, file=sys.stderr)
+        print("Aborting...", file=sys.stderr)
+        sys.stderr.flush()
+        self.comm.Abort()
+      raise Exception(error_msg)
 
   def extract_global(self,name,type):
     if name: name = name.encode()
diff --git a/src/error.cpp b/src/error.cpp
index 614b62d5f0..237984bfaf 100644
--- a/src/error.cpp
+++ b/src/error.cpp
@@ -22,7 +22,7 @@ using namespace LAMMPS_NS;
 
 /* ---------------------------------------------------------------------- */
 
-Error::Error(LAMMPS *lmp) : Pointers(lmp), last_error_message(NULL) {}
+Error::Error(LAMMPS *lmp) : Pointers(lmp), last_error_message(NULL), last_error_type(ERROR_NONE) {}
 
 /* ----------------------------------------------------------------------
    called by all procs in universe
@@ -208,13 +208,22 @@ char * Error::get_last_error() const
   return last_error_message;
 }
 
+/* ----------------------------------------------------------------------
+   return the type of the last error reported by LAMMPS (only used if
+   compiled with -DLAMMPS_EXCEPTIONS)
+------------------------------------------------------------------------- */
+
+ErrorType Error::get_last_error_type() const
+{
+  return last_error_type;
+}
 
 /* ----------------------------------------------------------------------
-   set the last error message (only used if compiled with
-   -DLAMMPS_EXCEPTIONS)
+   set the last error message and error type
+   (only used if compiled with -DLAMMPS_EXCEPTIONS)
 ------------------------------------------------------------------------- */
 
-void Error::set_last_error(const char * msg)
+void Error::set_last_error(const char * msg, ErrorType type)
 {
   delete [] last_error_message;
 
@@ -224,4 +233,5 @@ void Error::set_last_error(const char * msg)
   } else {
     last_error_message = NULL;
   }
+  last_error_type = type;
 }
diff --git a/src/error.h b/src/error.h
index 0ca9440572..0fa3475d1e 100644
--- a/src/error.h
+++ b/src/error.h
@@ -47,8 +47,15 @@ public:
   }
 };
 
+enum ErrorType {
+   ERROR_NONE   = 0,
+   ERROR_NORMAL = 1,
+   ERROR_ABORT  = 2
+};
+
 class Error : protected Pointers {
   char * last_error_message;
+  ErrorType last_error_type;
 
  public:
   Error(class LAMMPS *);
@@ -63,8 +70,9 @@ class Error : protected Pointers {
   void message(const char *, int, const char *, int = 1);
   void done(int = 0); // 1 would be fully backwards compatible
 
-  char * get_last_error() const;
-  void   set_last_error(const char * msg);
+  char *    get_last_error() const;
+  ErrorType get_last_error_type() const;
+  void   set_last_error(const char * msg, ErrorType type = ERROR_NORMAL);
 };
 
 }
diff --git a/src/library.cpp b/src/library.cpp
index 8aa1d9978e..4d493c979b 100644
--- a/src/library.cpp
+++ b/src/library.cpp
@@ -113,8 +113,18 @@ char *lammps_command(void *ptr, char *str)
 
   try {
     return lmp->input->one(str);
+  } catch(LAMMPSAbortException & ae) {
+    int nprocs = 0;
+    MPI_Comm_size(ae.universe, &nprocs );
+
+    if (nprocs > 1) {
+      error->set_last_error(ae.message.c_str(), ERROR_ABORT);
+    } else {
+      error->set_last_error(ae.message.c_str(), ERROR_NORMAL);
+    }
+    return NULL;
   } catch(LAMMPSException & e) {
-    error->set_last_error(e.message.c_str());
+    error->set_last_error(e.message.c_str(), ERROR_NORMAL);
     return NULL;
   }
 }
@@ -613,6 +623,9 @@ int lammps_has_error(void *ptr) {
 
 /* ----------------------------------------------------------------------
    Copy the last error message of LAMMPS into a character buffer
+   The return value encodes which type of error it is.
+   1 = normal error (recoverable)
+   2 = abort error (non-recoverable)
 ------------------------------------------------------------------------- */
 
 int lammps_get_last_error_message(void *ptr, char * buffer, int buffer_size) {
@@ -620,9 +633,10 @@ int lammps_get_last_error_message(void *ptr, char * buffer, int buffer_size) {
   Error * error = lmp->error;
 
   if(error->get_last_error()) {
+    int error_type = error->get_last_error_type();
     strncpy(buffer, error->get_last_error(), buffer_size-1);
-    error->set_last_error(NULL);
-    return 1;
+    error->set_last_error(NULL, ERROR_NONE);
+    return error_type;
   }
   return 0;
 }
-- 
GitLab