From 14f2e1ce3b3dc972330004537cd55c977d4ad2fc Mon Sep 17 00:00:00 2001
From: Marcin Kirsz <mkirsz@ed.ac.uk>
Date: Fri, 7 Feb 2025 13:12:01 +0000
Subject: [PATCH] LAMBDA 0 rcond number

---
 .../models/descriptors/d_basis_functions.h    | 149 +++++++++---------
 include/tadah/models/linear_regressor.h       |   3 +-
 include/tadah/models/m_krr_train.h            |   1 +
 include/tadah/models/ols.h                    |   3 +-
 include/tadah/models/ridge_regression.h       |   2 +-
 tests/test_ols.cpp                            |   2 +-
 6 files changed, 82 insertions(+), 78 deletions(-)

diff --git a/include/tadah/models/descriptors/d_basis_functions.h b/include/tadah/models/descriptors/d_basis_functions.h
index 09d4683..dcf14f4 100644
--- a/include/tadah/models/descriptors/d_basis_functions.h
+++ b/include/tadah/models/descriptors/d_basis_functions.h
@@ -5,101 +5,104 @@
 
 // GAUSSIANS
 inline double G(double r, double eta, double miu) {
-    return exp(-eta*(r-miu)*(r-miu));
+  return exp(-eta*(r-miu)*(r-miu));
 }
 inline double G(double r, double eta, double miu, double f) {
-    return exp(-eta*(r-miu)*(r-miu))*f;
+  return exp(-eta*(r-miu)*(r-miu))*f;
 }
 inline double dG(double r, double eta, double miu) {
-    return -2.0*eta*(r-miu)*exp(-eta*(r-miu)*(r-miu));
+  return -2.0*eta*(r-miu)*exp(-eta*(r-miu)*(r-miu));
 }
 inline double dG(double r, double eta, double miu, double f, double fp) {
-    return exp(-eta*(r-miu)*(r-miu) )*(-2.0*f*eta*(r-miu)+fp);
+  return exp(-eta*(r-miu)*(r-miu) )*(-2.0*f*eta*(r-miu)+fp);
 }
 
 // BLIPS
 inline double B(double x_c)
-    // x_c = eta*(rij-r_s)
-    // return B(x_c)
+  // r_s = miu
+  // x_c = eta*(rij-r_s)
 {
-    if (-2.0 < x_c && x_c <= -1.0) {
-        double b = 2.0 + x_c;
-        return 0.25 * b*b*b;
-    }
-    else if (-1.0 < x_c && x_c <= 1.0) {
-        double x2 = x_c*x_c;
-        double x3 = x2*x_c;
-        return 1.0 - 1.5*x2 + 0.75*std::fabs(x3);
-    }
-    else if (1.0 < x_c && x_c < 2.0) {
-        double b = 2.0 - x_c;
-        return 0.25 * b*b*b;
-    }
-    else {
-        return 0.0;
-    }
+  if (-2.0 < x_c && x_c <= -1.0) {
+    double b = 2.0 + x_c;
+    return 0.25 * b*b*b;
+  }
+  else if (-1.0 < x_c && x_c <= 1.0) {
+    double x2 = x_c*x_c;
+    double x3 = x2*x_c;
+    return 1.0 - 1.5*x2 + 0.75*std::fabs(x3);
+  }
+  else if (1.0 < x_c && x_c < 2.0) {
+    double b = 2.0 - x_c;
+    return 0.25 * b*b*b;
+  }
+  else {
+    return 0.0;
+  }
 }
 inline double B(double r, double eta, double miu) {
   return B(eta*(r-miu));
 }
 inline double B(double x_c, double f)
-    // x_c = eta*(rij-r_s)
-    // return B(x_c)*f
+  // x_c = eta*(rij-r_s)
+  // return B(x_c)*f
 {
-    if (-2.0 < x_c && x_c <= -1.0) {
-        double b = 2.0 + x_c;
-        return 0.25 * b*b*b *f;
-    }
-    else if (-1.0 < x_c && x_c <= 1.0) {
-        double x2 = x_c*x_c;
-        double x3 = x2*x_c;
-        return f*(1.0 - 1.5*x2 + 0.75*std::fabs(x3));
-    }
-    else if (1.0 < x_c && x_c < 2.0) {
-        double b = 2.0 - x_c;
-        return f*(0.25 * b*b*b);
-    }
-    else {
-        return 0.0;
-    }
+  if (-2.0 < x_c && x_c <= -1.0) {
+    double b = 2.0 + x_c;
+    return 0.25 * b*b*b *f;
+  }
+  else if (-1.0 < x_c && x_c <= 1.0) {
+    double x2 = x_c*x_c;
+    double x3 = x2*x_c;
+    return f*(1.0 - 1.5*x2 + 0.75*std::fabs(x3));
+  }
+  else if (1.0 < x_c && x_c < 2.0) {
+    double b = 2.0 - x_c;
+    return f*(0.25 * b*b*b);
+  }
+  else {
+    return 0.0;
+  }
 }
 inline double dB(double x_c, double eta)
-    // def: x_c = eta*(rij-r_s)
-    // return d/dr_ij B(x_c) = eta*d/dx_c B(x_c)
+  // def: x_c = eta*(rij-r_s)
+  // return d/dr_ij B(x_c) = eta*d/dx_c B(x_c)
 {
-    if (-2.0 < x_c && x_c <= -1.0) {
-        double d = 2.0+x_c;
-        return 0.75*eta*d*d;
-    }
-    else if (-1.0 < x_c && x_c <= 1.0) {
-        return -3.0*eta*x_c + 2.25*eta*x_c*std::fabs(x_c);
-    }
-    else if (1.0 < x_c && x_c < 2.0) {
-        double d = 2.0-x_c;
-        return -0.75*eta*d*d;
-    }
-    else {
-        return 0.0;
-    }
+  if (-2.0 < x_c && x_c <= -1.0) {
+    double d = 2.0+x_c;
+    return 0.75*eta*d*d;
+  }
+  else if (-1.0 < x_c && x_c <= 1.0) {
+    return -3.0*eta*x_c + 2.25*eta*x_c*std::fabs(x_c);
+  }
+  else if (1.0 < x_c && x_c < 2.0) {
+    double d = 2.0-x_c;
+    return -0.75*eta*d*d;
+  }
+  else {
+    return 0.0;
+  }
 }
 inline double dB(double x_c, double eta, double f, double fp)
-    // def: x_c = eta*(rij-r_s)
-    // return d/dr_ij f(r_ij)*B(x_c)
+  // def: x_c = eta*(rij-r_s)
+  // return d/dr_ij f(r_ij)*B(x_c)
 {
-    if (-2.0 < x_c && x_c <= -1.0) {
-        double d = 2.0+x_c;
-        return f*(0.75*eta*d*d)+fp*(0.25 * d*d*d);
-    }
-    else if (-1.0 < x_c && x_c <= 1.0) {
-        return f*(-3.0*eta*x_c + 2.25*eta*x_c*std::fabs(x_c))
-            +fp*(1.0 - 1.5*x_c*x_c + 0.75*std::fabs(x_c*x_c*x_c));
-    }
-    else if (1.0 < x_c && x_c < 2.0) {
-        double d = 2.0-x_c;
-        return f*(-0.75*eta*d*d)+fp*(0.25 * d*d*d);
-    }
-    else {
-        return 0.0;
-    }
+  if (-2.0 < x_c && x_c <= -1.0) {
+    double d = 2.0+x_c;
+    return f*(0.75*eta*d*d)+fp*(0.25 * d*d*d);
+  }
+  else if (-1.0 < x_c && x_c <= 1.0) {
+    return f*(-3.0*eta*x_c + 2.25*eta*x_c*std::fabs(x_c))
+      +fp*(1.0 - 1.5*x_c*x_c + 0.75*std::fabs(x_c*x_c*x_c));
+  }
+  else if (1.0 < x_c && x_c < 2.0) {
+    double d = 2.0-x_c;
+    return f*(-0.75*eta*d*d)+fp*(0.25 * d*d*d);
+  }
+  else {
+    return 0.0;
+  }
+}
+inline double dB(double r, double eta, double miu) {
+  return dB(eta*(r-miu), eta);
 }
 #endif
diff --git a/include/tadah/models/linear_regressor.h b/include/tadah/models/linear_regressor.h
index bcb9edd..ce111bb 100644
--- a/include/tadah/models/linear_regressor.h
+++ b/include/tadah/models/linear_regressor.h
@@ -35,9 +35,10 @@ class LinearRegressor {
 
       int verbose = config.get<int>("VERBOSE");
       double lambda = config.get<double>("LAMBDA");
+      double rcond = config.size("LAMBDA")==2 ? config.get<double>("LAMBDA",1) : 1e-8;
 
       if (lambda == 0) {
-        OLS::solve(Phi, T, weights);
+        OLS::solve(Phi, T, weights, rcond);
       } else {
         double alpha = config.get<double>("ALPHA");
         double beta = config.get<double>("BETA");
diff --git a/include/tadah/models/m_krr_train.h b/include/tadah/models/m_krr_train.h
index d4e7903..f5b1c8e 100644
--- a/include/tadah/models/m_krr_train.h
+++ b/include/tadah/models/m_krr_train.h
@@ -155,6 +155,7 @@ public:
         if (M_KRR_Core<Kern>::is_verbose()) std::cout << "Matrix condition number: " << condition_number(K) << std::endl;
 
         if (M_KRR_Core<Kern>::is_verbose()) std::cout << "Solving..." << std::flush;
+        //double rcond = config.template size("LAMBDA")==2 ? config.template get<double>("LAMBDA",1) : 1e-8;
         weights = solve_posv(K, T);
         if (M_KRR_Core<Kern>::is_verbose()) std::cout << "Done" << std::endl;
 
diff --git a/include/tadah/models/ols.h b/include/tadah/models/ols.h
index c507b34..b95c774 100644
--- a/include/tadah/models/ols.h
+++ b/include/tadah/models/ols.h
@@ -26,7 +26,7 @@ class OLS {
      * @param weights Output vector containing the computed weights.
      */
     template <typename M, typename V>
-    static void solve(M &A, V &B, V &weights) {
+    static void solve(M &A, V &B, V &weights, double rcond) {
         // Resize B if necessary to match A's column count.
         if (B.size() < A.cols())
             B.resize(A.cols());
@@ -40,7 +40,6 @@ class OLS {
         double *b = B.ptr();
         int ldb = std::max(m, n);
         double *s = new double[std::min(m, n)]; // Singular values
-        double rcond = 1e-8; // Condition for singularity
         int rank;
         double *work;
         int lwork = -1; // Workspace query
diff --git a/include/tadah/models/ridge_regression.h b/include/tadah/models/ridge_regression.h
index 4eca711..f2d6691 100644
--- a/include/tadah/models/ridge_regression.h
+++ b/include/tadah/models/ridge_regression.h
@@ -46,7 +46,7 @@ class RidgeRegression {
      * @param lambda Regularization parameter.
      */
     template <typename V, typename W>
-    static void solve(const SVD &svd, V b, W &weights, const double lambda) {
+    static void solve(const SVD &svd, V b, W &weights, double lambda) {
       double *U = svd.getU(); // Matrix U from SVD (m x m)
       double *s = svd.getS(); // Singular values (as a vector)
       double *VT = svd.getVT(); // Matrix V^T from SVD (n x n)
diff --git a/tests/test_ols.cpp b/tests/test_ols.cpp
index 8e6e63e..f84c6b0 100644
--- a/tests/test_ols.cpp
+++ b/tests/test_ols.cpp
@@ -14,7 +14,7 @@ TEST_CASE("Testing OLS") {
   aed_type w(3);
   aed_type b2=b1;
 
-  OLS::solve(Phi1,b1,w);
+  OLS::solve(Phi1,b1,w, 1e-8);
   aed_type p= Phi2*w;
   REQUIRE_THAT(p[0], Catch::Matchers::WithinRel(b2[0]));
   REQUIRE_THAT(p[1], Catch::Matchers::WithinRel(b2[1]));
-- 
GitLab