From e436e176b229d0a21f001771c7e702b7f90f2b10 Mon Sep 17 00:00:00 2001
From: mkirsz <s1351949@sms.ed.ac.uk>
Date: Mon, 24 Feb 2025 11:13:15 +0000
Subject: [PATCH] wip

---
 .../tadah/mlip/design_matrix/design_matrix.h  |  14 ++-
 .../functions/basis_functions/dm_bf_linear.h  |   6 +-
 .../basis_functions/dm_bf_polynomial2.h       |   6 +-
 .../functions/dm_function_base.h              |   7 +-
 .../functions/kernels/dm_kern_base.h          |   7 +-
 .../functions/kernels/dm_kern_linear.h        |   6 +-
 include/tadah/mlip/neighbor_list_db.h         |  14 +--
 include/tadah/mlip/structure.h                |  14 +++
 include/tadah/mlip/structure_neighbor_view.h  | 109 ++++++++++++++++--
 src/dm_bf_linear.cpp                          |   6 +-
 src/dm_bf_polynomial2.cpp                     |  35 ++++--
 src/dm_kern_base.cpp                          |   6 +-
 src/dm_kern_linear.cpp                        |   6 +-
 13 files changed, 179 insertions(+), 57 deletions(-)

diff --git a/include/tadah/mlip/design_matrix/design_matrix.h b/include/tadah/mlip/design_matrix/design_matrix.h
index 168b94b..a19ba9b 100644
--- a/include/tadah/mlip/design_matrix/design_matrix.h
+++ b/include/tadah/mlip/design_matrix/design_matrix.h
@@ -1,6 +1,7 @@
 #ifndef DESIGN_MATRIX_H
 #define DESIGN_MATRIX_H
 
+#include "tadah/mlip/structure_neighbor_view.h"
 #include <tadah/mlip/st_descriptors_db.h>
 #include <tadah/mlip/structure_db.h>
 #include <tadah/mlip/descriptors_calc.h>
@@ -11,6 +12,7 @@
 #include <tadah/mlip/memory/DesignMatrixWorkspace.h>
 
 #include <tadah/mlip/neighbor_list_db.h>
+#include <tadah/mlip/structure_neighbor_view.h>
 
 #include <stdexcept>
 
@@ -202,30 +204,32 @@ public:
 #endif
     for (size_t s=0; s<stdb.size(); ++s) {
       StDescriptors st_d = dc.calc(stdb(s), nldb);
+      // StructureNeighborView st_nb(nldb,s);  // but we need to know the structure offset in global arrays
+      StructureNeighborView st_nb(nldb.nlist(),stdb(s));
 
       if(config.get<bool>("NORM"))
         norm.normalise(st_d);
 
-      build(rows[s], stdb(s), nldb, st_d);
+      build(rows[s], stdb(s), nldb, st_nb, st_d);
     }
 
   }
-  void build(size_t &row, const Structure &st, const StDescriptors &st_d) {
+  void build(size_t &row, const Structure &st, const StructureNeighborView &st_nb, const StDescriptors &st_d) {
 
     double escale = 1;
     if (scale) escale = st.eweight*eweightglob/st.natoms();
-    f.calc_phi_energy_row(ws->Phi,row,escale,st,st_d);
+    f.calc_phi_energy_row(ws->Phi,row,escale,st,st_nb,st_d);
     if (force) {
       double fscale = 1;
       if (scale) fscale = st.fweight*fweightglob/st.natoms()/3.0;///f_std_dev;
-      f.calc_phi_force_rows(ws->Phi,row,fscale,st,st_d);
+      f.calc_phi_force_rows(ws->Phi,row,fscale,st,st_nb,st_d);
     }
     if (stress) {
       double sscale_arr[6] {1,1,1,1,1,1};
       if (scale)
         for(size_t xy=0;xy<6;++xy)
           sscale_arr[xy] = st.sweight*sweightglob/6.0;///s_std_dev[xy];
-      f.calc_phi_stress_rows(ws->Phi,row,sscale_arr,st,st_d);
+      f.calc_phi_stress_rows(ws->Phi,row,sscale_arr,st,st_nb,st_d);
     }
   }
   void fill_T(const StructureDB &stdb, size_t start=0) {
diff --git a/include/tadah/mlip/design_matrix/functions/basis_functions/dm_bf_linear.h b/include/tadah/mlip/design_matrix/functions/basis_functions/dm_bf_linear.h
index 6fa887d..96312bf 100644
--- a/include/tadah/mlip/design_matrix/functions/basis_functions/dm_bf_linear.h
+++ b/include/tadah/mlip/design_matrix/functions/basis_functions/dm_bf_linear.h
@@ -12,11 +12,11 @@ struct DM_BF_Linear: public DM_BF_Base, public BF_Linear
     DM_BF_Linear(const Config &c);
     size_t get_phi_cols(const Config &config) override;
     void calc_phi_energy_row(phi_type &Phi, size_t &row,
-            const double fac, const Structure &st, const StDescriptors &st_d) override;
+            const double fac, const Structure &st, const StructureNeighborView &st_nb, const StDescriptors &st_d) override;
     void calc_phi_force_rows(phi_type &Phi, size_t &row,
-            const double fac, const Structure &st, const StDescriptors &st_d) override;
+            const double fac, const Structure &st, const StructureNeighborView &st_nb, const StDescriptors &st_d) override;
     void calc_phi_stress_rows(phi_type &Phi, size_t &row,
-            const double fac[6], const Structure &st, const StDescriptors &st_d) override;
+            const double fac[6], const Structure &st, const StructureNeighborView &st_nb, const StDescriptors &st_d) override;
 };
 }
 }
diff --git a/include/tadah/mlip/design_matrix/functions/basis_functions/dm_bf_polynomial2.h b/include/tadah/mlip/design_matrix/functions/basis_functions/dm_bf_polynomial2.h
index bda5f5a..7bc796e 100644
--- a/include/tadah/mlip/design_matrix/functions/basis_functions/dm_bf_polynomial2.h
+++ b/include/tadah/mlip/design_matrix/functions/basis_functions/dm_bf_polynomial2.h
@@ -12,11 +12,11 @@ struct DM_BF_Polynomial2: public DM_BF_Base, public BF_Polynomial2
     DM_BF_Polynomial2(const Config &c);
     size_t get_phi_cols(const Config &config) override;
     void  calc_phi_energy_row(phi_type &Phi, size_t &row,
-            const double fac, const Structure &st, const StDescriptors &st_d) override;
+            const double fac, const Structure &st, const StructureNeighborView &st_nb, const StDescriptors &st_d) override;
     void  calc_phi_force_rows(phi_type &Phi, size_t &row,
-            const double fac, const Structure &st, const StDescriptors &st_d) override;
+            const double fac, const Structure &st, const StructureNeighborView &st_nb, const StDescriptors &st_d) override;
     void  calc_phi_stress_rows(phi_type &Phi, size_t &row,
-            const double fac[6], const Structure &st, const StDescriptors &st_d) override;
+            const double fac[6], const Structure &st, const StructureNeighborView &st_nb, const StDescriptors &st_d) override;
 };
 }
 }
diff --git a/include/tadah/mlip/design_matrix/functions/dm_function_base.h b/include/tadah/mlip/design_matrix/functions/dm_function_base.h
index 49e92bc..8d40e9f 100644
--- a/include/tadah/mlip/design_matrix/functions/dm_function_base.h
+++ b/include/tadah/mlip/design_matrix/functions/dm_function_base.h
@@ -7,6 +7,7 @@
 #include <tadah/models/functions/function_base.h>
 #include <tadah/mlip/structure.h>
 #include <tadah/mlip/st_descriptors.h>
+#include <tadah/mlip/structure_neighbor_view.h>
 
 #include <iomanip>
 #include <iostream>
@@ -25,11 +26,11 @@ struct DM_Function_Base: public virtual Function_Base {
 
   virtual size_t get_phi_cols(const Config &)=0;
   virtual void calc_phi_energy_row(phi_type &, size_t &,
-                                   const double , const Structure &, const StDescriptors &)=0;
+                                   const double , const Structure &, const StructureNeighborView &st_nb, const StDescriptors &)=0;
   virtual void calc_phi_force_rows(phi_type &, size_t &,
-                                   const double , const Structure &, const StDescriptors &)=0;
+                                   const double , const Structure &, const StructureNeighborView &st_nb, const StDescriptors &)=0;
   virtual void calc_phi_stress_rows(phi_type &, size_t &,
-                                    const double[6], const Structure &, const StDescriptors &)=0;
+                                    const double[6], const Structure &, const StructureNeighborView &st_nb, const StDescriptors &)=0;
 };
 }
 }
diff --git a/include/tadah/mlip/design_matrix/functions/kernels/dm_kern_base.h b/include/tadah/mlip/design_matrix/functions/kernels/dm_kern_base.h
index 5bc532c..996dce6 100644
--- a/include/tadah/mlip/design_matrix/functions/kernels/dm_kern_base.h
+++ b/include/tadah/mlip/design_matrix/functions/kernels/dm_kern_base.h
@@ -7,6 +7,7 @@
 #include <tadah/mlip/st_descriptors.h>
 #include <tadah/core/core_types.h>
 #include <tadah/models/functions/kernels/kern_base.h>
+#include <tadah/mlip/structure_neighbor_view.h>
 
 #include <iostream>
 namespace tadah {
@@ -26,9 +27,9 @@ class DM_Kern_Base: public DM_Function_Base, public virtual Kern_Base  {
         DM_Kern_Base(const Config&c);
         virtual ~DM_Kern_Base();
         virtual size_t get_phi_cols(const Config &config) override;
-        virtual void  calc_phi_energy_row(phi_type &Phi, size_t &row, const double fac, const Structure &st, const StDescriptors &st_d) override;
-        virtual void  calc_phi_force_rows(phi_type &Phi, size_t &row, const double fac, const Structure &st, const StDescriptors &st_d) override;
-        virtual void  calc_phi_stress_rows(phi_type &Phi, size_t &row, const double fac[6], const Structure &st, const StDescriptors &st_d) override;
+        virtual void  calc_phi_energy_row(phi_type &Phi, size_t &row, const double fac, const Structure &st, const StructureNeighborView &st_nb, const StDescriptors &st_d) override;
+        virtual void  calc_phi_force_rows(phi_type &Phi, size_t &row, const double fac, const Structure &st, const StructureNeighborView &st_nb, const StDescriptors &st_d) override;
+        virtual void  calc_phi_stress_rows(phi_type &Phi, size_t &row, const double fac[6], const Structure &st, const StructureNeighborView &st_nb, const StDescriptors &st_d) override;
 
 };
 }
diff --git a/include/tadah/mlip/design_matrix/functions/kernels/dm_kern_linear.h b/include/tadah/mlip/design_matrix/functions/kernels/dm_kern_linear.h
index fcad6b9..968478d 100644
--- a/include/tadah/mlip/design_matrix/functions/kernels/dm_kern_linear.h
+++ b/include/tadah/mlip/design_matrix/functions/kernels/dm_kern_linear.h
@@ -23,11 +23,11 @@ class DM_Kern_Linear :  public DM_Kern_Base, public Kern_Linear {
 
     size_t get_phi_cols(const Config &config) override;
     void calc_phi_energy_row(phi_type &Phi, size_t &row,
-            const double fac, const Structure &st, const StDescriptors &st_d) override;
+            const double fac, const Structure &st, const StructureNeighborView &st_nb, const StDescriptors &st_d) override;
     void calc_phi_force_rows(phi_type &Phi, size_t &row,
-            const double fac, const Structure &st, const StDescriptors &st_d) override;
+            const double fac, const Structure &st, const StructureNeighborView &st_nb, const StDescriptors &st_d) override;
     void calc_phi_stress_rows(phi_type &Phi, size_t &row,
-            const double fac[6], const Structure &st, const StDescriptors &st_d) override;
+            const double fac[6], const Structure &st, const StructureNeighborView &st_nb, const StDescriptors &st_d) override;
 };
 }
 }
diff --git a/include/tadah/mlip/neighbor_list_db.h b/include/tadah/mlip/neighbor_list_db.h
index f716070..acd4ea4 100644
--- a/include/tadah/mlip/neighbor_list_db.h
+++ b/include/tadah/mlip/neighbor_list_db.h
@@ -5,7 +5,7 @@
 
 #include <tadah/mlip/neighbor_list.h>
 #include <tadah/mlip/structure_db.h>
-#include <tadah/mlip/structure_neighbor_view.h>
+// #include <tadah/mlip/structure_neighbor_view.h>
 
 namespace tadah {
 namespace mlip {
@@ -58,7 +58,7 @@ public:
         if (localAtomIndex >= sRef.natoms()) {
             throw std::out_of_range("NeighborListDB::toGlobalAtomIndex: localAtomIndex out of range");
         }
-        return sRef.offset_ + localAtomIndex;
+        return sRef.offset() + localAtomIndex;
     }
 
     /**
@@ -112,11 +112,11 @@ public:
      * @param s A reference to the chosen Structure from the same DB.
      * @return A lightweight view object referencing the HPC data subrange for s.
      */
-    inline StructureNeighborView createStructureView(const Structure &s) const
-    {
-        // Optionally check that s.db_ == &db_ if you want to ensure they belong together.
-        return StructureNeighborView(nlist_, s);
-    }
+    // inline StructureNeighborView createStructureView(const Structure &s) const
+    // {
+    //     // Optionally check that s.db_ == &db_ if you want to ensure they belong together.
+    //     return StructureNeighborView(nlist_, s);
+    // }
 
 private:
     /**
diff --git a/include/tadah/mlip/structure.h b/include/tadah/mlip/structure.h
index f0c6653..5bafff0 100644
--- a/include/tadah/mlip/structure.h
+++ b/include/tadah/mlip/structure.h
@@ -87,6 +87,20 @@ public:
   Vec3dSoAView positionView(size_t i);
   Vec3dSoAView forceView(size_t i);
 
+  /**
+   * @brief Converts a global index to a local index within this structure.
+   *
+   */
+inline std::size_t globalToLocal(std::size_t gIdx) const {
+    // Check if gIdx belongs to this structure’s subrange
+    if (gIdx < offset_ || gIdx >= (offset_ + size_)) {
+        throw std::runtime_error(
+          "globalToLocal: global index not in this structure's subrange"
+        );
+    }
+    return gIdx - offset_;
+}
+
   // --------------------------------------------------
   // Minimal I/O
   // --------------------------------------------------
diff --git a/include/tadah/mlip/structure_neighbor_view.h b/include/tadah/mlip/structure_neighbor_view.h
index 277f317..21b9b1b 100644
--- a/include/tadah/mlip/structure_neighbor_view.h
+++ b/include/tadah/mlip/structure_neighbor_view.h
@@ -2,10 +2,10 @@
 
 #include <cstddef>
 #include <stdexcept>
-#include "neighbor_list.h"
-#include "structure_db.h"
-#include "structure.h"
-#include "neighbor_list_db.h"
+#include <tadah/mlip/neighbor_list.h>
+#include <tadah/mlip/structure_db.h>
+#include <tadah/mlip/structure.h>
+#include <tadah/mlip/neighbor_list_db.h>
 
 namespace tadah {
 namespace mlip {
@@ -35,16 +35,17 @@ public:
     }
 
     /**
-     * @brief Constructs by picking a structure from NeighborListDB.db()(structureIndex),
-     *        then referencing the HPC neighbor list from nlDB.nlist().
+     * @brief Constructs by picking a structure from StructureDB(structureIndex),
+     *        then referencing the HPC neighbor list from nldb.nlist().
      *
      * This avoids manually retrieving the Structure from StructureDB first.
      *
-     * @param nlDB The neighbor-list DB that references a single HPC neighbor list & DB.
-     * @param structureIndex The index of the desired structure in nlDB.db().
+     * @param stdb StructureDB containing the desired structure.
+     * @param nldb The neighbor-list DB that references a single HPC neighbor list & stdb.
+     * @param structureIndex The index of the desired structure in stdb.
      */
-    StructureNeighborView(const NeighborListDB &nlDB, std::size_t structureIndex)
-    : globalList_(nlDB.nlist()), str_(nlDB.db()(structureIndex))
+    StructureNeighborView(const StructureDB &stdb, const NeighborListDB &nldb, std::size_t structureIndex)
+    : globalList_(nldb.nlist()), str_(stdb(structureIndex))
     {
         // Minimal checks can be added if needed.
     }
@@ -115,5 +116,91 @@ public:
      * @brief Returns the Z-shift pointer.
      */
     inline const int* getShiftZPtr(std::size_t localAtomIndex) const {
-        if (localAtomIndex >= str_.nat
+        if (localAtomIndex >= str_.natoms()) {
+            throw std::out_of_range("StructureNeighborView::getShiftZPtr: local index out of range");
+        }
+        std::size_t gIdx = str_.offset() + localAtomIndex;
+        return globalList_.getShiftZPtr(gIdx);
+    }
+
+private:
+    /**
+     * @brief Reference to the single global HPC neighbor list (covering all atoms).
+     */
+    const NeighborList &globalList_;
+
+    /**
+     * @brief The targeted structure from the HPC DB, which provides offset() and natoms().
+     */
+    const Structure &str_;
+
+  /**
+   * @brief Returns the local index (ii) of the i-th atom within the j-th atom's neighbor list.
+   *
+   * This method uses mirror information stored in the neighbor list to identify
+   * the "reverse" index in the j-th atom's neighbor list that corresponds to
+   * the i-th atom. The parameter jj must be the position of j in i's neighbor list.
+   * i and j are local indices within the current structure (not global HPC indices).
+   * 
+   * Basic checks:
+   *   - 0 <= i  < numAtoms()
+   *   - 0 <= j  < numAtoms()
+   *   - 0 <= jj < numNeighbors(i)
+   *
+   * The logic:
+   *   1. Global HPC indices (gi, gj) are computed by adding each local index (i, j) to the structure offset.
+   *   2. mirrorOffset is fetched from the global neighbor list using getMirrorOffset(gi, jj).
+   *      This mirrorOffset indicates the flattened position in j's neighbor entries pointing back to i.
+   *   3. prefixSum_[gj] is the start offset of j's neighbors in the flattened global array.
+   *   4. Subtracting prefixSum_[gj] from mirrorOffset gives the local index of i in j's neighbor list (ii).
+   *
+   * Example usage within HPC pairwise iteration:
+   *   - One loops over each atom i in the structure.
+   *   - Then over each neighbor index jj of i (pointing to neighbor j).
+   *   - getAAIndex(i, j, jj) yields the local index ii (where i is found in j's neighbor list).
+   *   - This ensures consistent symmetrical housekeeping in pair computations.
+   *
+   * @param i  Local index of the first atom.
+   * @param j  Local index of the second atom (neighbor).
+   * @param jj Position of j in i's neighbor list.
+   * @return The local index (ii) of i in j's neighbor list.
+   */
+  inline size_t getAAIndex(const size_t i, const size_t j, const size_t jj)
+  {
+    // Verify bounds on i, j, and jj to avoid out-of-range errors.
+    // i, j, and jj must be valid for this structure.
+    if (i >= str_.natoms() || j >= str_.natoms()) {
+      throw std::out_of_range("StructureNeighborView::getAAIndex: atom index out of range");
+    }
+    if (jj >= globalList_.numNeighbors(str_.offset() + i)) {
+      throw std::out_of_range("StructureNeighborView::getAAIndex: neighbor index out of range");
+    }
+
+    // Convert local indices (i, j) to global HPC indices (gi, gj).
+    const size_t gi = str_.offset() + i;
+    const size_t gj = str_.offset() + j;
+
+    // Retrieve the flattened mirror offset for i <-> j.
+    // This offset indicates the actual position in j's neighbor array pointing back to i.
+    const size_t mirrorOffset = globalList_.getMirrorOffset(gi, jj);
+
+    // Identify j's subarray start in the flattened array.
+    // prefixData()[gj] is the offset in the neighbors array for atom gj.
+    const size_t jPrefix = globalList_.prefixData()[gj];
+
+    // The difference between mirrorOffset and j's prefix offset
+    // yields the local index of i in j's neighbor list.
+    const size_t ii = mirrorOffset - jPrefix;
+
+    // Optional check: ensure ii is within valid neighbor range for j.
+    if (ii >= globalList_.numNeighbors(gj)) {
+      throw std::out_of_range("StructureNeighborView::getAAIndex: computed mirror index out of range");
+    }
+
+    return ii;
+  }
+};
+
+} // namespace mlip
+} // namespace tadah
 
diff --git a/src/dm_bf_linear.cpp b/src/dm_bf_linear.cpp
index 7d576fa..483c263 100644
--- a/src/dm_bf_linear.cpp
+++ b/src/dm_bf_linear.cpp
@@ -17,7 +17,7 @@ size_t DM_BF_Linear::get_phi_cols(const Config &config)
   return cols;
 }
 void DM_BF_Linear::calc_phi_energy_row(phi_type &Phi, size_t &row,
-                                       const double fac, const Structure &, const StDescriptors &st_d)
+                                       const double fac, const Structure &, const StructureNeighborView &st_nb, const StDescriptors &st_d)
 {
   for (size_t i=0; i<st_d.naed(); ++i) {
     const aed_type &aed = st_d.get_aed(i);
@@ -28,7 +28,7 @@ void DM_BF_Linear::calc_phi_energy_row(phi_type &Phi, size_t &row,
   row++;
 }
 void DM_BF_Linear::calc_phi_force_rows(phi_type &Phi, size_t &row,
-                                       const double fac, const Structure &st, const StDescriptors &st_d)
+                                       const double fac, const Structure &st, const StructureNeighborView &st_nb, const StDescriptors &st_d)
 {
 
   for (size_t a=0; a<st.natoms(); ++a) {
@@ -47,7 +47,7 @@ void DM_BF_Linear::calc_phi_force_rows(phi_type &Phi, size_t &row,
   }
 }
 void DM_BF_Linear::calc_phi_stress_rows(phi_type &Phi, size_t &row,
-                                        const double fac[6], const Structure &st, const StDescriptors &st_d)
+                                        const double fac[6], const Structure &st, const StructureNeighborView &st_nb, const StDescriptors &st_d)
 {
   double V_inv = 1/st.get_volume();
   for (size_t i=0; i<st.natoms(); ++i) {
diff --git a/src/dm_bf_polynomial2.cpp b/src/dm_bf_polynomial2.cpp
index ddc417e..24f22e3 100644
--- a/src/dm_bf_polynomial2.cpp
+++ b/src/dm_bf_polynomial2.cpp
@@ -21,7 +21,7 @@ size_t DM_BF_Polynomial2::get_phi_cols(const Config &config)
 void DM_BF_Polynomial2::calc_phi_energy_row(phi_type &Phi,
                                             size_t &row,
                                             const double fac,
-                                            const Structure &,
+                                            const Structure &, const StructureNeighborView &st_nb,
                                             const StDescriptors &st_d)
 {
   for (size_t a=0; a<st_d.naed();++a) {
@@ -38,14 +38,22 @@ void DM_BF_Polynomial2::calc_phi_energy_row(phi_type &Phi,
 void DM_BF_Polynomial2::calc_phi_force_rows(phi_type &Phi,
                                             size_t &row,
                                             const double fac,
-                                            const Structure &st,
+                                            const Structure &st, const StructureNeighborView &st_nb,
                                             const StDescriptors &st_d)
 {
   for (size_t a=0; a<st.natoms(); ++a) {
+
+    const std::size_t* NbPtr    = st_nb.getNeighborsPtr(a);
+    size_t NNbrs = st_nb.numNeighbors(a);
+    const Vec3dSoAView ri = st.positionView(a);
     const aed_type& aedi = st_d.get_aed(a);
-    for (size_t jj=0; jj<st_d.fd[a].size(); ++jj) {
-      const size_t j=st.near_neigh_idx[a][jj];
-      size_t aa = st.get_nn_iindex(a,j,jj);
+
+    for (size_t jj=0; jj<NNbrs; ++jj) {
+
+      size_t gj = NbPtr[jj];
+      size_t j= st.globalToLocal(gj);  // j is local to st
+      size_t aa = getAAIndex(a, j, jj)
+      const aed_type& aedi = st_d.get_aed(a);
       const fd_type &fdji = st_d.fd[j][aa];
       const fd_type &fdij = st_d.fd[a][jj];
       const aed_type& aedj = st_d.get_aed(j);
@@ -68,16 +76,22 @@ void DM_BF_Polynomial2::calc_phi_force_rows(phi_type &Phi,
 void DM_BF_Polynomial2::calc_phi_stress_rows(phi_type &Phi,
                                              size_t &row,
                                              const double fac[6],
-                                             const Structure &st, 
+                                             const Structure &st, const StructureNeighborView &st_nb, 
                                              const StDescriptors &st_d)
 {
   double V_inv = 1/st.get_volume();
   for (size_t a=0; a<st.natoms(); ++a) {
-    const Vec3d &ri = st(a).position;
+
+    const std::size_t* NbPtr    = st_nb.getNeighborsPtr(a);
+    size_t NNbrs = st_nb.numNeighbors(a);
+    const Vec3dSoAView ri = st.positionView(a);
     const aed_type& aedi = st_d.get_aed(a);
-    for (size_t jj=0; jj<st_d.fd[a].size(); ++jj) {
-      const size_t j=st.near_neigh_idx[a][jj];
-      size_t aa = st.get_nn_iindex(a,j,jj);
+
+    for (size_t jj=0; jj<NNbrs; ++jj) {
+
+      size_t gj = NbPtr[jj];
+      size_t j= st.globalToLocal(gj);  // j is local to st
+      size_t aa = getAAIndex(a, j, jj)
       const fd_type &fdji = st_d.fd[j][aa];
       const fd_type &fdij = st_d.fd[a][jj];
       const Vec3d &rj = st.nn_pos(a,jj);
@@ -101,5 +115,6 @@ void DM_BF_Polynomial2::calc_phi_stress_rows(phi_type &Phi,
   }
   row += 6;
 }
+
 }
 }
diff --git a/src/dm_kern_base.cpp b/src/dm_kern_base.cpp
index f3fe327..28df869 100644
--- a/src/dm_kern_base.cpp
+++ b/src/dm_kern_base.cpp
@@ -16,7 +16,7 @@ size_t DM_Kern_Base::get_phi_cols(const Config &)
   return basis.cols();
 }
 void  DM_Kern_Base::calc_phi_energy_row(phi_type &Phi, size_t &row, const double fac,
-                                        const Structure &, const StDescriptors &st_d)
+                                        const Structure &, const StructureNeighborView &st_nb, const StDescriptors &st_d)
 {
   for (size_t a=0; a<st_d.naed();++a) {
     for (size_t b=0; b<basis.cols(); ++b) {
@@ -27,7 +27,7 @@ void  DM_Kern_Base::calc_phi_energy_row(phi_type &Phi, size_t &row, const double
   //Phi.row(row++) *= fac;
 }
 void  DM_Kern_Base::calc_phi_force_rows(phi_type &Phi, size_t &row, const double fac,
-                                        const Structure &st, const StDescriptors &st_d) {
+                                        const Structure &st, const StructureNeighborView &st_nb, const StDescriptors &st_d) {
   for (size_t i=0; i<st.natoms(); ++i) {
     const aed_type& aedi = st_d.get_aed(i);
     for (size_t jj=0; jj<st_d.fd[i].size(); ++jj) {
@@ -47,7 +47,7 @@ void  DM_Kern_Base::calc_phi_force_rows(phi_type &Phi, size_t &row, const double
   }
 }
 void  DM_Kern_Base::calc_phi_stress_rows(phi_type &Phi, size_t &row, const double fac[6],
-                                         const Structure &st, const StDescriptors &st_d)
+                                         const Structure &st, const StructureNeighborView &st_nb, const StDescriptors &st_d)
 {
   double V_inv = 1/st.get_volume();
   for (size_t i=0; i<st.natoms(); ++i) {
diff --git a/src/dm_kern_linear.cpp b/src/dm_kern_linear.cpp
index dc12670..dc5a898 100644
--- a/src/dm_kern_linear.cpp
+++ b/src/dm_kern_linear.cpp
@@ -18,7 +18,7 @@ size_t DM_Kern_Linear::get_phi_cols(const Config &config)
   return cols;
 }
 void DM_Kern_Linear::calc_phi_energy_row(phi_type &Phi, size_t &row,
-                                         const double fac, const Structure &, const StDescriptors &st_d)
+                                         const double fac, const Structure &, const StructureNeighborView &st_nb, const StDescriptors &st_d)
 {
   for (size_t a=0; a<st_d.naed();++a) {
     const aed_type &aed = st_d.get_aed(a);  // TODO
@@ -29,7 +29,7 @@ void DM_Kern_Linear::calc_phi_energy_row(phi_type &Phi, size_t &row,
   row++;
 }
 void DM_Kern_Linear::calc_phi_force_rows(phi_type &Phi, size_t &row,
-                                         const double fac, const Structure &st, const StDescriptors &st_d)
+                                         const double fac, const Structure &st, const StructureNeighborView &st_nb, const StDescriptors &st_d)
 {
   for (size_t a=0; a<st.natoms(); ++a) {
     for (size_t jj=0; jj<st_d.fd[a].size(); ++jj) {
@@ -48,7 +48,7 @@ void DM_Kern_Linear::calc_phi_force_rows(phi_type &Phi, size_t &row,
 
 }
 void DM_Kern_Linear::calc_phi_stress_rows(phi_type &Phi, size_t &row,
-                                          const double fac[6], const Structure &st, const StDescriptors &st_d)
+                                          const double fac[6], const Structure &st, const StructureNeighborView &st_nb, const StDescriptors &st_d)
 {
   double V_inv = 1/st.get_volume();
   for (size_t i=0; i<st.natoms(); ++i) {
-- 
GitLab