Skip to content
Snippets Groups Projects
Commit e5df1fcb authored by mkirsz's avatar mkirsz
Browse files

added bitmap for checking is calc initalised for atoms, fixed return to...

added bitmap for checking is calc initalised for atoms, fixed return to continue as it was a clear bug
parent 04d8394c
No related branches found
No related tags found
1 merge request!15added bitmap for checking is calc initalised for atoms, fixed return to...
Pipeline #49815 passed
Pipeline: MD

#49817

    Pipeline: MLIP

    #49816

      ......@@ -8,6 +8,7 @@
      #include <tadah/core/registry.h>
      #include <tadah/models/cut_all.h>
      #include <vector>
      #include <bitset>
      /** \brief Base class for all descriptor types.
      *
      ......@@ -30,9 +31,6 @@ public:
      // if key is not used, returns -1
      int get_arg_pos(const std::string &key) const;
      /** \brief Return dimension of the descriptor.
      */
      virtual size_t size() { return s; };
      /** \brief Return label of this descriptor.
      */
      ......@@ -60,10 +58,19 @@ public:
      virtual void set_fcut(Cut_Base* cut, bool manage_memory);
      double get_rcut() {
      return fcut->get_rcut(); }
      double get_rcut();
      /** \brief Return dimension of the descriptor.
      */
      virtual size_t size();
      virtual bool is_init_for_atom(int Z);
      virtual bool is_init_for_atoms(int Zi, int Zj);
      virtual void init_for_atom(int Zi);
      virtual void uninit_for_atom(int Zi);
      private:
      bool manage_memory=false; // who owns fcut
      std::bitset<119> init_for_atoms;
      };
      #endif
      ......@@ -22,7 +22,8 @@ D2_mJoin::D2_mJoin(Config &c) : D2_Base(c) {
      void D2_mJoin::calc_aed(const int Zi, const int Zj, const double rij,
      const double rij_sq, aed_type &aed, const double scale) {
      for (auto d : ds) {
      if (rij > d->get_rcut()) return;
      if (!d->is_init_for_atoms(Zi,Zj)) continue;
      if (rij > d->get_rcut()) continue;
      d->calc_aed(Zi, Zj, rij, rij_sq, aed, scale);
      }
      }
      ......@@ -30,7 +31,8 @@ void D2_mJoin::calc_aed(const int Zi, const int Zj, const double rij,
      void D2_mJoin::calc_dXijdri(const int Zi, const int Zj, const double rij,
      const double rij_sq, fd_type &fd_ij, const double scale) {
      for (auto d : ds) {
      if (rij > d->get_rcut()) return;
      if (!d->is_init_for_atoms(Zi,Zj)) continue;
      if (rij > d->get_rcut()) continue;
      d->calc_dXijdri(Zi, Zj, rij, rij_sq, fd_ij, scale);
      }
      }
      ......@@ -38,7 +40,8 @@ void D2_mJoin::calc_dXijdri(const int Zi, const int Zj, const double rij,
      void D2_mJoin::calc_all(const int Zi, const int Zj, const double rij,
      const double rij_sq, aed_type &aed, fd_type &fd_ij, const double scale) {
      for (auto d : ds) {
      if (rij > d->get_rcut()) return;
      if (!d->is_init_for_atoms(Zi,Zj)) continue;
      if (rij > d->get_rcut()) continue;
      d->calc_all(Zi, Zj, rij, rij_sq, aed, fd_ij, scale);
      }
      }
      ......
      ......@@ -79,6 +79,7 @@ D_Base::D_Base(Config &c):
      double wi = c.get<double>("WATOMS",i);
      int Z = PeriodicTable::find_by_symbol(symbol).Z;
      weights[Z]=wi;
      init_for_atom(Z);
      }
      }
      }
      ......@@ -98,3 +99,19 @@ int D_Base::get_arg_pos(const std::string &key) const {
      return std::distance(keys.begin(), it)+(nparams-keys.size());
      }
      size_t D_Base::size() { return s; };
      double D_Base::get_rcut() {
      return fcut->get_rcut();
      }
      bool D_Base::is_init_for_atom(int Z) {
      return init_for_atoms[Z];
      }
      bool D_Base::is_init_for_atoms(int Zi, int Zj) {
      return init_for_atoms[Zi] && init_for_atoms[Zj];
      }
      void D_Base::init_for_atom(int Z) {
      init_for_atoms.set(Z);
      }
      void D_Base::uninit_for_atom(int Z) {
      init_for_atoms.reset(Z);
      }
      ......@@ -41,11 +41,12 @@ void DM_mJoin::calc_dXijdri_dXjidri(
      const double scale) {
      size_t rho_fidx = 0;
      for (auto d : ds) {
      if (rij > d->get_rcut()) return;
      size_t rho_size = 2 * d->rhoi_size();
      rho_type rhoi_ptr(&rhoi[rho_fidx],rho_size);
      rho_type rhoj_ptr(&rhoj[rho_fidx],rho_size);
      d->calc_dXijdri_dXjidri(Zi,Zj,rij,rij_sq,vec_ij,rhoi_ptr,rhoj_ptr,fd_ij,scale);
      if (d->is_init_for_atoms(Zi,Zj) && rij < d->get_rcut()) {
      rho_type rhoi_ptr(&rhoi[rho_fidx],rho_size);
      rho_type rhoj_ptr(&rhoj[rho_fidx],rho_size);
      d->calc_dXijdri_dXjidri(Zi,Zj,rij,rij_sq,vec_ij,rhoi_ptr,rhoj_ptr,fd_ij,scale);
      }
      rho_fidx += rho_size;
      }
      }
      ......@@ -60,10 +61,11 @@ void DM_mJoin::calc_dXijdri(
      const double scale) {
      size_t rho_fidx = 0;
      for (auto d : ds) {
      if (rij > d->get_rcut()) return;
      size_t rho_size = 2 * d->rhoi_size();
      rho_type rhoi_ptr(&rhoi[rho_fidx],rho_size);
      d->calc_dXijdri(Zi,Zj,rij,rij_sq,vec_ij,rhoi_ptr,fd_ij,scale);
      if (d->is_init_for_atoms(Zi,Zj) && rij < d->get_rcut()) {
      rho_type rhoi_ptr(&rhoi[rho_fidx],rho_size);
      d->calc_dXijdri(Zi,Zj,rij,rij_sq,vec_ij,rhoi_ptr,fd_ij,scale);
      }
      rho_fidx += rho_size;
      }
      }
      ......@@ -81,10 +83,11 @@ void DM_mJoin::calc_rho(
      const double scale) {
      size_t rho_fidx = 0;
      for (auto d : ds) {
      if (rij > d->get_rcut()) return;
      size_t rho_size = 2 * d->rhoi_size();
      rho_type rho_ptr(&rho[rho_fidx],rho_size);
      d->calc_rho(Zi,Zj,rij,rij_sq,vec_ij,rho_ptr,scale);
      if (d->is_init_for_atoms(Zi,Zj) && rij < d->get_rcut()) {
      rho_type rho_ptr(&rho[rho_fidx],rho_size);
      d->calc_rho(Zi,Zj,rij,rij_sq,vec_ij,rho_ptr,scale);
      }
      rho_fidx += rho_size;
      }
      }
      ......
      0% Loading or .
      You are about to add 0 people to the discussion. Proceed with caution.
      Finish editing this message first!
      Please register or to comment