From c0c4db5a3b88506f5ae9de7a44dcaf2fd454786d Mon Sep 17 00:00:00 2001 From: mkirsz <s1351949@sms.ed.ac.uk> Date: Fri, 7 Mar 2025 20:00:03 +0000 Subject: [PATCH] New classes with tests --- include/tadah/mlip/atom.h | 9 + include/tadah/mlip/structure_properties.h | 53 ++ .../tadah/mlip/structure_transformations.h | 149 ++++++ src/atom.cpp | 6 + src/structure_properties.cpp | 71 +++ src/structure_transformations.cpp | 357 +++++++++++++ tests/test_structure_properties.cpp | 155 ++++++ tests/test_structure_transformations.cpp | 493 ++++++++++++++++++ 8 files changed, 1293 insertions(+) create mode 100644 include/tadah/mlip/structure_properties.h create mode 100644 include/tadah/mlip/structure_transformations.h create mode 100644 src/structure_properties.cpp create mode 100644 src/structure_transformations.cpp create mode 100644 tests/test_structure_properties.cpp create mode 100644 tests/test_structure_transformations.cpp diff --git a/include/tadah/mlip/atom.h b/include/tadah/mlip/atom.h index 04d82a4..385eef7 100644 --- a/include/tadah/mlip/atom.h +++ b/include/tadah/mlip/atom.h @@ -64,6 +64,15 @@ struct Atom: public Element { const double px, const double py, const double pz, const double fx, const double fy, const double fz); + /** This constructor fully initialise this object with zero force + * + * @param[in] element Chemical Element + * @param[in] px,py,pz Atomic coordinates + * + */ + Atom(const Element &element, + const double px, const double py, const double pz); + /** Hold position of the atom. */ Vec3d position; diff --git a/include/tadah/mlip/structure_properties.h b/include/tadah/mlip/structure_properties.h new file mode 100644 index 0000000..1af9760 --- /dev/null +++ b/include/tadah/mlip/structure_properties.h @@ -0,0 +1,53 @@ +#ifndef STRUCTURE_PROPERTIES_H +#define STRUCTURE_PROPERTIES_H + +#include <tadah/mlip/structure.h> + +/** + * @class StructureProperties + * @brief Collects read-only property computations for a Structure, such as: + * - Volume (getVolume) + * - Density (getDensity) + * - Center of mass (getCentreOfMass) + */ +class StructureProperties +{ +public: + /** + * @brief Computes the volume of the Structure from its lattice vectors. + * + * Interprets the columns of st.cell as vectors a1, a2, a3 and + * computes the scalar triple product: volume = a1 · (a2 × a3). + * Returns volume in Å^3. + */ + static double getVolume(const Structure &st); + + /** + * @brief Computes the density of the Structure in g/cm^3. + * + * Procedure: + * - Accumulate atomic mass from PeriodicTable::get_mass(atom.Z) in amu. + * - Convert volume from Å^3 to cm^3 (1 Å^3 = 1e-24 cm^3). + * - Multiply total mass (amu) by 1.66053906660e-24 g/amu. + * - Divide by volume in cm^3 => density. + * If the volume is non-positive, returns 0.0 or throws an exception (user’s choice). + */ + static double getDensity(const Structure &st); + + /** + * @brief Computes the center of mass in Cartesian coordinates. + * + * Sums (mass_i * position_i) for each atom, divides by total mass. + * If no atoms exist, returns (0,0,0). + */ + static Vec3d getCentreOfMass(const Structure &st); + + /** + * @brief Computes the *geometric center* (centroid) of all atoms. + * + * If no atoms exist, returns (0,0,0). + */ + static Vec3d getGeometricCenter(const Structure &st); +}; + +#endif // STRUCTURE_PROPERTIES_H diff --git a/include/tadah/mlip/structure_transformations.h b/include/tadah/mlip/structure_transformations.h new file mode 100644 index 0000000..d3144d2 --- /dev/null +++ b/include/tadah/mlip/structure_transformations.h @@ -0,0 +1,149 @@ +#ifndef STRUCTURE_TRANSFORMATION_H +#define STRUCTURE_TRANSFORMATION_H + +#include <tadah/mlip/structure.h> + +/** + * @class StructureTransformation + * @brief Methods that modify a Structure: scaling, shearing, rotating, replicating, etc. + * The stored virial stress is updated for linear transforms as S' = M * S * M^T + * if the transform has a physically consistent interpretation (det(M) > 0). + * For negative determinant transforms (e.g., reflection), the code still does + * S' = M*S*M^T, but it might be physically questionable. + * + * New additions: + * - scaleVolumeToDensity: resizes the cell to achieve a target density. + * - addAtom / removeAtom: insert or remove atoms from st. + * - shakeAtoms: random displacement to mimic thermal agitation. + * - remapToCell: ensures all atoms lie in the principal box. + */ +class StructureTransformation +{ +public: + /** + * Scales cell vectors and atomic positions by a factor > 0. + * Virial stress => S' = factor^2 * S (common approach for isotropic scaling). + */ + static void scale(Structure &st, double factor); + + /** + * Scales the cell to achieve a target volume [Å^3]. + * factor = cbrt(targetVolume / currentVolume). + * Then calls scale(st, factor). + */ + static void scaleToVolume(Structure &st, double targetVolume); + + /** + * Scales the cell to achieve a target density [g/cm^3]. + * + * 1) Compute total mass (amu -> grams). + * 2) volume = mass / density => in cm^3 => convert to Å^3. + * 3) scaleToVolume(...). + */ + static void scaleVolumeToDensity(Structure &st, double targetDensity); + + /** + * Rotates cell and atoms by a 3x3 matrix R: + * cell' = R * cell, + * pos' = R * pos. + * Stress => S' = R * S * R^T. + */ + static void rotate(Structure &st, const Matrix3d &R); + + /** + * Shears cell and atoms by shearMatrix: + * cell' = shearMatrix * cell, + * pos' = shearMatrix * pos, + * S' = shearMatrix * S * (shearMatrix^T). + */ + static void shear(Structure &st, const Matrix3d &shearMatrix); + + /** + * Replicates the structure nx, ny, nz times along cell vectors. + * Creates a supercell with cell dimension scaled accordingly and + * updates the atom list. Stress scaling is context-dependent. + */ + static void replicate(Structure &st, int nx, int ny, int nz); + + /** + * Shifts all atomic positions by shiftVec. Cell and stress remain unchanged. + */ + static void shift(Structure &st, const Vec3d &shiftVec); + + /** + * Mirrors about origin => pos' = -pos. + * By default, leaves stress unchanged or uses M = -I => S' = S mathematically. + */ + static void mirrorOrigin(Structure &st); + + /** + * Mirrors about plane normal to normalVec (through origin). + * Reflects positions via p' = p - 2*(p·n_hat) * n_hat. + * Then S' = M*S*M^T, M = (I - 2n_hat n_hat^T). + */ + static void mirrorPlane(Structure &st, const Vec3d &normalVec); + + /** + * Mirrors about a plane with Miller indices (h,k,l) through origin. + * The plane normal is computed from the reciprocal lattice. + * Similar reflection for stress => S' = M*S*M^T (det(M)<0). + */ + static void mirrorHkl(Structure &st, int h, int k, int l); + + /** + * Centers atoms so that either the geometric center or center of mass + * is placed at the origin. If shiftCell is true, also shift the cell origin + * (typical usage might not require that). + */ + static void center(Structure &st, bool useCenterOfMass = false, bool shiftCell = false); + + /** + * Realigns the cell’s principal directions to the Cartesian axes + * (e.g., diagonalizing the metric). A typical approach: + * - compute G = cell^T * cell + * - eigen-decompose G + * - build rotation R from eigenvectors + * - rotate(st, R) + * Provided as a stub. + */ + static void alignAxes(Structure &st); + + /** + * Remaps each atom into the principal unit cell by fractional wrapping. + */ + static void remapToCell(Structure &st); + + /** + * Adds an atom to the structure. + * If useRandomPosition is true, places it randomly inside the cell. + * If false, uses the provided pos argument. + * Checks distance from all existing atoms. If below minDistance, prints warning. + */ + static void addAtom(Structure &st, const Atom &atomTemplate, bool useRandomPosition = false, + const Vec3d &pos = Vec3d(0.0, 0.0, 0.0), double minDistance = 0.8); + + /** + * Removes an atom with a given global index from the structure if valid. + * If index is out of range, does nothing or throws a warning. + */ + static void removeAtom(Structure &st, size_t index); + + /** + * Applies random displacements to each atom’s position, up to maxDisplacement in each coordinate, + * to simulate thermal agitation. + * + * Example: shakeAtoms(st, 0.05) => random ±0.05 Å displacements in x,y,z. + */ + static void shakeAtoms(Structure &st, double maxDisplacement); + +private: + static void transformStress(Structure &st, const Matrix3d &M); + static void checkNearSingular(const Structure &st); + + /** + * Helper that returns a uniform random real in [0,1]. + */ + static double rand01(); +}; + +#endif // STRUCTURE_TRANSFORMATION_H diff --git a/src/atom.cpp b/src/atom.cpp index 164a6c4..38abf45 100644 --- a/src/atom.cpp +++ b/src/atom.cpp @@ -11,6 +11,12 @@ Atom::Atom(const Element &element, position(px,py,pz), force(fx,fy,fz) {} +Atom::Atom(const Element &element, + const double px, const double py, const double pz): + Element(element), + position(px,py,pz), + force(0,0,0) +{} std::ostream& operator<<(std::ostream& os, const Atom& atom) { diff --git a/src/structure_properties.cpp b/src/structure_properties.cpp new file mode 100644 index 0000000..3f61064 --- /dev/null +++ b/src/structure_properties.cpp @@ -0,0 +1,71 @@ +#include <tadah/mlip/structure_properties.h> +#include <tadah/core/periodic_table.h> +#include <stdexcept> +#include <cmath> + +double StructureProperties::getVolume(const Structure &st) +{ + // Columns of st.cell => a1, a2, a3 + Vec3d a1(st.cell(0,0), st.cell(1,0), st.cell(2,0)); + Vec3d a2(st.cell(0,1), st.cell(1,1), st.cell(2,1)); + Vec3d a3(st.cell(0,2), st.cell(1,2), st.cell(2,2)); + + double volume = a1 * (a2.cross(a3)); // scalar triple product + return volume; // [Å^3] +} + +double StructureProperties::getDensity(const Structure &st) +{ + double volumeA3 = getVolume(st); + if (volumeA3 <= 0.0) { + // Negative or zero volume => invalid + return 0.0; // or throw std::runtime_error("Volume <= 0.0, invalid cell.") + } + + // Sum atomic mass in amu + double totalMassAmu = 0.0; + for (const auto &atom : st.atoms) { + totalMassAmu += PeriodicTable::get_mass(atom.Z); + } + // Convert Å^3 => cm^3 + double volumeCm3 = volumeA3 * 1.0e-24; + + // Convert amu => grams + double totalMassGrams = totalMassAmu * 1.66053906660e-24; + + // density = mass / volume + double rho = totalMassGrams / volumeCm3; // [g/cm^3] + return rho; +} + +Vec3d StructureProperties::getCentreOfMass(const Structure &st) +{ + if (st.atoms.empty()) { + return Vec3d(0.0, 0.0, 0.0); + } + double totalMass = 0.0; + Vec3d com(0.0, 0.0, 0.0); + + for (const auto &atom : st.atoms) { + double mass = PeriodicTable::get_mass(atom.Z); // amu + com += atom.position * mass; + totalMass += mass; + } + if (totalMass > 1e-14) { + com /= totalMass; + } + return com; +} + +Vec3d StructureProperties::getGeometricCenter(const Structure &st) +{ + if (st.atoms.empty()) { + return Vec3d(0.0, 0.0, 0.0); + } + Vec3d c(0.0, 0.0, 0.0); + for (auto &atom : st.atoms) { + c += atom.position; + } + c /= (double)st.atoms.size(); + return c; +} diff --git a/src/structure_transformations.cpp b/src/structure_transformations.cpp new file mode 100644 index 0000000..f0c8ecd --- /dev/null +++ b/src/structure_transformations.cpp @@ -0,0 +1,357 @@ +#include <tadah/mlip/structure_transformations.h> +#include <tadah/mlip/structure_properties.h> +#include <random> +#include <iostream> +#include <stdexcept> +#include <cmath> + +namespace { + +/** Local utility function: multiply 3x3 matrices. */ +Matrix3d matMul(const Matrix3d &A, const Matrix3d &B) +{ + Matrix3d C; + for (int i=0; i<3; i++) { + for (int j=0; j<3; j++) { + C(i,j) = A(i,0)*B(0,j) + A(i,1)*B(1,j) + A(i,2)*B(2,j); + } + } + return C; +} + +/** A simple global RNG. Replace with thread_local or a better approach if needed. */ +std::mt19937 &globalRng() +{ + static std::random_device rd; + static std::mt19937 rng{rd()}; + return rng; +} + +} // end unnamed namespace + +double StructureTransformation::rand01() +{ + static std::uniform_real_distribution<double> dist(0.0, 1.0); + return dist(globalRng()); +} + +void StructureTransformation::transformStress(Structure &st, const Matrix3d &M) +{ + Matrix3d tmp = matMul(M, st.stress); + Matrix3d Mt = M.transpose(); + st.stress = matMul(tmp, Mt); +} + +void StructureTransformation::checkNearSingular(const Structure &st) +{ + double vol = StructureProperties::getVolume(st); + if (std::fabs(vol) < 1e-14) { + throw std::runtime_error("Cell volume is near zero after transformation. Invalid operation."); + } +} + +void StructureTransformation::scale(Structure &st, double factor) +{ + if (factor <= 0.0) { + throw std::invalid_argument("Scale factor must be positive."); + } + st.cell = factor * st.cell; + for (auto &atom : st.atoms) { + atom.position = factor * atom.position; + } + // Common approach => stress is factor^2 * stress + st.stress = (factor*factor) * st.stress; + + checkNearSingular(st); +} + +void StructureTransformation::scaleToVolume(Structure &st, double targetVolume) +{ + if (targetVolume <= 0.0) { + throw std::invalid_argument("Target volume must be positive."); + } + double currentVolume = StructureProperties::getVolume(st); + if (currentVolume <= 0.0) { + throw std::runtime_error("Current volume is invalid (<= 0); cannot scale."); + } + double factor = std::cbrt(targetVolume / currentVolume); + scale(st, factor); +} + +void StructureTransformation::scaleVolumeToDensity(Structure &st, double targetDensity) +{ + if (targetDensity <= 0.0) { + throw std::invalid_argument("Target density must be positive."); + } + // total mass in grams + double totalMassAmu = 0.0; + for (auto &atom : st.atoms) { + totalMassAmu += PeriodicTable::get_mass(atom.Z); + } + double totalMassGrams = totalMassAmu * 1.66053906660e-24; // amu->g + + // volume in cm^3 => mass/density + double volumeCm3 = totalMassGrams / targetDensity; + // convert to Å^3 + double volumeA3 = volumeCm3 * 1.0e24; + + scaleToVolume(st, volumeA3); +} + +void StructureTransformation::rotate(Structure &st, const Matrix3d &R) +{ + st.cell = matMul(R, st.cell); + for (auto &atom : st.atoms) { + atom.position = R * atom.position; + } + transformStress(st, R); + checkNearSingular(st); +} + +void StructureTransformation::shear(Structure &st, const Matrix3d &shearMatrix) +{ + st.cell = matMul(shearMatrix, st.cell); + for (auto &atom : st.atoms) { + atom.position = shearMatrix * atom.position; + } + transformStress(st, shearMatrix); + checkNearSingular(st); +} + +void StructureTransformation::replicate(Structure &st, int nx, int ny, int nz) +{ + if (nx < 1) nx = 1; + if (ny < 1) ny = 1; + if (nz < 1) nz = 1; + + size_t origNatoms = st.atoms.size(); + std::vector<Atom> originalAtoms = st.atoms; + Matrix3d oldCell = st.cell; + + // Scale cell + st.cell(0,0) *= nx; + st.cell(1,1) *= ny; + st.cell(2,2) *= nz; + + // Clear and reserve memory + st.atoms.clear(); + st.atoms.reserve(origNatoms * nx * ny * nz); + + for (int ix=0; ix<nx; ix++) { + for (int iy=0; iy<ny; iy++) { + for (int iz=0; iz<nz; iz++) { + // shift vector in oldCell coords + Vec3d shiftVec = + ix*Vec3d(oldCell(0,0), oldCell(1,0), oldCell(2,0)) + + iy*Vec3d(oldCell(0,1), oldCell(1,1), oldCell(2,1)) + + iz*Vec3d(oldCell(0,2), oldCell(1,2), oldCell(2,2)); + for (const auto &atom : originalAtoms) { + Atom newAtom = atom; + newAtom.position += shiftVec; + st.atoms.push_back(newAtom); + } + } + } + } + // Stress scaling is context dependent. E.g. st.stress might remain the same per cell, + // or scale by 1/(nx*ny*nz), etc. Not explicitly changed here. + checkNearSingular(st); +} + +void StructureTransformation::shift(Structure &st, const Vec3d &shiftVec) +{ + for (auto &atom : st.atoms) { + atom.position += shiftVec; + } + // cell, stress unchanged +} + +void StructureTransformation::mirrorOrigin(Structure &st) +{ + for (auto &atom : st.atoms) { + atom.position = -atom.position; + } + // Reflection M = -I => S'=S. (No net effect if we do M*S*M^T.) +} + +void StructureTransformation::mirrorPlane(Structure &st, const Vec3d &normalVec) +{ + Vec3d n = normalVec; + double len = n.norm(); + if (len < 1e-14) { + throw std::invalid_argument("Plane normal is near zero length."); + } + n /= len; + + for (auto &atom : st.atoms) { + double proj = atom.position.dot(n); + atom.position -= 2.0 * proj * n; + } + + // M = I - 2 n n^T + Matrix3d M = Matrix3d::Identity(); + for (int i=0; i<3; i++) { + for (int j=0; j<3; j++) { + M(i,j) -= 2.0 * n[i]*n[j]; + } + } + transformStress(st, M); +} + +void StructureTransformation::mirrorHkl(Structure &st, int h, int k, int l) +{ + if ((h==0) && (k==0) && (l==0)) { + throw std::invalid_argument("Miller indices (0,0,0) do not define a plane."); + } + // Build a1,a2,a3 from st.cell columns: + Vec3d a1(st.cell(0,0), st.cell(1,0), st.cell(2,0)); + Vec3d a2(st.cell(0,1), st.cell(1,1), st.cell(2,1)); + Vec3d a3(st.cell(0,2), st.cell(1,2), st.cell(2,2)); + + Vec3d crossA2A3 = a2.cross(a3); + double denomA1 = a1.dot(crossA2A3); + Vec3d b1 = crossA2A3 / denomA1; + + Vec3d crossA3A1 = a3.cross(a1); + double denomA2 = a2.dot(crossA3A1); + Vec3d b2 = crossA3A1 / denomA2; + + Vec3d crossA1A2 = a1.cross(a2); + double denomA3 = a3.dot(crossA1A2); + Vec3d b3 = crossA1A2 / denomA3; + + Vec3d g = (double)h*b1 + (double)k*b2 + (double)l*b3; + double normG = g.norm(); + if (normG < 1e-14) { + throw std::runtime_error("Plane normal is near zero. Check cell or (h,k,l)."); + } + Vec3d n = g / normG; + + // reflection + for (auto &atom : st.atoms) { + double proj = atom.position.dot(n); + atom.position -= 2.0 * proj * n; + } + + // reflection matrix + Matrix3d M = Matrix3d::Identity(); + for (int i=0; i<3; i++) { + for (int j=0; j<3; j++) { + M(i,j) -= 2.0 * n[i]*n[j]; + } + } + transformStress(st, M); +} + +void StructureTransformation::center(Structure &st, bool useCenterOfMass, bool shiftCell) +{ + Vec3d c; + if (useCenterOfMass) { + c = StructureProperties::getCentreOfMass(st); + } else { + c = StructureProperties::getGeometricCenter(st); + } + + for (auto &atom : st.atoms) { + atom.position -= c; + } + + if (shiftCell) { + // The cell matrix typically has its origin at (0,0,0). + // Shifting it requires some reference. + // If st had an explicit 'origin' vector, we'd do e.g. st.origin -= c; + // This snippet does nothing by default because st.cell is typically + // just directions from (0,0,0). + } +} + +void StructureTransformation::alignAxes(Structure &st) +{ + // Typical approach: + // 1) G = cell^T * cell + // 2) Eigen-decompose G => principal directions + // 3) build rotation R + // 4) rotate(st, R) + // Left unimplemented here. +} + +void StructureTransformation::remapToCell(Structure &st) +{ + double vol = StructureProperties::getVolume(st); + if (std::fabs(vol) < 1e-14) { + throw std::runtime_error("Volume is near zero, cannot remap to cell."); + } + Matrix3d invCell = st.cell.inverse(); + + for (auto &atom : st.atoms) { + Vec3d frac = invCell * atom.position; + for (int i=0; i<3; i++) { + frac[i] -= std::floor(frac[i]); + } + atom.position = st.cell * frac; + } +} + +void StructureTransformation::addAtom( + Structure &st, + const Atom &atomTemplate, + bool useRandomPosition, + const Vec3d &pos, + double minDistance +){ + Atom newAtom = atomTemplate; + + if (useRandomPosition) { + // place randomly in [0,1) fractional, then convert => Cartesian + Matrix3d invCell = st.cell.inverse(); + double vol = StructureProperties::getVolume(st); + if (std::fabs(vol) < 1e-14) { + throw std::runtime_error("Cannot add atom randomly: volume near zero."); + } + + Vec3d frac(rand01(), rand01(), rand01()); + newAtom.position = st.cell * frac; + } else { + newAtom.position = pos; + } + + // Check distance to existing atoms + for (const auto &a : st.atoms) { + double dist = (newAtom.position - a.position).norm(); + if (dist < minDistance) { + std::cerr << "[WARNING] New atom is only " << dist << " Å away from an existing atom. " + << "Minimum suggested distance: " << minDistance << " Å.\n"; + break; + } + } + + st.atoms.push_back(newAtom); +} + +void StructureTransformation::removeAtom(Structure &st, size_t index) +{ + if (index < st.atoms.size()) { + st.atoms.erase(st.atoms.begin() + index); + } else { + std::cerr << "[WARNING] removeAtom: index out of range (" << index << "). No atom removed.\n"; + } +} + +void StructureTransformation::shakeAtoms(Structure &st, double maxDisplacement) +{ + if (maxDisplacement <= 0.0) { + std::cerr << "[WARNING] shakeAtoms: non-positive maxDisplacement => no action.\n"; + return; + } + + for (auto &atom : st.atoms) { + // random in [-maxDisplacement, +maxDisplacement] + double rx = (2.0*rand01() - 1.0) * maxDisplacement; + double ry = (2.0*rand01() - 1.0) * maxDisplacement; + double rz = (2.0*rand01() - 1.0) * maxDisplacement; + atom.position[0] += rx; + atom.position[1] += ry; + atom.position[2] += rz; + } +} + diff --git a/tests/test_structure_properties.cpp b/tests/test_structure_properties.cpp new file mode 100644 index 0000000..bddbfaf --- /dev/null +++ b/tests/test_structure_properties.cpp @@ -0,0 +1,155 @@ +#include "catch2/catch.hpp" +#include <tadah/mlip/structure.h> +#include <tadah/mlip/structure_properties.h> +#include <tadah/mlip/atom.h> +#include <tadah/core/periodic_table.h> + +// Helper to create an element from symbol +static Element makeElement(const std::string& sym) { + return PeriodicTable::find_by_symbol(sym); +} + +TEST_CASE("StructureProperties: getVolume() with default/zeroed cell", "[structure_props][new]") { + Structure st; + // By default, st.cell is zero-initialized -> volume = 0 + REQUIRE(StructureProperties::getVolume(st) == Approx(0.0)); +} + +TEST_CASE("StructureProperties: getVolume() with known cell", "[structure_props][new]") { + Structure st; + + // Let's define a simple 3x3 identity cell => volume = 1 Å^3 + st.cell(0, 0) = 1.0; st.cell(0, 1) = 0.0; st.cell(0, 2) = 0.0; + st.cell(1, 0) = 0.0; st.cell(1, 1) = 1.0; st.cell(1, 2) = 0.0; + st.cell(2, 0) = 0.0; st.cell(2, 1) = 0.0; st.cell(2, 2) = 1.0; + + double vol = StructureProperties::getVolume(st); + REQUIRE(vol == Approx(1.0)); +} + +TEST_CASE("StructureProperties: getDensity() with zero-volume cell", "[structure_props][new]") { + Structure st; + // By default, st.cell is zero => volume is 0 => getDensity returns 0.0 + // We also add a few atoms to ensure it doesn't crash. + Atom a1(makeElement("H"), 0.0, 0.0, 0.0,0,0,0); + Atom a2(makeElement("C"), 1.0, 0.0, 0.0,0,0,0); + st.add_atom(a1); + st.add_atom(a2); + + double rho = StructureProperties::getDensity(st); + REQUIRE(rho == Approx(0.0)); // The code is written to return 0 on <=0 volume +} + +TEST_CASE("StructureProperties: getDensity() with known cell and atoms", "[structure_props][new]") { + // We'll define a 10Å x 10Å x 10Å cubic cell => volume = 1000 Å^3 + // Add 2 atoms: H (1 amu) and He (4 amu). Total = ~5 amu + // volume in cm^3 = 1000 * 1e-24 = 1e-21 cm^3 + // mass in grams = 5 amu * 1.66053906660e-24 g/amu = ~8.3026953e-24 g + // => density = (8.3026953e-24 g) / (1e-21 cm^3) = 8.3026953e-03 g/cm^3 + // => 0.008320297 + Structure st; + + // 10x10x10 cubic + st.cell(0,0) = 10.0; st.cell(0,1) = 0.0; st.cell(0,2) = 0.0; + st.cell(1,0) = 0.0; st.cell(1,1) = 10.0; st.cell(1,2) = 0.0; + st.cell(2,0) = 0.0; st.cell(2,1) = 0.0; st.cell(2,2) = 10.0; + + Atom h(makeElement("H"), 0.0, 0.0, 0.0,0,0,0); + Atom he(makeElement("He"), 5.0, 5.0, 5.0,0,0,0); + st.add_atom(h); + st.add_atom(he); + + double rho = StructureProperties::getDensity(st); + REQUIRE(rho == Approx(0.008320297).margin(1e-7)); +} + +TEST_CASE("StructureProperties: getCentreOfMass() with no atoms", "[structure_props][new]") { + Structure st; + // No atoms => centre = (0,0,0) by definition + Vec3d com = StructureProperties::getCentreOfMass(st); + REQUIRE(com[0] == Approx(0.0)); + REQUIRE(com[1] == Approx(0.0)); + REQUIRE(com[2] == Approx(0.0)); +} + +TEST_CASE("StructureProperties: getCentreOfMass() with single atom", "[structure_props][new]") { + // If one atom at (1,2,3), CoM should be (1,2,3) + Structure st; + Atom a(makeElement("C"), 1.0, 2.0, 3.0,0,0,0); // Carbon, but mass won't matter for single-atom + st.add_atom(a); + + Vec3d com = StructureProperties::getCentreOfMass(st); + REQUIRE(com[0] == Approx(1.0)); + REQUIRE(com[1] == Approx(2.0)); + REQUIRE(com[2] == Approx(3.0)); +} + +TEST_CASE("StructureProperties: getCentreOfMass() with multiple atoms", "[structure_props][new]") { + // Suppose we place 2 identical atoms of mass M at: + // ( 2, 0, 0 ) and ( -2, 0, 0 ) + // Then CoM = ( (M*(2) + M*(-2))/ (2M), 0, 0 ) = (0,0,0). + // We'll pick O for both, but in practice same mass anyway. + Structure st; + st.cell.set_zero(); // not used for CoM + Atom O1(makeElement("O"), 2.0, 0.0, 0.0,0,0,0); + Atom O2(makeElement("O"), -2.0, 0.0, 0.0,0,0,0); + st.add_atom(O1); + st.add_atom(O2); + + Vec3d com = StructureProperties::getCentreOfMass(st); + REQUIRE(com[0] == Approx(0.0)); + REQUIRE(com[1] == Approx(0.0)); + REQUIRE(com[2] == Approx(0.0)); +} + +TEST_CASE("StructureProperties: getCentreOfMass() with different masses", "[structure_props][new]") { + // Let's put H (mass ~1) at (0,0,0) and O (mass ~16) at (0, 3, 0) + // total mass ~17 + // Weighted sum of positions = 1*(0,0,0) + 16*(0,3,0) = (0,48,0) + // => CoM = (0,48,0)/17 ~ (0,2.823529...,0) + Structure st; + + Atom H(makeElement("H"), 0.0, 0.0, 0.0,0,0,0); + Atom O(makeElement("O"), 0.0, 3.0, 0.0,0,0,0); + st.add_atom(H); + st.add_atom(O); + + Vec3d com = StructureProperties::getCentreOfMass(st); + // 48 / 17 ~ 2.823529... + REQUIRE(com[0] == Approx(0.0)); + REQUIRE(com[1] == Approx(2.8221908626).margin(1e-8)); + REQUIRE(com[2] == Approx(0.0)); +} + +TEST_CASE("StructureProperties getGeometricCenter() vs. getCentreOfMass()", "[properties][new]") +{ + // We'll build a small structure with 2 atoms: + // H(0,0,0) => mass ~1.008 + // He(1,1,1) => mass ~4.0026 + // The geometric center => (0.5, 0.5, 0.5). + // The mass-weighted center => ~ ( (4.0026*1,1,1) + (1.008*0,0,0 ) ) / 5.0106 => ~ (0.80,0.80,0.80). + Structure st; + st.cell.set_zero(); + st.cell(0,0) = 1.0; + st.cell(1,1) = 1.0; + st.cell(2,2) = 1.0; + + Atom h(PeriodicTable::find_by_symbol("H"), 0.0, 0.0, 0.0); + Atom he(PeriodicTable::find_by_symbol("He"), 1.0, 1.0, 1.0); + st.add_atom(h); + st.add_atom(he); + + Vec3d geom = StructureProperties::getGeometricCenter(st); + Vec3d com = StructureProperties::getCentreOfMass(st); + + // Geometric center => (0.5,0.5,0.5) + REQUIRE(geom[0] == Approx(0.5)); + REQUIRE(geom[1] == Approx(0.5)); + REQUIRE(geom[2] == Approx(0.5)); + + // COM => near (0.80,0.80,0.80) + REQUIRE(com[0] == Approx(0.80).margin(0.03)); + REQUIRE(com[1] == Approx(0.80).margin(0.03)); + REQUIRE(com[2] == Approx(0.80).margin(0.03)); +} + diff --git a/tests/test_structure_transformations.cpp b/tests/test_structure_transformations.cpp new file mode 100644 index 0000000..b147320 --- /dev/null +++ b/tests/test_structure_transformations.cpp @@ -0,0 +1,493 @@ +// test_structure_transformation.cpp +#include "catch2/catch.hpp" +#include <tadah/mlip/structure.h> +#include <tadah/mlip/structure_transformations.h> +#include <tadah/mlip/structure_properties.h> +#include <tadah/mlip/atom.h> +#include <tadah/core/periodic_table.h> +#include <cmath> +#include <stdexcept> +#include <vector> +#include <iostream> + +// Helper function to build a simple 2×2×2 box with two atoms. +static Structure makeSimpleStructure() +{ + // 2x2x2 cell => volume 8 Å^3 + Structure st; + st.cell.set_zero(); + st.cell(0,0) = 2.0; + st.cell(1,1) = 2.0; + st.cell(2,2) = 2.0; + st.stress.set_zero(); + st.label = "Simple"; + + // Add H at (0,0,0) and He at (1,1,1) + Atom h(PeriodicTable::find_by_symbol("H"), 0.0, 0.0, 0.0); + Atom he(PeriodicTable::find_by_symbol("He"), 1.0, 1.0, 1.0); + st.add_atom(h); + st.add_atom(he); + + return st; +} + +TEST_CASE("StructureTransformation scale()", "[transform]") +{ + Structure st = makeSimpleStructure(); + + SECTION("Invalid scale factor throws std::invalid_argument") { + REQUIRE_THROWS_AS(StructureTransformation::scale(st, 0.0), std::invalid_argument); + REQUIRE_THROWS_AS(StructureTransformation::scale(st, -1.0), std::invalid_argument); + } + + SECTION("Valid scale factor modifies positions, cell, and stress") { + st.stress(0,0) = 1.0; + st.stress(1,1) = 1.5; + + double oldVol = StructureProperties::getVolume(st); + REQUIRE(oldVol == Approx(8.0)); + + StructureTransformation::scale(st, 2.0); + + double newVol = StructureProperties::getVolume(st); + REQUIRE(newVol == Approx(64.0)); + + // Positions doubled + REQUIRE(st.atoms[0].position[0] == Approx(0.0)); + REQUIRE(st.atoms[1].position[0] == Approx(2.0)); + + // Stress ~ factor^2 => 4× for factor=2 + REQUIRE(st.stress(0,0) == Approx(4.0)); + REQUIRE(st.stress(1,1) == Approx(6.0)); + } +} + +TEST_CASE("StructureTransformation scaleToVolume()", "[transform]") +{ + Structure st = makeSimpleStructure(); + double initVol = StructureProperties::getVolume(st); + REQUIRE(initVol == Approx(8.0)); + + SECTION("Throws if target volume <= 0") { + REQUIRE_THROWS_AS(StructureTransformation::scaleToVolume(st, 0.0), std::invalid_argument); + REQUIRE_THROWS_AS(StructureTransformation::scaleToVolume(st, -10.0), std::invalid_argument); + } + + SECTION("Scales to new volume") { + double targetVol = 125.0; // 5×5×5 + StructureTransformation::scaleToVolume(st, targetVol); + double newVol = StructureProperties::getVolume(st); + REQUIRE(newVol == Approx(targetVol).margin(1e-10)); + } +} + +TEST_CASE("StructureTransformation scaleVolumeToDensity()", "[transform]") +{ + // This structure has 2 atoms: H (~1.008 amu), He (~4.0026 amu). + // total mass ~5.0106 amu => ~8.31567e-24 g + // volume=8.0 Å^3 => ~8.0e-24 cm^3 => init density ~1.04 g/cm^3 + Structure st = makeSimpleStructure(); + + SECTION("Throws if target density <= 0") { + REQUIRE_THROWS_AS(StructureTransformation::scaleVolumeToDensity(st, 0.0), std::invalid_argument); + REQUIRE_THROWS_AS(StructureTransformation::scaleVolumeToDensity(st, -1.0), std::invalid_argument); + } + + SECTION("Scales to a target density") { + double targetDensity = 2.0; // g/cm^3 + StructureTransformation::scaleVolumeToDensity(st, targetDensity); + double newDensity = StructureProperties::getDensity(st); + REQUIRE(newDensity == Approx(targetDensity).margin(1e-3)); + } +} + +TEST_CASE("StructureTransformation rotate()", "[transform]") +{ + Structure st = makeSimpleStructure(); + st.stress(0,0) = 1.0; + st.stress(1,1) = 2.0; + st.stress(2,2) = 3.0; + + // 90° rotation about z-axis + Matrix3d R; + R.set_zero(); + R(0,1) = -1.0; + R(1,0) = 1.0; + R(2,2) = 1.0; + + StructureTransformation::rotate(st, R); + + // positions: H(0,0,0)->(0,0,0), He(1,1,1)->(-1,1,1) + REQUIRE(st.atoms[1].position[0] == Approx(-1.0)); + REQUIRE(st.atoms[1].position[1] == Approx(1.0)); + REQUIRE(st.atoms[1].position[2] == Approx(1.0)); + + // Volume unchanged + double vol = StructureProperties::getVolume(st); + REQUIRE(vol == Approx(8.0)); +} + +TEST_CASE("StructureTransformation shear()", "[transform]") +{ + Structure st = makeSimpleStructure(); + // Shear matrix => x' = x + 0.5y + Matrix3d shearMat; + shearMat.set_zero(); + shearMat(0,0) = 1.0; + shearMat(1,1) = 1.0; + shearMat(2,2) = 1.0; + shearMat(0,1) = 0.5; + + double oldVol = StructureProperties::getVolume(st); + StructureTransformation::shear(st, shearMat); + + double newVol = StructureProperties::getVolume(st); + REQUIRE(oldVol == Approx(8.0)); + REQUIRE(newVol == Approx(8.0)); +} + +TEST_CASE("StructureTransformation replicate()", "[transform]") +{ + Structure st = makeSimpleStructure(); + REQUIRE(st.natoms() == 2); + + // replicate(2,1,1) + StructureTransformation::replicate(st, 2,1,1); + REQUIRE(st.natoms() == 4); + + double vol = StructureProperties::getVolume(st); + REQUIRE(vol == Approx(16.0)); // 4×2×2 => 16 + + REQUIRE(st.cell(0,0) == Approx(4.0)); + REQUIRE(st.cell(1,1) == Approx(2.0)); + REQUIRE(st.cell(2,2) == Approx(2.0)); +} + +TEST_CASE("StructureTransformation shift()", "[transform]") +{ + Structure st = makeSimpleStructure(); + Vec3d shiftVec(1.0, -2.0, 0.5); + + StructureTransformation::shift(st, shiftVec); + REQUIRE(st.atoms[0].position[0] == Approx(1.0)); + REQUIRE(st.atoms[1].position[1] == Approx(-1.0)); +} + +TEST_CASE("StructureTransformation mirrorOrigin()", "[transform]") +{ + Structure st = makeSimpleStructure(); + StructureTransformation::mirrorOrigin(st); + // H(0,0,0)->(0,0,0), He(1,1,1)->(-1,-1,-1) + REQUIRE(st.atoms[1].position[0] == Approx(-1.0)); + REQUIRE(st.atoms[1].position[1] == Approx(-1.0)); + REQUIRE(st.atoms[1].position[2] == Approx(-1.0)); +} + +TEST_CASE("StructureTransformation mirrorPlane()", "[transform]") +{ + // Mirror about plane normal to x => x-> -x + Structure st = makeSimpleStructure(); + Vec3d n(1.0, 0.0, 0.0); + StructureTransformation::mirrorPlane(st, n); + REQUIRE(st.atoms[1].position[0] == Approx(-1.0)); + REQUIRE(st.atoms[1].position[1] == Approx(1.0)); +} + +TEST_CASE("StructureTransformation mirrorHkl()", "[transform]") +{ + // Mirror about (1,0,0) => effectively x-> -x for a simple diagonal cell + Structure st = makeSimpleStructure(); + SECTION("Invalid (0,0,0) => throw") { + REQUIRE_THROWS_AS(StructureTransformation::mirrorHkl(st, 0,0,0), std::invalid_argument); + } + SECTION("Mirror along (1,0,0)") { + StructureTransformation::mirrorHkl(st, 1,0,0); + REQUIRE(st.atoms[1].position[0] == Approx(-1.0)); + } +} + +TEST_CASE("StructureTransformation alignAxes()", "[transform]") +{ + // Stub: no effect + Structure st = makeSimpleStructure(); + auto oldPos = st.atoms[1].position; + REQUIRE_NOTHROW(StructureTransformation::alignAxes(st)); + REQUIRE(st.atoms[1].position[0] == Approx(oldPos[0])); +} + +TEST_CASE("StructureTransformation remapToCell()", "[transform]") +{ + Structure st = makeSimpleStructure(); + // place He outside => (2.5, -0.5, 3.0) + st.atoms[1].position = Vec3d(2.5, -0.5, 3.0); + + StructureTransformation::remapToCell(st); + // expected => (0.5,1.5,1.0) + REQUIRE(st.atoms[1].position[0] == Approx(0.5)); + REQUIRE(st.atoms[1].position[1] == Approx(1.5)); + REQUIRE(st.atoms[1].position[2] == Approx(1.0)); +} + +TEST_CASE("StructureTransformation addAtom()", "[transform]") +{ + Structure st = makeSimpleStructure(); + REQUIRE(st.natoms() == 2); + + SECTION("Add atom at fixed position") { + Atom c(PeriodicTable::find_by_symbol("C"), 0,0,0); + StructureTransformation::addAtom(st, c, false, Vec3d(3.0,4.0,5.0)); + REQUIRE(st.natoms() == 3); + REQUIRE(st.atoms.back().position[0] == Approx(3.0)); + REQUIRE(st.atoms.back().Z == 6); // Carbon + } + + SECTION("Add atom randomly") { + Atom c(PeriodicTable::find_by_symbol("C"), 0,0,0); + StructureTransformation::addAtom(st, c, true); + auto &pos = st.atoms.back().position; + // Can't be exact, but must lie within [0,2] + REQUIRE(pos[0] >= 0.0); REQUIRE(pos[0] < 2.0); + REQUIRE(pos[1] >= 0.0); REQUIRE(pos[1] < 2.0); + REQUIRE(pos[2] >= 0.0); REQUIRE(pos[2] < 2.0); + } +} + +TEST_CASE("StructureTransformation removeAtom()", "[transform]") +{ + Structure st = makeSimpleStructure(); + REQUIRE(st.natoms() == 2); + + SECTION("Valid index") { + StructureTransformation::removeAtom(st, 1); + REQUIRE(st.natoms() == 1); + } + + SECTION("Invalid index => no throw, prints warning") { + REQUIRE_NOTHROW(StructureTransformation::removeAtom(st, 999)); + REQUIRE(st.natoms() == 2); + } +} + +TEST_CASE("StructureTransformation shakeAtoms()", "[transform]") +{ + Structure st = makeSimpleStructure(); + + SECTION("Non-positive maxDisplacement => no action") { + StructureTransformation::shakeAtoms(st, 0.0); + REQUIRE(st.atoms[0].position[0] == Approx(0.0)); + REQUIRE(st.atoms[1].position[0] == Approx(1.0)); + } + + SECTION("Positive displacement => positions change randomly") { + StructureTransformation::shakeAtoms(st, 0.5); + // Each coordinate of each atom should be within ±0.5 of original + for (size_t i=0; i<st.natoms(); i++) { + for (int d=0; d<3; d++) { + double val = st.atoms[i].position[d]; + double expectedMin = double(i) - 0.5; + double expectedMax = double(i) + 0.5; + REQUIRE(val >= Approx(expectedMin).margin(0.001)); + REQUIRE(val <= Approx(expectedMax).margin(0.001)); + } + } + } +} + +// Helper to create a 3-atom structure in a non-cubic cell for more robust checking. +static Structure makeTrioStructure() +{ + // Create a 3×3×2 cell => volume = 3*3*2 = 18 Å^3 + // Place 3 atoms in somewhat triangular arrangement. + + Structure st; + st.cell.set_zero(); + st.cell(0,0) = 3.0; // x-axis + st.cell(1,1) = 3.0; // y-axis + st.cell(2,2) = 2.0; // z-axis + st.stress.set_zero(); + st.label = "TrioBox"; + + // Suppose we have H at (0.0, 0.0, 0.0), + // He at (1.5, 1.5, 1.0), + // C at (2.9, 2.5, 0.5) + // At least one of them is near the boundary to test partial wrapping or expansions. + Atom h(PeriodicTable::find_by_symbol("H"), 0.0, 0.0, 0.0); + Atom he(PeriodicTable::find_by_symbol("He"), 1.5, 1.5, 1.0); + Atom c (PeriodicTable::find_by_symbol("C"), 2.9, 2.5, 0.5); + + st.add_atom(h); + st.add_atom(he); + st.add_atom(c); + + return st; +} + +TEST_CASE("StructureTransformation replicate() more complex", "[transform][extended]") +{ + // We replicate a 3-atom system in a 2×2×1 manner => 3×2×2×1 = 12 atoms total + Structure st = makeTrioStructure(); + REQUIRE(st.natoms() == 3); + + double oldVolume = StructureProperties::getVolume(st); + REQUIRE(oldVolume == Approx(18.0).margin(1e-12)); + + // replicate(2,2,1) + StructureTransformation::replicate(st, 2, 2, 1); + REQUIRE(st.natoms() == 3 * 2 * 2 * 1); + + double newVolume = StructureProperties::getVolume(st); + // The original cell is 3×3×2 => replicate in x=2 => new x=6, in y=2 => new y=6, z unchanged => new z=2 + // => 6×6×2 = 72 + REQUIRE(newVolume == Approx(72.0).margin(1e-12)); +} + +TEST_CASE("StructureTransformation remapToCell() wide out-of-box atom", "[transform][extended]") +{ + // Place an atom far outside the box in negative/positive directions + Structure st = makeTrioStructure(); + // Insert a new atom at e.g. (-5, 10, 4). This is well outside the 3×3×2 cell. + Atom extra(PeriodicTable::find_by_symbol("O"), -5.0, 10.0, 4.0); + st.add_atom(extra); + REQUIRE(st.natoms() == 4); + + // Now remap + StructureTransformation::remapToCell(st); + + // The cell is 3 in x, 3 in y, 2 in z + // We'll convert fractional => xFrac = (-5 / 3) => -1.666..., wrap => 0.333... => final => ~1.0 + // yFrac = 10/3 => 3.333..., wrap => 0.333..., => final => ~1.0 + // zFrac = 4/2 => 2.0 => wrap => 0.0 => final => 0.0 + // => final ~ (1.0, 1.0, 0.0) + Vec3d pos = st.atoms.back().position; + REQUIRE(pos[0] == Approx(1.0).margin(1e-6)); + REQUIRE(pos[1] == Approx(1.0).margin(1e-6)); + REQUIRE(pos[2] == Approx(0.0).margin(1e-6)); +} + +TEST_CASE("StructureTransformation mirrorHkl() with (1,1,0) plane", "[transform][extended]") +{ + // We'll test reflection about a plane that is normal to (1,1,0) in reciprocal-lattice space. + // For our trio cell => a1=(3,0,0), a2=(0,3,0), a3=(0,0,2) + // (h,k,l)=(1,1,0) => normal is basically along the sum of b1+b2 => x+y direction in direct space. + Structure st = makeTrioStructure(); + REQUIRE_NOTHROW(StructureTransformation::mirrorHkl(st, 1,1,0)); + + // We can't easily do a direct numeric check unless we replicate the reflection math, + // but we can at least check volume remains the same magnitude (reflection => det < 0, but |det| is same). + double vol = StructureProperties::getVolume(st); + REQUIRE(std::fabs(vol) == Approx(18.0).margin(1e-12)); + + // Also confirm that the third atom (near x=2.9,y=2.5) now has x or y negated in some consistent reflection sense. + // We'll do a sanity check that the new positions differ from old by more than some threshold + const auto &a2 = st.atoms[2].position; + // It's presumably "mirrored" in a plane with normal ~ (1/sqrt(2),1/sqrt(2),0). + // We won't match exact numbers, just confirm it's not the same as original 2.9,2.5,0.5 + REQUIRE(std::fabs(a2[0] - 2.9) > 0.1); + REQUIRE(std::fabs(a2[1] - 2.5) > 0.1); +} + +TEST_CASE("StructureTransformation shakeAtoms() large triple-atom structure", "[transform][extended]") +{ + // Similar to existing test, but we do it for a 3-atom structure with bigger displacement + Structure st = makeTrioStructure(); + double maxDisp = 1.0; // more significant shake + + StructureTransformation::shakeAtoms(st, maxDisp); + + // Each original atom was at (0,0,0), (1.5,1.5,1.0), (2.9,2.5,0.5). + // Now each coordinate is in [orig - 1.0, orig + 1.0]. + for (size_t i=0; i < st.natoms(); ++i) { + const Vec3d &pos = st.atoms[i].position; + // We'll do a quick bounding check + // e.g., for the second atom's x=1.5 => must be in [0.5, 2.5] + // We'll just do a generic approach using the original positions from makeTrioStructure() + // For each dimension: + // low = orig - 1.0 + // high = orig + 1.0 + // We'll store them in an array for quick reference: + } + + // Let's re-implement the bounding logic more explicitly: + std::vector<Vec3d> originalPositions{ + {0.0, 0.0, 0.0}, + {1.5, 1.5, 1.0}, + {2.9, 2.5, 0.5} + }; + for (size_t i = 0; i < st.natoms(); i++) { + double ox = originalPositions[i][0]; + double oy = originalPositions[i][1]; + double oz = originalPositions[i][2]; + + Vec3d &curr = st.atoms[i].position; + // Check if inside [ox - 1.0, ox + 1.0], etc. + REQUIRE(curr[0] >= Approx(ox - 1.0).margin(1e-7)); + REQUIRE(curr[0] <= Approx(ox + 1.0).margin(1e-7)); + REQUIRE(curr[1] >= Approx(oy - 1.0).margin(1e-7)); + REQUIRE(curr[1] <= Approx(oy + 1.0).margin(1e-7)); + REQUIRE(curr[2] >= Approx(oz - 1.0).margin(1e-7)); + REQUIRE(curr[2] <= Approx(oz + 1.0).margin(1e-7)); + } +} + +TEST_CASE("StructureTransformation addAtom() checks minDistance in a 3-atom structure", "[transform][extended]") +{ + Structure st = makeTrioStructure(); + REQUIRE(st.natoms() == 3); + + // We'll attempt to add a new H atom near the second atom => we expect a warning + // The code won't throw, but we can verify it doesn't fail catastrophically. + // We'll do some extremely close position => e.g. (1.45,1.45,1.05) + Atom newH(PeriodicTable::find_by_symbol("H"), 0.0,0.0,0.0); // template + Vec3d closePos(1.45, 1.45, 1.05); + REQUIRE_NOTHROW(StructureTransformation::addAtom(st, newH, false, closePos, 0.5)); + // We expect a console warning but no exception. Now there's 4 atoms + REQUIRE(st.natoms() == 4); + + // The newly added atom's position is set to closePos + const auto &lastPos = st.atoms.back().position; + REQUIRE(lastPos[0] == Approx(1.45).margin(1e-6)); + REQUIRE(lastPos[1] == Approx(1.45).margin(1e-6)); + REQUIRE(lastPos[2] == Approx(1.05).margin(1e-6)); +} + +TEST_CASE("StructureTransformation::center() now calls either getCentreOfMass or getGeometricCenter", "[transform][new]") +{ + // We'll do a 2-atom system again, to illustrate difference. + Structure st; + st.cell.set_zero(); + st.cell(0,0) = 1.0; + st.cell(1,1) = 1.0; + st.cell(2,2) = 1.0; + Atom h(PeriodicTable::find_by_symbol("H"), 0.0, 0.0, 0.0); + Atom he(PeriodicTable::find_by_symbol("He"), 1.0, 1.0, 1.0); + st.add_atom(h); + st.add_atom(he); + + SECTION("Center by geometric center => He => (0.5,0.5,0.5)") { + StructureTransformation::center(st, /*useCenterOfMass=*/false, /*shiftCell=*/false); + // Atom 0 => old(0,0,0) - (0.5,0.5,0.5) => (-0.5,-0.5,-0.5) + // Atom 1 => old(1,1,1) - (0.5,0.5,0.5) => (0.5,0.5,0.5) + REQUIRE(st.atoms[0].position[0] == Approx(-0.5)); + REQUIRE(st.atoms[1].position[0] == Approx(0.5)); + } + + SECTION("Center by center of mass => He => ~ (0.2,0.2,0.2)") { + // We'll create a fresh structure since the previous one is offset + Structure st2; + st2.cell.set_zero(); + st2.cell(0,0) = 1.0; + st2.cell(1,1) = 1.0; + st2.cell(2,2) = 1.0; + st2.add_atom(h); + st2.add_atom(he); + + StructureTransformation::center(st2, /*useCenterOfMass=*/true, /*shiftCell=*/false); + + // The heavier He is ~ 4x the mass => new He pos = old(1,1,1) - COM => Should be around (0.2,0.2,0.2) + // Let's check the actual final position: + Vec3d hePos = st2.atoms[1].position; + REQUIRE(hePos[0] == Approx(0.2).margin(0.1)); + REQUIRE(hePos[1] == Approx(0.2).margin(0.1)); + REQUIRE(hePos[2] == Approx(0.2).margin(0.1)); + } +} -- GitLab