From 827592a8fdb9752f4b6202dcc2a4cfd0da322d28 Mon Sep 17 00:00:00 2001
From: Marcin Kirsz <mkirsz@ed.ac.uk>
Date: Mon, 10 Feb 2025 16:48:59 +0000
Subject: [PATCH] Added ErRMSE

---
 include/tadah/mlip/analytics/analytics.h |  3 +++
 src/analytics.cpp                        | 20 ++++++++++++++++++++
 2 files changed, 23 insertions(+)

diff --git a/include/tadah/mlip/analytics/analytics.h b/include/tadah/mlip/analytics/analytics.h
index 4a7dc4a..7dd1016 100644
--- a/include/tadah/mlip/analytics/analytics.h
+++ b/include/tadah/mlip/analytics/analytics.h
@@ -28,6 +28,9 @@ class Analytics {
          */
         t_type calc_s_mae() const;
 
+        /** Return Energy/atom Relative Root Mean Square Error for each DBFILE. */
+        t_type calc_e_rrmse() const;
+
         /** Return Energy/atom Root Mean Square Error for each DBFILE. */
         t_type calc_e_rmse() const;
 
diff --git a/src/analytics.cpp b/src/analytics.cpp
index 4865127..57d1dca 100644
--- a/src/analytics.cpp
+++ b/src/analytics.cpp
@@ -78,6 +78,26 @@ t_type Analytics::calc_s_mae() const {
     return smae_vec;
 }
 
+t_type Analytics::calc_e_rrmse() const {
+  t_type errmse_vec(st.dbidx.size()-1);
+  double errmse=0;
+  size_t dbidx=0;
+  size_t N=0;
+  for (size_t i=0; i<st.size(); ++i) {
+    if (st(i).energy != 0.0) {
+      errmse += std::pow((st(i).energy - stp(i).energy)/st(i).energy,2);
+      N++;
+    }
+    if (i+1==st.dbidx[dbidx+1]) {
+      errmse_vec(dbidx)=std::sqrt(errmse/N);
+      errmse=0;
+      N=0;
+      dbidx++;
+    }
+  }
+  return errmse_vec;
+}
+
 t_type Analytics::calc_e_rmse() const{
 
     t_type ermse_vec(st.dbidx.size()-1);
-- 
GitLab