From dfa76e78c6a6289f0c6811c7c44e1745aa8e5f9a Mon Sep 17 00:00:00 2001
From: mkirsz <s1351949@sms.ed.ac.uk>
Date: Fri, 22 Nov 2024 23:55:23 +0000
Subject: [PATCH] Update to random_shuffle

---
 include/tadah/mlip/models/basis.h | 153 +++++++++++++++---------------
 1 file changed, 79 insertions(+), 74 deletions(-)

diff --git a/include/tadah/mlip/models/basis.h b/include/tadah/mlip/models/basis.h
index 74a4071..80f2bbe 100644
--- a/include/tadah/mlip/models/basis.h
+++ b/include/tadah/mlip/models/basis.h
@@ -6,98 +6,103 @@
 #include <tadah/core/core_types.h>
 #include <tadah/core/config.h>
 
+#include <algorithm>
 #include <numeric>
 #include <stdexcept>
 #include <vector>
 
 template <typename K>
 class Basis {
-    private:
-        Config &config;
-        int verbose;
-    public:
-       Matrix b;
-       t_type T;    // Vectors corresponding to basis vectors
-        Basis(Config &c):
-            config(c),
-            verbose(c.get<int>("VERBOSE"))
+private:
+  Config &config;
+  int verbose;
+public:
+  Matrix b;
+  t_type T;    // Vectors corresponding to basis vectors
+  Basis(Config &c):
+    config(c),
+    verbose(c.get<int>("VERBOSE"))
 
-    {}
+  {}
 
-        void set_basis(Matrix &b_) {
-            b=b_;
-        }
-        void build_random_basis(size_t s, StDescriptorsDB &st_desc_db) {
+  void set_basis(Matrix &b_) {
+    b=b_;
+  }
+  void build_random_basis(size_t s, StDescriptorsDB &st_desc_db) {
 
-            // generate indices
-            std::vector<std::tuple<size_t,size_t>> indices;
-            size_t counter=0;
-            for( size_t st = 0; st < st_desc_db.size(); st++ ) {
-                for( size_t a = 0; a < st_desc_db(st).naed() ; a++ ) {
-                    indices.push_back(std::tuple<size_t,size_t>(st,a));
-                    counter++;
-                }
-            }
+    // generate indices
+    std::vector<std::tuple<size_t,size_t>> indices;
+    size_t counter=0;
+    for( size_t st = 0; st < st_desc_db.size(); st++ ) {
+      for( size_t a = 0; a < st_desc_db(st).naed() ; a++ ) {
+        indices.push_back(std::tuple<size_t,size_t>(st,a));
+        counter++;
+      }
+    }
 
-            if (counter < s) {
-                throw std::runtime_error("The number of requestd basis vectors is\n \
-                        larger than the amount of available AEDs\n");
-            }
+    if (counter < s) {
+      throw std::runtime_error("The number of requestd basis vectors is\n \
+larger than the amount of available AEDs\n");
+    }
 
-            std::random_shuffle(indices.begin(), indices.end());
+    std::random_device rd;
+    std::default_random_engine rng(rd()); // Initialize random engine
+    std::shuffle(indices.begin(), indices.end(),rng);
 
-            b.resize(st_desc_db(0).dim(),s);
-            b.set_zero();
-            // set first basis function as "bias vector"
-            b(0,0)=1;
-            for (size_t i=1; i<s; ++i) {
-                size_t st = std::get<0>(indices[i]);
-                size_t a = std::get<1>(indices[i]);
-                const aed_type2 &aed = st_desc_db(st).get_aed(a);
-                for (size_t j=0; j<aed.size(); ++j) {
-                    b(j,i)=aed[j];
-                }
-            }
-        }
-        void prep_basis_for_krr(StDescriptorsDB &st_desc_db,
-                StructureDB &stdb) {
+    b.resize(st_desc_db(0).dim(),s);
+    b.set_zero();
+    // set first basis function as "bias vector"
+    b(0,0)=1;
+    for (size_t i=1; i<s; ++i) {
+      size_t st = std::get<0>(indices[i]);
+      size_t a = std::get<1>(indices[i]);
+      const aed_type2 &aed = st_desc_db(st).get_aed(a);
+      for (size_t j=0; j<aed.size(); ++j) {
+        b(j,i)=aed[j];
+      }
+    }
+  }
+  void prep_basis_for_krr(StDescriptorsDB &st_desc_db,
+                          StructureDB &stdb) {
 
-            size_t s = config.get<size_t>("SBASIS");
+    size_t s = config.get<size_t>("SBASIS");
 
-            if (stdb.size() != st_desc_db.size()) {
-                throw std::runtime_error("The size of StructureDB is different \n \
-                        from StDescriptorsDB. This is likely to be a bug.\n");
-            }
+    if (stdb.size() != st_desc_db.size()) {
+      throw std::runtime_error("The size of StructureDB is different \n \
+from StDescriptorsDB. This is likely to be a bug.\n");
+    }
 
-            // generate indices
-            std::vector<size_t> indices;
-            for( size_t st = 0; st < stdb.size(); st++ ) {
-                indices.push_back(st);
-            }
+    // generate indices
+    std::vector<size_t> indices;
+    for( size_t st = 0; st < stdb.size(); st++ ) {
+      indices.push_back(st);
+    }
 
-            if (indices.size() < s) {
-                throw std::runtime_error("The number of requestd basis vectors is\n \
-                        larger than the amount of available AEDs\n");
-            }
+    if (indices.size() < s) {
+      throw std::runtime_error("The number of requestd basis vectors is\n \
+larger than the amount of available AEDs\n");
+    }
 
-            std::random_shuffle(indices.begin(), indices.end());
+    std::random_device rd;
+    std::default_random_engine rng(rd()); // Initialize random engine
+    std::shuffle(indices.begin(), indices.end(),rng);
 
-            b.resize(st_desc_db(0).dim(),s);
-            T.resize(s);
-            b.set_zero();
-            b(0,0)=1;
-            // Here we add all atomic energy descriptors into single
-            // descriptor which represents total energy of this configuration
-            for (size_t i=1; i<s; ++i) {
-                const size_t st = indices[i];
-                T(i)=stdb(st).energy/st_desc_db(st).naed();
-                for( size_t a=0; a<st_desc_db(st).naed(); a++ ) {
-                    const aed_type2 &aed = st_desc_db(st).get_aed(a);
-                    for (size_t j=0; j<aed.size(); ++j) {
-                        b(j,i)+=aed[j]/st_desc_db(st).naed();
-                    }
-                }
-            }
+    b.resize(st_desc_db(0).dim(),s);
+    T.resize(s);
+    b.set_zero();
+    b(0,0)=1;
+    // Here we add all atomic energy descriptors into single
+    // descriptor which represents total energy of this configuration
+    for (size_t i=1; i<s; ++i) {
+      const size_t st = indices[i];
+      T(i)=stdb(st).energy/st_desc_db(st).naed();
+      for( size_t a=0; a<st_desc_db(st).naed(); a++ ) {
+        const aed_type2 &aed = st_desc_db(st).get_aed(a);
+        for (size_t j=0; j<aed.size(); ++j) {
+          b(j,i)+=aed[j]/st_desc_db(st).naed();
         }
+      }
+    }
+  }
 };
 #endif
-- 
GitLab