diff --git a/examples/python/py_integrate.py b/examples/python/py_integrate.py index ca45c0bcd0a9e847d6857530445aa024af085bc8..5884ae19e8ced0ab169265c250821c6298c42e0f 100644 --- a/examples/python/py_integrate.py +++ b/examples/python/py_integrate.py @@ -4,9 +4,14 @@ import ctypes import traceback import numpy as np -class LAMMPSIntegrator(object): - def __init__(self, ptr): +class LAMMPSFix(object): + def __init__(self, ptr, group_name="all"): self.lmp = lammps.lammps(ptr=ptr) + self.group_name = group_name + +class LAMMPSIntegrator(LAMMPSFix): + def __init__(self, ptr, group_name="all"): + super(LAMMPSIntegrator, self).__init__(ptr, group_name) def init(self): pass @@ -29,8 +34,9 @@ class LAMMPSIntegrator(object): class NVE(LAMMPSIntegrator): """ Python implementation of fix/nve """ - def __init__(self, ptr): + def __init__(self, ptr, group_name="all"): super(NVE, self).__init__(ptr) + assert(self.group_name == "all") def init(self): dt = self.lmp.extract_global("dt", 1) @@ -66,8 +72,9 @@ class NVE(LAMMPSIntegrator): class NVE_Opt(LAMMPSIntegrator): """ Tuned Python implementation of fix/nve """ - def __init__(self, ptr): + def __init__(self, ptr, group_name="all"): super(NVE_Opt, self).__init__(ptr) + assert(self.group_name == "all") def init(self): dt = self.lmp.extract_global("dt", 1) diff --git a/python/lammps.py b/python/lammps.py index 944eaeabf5a982afddfb79cc8a9f4f286c8aa94e..a53bc431be3c4c841f596ccdf4dc0f1998425d54 100644 --- a/python/lammps.py +++ b/python/lammps.py @@ -37,7 +37,7 @@ def get_ctypes_int(size): return c_int32 elif size == 8: return c_int64 - return c_int + return c_int class MPIAbortException(Exception): def __init__(self, message): @@ -266,25 +266,41 @@ class lammps(object): def __init__(self, lmp): self.lmp = lmp + def _ctype_to_numpy_int(self, ctype_int): + if ctype_int == c_int32: + return np.int32 + elif ctype_int == c_int64: + return np.int64 + return np.intc + def extract_atom_iarray(self, name, nelem, dim=1): + if name in ['id', 'molecule']: + c_int_type = self.lmp.c_tagint + elif name in ['image']: + c_int_type = self.lmp.c_imageint + else: + c_int_type = c_int + + np_int_type = self._ctype_to_numpy_int(c_int_type) + if dim == 1: - tmp = self.lmp.extract_atom(name, 0) - ptr = cast(tmp, POINTER(c_int * nelem)) + tmp = self.lmp.extract_atom(name, 0) + ptr = cast(tmp, POINTER(c_int_type * nelem)) else: - tmp = self.lmp.extract_atom(name, 1) - ptr = cast(tmp[0], POINTER(c_int * nelem * dim)) + tmp = self.lmp.extract_atom(name, 1) + ptr = cast(tmp[0], POINTER(c_int_type * nelem * dim)) - a = np.frombuffer(ptr.contents, dtype=np.intc) + a = np.frombuffer(ptr.contents, dtype=np_int_type) a.shape = (nelem, dim) return a def extract_atom_darray(self, name, nelem, dim=1): if dim == 1: - tmp = self.lmp.extract_atom(name, 2) - ptr = cast(tmp, POINTER(c_double * nelem)) + tmp = self.lmp.extract_atom(name, 2) + ptr = cast(tmp, POINTER(c_double * nelem)) else: - tmp = self.lmp.extract_atom(name, 3) - ptr = cast(tmp[0], POINTER(c_double * nelem * dim)) + tmp = self.lmp.extract_atom(name, 3) + ptr = cast(tmp[0], POINTER(c_double * nelem * dim)) a = np.frombuffer(ptr.contents) a.shape = (nelem, dim)