From 0e417ef4f7fb6a785c90b52711bdc33dd08d80d8 Mon Sep 17 00:00:00 2001
From: Marcin Kirsz <mkirsz@ed.ac.uk>
Date: Wed, 19 Feb 2025 13:00:05 +0000
Subject: [PATCH] Dataset histogram tool

---
 bin/tadah_cli.cpp | 246 +++++++++++++++++++++++++++++++++++-----------
 bin/tadah_cli.h   |   5 +
 2 files changed, 191 insertions(+), 60 deletions(-)

diff --git a/bin/tadah_cli.cpp b/bin/tadah_cli.cpp
index 1fdd5e8..0cd6010 100644
--- a/bin/tadah_cli.cpp
+++ b/bin/tadah_cli.cpp
@@ -393,6 +393,19 @@ TadahCLI::TadahCLI():
     ->required()
     ->check(CLI::ExistingFile);
 
+  /*---------------------------------------------------------------------------*/
+  /*     DB RESCALE                                                            */
+  /*---------------------------------------------------------------------------*/
+  usage = 
+    "Add energy rescaling factors to dataset(s)\n";
+  rescale = db->add_subcommand("rescale", usage);
+
+  usage = "Specify one or more Tadah! datasets\n to add energy rescaling factors\n"
+    "The output is saved to new datasets(s) named:\n"
+    "ORIGNAL_FILE_rescaled\n";
+  rescale->add_option("-d,--datasets", datasets, usage)
+    ->required()
+    ->check(CLI::ExistingFile);
 
   /*---------------------------------------------------------------------------*/
   /*     DESCRIPTORS                                                           */
@@ -511,12 +524,11 @@ TadahCLI::TadahCLI():
   /*---------------------------------------------------------------------------*/
 
   usage = 
-    "Plot basis functions\n(and optionally cutoff values)\nused in two- and many-body expansions.\n"
-    "GNUPLOT plotting IS NOT IMPLEMENTED everything else should work\n";
+    "Plot basis functions\n(and optionally cutoff values)\nused in two- and many-body expansions.\n";
   bf = plot->add_subcommand("bf", usage);
 
   usage = 
-    "This function can generate a formatted text file for plotting with external software, or it can use gnuplot  (if available during Tadah! compilation) for visualization.\n"
+    "This function can generate a formatted text file for plotting with external software.\n"
     "Input must be either a trained potential model or a configuration file with the following keys:\n"
     "  - SGRID2B (or SGRIDMB) <--REQUIRED\n"
     "  - CGRID2B (or CGRIDMB) <--REQUIRED\n"
@@ -529,15 +541,21 @@ TadahCLI::TadahCLI():
   bf->usage(usage);
 
   usage = 
-    "Provide either a potential file\nor a config file containing\nthe necessary keys.\n";
+    "Provide either a potential file\nor a config file containing\nthe necessary keys.\n"
+    "The file must contain the following keys:\n"
+    " - INIT2B | INITMB\n"
+    " - TYPE2B | TYPEMB\n"
+    " - RCUT2B | RCUTMB\n"
+    " - RCTYPE2B | RCTYPEMB\n"
+    " - CGRID2B | CGRIDMB\n"
+    " - SGRID2B | SGRIDMB\n";
 
   bf->add_option("-c,--config", config_file, usage)
     ->check(CLI::ExistingFile)
     ->required();
 
   usage =
-    "Specify the output file for text data\nand (optionally) a plot file using gnuplot.\n"
-    "Supported plot file extensions:\n.png, .pdf, .eps, .svg.\n"
+    "Specify the output file for text data.\n"
     "Pathnames must not exist.\n";
 
   bf->add_option("-o,--outfile", out_files, usage)
@@ -547,15 +565,15 @@ TadahCLI::TadahCLI():
     ->required();
 
   bf->add_option("-t,--type", types, 
-                 "Specify the type of basis function and interaction:\n"
-                 "  - \"B 2b Y/N\": Blips for two-body interactions\n"
-                 "  - \"B mb Y/N\": Blips for many-body interactions\n"
-                 "  - \"G 2b Y/N\": Gaussians for two-body interactions\n"
-                 "  - \"G mb Y/N\": Gaussians for many-body interactions\n"
-                 "Y/N indicates whether to plot the cutoff function.\n")
+                 "Specify whether to plot two- or many-body grid\n"
+                 "Y/N indicates whether to also plot the cutoff function.\n"
+                 "  - \"2b Y/N\": Plot grid for TYPE2B\n"
+                 "  - \"mb Y/N\": Plot grid for TYPEMB\n"
+                 "\n")
     ->required();
 
-  bf->add_flag("-s,--scale", "Scale the heights of the basis\nfunctions by the cutoff value.\n");
+  bf->add_flag("-D,--Derivative", "Will plot derivatives of\nbasis functions and cutoffs.\n");
+  bf->add_flag("-s,--scale", "Rescale the heights of the basis\nfunctions by the cutoff value.\n");
 
   usage =
     "Indices of basis functions to be plotted\n"
@@ -581,8 +599,7 @@ TadahCLI::TadahCLI():
   /*---------------------------------------------------------------------------*/
 
   usage = 
-    "Plot two-body potential.\n"
-    "GNUPLOT plotting IS NOT IMPLEMENTED everything else should work\n";
+    "Plot two-body potential.\n";
 
   twobody = plot->add_subcommand("twobody", usage);
 
@@ -597,8 +614,7 @@ TadahCLI::TadahCLI():
     ->required();
 
   usage =
-    "Specify the output file for text data\nand (optionally) a plot file using gnuplot.\n"
-    "Supported plot file extensions:\n.png, .pdf, .eps, .svg.\n"
+    "Specify the output file for text data.\n"
     "Pathnames must not exist.\n";
 
   twobody->add_option("-o,--outfile", out_files, usage)
@@ -627,8 +643,7 @@ TadahCLI::TadahCLI():
   /*---------------------------------------------------------------------------*/
 
   usage = 
-    "Plot cutoffs.\n"
-    "GNUPLOT plotting IS NOT IMPLEMENTED everything else should work\n";
+    "Plot cutoffs.\n";
 
   cutoff = plot->add_subcommand("cutoff", usage);
 
@@ -638,8 +653,7 @@ TadahCLI::TadahCLI():
   cutoff->usage(usage);
 
   usage =
-    "Specify the output file for text data\nand (optionally) a plot file using gnuplot.\n"
-    "Supported plot file extensions:\n.png, .pdf, .eps, .svg.\n"
+    "Specify the output file for text data.\n"
     "Pathnames must not exist.\n";
 
   cutoff->add_option("-o,--outfile", out_files, usage)
@@ -663,6 +677,39 @@ TadahCLI::TadahCLI():
   cutoff->add_flag("-V,--Verbose", verbose, 
                  "Enable verbose output for detailed information.\n");
 
+
+/*---------------------------------------------------------------------------*/
+/*     PLOT DATA                                                             */
+/*---------------------------------------------------------------------------*/
+
+usage = 
+  "Visualize Tadah! datasets.\n";
+
+vdata = plot->add_subcommand("vdata", usage);
+
+usage = 
+  "This command provides tools for data visualization.\n"
+  "It generates a temperature vs. density histogram.\n"
+  "It counts the number of configurations that fall within each 2D bin (T, rho).\n"
+  "The formatted output is T rho count, where T and rho are the centers of the bins.\n";
+
+vdata->usage(usage);
+
+usage =
+  "One or more datasets for prediction.\n";
+
+vdata->add_option("-d,--datasets", datasets, usage)
+  ->required();
+
+vdata->add_option("-b,--bins", sizes, "Number of bins in the x and y directions.\n")
+  ->required();
+
+vdata->add_option("-n,--numeric", outprec,
+                  "Set the number of decimal places for numerical output.\n");
+
+vdata->add_flag("-V,--verbose", verbose, 
+                "Enable verbose output for detailed information.\n");
+
 }
 
 int TadahCLI::run(int argc, char** argv) {
@@ -699,6 +746,9 @@ int TadahCLI::run(int argc, char** argv) {
     else if (*summary) {
       return subsub_summary();
     }
+    else if (*rescale) {
+      return subsub_rescale();
+    }
     else if (*sample) {
       return subsub_sample();
     }
@@ -721,6 +771,9 @@ int TadahCLI::run(int argc, char** argv) {
     else if (*cutoff) {
       return subsub_cutoff();
     }
+    else if (*vdata) {
+      return subsub_data();
+    }
   }
   return 0;
 }
@@ -1182,6 +1235,21 @@ int TadahCLI::subsub_summary() {
   }
   return 0;
 }
+int TadahCLI::subsub_rescale() {
+  if (rank!=0) return 0;
+  CLI::Timer timer_tot{"Rescaling", CLI::Timer::Big};
+  for (const auto &infile: datasets) {
+    StructureDB stdb;
+    stdb.add(infile);
+    for (auto &s: stdb) {
+      s.eweight = std::abs(1/s.energy);
+    }
+    std::string outfile = infile + "_rescaled";
+    stdb.dump_to_file(outfile);
+  }
+  if (verbose) std::cout << timer_tot.to_string() << std::endl;
+  return 0;
+}
 
 int TadahCLI::subsub_sample() {
   if (rank!=0) return 0;
@@ -1517,46 +1585,47 @@ int TadahCLI::subsub_bf() {
   Config config(config_file);
   DC_Selector DCS(config);
 
-  v_type cgrid;
-  v_type sgrid;
-
   boost::to_upper(types[0]);
   boost::to_upper(types[1]);
-  boost::to_upper(types[2]);
   Cut_Base *fcut = nullptr;
-  if (types[1]=="2B") {
-    DCS.d2b->get_grid(config, "CGRID2B", cgrid);
-    DCS.d2b->get_grid(config, "SGRID2B", sgrid);
-    if (types[2]=="Y") {
+  double (*fb)(double,double,double) = nullptr;
+  std::string cgridtype;
+  std::string sgridtype;
+  std::string dtype;
+  if (types[0]=="2B") {
+    cgridtype="CGRID2B";
+    sgridtype="SGRID2B";
+    dtype = config.get<std::string>("TYPE2B");
+    if (types[1]=="Y" || bf->count("--scale")) {
       double rcut = config.get<double>("RCUT2BMAX");
       fcut = CONFIG::factory<Cut_Base,double>( config.get<std::string>("RCTYPE2B"), rcut);
     }
-  }
-  else if (types[1]=="MB") {
-    DCS.dmb->get_grid(config, "CGRIDMB", cgrid);
-    DCS.dmb->get_grid(config, "SGRIDMB", sgrid);
-    if (types[2]=="Y") {
+  } else if (types[0]=="MB") {
+    cgridtype="CGRIDMB";
+    sgridtype="SGRIDMB";
+    dtype = config.get<std::string>("TYPEMB");
+    if (types[1]=="Y" || bf->count("--scale")) {
       double rcut = config.get<double>("RCUTMBMAX");
       fcut = CONFIG::factory<Cut_Base,double>( config.get<std::string>("RCTYPEMB"), rcut);
     }
-  }
-  else {
+  } else {
     throw std::runtime_error("Unsupported -t, --type: " + types[0] + " " + types[1]);
   }
 
-  double (*fb)(double,double,double) = nullptr;
-  if (types[0]=="B") {
-    std::cout << "B type: " << types[0] << std::endl;
-    fb = &B;
-  }
-  else if (types[0]=="G") {
-    std::cout << "G type: " << types[0] << std::endl;
-    fb = &G;
-  }
-  else {
-    throw std::runtime_error("Unsupported -t, --type: " + types[0] + " " + types[1]);
+  v_type cgrid;
+  v_type sgrid;
+  D_Base::get_grid(config, cgridtype, cgrid);
+  D_Base::get_grid(config, sgridtype, sgrid);
+
+  if (dtype == "D2_Blip" || dtype == "DM_Blip") {
+    if (bf->count("--Derivative")) { fb = &dB;} else { fb = &B; }
+  } else if (dtype == "D2_BP" || dtype == "DM_EAD" || dtype == "DM_mEAD") {
+    if (bf->count("--Derivative")) { fb = &dG;} else { fb = &G; }
+  } else {
+    throw std::runtime_error("Unsupported: " + dtype);
   }
 
+
   std::vector<size_t> idx;
   if (bf->count("-i")) {
     idx = parse_indices(indices);
@@ -1585,11 +1654,16 @@ int TadahCLI::subsub_bf() {
   if (!bf->count("--numeric")) {
     outprec = 5;
   }
-  bool scale = bf->count("--scale");
   v_type cutoff_values (range.size(), 1.0);
-  if (scale) {
-    for (size_t i=0; i<range.size(); ++i) {
-      cutoff_values[i] = fcut->calc(range[i]);
+  if (bf->count("--scale")) {
+    if(bf->count("--Derivative")) {
+      for (size_t i=0; i<range.size(); ++i) {
+        cutoff_values[i] = fcut->calc_prime(range[i]);
+      }
+    } else {
+      for (size_t i=0; i<range.size(); ++i) {
+        cutoff_values[i] = fcut->calc(range[i]);
+      }
     }
   }
 
@@ -1606,15 +1680,25 @@ int TadahCLI::subsub_bf() {
     outfile << std::endl << std::endl;
   }
 
-  if (types[2]=="Y") {
-    for (size_t i=0; i<range.size(); ++i) {
-      outfile << std::fixed << std::setprecision(outprec)
-        << std::setw(outprec+5)
-        << range[i] << "    "
-        << std::setw(outprec+5) << fcut->calc(range[i]) << std::endl;
+  if (types[1]=="Y") {
+    if(bf->count("--Derivative")) {
+      for (size_t i=0; i<range.size(); ++i) {
+        outfile << std::fixed << std::setprecision(outprec)
+          << std::setw(outprec+5)
+          << range[i] << "    "
+          << std::setw(outprec+5) << fcut->calc_prime(range[i]) << std::endl;
+      }
+      outfile << std::endl << std::endl;
+    }
+    else {
+      for (size_t i=0; i<range.size(); ++i) {
+        outfile << std::fixed << std::setprecision(outprec)
+          << std::setw(outprec+5)
+          << range[i] << "    "
+          << std::setw(outprec+5) << fcut->calc(range[i]) << std::endl;
+      }
+      outfile << std::endl << std::endl;
     }
-    outfile << std::endl << std::endl;
-
   }
   outfile.flush();
 
@@ -1648,7 +1732,7 @@ int TadahCLI::subsub_twobody() {
   size_t N = static_cast<size_t>(range[2]);
   v_type range = linspace(start,stop,N);
 
-  if (!bf->count("--numeric")) {
+  if (!twobody->count("--numeric")) {
     outprec = 5;
   }
 
@@ -1776,3 +1860,45 @@ int TadahCLI::subsub_cutoff() {
   if (verbose) std::cout << timer_tot.to_string() << std::endl;
   return 0;
 }
+int TadahCLI::subsub_data() {
+  if (rank!=0) return 0;
+  CLI::Timer timer_tot{"Data Visualiser", CLI::Timer::Big};
+
+  if(vdata->count("--Verbose"))
+    set_verbose();
+
+  if (datasets.size()<1) throw std::runtime_error("At least one dataset is required. Check -d, --datasets");
+  if (sizes.size()!=2) throw std::runtime_error("Number of bins are requires. Provide two integer values. Check -b, --bins\n");
+  int x_bins = sizes[0];
+  int y_bins = sizes[1];
+
+  std::vector<std::pair<double,double>> data_2d;
+
+  std::vector<StructureDB> stdbs(datasets.size());
+  for (size_t i=0; i < stdbs.size(); ++i) {
+    stdbs[i].add(datasets[i]);
+    if (verbose) std::cout << datasets[i] << " " << stdbs[i].summary();
+  }
+
+  for (size_t i=0; i<stdbs.size(); ++i) {
+    for (size_t j=0; j<stdbs[i].size(); ++j) {
+      Structure &st = stdbs[i](j);
+      double T = st.get_temperature();
+      double rho = st.get_density();
+      data_2d.push_back(std::make_pair(T,rho));
+    }
+  }
+
+  std::vector<histogram_bin> histogram = generate_2d_histogram(data_2d, x_bins, y_bins);
+
+  for (const auto& bin : histogram) {
+    double x_center, y_center;
+    int count;
+    std::tie(x_center, y_center, count) = bin;
+    std::cout << x_center << " " << y_center << " " << count << "\n";
+  }
+
+  if (is_verbose()) std::cout << "Done!" << std::endl;
+  if (verbose) std::cout << timer_tot.to_string() << std::endl;
+  return 0;
+}
diff --git a/bin/tadah_cli.h b/bin/tadah_cli.h
index fd2b7fe..edb0513 100644
--- a/bin/tadah_cli.h
+++ b/bin/tadah_cli.h
@@ -61,11 +61,13 @@ private:
   int subsub_join();
   int subsub_split();
   int subsub_summary();
+  int subsub_rescale();
   int subsub_dcalc();
   int sub_swriter();
   int subsub_bf();
   int subsub_twobody();
   int subsub_cutoff();
+  int subsub_data();
   int flag_version();
 
   /* Return a vector of absolute paths to every file in a directory.
@@ -89,6 +91,7 @@ public:
       CLI::App *split=nullptr;
         CLI::Option_group *split_group=nullptr;
       CLI::App *summary=nullptr;
+      CLI::App *rescale=nullptr;
     CLI::App *desc=nullptr;
       CLI::App *dcalc=nullptr;
     CLI::App *swriter=nullptr;
@@ -96,6 +99,8 @@ public:
       CLI::App *bf=nullptr;
       CLI::App *cutoff=nullptr;
       CLI::App *twobody=nullptr;
+      CLI::App *vdata=nullptr;
+        CLI::Option_group *data_group=nullptr;
 
 };
 #endif
-- 
GitLab