From 083d287f4e36f3fd44dad58483516aa7f7536193 Mon Sep 17 00:00:00 2001
From: mkirsz <s1351949@sms.ed.ac.uk>
Date: Tue, 24 Dec 2024 23:49:26 +0000
Subject: [PATCH] Fixing of ML params, also shifted cos cutoff

---
 include/tadah/models/cutoffs.h            | 15 +++++
 include/tadah/models/descriptors/d_base.h |  1 +
 include/tadah/models/m_blr_train.h        | 63 +++++++++++++++++--
 include/tadah/models/m_krr_train.h        | 73 +++++++++++++++++++----
 include/tadah/models/m_train.h            | 64 ++++++++++++++++++--
 src/cutoffs.cpp                           | 42 +++++++++++++
 src/m_train.cpp                           |  4 ++
 tests/test_cutoffs.cpp                    | 43 +++++++++++++
 tests/test_svd.cpp                        |  1 -
 9 files changed, 286 insertions(+), 20 deletions(-)
 create mode 100644 src/m_train.cpp

diff --git a/include/tadah/models/cutoffs.h b/include/tadah/models/cutoffs.h
index 5403645..882f0c9 100644
--- a/include/tadah/models/cutoffs.h
+++ b/include/tadah/models/cutoffs.h
@@ -139,5 +139,20 @@ class Cut_Poly2 : public Cut_Base {
         double calc(double r);
         double calc_prime(double r);
 };
+class Cut_Cos_S : public Cut_Base {
+    private:
+        std::string lab = "Cut_Cos_S";
+        double rcut, rcut_sq, rcut_inv;
+        double rcut_inner;
+    public:
+        Cut_Cos_S();
+        Cut_Cos_S(double rcut);
+        std::string label() ;
+        void set_rcut(const double r);
+        double get_rcut();
+        double get_rcut_sq();
+        double calc(double r);
+        double calc_prime(double r);
+};
 //template<> inline Registry<Cut_Base,double>::Map Registry<Cut_Base,double>::registry{};
 #endif
diff --git a/include/tadah/models/descriptors/d_base.h b/include/tadah/models/descriptors/d_base.h
index e3499b8..e02737c 100644
--- a/include/tadah/models/descriptors/d_base.h
+++ b/include/tadah/models/descriptors/d_base.h
@@ -41,6 +41,7 @@ protected:
 public:
   std::vector<std::string> keys;    // keys required by this descriptor
   size_t nparams;  // number of params which follow TYPExB
+  std::vector<bool> is_optimizable;
   static void get_grid(const Config &c, std::string key, v_type &v);
   static v_type get_grid(std::vector<std::string>&);  // expand joined grids
   int verbose;
diff --git a/include/tadah/models/m_blr_train.h b/include/tadah/models/m_blr_train.h
index d78bd43..fdbc457 100644
--- a/include/tadah/models/m_blr_train.h
+++ b/include/tadah/models/m_blr_train.h
@@ -1,6 +1,8 @@
 #ifndef M_BLR_TRAIN_H
 #define M_BLR_TRAIN_H
 
+#include "tadah/core/core_types.h"
+#include "tadah/core/utils.h"
 #include <tadah/models/m_train.h>
 #include <tadah/models/m_blr_core.h>
 #include <tadah/models/linear_regressor.h>
@@ -9,6 +11,7 @@
 #include <tadah/core/config.h>
 
 #include <iostream>
+#include <vector>
 
 /**
  * @class M_BLR_Train
@@ -54,13 +57,63 @@ public:
      * @param T Target vector.
      * @throws std::runtime_error if the object is already trained.
      */
-    void train(phi_type &Phi, t_type &T) {
-        if (trained) {
-            throw std::runtime_error("This object is already trained!");
+  void train(phi_type &Phi, t_type &T) {
+    if (trained) {
+      throw std::runtime_error("This object is already trained!");
+    }
+
+    if (config.exist("FIXWEIGHT") && config.exist("FIXINDEX")) {
+      std::vector<std::string> indices_str(config.size("FIXINDEX"));
+      config.get("FIXINDEX", indices_str);
+
+      // Parse indices and perform checks
+      std::vector<size_t> indices = parse_indices(indices_str);
+
+      // Check that the min index is >= 1 and max index is <= Phi.cols()
+      if (!indices.empty()) {
+        size_t max_index = *std::max_element(indices.begin(), indices.end());
+        if (*std::min_element(indices.begin(), indices.end()) < 1 || max_index > Phi.cols()) {
+          throw std::runtime_error("FIXINDEX: Indices out of bounds: valid range is from 1 to " + std::to_string(Phi.cols()));
         }
-        LinearRegressor::train(config, Phi, T, weights, Sigma);
-        trained = true;
+      }
+
+      // Adjust for 0-based indexing
+      for (auto &i : indices) i--;
+
+      t_type w_f(config.size("FIXWEIGHT"));
+      if (w_f.size() != indices.size()) {
+        throw std::runtime_error("FIXWEIGHT and FIXINDEX differ in size: " + std::to_string(w_f.size()) + " and " + std::to_string(indices.size()));
+      }
+
+      config.get("FIXWEIGHT", w_f);
+
+      t_type T_r;
+      auto move_map = prep_train_with_residuals(Phi, T, indices, w_f, T_r);
+
+      // Resize Phi for training
+      Phi.resize(Phi.rows(), Phi.cols() - indices.size());
+
+      LinearRegressor::train(config, Phi, T_r, weights, Sigma);
+
+      t_type w_temp(weights.size() + indices.size());
+
+      for (size_t i = 0; i < w_f.size(); ++i) {
+        w_temp[w_temp.size() - w_f.size() + i] = w_f[i];
+      }
+
+      for (size_t i = 0; i < weights.size(); ++i) {
+        w_temp[i] = weights[i];
+      }
+
+      weights = w_temp;
+      reverse_vector(weights, move_map);
+      trained = true;
+    }
+    else {
+      LinearRegressor::train(config, Phi, T, weights, Sigma);
+      trained = true;
     }
+  }
 };
 
 #endif // M_BLR_TRAIN_H
diff --git a/include/tadah/models/m_krr_train.h b/include/tadah/models/m_krr_train.h
index 24e7d83..d4e7903 100644
--- a/include/tadah/models/m_krr_train.h
+++ b/include/tadah/models/m_krr_train.h
@@ -60,21 +60,74 @@ public:
      * @param T Target vector.
      * @throws std::runtime_error if the object is already trained.
      */
-    void train(phi_type &Phi, t_type &T) {
-        if (trained) {
-            throw std::runtime_error("This object is already trained!");
-        }
-        if (kernel.get_label() != "Kern_Linear") {
-            ekm.project(Phi);
+  void train(phi_type &Phi, t_type &T) {
+    if (trained) {
+      throw std::runtime_error("This object is already trained!");
+    }
+
+    if (config.exist("FIXWEIGHT") && config.exist("FIXINDEX")) {
+      std::vector<std::string> indices_str(config.size("FIXINDEX"));
+      config.get("FIXINDEX", indices_str);
+
+      // Parse indices and perform checks
+      std::vector<size_t> indices = parse_indices(indices_str);
+
+      // Check that the min index is >= 1 and max index is <= Phi.cols()
+      if (!indices.empty()) {
+        size_t max_index = *std::max_element(indices.begin(), indices.end());
+        if (*std::min_element(indices.begin(), indices.end()) < 1 || max_index > Phi.cols()) {
+          throw std::runtime_error("FIXINDEX: Indices out of bounds: valid range is from 1 to " + std::to_string(Phi.cols()));
         }
-        LinearRegressor::train(config, Phi, T, weights, Sigma);
+      }
+
+      // Adjust for 0-based indexing
+      for (auto &i : indices) i--;
+
+      t_type w_f(config.size("FIXWEIGHT"));
+      if (w_f.size() != indices.size()) {
+        throw std::runtime_error("FIXWEIGHT and FIXINDEX differ in size: " + std::to_string(w_f.size()) + " and " + std::to_string(indices.size()));
+      }
+
+      config.get("FIXWEIGHT", w_f);
 
-        if (kernel.get_label() != "Kern_Linear") {
-            weights = ekm.EKM_mat * weights;
+      t_type T_r;
+      auto move_map = prep_train_with_residuals(Phi, T, indices, w_f, T_r);
+
+      reverse_columns(Phi, move_map);
+
+      for (const auto &i : indices) {
+        for (size_t j=0; j<Phi.rows(); ++j) {
+          Phi(j,i)=0;
         }
-        trained = true;
+      }
+
+      if (kernel.get_label() != "Kern_Linear") {
+        ekm.project(Phi);
+      }
+      LinearRegressor::train(config, Phi, T_r, weights, Sigma);
+      if (kernel.get_label() != "Kern_Linear") {
+        weights = ekm.EKM_mat * weights;
+      }
+
+      // reverse_vector(weights, move_map);
+      for (size_t i=0; i<indices.size(); ++i) {
+        weights[indices[i]]=w_f[i];
+      }
+      trained = true;
+    }
+    else {
+      if (kernel.get_label() != "Kern_Linear") {
+        ekm.project(Phi);
+      }
+      LinearRegressor::train(config, Phi, T, weights, Sigma);
+      if (kernel.get_label() != "Kern_Linear") {
+        weights = ekm.EKM_mat * weights;
+      }
+      trained = true;
     }
 
+  }
+
     /**
      * @brief Standard KRR training using covariance matrix computation.
      *
diff --git a/include/tadah/models/m_train.h b/include/tadah/models/m_train.h
index abcf824..09a2113 100644
--- a/include/tadah/models/m_train.h
+++ b/include/tadah/models/m_train.h
@@ -2,6 +2,7 @@
 #define M_TRAIN_H
 
 #include <tadah/core/core_types.h>
+#include <tadah/core/utils.h>
 
 /**
  * @class M_Train
@@ -11,12 +12,12 @@
  */
 class M_Train {
 public:
-    /**
+  /**
      * @brief Virtual destructor for polymorphic deletion.
      */
-    virtual ~M_Train() {}
+  virtual ~M_Train() {}
 
-    /**
+  /**
      * @brief Pure virtual function to train the model.
      *
      * Must be implemented by derived classes.
@@ -24,7 +25,62 @@ public:
      * @param Phi Design matrix containing input features.
      * @param T Target vector for training.
      */
-    virtual void train(phi_type &Phi, t_type &T) = 0;
+  virtual void train(phi_type &Phi, t_type &T) = 0;
+
+  /**
+ * Swaps columns and fits a reduced matrix to obtain residuals.
+ *
+ * This function processes a full matrix Phi by swapping columns specified by indices,
+ * creating a new pointer for the fixed matrix Phi_f, and then use w_f weights
+ * to obtain prediction T_f = Phi_f w_f. The function then copmutes the residual vector
+ * T_r = T - T_f. On exit, the Phi matrix is rearranged such that the indices.size()
+ * rightmost columns of the original matrix are the same as Phi_f.
+ * The function return a mapping between columns of original Phi matrix and rearranged.
+ *
+ * @param Phi         The full matrix to be reduced and fitted to the adjusted target
+ *                    vector.
+ * @param T           The full target vector to which the reduced matrix Phi will
+ *                    be fitted.
+ * @param indices     A vector of indices specifying which columns in the
+ *                    original matrix are fixed and should be considered during
+ *                    the regression.
+ * @param w_f         A vector of fixed weights corresponding to the specified
+ *                    indices. These weights are used to adjust the target
+ *                    vector, resulting in residuals.
+ */
+  template <typename M>
+  std::vector<size_t> prep_train_with_residuals(
+    Matrix_Base<M> &Phi, t_type &T, std::vector<size_t> &indices, const t_type w_f, t_type &T_r) {
+
+    if (w_f.size() != indices.size()) {
+      throw std::runtime_error("Size of w_f must be equal to size of indices.");
+    }
+
+    std::sort(indices.begin(), indices.end());
+
+    if (indices.back() >= Phi.cols()) {
+      throw std::runtime_error("Largest index is greater than the number of columns in Phi.");
+    }
+
+    std::vector<size_t> move_map = move_columns_in_place(Phi,indices);
+
+    //   indices.size() = 1
+    //   indices[0] 1
+    //   1 4 7        1 7    4
+    //   2 5 8   ->   2 8    5
+    //   3 6 9        3 9    6
+    //    Phi        Phi_o   Phi_f
+    //    move_map = {0,2,1}
+
+    // the fixed columns are to the right
+    MatrixView Phi_f(Phi, Phi.cols()-indices.size());
+
+    t_type T_f = Phi_f * w_f;
+
+    T_r = T-T_f;
+
+    return move_map;
+  }
 };
 
 #endif // M_TRAIN_H
diff --git a/src/cutoffs.cpp b/src/cutoffs.cpp
index 768516a..b818cda 100644
--- a/src/cutoffs.cpp
+++ b/src/cutoffs.cpp
@@ -5,6 +5,7 @@
 
 template<> CONFIG::Registry<Cut_Base,double>::Map CONFIG::Registry<Cut_Base,double>::registry{};
 CONFIG::Registry<Cut_Base,double>::Register<Cut_Cos> Cut_Cos_1( "Cut_Cos" );
+CONFIG::Registry<Cut_Base,double>::Register<Cut_Cos_S> Cut_Cos_S_1( "Cut_Cos_S" );
 CONFIG::Registry<Cut_Base,double>::Register<Cut_Tanh> Cut_Tanh_1( "Cut_Tanh" );
 CONFIG::Registry<Cut_Base,double>::Register<Cut_Poly2> Cut_Poly_1( "Cut_Poly2" );
 CONFIG::Registry<Cut_Base,double>::Register<Cut_Dummy> Cut_Dummy_1( "Cut_Dummy" );
@@ -165,3 +166,44 @@ double Cut_Poly2::calc_prime(double r) {
     double rs=r-rcut_inner;
     return -30.0*(rs-1.0)*(rs-1.0)*rs*rs;
 }
+
+Cut_Cos_S::Cut_Cos_S() {}
+Cut_Cos_S::Cut_Cos_S(double rcut)
+{
+    set_rcut(rcut);
+    test_rcut(rcut);
+}
+
+std::string Cut_Cos_S::label() {
+    return lab;
+}
+
+void Cut_Cos_S::set_rcut(const double r) {
+    test_rcut(r);
+    rcut=r;
+    rcut_sq=r*r;
+    rcut_inv= r<=0 ? 0.0 : 1.0/r;
+    rcut_inner=rcut-1.0;
+}
+double Cut_Cos_S::get_rcut() {
+    return rcut;
+}
+
+double Cut_Cos_S::get_rcut_sq() {
+    return rcut_sq;
+}
+
+double Cut_Cos_S::calc(double r) {
+  if (r>=rcut) return 0.0;
+  else if (r<= rcut_inner) return 1.0;
+  double rs = (r-rcut_inner)/(rcut-rcut_inner);
+  return 0.5*(1+std::cos(M_PI*rs));
+}
+double Cut_Cos_S::calc_prime(double r) {
+  if (r>=rcut || r<= rcut_inner) return 0.0;
+  else {
+    double rs = (r - rcut_inner) / (rcut - rcut_inner);
+    double drs_dr = 1.0 / (rcut - rcut_inner);
+    return -0.5 * M_PI * std::sin(M_PI * rs) * drs_dr;
+  }
+}
diff --git a/src/m_train.cpp b/src/m_train.cpp
new file mode 100644
index 0000000..3bc76b5
--- /dev/null
+++ b/src/m_train.cpp
@@ -0,0 +1,4 @@
+#include <tadah/models/m_train.h>
+
+// void M_Train::train_with_residuals(
+// }
diff --git a/tests/test_cutoffs.cpp b/tests/test_cutoffs.cpp
index 997b03d..aaac9d0 100644
--- a/tests/test_cutoffs.cpp
+++ b/tests/test_cutoffs.cpp
@@ -162,3 +162,46 @@ TEST_CASE( "Testing Cutoffs: Cut_Poly2", "[Cut_Poly2]" ) {
                Catch::Matchers::WithinAbs(-1.728, 1e-12));
     delete c2b;
 }
+TEST_CASE( "Testing Cutoffs: Cut_Cos_s", "[Cut_Cos_S]" ) {
+
+    REQUIRE_NOTHROW(Cut_Cos_S());
+    double rcut2b = 6.2;
+    double rcut2bsq = rcut2b*rcut2b;
+    using Cut = Cut_Cos_S;
+    std::string cuttype = "Cut_Cos_S";
+    Cut_Base *c2b = new Cut( rcut2b );
+
+    REQUIRE( c2b->label() == cuttype );
+
+    REQUIRE( c2b->calc(rcut2b) < std::numeric_limits<double>::min() );
+    REQUIRE( c2b->calc_prime(rcut2b) < std::numeric_limits<double>::min() );
+    REQUIRE( std::abs(c2b->get_rcut()-rcut2b)<std::numeric_limits<double>::min() );
+    REQUIRE( std::abs(c2b->get_rcut_sq()-rcut2bsq)<std::numeric_limits<double>::min() );
+
+    // cutoff cannot be negative
+    double temp = -0.1;
+    REQUIRE_THROWS(Cut( temp ));
+    REQUIRE_THROWS_AS(c2b->set_rcut(temp), std::runtime_error);
+    REQUIRE_THROWS_AS(c2b->test_rcut(temp), std::runtime_error);
+
+    // recheck after resetting cutoff
+    rcut2b=3.4;
+    rcut2bsq=rcut2b*rcut2b;
+    c2b->set_rcut(100000);
+    c2b->set_rcut(rcut2b);
+    REQUIRE( c2b->calc(rcut2b) < std::numeric_limits<double>::min() );
+    REQUIRE( c2b->calc_prime(rcut2b) < std::numeric_limits<double>::min() );
+    REQUIRE( std::abs(c2b->get_rcut()-rcut2b)<std::numeric_limits<double>::min() );
+    REQUIRE( std::abs(c2b->get_rcut_sq()-rcut2bsq)<std::numeric_limits<double>::min() );
+
+  REQUIRE_THAT(c2b->calc(2.0),
+               Catch::Matchers::WithinAbs(1, 1e-12));
+  REQUIRE_THAT(c2b->calc_prime(2.0),
+               Catch::Matchers::WithinAbs(0, 1e-12));
+
+  REQUIRE_THAT(c2b->calc(3.0),
+               Catch::Matchers::WithinAbs(0.3454915028, 1e-10));
+  REQUIRE_THAT(c2b->calc_prime(3.0),
+               Catch::Matchers::WithinAbs(-1.4939160824, 1e-10));
+    delete c2b;
+}
diff --git a/tests/test_svd.cpp b/tests/test_svd.cpp
index f216f20..d434096 100644
--- a/tests/test_svd.cpp
+++ b/tests/test_svd.cpp
@@ -1,7 +1,6 @@
 #include "catch2/catch.hpp"
 #include <tadah/core/maths.h>
 #include <tadah/models/svd.h>
-#include <iomanip>
 
 void multiplyMatrices(const double* A, const double* B, double* C, int m, int n, int p) {
   // C = A * B
-- 
GitLab