From 94e358f596df35bde1a456611924829c34b88f6f Mon Sep 17 00:00:00 2001
From: mkirsz <s1351949@sms.ed.ac.uk>
Date: Mon, 24 Feb 2025 00:11:55 +0000
Subject: [PATCH] wip

---
 .../tadah/mlip/design_matrix/design_matrix.h  |  18 ++-
 .../functions/basis_functions/dm_bf_base.h    |   1 +
 include/tadah/mlip/models/basis.h             |   5 +
 include/tadah/mlip/models/m_all.h             |   4 +
 include/tadah/mlip/models/m_blr.h             |   8 +-
 include/tadah/mlip/models/m_krr.h             |   4 +
 include/tadah/mlip/models/m_tadah_base.h      |   4 +
 include/tadah/mlip/neighbor_calc.h            |   4 +-
 include/tadah/mlip/neighbor_list_db.h         |  21 ++-
 include/tadah/mlip/structure_neighbor_view.h  | 122 ++++++++++++++++++
 include/tadah/mlip/trainer.h                  |  10 +-
 11 files changed, 186 insertions(+), 15 deletions(-)
 create mode 100644 include/tadah/mlip/structure_neighbor_view.h

diff --git a/include/tadah/mlip/design_matrix/design_matrix.h b/include/tadah/mlip/design_matrix/design_matrix.h
index b98239c..168b94b 100644
--- a/include/tadah/mlip/design_matrix/design_matrix.h
+++ b/include/tadah/mlip/design_matrix/design_matrix.h
@@ -10,6 +10,8 @@
 #include <tadah/mlip/memory/MLIPWorkspaceManager.h>
 #include <tadah/mlip/memory/DesignMatrixWorkspace.h>
 
+#include <tadah/mlip/neighbor_list_db.h>
+
 #include <stdexcept>
 
 namespace tadah {
@@ -73,7 +75,7 @@ class DesignMatrix : public DesignMatrixBase {
 public:
 
   F f;
-  tadah::mlip::memory::DesignMatrixWorkspace *ws;
+  memory::DesignMatrixWorkspace *ws;
   bool scale=true;    // Control escale,fscale,sscale
 
   double e_std_dev=1;
@@ -92,7 +94,7 @@ public:
      *
      * \endcode
      */
-    DesignMatrix(F &f, Config &c, tadah::mlip::memory::IMLIPWorkspaceManager& workspaceManager)
+    DesignMatrix(F &f, Config &c, memory::IMLIPWorkspaceManager& workspaceManager)
         : f(f),
           workspaceManager_(&workspaceManager),
           ownWorkspaceManager(false),
@@ -114,7 +116,7 @@ public:
 
     // Constructor without workspaceManager_ parameter (delegating constructor)
     DesignMatrix(F &f, Config &c)
-        : DesignMatrix(f, c, *new tadah::mlip::memory::MLIPWorkspaceManager())
+        : DesignMatrix(f, c, *new memory::MLIPWorkspaceManager())
     {
         ownWorkspaceManager = true;  // Set ownership flag
     }
@@ -166,8 +168,10 @@ public:
 
   /** \brief Calculate descriptors and build design matrix. */
   template <typename DC>
-  void build(const StructureDB &stdb, Normaliser &norm,
+  void build(const StructureDB &stdb, NeighbourListDB &nldb,
+             Normaliser &norm,
              DC &dc) {
+
     //DescriptorsCalc<D2,D3,DM,C2,C3,CM> dc(config);
     calc_mn(stdb);
     ws = workspaceManager_->getDesignMatrixWorkspace(rows,cols);
@@ -197,12 +201,12 @@ public:
     #pragma omp parallel for
 #endif
     for (size_t s=0; s<stdb.size(); ++s) {
-      StDescriptors st_d = dc.calc(stdb(s));
+      StDescriptors st_d = dc.calc(stdb(s), nldb);
 
       if(config.get<bool>("NORM"))
         norm.normalise(st_d);
 
-      build(rows[s], stdb(s), st_d);
+      build(rows[s], stdb(s), nldb, st_d);
     }
 
   }
@@ -266,7 +270,7 @@ t_type &getT() { return ws->T; }
 t_type &getTlabels() { return ws->Tlabels; }
 
 private:
-  tadah::mlip::memory::IMLIPWorkspaceManager* workspaceManager_ = nullptr;
+  memory::IMLIPWorkspaceManager* workspaceManager_ = nullptr;
   bool ownWorkspaceManager = false;
   Config & config;
   size_t rows = 0;
diff --git a/include/tadah/mlip/design_matrix/functions/basis_functions/dm_bf_base.h b/include/tadah/mlip/design_matrix/functions/basis_functions/dm_bf_base.h
index 8dc9cc4..0cd2f6d 100644
--- a/include/tadah/mlip/design_matrix/functions/basis_functions/dm_bf_base.h
+++ b/include/tadah/mlip/design_matrix/functions/basis_functions/dm_bf_base.h
@@ -2,6 +2,7 @@
 #define DM_BASIS_FUNCTIONS_H
 
 #include <tadah/mlip/design_matrix/functions/dm_function_base.h>
+#include <tadah/mlip/atom.h>
 #include <tadah/mlip/structure.h>
 #include <tadah/mlip/st_descriptors.h>
 #include <tadah/core/core_types.h>
diff --git a/include/tadah/mlip/models/basis.h b/include/tadah/mlip/models/basis.h
index f00caf0..d5b699b 100644
--- a/include/tadah/mlip/models/basis.h
+++ b/include/tadah/mlip/models/basis.h
@@ -11,6 +11,9 @@
 #include <stdexcept>
 #include <vector>
 
+namespace tadah {
+namespace mlip {
+
 template <typename K>
 class Basis {
 private:
@@ -105,4 +108,6 @@ larger than the amount of available AEDs\n");
     }
   }
 };
+}
+}
 #endif
diff --git a/include/tadah/mlip/models/m_all.h b/include/tadah/mlip/models/m_all.h
index d155789..5a4e798 100644
--- a/include/tadah/mlip/models/m_all.h
+++ b/include/tadah/mlip/models/m_all.h
@@ -1,3 +1,7 @@
 //#include "m_tadah_base.h"
 #include <tadah/mlip/models/m_blr.h>
 #include <tadah/mlip/models/m_krr.h>
+namespace tadah {
+namespace mlip {
+}
+}
diff --git a/include/tadah/mlip/models/m_blr.h b/include/tadah/mlip/models/m_blr.h
index 7ba9d85..ceca902 100644
--- a/include/tadah/mlip/models/m_blr.h
+++ b/include/tadah/mlip/models/m_blr.h
@@ -15,6 +15,8 @@
 #include <type_traits>
 #include <iostream>
 
+namespace tadah {
+namespace mlip {
 /**
  * @class M_BLR
  * @brief Bayesian Linear Regression (BLR).
@@ -105,7 +107,7 @@ public:
     train(desmat);
   }
 
-  void train(StructureDB &stdb, DC_Base &dc) {
+  void train(const StructureDB &stdb, const NeighborListDB &nldb, DC_Base &dc) {
 
     if(config.template get<bool>("NORM")) {
 
@@ -131,7 +133,7 @@ public:
       config.add("STRESS", stress);
     }
 
-    desmat.build(stdb,norm,dc);
+    desmat.build(stdb,nldb,norm,dc);
     train(desmat);
   }
 
@@ -332,4 +334,6 @@ non linear basis function\n");
   using M_BLR_Train<BF>::weights;
   using M_BLR_Train<BF>::Sigma;
 };
+}
+}
 #endif
diff --git a/include/tadah/mlip/models/m_krr.h b/include/tadah/mlip/models/m_krr.h
index 662e200..609c380 100644
--- a/include/tadah/mlip/models/m_krr.h
+++ b/include/tadah/mlip/models/m_krr.h
@@ -16,6 +16,8 @@
 #include <type_traits>
 #include <iostream>
 
+namespace tadah {
+namespace mlip {
 /**
  * @class M_KRR
  * @brief Kernel Ridge Regression (KRR) with Empirical Kernel Map (EKM).
@@ -405,4 +407,6 @@ non linear kernel\n");
   using M_KRR_Train<K>::kernel;
   using M_KRR_Train<K>::ekm;
 };
+}
+}
 #endif
diff --git a/include/tadah/mlip/models/m_tadah_base.h b/include/tadah/mlip/models/m_tadah_base.h
index 3285cf0..4522eef 100644
--- a/include/tadah/mlip/models/m_tadah_base.h
+++ b/include/tadah/mlip/models/m_tadah_base.h
@@ -11,6 +11,8 @@
 #include <tadah/models/m_core.h>
 #include <tadah/models/m_predict.h>
 
+namespace tadah {
+namespace mlip {
 /** This interface provides functionality required from all models.
  */
 class M_Tadah_Base:
@@ -102,5 +104,7 @@ public:
   virtual StructureDB predict(StructureDB &stdb)=0;
 
 };
+}
+}
 //template<> inline Registry<M_Tadah_Base,DM_Function_Base&,Config&>::Map Registry<M_Tadah_Base,DM_Function_Base&,Config&>::registry{};
 #endif
diff --git a/include/tadah/mlip/neighbor_calc.h b/include/tadah/mlip/neighbor_calc.h
index e2e7e59..a26b9a7 100644
--- a/include/tadah/mlip/neighbor_calc.h
+++ b/include/tadah/mlip/neighbor_calc.h
@@ -1,6 +1,6 @@
 #pragma once
 
-#include "neighbor_list.h"
+#include <tadah/mlip/neighbor_list.h>
 #include <vector>
 #include <cmath>
 
@@ -103,6 +103,8 @@ public:
             }
         }
     }
+  void build (const StructureDB &stdb, NeighborList &nList, double cutoff) {
+  }
 };
 
 } // end namespace mlip
diff --git a/include/tadah/mlip/neighbor_list_db.h b/include/tadah/mlip/neighbor_list_db.h
index 1143607..f716070 100644
--- a/include/tadah/mlip/neighbor_list_db.h
+++ b/include/tadah/mlip/neighbor_list_db.h
@@ -3,8 +3,9 @@
 #include <cstddef>
 #include <stdexcept>
 
-#include "neighbor_list.h"
-#include "structure_db.h"
+#include <tadah/mlip/neighbor_list.h>
+#include <tadah/mlip/structure_db.h>
+#include <tadah/mlip/structure_neighbor_view.h>
 
 namespace tadah {
 namespace mlip {
@@ -101,6 +102,22 @@ public:
         return nlist_.getShiftXPtr(gIdx);
     }
 
+    /**
+     * @brief Creates a StructureNeighborView for a given Structure, enabling local-atom queries.
+     *
+     * This method constructs a view object that references nlist_ (the global HPC arrays)
+     * and the selected structure. The code can then call numNeighbors(localIdx) or getNeighborsPtr(localIdx)
+     * without manual offset logic.
+     *
+     * @param s A reference to the chosen Structure from the same DB.
+     * @return A lightweight view object referencing the HPC data subrange for s.
+     */
+    inline StructureNeighborView createStructureView(const Structure &s) const
+    {
+        // Optionally check that s.db_ == &db_ if you want to ensure they belong together.
+        return StructureNeighborView(nlist_, s);
+    }
+
 private:
     /**
      * @brief A read-only reference to the HPC structure database.
diff --git a/include/tadah/mlip/structure_neighbor_view.h b/include/tadah/mlip/structure_neighbor_view.h
new file mode 100644
index 0000000..78dc5c7
--- /dev/null
+++ b/include/tadah/mlip/structure_neighbor_view.h
@@ -0,0 +1,122 @@
+#pragma once
+
+#include <cstddef>
+#include <stdexcept>
+#include "neighbor_list.h"
+#include "structure_db.h"
+
+namespace tadah {
+namespace mlip {
+
+/**
+ * @class StructureNeighborView
+ * @brief Creates a read-only slice of the global NeighborList for a single Structure.
+ *
+ * This design references a specific "Structure" (with known offset_ and natoms()) 
+ * and a global "NeighborList" that covers all atoms in the HPC arrays.
+ * Methods compute each atom's global index on the fly (structure.offset_ + localIndex)
+ * and retrieve the neighbor data from the HPC neighbor list. 
+ *
+ * No duplication of neighbor arrays occurs. 
+ * This allows queries on "per-structure" neighbor data while still leveraging 
+ * the single global HPC neighbor list in memory.
+ */
+class StructureNeighborView
+{
+public:
+    /**
+     * @brief Constructs the view with references to a single Structure and the global HPC list.
+     *
+     * The code checks that the Structure's database pointer matches the DB 
+     * that the "NeighborList" belongs to, if desired. That step is optional 
+     * and can be enforced by additional runtime checks if multiple DBs exist.
+     *
+     * @param globalList Reference to the global HPC neighbor list.
+     * @param structure  Reference to the specific Structure for local atom offsets.
+     */
+    StructureNeighborView(const NeighborList &globalList, const Structure &structure)
+    : globalList_(globalList), str_(structure)
+    {
+        // Optional: check that str_.db_ matches the DB used by the globalList if needed.
+    }
+
+    /**
+     * @brief Returns how many atoms this structure possesses in its subrange.
+     */
+    inline std::size_t numAtoms() const
+    {
+        return str_.natoms();
+    }
+
+    /**
+     * @brief Returns how many neighbors atom (localAtomIndex) has within this structure's sub-slice.
+     *
+     * Internally, the HPC neighbor list is indexed by a global atom index.
+     * The method computes: globalIdx = str_.offset_ + localAtomIndex.
+     */
+    inline std::size_t numNeighbors(std::size_t localAtomIndex) const
+    {
+        if (localAtomIndex >= str_.natoms()) {
+            throw std::out_of_range("StructureNeighborView::numNeighbors: local atom index out of range");
+        }
+        std::size_t gIdx = str_.offset_ + localAtomIndex;
+        return globalList_.numNeighbors(gIdx);
+    }
+
+    /**
+     * @brief Returns a pointer to the neighbor indices for atom (localAtomIndex) in this structure.
+     *
+     * The caller typically pairs this with numNeighbors() to iterate over neighbor indices.
+     */
+    inline const std::size_t* getNeighborsPtr(std::size_t localAtomIndex) const
+    {
+        if (localAtomIndex >= str_.natoms()) {
+            throw std::out_of_range("StructureNeighborView::getNeighborsPtr: local atom index out of range");
+        }
+        std::size_t gIdx = str_.offset_ + localAtomIndex;
+        return globalList_.getNeighborsPtr(gIdx);
+    }
+
+    /**
+     * @brief Returns a pointer to the mirror index array for symmetrical pairs, if needed.
+     */
+    inline const std::size_t* getMirrorIndexPtr(std::size_t localAtomIndex) const
+    {
+        if (localAtomIndex >= str_.natoms()) {
+            throw std::out_of_range("StructureNeighborView::getMirrorIndexPtr: local atom index out of range");
+        }
+        std::size_t gIdx = str_.offset_ + localAtomIndex;
+        return globalList_.getMirrorIndexPtr(gIdx);
+    }
+
+    /**
+     * @brief Retrieves X shift array pointer for atom (localAtomIndex).
+     *        Similar methods can be added for Y, Z, etc.
+     */
+    inline const int* getShiftXPtr(std::size_t localAtomIndex) const
+    {
+        if (localAtomIndex >= str_.natoms()) {
+            throw std::out_of_range("StructureNeighborView::getShiftXPtr: local atom index out of range");
+        }
+        std::size_t gIdx = str_.offset_ + localAtomIndex;
+        return globalList_.getShiftXPtr(gIdx);
+    }
+
+    // Additional shiftYPtr, shiftZPtr, etc., can be provided if your code needs them.
+
+private:
+    /**
+     * @brief Reference to the global HPC neighbor list that covers all atoms.
+     */
+    const NeighborList &globalList_;
+
+    /**
+     * @brief Reference to the specific structure. 
+     *        offset_ + localIndex => global HPC neighbor index.
+     */
+    const Structure &str_;
+};
+
+} // namespace mlip
+} // namespace tadah
+
diff --git a/include/tadah/mlip/trainer.h b/include/tadah/mlip/trainer.h
index ef20679..8ebc323 100644
--- a/include/tadah/mlip/trainer.h
+++ b/include/tadah/mlip/trainer.h
@@ -11,6 +11,9 @@
 #include <tadah/mlip/memory/IMLIPWorkspaceManager.h>
 #include <tadah/mlip/memory/MLIPWorkspaceManager.h>
 
+#include <tadah/mlip/neighbor_calc.h>
+#include <tadah/mlip/neighbor_list_db.h>
+
 #include <iostream>
 
 namespace tadah {
@@ -20,10 +23,11 @@ class Trainer {
     Config config;
     DC_Selector DCS;
     DescriptorsCalc<> dc;
-    NNFinder nnf;
+    // NNFinder nnf;
     DM_Function_Base *fb;
     M_Tadah_Base *model;
     DesignMatrix<DM_Function_Base&> dm;
+  StructureDB stdb;
 
     ~Trainer() {
       if(model)
@@ -36,7 +40,7 @@ class Trainer {
       DCS(config),
       dc(config,*DCS.d2b,*DCS.d3b,*DCS.dmb,
           *DCS.c2b,*DCS.c3b,*DCS.cmb),
-      nnf(config),
+      // nnf(config),
       fb(CONFIG::factory<DM_Function_Base,Config&>(
             config.get<std::string>("MODEL",1),config)),
       model(CONFIG::factory<M_Tadah_Base,DM_Function_Base&,Config&,tadah::mlip::memory::IMLIPWorkspaceManager&>
@@ -48,7 +52,7 @@ class Trainer {
   }
 
     void train(StructureDB &stdb) {
-      nnf.calc(stdb);
+      // nnf.calc(stdb);
       model->train(stdb,dc);
     }
 
-- 
GitLab