From 31db9ef77ae092ca41ad01ce3e5e84cf3730ec81 Mon Sep 17 00:00:00 2001
From: mkirsz <s1351949@sms.ed.ac.uk>
Date: Tue, 25 Feb 2025 11:30:38 +0000
Subject: [PATCH] WiP on weighting factors

---
 .../tadah/mlip/design_matrix/design_matrix.h  | 194 ++++++++++++------
 .../functions/basis_functions/dm_bf_base.h    |   6 +-
 .../functions/basis_functions/dm_bf_linear.h  |   6 +-
 .../basis_functions/dm_bf_polynomial2.h       |   6 +-
 .../functions/dm_function_base.h              |   6 +-
 .../functions/kernels/dm_kern_base.h          |   6 +-
 .../functions/kernels/dm_kern_linear.h        |   6 +-
 .../tadah/mlip/memory/DesignMatrixWorkspace.h |   9 +-
 .../tadah/mlip/memory/IMLIPWorkspaceManager.h |   3 +-
 .../tadah/mlip/memory/MLIPWorkspaceManager.h  |   2 +-
 include/tadah/mlip/models/m_blr.h             |  23 ++-
 include/tadah/mlip/models/m_krr.h             |  23 ++-
 src/DesignMatrixWorkspace.cpp                 |   7 +-
 src/MLIPWorkspaceManager.cpp                  |   4 +-
 src/dm_bf_linear.cpp                          |  12 +-
 src/dm_bf_polynomial2.cpp                     |   9 +-
 src/dm_kern_base.cpp                          |  13 +-
 src/dm_kern_linear.cpp                        |  12 +-
 18 files changed, 210 insertions(+), 137 deletions(-)

diff --git a/include/tadah/mlip/design_matrix/design_matrix.h b/include/tadah/mlip/design_matrix/design_matrix.h
index c9c02cb..d78cc11 100644
--- a/include/tadah/mlip/design_matrix/design_matrix.h
+++ b/include/tadah/mlip/design_matrix/design_matrix.h
@@ -72,7 +72,7 @@ public:
 
   F f;
   tadah::mlip::memory::DesignMatrixWorkspace *ws;
-  bool scale=true;    // Control escale,fscale,sscale
+  bool w_copy_=false;  // keep copy of Phi and T witoout weights
 
   double e_std_dev=1;
   double f_std_dev=1;
@@ -90,40 +90,40 @@ public:
      *
      * \endcode
      */
-    DesignMatrix(F &f, Config &c, tadah::mlip::memory::IMLIPWorkspaceManager& workspaceManager)
-        : f(f),
-          workspaceManager_(&workspaceManager),
-          ownWorkspaceManager(false),
-          config(c),
-          force(config.get<bool>("FORCE")),
-          stress(config.get<bool>("STRESS")),
-          verbose(config.get<int>("VERBOSE")),
-          eweightglob(config.get<double>("EWEIGHT")),
-          fweightglob(config.get<double>("FWEIGHT")),
-          sweightglob(config.get<double>("SWEIGHT"))
-    {
-        if (config.exist("ESTDEV"))
-            e_std_dev = config.get<double>("ESTDEV");
-        if (config.exist("FSTDEV"))
-            f_std_dev = config.get<double>("FSTDEV");
-        if (config.exist("SSTDEV"))
-            config.get<double[6]>("SSTDEV", s_std_dev);
-    }
+  DesignMatrix(F &f, Config &c, tadah::mlip::memory::IMLIPWorkspaceManager& workspaceManager)
+    : f(f),
+    workspaceManager_(&workspaceManager),
+    ownWorkspaceManager(false),
+    config(c),
+    force(config.get<bool>("FORCE")),
+    stress(config.get<bool>("STRESS")),
+    verbose(config.get<int>("VERBOSE")),
+    eweightglob(config.get<double>("EWEIGHT")),
+    fweightglob(config.get<double>("FWEIGHT")),
+    sweightglob(config.get<double>("SWEIGHT"))
+  {
+    if (config.exist("ESTDEV"))
+      e_std_dev = config.get<double>("ESTDEV");
+    if (config.exist("FSTDEV"))
+      f_std_dev = config.get<double>("FSTDEV");
+    if (config.exist("SSTDEV"))
+      config.get<double[6]>("SSTDEV", s_std_dev);
+  }
 
-    // Constructor without workspaceManager_ parameter (delegating constructor)
-    DesignMatrix(F &f, Config &c)
-        : DesignMatrix(f, c, *new tadah::mlip::memory::MLIPWorkspaceManager())
-    {
-        ownWorkspaceManager = true;  // Set ownership flag
-    }
+  // Constructor without workspaceManager_ parameter (delegating constructor)
+  DesignMatrix(F &f, Config &c)
+  : DesignMatrix(f, c, *new tadah::mlip::memory::MLIPWorkspaceManager())
+  {
+    ownWorkspaceManager = true;  // Set ownership flag
+  }
 
-    // Destructor
-    ~DesignMatrix() {
-        if (ownWorkspaceManager && workspaceManager_ != nullptr) {
-            delete workspaceManager_;
-            workspaceManager_ = nullptr;
-        }
+  // Destructor
+  ~DesignMatrix() {
+    if (ownWorkspaceManager && workspaceManager_ != nullptr) {
+      delete workspaceManager_;
+      workspaceManager_ = nullptr;
     }
+  }
 
   /** \brief Build design matrix from already calculated StDescriptorsDB
      *
@@ -131,14 +131,16 @@ public:
      * The vector of targets  **T** is build from StructureDB.
      * D2, D3 and DM calculators are not used.
      */
-  void build(StDescriptorsDB &st_desc_db, const StructureDB &stdb) {
+  void build(StDescriptorsDB &st_desc_db, const StructureDB &stdb, bool wcopy) {
+    w_copy_ = wcopy;
     calc_mn(stdb);
-    ws = workspaceManager_->getDesignMatrixWorkspace(rows,cols);
+    ws = workspaceManager_->getDesignMatrixWorkspace(rows,cols, wcopy);
     ws->Phi.set_zero();
     ws->T.set_zero();
     ws->Tlabels.set_zero();
     compute_stdevs(stdb);
     fill_T(stdb);
+    compute_wfactors(stdb); // call after fill_T
     std::vector<size_t> rows(stdb.size());
     size_t row=0;
     for (size_t s=0; s<stdb.size(); ++s) {
@@ -165,15 +167,17 @@ public:
   /** \brief Calculate descriptors and build design matrix. */
   template <typename DC>
   void build(const StructureDB &stdb, Normaliser &norm,
-             DC &dc) {
+             DC &dc, bool wcopy) {
     //DescriptorsCalc<D2,D3,DM,C2,C3,CM> dc(config);
+    w_copy_ = wcopy;
     calc_mn(stdb);
-    ws = workspaceManager_->getDesignMatrixWorkspace(rows,cols);
+    ws = workspaceManager_->getDesignMatrixWorkspace(rows,cols, wcopy);
     ws->Phi.set_zero();
     ws->T.set_zero();
     ws->Tlabels.set_zero();
     compute_stdevs(stdb);
     fill_T(stdb);
+    compute_wfactors(stdb); // call after fill_T
     // for opm we need to find first rows for each structure
 
 
@@ -197,71 +201,77 @@ public:
     for (size_t s=0; s<stdb.size(); ++s) {
       StDescriptors st_d = dc.calc(stdb(s));
 
-      if(config.get<bool>("NORM"))
+      if(config.get<bool>("NORM"))  // TODO move to scale()
         norm.normalise(st_d);
 
       build(rows[s], stdb(s), st_d);
     }
+    if (w_copy_) {
+      copyAndScale();
+    }
+    else {
+      scale();
+    }
 
   }
   void build(size_t &row, const Structure &st, const StDescriptors &st_d) {
 
-    double escale = 1;
-    if (scale) escale = st.eweight*eweightglob/st.natoms();
-    f.calc_phi_energy_row(ws->Phi,row,escale,st,st_d);
+    // double escale = 1;
+    // if (scale) escale = st.eweight*eweightglob/st.natoms(); // TODO
+    f.calc_phi_energy_row(ws->Phi,row,st,st_d);
     if (force) {
-      double fscale = 1;
-      if (scale) fscale = st.fweight*fweightglob/st.natoms()/3.0;///f_std_dev;
-      f.calc_phi_force_rows(ws->Phi,row,fscale,st,st_d);
+      // double fscale = 1;
+      // if (scale) fscale = st.fweight*fweightglob/st.natoms()/3.0; // TODO
+      f.calc_phi_force_rows(ws->Phi,row,st,st_d);
     }
     if (stress) {
-      double sscale_arr[6] {1,1,1,1,1,1};
-      if (scale)
-        for(size_t xy=0;xy<6;++xy)
-          sscale_arr[xy] = st.sweight*sweightglob/6.0;///s_std_dev[xy];
-      f.calc_phi_stress_rows(ws->Phi,row,sscale_arr,st,st_d);
+      // double sscale_arr[6] {1,1,1,1,1,1};
+      // if (scale)
+      // for(size_t xy=0;xy<6;++xy)
+      // sscale_arr[xy] = st.sweight*sweightglob/6.0; // TODO
+      f.calc_phi_stress_rows(ws->Phi,row,st,st_d);
     }
   }
   void fill_T(const StructureDB &stdb, size_t start=0) {
     size_t j=start;
     for (size_t s=0; s<stdb.size(); ++s) {
 
-      double escale = 1;
-      if (scale) escale = stdb(s).eweight*eweightglob/stdb(s).natoms();///e_std_dev;
+      // double escale = 1;
       ws->Tlabels(j) = 0;
-      ws->T(j++) = stdb(s).energy*escale;
+      ws->T(j++) = stdb(s).energy;
 
       if (force) {
-        double fscale = 1;
-        if (scale) fscale = stdb(s).fweight*fweightglob/stdb(s).natoms()/3.0;///e_std_dev/f_std_dev;
+        // double fscale = 1;
+        // if (scale) fscale = stdb(s).fweight*fweightglob/stdb(s).natoms()/3.0;
         for (const Atom &a : stdb(s).atoms) {
           ws->Tlabels(j) = 1;
-          ws->T(j++) = a.force(0)*fscale;
+          ws->T(j++) = a.force(0);
           ws->Tlabels(j) = 1;
-          ws->T(j++) = a.force(1)*fscale;
+          ws->T(j++) = a.force(1);
           ws->Tlabels(j) = 1;
-          ws->T(j++) = a.force(2)*fscale;
+          ws->T(j++) = a.force(2);
         }
       }
       if (stress) {
-        double sscale_arr[6] {1,1,1,1,1,1};
-        if (scale)
-          for(size_t xy=0;xy<6;++xy)
-            sscale_arr[xy] = stdb(s).sweight*sweightglob/6.0;///e_std_dev/s_std_dev[xy];
-        size_t xy=0;
+        // double sscale_arr[6] {1,1,1,1,1,1};
+        // if (scale)
+        //   for(size_t xy=0;xy<6;++xy)
+        //     sscale_arr[xy] = stdb(s).sweight*sweightglob/6.0;
+        // size_t xy=0;
         for (size_t x=0; x<3; ++x) {
           for (size_t y=x; y<3; ++y) {
             ws->Tlabels(j) = 2;
-            ws->T(j++)=stdb(s).stress(x,y)*sscale_arr[xy++];
+            ws->T(j++)=stdb(s).stress(x,y);
           }
         }
       }
     }
   }
 
-phi_type &getPhi() { return ws->Phi; }
-t_type &getT() { return ws->T; }
-t_type &getTlabels() { return ws->Tlabels; }
+  phi_type &getPhi() { return ws->Phi; }
+  t_type &getT() { return ws->T; }
+  t_type &getTlabels() { return ws->Tlabels; }
+  bool hasCopy() { return w_copy_; }
 
 private:
   tadah::mlip::memory::IMLIPWorkspaceManager* workspaceManager_ = nullptr;
@@ -348,5 +358,59 @@ private:
     for (size_t j=0; j<6; ++j)
       config.add("SSTDEV",s_std_dev[j]);
   }
+
+  void compute_wfactors(const StructureDB &stdb) {
+    size_t j=0;
+    for (size_t s=0; s<stdb.size(); ++s) {
+      double escale = stdb(s).eweight*eweightglob;
+      ws->wfactors(j++) = escale;
+      if (force) {
+        double fscale = stdb(s).fweight*fweightglob;
+        for (size_t a=0; a<stdb(s).natoms(); ++a) {
+          ws->wfactors(j++) = fscale;
+          ws->wfactors(j++) = fscale;
+          ws->wfactors(j++) = fscale;
+        }
+      }
+      if (stress) {
+        double sscale_arr[6] {1,1,1,1,1,1};
+        for(size_t xy=0;xy<6;++xy)
+          sscale_arr[xy] = stdb(s).sweight*sweightglob/6.0;
+        int xy=0;
+        for (size_t x=0; x<3; ++x) {
+          for (size_t y=x; y<3; ++y) {
+            ws->wfactors(j++)=sscale_arr[xy++];
+          }
+        }
+      }
+    }
+  }
+
+  void copyAndScale()
+  {
+    // first copy scalled Phi to Phi_cpy and T to T_cpy
+    // followed by swapping pointers
+    for (std::size_t j = 0; j < ws->Phi.cols(); ++j) {
+      for (std::size_t i = 0; i < ws->Phi.rows(); ++i) {
+        ws->Phi_cpy(i,j) = ws->Phi(i,j)*ws->wfactors(i);
+      }
+    }
+    for (std::size_t i = 0; i < ws->Phi.rows(); ++i) {
+      ws->T_cpy(i) = ws->T(i)*ws->wfactors(i);
+    }
+    std::swap(ws->Phi, ws->Phi_cpy);
+    std::swap(ws->T, ws->T_cpy);
+  }
+  void scale()
+  {
+    for (std::size_t j = 0; j < ws->Phi.cols(); ++j) {
+      for (std::size_t i = 0; i < ws->Phi.rows(); ++i) {
+        ws->Phi(i,j) *= ws->wfactors(i);
+      }
+    }
+    for (std::size_t i = 0; i < ws->Phi.rows(); ++i) {
+      ws->T(i) *= ws->wfactors(i);
+    }
+  }
 };
 #endif
diff --git a/include/tadah/mlip/design_matrix/functions/basis_functions/dm_bf_base.h b/include/tadah/mlip/design_matrix/functions/basis_functions/dm_bf_base.h
index 6f6187d..7385b9c 100644
--- a/include/tadah/mlip/design_matrix/functions/basis_functions/dm_bf_base.h
+++ b/include/tadah/mlip/design_matrix/functions/basis_functions/dm_bf_base.h
@@ -16,10 +16,10 @@ struct DM_BF_Base: public DM_Function_Base, public virtual BF_Base {
     virtual ~DM_BF_Base();
     // virtual size_t get_phi_cols(const Config &config)=0;
     // virtual void calc_phi_energy_row(phi_type &Phi, size_t &row,
-    //         const double fac, const Structure &st, const StDescriptors &st_d)=0;
+    //         const Structure &st, const StDescriptors &st_d)=0;
     // virtual void calc_phi_force_rows(phi_type &Phi, size_t &row,
-    //         const double fac, const Structure &st, const StDescriptors &st_d)=0;
+    //         const Structure &st, const StDescriptors &st_d)=0;
     // virtual void calc_phi_stress_rows(phi_type &Phi, size_t &row,
-    //         const double fac[6], const Structure &st, const StDescriptors &st_d)=0;
+    //         const Structure &st, const StDescriptors &st_d)=0;
 };
 #endif
diff --git a/include/tadah/mlip/design_matrix/functions/basis_functions/dm_bf_linear.h b/include/tadah/mlip/design_matrix/functions/basis_functions/dm_bf_linear.h
index e0052c3..43c4834 100644
--- a/include/tadah/mlip/design_matrix/functions/basis_functions/dm_bf_linear.h
+++ b/include/tadah/mlip/design_matrix/functions/basis_functions/dm_bf_linear.h
@@ -10,10 +10,10 @@ struct DM_BF_Linear: public DM_BF_Base, public BF_Linear
     DM_BF_Linear(const Config &c);
     size_t get_phi_cols(const Config &config) override;
     void calc_phi_energy_row(phi_type &Phi, size_t &row,
-            const double fac, const Structure &st, const StDescriptors &st_d) override;
+            const Structure &st, const StDescriptors &st_d) override;
     void calc_phi_force_rows(phi_type &Phi, size_t &row,
-            const double fac, const Structure &st, const StDescriptors &st_d) override;
+            const Structure &st, const StDescriptors &st_d) override;
     void calc_phi_stress_rows(phi_type &Phi, size_t &row,
-            const double fac[6], const Structure &st, const StDescriptors &st_d) override;
+            const Structure &st, const StDescriptors &st_d) override;
 };
 #endif
diff --git a/include/tadah/mlip/design_matrix/functions/basis_functions/dm_bf_polynomial2.h b/include/tadah/mlip/design_matrix/functions/basis_functions/dm_bf_polynomial2.h
index 550f403..19bca29 100644
--- a/include/tadah/mlip/design_matrix/functions/basis_functions/dm_bf_polynomial2.h
+++ b/include/tadah/mlip/design_matrix/functions/basis_functions/dm_bf_polynomial2.h
@@ -10,10 +10,10 @@ struct DM_BF_Polynomial2: public DM_BF_Base, public BF_Polynomial2
     DM_BF_Polynomial2(const Config &c);
     size_t get_phi_cols(const Config &config) override;
     void  calc_phi_energy_row(phi_type &Phi, size_t &row,
-            const double fac, const Structure &st, const StDescriptors &st_d) override;
+            const Structure &st, const StDescriptors &st_d) override;
     void  calc_phi_force_rows(phi_type &Phi, size_t &row,
-            const double fac, const Structure &st, const StDescriptors &st_d) override;
+            const Structure &st, const StDescriptors &st_d) override;
     void  calc_phi_stress_rows(phi_type &Phi, size_t &row,
-            const double fac[6], const Structure &st, const StDescriptors &st_d) override;
+            const Structure &st, const StDescriptors &st_d) override;
 };
 #endif
diff --git a/include/tadah/mlip/design_matrix/functions/dm_function_base.h b/include/tadah/mlip/design_matrix/functions/dm_function_base.h
index 06f5321..fdbdee2 100644
--- a/include/tadah/mlip/design_matrix/functions/dm_function_base.h
+++ b/include/tadah/mlip/design_matrix/functions/dm_function_base.h
@@ -23,11 +23,11 @@ struct DM_Function_Base: public virtual Function_Base {
 
   virtual size_t get_phi_cols(const Config &)=0;
   virtual void calc_phi_energy_row(phi_type &, size_t &,
-                                   const double , const Structure &, const StDescriptors &)=0;
+                                   const Structure &, const StDescriptors &)=0;
   virtual void calc_phi_force_rows(phi_type &, size_t &,
-                                   const double , const Structure &, const StDescriptors &)=0;
+                                   const Structure &, const StDescriptors &)=0;
   virtual void calc_phi_stress_rows(phi_type &, size_t &,
-                                    const double[6], const Structure &, const StDescriptors &)=0;
+                                   const Structure &, const StDescriptors &)=0;
 };
 //template<> inline CONFIG::Registry<DM_Function_Base>::Map CONFIG::Registry<DM_Function_Base>::registry{};
 //template<> inline CONFIG::Registry<DM_Function_Base,Config&>::Map CONFIG::Registry<DM_Function_Base,Config&>::registry{};
diff --git a/include/tadah/mlip/design_matrix/functions/kernels/dm_kern_base.h b/include/tadah/mlip/design_matrix/functions/kernels/dm_kern_base.h
index 8a12489..5280252 100644
--- a/include/tadah/mlip/design_matrix/functions/kernels/dm_kern_base.h
+++ b/include/tadah/mlip/design_matrix/functions/kernels/dm_kern_base.h
@@ -24,9 +24,9 @@ class DM_Kern_Base: public DM_Function_Base, public virtual Kern_Base  {
         DM_Kern_Base(const Config&c);
         virtual ~DM_Kern_Base();
         virtual size_t get_phi_cols(const Config &config) override;
-        virtual void  calc_phi_energy_row(phi_type &Phi, size_t &row, const double fac, const Structure &st, const StDescriptors &st_d) override;
-        virtual void  calc_phi_force_rows(phi_type &Phi, size_t &row, const double fac, const Structure &st, const StDescriptors &st_d) override;
-        virtual void  calc_phi_stress_rows(phi_type &Phi, size_t &row, const double fac[6], const Structure &st, const StDescriptors &st_d) override;
+        virtual void  calc_phi_energy_row(phi_type &Phi, size_t &row, const Structure &st, const StDescriptors &st_d) override;
+        virtual void  calc_phi_force_rows(phi_type &Phi, size_t &row, const Structure &st, const StDescriptors &st_d) override;
+        virtual void  calc_phi_stress_rows(phi_type &Phi, size_t &row, const Structure &st, const StDescriptors &st_d) override;
 
 };
 #endif
diff --git a/include/tadah/mlip/design_matrix/functions/kernels/dm_kern_linear.h b/include/tadah/mlip/design_matrix/functions/kernels/dm_kern_linear.h
index cf026d0..3b35036 100644
--- a/include/tadah/mlip/design_matrix/functions/kernels/dm_kern_linear.h
+++ b/include/tadah/mlip/design_matrix/functions/kernels/dm_kern_linear.h
@@ -21,10 +21,10 @@ class DM_Kern_Linear :  public DM_Kern_Base, public Kern_Linear {
 
     size_t get_phi_cols(const Config &config) override;
     void calc_phi_energy_row(phi_type &Phi, size_t &row,
-            const double fac, const Structure &st, const StDescriptors &st_d) override;
+            const Structure &st, const StDescriptors &st_d) override;
     void calc_phi_force_rows(phi_type &Phi, size_t &row,
-            const double fac, const Structure &st, const StDescriptors &st_d) override;
+            const Structure &st, const StDescriptors &st_d) override;
     void calc_phi_stress_rows(phi_type &Phi, size_t &row,
-            const double fac[6], const Structure &st, const StDescriptors &st_d) override;
+            const Structure &st, const StDescriptors &st_d) override;
 };
 #endif
diff --git a/include/tadah/mlip/memory/DesignMatrixWorkspace.h b/include/tadah/mlip/memory/DesignMatrixWorkspace.h
index d9c9150..03b996d 100644
--- a/include/tadah/mlip/memory/DesignMatrixWorkspace.h
+++ b/include/tadah/mlip/memory/DesignMatrixWorkspace.h
@@ -20,7 +20,7 @@ public:
   /**
    * @brief Constructor.
    */
-  DesignMatrixWorkspace();
+  DesignMatrixWorkspace(bool wcopy);
 
   /**
    * @brief Destructor.
@@ -47,13 +47,14 @@ public:
    */
   bool isSufficient(size_t m_, size_t n_) const;
 
-  phi_type Phi;
-  t_type T;
+  phi_type Phi, Phi_cpy; // Design matrix and space for its copy
+  t_type T, T_cpy; // Target vector and space for its copy
   t_type Tlabels; // 0-Energy, 1-Force, 2-Stress  // TODO should be array of int not double
-
+  t_type wfactors; // Weighting factors for Phi
 
 private:
   // Disable copying
+  bool wcopy; // Keep copy of Phi and T for weighted design matrix and target vector
   DesignMatrixWorkspace(const DesignMatrixWorkspace &) = delete;
   DesignMatrixWorkspace &operator=(const DesignMatrixWorkspace &) = delete;
 };
diff --git a/include/tadah/mlip/memory/IMLIPWorkspaceManager.h b/include/tadah/mlip/memory/IMLIPWorkspaceManager.h
index e4cd63b..2cba199 100644
--- a/include/tadah/mlip/memory/IMLIPWorkspaceManager.h
+++ b/include/tadah/mlip/memory/IMLIPWorkspaceManager.h
@@ -29,9 +29,10 @@ public:
      *
      * @param m Number of rows in the matrix.
      * @param n Number of columns in the matrix.
+     * @param wcopy Keep copy of Phi and T for weighted design matrix and target vector.
      * @return Pointer to an allocated DesignMatrixWorkspace.
      */
-    virtual DesignMatrixWorkspace* getDesignMatrixWorkspace(size_t m, size_t n) = 0;
+    virtual DesignMatrixWorkspace* getDesignMatrixWorkspace(size_t m, size_t n, bool wcopy) = 0;
 
     /**
      * @brief Release a DesignMatrix workspace.
diff --git a/include/tadah/mlip/memory/MLIPWorkspaceManager.h b/include/tadah/mlip/memory/MLIPWorkspaceManager.h
index 3b26fd3..022f02b 100644
--- a/include/tadah/mlip/memory/MLIPWorkspaceManager.h
+++ b/include/tadah/mlip/memory/MLIPWorkspaceManager.h
@@ -36,7 +36,7 @@ class MLIPWorkspaceManager : public tadah::models::memory::ModelsWorkspaceManage
    * @param n Number of columns in the matrix.
    * @return Pointer to an allocated MLIPWorkspaceManager.
    */
-  DesignMatrixWorkspace *getDesignMatrixWorkspace(size_t m, size_t n) override;
+  DesignMatrixWorkspace *getDesignMatrixWorkspace(size_t m, size_t n, bool wcopy) override;
 
   /**
    * @brief Release a DesignMatrix workspace.
diff --git a/include/tadah/mlip/models/m_blr.h b/include/tadah/mlip/models/m_blr.h
index 7ba9d85..ac2996d 100644
--- a/include/tadah/mlip/models/m_blr.h
+++ b/include/tadah/mlip/models/m_blr.h
@@ -101,7 +101,8 @@ public:
     if(config.template get<bool>("NORM"))
       norm = Normaliser(config,st_desc_db);
 
-    desmat.build(st_desc_db,stdb);
+    bool wcopy=false;
+    desmat.build(st_desc_db,stdb, wcopy);
     train(desmat);
   }
 
@@ -131,7 +132,8 @@ public:
       config.add("STRESS", stress);
     }
 
-    desmat.build(stdb,norm,dc);
+    bool wcopy=false;
+    desmat.build(stdb,norm,dc,wcopy);
     train(desmat);
   }
 
@@ -179,8 +181,9 @@ public:
 
     LinearRegressor::read_sigma(config_pred,Sigma);
     DesignMatrix<BF> dm(bf,config_pred);
-    dm.scale=false; // do not scale energy, forces and stresses
-    dm.build(stdb,norm,dc);
+    // dm.scale=false; // do not scale energy, forces and stresses
+    bool wcopy=false;
+    dm.build(stdb,norm,dc,wcopy);
 
     predicted_error = T_MDMT_diag(dm.getPhi(), Sigma);
     double pmean = sqrt(predicted_error.mean());
@@ -227,9 +230,9 @@ Hint: check different predict() methods.");
     // compute energy, forces and stresses
     aed_type Tpred = T_dgemv(Phi, weights);
 
-    double eweightglob=config.template get<double>("EWEIGHT");
-    double fweightglob=config.template get<double>("FWEIGHT");
-    double sweightglob=config.template get<double>("SWEIGHT");
+    // double eweightglob=config.template get<double>("EWEIGHT");
+    // double fweightglob=config.template get<double>("FWEIGHT");
+    // double sweightglob=config.template get<double>("SWEIGHT");
 
     // Construct StructureDB object with predicted values
     StructureDB stdb_;
@@ -238,19 +241,19 @@ Hint: check different predict() methods.");
     size_t i=0;
     while (i<Phi.rows()) {
 
-      stdb_(s).energy = Tpred(i++)*stdb(s).natoms()/eweightglob/stdb(s).eweight;
+      stdb_(s).energy = Tpred(i++)*stdb(s).natoms();
       if (config.template get<bool>("FORCE")) {
         stdb_(s).atoms.resize(stdb(s).natoms());
         for (size_t a=0; a<stdb(s).natoms(); ++a) {
           for (size_t k=0; k<3; ++k) {
-            stdb_(s).atoms[a].force[k] = Tpred(i++)/fweightglob/stdb(s).fweight;
+            stdb_(s).atoms[a].force[k] = Tpred(i++);
           }
         }
       }
       if (config.template get<bool>("STRESS")) {
         for (size_t x=0; x<3; ++x) {
           for (size_t y=x; y<3; ++y) {
-            stdb_(s).stress(x,y) = Tpred(i++)/sweightglob/stdb(s).sweight;
+            stdb_(s).stress(x,y) = Tpred(i++);
             if (x!=y)
               stdb_(s).stress(y,x) = stdb_(s).stress(x,y);
           }
diff --git a/include/tadah/mlip/models/m_krr.h b/include/tadah/mlip/models/m_krr.h
index 662e200..4424854 100644
--- a/include/tadah/mlip/models/m_krr.h
+++ b/include/tadah/mlip/models/m_krr.h
@@ -104,7 +104,8 @@ public:
     if(config.template get<bool>("NORM"))
       norm = Normaliser(config,st_desc_db);
 
-    desmat.build(st_desc_db,stdb);
+    bool wcopy=false;
+    desmat.build(st_desc_db,stdb,wcopy);
     train(desmat);
   }
 
@@ -161,7 +162,8 @@ public:
         ekm.configure(basis.b);
       }
     }
-    desmat.build(stdb,norm,dc);
+    bool wcopy=false;
+    desmat.build(stdb,norm,dc,wcopy);
     train(desmat);
   }
 
@@ -249,8 +251,9 @@ public:
 
     LinearRegressor::read_sigma(config_pred,Sigma);
     DesignMatrix<K> dm(kernel,config_pred);
-    dm.scale=false; // do not scale energy, forces and stresses
-    dm.build(stdb,norm,dc);
+    // dm.scale=false; // do not scale energy, forces and stresses
+    bool wcopy=false;
+    dm.build(stdb,norm,dc,wcopy);
 
     // compute error
     predicted_error = T_MDMT_diag(dm.getPhi(), Sigma);
@@ -298,9 +301,9 @@ Hint: check different predict() methods.");
     // compute energy, forces and stresses
     aed_type Tpred = T_dgemv(Phi, weights);
 
-    double eweightglob=config.template get<double>("EWEIGHT");
-    double fweightglob=config.template get<double>("FWEIGHT");
-    double sweightglob=config.template get<double>("SWEIGHT");
+    // double eweightglob=config.template get<double>("EWEIGHT");
+    // double fweightglob=config.template get<double>("FWEIGHT");
+    // double sweightglob=config.template get<double>("SWEIGHT");
 
     // Construct StructureDB object with predicted values
     StructureDB stdb_;
@@ -309,19 +312,19 @@ Hint: check different predict() methods.");
     size_t i=0;
     while (i<Phi.rows()) {
 
-      stdb_(s).energy = Tpred(i++)*stdb(s).natoms()/eweightglob/stdb(s).eweight;
+      stdb_(s).energy = Tpred(i++)*stdb(s).natoms();
       if (config.template get<bool>("FORCE")) {
         stdb_(s).atoms.resize(stdb(s).natoms());
         for (size_t a=0; a<stdb(s).natoms(); ++a) {
           for (size_t k=0; k<3; ++k) {
-            stdb_(s).atoms[a].force[k] = Tpred(i++)/fweightglob/stdb(s).fweight;
+            stdb_(s).atoms[a].force[k] = Tpred(i++);
           }
         }
       }
       if (config.template get<bool>("STRESS")) {
         for (size_t x=0; x<3; ++x) {
           for (size_t y=x; y<3; ++y) {
-            stdb_(s).stress(x,y) = Tpred(i++)/sweightglob/stdb(s).sweight;
+            stdb_(s).stress(x,y) = Tpred(i++);
             if (x!=y)
               stdb_(s).stress(y,x) = stdb_(s).stress(x,y);
           }
diff --git a/src/DesignMatrixWorkspace.cpp b/src/DesignMatrixWorkspace.cpp
index 8390cd6..51053e9 100644
--- a/src/DesignMatrixWorkspace.cpp
+++ b/src/DesignMatrixWorkspace.cpp
@@ -8,13 +8,18 @@ namespace tadah {
 namespace mlip {
 namespace memory {
 
-DesignMatrixWorkspace::DesignMatrixWorkspace() {}
+DesignMatrixWorkspace::DesignMatrixWorkspace(bool wcopy): wcopy(wcopy) {}
 DesignMatrixWorkspace::~DesignMatrixWorkspace() {}
 
 void DesignMatrixWorkspace::allocate(size_t m_, size_t n_) {
   Phi.resize(m_,n_);
   T.resize(m_);
   Tlabels.resize(m_);
+  wfactors.resize(m_);
+  if (wcopy) {
+    Phi_cpy.resize(m_,n_);
+    T_cpy.resize(m_);
+  }
 }
 
 bool DesignMatrixWorkspace::isSufficient(size_t m_, size_t n_) const {
diff --git a/src/MLIPWorkspaceManager.cpp b/src/MLIPWorkspaceManager.cpp
index 019be87..a077f45 100644
--- a/src/MLIPWorkspaceManager.cpp
+++ b/src/MLIPWorkspaceManager.cpp
@@ -12,10 +12,10 @@ MLIPWorkspaceManager::~MLIPWorkspaceManager() {
     releaseDesignMatrixWorkspace(designMatrixWorkspace_);
 }
 
-DesignMatrixWorkspace* MLIPWorkspaceManager::getDesignMatrixWorkspace(size_t m, size_t n) {
+DesignMatrixWorkspace* MLIPWorkspaceManager::getDesignMatrixWorkspace(size_t m, size_t n, bool wcopy) {
     if (!designMatrixWorkspace_ || !designMatrixWorkspace_->isSufficient(m, n)) {
         releaseDesignMatrixWorkspace(designMatrixWorkspace_);
-        designMatrixWorkspace_ = new DesignMatrixWorkspace();
+        designMatrixWorkspace_ = new DesignMatrixWorkspace(wcopy);
         designMatrixWorkspace_->allocate(m, n);
     }
     return designMatrixWorkspace_;
diff --git a/src/dm_bf_linear.cpp b/src/dm_bf_linear.cpp
index 4f5dac2..a19ad70 100644
--- a/src/dm_bf_linear.cpp
+++ b/src/dm_bf_linear.cpp
@@ -15,18 +15,18 @@ size_t DM_BF_Linear::get_phi_cols(const Config &config)
   return cols;
 }
 void DM_BF_Linear::calc_phi_energy_row(phi_type &Phi, size_t &row,
-                                       const double fac, const Structure &, const StDescriptors &st_d)
+                                       const Structure &, const StDescriptors &st_d)
 {
   for (size_t i=0; i<st_d.naed(); ++i) {
     const aed_type &aed = st_d.get_aed(i);
     for (size_t j=0; j<aed.size(); ++j) {
-      Phi(row,j)+=aed[j]*fac;
+      Phi(row,j)+=aed[j];
     }
   }
   row++;
 }
 void DM_BF_Linear::calc_phi_force_rows(phi_type &Phi, size_t &row,
-                                       const double fac, const Structure &st, const StDescriptors &st_d)
+                                       const Structure &st, const StDescriptors &st_d)
 {
 
   for (size_t a=0; a<st.natoms(); ++a) {
@@ -37,7 +37,7 @@ void DM_BF_Linear::calc_phi_force_rows(phi_type &Phi, size_t &row,
       const fd_type &fji = st_d.fd[j][aa];
       for (size_t k=0; k<3; ++k) {
         for (size_t d=0; d<fij.rows(); ++d) {
-          Phi(row+k,d) -= fac*(fij(d,k)-fji(d,k));
+          Phi(row+k,d) -= (fij(d,k)-fji(d,k));
         }
       }
     }
@@ -45,7 +45,7 @@ void DM_BF_Linear::calc_phi_force_rows(phi_type &Phi, size_t &row,
   }
 }
 void DM_BF_Linear::calc_phi_stress_rows(phi_type &Phi, size_t &row,
-                                        const double fac[6], const Structure &st, const StDescriptors &st_d)
+                                        const Structure &st, const StDescriptors &st_d)
 {
   double V_inv = 1/st.get_volume();
   for (size_t i=0; i<st.natoms(); ++i) {
@@ -60,7 +60,7 @@ void DM_BF_Linear::calc_phi_stress_rows(phi_type &Phi, size_t &row,
       for (size_t x=0; x<3; ++x) {
         for (size_t y=x; y<3; ++y) {
           for (size_t d=0; d<fdij.rows(); ++d) {
-            Phi(row+mn,d) += V_inv*(fdij(d,y)-fdji(d,y))*0.5*fac[mn]*(ri(x)-rj(x));
+            Phi(row+mn,d) += V_inv*(fdij(d,y)-fdji(d,y))*0.5*(ri(x)-rj(x));
           }
           mn++;
         }
diff --git a/src/dm_bf_polynomial2.cpp b/src/dm_bf_polynomial2.cpp
index c7fb188..f63f0ce 100644
--- a/src/dm_bf_polynomial2.cpp
+++ b/src/dm_bf_polynomial2.cpp
@@ -18,7 +18,6 @@ size_t DM_BF_Polynomial2::get_phi_cols(const Config &config)
 }
 void DM_BF_Polynomial2::calc_phi_energy_row(phi_type &Phi,
                                             size_t &row,
-                                            const double fac,
                                             const Structure &,
                                             const StDescriptors &st_d)
 {
@@ -27,7 +26,7 @@ void DM_BF_Polynomial2::calc_phi_energy_row(phi_type &Phi,
     size_t b=0;
     for (size_t i=0; i<st_d.dim(); ++i) {
       for (size_t ii=i; ii<st_d.dim(); ++ii) {
-        Phi(row,b++) += aed(i)*aed(ii)*fac;
+        Phi(row,b++) += aed(i)*aed(ii);
       }
     }
   }
@@ -35,7 +34,6 @@ void DM_BF_Polynomial2::calc_phi_energy_row(phi_type &Phi,
 }
 void DM_BF_Polynomial2::calc_phi_force_rows(phi_type &Phi,
                                             size_t &row,
-                                            const double fac,
                                             const Structure &st,
                                             const StDescriptors &st_d)
 {
@@ -52,7 +50,7 @@ void DM_BF_Polynomial2::calc_phi_force_rows(phi_type &Phi,
         size_t b=0;
         for (size_t i=0; i<fdij.rows(); ++i) {
           for (size_t ii=i; ii<fdij.rows(); ++ii) {
-            Phi(row+k,b) -= fac*(fdij(i,k)*aedi(ii) + fdij(ii,k)*aedi(i)
+            Phi(row+k,b) -= (fdij(i,k)*aedi(ii) + fdij(ii,k)*aedi(i)
               - fdji(i,k)*aedj(ii) - fdji(ii,k)*aedj(i));
             b++;
           }
@@ -65,7 +63,6 @@ void DM_BF_Polynomial2::calc_phi_force_rows(phi_type &Phi,
 }
 void DM_BF_Polynomial2::calc_phi_stress_rows(phi_type &Phi,
                                              size_t &row,
-                                             const double fac[6],
                                              const Structure &st, 
                                              const StDescriptors &st_d)
 {
@@ -86,7 +83,7 @@ void DM_BF_Polynomial2::calc_phi_stress_rows(phi_type &Phi,
           size_t b=0;
           for (size_t i=0; i<fdij.rows(); ++i) {
             for (size_t ii=i; ii<fdij.rows(); ++ii) {
-              Phi(row+mn,b) += V_inv*0.5*fac[mn]*(ri(x)-rj(x))
+              Phi(row+mn,b) += V_inv*0.5*(ri(x)-rj(x))
                 *(fdij(i,y)*aedi(ii) + fdij(ii,y)*aedi(i)
                 - fdji(i,y)*aedj(ii) - fdji(ii,y)*aedj(i));
               b++;
diff --git a/src/dm_kern_base.cpp b/src/dm_kern_base.cpp
index f78c71e..7a07443 100644
--- a/src/dm_kern_base.cpp
+++ b/src/dm_kern_base.cpp
@@ -13,18 +13,17 @@ size_t DM_Kern_Base::get_phi_cols(const Config &)
 {
   return basis.cols();
 }
-void  DM_Kern_Base::calc_phi_energy_row(phi_type &Phi, size_t &row, const double fac,
+void  DM_Kern_Base::calc_phi_energy_row(phi_type &Phi, size_t &row,
                                         const Structure &, const StDescriptors &st_d)
 {
   for (size_t a=0; a<st_d.naed();++a) {
     for (size_t b=0; b<basis.cols(); ++b) {
-      Phi(row,b) += (*this)(basis.col(b),st_d.get_aed(a))*fac;
+      Phi(row,b) += (*this)(basis.col(b),st_d.get_aed(a));
     }
   }
   row++;
-  //Phi.row(row++) *= fac;
 }
-void  DM_Kern_Base::calc_phi_force_rows(phi_type &Phi, size_t &row, const double fac,
+void  DM_Kern_Base::calc_phi_force_rows(phi_type &Phi, size_t &row,
                                         const Structure &st, const StDescriptors &st_d) {
   for (size_t i=0; i<st.natoms(); ++i) {
     const aed_type& aedi = st_d.get_aed(i);
@@ -36,7 +35,7 @@ void  DM_Kern_Base::calc_phi_force_rows(phi_type &Phi, size_t &row, const double
       const aed_type& aedj = st_d.get_aed(j);
       for (size_t b=0; b<basis.cols(); ++b) {
         for (size_t k=0; k<3; ++k) {
-          Phi(row+k,b) -= fac*((*this).prime(basis.col(b), aedi,fdij(k)) -
+          Phi(row+k,b) -= ((*this).prime(basis.col(b), aedi,fdij(k)) -
             (*this).prime(basis.col(b),aedj,fdji(k)));
         }
       }
@@ -44,7 +43,7 @@ void  DM_Kern_Base::calc_phi_force_rows(phi_type &Phi, size_t &row, const double
     row+=3;
   }
 }
-void  DM_Kern_Base::calc_phi_stress_rows(phi_type &Phi, size_t &row, const double fac[6],
+void  DM_Kern_Base::calc_phi_stress_rows(phi_type &Phi, size_t &row,
                                          const Structure &st, const StDescriptors &st_d)
 {
   double V_inv = 1/st.get_volume();
@@ -62,7 +61,7 @@ void  DM_Kern_Base::calc_phi_stress_rows(phi_type &Phi, size_t &row, const doubl
       for (size_t x=0; x<3; ++x) {
         for (size_t y=x; y<3; ++y) {
           for (size_t b=0; b<basis.cols(); ++b) {
-            Phi(row+mn,b) += V_inv*0.5*fac[mn]*(ri(x)-rj(x))*
+            Phi(row+mn,b) += V_inv*0.5*(ri(x)-rj(x))*
               ((*this).prime(basis.col(b),aedi,fdij(y)) -
               (*this).prime(basis.col(b),aedj,fdji(y)));
           }
diff --git a/src/dm_kern_linear.cpp b/src/dm_kern_linear.cpp
index 84c484b..ca36e2a 100644
--- a/src/dm_kern_linear.cpp
+++ b/src/dm_kern_linear.cpp
@@ -16,18 +16,18 @@ size_t DM_Kern_Linear::get_phi_cols(const Config &config)
   return cols;
 }
 void DM_Kern_Linear::calc_phi_energy_row(phi_type &Phi, size_t &row,
-                                         const double fac, const Structure &, const StDescriptors &st_d)
+                                         const Structure &, const StDescriptors &st_d)
 {
   for (size_t a=0; a<st_d.naed();++a) {
     const aed_type &aed = st_d.get_aed(a);  // TODO
     for (size_t j=0; j<aed.size(); ++j) {
-      Phi(row,j)+=aed[j]*fac;
+      Phi(row,j)+=aed[j];
     }
   }
   row++;
 }
 void DM_Kern_Linear::calc_phi_force_rows(phi_type &Phi, size_t &row,
-                                         const double fac, const Structure &st, const StDescriptors &st_d)
+                                         const Structure &st, const StDescriptors &st_d)
 {
   for (size_t a=0; a<st.natoms(); ++a) {
     for (size_t jj=0; jj<st_d.fd[a].size(); ++jj) {
@@ -35,7 +35,7 @@ void DM_Kern_Linear::calc_phi_force_rows(phi_type &Phi, size_t &row,
       const size_t aa = st.get_nn_iindex(a,j,jj);
       for (size_t k=0; k<3; ++k) {
         aed_type temp = (st_d.fd[a][jj](k)-
-          st_d.fd[j][aa](k))*fac;
+          st_d.fd[j][aa](k));
         for (size_t d=0; d<temp.size(); ++d) {
           Phi(row+k,d) -= temp[d];
         }
@@ -46,7 +46,7 @@ void DM_Kern_Linear::calc_phi_force_rows(phi_type &Phi, size_t &row,
 
 }
 void DM_Kern_Linear::calc_phi_stress_rows(phi_type &Phi, size_t &row,
-                                          const double fac[6], const Structure &st, const StDescriptors &st_d)
+                                          const Structure &st, const StDescriptors &st_d)
 {
   double V_inv = 1/st.get_volume();
   for (size_t i=0; i<st.natoms(); ++i) {
@@ -60,7 +60,7 @@ void DM_Kern_Linear::calc_phi_stress_rows(phi_type &Phi, size_t &row,
       size_t mn=0;
       for (size_t x=0; x<3; ++x) {
         for (size_t y=x; y<3; ++y) {
-          aed_type temp = V_inv*(fdij(y)-fdji(y))*0.5*fac[mn]*(ri(x)-rj(x));
+          aed_type temp = V_inv*(fdij(y)-fdji(y))*0.5*(ri(x)-rj(x));
           for (size_t d=0; d<temp.size(); ++d) {
             Phi(row+mn,d) += temp[d];
           }
-- 
GitLab