diff --git a/python/lammps.py b/python/lammps.py index 0ab218bf7e5cea784dcbb034e408c25b7ce0a3be..ee921a1f48af89a5b9cd756a50e01a2627faa9b3 100644 --- a/python/lammps.py +++ b/python/lammps.py @@ -53,6 +53,7 @@ class lammps(object): def __init__(self,name="",cmdargs=None,ptr=None,comm=None): self.comm = comm + self.opened = 0 # determine module location @@ -133,21 +134,20 @@ class lammps(object): # self.lmp = self.lib.lammps_open_no_mpi(0,None) else: - if isinstance(ptr,lammps): - # magic to convert ptr to ctypes ptr - pythonapi.PyCObject_AsVoidPtr.restype = c_void_p - pythonapi.PyCObject_AsVoidPtr.argtypes = [py_object] - self.lmp = c_void_p(pythonapi.PyCObject_AsVoidPtr(ptr)) - else: - self.lmp = None - raise TypeError('Unsupported type passed as "ptr"') + # magic to convert ptr to ctypes ptr + pythonapi.PyCObject_AsVoidPtr.restype = c_void_p + pythonapi.PyCObject_AsVoidPtr.argtypes = [py_object] + self.lmp = c_void_p(pythonapi.PyCObject_AsVoidPtr(ptr)) def __del__(self): - if self.lmp and self.opened: self.lib.lammps_close(self.lmp) + if self.lmp and self.opened: + self.lib.lammps_close(self.lmp) + self.opened = 0 def close(self): if self.opened: self.lib.lammps_close(self.lmp) self.lmp = None + self.opened = 0 def version(self): return self.lib.lammps_version(self.lmp) @@ -507,8 +507,7 @@ class PyLammps(object): elif isinstance(ptr,lammps): self.lmp = ptr else: - self.lmp = None - raise TypeError('Unsupported type passed as "ptr"') + self.lmp = lammps(name=name,cmdargs=cmdargs,ptr=ptr,comm=comm) else: self.lmp = lammps(name=name,cmdargs=cmdargs,ptr=None,comm=comm) print("LAMMPS output is captured by PyLammps wrapper")