From 93be2d264ef8b9f7746a29ad46f67ed9e3127f57 Mon Sep 17 00:00:00 2001
From: Richard Berger <richard.berger@temple.edu>
Date: Sat, 15 Jul 2017 14:52:08 -0500
Subject: [PATCH] Detect correct integer type in lammps python interface

---
 examples/python/py_integrate.py | 15 ++++++++++----
 python/lammps.py                | 36 ++++++++++++++++++++++++---------
 2 files changed, 37 insertions(+), 14 deletions(-)

diff --git a/examples/python/py_integrate.py b/examples/python/py_integrate.py
index ca45c0bcd0..5884ae19e8 100644
--- a/examples/python/py_integrate.py
+++ b/examples/python/py_integrate.py
@@ -4,9 +4,14 @@ import ctypes
 import traceback
 import numpy as np
 
-class LAMMPSIntegrator(object):
-    def __init__(self, ptr):
+class LAMMPSFix(object):
+    def __init__(self, ptr, group_name="all"):
         self.lmp = lammps.lammps(ptr=ptr)
+        self.group_name = group_name
+
+class LAMMPSIntegrator(LAMMPSFix):
+    def __init__(self, ptr, group_name="all"):
+        super(LAMMPSIntegrator, self).__init__(ptr, group_name)
 
     def init(self):
         pass
@@ -29,8 +34,9 @@ class LAMMPSIntegrator(object):
 
 class NVE(LAMMPSIntegrator):
     """ Python implementation of fix/nve """
-    def __init__(self, ptr):
+    def __init__(self, ptr, group_name="all"):
         super(NVE, self).__init__(ptr)
+        assert(self.group_name == "all")
 
     def init(self):
         dt = self.lmp.extract_global("dt", 1)
@@ -66,8 +72,9 @@ class NVE(LAMMPSIntegrator):
 
 class NVE_Opt(LAMMPSIntegrator):
     """ Tuned Python implementation of fix/nve """
-    def __init__(self, ptr):
+    def __init__(self, ptr, group_name="all"):
         super(NVE_Opt, self).__init__(ptr)
+        assert(self.group_name == "all")
 
     def init(self):
         dt = self.lmp.extract_global("dt", 1)
diff --git a/python/lammps.py b/python/lammps.py
index 944eaeabf5..a53bc431be 100644
--- a/python/lammps.py
+++ b/python/lammps.py
@@ -37,7 +37,7 @@ def get_ctypes_int(size):
     return c_int32
   elif size == 8:
     return c_int64
-  return c_int
+  return c_int 
 
 class MPIAbortException(Exception):
   def __init__(self, message):
@@ -266,25 +266,41 @@ class lammps(object):
         def __init__(self, lmp):
           self.lmp = lmp
 
+        def _ctype_to_numpy_int(self, ctype_int):
+          if ctype_int == c_int32:
+            return np.int32
+          elif ctype_int == c_int64:
+            return np.int64
+          return np.intc
+
         def extract_atom_iarray(self, name, nelem, dim=1):
+          if name in ['id', 'molecule']:
+            c_int_type = self.lmp.c_tagint
+          elif name in ['image']:
+            c_int_type = self.lmp.c_imageint
+          else:
+            c_int_type = c_int
+
+          np_int_type = self._ctype_to_numpy_int(c_int_type)
+
           if dim == 1:
-              tmp = self.lmp.extract_atom(name, 0)
-              ptr = cast(tmp, POINTER(c_int * nelem))
+            tmp = self.lmp.extract_atom(name, 0)
+            ptr = cast(tmp, POINTER(c_int_type * nelem))
           else:
-              tmp = self.lmp.extract_atom(name, 1)
-              ptr = cast(tmp[0], POINTER(c_int * nelem * dim))
+            tmp = self.lmp.extract_atom(name, 1)
+            ptr = cast(tmp[0], POINTER(c_int_type * nelem * dim))
 
-          a = np.frombuffer(ptr.contents, dtype=np.intc)
+          a = np.frombuffer(ptr.contents, dtype=np_int_type)
           a.shape = (nelem, dim)
           return a
 
         def extract_atom_darray(self, name, nelem, dim=1):
           if dim == 1:
-              tmp = self.lmp.extract_atom(name, 2)
-              ptr = cast(tmp, POINTER(c_double * nelem))
+            tmp = self.lmp.extract_atom(name, 2)
+            ptr = cast(tmp, POINTER(c_double * nelem))
           else:
-              tmp = self.lmp.extract_atom(name, 3)
-              ptr = cast(tmp[0], POINTER(c_double * nelem * dim))
+            tmp = self.lmp.extract_atom(name, 3)
+            ptr = cast(tmp[0], POINTER(c_double * nelem * dim))
 
           a = np.frombuffer(ptr.contents)
           a.shape = (nelem, dim)
-- 
GitLab