diff --git a/src/error.cpp b/src/error.cpp index 6989e00034cc9b96af9aefe872484483d19c04f9..614b62d5f0e9754001a6063c5aea50d093b380db 100644 --- a/src/error.cpp +++ b/src/error.cpp @@ -13,6 +13,7 @@ #include <mpi.h> #include <stdlib.h> +#include <string.h> #include "error.h" #include "universe.h" #include "output.h" @@ -21,7 +22,7 @@ using namespace LAMMPS_NS; /* ---------------------------------------------------------------------- */ -Error::Error(LAMMPS *lmp) : Pointers(lmp) {} +Error::Error(LAMMPS *lmp) : Pointers(lmp), last_error_message(NULL) {} /* ---------------------------------------------------------------------- called by all procs in universe @@ -47,8 +48,14 @@ void Error::universe_all(const char *file, int line, const char *str) } if (universe->ulogfile) fclose(universe->ulogfile); +#ifdef LAMMPS_EXCEPTIONS + char msg[100]; + sprintf(msg, "ERROR: %s (%s:%d)\n", str, file, line); + throw LAMMPSException(msg); +#else MPI_Finalize(); exit(1); +#endif } /* ---------------------------------------------------------------------- @@ -61,7 +68,14 @@ void Error::universe_one(const char *file, int line, const char *str) if (universe->uscreen) fprintf(universe->uscreen,"ERROR on proc %d: %s (%s:%d)\n", universe->me,str,file,line); + +#ifdef LAMMPS_EXCEPTIONS + char msg[100]; + sprintf(msg, "ERROR: %s (%s:%d)\n", str, file, line); + throw LAMMPSAbortException(msg, universe->uworld); +#else MPI_Abort(universe->uworld,1); +#endif } /* ---------------------------------------------------------------------- @@ -95,6 +109,16 @@ void Error::all(const char *file, int line, const char *str) if (logfile) fprintf(logfile,"ERROR: %s (%s:%d)\n",str,file,line); } +#ifdef LAMMPS_EXCEPTIONS + char msg[100]; + sprintf(msg, "ERROR: %s (%s:%d)\n", str, file, line); + + if (universe->nworlds > 1) { + throw LAMMPSAbortException(msg, universe->uworld); + } + + throw LAMMPSException(msg); +#else if (output) delete output; if (screen && screen != stdout) fclose(screen); if (logfile) fclose(logfile); @@ -102,6 +126,7 @@ void Error::all(const char *file, int line, const char *str) if (universe->nworlds > 1) MPI_Abort(universe->uworld,1); MPI_Finalize(); exit(1); +#endif } /* ---------------------------------------------------------------------- @@ -121,7 +146,14 @@ void Error::one(const char *file, int line, const char *str) if (universe->uscreen) fprintf(universe->uscreen,"ERROR on proc %d: %s (%s:%d)\n", universe->me,str,file,line); + +#ifdef LAMMPS_EXCEPTIONS + char msg[100]; + sprintf(msg, "ERROR on proc %d: %s (%s:%d)\n", me, str, file, line); + throw LAMMPSAbortException(msg, world); +#else MPI_Abort(world,1); +#endif } /* ---------------------------------------------------------------------- @@ -165,3 +197,31 @@ void Error::done(int status) MPI_Finalize(); exit(status); } + +/* ---------------------------------------------------------------------- + return the last error message reported by LAMMPS (only used if + compiled with -DLAMMPS_EXCEPTIONS) +------------------------------------------------------------------------- */ + +char * Error::get_last_error() const +{ + return last_error_message; +} + + +/* ---------------------------------------------------------------------- + set the last error message (only used if compiled with + -DLAMMPS_EXCEPTIONS) +------------------------------------------------------------------------- */ + +void Error::set_last_error(const char * msg) +{ + delete [] last_error_message; + + if(msg) { + last_error_message = new char[strlen(msg)+1]; + strcpy(last_error_message, msg); + } else { + last_error_message = NULL; + } +} diff --git a/src/error.h b/src/error.h index fb44c966dbdc9c2f31d3e4a00cc332dc4b476e53..0ca94405720ccbaec0e7267391c843d8d289c893 100644 --- a/src/error.h +++ b/src/error.h @@ -15,10 +15,41 @@ #define LMP_ERROR_H #include "pointers.h" +#include <string> +#include <exception> namespace LAMMPS_NS { +class LAMMPSException : public std::exception +{ +public: + std::string message; + + LAMMPSException(std::string msg) : message(msg) { + } + + ~LAMMPSException() throw() { + } + + virtual const char * what() const throw() { + return message.c_str(); + } +}; + +class LAMMPSAbortException : public LAMMPSException { +public: + MPI_Comm universe; + + LAMMPSAbortException(std::string msg, MPI_Comm universe) : + LAMMPSException(msg), + universe(universe) + { + } +}; + class Error : protected Pointers { + char * last_error_message; + public: Error(class LAMMPS *); @@ -31,6 +62,9 @@ class Error : protected Pointers { void warning(const char *, int, const char *, int = 1); void message(const char *, int, const char *, int = 1); void done(int = 0); // 1 would be fully backwards compatible + + char * get_last_error() const; + void set_last_error(const char * msg); }; } diff --git a/src/library.cpp b/src/library.cpp index be407c8bbc5ab61dfd1a7211784ac990f9782b6b..8aa1d9978e6abfed8c5eef4281215df724f6eeba 100644 --- a/src/library.cpp +++ b/src/library.cpp @@ -108,8 +108,15 @@ void lammps_file(void *ptr, char *str) char *lammps_command(void *ptr, char *str) { - LAMMPS *lmp = (LAMMPS *) ptr; - return lmp->input->one(str); + LAMMPS * lmp = (LAMMPS *) ptr; + Error * error = lmp->error; + + try { + return lmp->input->one(str); + } catch(LAMMPSException & e) { + error->set_last_error(e.message.c_str()); + return NULL; + } } /* ---------------------------------------------------------------------- @@ -593,3 +600,29 @@ void lammps_scatter_atoms(void *ptr, char *name, } } } + +/* ---------------------------------------------------------------------- + Check if a new error message +------------------------------------------------------------------------- */ + +int lammps_has_error(void *ptr) { + LAMMPS * lmp = (LAMMPS *) ptr; + Error * error = lmp->error; + return error->get_last_error() ? 1 : 0; +} + +/* ---------------------------------------------------------------------- + Copy the last error message of LAMMPS into a character buffer +------------------------------------------------------------------------- */ + +int lammps_get_last_error_message(void *ptr, char * buffer, int buffer_size) { + LAMMPS * lmp = (LAMMPS *) ptr; + Error * error = lmp->error; + + if(error->get_last_error()) { + strncpy(buffer, error->get_last_error(), buffer_size-1); + error->set_last_error(NULL); + return 1; + } + return 0; +} diff --git a/src/library.h b/src/library.h index 8b7dc43f084a7bd72d7bc45c1ce146f27d23a2b2..5cb128fdb99b0b769792387e76ee031de6ccb437 100644 --- a/src/library.h +++ b/src/library.h @@ -45,6 +45,9 @@ int lammps_get_natoms(void *); void lammps_gather_atoms(void *, char *, int, int, void *); void lammps_scatter_atoms(void *, char *, int, int, void *); +int lammps_has_error(void *); +int lammps_get_last_error_message(void *, char *, int); + #ifdef __cplusplus } #endif diff --git a/src/main.cpp b/src/main.cpp index fdf7a791d0c6229c8ffd33be31aea9d878941ff6..cc8f8be906b819d6caa5d1506cea3ec97dadccdf 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -14,7 +14,9 @@ #include <mpi.h> #include "lammps.h" #include "input.h" +#include "error.h" #include <stdio.h> +#include <stdlib.h> using namespace LAMMPS_NS; @@ -26,11 +28,22 @@ int main(int argc, char **argv) { MPI_Init(&argc,&argv); +#ifdef LAMMPS_EXCEPTIONS + try { + LAMMPS *lammps = new LAMMPS(argc,argv,MPI_COMM_WORLD); + lammps->input->file(); + delete lammps; + } catch(LAMMPSAbortException & ae) { + MPI_Abort(ae.universe, 1); + } catch(LAMMPSException & e) { + MPI_Finalize(); + exit(1); + } +#else LAMMPS *lammps = new LAMMPS(argc,argv,MPI_COMM_WORLD); - lammps->input->file(); delete lammps; - +#endif MPI_Barrier(MPI_COMM_WORLD); MPI_Finalize(); }