From 031077b4fa2c62d59da6720b68c7dd633eb87377 Mon Sep 17 00:00:00 2001
From: Stefan Paquay <stefanpaquay@gmail.com>
Date: Fri, 1 Jun 2018 17:19:53 -0400
Subject: [PATCH] Made enforce2d also set rotations to in-plane.

---
 src/KOKKOS/fix_enforce2d_kokkos.cpp | 106 ++++++++++++++++++++++++----
 src/KOKKOS/fix_enforce2d_kokkos.h   |  15 ++--
 2 files changed, 103 insertions(+), 18 deletions(-)

diff --git a/src/KOKKOS/fix_enforce2d_kokkos.cpp b/src/KOKKOS/fix_enforce2d_kokkos.cpp
index b5fb964ea8..88291ead6e 100644
--- a/src/KOKKOS/fix_enforce2d_kokkos.cpp
+++ b/src/KOKKOS/fix_enforce2d_kokkos.cpp
@@ -12,13 +12,16 @@
 ------------------------------------------------------------------------- */
 
 /* ----------------------------------------------------------------------
-   Contributing authors: Stefan Paquay (Brandeis University)
+   Contributing authors: Stefan Paquay & Matthew Peterson (Brandeis University)
 ------------------------------------------------------------------------- */
 
 #include "atom_masks.h"
 #include "atom_kokkos.h"
+#include "comm.h"
+#include "error.h"
 #include "fix_enforce2d_kokkos.h"
 
+
 using namespace LAMMPS_NS;
 
 
@@ -30,14 +33,21 @@ FixEnforce2DKokkos<DeviceType>::FixEnforce2DKokkos(LAMMPS *lmp, int narg, char *
   atomKK = (AtomKokkos *) atom;
   execution_space = ExecutionSpaceFromDevice<DeviceType>::space;
 
-  datamask_read   = X_MASK | V_MASK | F_MASK | MASK_MASK;
-  datamask_modify = X_MASK | V_MASK | F_MASK;
+  datamask_read   = X_MASK | V_MASK | F_MASK | OMEGA_MASK | MASK_MASK;
+  /* TORQUE_MASK | ANGMOM_MASK | */ // MASK_MASK;
+
+  datamask_modify = X_MASK | V_MASK | F_MASK | OMEGA_MASK; // |
+	  /* TORQUE_MASK | ANGMOM_MASK */ ;
 }
 
 
 template <class DeviceType>
 void FixEnforce2DKokkos<DeviceType>::setup(int vflag)
 {
+  if( comm->me == 0 ){
+    fprintf(screen, "omega, angmom and torque flags are %d, %d, %d\n",
+            atomKK->omega_flag, atomKK->angmom_flag, atomKK->torque_flag );
+  }
   post_force(vflag);
 }
 
@@ -52,13 +62,71 @@ void FixEnforce2DKokkos<DeviceType>::post_force(int vflag)
   v = atomKK->k_v.view<DeviceType>();
   f = atomKK->k_f.view<DeviceType>();
 
+  if( atomKK->omega_flag )
+    omega  = atomKK->k_omega.view<DeviceType>();
+
+  if( atomKK->angmom_flag )
+    angmom = atomKK->k_angmom.view<DeviceType>();
+
+  if( atomKK->torque_flag )
+    torque = atomKK->k_torque.view<DeviceType>();
+
+
   mask = atomKK->k_mask.view<DeviceType>();
 
   int nlocal = atomKK->nlocal;
   if (igroup == atomKK->firstgroup) nlocal = atomKK->nfirst;
 
-  FixEnforce2DKokkosPostForceFunctor<DeviceType> functor(this);
-  Kokkos::parallel_for(nlocal,functor);
+  int flag_mask = 0;
+  if( atomKK->omega_flag ) flag_mask  |= 1;
+  if( atomKK->angmom_flag ) flag_mask |= 2;
+  if( atomKK->torque_flag ) flag_mask |= 4;
+
+  switch( flag_mask ){
+    case 0:{
+      FixEnforce2DKokkosPostForceFunctor<DeviceType,0,0,0> functor(this);
+      Kokkos::parallel_for(nlocal,functor);
+      break;
+    }
+    case 1:{
+      FixEnforce2DKokkosPostForceFunctor<DeviceType,1,0,0> functor(this);
+      Kokkos::parallel_for(nlocal,functor);
+      break;
+    }
+    case 2:{
+      FixEnforce2DKokkosPostForceFunctor<DeviceType,0,1,0> functor(this);
+      Kokkos::parallel_for(nlocal,functor);
+      break;
+    }
+    case 3:{
+      FixEnforce2DKokkosPostForceFunctor<DeviceType,1,1,0> functor(this);
+      Kokkos::parallel_for(nlocal,functor);
+      break;
+    }
+    case 4:{
+      FixEnforce2DKokkosPostForceFunctor<DeviceType,0,0,1> functor(this);
+      Kokkos::parallel_for(nlocal,functor);
+      break;
+    }
+    case 5:{
+      FixEnforce2DKokkosPostForceFunctor<DeviceType,1,0,1> functor(this);
+      Kokkos::parallel_for(nlocal,functor);
+      break;
+    }
+    case 6:{
+      FixEnforce2DKokkosPostForceFunctor<DeviceType,0,1,1> functor(this);
+      Kokkos::parallel_for(nlocal,functor);
+      break;
+    }
+    case 7:{
+      FixEnforce2DKokkosPostForceFunctor<DeviceType,1,1,1> functor(this);
+      Kokkos::parallel_for(nlocal,functor);
+      break;
+    }
+    default:
+      error->all(FLERR, "flag_mask outside of what it should be");
+  }
+
 
   // Probably sync here again?
   atomKK->sync(execution_space,datamask_read);
@@ -66,23 +134,33 @@ void FixEnforce2DKokkos<DeviceType>::post_force(int vflag)
 
   for (int m = 0; m < nfixlist; m++)
     flist[m]->enforce2d();
-
-
 }
 
 
 template <class DeviceType>
+template <int omega_flag, int angmom_flag, int torque_flag>
 void FixEnforce2DKokkos<DeviceType>::post_force_item( int i ) const
 {
-
   if (mask[i] & groupbit){
-    v(i,2) = 0;
-    x(i,2) = 0;
-    f(i,2) = 0;
-
-    // Add for omega, angmom, torque...
+    // x(i,2) = 0; // Enforce2d does not set x[2] to zero either... :/
+    v(i,2) = 0.0;
+    f(i,2) = 0.0;
+
+    if(omega_flag){
+      omega(i,0) = 0.0;
+      omega(i,1) = 0.0;
+    }
+
+    if(angmom_flag){
+      angmom(i,0) = 0.0;
+      angmom(i,1) = 0.0;
+    }
+
+    if(torque_flag){
+      torque(i,0) = 0.0;
+      torque(i,1) = 0.0;
+    }
   }
-
 }
 
 
diff --git a/src/KOKKOS/fix_enforce2d_kokkos.h b/src/KOKKOS/fix_enforce2d_kokkos.h
index 11cb213210..4130797f2c 100644
--- a/src/KOKKOS/fix_enforce2d_kokkos.h
+++ b/src/KOKKOS/fix_enforce2d_kokkos.h
@@ -37,8 +37,9 @@ class FixEnforce2DKokkos : public FixEnforce2D {
   void setup(int);
   void post_force(int);
 
+  template <int omega_flag, int angmom_flag, int torque_flag>
   KOKKOS_INLINE_FUNCTION
-  void post_force_item(int) const;
+  void post_force_item(const int i) const;
 
   // void min_setup(int);       Kokkos does not support minimization (yet)
   // void min_post_force(int);  Kokkos does not support minimization (yet)
@@ -50,20 +51,26 @@ class FixEnforce2DKokkos : public FixEnforce2D {
   typename ArrayTypes<DeviceType>::t_v_array v;
   typename ArrayTypes<DeviceType>::t_f_array f;
 
+  typename ArrayTypes<DeviceType>::t_v_array omega;
+  typename ArrayTypes<DeviceType>::t_v_array angmom;
+  typename ArrayTypes<DeviceType>::t_f_array torque;
+
   typename ArrayTypes<DeviceType>::t_int_1d mask;
 };
 
 
-template <class DeviceType>
-struct FixEnforce2DKokkosPostForceFunctor  {
+template <class DeviceType, int omega_flag, int angmom_flag, int torque_flag>
+struct FixEnforce2DKokkosPostForceFunctor {
   typedef DeviceType device_type;
   FixEnforce2DKokkos<DeviceType> c;
 
   FixEnforce2DKokkosPostForceFunctor(FixEnforce2DKokkos<DeviceType>* c_ptr):
     c(*c_ptr) {c.cleanup_copy();};
+
   KOKKOS_INLINE_FUNCTION
   void operator()(const int i) const {
-    c.post_force_item(i);
+    // c.template? Really C++?
+    c.template post_force_item <omega_flag, angmom_flag, torque_flag>(i);
   }
 };
 
-- 
GitLab