From fed579d1d878c09bbd33638916ce15071e12f28e Mon Sep 17 00:00:00 2001
From: mkirsz <s1351949@sms.ed.ac.uk>
Date: Sun, 16 Feb 2025 23:11:01 +0000
Subject: [PATCH] OLS updates, unfinished integration of RR with OLS

---
 include/tadah/models/ea.h               |   3 +-
 include/tadah/models/linear_regressor.h | 165 +++++++-----
 include/tadah/models/ols.h              | 323 +++++++++++-------------
 include/tadah/models/ridge_regression.h |  64 +++--
 include/tadah/models/svd.h              |  23 +-
 src/ols.cpp                             | 308 +++++++++-------------
 tests/test_ridge_regression.cpp         |   4 +-
 7 files changed, 438 insertions(+), 452 deletions(-)

diff --git a/include/tadah/models/ea.h b/include/tadah/models/ea.h
index d3751b7..3ea1175 100644
--- a/include/tadah/models/ea.h
+++ b/include/tadah/models/ea.h
@@ -51,10 +51,11 @@ class EA {
       double *t = new double[svd.shapeA().first];
       double *x = new double[svd.shapeA().second];
       t_type m_n((size_t)svd.sizeS());
+    double rcond = config.size("OALGO")==2 ? config.get<double>("OALGO",1) : 1e-8;
       while (test1>EPS2 || test2>EPS2) {
 
         // regularised least square problem (Tikhonov Regularization):
-        RidgeRegression::solve(svd,T,m_n,lambda);
+        RidgeRegression::solve(svd,T,m_n,lambda,rcond);
 
         double gamma = 0.0;
         for (size_t j=0; j<svd.sizeS(); j++) {
diff --git a/include/tadah/models/linear_regressor.h b/include/tadah/models/linear_regressor.h
index 3a56af1..0485982 100644
--- a/include/tadah/models/linear_regressor.h
+++ b/include/tadah/models/linear_regressor.h
@@ -1,6 +1,8 @@
 #ifndef LINEAR_REGRESSOR_H
 #define LINEAR_REGRESSOR_H
 
+#include <cstddef>
+#include <limits>
 #include <tadah/core/config.h>
 #include <tadah/core/core_types.h>
 #include <tadah/models/ea.h>
@@ -29,42 +31,73 @@ class LinearRegressor {
    * @param weights Vector to store the resultant weights.
    * @param Sigma Matrix to store the covariance matrix.
    */
-  public:
-    static void train(Config &config, phi_type &Phi, t_type &T,
-                      t_type &weights, Matrix &Sigma) {
-
-      int verbose = config.get<int>("VERBOSE");
-      double lambda = config.get<double>("LAMBDA");
-      double rcond = config.size("LAMBDA")==2 ? config.get<double>("LAMBDA",1) : 1e-8;
+public:
+  static void train(Config &config, phi_type &Phi, t_type &T,
+                    t_type &weights, Matrix &Sigma) {
+
+    int verbose = config.get<int>("VERBOSE");
+    double lambda = config.get<double>("LAMBDA",0);
+    int oalgo =  config.get<int>("OALGO",0);
+    double rcond = config.size("OALGO")==2 ? config.get<double>("OALGO",1) : 1e-8;
     weights.resize(Phi.cols());
 
-      if (lambda == 0) {
-        OLS::solve(Phi, T, weights, rcond, OLS::Algorithm::GELSD);
-      } else {
+    if (std::abs(lambda+1) < std::numeric_limits<double>::min() && oalgo != 3) {
+      throw std::runtime_error("LAMBDA -1 requires OALGO 3");
+    }
+
+    if (oalgo==1) {
+      size_t j=0;
+      if (lambda>0) {
+        for (size_t i=Phi.rows()-Phi.cols(); i<Phi.rows();++i) {
+          Phi(i,j++) = std::sqrt(lambda);
+        }
+      }
+      OLS::solve(Phi, T, weights, rcond, OLS::Algorithm::GELSD);
+    }
+    else if (oalgo==2) {
+      size_t j=0;
+      if (lambda>0) {
+        for (size_t i=Phi.rows()-Phi.cols(); i<Phi.rows();++i) {
+          Phi(i,j++) = std::sqrt(lambda);
+        }
+      }
+      OLS::solve(Phi, T, weights, rcond, OLS::Algorithm::GELS);
+    }
+    else if (oalgo==3) {
+      SVD svd = SVD(Phi);;
+      //OLS::Workspace ws;
+      if (lambda < 0) {
         double alpha = config.get<double>("ALPHA");
         double beta = config.get<double>("BETA");
+        EA ea(config, svd, T);
+        ea.run(alpha, beta);
+        lambda = alpha / beta;
+        config.remove("ALPHA");
+        config.remove("BETA");
+        config.add("ALPHA", alpha);
+        config.add("BETA", beta);
+      }
 
-        SVD svd(Phi);
-
-        if (lambda < 0) {
-          EA ea(config, svd, T);
-          ea.run(alpha, beta);
-          lambda = alpha / beta;
-          config.remove("ALPHA");
-          config.remove("BETA");
-          config.add("ALPHA", alpha);
-          config.add("BETA", beta);
-        }
+      //ws.U = svd.getU();
+      //ws.S = svd.getS();
+      //ws.VT = svd.getVT();
 
-        Sigma = get_sigma(svd, lambda);
-        config_add_sigma(config, Sigma);
+      Sigma = get_sigma(svd, lambda);
+      //std::cout << "Sigma: " << Sigma << std::endl;
+      config_add_sigma(config, Sigma);
 
-        if (verbose) std::cout << std::endl << "REG LAMBDA: " << lambda << std::endl;
-        RidgeRegression::solve(svd, T, weights, lambda);
-      }
+      if (verbose) std::cout << std::endl << "REG LAMBDA: " << lambda << std::endl;
+      //OLS::solve_with_workspace(Phi, T, weights, ws, rcond, OLS::Algorithm::SVD);
+      RidgeRegression::solve(svd, T, weights, lambda, rcond);
+      //Sigma = get_sigma(svd, lambda);
+      //std::cout << "Sigma: " << Sigma << std::endl;
     }
+    else {
+      throw std::runtime_error("Unsupported OALGO: " + std::to_string(oalgo));
+    }
+  }
 
-    /** 
+  /** 
      * @brief Compute and return the covariance matrix.
      * 
      * Uses the SVD to calculate the covariance matrix with regularization.
@@ -73,27 +106,27 @@ class LinearRegressor {
      * @param lambda Regularization parameter.
      * @return Covariance matrix.
      */
-    static Matrix get_sigma(SVD &svd, double lambda) {
-      double *VT = svd.getVT();
-      double *S = svd.getS();
-      int n = static_cast<int>(svd.shapeVT().first);
-
-      Matrix Sigma((size_t)n, (size_t)n);
-      Sigma.set_zero();
-      double *sigma = &Sigma.data()[0];
-
-      for (int i = 0; i < n; ++i) {
-        double ridge = 1.0 / (S[i] * S[i] + lambda);
-        for (int j = 0; j < n; ++j) {
-          for (int k = 0; k < n; ++k) {
-            sigma[j + k * n] += VT[j + i * n] * ridge * VT[k + i * n]; // Column-major
-          }
+  static Matrix get_sigma(SVD &svd, double lambda) {
+    double *VT = svd.getVT();
+    double *S = svd.getS();
+    int n = static_cast<int>(svd.shapeVT().first);
+
+    Matrix Sigma((size_t)n, (size_t)n);
+    Sigma.set_zero();
+    double *sigma = &Sigma.data()[0];
+
+    for (int i = 0; i < n; ++i) {
+      double ridge = 1.0 / (S[i] * S[i] + lambda);
+      for (int j = 0; j < n; ++j) {
+        for (int k = 0; k < n; ++k) {
+          sigma[j + k * n] += VT[j + i * n] * ridge * VT[k + i * n]; // Column-major
         }
       }
-      return Sigma;
     }
+    return Sigma;
+  }
 
-    /**
+  /**
      * @brief Add the covariance matrix to the configuration.
      * 
      * Converts the Sigma matrix to a vector and stores it in the configuration.
@@ -101,13 +134,13 @@ class LinearRegressor {
      * @param config Configuration object.
      * @param Sigma Covariance matrix to be added.
      */
-    static void config_add_sigma(Config &config, Matrix &Sigma) {
-      t_type Sigma_as_vector(Sigma.data(), Sigma.cols() * Sigma.rows());
-      config.add("SIGMA", Sigma.rows());
-      config.add<t_type>("SIGMA", Sigma_as_vector);
-    }
+  static void config_add_sigma(Config &config, Matrix &Sigma) {
+    t_type Sigma_as_vector(Sigma.data(), Sigma.cols() * Sigma.rows());
+    config.add("SIGMA", Sigma.rows());
+    config.add<t_type>("SIGMA", Sigma_as_vector);
+  }
 
-    /**
+  /**
      * @brief Read and reconstruct the covariance matrix from the configuration.
      * 
      * Retrieves the Sigma matrix stored in the configuration.
@@ -116,22 +149,22 @@ class LinearRegressor {
      * @param Sigma Matrix to store the reconstructed Sigma.
      * @throws std::runtime_error if Sigma is not computed.
      */
-    static void read_sigma(Config &config, Matrix &Sigma) {
-      using V = std::vector<double>;
-      size_t N;
-      try {
-        N = config.get<size_t>("SIGMA");
-      } catch (const std::runtime_error& error) {
-        throw std::runtime_error("\nSigma matrix is not computed.\nHint: It is only computed for LAMBDA != 0.\n");
-      }
-
-      V Sigma_as_vector(N * N + 1);
-      config.get<V>("SIGMA", Sigma_as_vector);
-      Sigma_as_vector.erase(Sigma_as_vector.begin());
-      Sigma.resize(N, N);
-      for (size_t c = 0; c < N; ++c)
-        for (size_t r = 0; r < N; ++r)
-          Sigma(r, c) = Sigma_as_vector.at(N * c + r);
+  static void read_sigma(Config &config, Matrix &Sigma) {
+    using V = std::vector<double>;
+    size_t N;
+    try {
+      N = config.get<size_t>("SIGMA");
+    } catch (const std::runtime_error& error) {
+      throw std::runtime_error("\nSigma matrix is not computed.\nHint: It is only computed for LAMBDA != 0.\n");
     }
+
+    V Sigma_as_vector(N * N + 1);
+    config.get<V>("SIGMA", Sigma_as_vector);
+    Sigma_as_vector.erase(Sigma_as_vector.begin());
+    Sigma.resize(N, N);
+    for (size_t c = 0; c < N; ++c)
+      for (size_t r = 0; r < N; ++r)
+        Sigma(r, c) = Sigma_as_vector.at(N * c + r);
+  }
 };
 #endif
diff --git a/include/tadah/models/ols.h b/include/tadah/models/ols.h
index ca94d21..fce3656 100644
--- a/include/tadah/models/ols.h
+++ b/include/tadah/models/ols.h
@@ -6,238 +6,211 @@
 #include <cmath>
 #include <stdexcept>
 #include <cstdlib> // For malloc and free
+#include <algorithm> // For std::max
 
 /**
  * @class OLS
  * @brief Provides functionality for solving Ordinary Least Squares (OLS) problems.
  *
- * Utilizes LAPACK's DGELS and DGELSD functions to solve linear systems where the goal is to minimize 
- * the Euclidean norm of residuals. By default, it uses DGELSD, but users can select DGELS if desired.
+ * Utilizes various algorithms to solve linear systems where the goal is to minimize 
+ * the Euclidean norm of residuals. Supports the following algorithms:
  *
- * **Algorithm Selection:**
- * - **GELSD (`DGELSD`):** Uses a singular value decomposition (SVD) approach, which is more robust and
- *   can handle rank-deficient or ill-conditioned systems. It can provide a minimum-norm solution and
- *   handle cases where the matrix `A` does not have full rank. `DGELSD` is recommended when precision
- *   is important, or when the system may be rank-deficient or ill-conditioned.
- * - **GELS (`DGELS`):** Uses QR or LQ factorization, which is generally faster but less robust for
- *   ill-conditioned problems. It assumes that `A` has full rank. `DGELS` is suitable for well-conditioned
- *   systems where performance is critical and the matrix `A` is expected to have full rank.
- *
- * **Usage Recommendation:**
- * - Use **GELSD** (`Algorithm::GELSD`) when dealing with potentially rank-deficient or ill-conditioned systems,
- *   or when you require the most accurate solution at the expense of computational time.
- * - Use **GELS** (`Algorithm::GELS`) when working with well-conditioned systems where speed is more important
- *   than handling rank deficiency.
+ * **Algorithms:**
+ * - **GELSD (`DGELSD`):** Uses a singular value decomposition (SVD) approach, handling rank-deficient and ill-conditioned systems.
+ * - **GELS (`DGELS`):** Uses QR or LQ factorization, fast but less robust for ill-conditioned systems.
+ * - **SVD:** Custom implementation using SVD, allows for explicit control over singular values and regularization.
+ * - **Cholesky:** Solves the normal equations using Cholesky decomposition, efficient but may be less stable for ill-conditioned systems.
  *
  * **Workspace Management:**
- * - The class provides methods to solve the OLS problem with and without providing a preallocated workspace.
- *   Preallocating the workspace can improve performance when solving multiple problems of the same size.
- * - If the workspace is insufficient, the `solve_with_workspace` method will automatically reallocate it to
- *   the required size.
+ * - The class provides a `Workspace` struct to manage memory allocations for the computations.
+ * - Workspaces are allocated based on the dimensions of the problem and the algorithm selected.
  *
  * **Example:**
  * ```cpp
  * Matrix A; // Initialize with your data
  * Vector B; // Initialize with your data
  * Vector weights(A.cols());
+ * OLS::Workspace ws;
  *
  * // Solve using default algorithm (GELSD)
  * OLS::solve(A, B, weights);
  *
- * // Solve using GELS algorithm
- * OLS::solve(A, B, weights, -1.0, OLS::Algorithm::GELS);
+ * // Solve using Cholesky algorithm
+ * OLS::solve_with_workspace(A, B, weights, ws, -1.0, OLS::Algorithm::Cholesky);
  * ```
  */
 class OLS {
 public:
-    /**
+  /**
      * @brief Enum to select the algorithm used for solving the OLS problem.
      */
-    enum class Algorithm {
-        GELS,   // Uses DGELS
-        GELSD   // Uses DGELSD
-    };
-
-    /**
+  enum class Algorithm {
+    GELS,     // Uses DGELS
+    GELSD,    // Uses DGELSD
+    SVD,      // Uses custom SVD implementation
+    Cholesky  // Uses Cholesky decomposition
+  };
+
+  /**
      * @brief Workspace struct to hold pre-allocated memory.
      *
-     * This struct contains the necessary buffers for the LAPACK functions.
+     * This struct contains the necessary buffers for the computations in different algorithms.
      * Users can create an instance of `Workspace` and pass it to the `solve_with_workspace` method.
      */
-    struct Workspace {
-        double* s;    // Singular values array (used by DGELSD)
-        int* iwork;   // Integer workspace array (used by DGELSD)
-        double* work; // Double workspace array
-        int lwork;    // Size of the work array
-        int m;        // Number of rows in A (for which workspace was allocated)
-        int n;        // Number of columns in A (for which workspace was allocated)
-        Algorithm algo; // Algorithm for which workspace was allocated
-
-        Workspace();
-        ~Workspace();
-
-        /**
-         * @brief Allocates memory for the workspace based on problem dimensions and algorithm.
-         * 
-         * @param m_ Number of rows in matrix A.
-         * @param n_ Number of columns in matrix A.
-         * @param algorithm The algorithm for which to allocate workspace.
-         */
-        void allocate(int m_, int n_, Algorithm algorithm = Algorithm::GELSD);
-
-        /**
-         * @brief Frees the allocated memory.
-         */
-        void release();
-
-        /**
-         * @brief Checks if the workspace is sufficient for the given dimensions and algorithm.
-         * 
-         * @param m_ Number of rows required.
-         * @param n_ Number of columns required.
-         * @param algorithm The algorithm required.
-         * @return True if workspace is sufficient, false otherwise.
-         */
-        bool is_sufficient(int m_, int n_, Algorithm algorithm = Algorithm::GELSD) const;
-    };
-
-    /**
-     * @brief Solves the OLS problem for the given matrix and vector without requiring external workspace.
-     *
-     * @tparam M Type for the matrix A.
-     * @tparam V Type for the vectors B and weights.
-     * @param A Input matrix (may be modified during computation).
-     * @param B Input target vector; contains the solution upon return.
-     * @param weights Output vector containing the computed weights.
-     * @param rcond The reciprocal of the condition number threshold (used for DGELSD). Default is -1.0.
-     * @param algorithm Algorithm to use (GELSD by default).
-     */
-    template <typename M, typename V>
-    static void solve(M& A, V& B, V& weights, double rcond, Algorithm algorithm);
-
-    /**
-     * @brief Solves the OLS problem for the given matrix and vector using provided workspace.
-     *
-     * @tparam M Type for the matrix A.
-     * @tparam V Type for the vectors B and weights.
-     * @param A Input matrix (may be modified during computation).
-     * @param B Input target vector; contains the solution upon return.
-     * @param weights Output vector containing the computed weights.
-     * @param ws Reference to the Workspace instance.
-     * @param rcond The reciprocal of the condition number threshold (used for DGELSD). Default is -1.0.
-     * @param algorithm Algorithm to use (GELSD by default).
-     */
-    template <typename M, typename V>
-    static void solve_with_workspace(M& A, V& B, V& weights, Workspace& ws, double rcond, Algorithm algorithm);
-
-    /**
-     * @brief Solves the OLS problem using default settings (GELSD algorithm and default rcond).
-     *
-     * @tparam M Type for the matrix A.
-     * @tparam V Type for the vectors B and weights.
-     * @param A Input matrix (may be modified during computation).
-     * @param B Input target vector; contains the solution upon return.
-     * @param weights Output vector containing the computed weights.
-     */
-    template <typename M, typename V>
-    static void solve(M& A, V& B, V& weights);
-
-    /**
-     * @brief Solves the OLS problem using default settings and provided workspace.
-     *
-     * @tparam M Type for the matrix A.
-     * @tparam V Type for the vectors B and weights.
-     * @param A Input matrix (may be modified during computation).
-     * @param B Input target vector; contains the solution upon return.
-     * @param weights Output vector containing the computed weights.
-     * @param ws Reference to the Workspace instance.
-     */
-    template <typename M, typename V>
-    static void solve_with_workspace(M& A, V& B, V& weights, Workspace& ws);
+  struct Workspace {
+    // Common workspace variables
+    double* work; // Double workspace array
+    int lwork;    // Size of the work array
+    int m;        // Number of rows in A (for which workspace was allocated)
+    int n;        // Number of columns in A (for which workspace was allocated)
+    Algorithm algo; // Algorithm for which workspace was allocated
+
+    // For DGELSD (GELSD algorithm)
+    double* s;    // Singular values array
+    int* iwork;   // Integer workspace array
+
+    // For SVD algorithm
+    double* U;    // U matrix from SVD (m x n)
+    double* S;    // Singular values array (size min(m,n))
+    double* VT;   // V^T matrix from SVD (n x n)
+
+    // For Cholesky algorithm
+    double* AtA;  // A^T * A matrix (n x n)
+    double* Atb;  // A^T * b vector (n)
+
+    // Constructors and methods (declarations only)
+    Workspace();
+    ~Workspace();
+
+    void allocate(int m_, int n_, Algorithm algorithm = Algorithm::GELSD);
+    void release();
+    bool is_sufficient(int m_, int n_, Algorithm algorithm = Algorithm::GELSD) const;
+  };
+
+  // Public methods
+  template <typename M, typename V>
+  static void solve(M& A, V& B, V& weights, double rcond, Algorithm algorithm);
+
+  template <typename M, typename V>
+  static void solve_with_workspace(M& A, V& B, V& weights, Workspace& ws, double rcond, Algorithm algorithm);
+
+  template <typename M, typename V>
+  static void solve(M& A, V& B, V& weights);
+
+  template <typename M, typename V>
+  static void solve_with_workspace(M& A, V& B, V& weights, Workspace& ws);
 };
 
-// Implement templated methods here
+// Implement template methods here
+// Since they are templates, definitions must be in the header
 
 template <typename M, typename V>
-void OLS::solve(M& A, V& B, V& weights, double rcond, Algorithm algorithm)
-{
-    Workspace ws;
-    solve_with_workspace(A, B, weights, ws, rcond, algorithm);
+void OLS::solve(M& A, V& B, V& weights, double rcond, Algorithm algorithm) {
+  Workspace ws;
+  solve_with_workspace(A, B, weights, ws, rcond, algorithm);
 }
 
 template <typename M, typename V>
-void OLS::solve_with_workspace(M& A, V& B, V& weights, Workspace& ws, double rcond, Algorithm algorithm)
-{
-    int m = A.rows();
-    int n = A.cols();
-
-    // Check for valid algorithm
-    if (algorithm != Algorithm::GELSD && algorithm != Algorithm::GELS) {
-        throw std::invalid_argument("Invalid algorithm specified.");
-    }
+void OLS::solve_with_workspace(M& A, V& B, V& weights, Workspace& ws, double rcond, Algorithm algorithm) {
+  int m = A.rows();
+  int n = A.cols();
+
+  // Check for valid algorithm
+  if (algorithm != Algorithm::GELSD && algorithm != Algorithm::GELS &&
+    algorithm != Algorithm::SVD && algorithm != Algorithm::Cholesky) {
+    throw std::invalid_argument("Invalid algorithm specified.");
+  }
+
+  // Check that dimensions of weights match expected size
+  if (weights.size() != static_cast<std::size_t>(n)) {
+    throw std::invalid_argument("weights size is incorrect.");
+  }
+
+  // Check that B has sufficient size
+  int maxmn = std::max(m, n);
+  if (B.size() < static_cast<std::size_t>(maxmn)) {
+    throw std::invalid_argument("B size is insufficient.");
+  }
+
+  // Check if workspace is sufficient; if not, reallocate
+  if (!ws.is_sufficient(m, n, algorithm)) {
+    ws.allocate(m, n, algorithm);
+  }
+
+  int info;
+  if (algorithm == Algorithm::GELSD) {
+    int nrhs = 1;
+    double* a = A.ptr();
+    int lda = m;
+    double* b = B.ptr();
+    int ldb = maxmn;
+    int rank;
 
-    // Check that dimensions of weights match expected size
-    if (weights.size() != static_cast<std::size_t>(n)) {
-        throw std::invalid_argument("weights size is incorrect.");
-    }
+    // Solve using DGELSD
+    ::dgelsd_(&m, &n, &nrhs, a, &lda, b, &ldb, ws.s, &rcond, &rank,
+              ws.work, &ws.lwork, ws.iwork, &info);
 
-    // Check that B has sufficient size
-    int maxmn = std::max(m, n);
-    if (B.size() < static_cast<std::size_t>(maxmn)) {
-        throw std::invalid_argument("B size is insufficient.");
+    if (info != 0) {
+      throw std::runtime_error("Error in DGELSD: info = " + std::to_string(info));
     }
 
-    // Check if workspace is sufficient; if not, reallocate
-    if (!ws.is_sufficient(m, n, algorithm)) {
-        ws.allocate(m, n, algorithm);
+    // Copy solution to weights
+    for (int i = 0; i < n; ++i) {
+      weights[i] = b[i];
     }
-
+  } else if (algorithm == Algorithm::GELS) {
     int nrhs = 1;
     double* a = A.ptr();
     int lda = m;
     double* b = B.ptr();
-    int ldb = maxmn;
-    int info;
-
-    if (algorithm == Algorithm::GELSD) {
-        int rank;
-
-        // Solve using DGELSD
-        ::dgelsd_(&m, &n, &nrhs, a, &lda, b, &ldb, ws.s, &rcond, &rank,
-                ws.work, &ws.lwork, ws.iwork, &info);
-    } else { // Algorithm::GELS
-        // Solve using DGELS
-        char trans = 'N';
-        ::dgels_(&trans, &m, &n, &nrhs, a, &lda, b, &ldb,
-               ws.work, &ws.lwork, &info);
-    }
+    int ldb = m;
+    char trans = 'N';
+
+    // Solve using DGELS
+    ::dgels_(&trans, &m, &n, &nrhs, a, &lda, b, &ldb,
+             ws.work, &ws.lwork, &info);
 
     if (info != 0) {
-        throw std::runtime_error("Error in LAPACK function: info = " + std::to_string(info));
+      throw std::runtime_error("Error in DGELS: info = " + std::to_string(info));
     }
 
     // Copy solution to weights
     for (int i = 0; i < n; ++i) {
-        weights[i] = b[i];
+      weights[i] = b[i];
     }
+  } else if (algorithm == Algorithm::SVD) {
+    // Custom SVD-based solution
+    // Implementation can be provided or left as an exercise
+    //throw std::runtime_error("SVD algorithm is not implemented in this version.");
+    // Allocate memory for intermediate calculations
+
+
+
+
+
+
+  } else if (algorithm == Algorithm::Cholesky) {
+    // Cholesky-based solution
+    // Implementation can be provided or left as an exercise
+    throw std::runtime_error("Cholesky algorithm is not implemented in this version.");
+  }
 }
 
 template <typename M, typename V>
-void OLS::solve(M& A, V& B, V& weights)
-{
-    // Default rcond and algorithm
-    double rcond = -1.0;
-    Algorithm algorithm = Algorithm::GELSD;
-    solve(A, B, weights, rcond, algorithm);
+void OLS::solve(M& A, V& B, V& weights) {
+  // Default rcond and algorithm
+  double rcond = -1.0;
+  Algorithm algorithm = Algorithm::GELSD;
+  solve(A, B, weights, rcond, algorithm);
 }
 
 template <typename M, typename V>
-void OLS::solve_with_workspace(M& A, V& B, V& weights, Workspace& ws)
-{
-    // Default rcond and algorithm
-    double rcond = -1.0;
-    Algorithm algorithm = Algorithm::GELSD;
-    solve_with_workspace(A, B, weights, ws, rcond, algorithm);
+void OLS::solve_with_workspace(M& A, V& B, V& weights, Workspace& ws) {
+  // Default rcond and algorithm
+  double rcond = -1.0;
+  Algorithm algorithm = Algorithm::GELSD;
+  solve_with_workspace(A, B, weights, ws, rcond, algorithm);
 }
 
 #endif // OLS_H
+
diff --git a/include/tadah/models/ridge_regression.h b/include/tadah/models/ridge_regression.h
index f2d6691..7f208e3 100644
--- a/include/tadah/models/ridge_regression.h
+++ b/include/tadah/models/ridge_regression.h
@@ -5,6 +5,7 @@
 #include <tadah/core/maths.h>
 #include <tadah/core/lapack.h>
 #include <tadah/models/svd.h>
+#include <algorithm> // For std::max_element
 
 /**
  * @class RidgeRegression
@@ -16,25 +17,45 @@
 class RidgeRegression {
   public:
     /**
-     * @brief Inverts and multiplies sigma values by lambda for regularization.
+     * @brief Inverts and multiplies sigma values by lambda for regularization,
+     *        considering the rcond threshold.
      *
      * Performs an element-wise operation on the sigma values to handle regularization.
+     * Singular values below the threshold are discarded (set to zero).
      *
      * @tparam V Vector type for sigma and result.
      * @param sigma Input vector of sigma values.
      * @param result Output vector for inverted and multiplied results.
      * @param n Size of the vectors.
      * @param lambda Regularization parameter.
+     * @param rcond Reciprocal of the condition number; threshold for singular values.
      */
     template <typename V>
-    static void invertMultiplySigmaLambda(const V& sigma, V& result, int n, double lambda) {
+    static void invertMultiplySigmaLambdaRcond(const V& sigma, V& result, int n, double lambda, double rcond) {
+      // Find the maximum singular value
+      double smax = 0.0;
       for (int i = 0; i < n; ++i) {
-        result[i] = sigma[i] != 0.0 ? sigma[i] / (sigma[i] * sigma[i] + lambda) : 0.0;
+        if (sigma[i] > smax) {
+          smax = sigma[i];
+        }
+      }
+
+      // Threshold for singular values
+      double thresh = rcond > 0.0 ? rcond * smax : 0.0;
+
+      for (int i = 0; i < n; ++i) {
+        if (sigma[i] > thresh) {
+          result[i] = sigma[i] / (sigma[i] * sigma[i] + lambda);
+        } else {
+          // Singular value is too small; set result to zero
+          result[i] = 0.0;
+        }
       }
     }
 
     /**
-     * @brief Solves the ridge regression problem using SVD components.
+     * @brief Solves the ridge regression problem using SVD components,
+     *        incorporating rcond for singular value thresholding.
      *
      * Computes the weights that minimize the regularized least squares error.
      *
@@ -44,38 +65,41 @@ class RidgeRegression {
      * @param b Vector of target values.
      * @param weights Output vector for computed weights.
      * @param lambda Regularization parameter.
+     * @param rcond Reciprocal of the condition number; threshold for singular values.
      */
     template <typename V, typename W>
-    static void solve(const SVD &svd, V b, W &weights, double lambda) {
-      double *U = svd.getU(); // Matrix U from SVD (m x m)
-      double *s = svd.getS(); // Singular values (as a vector)
-      double *VT = svd.getVT(); // Matrix V^T from SVD (n x n)
-
-      // Dynamic memory allocation for intermediate calculations
-      double *d = new double[svd.sizeS()];
+    static void solve(const SVD &svd, V b, W &weights, double lambda, double rcond) {
+      double *U = svd.getU();    // Matrix U from SVD (m x n)
+      double *s = svd.getS();    // Singular values (as a vector)
+      double *VT = svd.getVT();  // Matrix V^T from SVD (n x n)
 
       int m = svd.shapeA().first;
       int n = svd.shapeA().second;
 
-      double alpha_ = 1.0;  // Scalar used in calculations
-      double beta_ = 0.0;   // Scalar used in calculations
+      // Allocate memory for intermediate calculations
+      double *d = new double[n];   // For inverted singular values
+      double *UTb = new double[n]; // For U^T * b
 
       // Step 1: Compute U^T * b
       char trans = 'T';
-      double *UTb = new double[n];
+      double alpha_ = 1.0;  // Scalar for multiplication
+      double beta_ = 0.0;   // Scalar for addition
       int incx = 1;
-      dgemv_(&trans, &m, &n, &alpha_, U, &m, b.ptr(), &incx, &beta_, UTb, &incx);
+      int incy = 1;
+      dgemv_(&trans, &m, &n, &alpha_, U, &m, b.ptr(), &incx, &beta_, UTb, &incy);
+
+      // Step 2: Invert and multiply sigma values, applying rcond threshold
+      invertMultiplySigmaLambdaRcond(s, d, n, lambda, rcond);
 
-      // Step 2: Element-wise multiply and invert sigma in result
-      invertMultiplySigmaLambda(s, d, n, lambda);
+      // Element-wise multiplication: d[i] = d[i] * (U^T * b)[i]
       for (int i = 0; i < n; ++i) {
         d[i] *= UTb[i];
       }
 
-      // Step 3: Compute V * (D * (U^T * b))
-      trans = 'T';
+      // Step 3: Compute weights = V * d
+      trans = 'T'; // Since VT is V^T, transposing VT gives V
       weights.resize(n);
-      dgemv_(&trans, &n, &n, &alpha_, VT, &n, d, &incx, &beta_, weights.ptr(), &incx);
+      dgemv_(&trans, &n, &n, &alpha_, VT, &n, d, &incx, &beta_, weights.ptr(), &incy);
 
       // Cleanup dynamic memory
       delete[] d;
diff --git a/include/tadah/models/svd.h b/include/tadah/models/svd.h
index bca5f9a..707c20d 100644
--- a/include/tadah/models/svd.h
+++ b/include/tadah/models/svd.h
@@ -107,12 +107,23 @@ class SVD {
       dgesdd_(&jobz, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt,
           work, &lwork, iwork, &info);
 
-      // Filter out very small singular values
-      for (size_t i = 0; i < sizeS(); ++i) {
-        if (s[i] < std::numeric_limits<double>::min()) {
-          s[i] = 0.0;
-        }
-      }
+    //   // Find the maximum singular value
+    //   double smax = 0.0;
+    //   for (int i = 0; i < n; ++i) {
+    //     if (s[i] > smax) {
+    //       smax = s[i];
+    //     }
+    //   }
+    //
+    //   // Threshold for singular values
+    //   double thresh = rcond > 0.0 ? rcond * smax : 0.0;
+    //
+    // for (int i = 0; i < n; ++i) {
+    //   if (s[i] < thresh) {
+    //     // Singular value is too small; set result to zero
+    //     s[i] = 0.0;
+    //   }
+    // }
     }
 
     /**
diff --git a/src/ols.cpp b/src/ols.cpp
index a465a5b..22836f1 100644
--- a/src/ols.cpp
+++ b/src/ols.cpp
@@ -2,210 +2,154 @@
 #include <cmath>
 #include <stdexcept>
 #include <cstdlib> // For malloc and free
+#include <algorithm> // For std::max
 
-// Ensure that LAPACK functions are declared in the global namespace
-// If not, include the necessary headers or declarations
+// Implementation of non-template methods
 
+// Constructor
 OLS::Workspace::Workspace()
-    : s(nullptr), iwork(nullptr), work(nullptr), lwork(0), m(0), n(0), algo(Algorithm::GELSD)
-{
+  : work(nullptr), lwork(-1), m(0), n(0), algo(Algorithm::GELSD),
+  s(nullptr), iwork(nullptr), U(nullptr), S(nullptr), VT(nullptr),
+  AtA(nullptr), Atb(nullptr) {
 }
 
-OLS::Workspace::~Workspace()
-{
-    release();
+// Destructor
+OLS::Workspace::~Workspace() {
+  release();
 }
 
-void OLS::Workspace::allocate(int m_, int n_, Algorithm algorithm)
-{
-    // Check for valid algorithm
-    if (algorithm != Algorithm::GELSD && algorithm != Algorithm::GELS) {
-        throw std::invalid_argument("Invalid algorithm specified in Workspace::allocate().");
-    }
+// allocate method
+void OLS::Workspace::allocate(int m_, int n_, Algorithm algorithm) {
+  release(); // Release any existing memory
 
-    // Release existing memory if allocated
-    release();
-
-    m = m_;
-    n = n_;
-    algo = algorithm;
-
-    if (algorithm == Algorithm::GELSD) {
-        int minmn = std::min(m, n);
-
-        // Allocate s for singular values
-        s = reinterpret_cast<double*>(malloc(minmn * sizeof(double)));
-        if (!s) throw std::bad_alloc();
-
-        // Compute nlvl and allocate iwork
-        int smlsiz = 25;
-        int nlvl = std::max(0, int(std::log2(static_cast<double>(minmn) / (smlsiz + 1))) + 1);
-        int iwork_size = 3 * minmn * nlvl + 11 * minmn;
-        iwork = reinterpret_cast<int*>(malloc(iwork_size * sizeof(int)));
-        if (!iwork) {
-            release();
-            throw std::bad_alloc();
-        }
-    } else {
-        // For DGELS
-        s = nullptr;
-        iwork = nullptr;
-    }
+  m = m_;
+  n = n_;
+  algo = algorithm;
+
+  int info;
+
+  if (algorithm == Algorithm::GELSD) {
+    // Allocate workspace for DGELSD
+    int minmn = std::min(m, n);
+    int maxmn = std::max(m, n);
+    int nlvl = std::max(0, static_cast<int>(std::log2(minmn / 25.0)) + 1);
+
+    // Allocate s (singular values)
+    s = new double[minmn];
+
+    // Allocate iwork
+    int liwork = 3 * minmn * nlvl + 11 * minmn;
+    iwork = new int[liwork];
 
-    // Query for optimal work array size
-    double wkopt;
+    // Workspace query
+    double work_query;
     int lwork_query = -1;
     int nrhs = 1;
-    int lda = m;
-    int ldb = std::max(m, n);
-
-    // Allocate minimal dummy arrays for A and B
-    double* adummy = reinterpret_cast<double*>(malloc(m * n * sizeof(double)));
-    double* bdummy = reinterpret_cast<double*>(malloc(ldb * nrhs * sizeof(double)));
-    if (!adummy || !bdummy) {
-        if (adummy) ::free(adummy);
-        if (bdummy) ::free(bdummy);
-        release();
-        throw std::bad_alloc();
-    }
-
+    double* a_dummy = nullptr;
+    double* b_dummy = nullptr;
+    ::dgelsd_(&m, &n, &nrhs, a_dummy, &m, b_dummy, &maxmn, s, nullptr,
+              nullptr, &work_query, &lwork_query, iwork, &info);
+    lwork = static_cast<int>(work_query);
+    work = new double[lwork];
+  } else if (algorithm == Algorithm::GELS) {
+    // Allocate workspace for DGELS
+    double work_query;
+    int lwork_query = -1;
+    int nrhs = 1;
+    double* a_dummy = nullptr;
+    double* b_dummy = nullptr;
+    char trans = 'N';
+    ::dgels_(&trans, &m, &n, &nrhs, a_dummy, &m, b_dummy, &m, &work_query, &lwork_query, &info);
+    lwork = static_cast<int>(work_query);
+    work = new double[lwork];
+  } else if (algorithm == Algorithm::SVD) {
+    // Allocate workspace for custom SVD implementation using DGESDD
+    int minmn = std::min(m, n);
+    int maxmn = std::max(m, n);
+
+    // Set jobz parameter to compute singular vectors
+    char jobz = 'S'; // Compute the first min(m,n) singular vectors
+
+    // Leading dimensions
+    int lda = m;      // Leading dimension of A
+    int ldu = m;      // Leading dimension of U
+    int ldvt = minmn; // Leading dimension of VT
+
+    // Allocate arrays for U, S, and VT
+    U = new double[ldu * minmn];       // U is m x min(m,n)
+    S = new double[minmn];             // Singular values vector
+    VT = new double[ldvt * n];         // VT is min(m,n) x n
+
+    // Allocate integer workspace for DGESDD
+    int* iwork = new int[8 * minmn];
+
+    // Workspace query for DGESDD
+    double work_query;
+    int lwork = -1;
     int info;
 
-    if (algorithm == Algorithm::GELSD) {
-        int rank;
-        double rcond_query = -1;  // Use default rcond
-
-        // Set up dummy LAPACK call
-        ::dgelsd_(&m, &n, &nrhs, adummy, &lda, bdummy, &ldb, s, &rcond_query, &rank,
-                &wkopt, &lwork_query, iwork, &info);
-    } else { // Algorithm::GELS
-        char trans = 'N';
-        ::dgels_(&trans, &m, &n, &nrhs, adummy, &lda, bdummy, &ldb, &wkopt, &lwork_query, &info);
-    }
+    // Create a dummy A matrix for the workspace query
+    double* a_dummy = nullptr;
 
-    // Free dummy arrays
-    ::free(adummy);
-    ::free(bdummy);
+    // Perform the workspace query
+    ::dgesdd_(&jobz, &m, &n, a_dummy, &lda, S, U, &ldu, VT, &ldvt,
+              &work_query, &lwork, iwork, &info);
 
+    // Check for errors
     if (info != 0) {
-        release();
-        throw std::runtime_error("Error in workspace query: info = " + std::to_string(info));
+      delete[] iwork;
+      throw std::runtime_error("Error in DGESDD workspace query: INFO = " + std::to_string(info));
     }
 
-    lwork = static_cast<int>(wkopt);
-    work = reinterpret_cast<double*>(malloc(lwork * sizeof(double)));
-    if (!work) {
-        release();
-        throw std::bad_alloc();
-    }
-}
+    // Allocate the optimal workspace
+    lwork = static_cast<int>(work_query);
+    work = new double[lwork];
 
-void OLS::Workspace::release()
-{
-    if (s) ::free(s);
-    if (iwork) ::free(iwork);
-    if (work) ::free(work);
-    s = nullptr;
-    iwork = nullptr;
+  } else if (algorithm == Algorithm::Cholesky) {
+    // Allocate space for AtA (n x n) and Atb (n)
+    AtA = new double[n * n];
+    Atb = new double[n];
+
+    // No additional workspace required for Cholesky decomposition
     work = nullptr;
     lwork = 0;
-    m = 0;
-    n = 0;
-    algo = Algorithm::GELSD;
+  } else {
+    throw std::invalid_argument("Invalid algorithm specified in Workspace::allocate().");
+  }
 }
 
-bool OLS::Workspace::is_sufficient(int m_, int n_, Algorithm algorithm) const
-{
-    // Check for valid algorithm
-    if (algorithm != Algorithm::GELSD && algorithm != Algorithm::GELS) {
-        throw std::invalid_argument("Invalid algorithm specified in Workspace::is_sufficient().");
-    }
-
-    // Check if current workspace dimensions and algorithm are sufficient
-    if (algo != algorithm || m < m_ || n < n_) {
-        return false;
-    }
-
-    int maxmn = std::max(m_, n_);
-
-    if (algorithm == Algorithm::GELSD) {
-        // Check if s is allocated
-        if (!s) return false;
-
-        // For work size, check if lwork is sufficient
-        // We can perform a workspace query to find the required lwork
-        double wkopt;
-        int lwork_query = -1;
-        int nrhs = 1;
-        int lda = m_;
-        int ldb = maxmn;
-
-        // Allocate minimal dummy arrays for A and B
-        double* adummy = reinterpret_cast<double*>(malloc(m_ * n_ * sizeof(double)));
-        double* bdummy = reinterpret_cast<double*>(malloc(ldb * nrhs * sizeof(double)));
-        if (!adummy || !bdummy) {
-            if (adummy) ::free(adummy);
-            if (bdummy) ::free(bdummy);
-            throw std::bad_alloc();
-        }
-        int rank;
-        int info;
-
-        // Query optimal lwork
-        double rcond_query = -1;  // Use default rcond
-        ::dgelsd_(&m_, &n_, &nrhs, adummy, &lda, bdummy, &ldb, s, &rcond_query, &rank,
-                &wkopt, &lwork_query, iwork, &info);
-
-        // Free dummy arrays
-        ::free(adummy);
-        ::free(bdummy);
-
-        if (info != 0) {
-            throw std::runtime_error("Error in workspace sufficiency check: info = " + std::to_string(info));
-        }
-
-        int required_lwork = static_cast<int>(wkopt);
-
-        if (lwork < required_lwork) {
-            return false;
-        }
-
-    } else { // Algorithm::GELS
-        // For DGELS
-        // Query optimal lwork
-        double wkopt;
-        int lwork_query = -1;
-        int nrhs = 1;
-        int lda = m_;
-        int ldb = maxmn;
-        int info;
-        char trans = 'N';
-
-        // Allocate minimal dummy arrays for A and B
-        double* adummy = reinterpret_cast<double*>(malloc(m_ * n_ * sizeof(double)));
-        double* bdummy = reinterpret_cast<double*>(malloc(ldb * nrhs * sizeof(double)));
-        if (!adummy || !bdummy) {
-            if (adummy) ::free(adummy);
-            if (bdummy) ::free(bdummy);
-            throw std::bad_alloc();
-        }
-
-        ::dgels_(&trans, &m_, &n_, &nrhs, adummy, &lda, bdummy, &ldb, &wkopt, &lwork_query, &info);
-
-        // Free dummy arrays
-        ::free(adummy);
-        ::free(bdummy);
-
-        if (info != 0) {
-            throw std::runtime_error("Error in workspace sufficiency check: info = " + std::to_string(info));
-        }
-
-        int required_lwork = static_cast<int>(wkopt);
-        if (lwork < required_lwork) {
-            return false;
-        }
-    }
+// release method
+void OLS::Workspace::release() {
+  delete[] s;
+  delete[] iwork;
+  delete[] work;
+  delete[] U;
+  delete[] S;
+  delete[] VT;
+  delete[] AtA;
+  delete[] Atb;
+
+  s = nullptr;
+  iwork = nullptr;
+  work = nullptr;
+  U = nullptr;
+  S = nullptr;
+  VT = nullptr;
+  AtA = nullptr;
+  Atb = nullptr;
+
+  lwork = -1;
+  m = 0;
+  n = 0;
+  algo = Algorithm::GELSD;
+}
 
-    // All checks passed
-    return true;
+// is_sufficient method
+bool OLS::Workspace::is_sufficient(int m_, int n_, Algorithm algorithm) const {
+  if (m != m_ || n != n_ || algo != algorithm) {
+    return false;
+  }
+  // Additional checks can be added here if necessary
+  return true;
 }
+
diff --git a/tests/test_ridge_regression.cpp b/tests/test_ridge_regression.cpp
index 5c56e51..064a0cd 100644
--- a/tests/test_ridge_regression.cpp
+++ b/tests/test_ridge_regression.cpp
@@ -15,7 +15,7 @@ TEST_CASE("Testing RR") {
 
   SECTION("lambda zero") {
     double lambda = 0;
-    RidgeRegression::solve(Phi1,b1,w, lambda);
+    RidgeRegression::solve(Phi1,b1,w, lambda, 1e-8);
     aed_type p= Phi2*w;
     REQUIRE_THAT(p[0], Catch::Matchers::WithinRel(b2[0]));
     REQUIRE_THAT(p[1], Catch::Matchers::WithinRel(b2[1]));
@@ -23,7 +23,7 @@ TEST_CASE("Testing RR") {
   }
   SECTION("small lambda ") {
     double lambda = 1e-10;
-    RidgeRegression::solve(Phi1,b1,w, lambda);
+    RidgeRegression::solve(Phi1,b1,w, lambda, 1e-8);
     aed_type p= Phi2*w;
     REQUIRE(p.isApprox(b2,lambda));
   }
-- 
GitLab