Skip to content
Snippets Groups Projects
Commit f2bb180e authored by Rupe Nash's avatar Rupe Nash
Browse files

fix a few bugs, tidy code and vastly improve docstrings

parent e5a81216
No related branches found
No related tags found
No related merge requests found
// -*- mode: c++; -*-
//
// Copyright (c) 2019, Rupert Nash
//
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
#ifndef LIBMISC_ELEMENT_ITERATOR_HPP
#define LIBMISC_ELEMENT_ITERATOR_HPP
// Bare-bones iterator transformer. It wraps an iterator that when
// dereferenced can be accessed with std::get<unsigned> and will
// return the corresponding element when dereferenced.
//
// Useful to allow, for example, easy interation over the values or
// keys of a map.
//
// Base iterator B must implement LegacyIterator and LegacyForwardIterator.
//
// If B implements LegacyBidirectionalIterator then so does this.
template <typename B, std::size_t index>
struct element_iterator {
using base_iterator = B;
using base_value_type = typename std::decay<
decltype(*std::declval<B>())
>::type;
// Support iterator_traits
using value_type = typename std::decay<
decltype(std::get<index>(std::declval<base_value_type>()))
>::type;
using reference = value_type&;
using pointer = value_type*;
using difference_type = typename std::iterator_traits<B>::difference_type;
// This is probably wrong
using iterator_category = typename std::iterator_traits<B>::iterator_category;
// Pre increment
element_iterator& operator++() {
++base;
return *this;
}
// Post increment
element_iterator operator++(int) {
return element_iterator{base++};
}
// Pre decrement
element_iterator& operator--() {
static_assert(std::is_same<decltype(--B() ), B&>::value,
"base iterator does implement operator--() as standard requires");
--base;
return *this;
}
// Post decrement
element_iterator operator--(int) {
static_assert(std::is_same<decltype(--B() ), B&>::value,
"base iterator does implement operator--() as standard requires");
return element_iterator{base++};
}
// Dereference
reference operator*() {
return std::get<index>(*base);
}
// Could make this private but I'm lazy
base_iterator base;
};
template <typename B, std::size_t index>
bool operator!=(const element_iterator<B, index>& a, const element_iterator<B, index>& b) {
return a.base != b.base;
}
template <typename B, std::size_t index>
bool operator==(const element_iterator<B, index>& a, const element_iterator<B, index>& b) {
return a.base == b.base;
}
#endif
......@@ -5,6 +5,7 @@
#include "error.h"
#include "modify.h"
#include "update.h"
#include "element_iterator.h"
using namespace LAMMPS_NS;
......@@ -27,26 +28,49 @@ FixMuiConnection::FixMuiConnection(LAMMPS *lmp, int narg, char **arg) :
// Take ownership of the pointer and pass to the shared_ptr in the map
unifaces.emplace(arg[i+4], univec[i].release());
}
if (screen) {
fprintf(screen,
"MUI connection established as domain '%s' on interface(s)",
arg[3]);
int i = 4;
fprintf(screen, " '%s'", arg[i]);
for(i = 5; i < narg; ++i) {
fprintf(screen, ", '%s'", arg[i]);
}
fprintf(screen, ".\n");
}
}
FixMuiConnection::~FixMuiConnection() {
}
// The fix only does things at set up time!
int FixMuiConnection::setmask() {
return 0;
}
namespace {
// Factory function to Adapt element_iterator to iterate over the
// values of a key/val container
template<typename base_iter>
element_iterator<base_iter, 1> make_val_iter(base_iter i) {
return element_iterator<base_iter, 1>{i};
}
}
// Ensure that all MUI connections are up and running before
// continuing.
void FixMuiConnection::setup(int) {
update->update_time();
double t = update->atime;
for (auto name_unif: unifaces) {
auto& unif = name_unif.second;
unif->push("ready", 1.0);
unif->commit(t);
unif->barrier(t);
}
mui::sync_all(make_val_iter(unifaces.begin()),
make_val_iter(unifaces.end()),
0.0);
}
// Get the interface by name to serve client fixes
auto FixMuiConnection::get_uniface(const std::string& if_name) -> uniface_ptr {
auto key_val = unifaces.find(if_name);
if (key_val == unifaces.end())
......@@ -55,25 +79,20 @@ auto FixMuiConnection::get_uniface(const std::string& if_name) -> uniface_ptr {
return key_val->second;
}
// Static member function to search the supplied Modify for a fix with
// the given ID and of this type. Returns nullptr on failure.
FixMuiConnection* FixMuiConnection::find(class Modify* mod, char* fixID) {
// lookup Fix ID
int ifix = 0;
for (; ifix != mod->nfix; ifix++)
if (strcmp(fixID, mod->fix[ifix]->id) == 0)
break;
if (ifix == mod->nfix) return nullptr;
if (ifix == mod->nfix)
return nullptr;
auto connection_fix = dynamic_cast<FixMuiConnection*>(mod->fix[ifix]);
if (!connection_fix) return nullptr;
//mod->error->all(FLERR, "Found fix but it is not a FixMuiConnection");
if (!connection_fix)
return nullptr;
return connection_fix;
}
auto FixMuiConnection::find_uniface(class Modify* mod, char* fixID, const std::string& if_name) -> uniface_ptr {
auto connection_fix = find(mod, fixID);
if (connection_fix)
return connection_fix->get_uniface(if_name);
else
return nullptr;
}
......@@ -13,24 +13,36 @@ FixStyle(mui/connect,FixMuiConnection)
#include "fix.h"
#include "mui_fwd.h"
// This fix simply starts a MUI connection on 1 or more interfaces for
// use by other mui fix styles
namespace LAMMPS_NS {
// This fix simply starts a MUI connection on 1 or more interfaces for
// use by other MUI fix styles
class FixMuiConnection : public Fix {
public:
// LAMMPS is fundamentally 3D, so restricting MUI to 3D unifaces.
using uniface = mui::uniface<mui::config_3d>;
using uniface_ptr = std::shared_ptr<uniface>;
using uniface_map = std::map<std::string, uniface_ptr>;
// Construct in the standard Fix way
// Usage in an input script:
// fix fix-ID group-ID mui/connect domain [ interface ]+
FixMuiConnection(class LAMMPS *, int, char **);
virtual ~FixMuiConnection();
// Fix API overrides
virtual int setmask() override;
virtual void setup(int) override;
// Get the interface by name to serve client fixes
uniface_ptr get_uniface(const std::string&);
// Static member function to search the supplied Modify for a fix
// with the given ID and of this type. Returns nullptr on failure.
static FixMuiConnection* find(class Modify*, char*);
static uniface_ptr find_uniface(class Modify*, char*, const std::string&);
private:
// Map holding our unifaces
uniface_map unifaces;
};
}
......
......@@ -11,10 +11,6 @@
using namespace LAMMPS_NS;
double current_sim_time(Update* up) {
return up->atime + (up->ntimestep - up->atimestep) * up->dt;
}
FixMuiFetchParameter::FixMuiFetchParameter(LAMMPS *lmp, int narg, char **arg) :
Fix(lmp, narg, arg) {
// The "fix" string is not passed.
......@@ -79,6 +75,8 @@ FixMuiFetchParameter::FixMuiFetchParameter(LAMMPS *lmp, int narg, char **arg) :
if (strcmp(key, "EVERY") == 0) {
nevery = force->inumeric(FLERR, val);
if (nevery < 1)
error->all(FLERR, "fix mui/fetch_param: EVERY must be positive\n");
} else {
std::string msg = "fix mui/fetch_param: invalid option '";
msg += key;
......@@ -90,11 +88,8 @@ FixMuiFetchParameter::FixMuiFetchParameter(LAMMPS *lmp, int narg, char **arg) :
if (screen)
fprintf(screen,
"mui/fetch_param: connection fix = %s; interface = %s; parameter name = %s; dest variable = %s; checking every %d steps\n",
argv[0], argv[1], param_name.c_str(), var_name.c_str(), nevery);
fprintf(screen,
"update->dt = %e\nforce->femtosecond = %e\nupdate->ntimestep = %d\n",
update->dt, force->femtosecond, update->ntimestep);
"MUI fetch parameter on interface = '%s' '%s' -> '%s' every %d steps\n",
argv[1], param_name.c_str(), var_name.c_str(), nevery);
}
......@@ -105,16 +100,17 @@ int FixMuiFetchParameter::setmask() {
return FixConst::END_OF_STEP;
}
// the FETCH part
// Do the fetch from MUI into the variable
void FixMuiFetchParameter::end_of_step() {
update->update_time();
double tnow = update->atime;
fprintf(screen, "LAMMPS MUI barrier at t = %f\n", tnow);
uniface->barrier(tnow);
// MUI requires a barrier to actually receive data (in a field-style
// fetch this would be done internally).
uniface->barrier(tnow);
param_value = uniface->fetch<double>(param_name);
if (screen)
fprintf(screen, "MUI fetched parameter '%s' = %f\n", param_name.c_str(), param_value);
// Look up the variable and set it's value
int varind = input->variable->find(const_cast<char*>(var_name.c_str()));
input->variable->internal_set(varind, param_value);
}
......
......@@ -12,20 +12,32 @@ FixStyle(mui/fetch_param,FixMuiFetchParameter)
#include "fix.h"
#include "fix_mui_connection.h"
// This uses a mui/connect fix to fetch a parameter (of type double) every N timesteps
namespace LAMMPS_NS {
// This fix uses a mui/connect fix to fetch a parameter (of type
// double) every N timesteps
class FixMuiFetchParameter : public Fix {
public:
using uniface_ptr = FixMuiConnection::uniface_ptr;
// Construct in the standard Fix way
// Usage in an input script:
// fix fixID grpID mui/fetch_param connectionID interface param_name var_name [EVERY nsteps]"
// connectionID must be the ID of an existing mui/connect fix
// interface must be the name of an interface belonging to the connection
// param_name is the string identifying the parameter to MUI
// var_name must be the name of a LAMMPS internal style variable
// nsteps is an integer > 0
FixMuiFetchParameter(class LAMMPS *, int, char **);
virtual ~FixMuiFetchParameter();
// Fix API
virtual int setmask() override;
virtual void end_of_step() override;
private:
using uniface_ptr = FixMuiConnection::uniface_ptr;
uniface_ptr uniface;
std::string param_name;
std::string var_name;
......
......@@ -52,6 +52,8 @@ FixMuiPush::FixMuiPush(LAMMPS *lmp, int narg, char **arg) :
if (iarg + 1 >= narg)
error->all(FLERR, "keyword arg EVERY must be followed by a value");
nevery = force->inumeric(FLERR, arg[iarg + 1]);
if (nevery < 1)
error->all(FLERR, "fix mui/fetch_param: EVERY must be positive\n");
iarg += 2;
} else if (strcmp(key, "PARAM") == 0) {
if (iarg + 2 >= narg)
......@@ -67,6 +69,16 @@ FixMuiPush::FixMuiPush(LAMMPS *lmp, int narg, char **arg) :
error->all(FLERR, msg.c_str());
}
}
if (screen) {
fprintf(screen,
"MUI push on interface = '%s' every %d steps",
arg[3], nevery);
for (const auto& i : parameters)
fprintf(screen,
", parameter '%s' -> '%s'",
i.second.c_str(), i.first.c_str());
fprintf(screen, ".\n");
}
}
FixMuiPush::~FixMuiPush() {
......@@ -76,7 +88,7 @@ int FixMuiPush::setmask() {
return FixConst::END_OF_STEP;
}
// Push the messages and commit
// Compute the variable values, push as parameters, and commit.
void FixMuiPush::end_of_step() {
for (const auto& i : parameters) {
const auto& param_name = i.first;
......@@ -87,7 +99,6 @@ void FixMuiPush::end_of_step() {
}
update->update_time();
double tnow = update->atime;
fprintf(screen, "LAMMPS commit at t = %f\n", tnow);
uniface->commit(tnow);
}
......
......@@ -12,20 +12,23 @@ FixStyle(mui/push, FixMuiPush)
#include "fix.h"
#include "fix_mui_connection.h"
// This uses a mui/connect fix to fetch a parameter (of type double) every N timesteps
namespace LAMMPS_NS {
// This fix uses a mui/connect fix to push data every N timesteps
// Currently restricted to simple double parameters
class FixMuiPush : public Fix {
public:
using uniface_ptr = FixMuiConnection::uniface_ptr;
// Construct in the standard Fix way
// Usage in an input script:
// fix $fixID $grpID mui/push $connectionID $if_name [EVERY $nsteps] [PARAM $param_name $varname]*
FixMuiPush(class LAMMPS *, int, char **);
virtual ~FixMuiPush();
// Fix API
virtual int setmask() override;
virtual void end_of_step() override;
private:
using uniface_ptr = FixMuiConnection::uniface_ptr;
uniface_ptr uniface;
// MUI name -> LAMMPS source variable name
using param_map_t = std::map<std::string, std::string>;
......
......@@ -35,6 +35,7 @@ using namespace LAMMPS_NS;
int main(int argc, char **argv)
{
MPI_Init(&argc,&argv);
MPI_Comm this_world = mui::mpi_split_by_app();
// enable trapping selected floating point exceptions.
// this uses GNU extensions and is only tested on Linux
......@@ -50,7 +51,7 @@ int main(int argc, char **argv)
#ifdef LAMMPS_EXCEPTIONS
try {
LAMMPS *lammps = new LAMMPS(argc, argv, mui::mpi_split_by_app());
LAMMPS *lammps = new LAMMPS(argc, argv, this_world);
lammps->input->file();
delete lammps;
} catch(LAMMPSAbortException & ae) {
......@@ -60,11 +61,11 @@ int main(int argc, char **argv)
exit(1);
}
#else
LAMMPS *lammps = new LAMMPS(argc, argv, mui::mpi_split_by_app());
LAMMPS *lammps = new LAMMPS(argc, argv, this_world);
lammps->input->file();
delete lammps;
#endif
MPI_Barrier(MPI_COMM_WORLD);
MPI_Barrier(this_world);
MPI_Finalize();
#ifdef FFT_FFTW3
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment