From 9a0860d0f77ae1628036ada953b02006f76ec139 Mon Sep 17 00:00:00 2001
From: mkirsz <s1351949@sms.ed.ac.uk>
Date: Fri, 21 Feb 2025 13:31:29 +0000
Subject: [PATCH] NN list with binning

---
 include/tadah/mlip/descriptors_calc.hpp |   8 +-
 include/tadah/mlip/nn_finder.h          | 178 ++++++++++--
 include/tadah/mlip/structure.h          |  17 +-
 include/tadah/mlip/structure_db.h       |   9 +
 src/nn_finder.cpp                       | 348 ++++++++++++++++++------
 src/structure.cpp                       |  28 +-
 src/structure_db.cpp                    |  77 ++++++
 7 files changed, 539 insertions(+), 126 deletions(-)

diff --git a/include/tadah/mlip/descriptors_calc.hpp b/include/tadah/mlip/descriptors_calc.hpp
index f9659b4..b5bc9b1 100644
--- a/include/tadah/mlip/descriptors_calc.hpp
+++ b/include/tadah/mlip/descriptors_calc.hpp
@@ -1,6 +1,7 @@
 #ifndef DESCRIPTORS_CALC_HPP
 #define DESCRIPTORS_CALC_HPP
 
+#include <cstddef>
 #include <tadah/mlip/descriptors_calc.h>
 #include <tadah/core/periodic_table.h>
 
@@ -179,7 +180,8 @@ void DescriptorsCalc<D2,D3,DM,C2,C3,CM>::calc_rho(const Structure &st, StDescrip
       //double rij_sq = delij * delij;
       double rij_sq = delij[0]*delij[0] + delij[1]*delij[1] + delij[2]*delij[2];
       if (rij_sq > rcut_mb_sq) continue;
-      int Zj = st.near_neigh_atoms[i][jj].Z;
+      size_t neighIdx = st.near_neigh_idx[i][jj];
+      int Zj = st(neighIdx).Z;
       double rij = sqrt(rij_sq);
       dm.calc_rho(Zi,Zj,rij,rij_sq,delij,st_d.get_rho(i));
     }
@@ -252,7 +254,8 @@ void DescriptorsCalc<D2,D3,DM,C2,C3,CM>::calc(const Structure &st, StDescriptors
       double rij_sq = delij[0]*delij[0] + delij[1]*delij[1] + delij[2]*delij[2];
 
       if (rij_sq > rcut_max_sq) continue;
-      int Zj = st.near_neigh_atoms[i][jj].Z;
+      size_t neighIdx = st.near_neigh_idx[i][jj];
+      int Zj = st(neighIdx).Z;
       double rij = sqrt(rij_sq);
       double rij_inv = 1.0/rij;
 
@@ -343,7 +346,6 @@ void DescriptorsCalc<D2,D3,DM,C2,C3,CM>::calc_dimer(const Structure &st, StDescr
 
   // TODO weighting factors
   // For now assume all are the same type
-  //int Zj = st.near_neigh_atoms[0][0].Z;
   int Zi = 1;
   int Zj = 1;
 
diff --git a/include/tadah/mlip/nn_finder.h b/include/tadah/mlip/nn_finder.h
index e263514..2593334 100644
--- a/include/tadah/mlip/nn_finder.h
+++ b/include/tadah/mlip/nn_finder.h
@@ -2,39 +2,161 @@
 #define NN_FINDER_H
 
 #include <tadah/core/config.h>
+#include <tadah/core/lapack.h>
 #include <tadah/mlip/structure.h>
 #include <tadah/mlip/structure_db.h>
 
-/** Nearest Neighbour Finder
+/**
+ * @class NNFinder
  *
- * Construct a full nearest neighbour list
- * for every atom in a structure.
+ * Nearest Neighbor Finder that constructs a full nearest neighbor list
+ * for every atom in a structure, using:
+ *  - a **binned** (linked-cell) approach if every cell dimension ≥ cutoff,
+ *  - a **naive** approach otherwise (fallback).
  *
- * The lists are stored with a Structure object provided.
- *
- * The cutoff used is  max(\ref RCUT2B,\ref RCUT3B,\ref RCUTMB)
+ * The cutoff is from config("RCUTMAX").
  */
 class NNFinder {
-    private:
-        double cutoff_sq;
-        double cutoff;
-        /** Return false if cutoff is larger than
-         * one of cell dimensions.
-         */
-        bool check_box(Structure &st);
-        void num_shifts(Structure &st, int N[3]);
-    public:
-        /** Constructor to initalise this object
-         *
-         *  Required keys: at least one of:
-         *  \ref RCUT2B,\ref RCUT3B,\ref RCUTMB
-         */
-        NNFinder(Config &config);
-
-        /** \brief Find nearest neighbours for all atoms in a Structure */
-        void calc(Structure &st);
-
-        /** \brief Find nearest neighbours for all atoms in all Structure(s) */
-        void calc(StructureDB &stdb);
+private:
+    double cutoff;      ///< The cutoff distance
+    double cutoff_sq;   ///< cutoff^2 for distance checks
+
+    /**
+     * Return false if any cell dimension < cutoff => fallback to naive.
+     */
+    bool check_box(Structure &st);
+
+    /**
+     * Invert a 3×3 matrix (column-major) 
+     */
+/**
+ * Invert a 3x3 matrix using the cofactor (adjugate) method,
+ * matching the layout:
+ *   M(r,c) stored at index [r*3 + c]
+ *
+ * So, for example:
+ *   M(0,0) -> inM[0],
+ *   M(1,0) -> inM[3],
+ *   M(2,0) -> inM[6],
+ *   M(0,1) -> inM[1],
+ *   ...
+ *
+ * The output is also stored in the same pattern.
+ *
+ * Throws std::runtime_error if the matrix is near-singular.
+ */
+inline void inverse_3x3_direct(const double* inM, double* outM)
+{
+    // Map input array -> matrix elements:
+    // M(r,c) => inM[r*3 + c]
+    const double m00 = inM[0];  // M(0,0)
+    const double m10 = inM[3];  // M(1,0)
+    const double m20 = inM[6];  // M(2,0)
+    
+    const double m01 = inM[1];  // M(0,1)
+    const double m11 = inM[4];  // M(1,1)
+    const double m21 = inM[7];  // M(2,1)
+    
+    const double m02 = inM[2];  // M(0,2)
+    const double m12 = inM[5];  // M(1,2)
+    const double m22 = inM[8];  // M(2,2)
+
+    // ----------------------------------------------------------------------
+    // 1) Compute determinant:
+    //    det(M) = m00*(m11*m22 - m12*m21)
+    //           - m01*(m10*m22 - m12*m20)
+    //           + m02*(m10*m21 - m11*m20)
+    // ----------------------------------------------------------------------
+    double det = m00*(m11*m22 - m12*m21)
+               - m01*(m10*m22 - m12*m20)
+               + m02*(m10*m21 - m11*m20);
+
+    if (std::fabs(det) < 1e-14) {
+        throw std::runtime_error(
+            "inverse_3x3_cofactor_custom: matrix is near-singular."
+        );
+    }
+    double invDet = 1.0 / det;
+
+    // ----------------------------------------------------------------------
+    // 2) Cofactor / adjugate formula:
+    //
+    //    inv(0,0) = (m11*m22 - m12*m21)
+    //    inv(0,1) = -(m01*m22 - m02*m21)
+    //    inv(0,2) = (m01*m12 - m02*m11)
+    //
+    //    inv(1,0) = -(m10*m22 - m12*m20)
+    //    inv(1,1) = (m00*m22 - m02*m20)
+    //    inv(1,2) = -(m00*m12 - m02*m10)
+    //
+    //    inv(2,0) = (m10*m21 - m11*m20)
+    //    inv(2,1) = -(m00*m21 - m01*m20)
+    //    inv(2,2) = (m00*m11 - m01*m10)
+    //
+    // Multiply each cofactor by invDet.
+    // ----------------------------------------------------------------------
+    double i00 =  (m11*m22 - m12*m21) * invDet;
+    double i01 = -(m01*m22 - m02*m21) * invDet;
+    double i02 =  (m01*m12 - m02*m11) * invDet;
+
+    double i10 = -(m10*m22 - m12*m20) * invDet;
+    double i11 =  (m00*m22 - m02*m20) * invDet;
+    double i12 = -(m00*m12 - m02*m10) * invDet;
+
+    double i20 =  (m10*m21 - m11*m20) * invDet;
+    double i21 = -(m00*m21 - m01*m20) * invDet;
+    double i22 =  (m00*m11 - m01*m10) * invDet;
+
+    // ----------------------------------------------------------------------
+    // 3) Store the inverse back into outM in the same layout:
+    //    inv(r,c) => outM[r*3 + c]
+    // ----------------------------------------------------------------------
+    outM[0] = i00;  // inv(0,0)
+    outM[1] = i01;  // inv(0,1)
+    outM[2] = i02;  // inv(0,2)
+
+    outM[3] = i10;  // inv(1,0)
+    outM[4] = i11;  // inv(1,1)
+    outM[5] = i12;  // inv(1,2)
+
+    outM[6] = i20;  // inv(2,0)
+    outM[7] = i21;  // inv(2,1)
+    outM[8] = i22;  // inv(2,2)
+}
+
+    /**
+     * Naive approach to build neighbor lists.
+     */
+    void calc_naive(Structure &st);
+
+    /**
+     * Binning-based approach. If any dimension < cutoff, fallback to naive.
+     */
+    void calc_binned(Structure &st);
+
+    /**
+     * For naive approach, compute ±N to consider for image shifts.
+     */
+    void num_shifts(Structure &st, int N[3]);
+
+public:
+    /**
+     * Construct with cutoff from config("RCUTMAX").
+     */
+    NNFinder(Config &config);
+
+    /**
+     * Build nearest neighbors for all atoms in one Structure.
+     * Uses binned approach if possible, else naive fallback.
+     */
+    void calc(Structure &st);
+
+    /**
+     * Build nearest neighbors for each structure in a DB.
+     */
+    void calc(StructureDB &stdb);
 };
-#endif
+
+#endif // NN_FINDER_H
+
+
diff --git a/include/tadah/mlip/structure.h b/include/tadah/mlip/structure.h
index b264a1e..9ab8d6a 100644
--- a/include/tadah/mlip/structure.h
+++ b/include/tadah/mlip/structure.h
@@ -68,10 +68,6 @@ struct Structure {
    */
   double T=0;
 
-  /**
-   * Container for nearest neighbour atoms for every atom in the structure.
-   */
-  std::vector<std::vector<Atom>> near_neigh_atoms;
 
   /** Periodic image flag for neigherest neighbours.
    *
@@ -162,8 +158,17 @@ struct Structure {
    */
   double get_pressure(const double T, const double kB=8.617333262145e-5) const;
 
-  /** @return position of the n-th nearest neighbour of the i-th Atom. */
-  const Vec3d& nn_pos(const size_t i, const size_t n) const;
+/**
+ * Return the position of the n-th neighbor of atom i,
+ * computed via the neighbor's global index and periodic shifts.
+ *
+ * Assumes:
+ *   near_neigh_idx[i][n]  -> global index of neighbor
+ *   near_neigh_shift[i][n] -> integer triple (n1, n2, n3) OR real shift
+ *
+ * Returns by value (Vec3d) so we don't reference a temporary.
+ */
+  Vec3d nn_pos(const size_t i, const size_t n) const;
 
   /** @return a number of nearest neighbours of the i-th Atom. */
   size_t nn_size(const size_t i) const;
diff --git a/include/tadah/mlip/structure_db.h b/include/tadah/mlip/structure_db.h
index 8354890..0cab4a6 100644
--- a/include/tadah/mlip/structure_db.h
+++ b/include/tadah/mlip/structure_db.h
@@ -170,5 +170,14 @@ struct StructureDB {
 
   /** Method to dump class content to a file */
   void dump_to_file(const std::string& filepath, size_t prec=12) const;
+
+  // Public method that reads the file, counts blocks and line counts per block,
+  // then prints the results to std::cout.
+  void parseFile(const std::string& filename);
+
+private:
+  // Checks if a line is empty or contains only whitespace
+  bool isBlankLine(const std::string& line) const;
+
 };
 #endif
diff --git a/src/nn_finder.cpp b/src/nn_finder.cpp
index c2840d9..81569cb 100644
--- a/src/nn_finder.cpp
+++ b/src/nn_finder.cpp
@@ -1,108 +1,286 @@
 #include <tadah/mlip/nn_finder.h>
 #include <limits>
+#include <cmath>
+#include <algorithm>
+#include <stdexcept>
+#include <cstring>
+#include <chrono>
+#include <iostream>
+#include <vector>
 
-NNFinder::NNFinder(Config &config):
-  cutoff_sq(pow(config.get<double>("RCUTMAX"),2)),
-  cutoff(config.get<double>("RCUTMAX"))
-{}
+// Constructor
+NNFinder::NNFinder(Config &config)
+{
+  cutoff     = config.get<double>("RCUTMAX");
+  cutoff_sq  = cutoff * cutoff;
+}
+
+// ---------------------------------------------------------------------------
+bool NNFinder::check_box(Structure &st)
+{
+  for (int i = 0; i < 3; i++) {
+    double vx = st.cell(i, 0);
+    double vy = st.cell(i, 1);
+    double vz = st.cell(i, 2);
+    double len2 = vx*vx + vy*vy + vz*vz;
+    if (len2 < cutoff_sq) {
+      return false;
+    }
+  }
+  return true;
+}
+
+// ---------------------------------------------------------------------------
+void NNFinder::num_shifts(Structure &st, int N[3]) {
+  Matrix3d cell_inv = st.cell.inverse();
+
+  double l1 = cell_inv.col(0).norm();
+  double l2 = cell_inv.col(1).norm();
+  double l3 = cell_inv.col(2).norm();
+
+  double f1 = (l1 > 0) ? 1.0/l1 : 1.0;
+  double f2 = (l2 > 0) ? 1.0/l2 : 1.0;
+  double f3 = (l3 > 0) ? 1.0/l3 : 1.0;
+
+  int b1 = std::max(int(f1/cutoff),1);
+  int b2 = std::max(int(f2/cutoff),1);
+  int b3 = std::max(int(f3/cutoff),1);
+
+  N[0] = (int)std::round(0.5 + cutoff*b1/f1);
+  N[1] = (int)std::round(0.5 + cutoff*b2/f2);
+  N[2] = (int)std::round(0.5 + cutoff*b3/f3);
+}
 
-void NNFinder::calc(Structure &st) {
+// ---------------------------------------------------------------------------
+// Naive approach - store only neighbor indices and shifts, no local Atom copies
+void NNFinder::calc_naive(Structure &st)
+{
+  st.near_neigh_shift.resize(st.natoms());
+  st.near_neigh_idx.resize(st.natoms());
+  for (size_t i = 0; i < st.natoms(); i++) {
+    st.near_neigh_shift[i].reserve(100);
+    st.near_neigh_idx[i].reserve(100);
+  }
 
+  // Compute shift bounds
   int N[3];
   num_shifts(st, N);
 
-  Matrix shiftedpos(st.natoms(),3);
-  // for convenience only:
-  std::vector<std::vector<Atom>> &nnatoms = st.near_neigh_atoms;
-  std::vector<std::vector<Vec3d>> &nnshift = st.near_neigh_shift;
-  std::vector<std::vector<size_t>> &nnidx = st.near_neigh_idx;
-  nnatoms.resize(st.natoms());
-  nnshift.resize(st.natoms());
-  nnidx.resize(st.natoms());
-
-  Vec3d displacement;
-  Vec3d delij;
-  Vec3d shift;
-  Atom atom1;
-  Atom atom2;
-  double min_double = std::numeric_limits<double>::min();
-  for (int n1=-N[0]; n1<=N[0]; n1++) {
-    for (int n2=-N[1]; n2<=N[1]; n2++) {
-      for (int n3=-N[2]; n3<=N[2]; n3++) {
-
-        shift(n1,n2,n3);
-        displacement[0] = st.cell(0,0)*n1 + st.cell(1,0)*n2 +  st.cell(2,0)*n3 ;
-        displacement[1] = st.cell(0,1)*n1 + st.cell(1,1)*n2 +  st.cell(2,1)*n3 ;
-        displacement[2] = st.cell(0,2)*n1 + st.cell(1,2)*n2 +  st.cell(2,2)*n3 ;
-
-        for (size_t a=0; a<st.natoms(); ++a) {
-          shiftedpos(a,0) = st(a).position[0] + displacement[0];
-          shiftedpos(a,1) = st(a).position[1] + displacement[1];
-          shiftedpos(a,2) = st(a).position[2] + displacement[2];
+  // Precompute shifts (both real displacement and integer triple if needed)
+  std::vector<Vec3d> shifts;
+  std::vector<Vec3d> shiftIdx;
+  shifts.reserve((2*N[0]+1)*(2*N[1]+1)*(2*N[2]+1));
+  shiftIdx.reserve((2*N[0]+1)*(2*N[1]+1)*(2*N[2]+1));
+
+  for (int n1 = -N[0]; n1 <= N[0]; n1++) {
+    for (int n2 = -N[1]; n2 <= N[1]; n2++) {
+      for (int n3 = -N[2]; n3 <= N[2]; n3++) {
+        Vec3d disp;
+        disp[0] = st.cell(0,0)*n1 + st.cell(1,0)*n2 + st.cell(2,0)*n3;
+        disp[1] = st.cell(0,1)*n1 + st.cell(1,1)*n2 + st.cell(2,1)*n3;
+        disp[2] = st.cell(0,2)*n1 + st.cell(1,2)*n2 + st.cell(2,2)*n3;
+        shifts.push_back(disp);
+        shiftIdx.push_back(Vec3d(n1, n2, n3));
+      }
+    }
+  }
+
+  // Extract positions in contiguous arrays (optional but faster for distance checks)
+  const size_t natoms = st.natoms();
+  std::vector<double> xPos(natoms), yPos(natoms), zPos(natoms);
+  for (size_t i = 0; i < natoms; i++) {
+    xPos[i] = st(i).position[0];
+    yPos[i] = st(i).position[1];
+    zPos[i] = st(i).position[2];
+  }
+
+  // Distance checks
+  for (size_t s = 0; s < shifts.size(); s++) {
+    const Vec3d &disp = shifts[s];
+    const Vec3d &dispIdxVal = shiftIdx[s];
+    bool selfShift = (dispIdxVal[0] == 0 && dispIdxVal[1] == 0 && dispIdxVal[2] == 0);
+    size_t startA2 = (selfShift ? 1ul : 0ul);
+
+    for (size_t a1 = 0; a1 < natoms; a1++) {
+      for (size_t a2 = a1 + startA2; a2 < natoms; a2++) {
+        double dx = xPos[a1] - (xPos[a2] + disp[0]);
+        double dy = yPos[a1] - (yPos[a2] + disp[1]);
+        double dz = zPos[a1] - (zPos[a2] + disp[2]);
+        double rij_sq = dx*dx + dy*dy + dz*dz;
+        if (rij_sq < cutoff_sq) {
+          // forward
+          st.near_neigh_idx[a1].push_back(a2);
+          st.near_neigh_shift[a1].push_back(dispIdxVal);
+
+          // reverse
+          st.near_neigh_idx[a2].push_back(a1);
+          st.near_neigh_shift[a2].push_back(-dispIdxVal);
         }
+      }
+    }
+  }
 
-        // calculate all neighbours of a1 for this shift
-        size_t start = n1==0 && n2==0 && n3==0 ? 1 : 0;
-        for (size_t a1=0; a1<st.natoms(); ++a1) {
-          for (size_t a2=a1+start; a2<st.natoms(); ++a2) {
-            delij[0] = st(a1).position[0] - shiftedpos(a2,0);
-            delij[1] = st(a1).position[1] - shiftedpos(a2,1);
-            delij[2] = st(a1).position[2] - shiftedpos(a2,2);
-            double rij_sq = delij[0]*delij[0] + delij[1]*delij[1] + delij[2]*delij[2];
-
-            if(rij_sq<cutoff_sq && rij_sq>min_double) {
-              atom1 = st(a1);
-              atom2 = st(a2);
-              for (size_t i=0; i<3; ++i) {
-                atom2.position[i] = shiftedpos(a2,i);
-                atom1.position[i] = st(a1).position[i]-displacement[i];
-              }
-              nnatoms[a1].push_back(atom2);
-              nnidx[a1].push_back(a2);
-              nnshift[a1].push_back(shift);
+  // shrink neighbor arrays
+  for (size_t i = 0; i < natoms; i++) {
+    st.near_neigh_shift[i].shrink_to_fit();
+    st.near_neigh_idx[i].shrink_to_fit();
+  }
+}
+
+// ---------------------------------------------------------------------------
+// Binned approach - similarly remove local copies, store just indices & shifts
+void NNFinder::calc_binned(Structure &st)
+{
+
+  st.near_neigh_shift.resize(st.natoms());
+  st.near_neigh_idx.resize(st.natoms());
+  for (size_t i = 0; i < st.natoms(); i++) {
+    st.near_neigh_shift[i].reserve(100);
+    st.near_neigh_idx[i].reserve(100);
+  }
+
+  // invert cell
+  double invC[9];
+  inverse_3x3_direct(st.cell.data(), invC);
+
+  auto rowLength = [&](int row){
+    double vx = st.cell(row, 0);
+    double vy = st.cell(row, 1);
+    double vz = st.cell(row, 2);
+    return std::sqrt(vx*vx + vy*vy + vz*vz);
+  };
 
-              nnatoms[a2].push_back(atom1);
-              nnidx[a2].push_back(a1);
-              nnshift[a2].push_back(-shift);
+  double cellLenA = rowLength(0);
+  double cellLenB = rowLength(1);
+  double cellLenC = rowLength(2);
+
+  int nBinsA = std::max(1, (int)std::floor(cellLenA / cutoff));
+  int nBinsB = std::max(1, (int)std::floor(cellLenB / cutoff));
+  int nBinsC = std::max(1, (int)std::floor(cellLenC / cutoff));
+
+  struct BinCell {
+    std::vector<size_t> atomIndices;
+  };
+  std::vector<BinCell> bins(nBinsA * nBinsB * nBinsC);
+
+  auto binIndex = [&](int ia, int ib, int ic){
+    auto wrap = [&](int k, int n){ return ( (k % n) + n ) % n; };
+    ia = wrap(ia, nBinsA);
+    ib = wrap(ib, nBinsB);
+    ic = wrap(ic, nBinsC);
+    return size_t(ia + nBinsA * (ib + nBinsB * ic));
+  };
+
+  // fill bins
+  for (size_t i = 0; i < st.natoms(); i++) {
+    const auto &atm = st(i);
+    double fx = invC[0]*atm.position[0] + invC[3]*atm.position[1] + invC[6]*atm.position[2];
+    double fy = invC[1]*atm.position[0] + invC[4]*atm.position[1] + invC[7]*atm.position[2];
+    double fz = invC[2]*atm.position[0] + invC[5]*atm.position[1] + invC[8]*atm.position[2];
+
+    fx -= std::floor(fx);
+    fy -= std::floor(fy);
+    fz -= std::floor(fz);
+
+    int ia = (int)std::floor(fx*nBinsA);
+    int ib = (int)std::floor(fy*nBinsB);
+    int ic = (int)std::floor(fz*nBinsC);
+    bins[ binIndex(ia, ib, ic) ].atomIndices.push_back(i);
+  }
+
+  double cutSQ = cutoff_sq;
+
+  // search
+  for (int ia = 0; ia < nBinsA; ia++) {
+    for (int ib = 0; ib < nBinsB; ib++) {
+      for (int ic = 0; ic < nBinsC; ic++) {
+        size_t b0 = binIndex(ia, ib, ic);
+        auto &vec0 = bins[b0].atomIndices;
+
+        for (int ja = ia - 1; ja <= ia + 1; ja++) {
+          for (int jb = ib - 1; jb <= ib + 1; jb++) {
+            for (int jc = ic - 1; jc <= ic + 1; jc++) {
+              size_t b1 = binIndex(ja, jb, jc);
+              auto &vec1 = bins[b1].atomIndices;
+
+              int dA = ja - ia;
+              int dB = jb - ib;
+              int dC = jc - ic;
+
+              for (size_t idxA : vec0) {
+                for (size_t idxB : vec1) {
+                  // avoid double counting
+                  if ((b0 == b1) && (idxB <= idxA)) continue;
+
+                  const Atom &a1 = st(idxA);
+                  const Atom &a2 = st(idxB);
+
+                  double fracX = double(dA)/double(nBinsA);
+                  double fracY = double(dB)/double(nBinsB);
+                  double fracZ = double(dC)/double(nBinsC);
+
+                  // Real shift
+                  Vec3d shiftDisp(
+                    st.cell(0,0)*fracX + st.cell(0,1)*fracY + st.cell(0,2)*fracZ,
+                    st.cell(1,0)*fracX + st.cell(1,1)*fracY + st.cell(1,2)*fracZ,
+                    st.cell(2,0)*fracX + st.cell(2,1)*fracY + st.cell(2,2)*fracZ
+                  );
+
+                  double dx = a1.position[0] - (a2.position[0] + shiftDisp[0]);
+                  double dy = a1.position[1] - (a2.position[1] + shiftDisp[1]);
+                  double dz = a1.position[2] - (a2.position[2] + shiftDisp[2]);
+                  double dist2 = dx*dx + dy*dy + dz*dz;
+
+                  if (dist2 < cutSQ) {
+                    // forward
+                    st.near_neigh_idx[idxA].push_back(idxB);
+                    st.near_neigh_shift[idxA].push_back(Vec3d(dA, dB, dC));
+
+                    // reverse
+                    st.near_neigh_idx[idxB].push_back(idxA);
+                    st.near_neigh_shift[idxB].push_back(Vec3d(-dA, -dB, -dC));
+                  }
+                }
+              }
             }
           }
         }
       }
     }
   }
+
+  // shrink
+  for (size_t i = 0; i < st.natoms(); i++) {
+    st.near_neigh_shift[i].shrink_to_fit();
+    st.near_neigh_idx[i].shrink_to_fit();
+  }
 }
-void NNFinder::calc(StructureDB &stdb) {
+
+// ---------------------------------------------------------------------------
+// Master calc that chooses naive or binned
+void NNFinder::calc(Structure &st)
+{
+  if(!check_box(st)) {
+    calc_naive(st);
+  } else {
+    calc_binned(st);
+  }
+}
+
+// ---------------------------------------------------------------------------
+// Parallel loop over structure database
+void NNFinder::calc(StructureDB &stdb)
+{
+  auto t0 = std::chrono::steady_clock::now();
 #ifdef _OPENMP
-#pragma omp parallel for
+  #pragma omp parallel for
 #endif
-  for (size_t s=0; s<stdb.size(); ++s) {
-    calc(stdb(s));
+  for(size_t i = 0; i < stdb.size(); i++){
+    calc(stdb(i));
   }
+  auto t1 = std::chrono::steady_clock::now();
+  double seconds = std::chrono::duration<double>(t1 - t0).count();
+  std::cout << "calc(StructureDB &stdb) for-loop took "
+            << seconds << " seconds\n";
 }
-bool NNFinder::check_box(Structure &st) {
-  double f = 1.05;   // extra safety measure
-  for (size_t i=0; i<3; ++i)
-    if (st.cell.row(i)*st.cell.row(i) < f*cutoff_sq)
-      return false;
-  return true;
-}
-void NNFinder::num_shifts(Structure &st, int N[3]) {
-  Matrix3d cell_inv = st.cell.inverse();
-
-  double l1 = cell_inv.col(0).norm();
-  double l2 = cell_inv.col(1).norm();
-  double l3 = cell_inv.col(2).norm();
-
-  double f1 = l1 > 0 ? 1.0/l1 : 1.0;
-  double f2 = l2 > 0 ? 1.0/l2 : 1.0;
-  double f3 = l3 > 0 ? 1.0/l3 : 1.0;
-
-  int b1 = std::max(int(f1/cutoff),1);
-  int b2 = std::max(int(f2/cutoff),1);
-  int b3 = std::max(int(f3/cutoff),1);
 
-  N[0] = (int)std::round(0.5+cutoff*b1/(f1));
-  N[1] = (int)std::round(0.5+cutoff*b2/(f2));
-  N[2] = (int)std::round(0.5+cutoff*b3/(f3));
-
-}
diff --git a/src/structure.cpp b/src/structure.cpp
index b759203..ee3ca19 100644
--- a/src/structure.cpp
+++ b/src/structure.cpp
@@ -19,8 +19,29 @@ void Structure::remove_atom(const size_t i) {
   atoms.erase(atoms.begin()+i);
 }
 
-const Vec3d& Structure::nn_pos(const size_t i, const size_t n) const {
-  return near_neigh_atoms[i][n].position;
+Vec3d Structure::nn_pos(const size_t i, const size_t n) const
+{
+    // (A) Global index of this neighbor
+    const size_t neighborIndex = near_neigh_idx[i][n];
+
+    // (B) Atom's original "unshifted" position:
+    const Vec3d &posNeighbor = atoms[neighborIndex].position;
+
+    // (C) Convert the stored shift -> real displacement shiftDisp
+    //     If near_neigh_shift[i][n] is an integer triple (n1, n2, n3),
+    //     multiply by the cell. Otherwise, if it's already in real space,
+    //     you can just do: Vec3d shiftDisp = near_neigh_shift[i][n].
+    Vec3d shift = near_neigh_shift[i][n]; // might be (n1, n2, n3)
+
+    // If shift is integer triple (n1, n2, n3), multiply with cell:
+    Vec3d shiftDisp;
+    shiftDisp[0] = shift[0]*cell(0,0) + shift[1]*cell(0,1) + shift[2]*cell(0,2);
+    shiftDisp[1] = shift[0]*cell(1,0) + shift[1]*cell(1,1) + shift[2]*cell(1,2);
+    shiftDisp[2] = shift[0]*cell(2,0) + shift[1]*cell(2,1) + shift[2]*cell(2,2);
+
+    // (D) Final neighbor position = unshifted position + shiftDisp
+    Vec3d posShifted = posNeighbor + shiftDisp;
+    return posShifted;
 }
 
 size_t Structure::natoms() const {
@@ -28,7 +49,7 @@ size_t Structure::natoms() const {
 }
 
 size_t Structure::nn_size(size_t i) const {
-  return near_neigh_atoms[i].size();
+  return near_neigh_idx[i].size();
 }
 
 int Structure::read(std::ifstream &ifs) {
@@ -258,7 +279,6 @@ int Structure::next_structure(std::ifstream &ifs) {
   return natoms;
 }
 void Structure::clear_nn() {
-  near_neigh_atoms.clear();
   near_neigh_shift.clear();
   near_neigh_idx.clear();
 }
diff --git a/src/structure_db.cpp b/src/structure_db.cpp
index 384e0a7..f1aaa36 100644
--- a/src/structure_db.cpp
+++ b/src/structure_db.cpp
@@ -2,6 +2,7 @@
 #include <tadah/mlip/structure_db.h>
 #include <tadah/core/periodic_table.h>
 #include <cstdio>
+#include <cctype>   // For std::isspace
 
 StructureDB::StructureDB() {
   PeriodicTable::initialize();
@@ -16,6 +17,7 @@ void StructureDB::add(const std::string fn) {
   if (!ifs.is_open()) {
     throw std::runtime_error("DBFILE does not exist: "+fn);
   }
+  parseFile(fn);
   while (true) {
     structures.push_back(Structure());
     int t = structures.back().read(ifs);
@@ -277,3 +279,78 @@ std::string StructureDB::summary() const {
   str+="\n";
   return str;
 }
+void StructureDB::parseFile(const std::string& filename)
+{
+    std::ifstream fin(filename, std::ios::in | std::ios::binary);
+    if (!fin.is_open())
+    {
+        std::cerr << "Error: could not open file " << filename << "\n";
+        return;
+    }
+
+  std::size_t header_size = 9;
+
+    // Increase buffer size to speed up I/O on large files.
+    static const size_t BUFSIZE = 100ULL << 20; // 100 MiB
+    char* buffer = new char[BUFSIZE];
+    fin.rdbuf()->pubsetbuf(buffer, BUFSIZE);
+
+    std::vector<size_t> blockLineCounts;
+    blockLineCounts.reserve(10000); // Pre-allocate to reduce repeated allocations
+
+    size_t currentBlockCount = 0;
+    std::string line;
+
+    while (true)
+    {
+        if (!std::getline(fin, line))
+        {
+            // End of file or read error
+            break;
+        }
+
+        if (isBlankLine(line))
+        {
+            // We reached the end of the current block
+            if (currentBlockCount > 0)
+            {
+                blockLineCounts.push_back(currentBlockCount-header_size);
+                currentBlockCount = 0;
+            }
+        }
+        else
+        {
+            // Non-empty line => belongs to the current block
+            currentBlockCount++;
+        }
+    }
+
+    // If the last block didn’t end with a blank line, close it out
+    if (currentBlockCount > 0)
+    {
+        blockLineCounts.push_back(currentBlockCount-header_size);
+    }
+
+    fin.close();
+    delete[] buffer;
+
+    // Print the results
+    std::cout << "Found " << blockLineCounts.size() << " blocks.\n";
+    for (size_t i = 0; i < blockLineCounts.size(); i+=1000)
+    {
+        std::cout << "Block " << i << " has "
+                  << blockLineCounts[i] << " atoms\n";
+    }
+}
+
+bool StructureDB::isBlankLine(const std::string& line) const
+{
+    for (char c : line)
+    {
+        if (!std::isspace(static_cast<unsigned char>(c)))
+        {
+            return false;
+        }
+    }
+    return true;
+}
-- 
GitLab