From 93be2d264ef8b9f7746a29ad46f67ed9e3127f57 Mon Sep 17 00:00:00 2001 From: Richard Berger <richard.berger@temple.edu> Date: Sat, 15 Jul 2017 14:52:08 -0500 Subject: [PATCH] Detect correct integer type in lammps python interface --- examples/python/py_integrate.py | 15 ++++++++++---- python/lammps.py | 36 ++++++++++++++++++++++++--------- 2 files changed, 37 insertions(+), 14 deletions(-) diff --git a/examples/python/py_integrate.py b/examples/python/py_integrate.py index ca45c0bcd0..5884ae19e8 100644 --- a/examples/python/py_integrate.py +++ b/examples/python/py_integrate.py @@ -4,9 +4,14 @@ import ctypes import traceback import numpy as np -class LAMMPSIntegrator(object): - def __init__(self, ptr): +class LAMMPSFix(object): + def __init__(self, ptr, group_name="all"): self.lmp = lammps.lammps(ptr=ptr) + self.group_name = group_name + +class LAMMPSIntegrator(LAMMPSFix): + def __init__(self, ptr, group_name="all"): + super(LAMMPSIntegrator, self).__init__(ptr, group_name) def init(self): pass @@ -29,8 +34,9 @@ class LAMMPSIntegrator(object): class NVE(LAMMPSIntegrator): """ Python implementation of fix/nve """ - def __init__(self, ptr): + def __init__(self, ptr, group_name="all"): super(NVE, self).__init__(ptr) + assert(self.group_name == "all") def init(self): dt = self.lmp.extract_global("dt", 1) @@ -66,8 +72,9 @@ class NVE(LAMMPSIntegrator): class NVE_Opt(LAMMPSIntegrator): """ Tuned Python implementation of fix/nve """ - def __init__(self, ptr): + def __init__(self, ptr, group_name="all"): super(NVE_Opt, self).__init__(ptr) + assert(self.group_name == "all") def init(self): dt = self.lmp.extract_global("dt", 1) diff --git a/python/lammps.py b/python/lammps.py index 944eaeabf5..a53bc431be 100644 --- a/python/lammps.py +++ b/python/lammps.py @@ -37,7 +37,7 @@ def get_ctypes_int(size): return c_int32 elif size == 8: return c_int64 - return c_int + return c_int class MPIAbortException(Exception): def __init__(self, message): @@ -266,25 +266,41 @@ class lammps(object): def __init__(self, lmp): self.lmp = lmp + def _ctype_to_numpy_int(self, ctype_int): + if ctype_int == c_int32: + return np.int32 + elif ctype_int == c_int64: + return np.int64 + return np.intc + def extract_atom_iarray(self, name, nelem, dim=1): + if name in ['id', 'molecule']: + c_int_type = self.lmp.c_tagint + elif name in ['image']: + c_int_type = self.lmp.c_imageint + else: + c_int_type = c_int + + np_int_type = self._ctype_to_numpy_int(c_int_type) + if dim == 1: - tmp = self.lmp.extract_atom(name, 0) - ptr = cast(tmp, POINTER(c_int * nelem)) + tmp = self.lmp.extract_atom(name, 0) + ptr = cast(tmp, POINTER(c_int_type * nelem)) else: - tmp = self.lmp.extract_atom(name, 1) - ptr = cast(tmp[0], POINTER(c_int * nelem * dim)) + tmp = self.lmp.extract_atom(name, 1) + ptr = cast(tmp[0], POINTER(c_int_type * nelem * dim)) - a = np.frombuffer(ptr.contents, dtype=np.intc) + a = np.frombuffer(ptr.contents, dtype=np_int_type) a.shape = (nelem, dim) return a def extract_atom_darray(self, name, nelem, dim=1): if dim == 1: - tmp = self.lmp.extract_atom(name, 2) - ptr = cast(tmp, POINTER(c_double * nelem)) + tmp = self.lmp.extract_atom(name, 2) + ptr = cast(tmp, POINTER(c_double * nelem)) else: - tmp = self.lmp.extract_atom(name, 3) - ptr = cast(tmp[0], POINTER(c_double * nelem * dim)) + tmp = self.lmp.extract_atom(name, 3) + ptr = cast(tmp[0], POINTER(c_double * nelem * dim)) a = np.frombuffer(ptr.contents) a.shape = (nelem, dim) -- GitLab