From 735f49670fd6f2505b3b16de795c48950d621c34 Mon Sep 17 00:00:00 2001 From: sgrover <shivani.grover@ed.ac.uk> Date: Mon, 4 Nov 2024 16:26:23 +0000 Subject: [PATCH] Upload Workshop notebook --- Workshop3/molcal_workshop3.ipynb | 982 +++++++++++++++++++++++++++++++ 1 file changed, 982 insertions(+) create mode 100644 Workshop3/molcal_workshop3.ipynb diff --git a/Workshop3/molcal_workshop3.ipynb b/Workshop3/molcal_workshop3.ipynb new file mode 100644 index 0000000..4228bc3 --- /dev/null +++ b/Workshop3/molcal_workshop3.ipynb @@ -0,0 +1,982 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "#!pip install jarvis-tools" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Obtaining 3D dataset 76k ...\n", + "Reference:https://www.nature.com/articles/s41524-020-00440-1\n", + "Other versions:https://doi.org/10.6084/m9.figshare.6815699\n", + "Loading the zipfile...\n", + "Loading completed.\n" + ] + } + ], + "source": [ + "# !pip install jarvis-tools, and restart the notebook\n", + "from jarvis.db.figshare import data\n", + "\n", + "dft_3d = data('dft_3d')" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "dict_keys(['jid', 'spg_number', 'spg_symbol', 'formula', 'formation_energy_peratom', 'func', 'optb88vdw_bandgap', 'atoms', 'slme', 'magmom_oszicar', 'spillage', 'elastic_tensor', 'effective_masses_300K', 'kpoint_length_unit', 'maxdiff_mesh', 'maxdiff_bz', 'encut', 'optb88vdw_total_energy', 'epsx', 'epsy', 'epsz', 'mepsx', 'mepsy', 'mepsz', 'modes', 'magmom_outcar', 'max_efg', 'avg_elec_mass', 'avg_hole_mass', 'icsd', 'dfpt_piezo_max_eij', 'dfpt_piezo_max_dij', 'dfpt_piezo_max_dielectric', 'dfpt_piezo_max_dielectric_electronic', 'dfpt_piezo_max_dielectric_ionic', 'max_ir_mode', 'min_ir_mode', 'n-Seebeck', 'p-Seebeck', 'n-powerfact', 'p-powerfact', 'ncond', 'pcond', 'nkappa', 'pkappa', 'ehull', 'Tc_supercon', 'dimensionality', 'efg', 'xml_data_link', 'typ', 'exfoliation_energy', 'spg', 'crys', 'density', 'poisson', 'raw_files', 'nat', 'bulk_modulus_kv', 'shear_modulus_gv', 'mbj_bandgap', 'hse_gap', 'reference', 'search'])" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dft_3d[0].keys()" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "## Let's make a dataframe from this:\n", + "## !pip install pandas ## if it's not installed\n", + "import pandas as pd\n", + "import numpy as np " + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "df=pd.DataFrame(dft_3d)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>jid</th>\n", + " <th>spg_number</th>\n", + " <th>spg_symbol</th>\n", + " <th>formula</th>\n", + " <th>formation_energy_peratom</th>\n", + " <th>func</th>\n", + " <th>optb88vdw_bandgap</th>\n", + " <th>atoms</th>\n", + " <th>slme</th>\n", + " <th>magmom_oszicar</th>\n", + " <th>...</th>\n", + " <th>density</th>\n", + " <th>poisson</th>\n", + " <th>raw_files</th>\n", + " <th>nat</th>\n", + " <th>bulk_modulus_kv</th>\n", + " <th>shear_modulus_gv</th>\n", + " <th>mbj_bandgap</th>\n", + " <th>hse_gap</th>\n", + " <th>reference</th>\n", + " <th>search</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>0</th>\n", + " <td>JVASP-90856</td>\n", + " <td>129</td>\n", + " <td>P4/nmm</td>\n", + " <td>TiCuSiAs</td>\n", + " <td>-0.42762</td>\n", + " <td>OptB88vdW</td>\n", + " <td>0.000</td>\n", + " <td>{'lattice_mat': [[3.566933224304235, 0.0, -0.0...</td>\n", + " <td>na</td>\n", + " <td>0.0</td>\n", + " <td>...</td>\n", + " <td>5.956</td>\n", + " <td>na</td>\n", + " <td>[]</td>\n", + " <td>8</td>\n", + " <td>na</td>\n", + " <td>na</td>\n", + " <td>na</td>\n", + " <td>na</td>\n", + " <td>mp-1080455</td>\n", + " <td>-As-Cu-Si-Ti</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1</th>\n", + " <td>JVASP-86097</td>\n", + " <td>221</td>\n", + " <td>Pm-3m</td>\n", + " <td>DyB6</td>\n", + " <td>-0.41596</td>\n", + " <td>OptB88vdW</td>\n", + " <td>0.000</td>\n", + " <td>{'lattice_mat': [[4.089078911208881, 0.0, 0.0]...</td>\n", + " <td>na</td>\n", + " <td>0.0</td>\n", + " <td>...</td>\n", + " <td>5.522</td>\n", + " <td>na</td>\n", + " <td>[OPT-LOPTICS,JVASP-86097.zip,https://ndownload...</td>\n", + " <td>7</td>\n", + " <td>na</td>\n", + " <td>na</td>\n", + " <td>na</td>\n", + " <td>na</td>\n", + " <td>mp-568319</td>\n", + " <td>-B-Dy</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2</th>\n", + " <td>JVASP-64906</td>\n", + " <td>119</td>\n", + " <td>I-4m2</td>\n", + " <td>Be2OsRu</td>\n", + " <td>0.04847</td>\n", + " <td>OptB88vdW</td>\n", + " <td>0.000</td>\n", + " <td>{'lattice_mat': [[-1.833590720595598, 1.833590...</td>\n", + " <td>na</td>\n", + " <td>0.0</td>\n", + " <td>...</td>\n", + " <td>10.960</td>\n", + " <td>na</td>\n", + " <td>[OPT-LOPTICS,JVASP-64906.zip,https://ndownload...</td>\n", + " <td>4</td>\n", + " <td>na</td>\n", + " <td>na</td>\n", + " <td>na</td>\n", + " <td>na</td>\n", + " <td>auid-3eaf68dd483bf4f4</td>\n", + " <td>-Be-Os-Ru</td>\n", + " </tr>\n", + " <tr>\n", + " <th>3</th>\n", + " <td>JVASP-98225</td>\n", + " <td>14</td>\n", + " <td>P2_1/c</td>\n", + " <td>KBi</td>\n", + " <td>-0.44140</td>\n", + " <td>OptB88vdW</td>\n", + " <td>0.472</td>\n", + " <td>{'lattice_mat': [[7.2963518353359165, 0.0, 0.0...</td>\n", + " <td>na</td>\n", + " <td>0.0</td>\n", + " <td>...</td>\n", + " <td>5.145</td>\n", + " <td>na</td>\n", + " <td>[]</td>\n", + " <td>32</td>\n", + " <td>na</td>\n", + " <td>na</td>\n", + " <td>na</td>\n", + " <td>na</td>\n", + " <td>mp-31104</td>\n", + " <td>-Bi-K</td>\n", + " </tr>\n", + " <tr>\n", + " <th>4</th>\n", + " <td>JVASP-10</td>\n", + " <td>164</td>\n", + " <td>P-3m1</td>\n", + " <td>VSe2</td>\n", + " <td>-0.71026</td>\n", + " <td>OptB88vdW</td>\n", + " <td>0.000</td>\n", + " <td>{'lattice_mat': [[1.6777483798834445, -2.90594...</td>\n", + " <td>na</td>\n", + " <td>0.0</td>\n", + " <td>...</td>\n", + " <td>5.718</td>\n", + " <td>0.23</td>\n", + " <td>[FD-ELAST,JVASP-10.zip,https://ndownloader.fig...</td>\n", + " <td>3</td>\n", + " <td>48.79</td>\n", + " <td>33.05</td>\n", + " <td>0.0</td>\n", + " <td>na</td>\n", + " <td>mp-694</td>\n", + " <td>-Se-V</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "<p>5 rows × 64 columns</p>\n", + "</div>" + ], + "text/plain": [ + " jid spg_number spg_symbol formula formation_energy_peratom \\\n", + "0 JVASP-90856 129 P4/nmm TiCuSiAs -0.42762 \n", + "1 JVASP-86097 221 Pm-3m DyB6 -0.41596 \n", + "2 JVASP-64906 119 I-4m2 Be2OsRu 0.04847 \n", + "3 JVASP-98225 14 P2_1/c KBi -0.44140 \n", + "4 JVASP-10 164 P-3m1 VSe2 -0.71026 \n", + "\n", + " func optb88vdw_bandgap \\\n", + "0 OptB88vdW 0.000 \n", + "1 OptB88vdW 0.000 \n", + "2 OptB88vdW 0.000 \n", + "3 OptB88vdW 0.472 \n", + "4 OptB88vdW 0.000 \n", + "\n", + " atoms slme magmom_oszicar ... \\\n", + "0 {'lattice_mat': [[3.566933224304235, 0.0, -0.0... na 0.0 ... \n", + "1 {'lattice_mat': [[4.089078911208881, 0.0, 0.0]... na 0.0 ... \n", + "2 {'lattice_mat': [[-1.833590720595598, 1.833590... na 0.0 ... \n", + "3 {'lattice_mat': [[7.2963518353359165, 0.0, 0.0... na 0.0 ... \n", + "4 {'lattice_mat': [[1.6777483798834445, -2.90594... na 0.0 ... \n", + "\n", + " density poisson raw_files nat \\\n", + "0 5.956 na [] 8 \n", + "1 5.522 na [OPT-LOPTICS,JVASP-86097.zip,https://ndownload... 7 \n", + "2 10.960 na [OPT-LOPTICS,JVASP-64906.zip,https://ndownload... 4 \n", + "3 5.145 na [] 32 \n", + "4 5.718 0.23 [FD-ELAST,JVASP-10.zip,https://ndownloader.fig... 3 \n", + "\n", + " bulk_modulus_kv shear_modulus_gv mbj_bandgap hse_gap \\\n", + "0 na na na na \n", + "1 na na na na \n", + "2 na na na na \n", + "3 na na na na \n", + "4 48.79 33.05 0.0 na \n", + "\n", + " reference search \n", + "0 mp-1080455 -As-Cu-Si-Ti \n", + "1 mp-568319 -B-Dy \n", + "2 auid-3eaf68dd483bf4f4 -Be-Os-Ru \n", + "3 mp-31104 -Bi-K \n", + "4 mp-694 -Se-V \n", + "\n", + "[5 rows x 64 columns]" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "jid 75993\n", + "spg_number 75993\n", + "spg_symbol 75993\n", + "formula 75993\n", + "formation_energy_peratom 75993\n", + "func 75993\n", + "optb88vdw_bandgap 75993\n", + "atoms 75993\n", + "slme 9770\n", + "magmom_oszicar 71320\n", + "spillage 11377\n", + "elastic_tensor 25513\n", + "effective_masses_300K 75993\n", + "kpoint_length_unit 75671\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/p4/gl5hwwk13vjb1pncdz82sq4h0000gq/T/ipykernel_73354/3773379141.py:3: FutureWarning: Downcasting behavior in `replace` is deprecated and will be removed in a future version. To retain the old behavior, explicitly call `result.infer_objects(copy=False)`. To opt-in to the future behavior, set `pd.set_option('future.no_silent_downcasting', True)`\n", + " val=df[i].replace('na',np.nan).dropna().values\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "maxdiff_mesh 5861\n", + "maxdiff_bz 5861\n", + "encut 75670\n", + "optb88vdw_total_energy 75993\n", + "epsx 52168\n", + "epsy 52168\n", + "epsz 52168\n", + "mepsx 18293\n", + "mepsy 18293\n", + "mepsz 18293\n", + "modes 13910\n", + "magmom_outcar 74261\n", + "max_efg 11871\n", + "avg_elec_mass 17645\n", + "avg_hole_mass 17645\n", + "icsd 75993\n", + "dfpt_piezo_max_eij 4799\n", + "dfpt_piezo_max_dij 3347\n", + "dfpt_piezo_max_dielectric 4706\n", + "dfpt_piezo_max_dielectric_electronic 4809\n", + "dfpt_piezo_max_dielectric_ionic 4809\n", + "max_ir_mode 4805\n", + "min_ir_mode 4809\n", + "n-Seebeck 23218\n", + "p-Seebeck 23218\n", + "n-powerfact 23218\n", + "p-powerfact 23218\n", + "ncond 23218\n", + "pcond 23218\n", + "nkappa 23218\n", + "pkappa 23218\n", + "ehull 75993\n", + "Tc_supercon 1058\n", + "dimensionality 75560\n", + "efg 75993\n", + "xml_data_link 75993\n", + "typ 75993\n", + "exfoliation_energy 813\n", + "spg 75993\n", + "crys 75993\n", + "density 75993\n", + "poisson 23597\n", + "raw_files 75993\n", + "nat 75993\n", + "bulk_modulus_kv 23824\n", + "shear_modulus_gv 23824\n", + "mbj_bandgap 19805\n", + "hse_gap 56\n", + "reference 75993\n", + "search 75993\n" + ] + } + ], + "source": [ + "## Count number of entries for each property\n", + "for i in df.columns.values:\n", + " val=df[i].replace('na',np.nan).dropna().values\n", + " print(i,len(val))" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "## Filter dataset based on desired property \n", + "## We will focus on elastic properties for today, i.e. Bulk modulus" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "#!pip install pymatgen" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "from jarvis.core.atoms import Atoms\n", + "bm=df[df.bulk_modulus_kv != 'na']\n", + "data = [(Atoms.from_dict(bm.iloc[i]['atoms']).pymatgen_converter(), bm.iloc[i].bulk_modulus_kv) for i in range(len(bm))]" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "import itertools\n", + "def get_stoichiometry(elements):\n", + " return [(g[0], len(list(g[1]))) for g in itertools.groupby(elements)]" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 23824/23824 [00:02<00:00, 8443.33it/s]\n" + ] + } + ], + "source": [ + "## Use all the material dataset for training the bulk modulus\n", + "from tqdm import tqdm\n", + "\n", + "stoichs=[] #stoichiometry\n", + "bulk=[] #bulk modulus\n", + "for i in tqdm(range(len(bm))):\n", + " stoichs.append(Atoms.from_dict(bm.iloc[i]['atoms']).pymatgen_converter())\n", + " bulk.append(bm.iloc[i]['bulk_modulus_kv'])\n", + "data_ran=list(zip(stoichs,bulk))\n", + "\n", + "import pickle\n", + "with open('data_ran.pickle', 'wb') as f:\n", + " pickle.dump(data_ran, f)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "### change environments from jarvis to matgl" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/sgrover/anaconda3/envs/matgl-megnet/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "from __future__ import annotations\n", + "\n", + "import os\n", + "import shutil\n", + "import warnings\n", + "import zipfile\n", + "import matgl\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import pandas as pd\n", + "import pytorch_lightning as pl\n", + "import torch\n", + "import pickle\n", + "import numpy as np\n", + "from dgl.data.utils import split_dataset\n", + "from pymatgen.core import Structure\n", + "from pytorch_lightning.loggers import CSVLogger\n", + "from tqdm import tqdm\n", + "\n", + "from matgl.ext.pymatgen import Structure2Graph, get_element_list\n", + "from matgl.graph.data import MGLDataset, MGLDataLoader, collate_fn\n", + "from matgl.layers import BondExpansion\n", + "from matgl.models import MEGNet\n", + "from matgl.utils.io import RemoteFile\n", + "from matgl.utils.training import ModelLightningModule\n", + "\n", + "# To suppress warnings for clearer output\n", + "warnings.simplefilter(\"ignore\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "data_ran=pd.read_pickle('./data_ran.pickle')" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "list" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "type(data_ran)" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [], + "source": [ + "import random\n", + "random.shuffle(data_ran)" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Full Formula (Al2 Cr2 Ge2)\n", + "Reduced Formula: AlCrGe\n", + "abc : 5.953030 4.915863 4.763705\n", + "angles: 75.887124 53.211418 50.901458\n", + "pbc : True True True\n", + "Sites (6)\n", + " # SP a b c\n", + "--- ---- -------- -------- --------\n", + " 0 Al 0.07707 0.42293 0.07707\n", + " 1 Al 0.42293 0.07707 0.42293\n", + " 2 Cr 0.75 0.75 0.75\n", + " 3 Cr 0 0 0\n", + " 4 Ge 0.669987 0.330013 0.669987\n", + " 5 Ge 0.330013 0.669987 0.330013 2.1178676265660163\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "\n", + "structures=[d[0] for d in data_ran[:15000]]\n", + "targets=np.log10([d[1] for d in data_ran])\n", + "\n", + "print(structures[0],targets[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": {}, + "outputs": [], + "source": [ + "# get element types in the dataset\n", + "elem_list = get_element_list(structures)\n", + "# setup a graph converter\n", + "converter = Structure2Graph(element_types=elem_list, cutoff=4.0)\n", + "# convert the raw dataset into MEGNetDataset\n", + "mp_dataset = MGLDataset(\n", + " structures=structures,\n", + " labels={\"bulk_modulus_kv\": targets},\n", + " converter=converter,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": {}, + "outputs": [], + "source": [ + "train_data, val_data, test_data = split_dataset(\n", + " mp_dataset,\n", + " frac_list=[0.8, 0.1, 0.1],\n", + " shuffle=True,\n", + " random_state=42,\n", + ")\n", + "train_loader, val_loader, test_loader = MGLDataLoader(\n", + " train_data=train_data,\n", + " val_data=val_data,\n", + " test_data=test_data,\n", + " collate_fn=collate_fn,\n", + " batch_size=64,\n", + " num_workers=0,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": {}, + "outputs": [], + "source": [ + "# setup the embedding layer for node attributes\n", + "node_embed = torch.nn.Embedding(len(elem_list), 16)\n", + "# define the bond expansion\n", + "bond_expansion = BondExpansion(rbf_type=\"Gaussian\", initial=0.0, final=5.0, num_centers=100, width=0.5)\n", + "\n", + "# setup the architecture of MEGNet model\n", + "model = MEGNet(\n", + " dim_node_embedding=16,\n", + " dim_edge_embedding=100,\n", + " dim_state_embedding=2,\n", + " nblocks=3,\n", + " hidden_layer_sizes_input=(64, 32),\n", + " hidden_layer_sizes_conv=(64, 64, 32),\n", + " nlayers_set2set=1,\n", + " niters_set2set=2,\n", + " hidden_layer_sizes_output=(32, 16),\n", + " is_classification=False,\n", + " activation_type=\"softplus2\",\n", + " bond_expansion=bond_expansion,\n", + " cutoff=4.0,\n", + " gauss_width=0.5,\n", + ")\n", + "\n", + "# setup the MEGNetTrainer\n", + "lit_module = ModelLightningModule(model=model)" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (mps), used: False\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + " | Name | Type | Params\n", + "--------------------------------------------\n", + "0 | model | MEGNet | 189 K \n", + "1 | mae | MeanAbsoluteError | 0 \n", + "2 | rmse | MeanSquaredError | 0 \n", + "--------------------------------------------\n", + "189 K Trainable params\n", + "100 Non-trainable params\n", + "189 K Total params\n", + "0.758 Total estimated model params size (MB)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 39: 100%|██████████| 298/298 [00:24<00:00, 11.99it/s, v_num=6, val_Total_Loss=nan.0, val_MAE=nan.0, val_RMSE=nan.0, train_Total_Loss=nan.0, train_MAE=nan.0, train_RMSE=nan.0]" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`Trainer.fit` stopped: `max_epochs=40` reached.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 39: 100%|██████████| 298/298 [00:24<00:00, 11.97it/s, v_num=6, val_Total_Loss=nan.0, val_MAE=nan.0, val_RMSE=nan.0, train_Total_Loss=nan.0, train_MAE=nan.0, train_RMSE=nan.0]\n" + ] + } + ], + "source": [ + "logger = CSVLogger(\"logs\", name=\"MEGNet_training\")\n", + "trainer = pl.Trainer(max_epochs=40, accelerator=\"cpu\", logger=logger)\n", + "trainer.fit(model=lit_module, train_dataloaders=train_loader, val_dataloaders=val_loader)\n", + "\n", + "warnings.simplefilter(\"ignore\")" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjcAAAGdCAYAAADuR1K7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAqKElEQVR4nO3de3SV1Z3/8c/JPVySEC45BBJAyBgESqZAQliuQU2mQRFBQ8VMuLNEKqAWsIByqVjNVLCCirJcQ6CMIBRGGQuUDgaLKJFLsA73sS5uAicBMQnXJCT794e/nPaYEAPkJDmb92utZ2H2s/d5vnsTPR+f8zzPcRhjjAAAACzh19AFAAAA1CXCDQAAsArhBgAAWIVwAwAArEK4AQAAViHcAAAAqxBuAACAVQg3AADAKgENXUBDqKio0OnTp9W8eXM5HI6GLgcAANSCMUYXLlxQdHS0/Pyuf37mtgw3p0+fVkxMTEOXAQAAbsLJkyfVvn376+6/LcNN8+bNJX2/OGFhYQ1cDQAAqI3i4mLFxMS438ev57YMN5UfRYWFhRFuAADwMT92SQkXFAMAAKsQbgAAgFUINwAAwCq35TU3AADfZ4zRtWvXVF5e3tCloI74+/srICDglh/TQrgBAPic0tJSnTlzRpcvX27oUlDHmjRporZt2yooKOimX4NwAwDwKRUVFTp69Kj8/f0VHR2toKAgHshqAWOMSktLdfbsWR09elRxcXE1PqivJoQbAIBPKS0tVUVFhWJiYtSkSZOGLgd1KDQ0VIGBgTp+/LhKS0sVEhJyU6/DBcUAAJ90s/9Xj8atLv5e+c0AAABWIdwAAACrEG4AAPBBHTt21MKFCxu6jEaJcAMAQD2555579Mwzz9TJa+3evVvjx4+vk9c6duyYHA6H/P39derUKY99Z86ccT975tixY1XGpqWlyd/fX7t3766yb/To0XI4HFW2AQMG1End10O4AQCgkah8MGFttG7dus7vFmvXrp1WrFjh0fb73/9e7dq1q7b/iRMntGPHDk2aNEnZ2dnV9hkwYIDOnDnjsb333nt1WvcPEW4AAD7PGKPLpdcaZDPG1KrG0aNHa9u2bVq0aJH7DMby5cvlcDj0pz/9Sb169VJwcLA+/fRTff311xo8eLCioqLUrFkz9enTRx999JHH6/3wYymHw6H/+I//0MMPP6wmTZooLi5OH3744Q2t46hRo7Rs2TKPtmXLlmnUqFHV9l+2bJkefPBB/eIXv9B7772nK1euVOkTHBwsp9PpsbVo0eKG6rpRPOcGAODzrpSV6645f26QYx+cl6YmQT/+drpo0SL93//9n7p376558+ZJkg4cOCBJmjFjhhYsWKA77rhDLVq00MmTJ/XAAw/opZdeUnBwsFasWKFBgwbpyJEjio2Nve4xXnjhBb3yyiuaP3++3njjDWVmZur48eOKjIys1VweeughLVmyRJ9++qnuvvtuffrpp/ruu+80aNAgvfjiix59jTFatmyZFi9erPj4eHXp0kXr1q3TiBEjanUsb+LMDQAA9SA8PFxBQUFq0qSJ+wyGv7+/JGnevHn613/9V3Xu3FmRkZHq2bOnnnjiCXXv3l1xcXF68cUX1blz5x89EzN69GhlZGSoS5cuevnll3Xx4kXt2rWr1jUGBgZq+PDh7o+YsrOzNXz4cAUGBlbp+9FHH+ny5ctKS0uTJA0fPlxLly6t0m/Dhg1q1qyZx/byyy/XuqabwZkbAIDPCw3018F5aQ127FvVu3dvj58vXryoX//619q4caPOnDmja9eu6cqVKzpx4kSNr/OTn/zE/c9NmzZVWFiYCgoKbqiWsWPHql+/fnr55Ze1du1a5ebmVnsdUHZ2toYNG6aAgO+jREZGhp599ll9/fXX6ty5s7vfvffeq7fffttjbG3PJN0swg0AwOc5HI5afTTUWDVt2tTj52nTpmnLli1asGCBunTpotDQUA0dOlSlpaU1vs4Pz7A4HA5VVFTcUC09evRQfHy8MjIy1LVrV3Xv3l1//etfPfqcP39eH3zwgcrKyjyCS3l5ubKzs/XSSy95zK1Lly43VMOt8t3fBAAAfExQUJDKy8t/tN9nn32m0aNH6+GHH5b0/Zmc6m7D9paxY8fqySefrHLGpdLKlSvVvn17rV+/3qP9f/7nf/Tqq69q3rx57o/cGgLhBgCAetKxY0ft3LlTx44dU7Nmza57ViUuLk7vv/++Bg0aJIfDodmzZ9/wGZhb8fjjj+vnP/+5IiIiqt2/dOlSDR06VN27d/doj4mJ0cyZM7V582YNHDhQklRSUiKXy+XRLyAgQK1atfJK7RIXFAMAUG+mTZsmf39/3XXXXWrduvV1r6H53e9+pxYtWqhfv34aNGiQ0tLS9NOf/rTe6qwMH5XX0/yjvLw8ffnll0pPT6+yLzw8XCkpKR4XFm/evFlt27b12O6++26v1u8wtb1B3yLFxcUKDw9XUVGRwsLCGrocAMANuHr1qo4ePapOnTopJCSkoctBHavp77e279+cuQEAAFYh3AAAYLkJEyZUedZM5TZhwoSGLq/OcUExAACWmzdvnqZNm1btPhsvzyDcAABguTZt2qhNmzYNXUa94WMpAABgFcINAACwCuEGAABYhXADAACsQrgBAABWIdwAAOAjOnbsqIULFzZ0GY0e4QYAAAt17NhRDodDq1evrrKvW7ducjgcWr58eZV9WVlZ8vf31/z586vsW758uRwOR5WtsX0NBuEGAABLxcTEaNmyZR5tn3/+uVwul5o2bVrtmOzsbP3qV79SdnZ2tfvDwsJ05swZj+348eN1XvutINwAAHyfMVLppYbZavn90++8846io6NVUVHh0T548GCNHTtWX3/9tQYPHqyoqCg1a9ZMffr00UcffXRLy5KZmalt27bp5MmT7rbs7GxlZmZW+43f27Zt05UrVzRv3jwVFxdrx44dVfo4HA45nU6PLSoq6pbqrGs8oRgA4PvKLksvRzfMsZ87LQVVfxbkH/385z/X5MmT9fHHHyslJUWSdP78eW3evFmbNm3SxYsX9cADD+ill15ScHCwVqxYoUGDBunIkSOKjY29qdKioqKUlpam3//+95o1a5YuX76sNWvWaNu2bVqxYkWV/kuXLlVGRoYCAwOVkZGhpUuXql+/fjd17IbEmRsAAOpBixYtdP/992vVqlXutnXr1qlVq1a699571bNnTz3xxBPq3r274uLi9OKLL6pz58768MMPb+m4Y8eO1fLly2WM0bp169S5c2clJCRU6VdcXKx169Zp+PDhkqThw4frD3/4gy5evOjRr6ioqMqXb95///23VGNd48wNAMD3BTb5/gxKQx27ljIzM/X444/rrbfeUnBwsFauXKnHHntMfn5+unjxon79619r48aNOnPmjK5du6YrV67oxIkTt1TewIED9cQTT+iTTz5Rdna2xo4dW22/9957T507d1bPnj0lSQkJCerQoYPWrFmjcePGufs1b95ce/fu9RgbGhp6SzXWNcINAMD3ORy1+miooQ0aNEjGGG3cuFF9+vTR9u3b9dprr0mSpk2bpi1btmjBggXq0qWLQkNDNXToUJWWlt7SMQMCAjRixAjNnTtXO3fu1AcffFBtv6VLl+rAgQMe1+JUVFQoOzvbI9z4+fmpS5cut1STtxFuAACoJyEhIXrkkUe0cuVK/e1vf9Odd96pn/70p5Kkzz77TKNHj9bDDz8sSbp48aKOHTtWJ8cdO3asFixYoGHDhqlFixZV9u/bt0979uzRX/7yF0VGRrrbz58/r3vuuUeHDx9WfHx8ndRSHwg3AADUo8zMTD344IM6cOCA+/oWSYqLi9P777+vQYMGyeFwaPbs2VXurLpZXbt21blz59SkSfUfoS1dulSJiYn6l3/5lyr7+vTpo6VLl7qfe2OMkcvlqtKvTZs28vNrHJfyNo4qAAC4Tdx3332KjIzUkSNH9G//9m/u9t/97ndq0aKF+vXrp0GDBiktLc19VqcutGzZstprY0pLS/Xuu+8qPT292nHp6elasWKFysrKJH1/4XHbtm2rbAUFBXVW661yGFPLG/QtUlxcrPDwcBUVFSksLKyhywEA3ICrV6/q6NGj6tSpU6N7Mi5uXU1/v7V9/66XMzeLFy9Wx44dFRISoqSkJO3atavG/mvXrlV8fLxCQkLUo0cPbdq06bp9J0yYIIfDwXdtAAAASfUQbtasWaMpU6Zo7ty52rt3r3r27Km0tLTrnr7asWOHMjIyNG7cOH3xxRcaMmSIhgwZov3791fp+8EHH+jzzz9XdHQDPbgJAIAGsHLlyirPmqncunXr1tDlNTivfyyVlJSkPn366M0335T0/W1lMTExmjx5smbMmFGl/7Bhw3Tp0iVt2LDB3da3b18lJCRoyZIl7rZTp04pKSlJf/7znzVw4EA988wzeuaZZ2pVEx9LAYDv4mMp6cKFC8rPz692X2BgoDp06FDPFdWduvhYyqt3S5WWliovL08zZ850t/n5+Sk1NVW5ubnVjsnNzdWUKVM82tLS0rR+/Xr3zxUVFRoxYoSeffbZWiXUkpISlZSUuH8uLi6+wZkAANB4NG/eXM2bN2/oMhotr34sde7cOZWXl1f5Qq2oqKhqbyOTJJfL9aP9f/vb3yogIEBPPfVUrerIyspSeHi4e4uJibnBmQAAGpvb8H6Y20Jd/L363K3geXl5WrRokZYvXy6Hw1GrMTNnzlRRUZF7+8dvRwUA+JbAwEBJ0uXLlxu4EnhD5d9r5d/zzfDqx1KtWrWSv79/lc8F8/Pz5XQ6qx3jdDpr7L99+3YVFBR4fENqeXm5pk6dqoULF1b7NMfg4GAFBwff4mwAAI2Bv7+/IiIi3DemNGnSpNb/s4vGyxijy5cvq6CgQBEREfL397/p1/JquAkKClKvXr2Uk5OjIUOGSPr+epmcnBxNmjSp2jHJycnKycnxuDh4y5YtSk5OliSNGDFCqampHmPS0tI0YsQIjRkzxivzAAA0LpX/w9uYHhyHuhEREXHdEyC15fWvX5gyZYpGjRql3r17KzExUQsXLtSlS5fcQWTkyJFq166dsrKyJElPP/20+vfvr1dffVUDBw7U6tWrtWfPHr3zzjuSvn/CYsuWLT2OERgYKKfTqTvvvNPb0wEANAIOh0Nt27ZVmzZt3E/Ohe8LDAy8pTM2lbweboYNG6azZ89qzpw5crlcSkhI0ObNm90XDZ84ccLjuyj69eunVatWadasWXruuecUFxen9evXq3v37t4uFQDgY/z9/evkzRB24esXeM4NAAA+oVF9/QIAAEB9IdwAAACrEG4AAIBVCDcAAMAqhBsAAGAVwg0AALAK4QYAAFiFcAMAAKxCuAEAAFYh3AAAAKsQbgAAgFUINwAAwCqEGwAAYBXCDQAAsArhBgAAWIVwAwAArEK4AQAAViHcAAAAqxBuAACAVQg3AADAKoQbAABgFcINAACwCuEGAABYhXADAACsQrgBAABWIdwAAACrEG4AAIBVCDcAAMAqhBsAAGAVwg0AALAK4QYAAFiFcAMAAKxCuAEAAFYh3AAAAKsQbgAAgFUINwAAwCqEGwAAYBXCDQAAsArhBgAAWIVwAwAArEK4AQAAViHcAAAAqxBuAACAVQg3AADAKoQbAABgFcINAACwCuEGAABYhXADAACsQrgBAABWIdwAAACrEG4AAIBVCDcAAMAqhBsAAGAVwg0AALAK4QYAAFiFcAMAAKxSL+Fm8eLF6tixo0JCQpSUlKRdu3bV2H/t2rWKj49XSEiIevTooU2bNrn3lZWVafr06erRo4eaNm2q6OhojRw5UqdPn/b2NAAAgA/werhZs2aNpkyZorlz52rv3r3q2bOn0tLSVFBQUG3/HTt2KCMjQ+PGjdMXX3yhIUOGaMiQIdq/f78k6fLly9q7d69mz56tvXv36v3339eRI0f00EMPeXsqAADABziMMcabB0hKSlKfPn305ptvSpIqKioUExOjyZMna8aMGVX6Dxs2TJcuXdKGDRvcbX379lVCQoKWLFlS7TF2796txMREHT9+XLGxsT9aU3FxscLDw1VUVKSwsLCbnBkAAKhPtX3/9uqZm9LSUuXl5Sk1NfXvB/TzU2pqqnJzc6sdk5ub69FfktLS0q7bX5KKiorkcDgUERFR7f6SkhIVFxd7bAAAwE5eDTfnzp1TeXm5oqKiPNqjoqLkcrmqHeNyuW6o/9WrVzV9+nRlZGRcN8VlZWUpPDzcvcXExNzEbAAAgC/w6bulysrK9Oijj8oYo7fffvu6/WbOnKmioiL3dvLkyXqsEgAA1KcAb754q1at5O/vr/z8fI/2/Px8OZ3Oasc4nc5a9a8MNsePH9fWrVtr/OwtODhYwcHBNzkLAADgS7x65iYoKEi9evVSTk6Ou62iokI5OTlKTk6udkxycrJHf0nasmWLR//KYPPVV1/po48+UsuWLb0zAQAA4HO8euZGkqZMmaJRo0apd+/eSkxM1MKFC3Xp0iWNGTNGkjRy5Ei1a9dOWVlZkqSnn35a/fv316uvvqqBAwdq9erV2rNnj9555x1J3weboUOHau/evdqwYYPKy8vd1+NERkYqKCjI21MCAACNmNfDzbBhw3T27FnNmTNHLpdLCQkJ2rx5s/ui4RMnTsjP7+8nkPr166dVq1Zp1qxZeu655xQXF6f169ere/fukqRTp07pww8/lCQlJCR4HOvjjz/WPffc4+0pAQCARszrz7lpjHjODQAAvqdRPOcGAACgvhFuAACAVQg3AADAKoQbAABgFcINAACwCuEGAABYhXADAACsQrgBAABWIdwAAACrEG4AAIBVCDcAAMAqhBsAAGAVwg0AALAK4QYAAFiFcAMAAKxCuAEAAFYh3AAAAKsQbgAAgFUINwAAwCqEGwAAYBXCDQAAsArhBgAAWIVwAwAArEK4AQAAViHcAAAAqxBuAACAVQg3AADAKoQbAABgFcINAACwCuEGAABYhXADAACsQrgBAABWIdwAAACrEG4AAIBVCDcAAMAqhBsAAGAVwg0AALAK4QYAAFiFcAMAAKxCuAEAAFYh3AAAAKsQbgAAgFUINwAAwCqEGwAAYBXCDQAAsArhBgAAWIVwAwAArEK4AQAAViHcAAAAqxBuAACAVQg3AADAKoQbAABgFcINAACwCuEGAABYhXADAACsQrgBAABWqZdws3jxYnXs2FEhISFKSkrSrl27auy/du1axcfHKyQkRD169NCmTZs89htjNGfOHLVt21ahoaFKTU3VV1995c0pAAAAH+H1cLNmzRpNmTJFc+fO1d69e9WzZ0+lpaWpoKCg2v47duxQRkaGxo0bpy+++EJDhgzRkCFDtH//fnefV155Ra+//rqWLFminTt3qmnTpkpLS9PVq1e9PR0AANDIOYwxxpsHSEpKUp8+ffTmm29KkioqKhQTE6PJkydrxowZVfoPGzZMly5d0oYNG9xtffv2VUJCgpYsWSJjjKKjozV16lRNmzZNklRUVKSoqCgtX75cjz322I/WVFxcrPDwcBUVFSksLKyOZgoAALyptu/fXj1zU1paqry8PKWmpv79gH5+Sk1NVW5ubrVjcnNzPfpLUlpamrv/0aNH5XK5PPqEh4crKSnpuq9ZUlKi4uJijw0AANjJq+Hm3LlzKi8vV1RUlEd7VFSUXC5XtWNcLleN/Sv/vJHXzMrKUnh4uHuLiYm5qfkAAIDG77a4W2rmzJkqKipybydPnmzokgAAgJd4Ndy0atVK/v7+ys/P92jPz8+X0+msdozT6ayxf+WfN/KawcHBCgsL89gAAICdvBpugoKC1KtXL+Xk5LjbKioqlJOTo+Tk5GrHJCcne/SXpC1btrj7d+rUSU6n06NPcXGxdu7ced3XBAAAt48Abx9gypQpGjVqlHr37q3ExEQtXLhQly5d0pgxYyRJI0eOVLt27ZSVlSVJevrpp9W/f3+9+uqrGjhwoFavXq09e/bonXfekSQ5HA4988wz+s1vfqO4uDh16tRJs2fPVnR0tIYMGeLt6QAAgEbO6+Fm2LBhOnv2rObMmSOXy6WEhARt3rzZfUHwiRMn5Of39xNI/fr106pVqzRr1iw999xziouL0/r169W9e3d3n1/96le6dOmSxo8fr8LCQt19993avHmzQkJCvD0dAADQyHn9OTeNEc+5AQDA9zSK59wAAADUN8INAACwCuEGAABYhXADAACsQrgBAABWIdwAAACrEG4AAIBVCDcAAMAqhBsAAGAVwg0AALAK4QYAAFiFcAMAAKxCuAEAAFYh3AAAAKsQbgAAgFUINwAAwCqEGwAAYBXCDQAAsArhBgAAWIVwAwAArEK4AQAAViHcAAAAqxBuAACAVQg3AADAKoQbAABgFcINAACwCuEGAABYhXADAACsQrgBAABWIdwAAACrEG4AAIBVCDcAAMAqhBsAAGAVwg0AALAK4QYAAFiFcAMAAKxCuAEAAFYh3AAAAKsQbgAAgFUINwAAwCqEGwAAYBXCDQAAsArhBgAAWIVwAwAArEK4AQAAViHcAAAAqxBuAACAVQg3AADAKoQbAABgFcINAACwCuEGAABYhXADAACsQrgBAABWIdwAAACrEG4AAIBVCDcAAMAqXgs358+fV2ZmpsLCwhQREaFx48bp4sWLNY65evWqJk6cqJYtW6pZs2ZKT09Xfn6+e/+XX36pjIwMxcTEKDQ0VF27dtWiRYu8NQUAAOCDvBZuMjMzdeDAAW3ZskUbNmzQJ598ovHjx9c45pe//KX++Mc/au3atdq2bZtOnz6tRx55xL0/Ly9Pbdq00bvvvqsDBw7o+eef18yZM/Xmm296axoAAMDHOIwxpq5f9NChQ7rrrru0e/du9e7dW5K0efNmPfDAA/rmm28UHR1dZUxRUZFat26tVatWaejQoZKkw4cPq2vXrsrNzVXfvn2rPdbEiRN16NAhbd26tdb1FRcXKzw8XEVFRQoLC7uJGQIAgPpW2/dvr5y5yc3NVUREhDvYSFJqaqr8/Py0c+fOasfk5eWprKxMqamp7rb4+HjFxsYqNzf3uscqKipSZGRk3RUPAAB8WoA3XtTlcqlNmzaeBwoIUGRkpFwu13XHBAUFKSIiwqM9KirqumN27NihNWvWaOPGjTXWU1JSopKSEvfPxcXFtZgFAADwRTd05mbGjBlyOBw1bocPH/ZWrR7279+vwYMHa+7cufrZz35WY9+srCyFh4e7t5iYmHqpEQAA1L8bOnMzdepUjR49usY+d9xxh5xOpwoKCjzar127pvPnz8vpdFY7zul0qrS0VIWFhR5nb/Lz86uMOXjwoFJSUjR+/HjNmjXrR+ueOXOmpkyZ4v65uLiYgAMAgKVuKNy0bt1arVu3/tF+ycnJKiwsVF5ennr16iVJ2rp1qyoqKpSUlFTtmF69eikwMFA5OTlKT0+XJB05ckQnTpxQcnKyu9+BAwd03333adSoUXrppZdqVXdwcLCCg4Nr1RcAAPg2r9wtJUn333+/8vPztWTJEpWVlWnMmDHq3bu3Vq1aJUk6deqUUlJStGLFCiUmJkqSfvGLX2jTpk1avny5wsLCNHnyZEnfX1sjff9R1H333ae0tDTNnz/ffSx/f/9aha5K3C0FAIDvqe37t1cuKJaklStXatKkSUpJSZGfn5/S09P1+uuvu/eXlZXpyJEjunz5srvttddec/ctKSlRWlqa3nrrLff+devW6ezZs3r33Xf17rvvuts7dOigY8eOeWsqAADAh3jtzE1jxpkbAAB8T4M+5wYAAKChEG4AAIBVCDcAAMAqhBsAAGAVwg0AALAK4QYAAFiFcAMAAKxCuAEAAFYh3AAAAKsQbgAAgFUINwAAwCqEGwAAYBXCDQAAsArhBgAAWIVwAwAArEK4AQAAViHcAAAAqxBuAACAVQg3AADAKoQbAABgFcINAACwCuEGAABYhXADAACsQrgBAABWIdwAAACrEG4AAIBVCDcAAMAqhBsAAGAVwg0AALAK4QYAAFiFcAMAAKxCuAEAAFYh3AAAAKsQbgAAgFUINwAAwCqEGwAAYBXCDQAAsArhBgAAWIVwAwAArEK4AQAAViHcAAAAqxBuAACAVQg3AADAKoQbAABgFcINAACwCuEGAABYhXADAACsQrgBAABWIdwAAACrEG4AAIBVCDcAAMAqhBsAAGAVwg0AALAK4QYAAFiFcAMAAKxCuAEAAFYh3AAAAKt4LdycP39emZmZCgsLU0REhMaNG6eLFy/WOObq1auaOHGiWrZsqWbNmik9PV35+fnV9v3222/Vvn17ORwOFRYWemEGAADAF3kt3GRmZurAgQPasmWLNmzYoE8++UTjx4+vccwvf/lL/fGPf9TatWu1bds2nT59Wo888ki1fceNG6ef/OQn3igdAAD4MIcxxtT1ix46dEh33XWXdu/erd69e0uSNm/erAceeEDffPONoqOjq4wpKipS69attWrVKg0dOlSSdPjwYXXt2lW5ubnq27evu+/bb7+tNWvWaM6cOUpJSdF3332niIiIWtdXXFys8PBwFRUVKSws7NYmCwAA6kVt37+9cuYmNzdXERER7mAjSampqfLz89POnTurHZOXl6eysjKlpqa62+Lj4xUbG6vc3Fx328GDBzVv3jytWLFCfn61K7+kpETFxcUeGwAAsJNXwo3L5VKbNm082gICAhQZGSmXy3XdMUFBQVXOwERFRbnHlJSUKCMjQ/Pnz1dsbGyt68nKylJ4eLh7i4mJubEJAQAAn3FD4WbGjBlyOBw1bocPH/ZWrZo5c6a6du2q4cOH3/C4oqIi93by5EkvVQgAABpawI10njp1qkaPHl1jnzvuuENOp1MFBQUe7deuXdP58+fldDqrHed0OlVaWqrCwkKPszf5+fnuMVu3btW+ffu0bt06SVLl5UKtWrXS888/rxdeeKHa1w4ODlZwcHBtpggAAHzcDYWb1q1bq3Xr1j/aLzk5WYWFhcrLy1OvXr0kfR9MKioqlJSUVO2YXr16KTAwUDk5OUpPT5ckHTlyRCdOnFBycrIk6b/+67905coV95jdu3dr7Nix2r59uzp37nwjUwEAAJa6oXBTW127dtWAAQP0+OOPa8mSJSorK9OkSZP02GOPue+UOnXqlFJSUrRixQolJiYqPDxc48aN05QpUxQZGamwsDBNnjxZycnJ7julfhhgzp075z7ejdwtBQAA7OWVcCNJK1eu1KRJk5SSkiI/Pz+lp6fr9ddfd+8vKyvTkSNHdPnyZXfba6+95u5bUlKitLQ0vfXWW94qEQAAWMgrz7lp7HjODQAAvqdBn3MDAADQUAg3AADAKoQbAABgFcINAACwCuEGAABYhXADAACsQrgBAABWIdwAAACrEG4AAIBVCDcAAMAqhBsAAGAVwg0AALAK4QYAAFiFcAMAAKxCuAEAAFYh3AAAAKsQbgAAgFUINwAAwCqEGwAAYBXCDQAAsArhBgAAWIVwAwAArEK4AQAAViHcAAAAqxBuAACAVQg3AADAKoQbAABgFcINAACwCuEGAABYhXADAACsQrgBAABWIdwAAACrEG4AAIBVCDcAAMAqhBsAAGAVwg0AALAK4QYAAFiFcAMAAKxCuAEAAFYh3AAAAKsQbgAAgFUINwAAwCoBDV1AQzDGSJKKi4sbuBIAAFBble/ble/j13NbhpsLFy5IkmJiYhq4EgAAcKMuXLig8PDw6+53mB+LPxaqqKjQ6dOn1bx5czkcjoYup8EVFxcrJiZGJ0+eVFhYWEOXYy3WuX6wzvWDda4frLMnY4wuXLig6Oho+fld/8qa2/LMjZ+fn9q3b9/QZTQ6YWFh/MtTD1jn+sE61w/WuX6wzn9X0xmbSlxQDAAArEK4AQAAViHcQMHBwZo7d66Cg4MbuhSrsc71g3WuH6xz/WCdb85teUExAACwF2duAACAVQg3AADAKoQbAABgFcINAACwCuHmNnD+/HllZmYqLCxMERERGjdunC5evFjjmKtXr2rixIlq2bKlmjVrpvT0dOXn51fb99tvv1X79u3lcDhUWFjohRn4Bm+s85dffqmMjAzFxMQoNDRUXbt21aJFi7w9lUZn8eLF6tixo0JCQpSUlKRdu3bV2H/t2rWKj49XSEiIevTooU2bNnnsN8Zozpw5atu2rUJDQ5WamqqvvvrKm1PwCXW5zmVlZZo+fbp69Oihpk2bKjo6WiNHjtTp06e9PY1Gr65/n//RhAkT5HA4tHDhwjqu2scYWG/AgAGmZ8+e5vPPPzfbt283Xbp0MRkZGTWOmTBhgomJiTE5OTlmz549pm/fvqZfv37V9h08eLC5//77jSTz3XffeWEGvsEb67x06VLz1FNPmb/85S/m66+/Nv/5n/9pQkNDzRtvvOHt6TQaq1evNkFBQSY7O9scOHDAPP744yYiIsLk5+dX2/+zzz4z/v7+5pVXXjEHDx40s2bNMoGBgWbfvn3uPv/+7/9uwsPDzfr1682XX35pHnroIdOpUydz5cqV+ppWo1PX61xYWGhSU1PNmjVrzOHDh01ubq5JTEw0vXr1qs9pNTre+H2u9P7775uePXua6Oho89prr3l5Jo0b4cZyBw8eNJLM7t273W1/+tOfjMPhMKdOnap2TGFhoQkMDDRr1651tx06dMhIMrm5uR5933rrLdO/f3+Tk5NzW4cbb6/zP3ryySfNvffeW3fFN3KJiYlm4sSJ7p/Ly8tNdHS0ycrKqrb/o48+agYOHOjRlpSUZJ544gljjDEVFRXG6XSa+fPnu/cXFhaa4OBg895773lhBr6hrte5Ort27TKSzPHjx+umaB/krXX+5ptvTLt27cz+/ftNhw4dbvtww8dSlsvNzVVERIR69+7tbktNTZWfn5927txZ7Zi8vDyVlZUpNTXV3RYfH6/Y2Fjl5ua62w4ePKh58+ZpxYoVNX6B2e3Am+v8Q0VFRYqMjKy74hux0tJS5eXleayRn5+fUlNTr7tGubm5Hv0lKS0tzd3/6NGjcrlcHn3Cw8OVlJRU47rbzBvrXJ2ioiI5HA5FRETUSd2+xlvrXFFRoREjRujZZ59Vt27dvFO8j7m935FuAy6XS23atPFoCwgIUGRkpFwu13XHBAUFVfkPUFRUlHtMSUmJMjIyNH/+fMXGxnqldl/irXX+oR07dmjNmjUaP358ndTd2J07d07l5eWKioryaK9pjVwuV439K/+8kde0nTfW+YeuXr2q6dOnKyMj47b9AkhvrfNvf/tbBQQE6Kmnnqr7on0U4cZHzZgxQw6Ho8bt8OHDXjv+zJkz1bVrVw0fPtxrx2gMGnqd/9H+/fs1ePBgzZ07Vz/72c/q5ZhAXSgrK9Ojjz4qY4zefvvthi7HKnl5eVq0aJGWL18uh8PR0OU0GgENXQBuztSpUzV69Oga+9xxxx1yOp0qKCjwaL927ZrOnz8vp9NZ7Tin06nS0lIVFhZ6nFXIz893j9m6dav27dundevWSfr+7hNJatWqlZ5//nm98MILNzmzxqWh17nSwYMHlZKSovHjx2vWrFk3NRdf1KpVK/n7+1e5U6+6NarkdDpr7F/5Z35+vtq2bevRJyEhoQ6r9x3eWOdKlcHm+PHj2rp162171kbyzjpv375dBQUFHmfQy8vLNXXqVC1cuFDHjh2r20n4ioa+6AfeVXmh6549e9xtf/7zn2t1oeu6devcbYcPH/a40PVvf/ub2bdvn3vLzs42ksyOHTuue9W/zby1zsYYs3//ftOmTRvz7LPPem8CjVhiYqKZNGmS++fy8nLTrl27Gi/AfPDBBz3akpOTq1xQvGDBAvf+oqIiLiiu43U2xpjS0lIzZMgQ061bN1NQUOCdwn1MXa/zuXPnPP5bvG/fPhMdHW2mT59uDh8+7L2JNHKEm9vAgAEDzD//8z+bnTt3mk8//dTExcV53KL8zTffmDvvvNPs3LnT3TZhwgQTGxtrtm7davbs2WOSk5NNcnLydY/x8ccf39Z3SxnjnXXet2+fad26tRk+fLg5c+aMe7ud3ihWr15tgoODzfLly83BgwfN+PHjTUREhHG5XMYYY0aMGGFmzJjh7v/ZZ5+ZgIAAs2DBAnPo0CEzd+7cam8Fj4iIMP/93/9t/vd//9cMHjyYW8HreJ1LS0vNQw89ZNq3b2/++te/evz+lpSUNMgcGwNv/D7/EHdLEW5uC99++63JyMgwzZo1M2FhYWbMmDHmwoUL7v1Hjx41kszHH3/sbrty5Yp58sknTYsWLUyTJk3Mww8/bM6cOXPdYxBuvLPOc+fONZKqbB06dKjHmTW8N954w8TGxpqgoCCTmJhoPv/8c/e+/v37m1GjRnn0/8Mf/mD+6Z/+yQQFBZlu3bqZjRs3euyvqKgws2fPNlFRUSY4ONikpKSYI0eO1MdUGrW6XOfK3/fqtn/8d+B2VNe/zz9EuDHGYcz/v1gCAADAAtwtBQAArEK4AQAAViHcAAAAqxBuAACAVQg3AADAKoQbAABgFcINAACwCuEGAABYhXADAACsQrgBAABWIdwAAACrEG4AAIBV/h/KrmaEItaIiQAAAABJRU5ErkJggg==", + "text/plain": [ + "<Figure size 640x480 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "metrics = pd.read_csv(\"logs/MEGNet_training/version_0/metrics.csv\")\n", + "metrics[\"train_MAE\"].dropna().plot()\n", + "metrics[\"val_MAE\"].dropna().plot()\n", + "\n", + "_ = plt.legend()\n", + "#plt.savefig(\"loss.jpg\")" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>epoch</th>\n", + " <th>step</th>\n", + " <th>train_MAE</th>\n", + " <th>train_RMSE</th>\n", + " <th>train_Total_Loss</th>\n", + " <th>val_MAE</th>\n", + " <th>val_RMSE</th>\n", + " <th>val_Total_Loss</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>0</th>\n", + " <td>0</td>\n", + " <td>297</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1</th>\n", + " <td>0</td>\n", + " <td>297</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2</th>\n", + " <td>1</td>\n", + " <td>595</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " </tr>\n", + " <tr>\n", + " <th>3</th>\n", + " <td>1</td>\n", + " <td>595</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " </tr>\n", + " <tr>\n", + " <th>4</th>\n", + " <td>2</td>\n", + " <td>893</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " </tr>\n", + " <tr>\n", + " <th>5</th>\n", + " <td>2</td>\n", + " <td>893</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " </tr>\n", + " <tr>\n", + " <th>6</th>\n", + " <td>3</td>\n", + " <td>1191</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " </tr>\n", + " <tr>\n", + " <th>7</th>\n", + " <td>3</td>\n", + " <td>1191</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " </tr>\n", + " <tr>\n", + " <th>8</th>\n", + " <td>4</td>\n", + " <td>1489</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " </tr>\n", + " <tr>\n", + " <th>9</th>\n", + " <td>4</td>\n", + " <td>1489</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "</div>" + ], + "text/plain": [ + " epoch step train_MAE train_RMSE train_Total_Loss val_MAE val_RMSE \\\n", + "0 0 297 NaN NaN NaN NaN NaN \n", + "1 0 297 NaN NaN NaN NaN NaN \n", + "2 1 595 NaN NaN NaN NaN NaN \n", + "3 1 595 NaN NaN NaN NaN NaN \n", + "4 2 893 NaN NaN NaN NaN NaN \n", + "5 2 893 NaN NaN NaN NaN NaN \n", + "6 3 1191 NaN NaN NaN NaN NaN \n", + "7 3 1191 NaN NaN NaN NaN NaN \n", + "8 4 1489 NaN NaN NaN NaN NaN \n", + "9 4 1489 NaN NaN NaN NaN NaN \n", + "\n", + " val_Total_Loss \n", + "0 NaN \n", + "1 NaN \n", + "2 NaN \n", + "3 NaN \n", + "4 NaN \n", + "5 NaN \n", + "6 NaN \n", + "7 NaN \n", + "8 NaN \n", + "9 NaN " + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "metrics" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "i=0\n", + "prediction=np.zeros(len(test_data))\n", + "for i in range(len(structures_test)):\n", + " prediction[i]=model.predict_structure(structures_test[i])" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "molcal", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} -- GitLab