From 20367b3280f585990a22cae3259af081822ce095 Mon Sep 17 00:00:00 2001
From: mkirsz <s1351949@sms.ed.ac.uk>
Date: Mon, 17 Mar 2025 10:30:51 +0000
Subject: [PATCH] WiP

---
 src/predict_engine.cpp | 75 ++++++++++++++++++++++++++++++++++++++++--
 1 file changed, 73 insertions(+), 2 deletions(-)

diff --git a/src/predict_engine.cpp b/src/predict_engine.cpp
index a6480f9..a067e0a 100644
--- a/src/predict_engine.cpp
+++ b/src/predict_engine.cpp
@@ -7,8 +7,79 @@ namespace tadah {
 namespace mlip {
 namespace engines {
 
-int PredictEngine::predict(const tadah::core::Config &pot_config, const tadah::mlip::StructureDB &stdb) {
-  std::cout << "PredictEngine: Performing prediction (stub implementation)." << std::endl;
+int PredictEngine::predict(const tadah::core::Config &pot_config, const StructureDB &stdb) {
+  if (rank!=0) return EXIT_SUCCESS;
+
+  tadah::core::Config pot_config(pot_file);
+
+
+  DC_Selector DCS(pot_config);
+
+
+  if (is_verbose()) std::cout << "Finding nearest neighbours within: " <<
+    pot_config.get<double>("RCUTMAX") << " cutoff distance..." << std::flush;
+  NNFinder nnf(pot_config);
+  nnf.calc(stdb);
+  if (is_verbose()) std::cout << "Done!" << std::endl;
+
+  DescriptorsCalc<> dc(pot_config,*DCS.d2b,*DCS.d3b,*DCS.dmb,
+                       *DCS.c2b,*DCS.c3b,*DCS.cmb);
+
+  if (is_verbose()) std::cout << "Prediction..." << std::flush;
+  DM_Function_Base *fb = tadah::core::factory<DM_Function_Base,Config&>(
+    pot_config.get<std::string>("MODEL",1),pot_config);
+  M_Tadah_Base *modelp = tadah::core::factory<M_Tadah_Base,DM_Function_Base&,Config&>(
+    pot_config.get<std::string>("MODEL",0),*fb,pot_config);
+
+  StructureDB stpred;
+  aed_type predicted_error;
+  if (predict->count("--error")) {
+    stpred = modelp->predict(pot_config,stdb,dc,predicted_error);
+  }
+  else {
+    stpred = modelp->predict(pot_config,stdb,dc);
+  }
+
+  if (is_verbose()) std::cout << "Done!" << std::endl;
+
+  if (is_verbose()) std::cout << "Dumping output..." << std::flush;
+
+  Output output(pot_config,predict->count("--error"));
+
+  if (predict->count("--numeric"))
+    output.set_numeric(outprec);
+
+  output.print_predict_all(stdb,stpred,predicted_error);
+
+  if (is_verbose()) std::cout << "Done!" << std::endl;
+  if (is_verbose()) std::cout << timer_tot.to_string() << std::endl;
+
+  if(predict->count("--analytics")) {
+    Analytics a(stdb,stpred);
+
+    std::cout << "Energy MAE (meV/atom): " << 1000*a.calc_e_mae() << std::endl;
+    std::cout << "Energy RMSE (meV/atom): " << 1000*a.calc_e_rmse() << std::endl;
+    std::cout << "Energy R^2: " << a.calc_e_r_sq() << std::endl;
+
+    if (predict->count("--Force")) {
+      std::cout << "Force MAE (eV/A): "<< a.calc_f_mae() << std::endl;
+      std::cout << "Force RMSE (eV/A): "<< a.calc_f_rmse() << std::endl;
+      std::cout << "Force R^2: " << a.calc_f_r_sq() << std::endl;
+    }
+
+    if (predict->count("--Stress")) {
+      std::cout << "Stress MAE (eV/A^3): "<< a.calc_s_mae() << std::endl;
+      std::cout << "Stress RMSE (eV/A^3): "<< a.calc_s_rmse() << std::endl;
+      std::cout << "Stress R^2: " << a.calc_s_r_sq() << std::endl;
+    }
+  }
+
+  if(modelp)
+    delete modelp;
+  if(fb)
+    delete fb;
+
+  return 0;
   return EXIT_SUCCESS;
 }
 
-- 
GitLab