From 0e417ef4f7fb6a785c90b52711bdc33dd08d80d8 Mon Sep 17 00:00:00 2001 From: Marcin Kirsz <mkirsz@ed.ac.uk> Date: Wed, 19 Feb 2025 13:00:05 +0000 Subject: [PATCH] Dataset histogram tool --- bin/tadah_cli.cpp | 246 +++++++++++++++++++++++++++++++++++----------- bin/tadah_cli.h | 5 + 2 files changed, 191 insertions(+), 60 deletions(-) diff --git a/bin/tadah_cli.cpp b/bin/tadah_cli.cpp index 1fdd5e8..0cd6010 100644 --- a/bin/tadah_cli.cpp +++ b/bin/tadah_cli.cpp @@ -393,6 +393,19 @@ TadahCLI::TadahCLI(): ->required() ->check(CLI::ExistingFile); + /*---------------------------------------------------------------------------*/ + /* DB RESCALE */ + /*---------------------------------------------------------------------------*/ + usage = + "Add energy rescaling factors to dataset(s)\n"; + rescale = db->add_subcommand("rescale", usage); + + usage = "Specify one or more Tadah! datasets\n to add energy rescaling factors\n" + "The output is saved to new datasets(s) named:\n" + "ORIGNAL_FILE_rescaled\n"; + rescale->add_option("-d,--datasets", datasets, usage) + ->required() + ->check(CLI::ExistingFile); /*---------------------------------------------------------------------------*/ /* DESCRIPTORS */ @@ -511,12 +524,11 @@ TadahCLI::TadahCLI(): /*---------------------------------------------------------------------------*/ usage = - "Plot basis functions\n(and optionally cutoff values)\nused in two- and many-body expansions.\n" - "GNUPLOT plotting IS NOT IMPLEMENTED everything else should work\n"; + "Plot basis functions\n(and optionally cutoff values)\nused in two- and many-body expansions.\n"; bf = plot->add_subcommand("bf", usage); usage = - "This function can generate a formatted text file for plotting with external software, or it can use gnuplot (if available during Tadah! compilation) for visualization.\n" + "This function can generate a formatted text file for plotting with external software.\n" "Input must be either a trained potential model or a configuration file with the following keys:\n" " - SGRID2B (or SGRIDMB) <--REQUIRED\n" " - CGRID2B (or CGRIDMB) <--REQUIRED\n" @@ -529,15 +541,21 @@ TadahCLI::TadahCLI(): bf->usage(usage); usage = - "Provide either a potential file\nor a config file containing\nthe necessary keys.\n"; + "Provide either a potential file\nor a config file containing\nthe necessary keys.\n" + "The file must contain the following keys:\n" + " - INIT2B | INITMB\n" + " - TYPE2B | TYPEMB\n" + " - RCUT2B | RCUTMB\n" + " - RCTYPE2B | RCTYPEMB\n" + " - CGRID2B | CGRIDMB\n" + " - SGRID2B | SGRIDMB\n"; bf->add_option("-c,--config", config_file, usage) ->check(CLI::ExistingFile) ->required(); usage = - "Specify the output file for text data\nand (optionally) a plot file using gnuplot.\n" - "Supported plot file extensions:\n.png, .pdf, .eps, .svg.\n" + "Specify the output file for text data.\n" "Pathnames must not exist.\n"; bf->add_option("-o,--outfile", out_files, usage) @@ -547,15 +565,15 @@ TadahCLI::TadahCLI(): ->required(); bf->add_option("-t,--type", types, - "Specify the type of basis function and interaction:\n" - " - \"B 2b Y/N\": Blips for two-body interactions\n" - " - \"B mb Y/N\": Blips for many-body interactions\n" - " - \"G 2b Y/N\": Gaussians for two-body interactions\n" - " - \"G mb Y/N\": Gaussians for many-body interactions\n" - "Y/N indicates whether to plot the cutoff function.\n") + "Specify whether to plot two- or many-body grid\n" + "Y/N indicates whether to also plot the cutoff function.\n" + " - \"2b Y/N\": Plot grid for TYPE2B\n" + " - \"mb Y/N\": Plot grid for TYPEMB\n" + "\n") ->required(); - bf->add_flag("-s,--scale", "Scale the heights of the basis\nfunctions by the cutoff value.\n"); + bf->add_flag("-D,--Derivative", "Will plot derivatives of\nbasis functions and cutoffs.\n"); + bf->add_flag("-s,--scale", "Rescale the heights of the basis\nfunctions by the cutoff value.\n"); usage = "Indices of basis functions to be plotted\n" @@ -581,8 +599,7 @@ TadahCLI::TadahCLI(): /*---------------------------------------------------------------------------*/ usage = - "Plot two-body potential.\n" - "GNUPLOT plotting IS NOT IMPLEMENTED everything else should work\n"; + "Plot two-body potential.\n"; twobody = plot->add_subcommand("twobody", usage); @@ -597,8 +614,7 @@ TadahCLI::TadahCLI(): ->required(); usage = - "Specify the output file for text data\nand (optionally) a plot file using gnuplot.\n" - "Supported plot file extensions:\n.png, .pdf, .eps, .svg.\n" + "Specify the output file for text data.\n" "Pathnames must not exist.\n"; twobody->add_option("-o,--outfile", out_files, usage) @@ -627,8 +643,7 @@ TadahCLI::TadahCLI(): /*---------------------------------------------------------------------------*/ usage = - "Plot cutoffs.\n" - "GNUPLOT plotting IS NOT IMPLEMENTED everything else should work\n"; + "Plot cutoffs.\n"; cutoff = plot->add_subcommand("cutoff", usage); @@ -638,8 +653,7 @@ TadahCLI::TadahCLI(): cutoff->usage(usage); usage = - "Specify the output file for text data\nand (optionally) a plot file using gnuplot.\n" - "Supported plot file extensions:\n.png, .pdf, .eps, .svg.\n" + "Specify the output file for text data.\n" "Pathnames must not exist.\n"; cutoff->add_option("-o,--outfile", out_files, usage) @@ -663,6 +677,39 @@ TadahCLI::TadahCLI(): cutoff->add_flag("-V,--Verbose", verbose, "Enable verbose output for detailed information.\n"); + +/*---------------------------------------------------------------------------*/ +/* PLOT DATA */ +/*---------------------------------------------------------------------------*/ + +usage = + "Visualize Tadah! datasets.\n"; + +vdata = plot->add_subcommand("vdata", usage); + +usage = + "This command provides tools for data visualization.\n" + "It generates a temperature vs. density histogram.\n" + "It counts the number of configurations that fall within each 2D bin (T, rho).\n" + "The formatted output is T rho count, where T and rho are the centers of the bins.\n"; + +vdata->usage(usage); + +usage = + "One or more datasets for prediction.\n"; + +vdata->add_option("-d,--datasets", datasets, usage) + ->required(); + +vdata->add_option("-b,--bins", sizes, "Number of bins in the x and y directions.\n") + ->required(); + +vdata->add_option("-n,--numeric", outprec, + "Set the number of decimal places for numerical output.\n"); + +vdata->add_flag("-V,--verbose", verbose, + "Enable verbose output for detailed information.\n"); + } int TadahCLI::run(int argc, char** argv) { @@ -699,6 +746,9 @@ int TadahCLI::run(int argc, char** argv) { else if (*summary) { return subsub_summary(); } + else if (*rescale) { + return subsub_rescale(); + } else if (*sample) { return subsub_sample(); } @@ -721,6 +771,9 @@ int TadahCLI::run(int argc, char** argv) { else if (*cutoff) { return subsub_cutoff(); } + else if (*vdata) { + return subsub_data(); + } } return 0; } @@ -1182,6 +1235,21 @@ int TadahCLI::subsub_summary() { } return 0; } +int TadahCLI::subsub_rescale() { + if (rank!=0) return 0; + CLI::Timer timer_tot{"Rescaling", CLI::Timer::Big}; + for (const auto &infile: datasets) { + StructureDB stdb; + stdb.add(infile); + for (auto &s: stdb) { + s.eweight = std::abs(1/s.energy); + } + std::string outfile = infile + "_rescaled"; + stdb.dump_to_file(outfile); + } + if (verbose) std::cout << timer_tot.to_string() << std::endl; + return 0; +} int TadahCLI::subsub_sample() { if (rank!=0) return 0; @@ -1517,46 +1585,47 @@ int TadahCLI::subsub_bf() { Config config(config_file); DC_Selector DCS(config); - v_type cgrid; - v_type sgrid; - boost::to_upper(types[0]); boost::to_upper(types[1]); - boost::to_upper(types[2]); Cut_Base *fcut = nullptr; - if (types[1]=="2B") { - DCS.d2b->get_grid(config, "CGRID2B", cgrid); - DCS.d2b->get_grid(config, "SGRID2B", sgrid); - if (types[2]=="Y") { + double (*fb)(double,double,double) = nullptr; + std::string cgridtype; + std::string sgridtype; + std::string dtype; + if (types[0]=="2B") { + cgridtype="CGRID2B"; + sgridtype="SGRID2B"; + dtype = config.get<std::string>("TYPE2B"); + if (types[1]=="Y" || bf->count("--scale")) { double rcut = config.get<double>("RCUT2BMAX"); fcut = CONFIG::factory<Cut_Base,double>( config.get<std::string>("RCTYPE2B"), rcut); } - } - else if (types[1]=="MB") { - DCS.dmb->get_grid(config, "CGRIDMB", cgrid); - DCS.dmb->get_grid(config, "SGRIDMB", sgrid); - if (types[2]=="Y") { + } else if (types[0]=="MB") { + cgridtype="CGRIDMB"; + sgridtype="SGRIDMB"; + dtype = config.get<std::string>("TYPEMB"); + if (types[1]=="Y" || bf->count("--scale")) { double rcut = config.get<double>("RCUTMBMAX"); fcut = CONFIG::factory<Cut_Base,double>( config.get<std::string>("RCTYPEMB"), rcut); } - } - else { + } else { throw std::runtime_error("Unsupported -t, --type: " + types[0] + " " + types[1]); } - double (*fb)(double,double,double) = nullptr; - if (types[0]=="B") { - std::cout << "B type: " << types[0] << std::endl; - fb = &B; - } - else if (types[0]=="G") { - std::cout << "G type: " << types[0] << std::endl; - fb = &G; - } - else { - throw std::runtime_error("Unsupported -t, --type: " + types[0] + " " + types[1]); + v_type cgrid; + v_type sgrid; + D_Base::get_grid(config, cgridtype, cgrid); + D_Base::get_grid(config, sgridtype, sgrid); + + if (dtype == "D2_Blip" || dtype == "DM_Blip") { + if (bf->count("--Derivative")) { fb = &dB;} else { fb = &B; } + } else if (dtype == "D2_BP" || dtype == "DM_EAD" || dtype == "DM_mEAD") { + if (bf->count("--Derivative")) { fb = &dG;} else { fb = &G; } + } else { + throw std::runtime_error("Unsupported: " + dtype); } + std::vector<size_t> idx; if (bf->count("-i")) { idx = parse_indices(indices); @@ -1585,11 +1654,16 @@ int TadahCLI::subsub_bf() { if (!bf->count("--numeric")) { outprec = 5; } - bool scale = bf->count("--scale"); v_type cutoff_values (range.size(), 1.0); - if (scale) { - for (size_t i=0; i<range.size(); ++i) { - cutoff_values[i] = fcut->calc(range[i]); + if (bf->count("--scale")) { + if(bf->count("--Derivative")) { + for (size_t i=0; i<range.size(); ++i) { + cutoff_values[i] = fcut->calc_prime(range[i]); + } + } else { + for (size_t i=0; i<range.size(); ++i) { + cutoff_values[i] = fcut->calc(range[i]); + } } } @@ -1606,15 +1680,25 @@ int TadahCLI::subsub_bf() { outfile << std::endl << std::endl; } - if (types[2]=="Y") { - for (size_t i=0; i<range.size(); ++i) { - outfile << std::fixed << std::setprecision(outprec) - << std::setw(outprec+5) - << range[i] << " " - << std::setw(outprec+5) << fcut->calc(range[i]) << std::endl; + if (types[1]=="Y") { + if(bf->count("--Derivative")) { + for (size_t i=0; i<range.size(); ++i) { + outfile << std::fixed << std::setprecision(outprec) + << std::setw(outprec+5) + << range[i] << " " + << std::setw(outprec+5) << fcut->calc_prime(range[i]) << std::endl; + } + outfile << std::endl << std::endl; + } + else { + for (size_t i=0; i<range.size(); ++i) { + outfile << std::fixed << std::setprecision(outprec) + << std::setw(outprec+5) + << range[i] << " " + << std::setw(outprec+5) << fcut->calc(range[i]) << std::endl; + } + outfile << std::endl << std::endl; } - outfile << std::endl << std::endl; - } outfile.flush(); @@ -1648,7 +1732,7 @@ int TadahCLI::subsub_twobody() { size_t N = static_cast<size_t>(range[2]); v_type range = linspace(start,stop,N); - if (!bf->count("--numeric")) { + if (!twobody->count("--numeric")) { outprec = 5; } @@ -1776,3 +1860,45 @@ int TadahCLI::subsub_cutoff() { if (verbose) std::cout << timer_tot.to_string() << std::endl; return 0; } +int TadahCLI::subsub_data() { + if (rank!=0) return 0; + CLI::Timer timer_tot{"Data Visualiser", CLI::Timer::Big}; + + if(vdata->count("--Verbose")) + set_verbose(); + + if (datasets.size()<1) throw std::runtime_error("At least one dataset is required. Check -d, --datasets"); + if (sizes.size()!=2) throw std::runtime_error("Number of bins are requires. Provide two integer values. Check -b, --bins\n"); + int x_bins = sizes[0]; + int y_bins = sizes[1]; + + std::vector<std::pair<double,double>> data_2d; + + std::vector<StructureDB> stdbs(datasets.size()); + for (size_t i=0; i < stdbs.size(); ++i) { + stdbs[i].add(datasets[i]); + if (verbose) std::cout << datasets[i] << " " << stdbs[i].summary(); + } + + for (size_t i=0; i<stdbs.size(); ++i) { + for (size_t j=0; j<stdbs[i].size(); ++j) { + Structure &st = stdbs[i](j); + double T = st.get_temperature(); + double rho = st.get_density(); + data_2d.push_back(std::make_pair(T,rho)); + } + } + + std::vector<histogram_bin> histogram = generate_2d_histogram(data_2d, x_bins, y_bins); + + for (const auto& bin : histogram) { + double x_center, y_center; + int count; + std::tie(x_center, y_center, count) = bin; + std::cout << x_center << " " << y_center << " " << count << "\n"; + } + + if (is_verbose()) std::cout << "Done!" << std::endl; + if (verbose) std::cout << timer_tot.to_string() << std::endl; + return 0; +} diff --git a/bin/tadah_cli.h b/bin/tadah_cli.h index fd2b7fe..edb0513 100644 --- a/bin/tadah_cli.h +++ b/bin/tadah_cli.h @@ -61,11 +61,13 @@ private: int subsub_join(); int subsub_split(); int subsub_summary(); + int subsub_rescale(); int subsub_dcalc(); int sub_swriter(); int subsub_bf(); int subsub_twobody(); int subsub_cutoff(); + int subsub_data(); int flag_version(); /* Return a vector of absolute paths to every file in a directory. @@ -89,6 +91,7 @@ public: CLI::App *split=nullptr; CLI::Option_group *split_group=nullptr; CLI::App *summary=nullptr; + CLI::App *rescale=nullptr; CLI::App *desc=nullptr; CLI::App *dcalc=nullptr; CLI::App *swriter=nullptr; @@ -96,6 +99,8 @@ public: CLI::App *bf=nullptr; CLI::App *cutoff=nullptr; CLI::App *twobody=nullptr; + CLI::App *vdata=nullptr; + CLI::Option_group *data_group=nullptr; }; #endif -- GitLab