From 735f49670fd6f2505b3b16de795c48950d621c34 Mon Sep 17 00:00:00 2001
From: sgrover <shivani.grover@ed.ac.uk>
Date: Mon, 4 Nov 2024 16:26:23 +0000
Subject: [PATCH] Upload Workshop notebook

---
 Workshop3/molcal_workshop3.ipynb | 982 +++++++++++++++++++++++++++++++
 1 file changed, 982 insertions(+)
 create mode 100644 Workshop3/molcal_workshop3.ipynb

diff --git a/Workshop3/molcal_workshop3.ipynb b/Workshop3/molcal_workshop3.ipynb
new file mode 100644
index 0000000..4228bc3
--- /dev/null
+++ b/Workshop3/molcal_workshop3.ipynb
@@ -0,0 +1,982 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 14,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "#!pip install jarvis-tools"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 15,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Obtaining 3D dataset 76k ...\n",
+      "Reference:https://www.nature.com/articles/s41524-020-00440-1\n",
+      "Other versions:https://doi.org/10.6084/m9.figshare.6815699\n",
+      "Loading the zipfile...\n",
+      "Loading completed.\n"
+     ]
+    }
+   ],
+   "source": [
+    "# !pip install jarvis-tools, and restart the notebook\n",
+    "from jarvis.db.figshare import data\n",
+    "\n",
+    "dft_3d = data('dft_3d')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 16,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "dict_keys(['jid', 'spg_number', 'spg_symbol', 'formula', 'formation_energy_peratom', 'func', 'optb88vdw_bandgap', 'atoms', 'slme', 'magmom_oszicar', 'spillage', 'elastic_tensor', 'effective_masses_300K', 'kpoint_length_unit', 'maxdiff_mesh', 'maxdiff_bz', 'encut', 'optb88vdw_total_energy', 'epsx', 'epsy', 'epsz', 'mepsx', 'mepsy', 'mepsz', 'modes', 'magmom_outcar', 'max_efg', 'avg_elec_mass', 'avg_hole_mass', 'icsd', 'dfpt_piezo_max_eij', 'dfpt_piezo_max_dij', 'dfpt_piezo_max_dielectric', 'dfpt_piezo_max_dielectric_electronic', 'dfpt_piezo_max_dielectric_ionic', 'max_ir_mode', 'min_ir_mode', 'n-Seebeck', 'p-Seebeck', 'n-powerfact', 'p-powerfact', 'ncond', 'pcond', 'nkappa', 'pkappa', 'ehull', 'Tc_supercon', 'dimensionality', 'efg', 'xml_data_link', 'typ', 'exfoliation_energy', 'spg', 'crys', 'density', 'poisson', 'raw_files', 'nat', 'bulk_modulus_kv', 'shear_modulus_gv', 'mbj_bandgap', 'hse_gap', 'reference', 'search'])"
+      ]
+     },
+     "execution_count": 16,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "dft_3d[0].keys()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 17,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "## Let's make a dataframe from this:\n",
+    "## !pip install pandas     ## if it's not installed\n",
+    "import pandas as pd\n",
+    "import numpy as np "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 18,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "df=pd.DataFrame(dft_3d)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 19,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "<div>\n",
+       "<style scoped>\n",
+       "    .dataframe tbody tr th:only-of-type {\n",
+       "        vertical-align: middle;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe tbody tr th {\n",
+       "        vertical-align: top;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe thead th {\n",
+       "        text-align: right;\n",
+       "    }\n",
+       "</style>\n",
+       "<table border=\"1\" class=\"dataframe\">\n",
+       "  <thead>\n",
+       "    <tr style=\"text-align: right;\">\n",
+       "      <th></th>\n",
+       "      <th>jid</th>\n",
+       "      <th>spg_number</th>\n",
+       "      <th>spg_symbol</th>\n",
+       "      <th>formula</th>\n",
+       "      <th>formation_energy_peratom</th>\n",
+       "      <th>func</th>\n",
+       "      <th>optb88vdw_bandgap</th>\n",
+       "      <th>atoms</th>\n",
+       "      <th>slme</th>\n",
+       "      <th>magmom_oszicar</th>\n",
+       "      <th>...</th>\n",
+       "      <th>density</th>\n",
+       "      <th>poisson</th>\n",
+       "      <th>raw_files</th>\n",
+       "      <th>nat</th>\n",
+       "      <th>bulk_modulus_kv</th>\n",
+       "      <th>shear_modulus_gv</th>\n",
+       "      <th>mbj_bandgap</th>\n",
+       "      <th>hse_gap</th>\n",
+       "      <th>reference</th>\n",
+       "      <th>search</th>\n",
+       "    </tr>\n",
+       "  </thead>\n",
+       "  <tbody>\n",
+       "    <tr>\n",
+       "      <th>0</th>\n",
+       "      <td>JVASP-90856</td>\n",
+       "      <td>129</td>\n",
+       "      <td>P4/nmm</td>\n",
+       "      <td>TiCuSiAs</td>\n",
+       "      <td>-0.42762</td>\n",
+       "      <td>OptB88vdW</td>\n",
+       "      <td>0.000</td>\n",
+       "      <td>{'lattice_mat': [[3.566933224304235, 0.0, -0.0...</td>\n",
+       "      <td>na</td>\n",
+       "      <td>0.0</td>\n",
+       "      <td>...</td>\n",
+       "      <td>5.956</td>\n",
+       "      <td>na</td>\n",
+       "      <td>[]</td>\n",
+       "      <td>8</td>\n",
+       "      <td>na</td>\n",
+       "      <td>na</td>\n",
+       "      <td>na</td>\n",
+       "      <td>na</td>\n",
+       "      <td>mp-1080455</td>\n",
+       "      <td>-As-Cu-Si-Ti</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>1</th>\n",
+       "      <td>JVASP-86097</td>\n",
+       "      <td>221</td>\n",
+       "      <td>Pm-3m</td>\n",
+       "      <td>DyB6</td>\n",
+       "      <td>-0.41596</td>\n",
+       "      <td>OptB88vdW</td>\n",
+       "      <td>0.000</td>\n",
+       "      <td>{'lattice_mat': [[4.089078911208881, 0.0, 0.0]...</td>\n",
+       "      <td>na</td>\n",
+       "      <td>0.0</td>\n",
+       "      <td>...</td>\n",
+       "      <td>5.522</td>\n",
+       "      <td>na</td>\n",
+       "      <td>[OPT-LOPTICS,JVASP-86097.zip,https://ndownload...</td>\n",
+       "      <td>7</td>\n",
+       "      <td>na</td>\n",
+       "      <td>na</td>\n",
+       "      <td>na</td>\n",
+       "      <td>na</td>\n",
+       "      <td>mp-568319</td>\n",
+       "      <td>-B-Dy</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>2</th>\n",
+       "      <td>JVASP-64906</td>\n",
+       "      <td>119</td>\n",
+       "      <td>I-4m2</td>\n",
+       "      <td>Be2OsRu</td>\n",
+       "      <td>0.04847</td>\n",
+       "      <td>OptB88vdW</td>\n",
+       "      <td>0.000</td>\n",
+       "      <td>{'lattice_mat': [[-1.833590720595598, 1.833590...</td>\n",
+       "      <td>na</td>\n",
+       "      <td>0.0</td>\n",
+       "      <td>...</td>\n",
+       "      <td>10.960</td>\n",
+       "      <td>na</td>\n",
+       "      <td>[OPT-LOPTICS,JVASP-64906.zip,https://ndownload...</td>\n",
+       "      <td>4</td>\n",
+       "      <td>na</td>\n",
+       "      <td>na</td>\n",
+       "      <td>na</td>\n",
+       "      <td>na</td>\n",
+       "      <td>auid-3eaf68dd483bf4f4</td>\n",
+       "      <td>-Be-Os-Ru</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>3</th>\n",
+       "      <td>JVASP-98225</td>\n",
+       "      <td>14</td>\n",
+       "      <td>P2_1/c</td>\n",
+       "      <td>KBi</td>\n",
+       "      <td>-0.44140</td>\n",
+       "      <td>OptB88vdW</td>\n",
+       "      <td>0.472</td>\n",
+       "      <td>{'lattice_mat': [[7.2963518353359165, 0.0, 0.0...</td>\n",
+       "      <td>na</td>\n",
+       "      <td>0.0</td>\n",
+       "      <td>...</td>\n",
+       "      <td>5.145</td>\n",
+       "      <td>na</td>\n",
+       "      <td>[]</td>\n",
+       "      <td>32</td>\n",
+       "      <td>na</td>\n",
+       "      <td>na</td>\n",
+       "      <td>na</td>\n",
+       "      <td>na</td>\n",
+       "      <td>mp-31104</td>\n",
+       "      <td>-Bi-K</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>4</th>\n",
+       "      <td>JVASP-10</td>\n",
+       "      <td>164</td>\n",
+       "      <td>P-3m1</td>\n",
+       "      <td>VSe2</td>\n",
+       "      <td>-0.71026</td>\n",
+       "      <td>OptB88vdW</td>\n",
+       "      <td>0.000</td>\n",
+       "      <td>{'lattice_mat': [[1.6777483798834445, -2.90594...</td>\n",
+       "      <td>na</td>\n",
+       "      <td>0.0</td>\n",
+       "      <td>...</td>\n",
+       "      <td>5.718</td>\n",
+       "      <td>0.23</td>\n",
+       "      <td>[FD-ELAST,JVASP-10.zip,https://ndownloader.fig...</td>\n",
+       "      <td>3</td>\n",
+       "      <td>48.79</td>\n",
+       "      <td>33.05</td>\n",
+       "      <td>0.0</td>\n",
+       "      <td>na</td>\n",
+       "      <td>mp-694</td>\n",
+       "      <td>-Se-V</td>\n",
+       "    </tr>\n",
+       "  </tbody>\n",
+       "</table>\n",
+       "<p>5 rows × 64 columns</p>\n",
+       "</div>"
+      ],
+      "text/plain": [
+       "           jid spg_number spg_symbol   formula  formation_energy_peratom  \\\n",
+       "0  JVASP-90856        129     P4/nmm  TiCuSiAs                  -0.42762   \n",
+       "1  JVASP-86097        221      Pm-3m      DyB6                  -0.41596   \n",
+       "2  JVASP-64906        119      I-4m2   Be2OsRu                   0.04847   \n",
+       "3  JVASP-98225         14     P2_1/c       KBi                  -0.44140   \n",
+       "4     JVASP-10        164      P-3m1      VSe2                  -0.71026   \n",
+       "\n",
+       "        func  optb88vdw_bandgap  \\\n",
+       "0  OptB88vdW              0.000   \n",
+       "1  OptB88vdW              0.000   \n",
+       "2  OptB88vdW              0.000   \n",
+       "3  OptB88vdW              0.472   \n",
+       "4  OptB88vdW              0.000   \n",
+       "\n",
+       "                                               atoms slme magmom_oszicar  ...  \\\n",
+       "0  {'lattice_mat': [[3.566933224304235, 0.0, -0.0...   na            0.0  ...   \n",
+       "1  {'lattice_mat': [[4.089078911208881, 0.0, 0.0]...   na            0.0  ...   \n",
+       "2  {'lattice_mat': [[-1.833590720595598, 1.833590...   na            0.0  ...   \n",
+       "3  {'lattice_mat': [[7.2963518353359165, 0.0, 0.0...   na            0.0  ...   \n",
+       "4  {'lattice_mat': [[1.6777483798834445, -2.90594...   na            0.0  ...   \n",
+       "\n",
+       "  density poisson                                          raw_files nat  \\\n",
+       "0   5.956      na                                                 []   8   \n",
+       "1   5.522      na  [OPT-LOPTICS,JVASP-86097.zip,https://ndownload...   7   \n",
+       "2  10.960      na  [OPT-LOPTICS,JVASP-64906.zip,https://ndownload...   4   \n",
+       "3   5.145      na                                                 []  32   \n",
+       "4   5.718    0.23  [FD-ELAST,JVASP-10.zip,https://ndownloader.fig...   3   \n",
+       "\n",
+       "  bulk_modulus_kv shear_modulus_gv mbj_bandgap  hse_gap  \\\n",
+       "0              na               na          na       na   \n",
+       "1              na               na          na       na   \n",
+       "2              na               na          na       na   \n",
+       "3              na               na          na       na   \n",
+       "4           48.79            33.05         0.0       na   \n",
+       "\n",
+       "               reference        search  \n",
+       "0             mp-1080455  -As-Cu-Si-Ti  \n",
+       "1              mp-568319         -B-Dy  \n",
+       "2  auid-3eaf68dd483bf4f4     -Be-Os-Ru  \n",
+       "3               mp-31104         -Bi-K  \n",
+       "4                 mp-694         -Se-V  \n",
+       "\n",
+       "[5 rows x 64 columns]"
+      ]
+     },
+     "execution_count": 19,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "df.head()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 20,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "jid 75993\n",
+      "spg_number 75993\n",
+      "spg_symbol 75993\n",
+      "formula 75993\n",
+      "formation_energy_peratom 75993\n",
+      "func 75993\n",
+      "optb88vdw_bandgap 75993\n",
+      "atoms 75993\n",
+      "slme 9770\n",
+      "magmom_oszicar 71320\n",
+      "spillage 11377\n",
+      "elastic_tensor 25513\n",
+      "effective_masses_300K 75993\n",
+      "kpoint_length_unit 75671\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/var/folders/p4/gl5hwwk13vjb1pncdz82sq4h0000gq/T/ipykernel_73354/3773379141.py:3: FutureWarning: Downcasting behavior in `replace` is deprecated and will be removed in a future version. To retain the old behavior, explicitly call `result.infer_objects(copy=False)`. To opt-in to the future behavior, set `pd.set_option('future.no_silent_downcasting', True)`\n",
+      "  val=df[i].replace('na',np.nan).dropna().values\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "maxdiff_mesh 5861\n",
+      "maxdiff_bz 5861\n",
+      "encut 75670\n",
+      "optb88vdw_total_energy 75993\n",
+      "epsx 52168\n",
+      "epsy 52168\n",
+      "epsz 52168\n",
+      "mepsx 18293\n",
+      "mepsy 18293\n",
+      "mepsz 18293\n",
+      "modes 13910\n",
+      "magmom_outcar 74261\n",
+      "max_efg 11871\n",
+      "avg_elec_mass 17645\n",
+      "avg_hole_mass 17645\n",
+      "icsd 75993\n",
+      "dfpt_piezo_max_eij 4799\n",
+      "dfpt_piezo_max_dij 3347\n",
+      "dfpt_piezo_max_dielectric 4706\n",
+      "dfpt_piezo_max_dielectric_electronic 4809\n",
+      "dfpt_piezo_max_dielectric_ionic 4809\n",
+      "max_ir_mode 4805\n",
+      "min_ir_mode 4809\n",
+      "n-Seebeck 23218\n",
+      "p-Seebeck 23218\n",
+      "n-powerfact 23218\n",
+      "p-powerfact 23218\n",
+      "ncond 23218\n",
+      "pcond 23218\n",
+      "nkappa 23218\n",
+      "pkappa 23218\n",
+      "ehull 75993\n",
+      "Tc_supercon 1058\n",
+      "dimensionality 75560\n",
+      "efg 75993\n",
+      "xml_data_link 75993\n",
+      "typ 75993\n",
+      "exfoliation_energy 813\n",
+      "spg 75993\n",
+      "crys 75993\n",
+      "density 75993\n",
+      "poisson 23597\n",
+      "raw_files 75993\n",
+      "nat 75993\n",
+      "bulk_modulus_kv 23824\n",
+      "shear_modulus_gv 23824\n",
+      "mbj_bandgap 19805\n",
+      "hse_gap 56\n",
+      "reference 75993\n",
+      "search 75993\n"
+     ]
+    }
+   ],
+   "source": [
+    "## Count number of entries for each property\n",
+    "for i in df.columns.values:\n",
+    "  val=df[i].replace('na',np.nan).dropna().values\n",
+    "  print(i,len(val))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 21,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "## Filter dataset based on desired property \n",
+    "## We will focus on elastic properties for today, i.e. Bulk modulus"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 22,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "#!pip install pymatgen"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 23,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from jarvis.core.atoms import Atoms\n",
+    "bm=df[df.bulk_modulus_kv != 'na']\n",
+    "data = [(Atoms.from_dict(bm.iloc[i]['atoms']).pymatgen_converter(), bm.iloc[i].bulk_modulus_kv) for i in range(len(bm))]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 24,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import itertools\n",
+    "def get_stoichiometry(elements):\n",
+    "    return [(g[0], len(list(g[1]))) for g in itertools.groupby(elements)]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 25,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "100%|██████████| 23824/23824 [00:02<00:00, 8443.33it/s]\n"
+     ]
+    }
+   ],
+   "source": [
+    "## Use all the material dataset for training the bulk modulus\n",
+    "from tqdm import tqdm\n",
+    "\n",
+    "stoichs=[]   #stoichiometry\n",
+    "bulk=[]      #bulk modulus\n",
+    "for i in tqdm(range(len(bm))):\n",
+    "    stoichs.append(Atoms.from_dict(bm.iloc[i]['atoms']).pymatgen_converter())\n",
+    "    bulk.append(bm.iloc[i]['bulk_modulus_kv'])\n",
+    "data_ran=list(zip(stoichs,bulk))\n",
+    "\n",
+    "import pickle\n",
+    "with open('data_ran.pickle', 'wb') as f:\n",
+    "    pickle.dump(data_ran, f)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "### change environments from jarvis to matgl"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/Users/sgrover/anaconda3/envs/matgl-megnet/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+      "  from .autonotebook import tqdm as notebook_tqdm\n"
+     ]
+    }
+   ],
+   "source": [
+    "from __future__ import annotations\n",
+    "\n",
+    "import os\n",
+    "import shutil\n",
+    "import warnings\n",
+    "import zipfile\n",
+    "import matgl\n",
+    "\n",
+    "import matplotlib.pyplot as plt\n",
+    "import pandas as pd\n",
+    "import pytorch_lightning as pl\n",
+    "import torch\n",
+    "import pickle\n",
+    "import numpy as np\n",
+    "from dgl.data.utils import split_dataset\n",
+    "from pymatgen.core import Structure\n",
+    "from pytorch_lightning.loggers import CSVLogger\n",
+    "from tqdm import tqdm\n",
+    "\n",
+    "from matgl.ext.pymatgen import Structure2Graph, get_element_list\n",
+    "from matgl.graph.data import MGLDataset, MGLDataLoader, collate_fn\n",
+    "from matgl.layers import BondExpansion\n",
+    "from matgl.models import MEGNet\n",
+    "from matgl.utils.io import RemoteFile\n",
+    "from matgl.utils.training import ModelLightningModule\n",
+    "\n",
+    "# To suppress warnings for clearer output\n",
+    "warnings.simplefilter(\"ignore\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "data_ran=pd.read_pickle('./data_ran.pickle')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "list"
+      ]
+     },
+     "execution_count": 4,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "type(data_ran)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 46,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import random\n",
+    "random.shuffle(data_ran)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 47,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Full Formula (Al2 Cr2 Ge2)\n",
+      "Reduced Formula: AlCrGe\n",
+      "abc   :   5.953030   4.915863   4.763705\n",
+      "angles:  75.887124  53.211418  50.901458\n",
+      "pbc   :       True       True       True\n",
+      "Sites (6)\n",
+      "  #  SP           a         b         c\n",
+      "---  ----  --------  --------  --------\n",
+      "  0  Al    0.07707   0.42293   0.07707\n",
+      "  1  Al    0.42293   0.07707   0.42293\n",
+      "  2  Cr    0.75      0.75      0.75\n",
+      "  3  Cr    0         0         0\n",
+      "  4  Ge    0.669987  0.330013  0.669987\n",
+      "  5  Ge    0.330013  0.669987  0.330013 2.1178676265660163\n"
+     ]
+    }
+   ],
+   "source": [
+    "import numpy as np\n",
+    "\n",
+    "structures=[d[0] for d in data_ran[:15000]]\n",
+    "targets=np.log10([d[1] for d in data_ran])\n",
+    "\n",
+    "print(structures[0],targets[0])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 48,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# get element types in the dataset\n",
+    "elem_list = get_element_list(structures)\n",
+    "# setup a graph converter\n",
+    "converter = Structure2Graph(element_types=elem_list, cutoff=4.0)\n",
+    "# convert the raw dataset into MEGNetDataset\n",
+    "mp_dataset = MGLDataset(\n",
+    "    structures=structures,\n",
+    "    labels={\"bulk_modulus_kv\": targets},\n",
+    "    converter=converter,\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 49,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "train_data, val_data, test_data = split_dataset(\n",
+    "    mp_dataset,\n",
+    "    frac_list=[0.8, 0.1, 0.1],\n",
+    "    shuffle=True,\n",
+    "    random_state=42,\n",
+    ")\n",
+    "train_loader, val_loader, test_loader = MGLDataLoader(\n",
+    "    train_data=train_data,\n",
+    "    val_data=val_data,\n",
+    "    test_data=test_data,\n",
+    "    collate_fn=collate_fn,\n",
+    "    batch_size=64,\n",
+    "    num_workers=0,\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 50,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# setup the embedding layer for node attributes\n",
+    "node_embed = torch.nn.Embedding(len(elem_list), 16)\n",
+    "# define the bond expansion\n",
+    "bond_expansion = BondExpansion(rbf_type=\"Gaussian\", initial=0.0, final=5.0, num_centers=100, width=0.5)\n",
+    "\n",
+    "# setup the architecture of MEGNet model\n",
+    "model = MEGNet(\n",
+    "    dim_node_embedding=16,\n",
+    "    dim_edge_embedding=100,\n",
+    "    dim_state_embedding=2,\n",
+    "    nblocks=3,\n",
+    "    hidden_layer_sizes_input=(64, 32),\n",
+    "    hidden_layer_sizes_conv=(64, 64, 32),\n",
+    "    nlayers_set2set=1,\n",
+    "    niters_set2set=2,\n",
+    "    hidden_layer_sizes_output=(32, 16),\n",
+    "    is_classification=False,\n",
+    "    activation_type=\"softplus2\",\n",
+    "    bond_expansion=bond_expansion,\n",
+    "    cutoff=4.0,\n",
+    "    gauss_width=0.5,\n",
+    ")\n",
+    "\n",
+    "# setup the MEGNetTrainer\n",
+    "lit_module = ModelLightningModule(model=model)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 51,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "GPU available: True (mps), used: False\n",
+      "TPU available: False, using: 0 TPU cores\n",
+      "IPU available: False, using: 0 IPUs\n",
+      "HPU available: False, using: 0 HPUs\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "\n",
+      "  | Name  | Type              | Params\n",
+      "--------------------------------------------\n",
+      "0 | model | MEGNet            | 189 K \n",
+      "1 | mae   | MeanAbsoluteError | 0     \n",
+      "2 | rmse  | MeanSquaredError  | 0     \n",
+      "--------------------------------------------\n",
+      "189 K     Trainable params\n",
+      "100       Non-trainable params\n",
+      "189 K     Total params\n",
+      "0.758     Total estimated model params size (MB)\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Epoch 39: 100%|██████████| 298/298 [00:24<00:00, 11.99it/s, v_num=6, val_Total_Loss=nan.0, val_MAE=nan.0, val_RMSE=nan.0, train_Total_Loss=nan.0, train_MAE=nan.0, train_RMSE=nan.0]"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "`Trainer.fit` stopped: `max_epochs=40` reached.\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Epoch 39: 100%|██████████| 298/298 [00:24<00:00, 11.97it/s, v_num=6, val_Total_Loss=nan.0, val_MAE=nan.0, val_RMSE=nan.0, train_Total_Loss=nan.0, train_MAE=nan.0, train_RMSE=nan.0]\n"
+     ]
+    }
+   ],
+   "source": [
+    "logger = CSVLogger(\"logs\", name=\"MEGNet_training\")\n",
+    "trainer = pl.Trainer(max_epochs=40, accelerator=\"cpu\", logger=logger)\n",
+    "trainer.fit(model=lit_module, train_dataloaders=train_loader, val_dataloaders=val_loader)\n",
+    "\n",
+    "warnings.simplefilter(\"ignore\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 52,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjcAAAGdCAYAAADuR1K7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAqKElEQVR4nO3de3SV1Z3/8c/JPVySEC45BBJAyBgESqZAQliuQU2mQRFBQ8VMuLNEKqAWsIByqVjNVLCCirJcQ6CMIBRGGQuUDgaLKJFLsA73sS5uAicBMQnXJCT794e/nPaYEAPkJDmb92utZ2H2s/d5vnsTPR+f8zzPcRhjjAAAACzh19AFAAAA1CXCDQAAsArhBgAAWIVwAwAArEK4AQAAViHcAAAAqxBuAACAVQg3AADAKgENXUBDqKio0OnTp9W8eXM5HI6GLgcAANSCMUYXLlxQdHS0/Pyuf37mtgw3p0+fVkxMTEOXAQAAbsLJkyfVvn376+6/LcNN8+bNJX2/OGFhYQ1cDQAAqI3i4mLFxMS438ev57YMN5UfRYWFhRFuAADwMT92SQkXFAMAAKsQbgAAgFUINwAAwCq35TU3AADfZ4zRtWvXVF5e3tCloI74+/srICDglh/TQrgBAPic0tJSnTlzRpcvX27oUlDHmjRporZt2yooKOimX4NwAwDwKRUVFTp69Kj8/f0VHR2toKAgHshqAWOMSktLdfbsWR09elRxcXE1PqivJoQbAIBPKS0tVUVFhWJiYtSkSZOGLgd1KDQ0VIGBgTp+/LhKS0sVEhJyU6/DBcUAAJ90s/9Xj8atLv5e+c0AAABWIdwAAACrEG4AAPBBHTt21MKFCxu6jEaJcAMAQD2555579Mwzz9TJa+3evVvjx4+vk9c6duyYHA6H/P39derUKY99Z86ccT975tixY1XGpqWlyd/fX7t3766yb/To0XI4HFW2AQMG1End10O4AQCgkah8MGFttG7dus7vFmvXrp1WrFjh0fb73/9e7dq1q7b/iRMntGPHDk2aNEnZ2dnV9hkwYIDOnDnjsb333nt1WvcPEW4AAD7PGKPLpdcaZDPG1KrG0aNHa9u2bVq0aJH7DMby5cvlcDj0pz/9Sb169VJwcLA+/fRTff311xo8eLCioqLUrFkz9enTRx999JHH6/3wYymHw6H/+I//0MMPP6wmTZooLi5OH3744Q2t46hRo7Rs2TKPtmXLlmnUqFHV9l+2bJkefPBB/eIXv9B7772nK1euVOkTHBwsp9PpsbVo0eKG6rpRPOcGAODzrpSV6645f26QYx+cl6YmQT/+drpo0SL93//9n7p376558+ZJkg4cOCBJmjFjhhYsWKA77rhDLVq00MmTJ/XAAw/opZdeUnBwsFasWKFBgwbpyJEjio2Nve4xXnjhBb3yyiuaP3++3njjDWVmZur48eOKjIys1VweeughLVmyRJ9++qnuvvtuffrpp/ruu+80aNAgvfjiix59jTFatmyZFi9erPj4eHXp0kXr1q3TiBEjanUsb+LMDQAA9SA8PFxBQUFq0qSJ+wyGv7+/JGnevHn613/9V3Xu3FmRkZHq2bOnnnjiCXXv3l1xcXF68cUX1blz5x89EzN69GhlZGSoS5cuevnll3Xx4kXt2rWr1jUGBgZq+PDh7o+YsrOzNXz4cAUGBlbp+9FHH+ny5ctKS0uTJA0fPlxLly6t0m/Dhg1q1qyZx/byyy/XuqabwZkbAIDPCw3018F5aQ127FvVu3dvj58vXryoX//619q4caPOnDmja9eu6cqVKzpx4kSNr/OTn/zE/c9NmzZVWFiYCgoKbqiWsWPHql+/fnr55Ze1du1a5ebmVnsdUHZ2toYNG6aAgO+jREZGhp599ll9/fXX6ty5s7vfvffeq7fffttjbG3PJN0swg0AwOc5HI5afTTUWDVt2tTj52nTpmnLli1asGCBunTpotDQUA0dOlSlpaU1vs4Pz7A4HA5VVFTcUC09evRQfHy8MjIy1LVrV3Xv3l1//etfPfqcP39eH3zwgcrKyjyCS3l5ubKzs/XSSy95zK1Lly43VMOt8t3fBAAAfExQUJDKy8t/tN9nn32m0aNH6+GHH5b0/Zmc6m7D9paxY8fqySefrHLGpdLKlSvVvn17rV+/3qP9f/7nf/Tqq69q3rx57o/cGgLhBgCAetKxY0ft3LlTx44dU7Nmza57ViUuLk7vv/++Bg0aJIfDodmzZ9/wGZhb8fjjj+vnP/+5IiIiqt2/dOlSDR06VN27d/doj4mJ0cyZM7V582YNHDhQklRSUiKXy+XRLyAgQK1atfJK7RIXFAMAUG+mTZsmf39/3XXXXWrduvV1r6H53e9+pxYtWqhfv34aNGiQ0tLS9NOf/rTe6qwMH5XX0/yjvLw8ffnll0pPT6+yLzw8XCkpKR4XFm/evFlt27b12O6++26v1u8wtb1B3yLFxcUKDw9XUVGRwsLCGrocAMANuHr1qo4ePapOnTopJCSkoctBHavp77e279+cuQEAAFYh3AAAYLkJEyZUedZM5TZhwoSGLq/OcUExAACWmzdvnqZNm1btPhsvzyDcAABguTZt2qhNmzYNXUa94WMpAABgFcINAACwCuEGAABYhXADAACsQrgBAABWIdwAAOAjOnbsqIULFzZ0GY0e4QYAAAt17NhRDodDq1evrrKvW7ducjgcWr58eZV9WVlZ8vf31/z586vsW758uRwOR5WtsX0NBuEGAABLxcTEaNmyZR5tn3/+uVwul5o2bVrtmOzsbP3qV79SdnZ2tfvDwsJ05swZj+348eN1XvutINwAAHyfMVLppYbZavn90++8846io6NVUVHh0T548GCNHTtWX3/9tQYPHqyoqCg1a9ZMffr00UcffXRLy5KZmalt27bp5MmT7rbs7GxlZmZW+43f27Zt05UrVzRv3jwVFxdrx44dVfo4HA45nU6PLSoq6pbqrGs8oRgA4PvKLksvRzfMsZ87LQVVfxbkH/385z/X5MmT9fHHHyslJUWSdP78eW3evFmbNm3SxYsX9cADD+ill15ScHCwVqxYoUGDBunIkSOKjY29qdKioqKUlpam3//+95o1a5YuX76sNWvWaNu2bVqxYkWV/kuXLlVGRoYCAwOVkZGhpUuXql+/fjd17IbEmRsAAOpBixYtdP/992vVqlXutnXr1qlVq1a699571bNnTz3xxBPq3r274uLi9OKLL6pz58768MMPb+m4Y8eO1fLly2WM0bp169S5c2clJCRU6VdcXKx169Zp+PDhkqThw4frD3/4gy5evOjRr6ioqMqXb95///23VGNd48wNAMD3BTb5/gxKQx27ljIzM/X444/rrbfeUnBwsFauXKnHHntMfn5+unjxon79619r48aNOnPmjK5du6YrV67oxIkTt1TewIED9cQTT+iTTz5Rdna2xo4dW22/9957T507d1bPnj0lSQkJCerQoYPWrFmjcePGufs1b95ce/fu9RgbGhp6SzXWNcINAMD3ORy1+miooQ0aNEjGGG3cuFF9+vTR9u3b9dprr0mSpk2bpi1btmjBggXq0qWLQkNDNXToUJWWlt7SMQMCAjRixAjNnTtXO3fu1AcffFBtv6VLl+rAgQMe1+JUVFQoOzvbI9z4+fmpS5cut1STtxFuAACoJyEhIXrkkUe0cuVK/e1vf9Odd96pn/70p5Kkzz77TKNHj9bDDz8sSbp48aKOHTtWJ8cdO3asFixYoGHDhqlFixZV9u/bt0979uzRX/7yF0VGRrrbz58/r3vuuUeHDx9WfHx8ndRSHwg3AADUo8zMTD344IM6cOCA+/oWSYqLi9P777+vQYMGyeFwaPbs2VXurLpZXbt21blz59SkSfUfoS1dulSJiYn6l3/5lyr7+vTpo6VLl7qfe2OMkcvlqtKvTZs28vNrHJfyNo4qAAC4Tdx3332KjIzUkSNH9G//9m/u9t/97ndq0aKF+vXrp0GDBiktLc19VqcutGzZstprY0pLS/Xuu+8qPT292nHp6elasWKFysrKJH1/4XHbtm2rbAUFBXVW661yGFPLG/QtUlxcrPDwcBUVFSksLKyhywEA3ICrV6/q6NGj6tSpU6N7Mi5uXU1/v7V9/66XMzeLFy9Wx44dFRISoqSkJO3atavG/mvXrlV8fLxCQkLUo0cPbdq06bp9J0yYIIfDwXdtAAAASfUQbtasWaMpU6Zo7ty52rt3r3r27Km0tLTrnr7asWOHMjIyNG7cOH3xxRcaMmSIhgwZov3791fp+8EHH+jzzz9XdHQDPbgJAIAGsHLlyirPmqncunXr1tDlNTivfyyVlJSkPn366M0335T0/W1lMTExmjx5smbMmFGl/7Bhw3Tp0iVt2LDB3da3b18lJCRoyZIl7rZTp04pKSlJf/7znzVw4EA988wzeuaZZ2pVEx9LAYDv4mMp6cKFC8rPz692X2BgoDp06FDPFdWduvhYyqt3S5WWliovL08zZ850t/n5+Sk1NVW5ubnVjsnNzdWUKVM82tLS0rR+/Xr3zxUVFRoxYoSeffbZWiXUkpISlZSUuH8uLi6+wZkAANB4NG/eXM2bN2/oMhotr34sde7cOZWXl1f5Qq2oqKhqbyOTJJfL9aP9f/vb3yogIEBPPfVUrerIyspSeHi4e4uJibnBmQAAGpvb8H6Y20Jd/L363K3geXl5WrRokZYvXy6Hw1GrMTNnzlRRUZF7+8dvRwUA+JbAwEBJ0uXLlxu4EnhD5d9r5d/zzfDqx1KtWrWSv79/lc8F8/Pz5XQ6qx3jdDpr7L99+3YVFBR4fENqeXm5pk6dqoULF1b7NMfg4GAFBwff4mwAAI2Bv7+/IiIi3DemNGnSpNb/s4vGyxijy5cvq6CgQBEREfL397/p1/JquAkKClKvXr2Uk5OjIUOGSPr+epmcnBxNmjSp2jHJycnKycnxuDh4y5YtSk5OliSNGDFCqampHmPS0tI0YsQIjRkzxivzAAA0LpX/w9uYHhyHuhEREXHdEyC15fWvX5gyZYpGjRql3r17KzExUQsXLtSlS5fcQWTkyJFq166dsrKyJElPP/20+vfvr1dffVUDBw7U6tWrtWfPHr3zzjuSvn/CYsuWLT2OERgYKKfTqTvvvNPb0wEANAIOh0Nt27ZVmzZt3E/Ohe8LDAy8pTM2lbweboYNG6azZ89qzpw5crlcSkhI0ObNm90XDZ84ccLjuyj69eunVatWadasWXruuecUFxen9evXq3v37t4uFQDgY/z9/evkzRB24esXeM4NAAA+oVF9/QIAAEB9IdwAAACrEG4AAIBVCDcAAMAqhBsAAGAVwg0AALAK4QYAAFiFcAMAAKxCuAEAAFYh3AAAAKsQbgAAgFUINwAAwCqEGwAAYBXCDQAAsArhBgAAWIVwAwAArEK4AQAAViHcAAAAqxBuAACAVQg3AADAKoQbAABgFcINAACwCuEGAABYhXADAACsQrgBAABWIdwAAACrEG4AAIBVCDcAAMAqhBsAAGAVwg0AALAK4QYAAFiFcAMAAKxCuAEAAFYh3AAAAKsQbgAAgFUINwAAwCqEGwAAYBXCDQAAsArhBgAAWIVwAwAArEK4AQAAViHcAAAAqxBuAACAVQg3AADAKoQbAABgFcINAACwCuEGAABYhXADAACsQrgBAABWIdwAAACrEG4AAIBVCDcAAMAqhBsAAGAVwg0AALAK4QYAAFiFcAMAAKxSL+Fm8eLF6tixo0JCQpSUlKRdu3bV2H/t2rWKj49XSEiIevTooU2bNrn3lZWVafr06erRo4eaNm2q6OhojRw5UqdPn/b2NAAAgA/werhZs2aNpkyZorlz52rv3r3q2bOn0tLSVFBQUG3/HTt2KCMjQ+PGjdMXX3yhIUOGaMiQIdq/f78k6fLly9q7d69mz56tvXv36v3339eRI0f00EMPeXsqAADABziMMcabB0hKSlKfPn305ptvSpIqKioUExOjyZMna8aMGVX6Dxs2TJcuXdKGDRvcbX379lVCQoKWLFlS7TF2796txMREHT9+XLGxsT9aU3FxscLDw1VUVKSwsLCbnBkAAKhPtX3/9uqZm9LSUuXl5Sk1NfXvB/TzU2pqqnJzc6sdk5ub69FfktLS0q7bX5KKiorkcDgUERFR7f6SkhIVFxd7bAAAwE5eDTfnzp1TeXm5oqKiPNqjoqLkcrmqHeNyuW6o/9WrVzV9+nRlZGRcN8VlZWUpPDzcvcXExNzEbAAAgC/w6bulysrK9Oijj8oYo7fffvu6/WbOnKmioiL3dvLkyXqsEgAA1KcAb754q1at5O/vr/z8fI/2/Px8OZ3Oasc4nc5a9a8MNsePH9fWrVtr/OwtODhYwcHBNzkLAADgS7x65iYoKEi9evVSTk6Ou62iokI5OTlKTk6udkxycrJHf0nasmWLR//KYPPVV1/po48+UsuWLb0zAQAA4HO8euZGkqZMmaJRo0apd+/eSkxM1MKFC3Xp0iWNGTNGkjRy5Ei1a9dOWVlZkqSnn35a/fv316uvvqqBAwdq9erV2rNnj9555x1J3weboUOHau/evdqwYYPKy8vd1+NERkYqKCjI21MCAACNmNfDzbBhw3T27FnNmTNHLpdLCQkJ2rx5s/ui4RMnTsjP7+8nkPr166dVq1Zp1qxZeu655xQXF6f169ere/fukqRTp07pww8/lCQlJCR4HOvjjz/WPffc4+0pAQCARszrz7lpjHjODQAAvqdRPOcGAACgvhFuAACAVQg3AADAKoQbAABgFcINAACwCuEGAABYhXADAACsQrgBAABWIdwAAACrEG4AAIBVCDcAAMAqhBsAAGAVwg0AALAK4QYAAFiFcAMAAKxCuAEAAFYh3AAAAKsQbgAAgFUINwAAwCqEGwAAYBXCDQAAsArhBgAAWIVwAwAArEK4AQAAViHcAAAAqxBuAACAVQg3AADAKoQbAABgFcINAACwCuEGAABYhXADAACsQrgBAABWIdwAAACrEG4AAIBVCDcAAMAqhBsAAGAVwg0AALAK4QYAAFiFcAMAAKxCuAEAAFYh3AAAAKsQbgAAgFUINwAAwCqEGwAAYBXCDQAAsArhBgAAWIVwAwAArEK4AQAAViHcAAAAqxBuAACAVQg3AADAKoQbAABgFcINAACwCuEGAABYhXADAACsQrgBAABWqZdws3jxYnXs2FEhISFKSkrSrl27auy/du1axcfHKyQkRD169NCmTZs89htjNGfOHLVt21ahoaFKTU3VV1995c0pAAAAH+H1cLNmzRpNmTJFc+fO1d69e9WzZ0+lpaWpoKCg2v47duxQRkaGxo0bpy+++EJDhgzRkCFDtH//fnefV155Ra+//rqWLFminTt3qmnTpkpLS9PVq1e9PR0AANDIOYwxxpsHSEpKUp8+ffTmm29KkioqKhQTE6PJkydrxowZVfoPGzZMly5d0oYNG9xtffv2VUJCgpYsWSJjjKKjozV16lRNmzZNklRUVKSoqCgtX75cjz322I/WVFxcrPDwcBUVFSksLKyOZgoAALyptu/fXj1zU1paqry8PKWmpv79gH5+Sk1NVW5ubrVjcnNzPfpLUlpamrv/0aNH5XK5PPqEh4crKSnpuq9ZUlKi4uJijw0AANjJq+Hm3LlzKi8vV1RUlEd7VFSUXC5XtWNcLleN/Sv/vJHXzMrKUnh4uHuLiYm5qfkAAIDG77a4W2rmzJkqKipybydPnmzokgAAgJd4Ndy0atVK/v7+ys/P92jPz8+X0+msdozT6ayxf+WfN/KawcHBCgsL89gAAICdvBpugoKC1KtXL+Xk5LjbKioqlJOTo+Tk5GrHJCcne/SXpC1btrj7d+rUSU6n06NPcXGxdu7ced3XBAAAt48Abx9gypQpGjVqlHr37q3ExEQtXLhQly5d0pgxYyRJI0eOVLt27ZSVlSVJevrpp9W/f3+9+uqrGjhwoFavXq09e/bonXfekSQ5HA4988wz+s1vfqO4uDh16tRJs2fPVnR0tIYMGeLt6QAAgEbO6+Fm2LBhOnv2rObMmSOXy6WEhARt3rzZfUHwiRMn5Of39xNI/fr106pVqzRr1iw999xziouL0/r169W9e3d3n1/96le6dOmSxo8fr8LCQt19993avHmzQkJCvD0dAADQyHn9OTeNEc+5AQDA9zSK59wAAADUN8INAACwCuEGAABYhXADAACsQrgBAABWIdwAAACrEG4AAIBVCDcAAMAqhBsAAGAVwg0AALAK4QYAAFiFcAMAAKxCuAEAAFYh3AAAAKsQbgAAgFUINwAAwCqEGwAAYBXCDQAAsArhBgAAWIVwAwAArEK4AQAAViHcAAAAqxBuAACAVQg3AADAKoQbAABgFcINAACwCuEGAABYhXADAACsQrgBAABWIdwAAACrEG4AAIBVCDcAAMAqhBsAAGAVwg0AALAK4QYAAFiFcAMAAKxCuAEAAFYh3AAAAKsQbgAAgFUINwAAwCqEGwAAYBXCDQAAsArhBgAAWIVwAwAArEK4AQAAViHcAAAAqxBuAACAVQg3AADAKoQbAABgFcINAACwCuEGAABYhXADAACsQrgBAABWIdwAAACrEG4AAIBVCDcAAMAqXgs358+fV2ZmpsLCwhQREaFx48bp4sWLNY65evWqJk6cqJYtW6pZs2ZKT09Xfn6+e/+XX36pjIwMxcTEKDQ0VF27dtWiRYu8NQUAAOCDvBZuMjMzdeDAAW3ZskUbNmzQJ598ovHjx9c45pe//KX++Mc/au3atdq2bZtOnz6tRx55xL0/Ly9Pbdq00bvvvqsDBw7o+eef18yZM/Xmm296axoAAMDHOIwxpq5f9NChQ7rrrru0e/du9e7dW5K0efNmPfDAA/rmm28UHR1dZUxRUZFat26tVatWaejQoZKkw4cPq2vXrsrNzVXfvn2rPdbEiRN16NAhbd26tdb1FRcXKzw8XEVFRQoLC7uJGQIAgPpW2/dvr5y5yc3NVUREhDvYSFJqaqr8/Py0c+fOasfk5eWprKxMqamp7rb4+HjFxsYqNzf3uscqKipSZGRk3RUPAAB8WoA3XtTlcqlNmzaeBwoIUGRkpFwu13XHBAUFKSIiwqM9KirqumN27NihNWvWaOPGjTXWU1JSopKSEvfPxcXFtZgFAADwRTd05mbGjBlyOBw1bocPH/ZWrR7279+vwYMHa+7cufrZz35WY9+srCyFh4e7t5iYmHqpEQAA1L8bOnMzdepUjR49usY+d9xxh5xOpwoKCjzar127pvPnz8vpdFY7zul0qrS0VIWFhR5nb/Lz86uMOXjwoFJSUjR+/HjNmjXrR+ueOXOmpkyZ4v65uLiYgAMAgKVuKNy0bt1arVu3/tF+ycnJKiwsVF5ennr16iVJ2rp1qyoqKpSUlFTtmF69eikwMFA5OTlKT0+XJB05ckQnTpxQcnKyu9+BAwd03333adSoUXrppZdqVXdwcLCCg4Nr1RcAAPg2r9wtJUn333+/8vPztWTJEpWVlWnMmDHq3bu3Vq1aJUk6deqUUlJStGLFCiUmJkqSfvGLX2jTpk1avny5wsLCNHnyZEnfX1sjff9R1H333ae0tDTNnz/ffSx/f/9aha5K3C0FAIDvqe37t1cuKJaklStXatKkSUpJSZGfn5/S09P1+uuvu/eXlZXpyJEjunz5srvttddec/ctKSlRWlqa3nrrLff+devW6ezZs3r33Xf17rvvuts7dOigY8eOeWsqAADAh3jtzE1jxpkbAAB8T4M+5wYAAKChEG4AAIBVCDcAAMAqhBsAAGAVwg0AALAK4QYAAFiFcAMAAKxCuAEAAFYh3AAAAKsQbgAAgFUINwAAwCqEGwAAYBXCDQAAsArhBgAAWIVwAwAArEK4AQAAViHcAAAAqxBuAACAVQg3AADAKoQbAABgFcINAACwCuEGAABYhXADAACsQrgBAABWIdwAAACrEG4AAIBVCDcAAMAqhBsAAGAVwg0AALAK4QYAAFiFcAMAAKxCuAEAAFYh3AAAAKsQbgAAgFUINwAAwCqEGwAAYBXCDQAAsArhBgAAWIVwAwAArEK4AQAAViHcAAAAqxBuAACAVQg3AADAKoQbAABgFcINAACwCuEGAABYhXADAACsQrgBAABWIdwAAACrEG4AAIBVCDcAAMAqhBsAAGAVwg0AALAK4QYAAFiFcAMAAKxCuAEAAFYh3AAAAKt4LdycP39emZmZCgsLU0REhMaNG6eLFy/WOObq1auaOHGiWrZsqWbNmik9PV35+fnV9v3222/Vvn17ORwOFRYWemEGAADAF3kt3GRmZurAgQPasmWLNmzYoE8++UTjx4+vccwvf/lL/fGPf9TatWu1bds2nT59Wo888ki1fceNG6ef/OQn3igdAAD4MIcxxtT1ix46dEh33XWXdu/erd69e0uSNm/erAceeEDffPONoqOjq4wpKipS69attWrVKg0dOlSSdPjwYXXt2lW5ubnq27evu+/bb7+tNWvWaM6cOUpJSdF3332niIiIWtdXXFys8PBwFRUVKSws7NYmCwAA6kVt37+9cuYmNzdXERER7mAjSampqfLz89POnTurHZOXl6eysjKlpqa62+Lj4xUbG6vc3Fx328GDBzVv3jytWLFCfn61K7+kpETFxcUeGwAAsJNXwo3L5VKbNm082gICAhQZGSmXy3XdMUFBQVXOwERFRbnHlJSUKCMjQ/Pnz1dsbGyt68nKylJ4eLh7i4mJubEJAQAAn3FD4WbGjBlyOBw1bocPH/ZWrZo5c6a6du2q4cOH3/C4oqIi93by5EkvVQgAABpawI10njp1qkaPHl1jnzvuuENOp1MFBQUe7deuXdP58+fldDqrHed0OlVaWqrCwkKPszf5+fnuMVu3btW+ffu0bt06SVLl5UKtWrXS888/rxdeeKHa1w4ODlZwcHBtpggAAHzcDYWb1q1bq3Xr1j/aLzk5WYWFhcrLy1OvXr0kfR9MKioqlJSUVO2YXr16KTAwUDk5OUpPT5ckHTlyRCdOnFBycrIk6b/+67905coV95jdu3dr7Nix2r59uzp37nwjUwEAAJa6oXBTW127dtWAAQP0+OOPa8mSJSorK9OkSZP02GOPue+UOnXqlFJSUrRixQolJiYqPDxc48aN05QpUxQZGamwsDBNnjxZycnJ7julfhhgzp075z7ejdwtBQAA7OWVcCNJK1eu1KRJk5SSkiI/Pz+lp6fr9ddfd+8vKyvTkSNHdPnyZXfba6+95u5bUlKitLQ0vfXWW94qEQAAWMgrz7lp7HjODQAAvqdBn3MDAADQUAg3AADAKoQbAABgFcINAACwCuEGAABYhXADAACsQrgBAABWIdwAAACrEG4AAIBVCDcAAMAqhBsAAGAVwg0AALAK4QYAAFiFcAMAAKxCuAEAAFYh3AAAAKsQbgAAgFUINwAAwCqEGwAAYBXCDQAAsArhBgAAWIVwAwAArEK4AQAAViHcAAAAqxBuAACAVQg3AADAKoQbAABgFcINAACwCuEGAABYhXADAACsQrgBAABWIdwAAACrEG4AAIBVCDcAAMAqhBsAAGAVwg0AALAK4QYAAFiFcAMAAKxCuAEAAFYh3AAAAKsQbgAAgFUINwAAwCoBDV1AQzDGSJKKi4sbuBIAAFBble/ble/j13NbhpsLFy5IkmJiYhq4EgAAcKMuXLig8PDw6+53mB+LPxaqqKjQ6dOn1bx5czkcjoYup8EVFxcrJiZGJ0+eVFhYWEOXYy3WuX6wzvWDda4frLMnY4wuXLig6Oho+fld/8qa2/LMjZ+fn9q3b9/QZTQ6YWFh/MtTD1jn+sE61w/WuX6wzn9X0xmbSlxQDAAArEK4AQAAViHcQMHBwZo7d66Cg4MbuhSrsc71g3WuH6xz/WCdb85teUExAACwF2duAACAVQg3AADAKoQbAABgFcINAACwCuHmNnD+/HllZmYqLCxMERERGjdunC5evFjjmKtXr2rixIlq2bKlmjVrpvT0dOXn51fb99tvv1X79u3lcDhUWFjohRn4Bm+s85dffqmMjAzFxMQoNDRUXbt21aJFi7w9lUZn8eLF6tixo0JCQpSUlKRdu3bV2H/t2rWKj49XSEiIevTooU2bNnnsN8Zozpw5atu2rUJDQ5WamqqvvvrKm1PwCXW5zmVlZZo+fbp69Oihpk2bKjo6WiNHjtTp06e9PY1Gr65/n//RhAkT5HA4tHDhwjqu2scYWG/AgAGmZ8+e5vPPPzfbt283Xbp0MRkZGTWOmTBhgomJiTE5OTlmz549pm/fvqZfv37V9h08eLC5//77jSTz3XffeWEGvsEb67x06VLz1FNPmb/85S/m66+/Nv/5n/9pQkNDzRtvvOHt6TQaq1evNkFBQSY7O9scOHDAPP744yYiIsLk5+dX2/+zzz4z/v7+5pVXXjEHDx40s2bNMoGBgWbfvn3uPv/+7/9uwsPDzfr1682XX35pHnroIdOpUydz5cqV+ppWo1PX61xYWGhSU1PNmjVrzOHDh01ubq5JTEw0vXr1qs9pNTre+H2u9P7775uePXua6Oho89prr3l5Jo0b4cZyBw8eNJLM7t273W1/+tOfjMPhMKdOnap2TGFhoQkMDDRr1651tx06dMhIMrm5uR5933rrLdO/f3+Tk5NzW4cbb6/zP3ryySfNvffeW3fFN3KJiYlm4sSJ7p/Ly8tNdHS0ycrKqrb/o48+agYOHOjRlpSUZJ544gljjDEVFRXG6XSa+fPnu/cXFhaa4OBg895773lhBr6hrte5Ort27TKSzPHjx+umaB/krXX+5ptvTLt27cz+/ftNhw4dbvtww8dSlsvNzVVERIR69+7tbktNTZWfn5927txZ7Zi8vDyVlZUpNTXV3RYfH6/Y2Fjl5ua62w4ePKh58+ZpxYoVNX6B2e3Am+v8Q0VFRYqMjKy74hux0tJS5eXleayRn5+fUlNTr7tGubm5Hv0lKS0tzd3/6NGjcrlcHn3Cw8OVlJRU47rbzBvrXJ2ioiI5HA5FRETUSd2+xlvrXFFRoREjRujZZ59Vt27dvFO8j7m935FuAy6XS23atPFoCwgIUGRkpFwu13XHBAUFVfkPUFRUlHtMSUmJMjIyNH/+fMXGxnqldl/irXX+oR07dmjNmjUaP358ndTd2J07d07l5eWKioryaK9pjVwuV439K/+8kde0nTfW+YeuXr2q6dOnKyMj47b9AkhvrfNvf/tbBQQE6Kmnnqr7on0U4cZHzZgxQw6Ho8bt8OHDXjv+zJkz1bVrVw0fPtxrx2gMGnqd/9H+/fs1ePBgzZ07Vz/72c/q5ZhAXSgrK9Ojjz4qY4zefvvthi7HKnl5eVq0aJGWL18uh8PR0OU0GgENXQBuztSpUzV69Oga+9xxxx1yOp0qKCjwaL927ZrOnz8vp9NZ7Tin06nS0lIVFhZ6nFXIz893j9m6dav27dundevWSfr+7hNJatWqlZ5//nm98MILNzmzxqWh17nSwYMHlZKSovHjx2vWrFk3NRdf1KpVK/n7+1e5U6+6NarkdDpr7F/5Z35+vtq2bevRJyEhoQ6r9x3eWOdKlcHm+PHj2rp162171kbyzjpv375dBQUFHmfQy8vLNXXqVC1cuFDHjh2r20n4ioa+6AfeVXmh6549e9xtf/7zn2t1oeu6devcbYcPH/a40PVvf/ub2bdvn3vLzs42ksyOHTuue9W/zby1zsYYs3//ftOmTRvz7LPPem8CjVhiYqKZNGmS++fy8nLTrl27Gi/AfPDBBz3akpOTq1xQvGDBAvf+oqIiLiiu43U2xpjS0lIzZMgQ061bN1NQUOCdwn1MXa/zuXPnPP5bvG/fPhMdHW2mT59uDh8+7L2JNHKEm9vAgAEDzD//8z+bnTt3mk8//dTExcV53KL8zTffmDvvvNPs3LnT3TZhwgQTGxtrtm7davbs2WOSk5NNcnLydY/x8ccf39Z3SxnjnXXet2+fad26tRk+fLg5c+aMe7ud3ihWr15tgoODzfLly83BgwfN+PHjTUREhHG5XMYYY0aMGGFmzJjh7v/ZZ5+ZgIAAs2DBAnPo0CEzd+7cam8Fj4iIMP/93/9t/vd//9cMHjyYW8HreJ1LS0vNQw89ZNq3b2/++te/evz+lpSUNMgcGwNv/D7/EHdLEW5uC99++63JyMgwzZo1M2FhYWbMmDHmwoUL7v1Hjx41kszHH3/sbrty5Yp58sknTYsWLUyTJk3Mww8/bM6cOXPdYxBuvLPOc+fONZKqbB06dKjHmTW8N954w8TGxpqgoCCTmJhoPv/8c/e+/v37m1GjRnn0/8Mf/mD+6Z/+yQQFBZlu3bqZjRs3euyvqKgws2fPNlFRUSY4ONikpKSYI0eO1MdUGrW6XOfK3/fqtn/8d+B2VNe/zz9EuDHGYcz/v1gCAADAAtwtBQAArEK4AQAAViHcAAAAqxBuAACAVQg3AADAKoQbAABgFcINAACwCuEGAABYhXADAACsQrgBAABWIdwAAACrEG4AAIBV/h/KrmaEItaIiQAAAABJRU5ErkJggg==",
+      "text/plain": [
+       "<Figure size 640x480 with 1 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "metrics = pd.read_csv(\"logs/MEGNet_training/version_0/metrics.csv\")\n",
+    "metrics[\"train_MAE\"].dropna().plot()\n",
+    "metrics[\"val_MAE\"].dropna().plot()\n",
+    "\n",
+    "_ = plt.legend()\n",
+    "#plt.savefig(\"loss.jpg\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 44,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "<div>\n",
+       "<style scoped>\n",
+       "    .dataframe tbody tr th:only-of-type {\n",
+       "        vertical-align: middle;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe tbody tr th {\n",
+       "        vertical-align: top;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe thead th {\n",
+       "        text-align: right;\n",
+       "    }\n",
+       "</style>\n",
+       "<table border=\"1\" class=\"dataframe\">\n",
+       "  <thead>\n",
+       "    <tr style=\"text-align: right;\">\n",
+       "      <th></th>\n",
+       "      <th>epoch</th>\n",
+       "      <th>step</th>\n",
+       "      <th>train_MAE</th>\n",
+       "      <th>train_RMSE</th>\n",
+       "      <th>train_Total_Loss</th>\n",
+       "      <th>val_MAE</th>\n",
+       "      <th>val_RMSE</th>\n",
+       "      <th>val_Total_Loss</th>\n",
+       "    </tr>\n",
+       "  </thead>\n",
+       "  <tbody>\n",
+       "    <tr>\n",
+       "      <th>0</th>\n",
+       "      <td>0</td>\n",
+       "      <td>297</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>1</th>\n",
+       "      <td>0</td>\n",
+       "      <td>297</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>2</th>\n",
+       "      <td>1</td>\n",
+       "      <td>595</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>3</th>\n",
+       "      <td>1</td>\n",
+       "      <td>595</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>4</th>\n",
+       "      <td>2</td>\n",
+       "      <td>893</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>5</th>\n",
+       "      <td>2</td>\n",
+       "      <td>893</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>6</th>\n",
+       "      <td>3</td>\n",
+       "      <td>1191</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>7</th>\n",
+       "      <td>3</td>\n",
+       "      <td>1191</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>8</th>\n",
+       "      <td>4</td>\n",
+       "      <td>1489</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>9</th>\n",
+       "      <td>4</td>\n",
+       "      <td>1489</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "      <td>NaN</td>\n",
+       "    </tr>\n",
+       "  </tbody>\n",
+       "</table>\n",
+       "</div>"
+      ],
+      "text/plain": [
+       "   epoch  step  train_MAE  train_RMSE  train_Total_Loss  val_MAE  val_RMSE  \\\n",
+       "0      0   297        NaN         NaN               NaN      NaN       NaN   \n",
+       "1      0   297        NaN         NaN               NaN      NaN       NaN   \n",
+       "2      1   595        NaN         NaN               NaN      NaN       NaN   \n",
+       "3      1   595        NaN         NaN               NaN      NaN       NaN   \n",
+       "4      2   893        NaN         NaN               NaN      NaN       NaN   \n",
+       "5      2   893        NaN         NaN               NaN      NaN       NaN   \n",
+       "6      3  1191        NaN         NaN               NaN      NaN       NaN   \n",
+       "7      3  1191        NaN         NaN               NaN      NaN       NaN   \n",
+       "8      4  1489        NaN         NaN               NaN      NaN       NaN   \n",
+       "9      4  1489        NaN         NaN               NaN      NaN       NaN   \n",
+       "\n",
+       "   val_Total_Loss  \n",
+       "0             NaN  \n",
+       "1             NaN  \n",
+       "2             NaN  \n",
+       "3             NaN  \n",
+       "4             NaN  \n",
+       "5             NaN  \n",
+       "6             NaN  \n",
+       "7             NaN  \n",
+       "8             NaN  \n",
+       "9             NaN  "
+      ]
+     },
+     "execution_count": 44,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "metrics"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "i=0\n",
+    "prediction=np.zeros(len(test_data))\n",
+    "for i in range(len(structures_test)):\n",
+    "    prediction[i]=model.predict_structure(structures_test[i])"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "molcal",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.10.13"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
-- 
GitLab