From a4f0b5cc42becef2fd55a12ae4784e5421a32139 Mon Sep 17 00:00:00 2001
From: mkirsz <s1351949@sms.ed.ac.uk>
Date: Wed, 19 Feb 2025 00:21:55 +0000
Subject: [PATCH] Integration of memoryManager

---
 include/tadah/mlip/models/m_blr.h |  8 ++++++++
 include/tadah/mlip/models/m_krr.h |  8 ++++++++
 include/tadah/mlip/trainer.h      |  2 +-
 src/m_all.cpp                     | 17 ++++++++++++++---
 4 files changed, 31 insertions(+), 4 deletions(-)

diff --git a/include/tadah/mlip/models/m_blr.h b/include/tadah/mlip/models/m_blr.h
index 943a7fe..c8e3eb0 100644
--- a/include/tadah/mlip/models/m_blr.h
+++ b/include/tadah/mlip/models/m_blr.h
@@ -8,6 +8,7 @@
 #include <tadah/mlip/normaliser.h>
 #include <tadah/models/m_blr_train.h>
 #include <tadah/core/config.h>
+#include <tadah/models/memory/IModelsWorkspaceManager.h>
 
 #include <limits>
 #include <stdexcept>
@@ -76,6 +77,13 @@ public:
     norm = Normaliser(c);
   }
 
+  M_BLR(BF &bf, Config &c, tadah::models::memory::IModelsWorkspaceManager& workspaceManager):
+    M_BLR_Train<BF>(bf,c,workspaceManager),
+    desmat(bf,c)
+  {
+    norm = Normaliser(c);
+  }
+
   double epredict(const aed_type &aed) const{
     return bf.epredict(weights,aed);
   };
diff --git a/include/tadah/mlip/models/m_krr.h b/include/tadah/mlip/models/m_krr.h
index 92aee1d..ccd2c6c 100644
--- a/include/tadah/mlip/models/m_krr.h
+++ b/include/tadah/mlip/models/m_krr.h
@@ -9,6 +9,7 @@
 #include <tadah/models/m_krr_train.h>
 #include <tadah/core/config.h>
 #include <tadah/mlip/models/basis.h>
+#include <tadah/models/memory/IModelsWorkspaceManager.h>
 
 #include <limits>
 #include <stdexcept>
@@ -79,6 +80,13 @@ public:
     norm = Normaliser(c);
   }
 
+  M_KRR(K &kernel, Config &c, tadah::models::memory::IModelsWorkspaceManager& workspaceManager):
+    M_KRR_Train<K>(kernel,c,workspaceManager),
+    basis(c),
+    desmat(kernel,c)
+  {
+    norm = Normaliser(c);
+  }
   double epredict(const aed_type &aed) const {
     return kernel.epredict(weights,aed);
   };
diff --git a/include/tadah/mlip/trainer.h b/include/tadah/mlip/trainer.h
index e206917..a212a9d 100644
--- a/include/tadah/mlip/trainer.h
+++ b/include/tadah/mlip/trainer.h
@@ -36,7 +36,7 @@ class Trainer {
       nnf(config),
       fb(CONFIG::factory<DM_Function_Base,Config&>(
             config.get<std::string>("MODEL",1),config)),
-      model(CONFIG::factory<M_Tadah_Base,DM_Function_Base&,Config&>
+      model(CONFIG::factory<M_Tadah_Base,DM_Function_Base&,Config&,tadah::models::memory::IModelsWorkspaceManager&>
           (config.get<std::string>("MODEL",0),*fb,config,workspaceManager)),
           //(config.get<std::string>("MODEL",0),*fb,config)),
       dm(*fb, config)
diff --git a/src/m_all.cpp b/src/m_all.cpp
index 9c67972..fa99c8c 100644
--- a/src/m_all.cpp
+++ b/src/m_all.cpp
@@ -1,5 +1,16 @@
 #include <tadah/mlip/models/m_all.h>
+#include <tadah/models/memory/IModelsWorkspaceManager.h>
 
-template<> CONFIG::Registry<M_Tadah_Base,DM_Function_Base&,Config&>::Map CONFIG::Registry<M_Tadah_Base,DM_Function_Base&,Config&>::registry{};
-CONFIG::Registry<M_Tadah_Base,DM_Function_Base&,Config&>::Register<M_KRR<>> M_KRR_1("M_KRR");
-CONFIG::Registry<M_Tadah_Base,DM_Function_Base&,Config&>::Register<M_BLR<>> M_BLR_1("M_BLR");
+template<>
+CONFIG::Registry<M_Tadah_Base, DM_Function_Base&, Config&>::Map
+CONFIG::Registry<M_Tadah_Base, DM_Function_Base&, Config&>::registry{};
+
+CONFIG::Registry<M_Tadah_Base, DM_Function_Base&, Config&>::Register<M_KRR<>> M_KRR_1("M_KRR");
+CONFIG::Registry<M_Tadah_Base, DM_Function_Base&, Config&>::Register<M_BLR<>> M_BLR_1("M_BLR");
+
+template<>
+CONFIG::Registry<M_Tadah_Base, DM_Function_Base&, Config&, tadah::models::memory::IModelsWorkspaceManager&>::Map
+CONFIG::Registry<M_Tadah_Base, DM_Function_Base&, Config&, tadah::models::memory::IModelsWorkspaceManager&>::registry{};
+
+CONFIG::Registry<M_Tadah_Base, DM_Function_Base&, Config&, tadah::models::memory::IModelsWorkspaceManager&>::Register<M_BLR<>> M_BLR_2("M_BLR");
+CONFIG::Registry<M_Tadah_Base, DM_Function_Base&, Config&, tadah::models::memory::IModelsWorkspaceManager&>::Register<M_KRR<>> M_KRR_2("M_KRR");
-- 
GitLab