From fd02e37af0226e2c678815ae0ee4b94dbc857dc9 Mon Sep 17 00:00:00 2001
From: Marcin Kirsz <mkirsz@ed.ac.uk>
Date: Wed, 19 Feb 2025 16:04:01 +0000
Subject: [PATCH] Working MM for DM

---
 .../tadah/mlip/design_matrix/design_matrix.h  | 115 ++++++++++--------
 .../tadah/mlip/memory/DesignMatrixWorkspace.h |  65 ++++++++++
 .../tadah/mlip/memory/IMLIPWorkspaceManager.h |  49 ++++++++
 .../tadah/mlip/memory/MLIPWorkspaceManager.h  |  56 +++++++++
 include/tadah/mlip/models/m_blr.h             |  14 +--
 include/tadah/mlip/models/m_krr.h             |  14 +--
 include/tadah/mlip/models/m_tadah_base.h      |   1 +
 include/tadah/mlip/trainer.h                  |   8 +-
 src/DesignMatrixWorkspace.cpp                 |  28 +++++
 src/IMLIPWorkspaceManager.cpp                 |  12 ++
 src/MLIPWorkspaceManager.cpp                  |  35 ++++++
 src/m_all.cpp                                 |  16 ++-
 12 files changed, 342 insertions(+), 71 deletions(-)
 create mode 100644 include/tadah/mlip/memory/DesignMatrixWorkspace.h
 create mode 100644 include/tadah/mlip/memory/IMLIPWorkspaceManager.h
 create mode 100644 include/tadah/mlip/memory/MLIPWorkspaceManager.h
 create mode 100644 src/DesignMatrixWorkspace.cpp
 create mode 100644 src/IMLIPWorkspaceManager.cpp
 create mode 100644 src/MLIPWorkspaceManager.cpp

diff --git a/include/tadah/mlip/design_matrix/design_matrix.h b/include/tadah/mlip/design_matrix/design_matrix.h
index 5a785ec..43efa09 100644
--- a/include/tadah/mlip/design_matrix/design_matrix.h
+++ b/include/tadah/mlip/design_matrix/design_matrix.h
@@ -6,6 +6,9 @@
 #include <tadah/mlip/descriptors_calc.h>
 #include <tadah/mlip/normaliser.h>
 #include <tadah/core/config.h>
+#include <tadah/mlip/memory/IMLIPWorkspaceManager.h>
+#include <tadah/mlip/memory/MLIPWorkspaceManager.h>
+#include <tadah/mlip/memory/DesignMatrixWorkspace.h>
 
 #include <stdexcept>
 
@@ -68,9 +71,7 @@ class DesignMatrix : public DesignMatrixBase {
 public:
 
   F f;
-  phi_type Phi;
-  t_type T;
-  t_type Tlabels; // 0-Energy, 1-Force, 2-Stress  // TODO should be array of int not double
+  tadah::mlip::memory::DesignMatrixWorkspace *ws;
   bool scale=true;    // Control escale,fscale,sscale
 
   double e_std_dev=1;
@@ -89,24 +90,40 @@ public:
      *
      * \endcode
      */
-  DesignMatrix(F &f, Config &c):
-    f(f),
-    config(c),
-    force(config.get<bool>("FORCE")),
-    stress(config.get<bool>("STRESS")),
-    verbose(config.get<int>("VERBOSE")),
-    eweightglob(config.get<double>("EWEIGHT")),
-    fweightglob(config.get<double>("FWEIGHT")),
-    sweightglob(config.get<double>("SWEIGHT"))
-  {
-    if (config.exist("ESTDEV"))
-      e_std_dev = config.get<double>("ESTDEV");
-    if (config.exist("FSTDEV"))
-      f_std_dev = config.get<double>("FSTDEV");
-    if (config.exist("SSTDEV"))
-      config.get<double[6]>("SSTDEV", s_std_dev);
-  }
+    DesignMatrix(F &f, Config &c, tadah::mlip::memory::IMLIPWorkspaceManager& workspaceManager)
+        : f(f),
+          workspaceManager_(&workspaceManager),
+          ownWorkspaceManager(false),
+          config(c),
+          force(config.get<bool>("FORCE")),
+          stress(config.get<bool>("STRESS")),
+          verbose(config.get<int>("VERBOSE")),
+          eweightglob(config.get<double>("EWEIGHT")),
+          fweightglob(config.get<double>("FWEIGHT")),
+          sweightglob(config.get<double>("SWEIGHT"))
+    {
+        if (config.exist("ESTDEV"))
+            e_std_dev = config.get<double>("ESTDEV");
+        if (config.exist("FSTDEV"))
+            f_std_dev = config.get<double>("FSTDEV");
+        if (config.exist("SSTDEV"))
+            config.get<double[6]>("SSTDEV", s_std_dev);
+    }
 
+    // Constructor without workspaceManager_ parameter (delegating constructor)
+    DesignMatrix(F &f, Config &c)
+        : DesignMatrix(f, c, *new tadah::mlip::memory::MLIPWorkspaceManager())
+    {
+        ownWorkspaceManager = true;  // Set ownership flag
+    }
+
+    // Destructor
+    ~DesignMatrix() {
+        if (ownWorkspaceManager && workspaceManager_ != nullptr) {
+            delete workspaceManager_;
+            workspaceManager_ = nullptr;
+        }
+    }
 
   /** \brief Build design matrix from already calculated StDescriptorsDB
      *
@@ -115,7 +132,8 @@ public:
      * D2, D3 and DM calculators are not used.
      */
   void build(StDescriptorsDB &st_desc_db, const StructureDB &stdb) {
-    resize(stdb);
+    calc_mn(stdb);
+    ws = workspaceManager_->getDesignMatrixWorkspace(rows,cols);
     compute_stdevs(stdb);
     fill_T(stdb);
     std::vector<size_t> rows(stdb.size());
@@ -146,7 +164,8 @@ public:
   void build(const StructureDB &stdb, Normaliser &norm,
              DC &dc) {
     //DescriptorsCalc<D2,D3,DM,C2,C3,CM> dc(config);
-    resize(stdb);   // call after dc set DSIZE key
+    calc_mn(stdb);
+    ws = workspaceManager_->getDesignMatrixWorkspace(rows,cols);
     compute_stdevs(stdb);
     fill_T(stdb);
     // for opm we need to find first rows for each structure
@@ -183,18 +202,18 @@ public:
 
     double escale = 1;
     if (scale) escale = st.eweight*eweightglob/st.natoms();
-    f.calc_phi_energy_row(Phi,row,escale,st,st_d);
+    f.calc_phi_energy_row(ws->Phi,row,escale,st,st_d);
     if (force) {
       double fscale = 1;
       if (scale) fscale = st.fweight*fweightglob/st.natoms()/3.0;///f_std_dev;
-      f.calc_phi_force_rows(Phi,row,fscale,st,st_d);
+      f.calc_phi_force_rows(ws->Phi,row,fscale,st,st_d);
     }
     if (stress) {
       double sscale_arr[6] {1,1,1,1,1,1};
       if (scale)
         for(size_t xy=0;xy<6;++xy)
           sscale_arr[xy] = st.sweight*sweightglob/6.0;///s_std_dev[xy];
-      f.calc_phi_stress_rows(Phi,row,sscale_arr,st,st_d);
+      f.calc_phi_stress_rows(ws->Phi,row,sscale_arr,st,st_d);
     }
   }
   void fill_T(const StructureDB &stdb, size_t start=0) {
@@ -203,19 +222,19 @@ public:
 
       double escale = 1;
       if (scale) escale = stdb(s).eweight*eweightglob/stdb(s).natoms();///e_std_dev;
-      Tlabels(j) = 0;
-      T(j++) = stdb(s).energy*escale;
+      ws->Tlabels(j) = 0;
+      ws->T(j++) = stdb(s).energy*escale;
 
       if (force) {
         double fscale = 1;
         if (scale) fscale = stdb(s).fweight*fweightglob/stdb(s).natoms()/3.0;///e_std_dev/f_std_dev;
         for (const Atom &a : stdb(s).atoms) {
-          Tlabels(j) = 1;
-          T(j++) = a.force(0)*fscale;
-          Tlabels(j) = 1;
-          T(j++) = a.force(1)*fscale;
-          Tlabels(j) = 1;
-          T(j++) = a.force(2)*fscale;
+          ws->Tlabels(j) = 1;
+          ws->T(j++) = a.force(0)*fscale;
+          ws->Tlabels(j) = 1;
+          ws->T(j++) = a.force(1)*fscale;
+          ws->Tlabels(j) = 1;
+          ws->T(j++) = a.force(2)*fscale;
         }
       }
       if (stress) {
@@ -226,18 +245,24 @@ public:
         size_t xy=0;
         for (size_t x=0; x<3; ++x) {
           for (size_t y=x; y<3; ++y) {
-            Tlabels(j) = 2;
-            T(j++)=stdb(s).stress(x,y)*sscale_arr[xy++];
+            ws->Tlabels(j) = 2;
+            ws->T(j++)=stdb(s).stress(x,y)*sscale_arr[xy++];
           }
         }
       }
     }
   }
 
+phi_type &getPhi() { return ws->Phi; }
+t_type &getT() { return ws->T; }
+t_type &getTlabels() { return ws->Tlabels; }
+
 private:
-  Config &config;
-  size_t rows=0;
-  size_t cols=0;
+  tadah::mlip::memory::IMLIPWorkspaceManager* workspaceManager_ = nullptr;
+  bool ownWorkspaceManager = false;
+  Config & config;
+  size_t rows = 0;
+  size_t cols = 0;
   bool force;
   bool stress;
   int verbose;
@@ -246,7 +271,7 @@ private:
   double sweightglob;
 
   // resize Phi and T and set all elements to zero
-  void resize(const StructureDB &stdb) {
+  void calc_mn(const StructureDB &stdb) {
     rows=0;
     cols = f.get_phi_cols(config);
 
@@ -259,18 +284,10 @@ private:
     // Add space for augmented matrix
     rows += cols;
 
-    T.resize(rows);
-    Tlabels.resize(rows);
-    Phi.resize(rows,cols);
-
     if (verbose) {
-      std::cout << "Phi rows: "<< Phi.rows() << std::endl;
-      std::cout << "Phi cols: "<< Phi.cols() << std::endl;
+      std::cout << "Phi rows: "<< rows << std::endl;
+      std::cout << "Phi cols: "<< cols << std::endl;
     }
-
-    T.set_zero();
-    Phi.set_zero();
-
   }
 
   void compute_stdevs(const StructureDB &stdb) {
diff --git a/include/tadah/mlip/memory/DesignMatrixWorkspace.h b/include/tadah/mlip/memory/DesignMatrixWorkspace.h
new file mode 100644
index 0000000..d9c9150
--- /dev/null
+++ b/include/tadah/mlip/memory/DesignMatrixWorkspace.h
@@ -0,0 +1,65 @@
+#ifndef TADAH_MLIP_MEMORY_DESIGNMATRIXWORKSPACE_H
+#define TADAH_MLIP_MEMORY_DESIGNMATRIXWORKSPACE_H
+
+#include <tadah/core/core_types.h>
+
+namespace tadah {
+namespace mlip {
+namespace memory {
+
+/**
+ * @class DesignMatrixWorkspace
+ * @brief Manages memory allocations for DesignMatrix computations.
+ *
+ * The DesigmMatrixWorkspace class handles the memory required for various DM
+ * algorithms. It ensures that memory is allocated and deallocated appropriately
+ * and provides storage for intermediate computations.
+ */
+class DesignMatrixWorkspace {
+public:
+  /**
+   * @brief Constructor.
+   */
+  DesignMatrixWorkspace();
+
+  /**
+   * @brief Destructor.
+   *
+   * Releases all allocated memory.
+   */
+  ~DesignMatrixWorkspace();
+
+  /**
+   * @brief Allocate workspace for the specified problem size and algorithm.
+   *
+   * @param m Number of rows in the matrix.
+   * @param n Number of columns in the matrix.
+   */
+  void allocate(size_t m_, size_t n_);
+
+  /**
+   * @brief Check if the workspace is sufficient for the given problem size and
+   * algorithm.
+   *
+   * @param m Number of rows in the matrix.
+   * @param n Number of columns in the matrix.
+   * @return True if the workspace is sufficient, false otherwise.
+   */
+  bool isSufficient(size_t m_, size_t n_) const;
+
+  phi_type Phi;
+  t_type T;
+  t_type Tlabels; // 0-Energy, 1-Force, 2-Stress  // TODO should be array of int not double
+
+
+private:
+  // Disable copying
+  DesignMatrixWorkspace(const DesignMatrixWorkspace &) = delete;
+  DesignMatrixWorkspace &operator=(const DesignMatrixWorkspace &) = delete;
+};
+
+} // namespace memory
+} // namespace mlip
+} // namespace tadah
+
+#endif
diff --git a/include/tadah/mlip/memory/IMLIPWorkspaceManager.h b/include/tadah/mlip/memory/IMLIPWorkspaceManager.h
new file mode 100644
index 0000000..e4cd63b
--- /dev/null
+++ b/include/tadah/mlip/memory/IMLIPWorkspaceManager.h
@@ -0,0 +1,49 @@
+#ifndef TADAH_MLIP_MEMORY_IMLIPWORKSPACEMANAGER_H
+#define TADAH_MLIP_MEMORY_IMLIPWORKSPACEMANAGER_H
+
+#include <tadah/models/memory/IModelsWorkspaceManager.h>
+#include <tadah/core/core_types.h>
+
+namespace tadah {
+namespace mlip {
+namespace memory {
+
+class DesignMatrixWorkspace; ///< Forward declaration of DesignMatrixWorkspace
+
+/**
+ * @interface IMLIPWorkspaceManager
+ * @brief Interface for managing workspaces specific to the MLIP module.
+ *
+ * Extends the IModelsWorkspaceManager interface and provides methods for
+ * obtaining and releasing DesignMatrix workspace.
+ */
+class IMLIPWorkspaceManager : public virtual tadah::models::memory::IModelsWorkspaceManager {
+public:
+    /**
+     * @brief Virtual destructor.
+     */
+    virtual ~IMLIPWorkspaceManager();
+
+    /**
+     * @brief Get a DesignMatrix workspace suitable for the given problem size.
+     *
+     * @param m Number of rows in the matrix.
+     * @param n Number of columns in the matrix.
+     * @return Pointer to an allocated DesignMatrixWorkspace.
+     */
+    virtual DesignMatrixWorkspace* getDesignMatrixWorkspace(size_t m, size_t n) = 0;
+
+    /**
+     * @brief Release a DesignMatrix workspace.
+     *
+     * @param workspace Pointer to the DesignMatrixWorkspace to be released.
+     */
+    virtual void releaseDesignMatrixWorkspace(DesignMatrixWorkspace* workspace) = 0;
+
+};
+
+} // namespace memory
+} // namespace mlip
+} // namespace tadah
+
+#endif // TADAH_MLIP_MEMORY_IMLIPWORKSPACEMANAGER_H
diff --git a/include/tadah/mlip/memory/MLIPWorkspaceManager.h b/include/tadah/mlip/memory/MLIPWorkspaceManager.h
new file mode 100644
index 0000000..3b26fd3
--- /dev/null
+++ b/include/tadah/mlip/memory/MLIPWorkspaceManager.h
@@ -0,0 +1,56 @@
+#ifndef TADAH_MLIP_MEMORY_MLIPWORKSPACEMANAGER_H
+#define TADAH_MLIP_MEMORY_MLIPWORKSPACEMANAGER_H
+
+#include <tadah/models/memory/ModelsWorkspaceManager.h>
+#include <tadah/mlip/memory/IMLIPWorkspaceManager.h>
+
+namespace tadah {
+namespace mlip {
+namespace memory {
+
+/**
+ * @class MLIPWorkspaceManager
+ * @brief Implements the IMLIPWorkspaceManager interface to manage workspaces.
+ *
+ * Manages DesignMatrix workspace, providing methods to obtain and release them.
+ */
+class MLIPWorkspaceManager : public tadah::models::memory::ModelsWorkspaceManager, public virtual IMLIPWorkspaceManager {
+
+  public:
+  /**
+   * @brief Constructor.
+   */
+  MLIPWorkspaceManager();
+
+  /**
+   * @brief Destructor.
+   *
+   * Releases all allocated workspaces.
+   */
+  ~MLIPWorkspaceManager() override;
+
+  /**
+   * @brief Get a DesignMatrix workspace.
+   *
+   * @param m Number of rows in the matrix.
+   * @param n Number of columns in the matrix.
+   * @return Pointer to an allocated MLIPWorkspaceManager.
+   */
+  DesignMatrixWorkspace *getDesignMatrixWorkspace(size_t m, size_t n) override;
+
+  /**
+   * @brief Release a DesignMatrix workspace.
+   *
+   * @param workspace Pointer to the MLIPWorkspaceManager to be released.
+   */
+  void releaseDesignMatrixWorkspace(DesignMatrixWorkspace *workspace) override;
+
+private:
+  DesignMatrixWorkspace *designMatrixWorkspace_; ///< Pointer to the MLIPWorkspaceManager.
+};
+
+} // namespace memory
+} // namespace mlip
+} // namespace tadah
+
+#endif
diff --git a/include/tadah/mlip/models/m_blr.h b/include/tadah/mlip/models/m_blr.h
index c8e3eb0..596730a 100644
--- a/include/tadah/mlip/models/m_blr.h
+++ b/include/tadah/mlip/models/m_blr.h
@@ -8,7 +8,7 @@
 #include <tadah/mlip/normaliser.h>
 #include <tadah/models/m_blr_train.h>
 #include <tadah/core/config.h>
-#include <tadah/models/memory/IModelsWorkspaceManager.h>
+#include <tadah/mlip/memory/IMLIPWorkspaceManager.h>
 
 #include <limits>
 #include <stdexcept>
@@ -77,7 +77,7 @@ public:
     norm = Normaliser(c);
   }
 
-  M_BLR(BF &bf, Config &c, tadah::models::memory::IModelsWorkspaceManager& workspaceManager):
+  M_BLR(BF &bf, Config &c, tadah::mlip::memory::IMLIPWorkspaceManager& workspaceManager):
     M_BLR_Train<BF>(bf,c,workspaceManager),
     desmat(bf,c)
   {
@@ -182,11 +182,11 @@ public:
     dm.scale=false; // do not scale energy, forces and stresses
     dm.build(stdb,norm,dc);
 
-    predicted_error = T_MDMT_diag(dm.Phi, Sigma);
+    predicted_error = T_MDMT_diag(dm.getPhi(), Sigma);
     double pmean = sqrt(predicted_error.mean());
 
     // compute energy, forces and stresses
-    aed_type Tpred = T_dgemv(dm.Phi, weights);
+    aed_type Tpred = T_dgemv(dm.getPhi(), weights);
 
     // Construct StructureDB object with predicted values
     StructureDB stdb_;
@@ -222,7 +222,7 @@ public:
     if(!trained) throw std::runtime_error("This object is not trained!\n\
 Hint: check different predict() methods.");
 
-    phi_type &Phi = desmat.Phi;
+    phi_type &Phi = desmat.getPhi();
     //std::cout << Phi.row(0) << std::endl;
 
     // compute energy, forces and stresses
@@ -314,8 +314,8 @@ non linear basis function\n");
     // In consequence, we cannot use it for quick prediction
     // The simple solution, for now, is to make a copy of the Phi matrix
     //phi_type &Phi = desmat.Phi;
-    phi_type Phi = desmat.Phi;
-    t_type T = desmat.T;
+    phi_type Phi = desmat.getPhi();
+    t_type T = desmat.getT();
     //t_type &T = desmat.T;
     M_BLR_Train<BF>::train(Phi,T);
 
diff --git a/include/tadah/mlip/models/m_krr.h b/include/tadah/mlip/models/m_krr.h
index ccd2c6c..662e200 100644
--- a/include/tadah/mlip/models/m_krr.h
+++ b/include/tadah/mlip/models/m_krr.h
@@ -9,7 +9,7 @@
 #include <tadah/models/m_krr_train.h>
 #include <tadah/core/config.h>
 #include <tadah/mlip/models/basis.h>
-#include <tadah/models/memory/IModelsWorkspaceManager.h>
+#include <tadah/mlip/memory/IMLIPWorkspaceManager.h>
 
 #include <limits>
 #include <stdexcept>
@@ -80,7 +80,7 @@ public:
     norm = Normaliser(c);
   }
 
-  M_KRR(K &kernel, Config &c, tadah::models::memory::IModelsWorkspaceManager& workspaceManager):
+  M_KRR(K &kernel, Config &c, tadah::mlip::memory::IMLIPWorkspaceManager& workspaceManager):
     M_KRR_Train<K>(kernel,c,workspaceManager),
     basis(c),
     desmat(kernel,c)
@@ -253,11 +253,11 @@ public:
     dm.build(stdb,norm,dc);
 
     // compute error
-    predicted_error = T_MDMT_diag(dm.Phi, Sigma);
+    predicted_error = T_MDMT_diag(dm.getPhi(), Sigma);
     double pmean = sqrt(predicted_error.mean());
 
     // compute energy, forces and stresses
-    aed_type Tpred = T_dgemv(dm.Phi, weights);
+    aed_type Tpred = T_dgemv(dm.getPhi(), weights);
 
     // Construct StructureDB object with predicted values
     StructureDB stdb_;
@@ -293,7 +293,7 @@ public:
     if(!trained) throw std::runtime_error("This object is not trained!\n\
 Hint: check different predict() methods.");
 
-    phi_type &Phi = desmat.Phi;
+    phi_type &Phi = desmat.getPhi();
 
     // compute energy, forces and stresses
     aed_type Tpred = T_dgemv(Phi, weights);
@@ -384,8 +384,8 @@ non linear kernel\n");
   template <typename D>
   void train(D &desmat) {
     // TODO see comments in M_BLR
-    phi_type Phi = desmat.Phi;
-    t_type T = desmat.T;
+    phi_type Phi = desmat.getPhi();
+    t_type T = desmat.getT();
     M_KRR_Train<K>::train(Phi,T);
 
     if (config.template get<bool>("NORM") &&
diff --git a/include/tadah/mlip/models/m_tadah_base.h b/include/tadah/mlip/models/m_tadah_base.h
index d719829..3285cf0 100644
--- a/include/tadah/mlip/models/m_tadah_base.h
+++ b/include/tadah/mlip/models/m_tadah_base.h
@@ -26,6 +26,7 @@ public:
   /** \brief Predict total energy for a set of atoms. */
   using M_Predict::epredict;
   using M_Predict::fpredict;
+  using M_Core::predict;
   double epredict(const StDescriptors &std);
 
   ///** \brief Predict force between a pair of atoms in a k-direction. */
diff --git a/include/tadah/mlip/trainer.h b/include/tadah/mlip/trainer.h
index a212a9d..629e593 100644
--- a/include/tadah/mlip/trainer.h
+++ b/include/tadah/mlip/trainer.h
@@ -8,7 +8,7 @@
 #include <tadah/mlip/nn_finder.h>
 #include <tadah/core/config.h>
 #include <tadah/models/dc_selector.h>
-#include <tadah/models/memory/IModelsWorkspaceManager.h>
+#include <tadah/mlip/memory/IMLIPWorkspaceManager.h>
 
 #include <iostream>
 
@@ -28,7 +28,7 @@ class Trainer {
       if(fb)
         delete fb;
     }
-    Trainer (Config &c, tadah::models::memory::IModelsWorkspaceManager& workspaceManager):
+    Trainer (Config &c, tadah::mlip::memory::IMLIPWorkspaceManager& workspaceManager):
       config(c),
       DCS(config),
       dc(config,*DCS.d2b,*DCS.d3b,*DCS.dmb,
@@ -36,10 +36,10 @@ class Trainer {
       nnf(config),
       fb(CONFIG::factory<DM_Function_Base,Config&>(
             config.get<std::string>("MODEL",1),config)),
-      model(CONFIG::factory<M_Tadah_Base,DM_Function_Base&,Config&,tadah::models::memory::IModelsWorkspaceManager&>
+      model(CONFIG::factory<M_Tadah_Base,DM_Function_Base&,Config&,tadah::mlip::memory::IMLIPWorkspaceManager&>
           (config.get<std::string>("MODEL",0),*fb,config,workspaceManager)),
           //(config.get<std::string>("MODEL",0),*fb,config)),
-      dm(*fb, config)
+      dm(*fb, config, workspaceManager)
   {
     config.postprocess();
     config.check_for_training();
diff --git a/src/DesignMatrixWorkspace.cpp b/src/DesignMatrixWorkspace.cpp
new file mode 100644
index 0000000..8390cd6
--- /dev/null
+++ b/src/DesignMatrixWorkspace.cpp
@@ -0,0 +1,28 @@
+#include <algorithm>
+#include <cmath>
+#include <stdexcept>
+#include <tadah/core/lapack.h>
+#include <tadah/mlip/memory/DesignMatrixWorkspace.h>
+
+namespace tadah {
+namespace mlip {
+namespace memory {
+
+DesignMatrixWorkspace::DesignMatrixWorkspace() {}
+DesignMatrixWorkspace::~DesignMatrixWorkspace() {}
+
+void DesignMatrixWorkspace::allocate(size_t m_, size_t n_) {
+  Phi.resize(m_,n_);
+  T.resize(m_);
+  Tlabels.resize(m_);
+}
+
+bool DesignMatrixWorkspace::isSufficient(size_t m_, size_t n_) const {
+  return (Phi.rows() == m_ && Phi.cols() == n_
+      && T.size() == m_ && Tlabels.size() == m_);
+}
+
+} // namespace memory
+} // namespace mlip
+} // namespace tadah
+
diff --git a/src/IMLIPWorkspaceManager.cpp b/src/IMLIPWorkspaceManager.cpp
new file mode 100644
index 0000000..268d1c1
--- /dev/null
+++ b/src/IMLIPWorkspaceManager.cpp
@@ -0,0 +1,12 @@
+#include <tadah/mlip/memory/IMLIPWorkspaceManager.h>
+
+namespace tadah {
+namespace mlip {
+namespace memory {
+
+IMLIPWorkspaceManager::~IMLIPWorkspaceManager() = default;
+
+} // namespace memory
+} // namespace mlip
+} // namespace tadah
+
diff --git a/src/MLIPWorkspaceManager.cpp b/src/MLIPWorkspaceManager.cpp
new file mode 100644
index 0000000..019be87
--- /dev/null
+++ b/src/MLIPWorkspaceManager.cpp
@@ -0,0 +1,35 @@
+#include <tadah/mlip/memory/MLIPWorkspaceManager.h>
+#include <tadah/mlip/memory/DesignMatrixWorkspace.h>
+
+namespace tadah {
+namespace mlip {
+namespace memory {
+
+MLIPWorkspaceManager::MLIPWorkspaceManager()
+    : designMatrixWorkspace_(nullptr) {}
+
+MLIPWorkspaceManager::~MLIPWorkspaceManager() {
+    releaseDesignMatrixWorkspace(designMatrixWorkspace_);
+}
+
+DesignMatrixWorkspace* MLIPWorkspaceManager::getDesignMatrixWorkspace(size_t m, size_t n) {
+    if (!designMatrixWorkspace_ || !designMatrixWorkspace_->isSufficient(m, n)) {
+        releaseDesignMatrixWorkspace(designMatrixWorkspace_);
+        designMatrixWorkspace_ = new DesignMatrixWorkspace();
+        designMatrixWorkspace_->allocate(m, n);
+    }
+    return designMatrixWorkspace_;
+}
+
+void MLIPWorkspaceManager::releaseDesignMatrixWorkspace(DesignMatrixWorkspace* workspace) {
+    if (workspace) {
+        delete workspace;
+        workspace = nullptr;
+    }
+    designMatrixWorkspace_ = nullptr;
+}
+
+
+} // namespace memory
+} // namespace mlip
+} // namespace tadah
diff --git a/src/m_all.cpp b/src/m_all.cpp
index fa99c8c..c8387e6 100644
--- a/src/m_all.cpp
+++ b/src/m_all.cpp
@@ -1,5 +1,6 @@
 #include <tadah/mlip/models/m_all.h>
 #include <tadah/models/memory/IModelsWorkspaceManager.h>
+#include <tadah/mlip/memory/IMLIPWorkspaceManager.h>
 
 template<>
 CONFIG::Registry<M_Tadah_Base, DM_Function_Base&, Config&>::Map
@@ -8,9 +9,16 @@ CONFIG::Registry<M_Tadah_Base, DM_Function_Base&, Config&>::registry{};
 CONFIG::Registry<M_Tadah_Base, DM_Function_Base&, Config&>::Register<M_KRR<>> M_KRR_1("M_KRR");
 CONFIG::Registry<M_Tadah_Base, DM_Function_Base&, Config&>::Register<M_BLR<>> M_BLR_1("M_BLR");
 
+/*template<>*/
+/*CONFIG::Registry<M_Tadah_Base, DM_Function_Base&, Config&, tadah::models::memory::IModelsWorkspaceManager&>::Map*/
+/*CONFIG::Registry<M_Tadah_Base, DM_Function_Base&, Config&, tadah::models::memory::IModelsWorkspaceManager&>::registry{};*/
+/**/
+/*CONFIG::Registry<M_Tadah_Base, DM_Function_Base&, Config&, tadah::models::memory::IModelsWorkspaceManager&>::Register<M_BLR<>> M_BLR_2("M_BLR");*/
+/*CONFIG::Registry<M_Tadah_Base, DM_Function_Base&, Config&, tadah::models::memory::IModelsWorkspaceManager&>::Register<M_KRR<>> M_KRR_2("M_KRR");*/
+
 template<>
-CONFIG::Registry<M_Tadah_Base, DM_Function_Base&, Config&, tadah::models::memory::IModelsWorkspaceManager&>::Map
-CONFIG::Registry<M_Tadah_Base, DM_Function_Base&, Config&, tadah::models::memory::IModelsWorkspaceManager&>::registry{};
+CONFIG::Registry<M_Tadah_Base, DM_Function_Base&, Config&, tadah::mlip::memory::IMLIPWorkspaceManager&>::Map
+CONFIG::Registry<M_Tadah_Base, DM_Function_Base&, Config&, tadah::mlip::memory::IMLIPWorkspaceManager&>::registry{};
 
-CONFIG::Registry<M_Tadah_Base, DM_Function_Base&, Config&, tadah::models::memory::IModelsWorkspaceManager&>::Register<M_BLR<>> M_BLR_2("M_BLR");
-CONFIG::Registry<M_Tadah_Base, DM_Function_Base&, Config&, tadah::models::memory::IModelsWorkspaceManager&>::Register<M_KRR<>> M_KRR_2("M_KRR");
+CONFIG::Registry<M_Tadah_Base, DM_Function_Base&, Config&, tadah::mlip::memory::IMLIPWorkspaceManager&>::Register<M_BLR<>> M_BLR_3("M_BLR");
+CONFIG::Registry<M_Tadah_Base, DM_Function_Base&, Config&, tadah::mlip::memory::IMLIPWorkspaceManager&>::Register<M_KRR<>> M_KRR_3("M_KRR");
-- 
GitLab