From 735f49670fd6f2505b3b16de795c48950d621c34 Mon Sep 17 00:00:00 2001 From: sgrover <shivani.grover@ed.ac.uk> Date: Mon, 4 Nov 2024 16:26:23 +0000 Subject: [PATCH] Upload Workshop notebook --- Workshop3/molcal_workshop3.ipynb | 982 +++++++++++++++++++++++++++++++ 1 file changed, 982 insertions(+) create mode 100644 Workshop3/molcal_workshop3.ipynb diff --git a/Workshop3/molcal_workshop3.ipynb b/Workshop3/molcal_workshop3.ipynb new file mode 100644 index 0000000..4228bc3 --- /dev/null +++ b/Workshop3/molcal_workshop3.ipynb @@ -0,0 +1,982 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "#!pip install jarvis-tools" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Obtaining 3D dataset 76k ...\n", + "Reference:https://www.nature.com/articles/s41524-020-00440-1\n", + "Other versions:https://doi.org/10.6084/m9.figshare.6815699\n", + "Loading the zipfile...\n", + "Loading completed.\n" + ] + } + ], + "source": [ + "# !pip install jarvis-tools, and restart the notebook\n", + "from jarvis.db.figshare import data\n", + "\n", + "dft_3d = data('dft_3d')" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "dict_keys(['jid', 'spg_number', 'spg_symbol', 'formula', 'formation_energy_peratom', 'func', 'optb88vdw_bandgap', 'atoms', 'slme', 'magmom_oszicar', 'spillage', 'elastic_tensor', 'effective_masses_300K', 'kpoint_length_unit', 'maxdiff_mesh', 'maxdiff_bz', 'encut', 'optb88vdw_total_energy', 'epsx', 'epsy', 'epsz', 'mepsx', 'mepsy', 'mepsz', 'modes', 'magmom_outcar', 'max_efg', 'avg_elec_mass', 'avg_hole_mass', 'icsd', 'dfpt_piezo_max_eij', 'dfpt_piezo_max_dij', 'dfpt_piezo_max_dielectric', 'dfpt_piezo_max_dielectric_electronic', 'dfpt_piezo_max_dielectric_ionic', 'max_ir_mode', 'min_ir_mode', 'n-Seebeck', 'p-Seebeck', 'n-powerfact', 'p-powerfact', 'ncond', 'pcond', 'nkappa', 'pkappa', 'ehull', 'Tc_supercon', 'dimensionality', 'efg', 'xml_data_link', 'typ', 'exfoliation_energy', 'spg', 'crys', 'density', 'poisson', 'raw_files', 'nat', 'bulk_modulus_kv', 'shear_modulus_gv', 'mbj_bandgap', 'hse_gap', 'reference', 'search'])" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dft_3d[0].keys()" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "## Let's make a dataframe from this:\n", + "## !pip install pandas ## if it's not installed\n", + "import pandas as pd\n", + "import numpy as np " + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "df=pd.DataFrame(dft_3d)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>jid</th>\n", + " <th>spg_number</th>\n", + " <th>spg_symbol</th>\n", + " <th>formula</th>\n", + " <th>formation_energy_peratom</th>\n", + " <th>func</th>\n", + " <th>optb88vdw_bandgap</th>\n", + " <th>atoms</th>\n", + " <th>slme</th>\n", + " <th>magmom_oszicar</th>\n", + " <th>...</th>\n", + " <th>density</th>\n", + " <th>poisson</th>\n", + " <th>raw_files</th>\n", + " <th>nat</th>\n", + " <th>bulk_modulus_kv</th>\n", + " <th>shear_modulus_gv</th>\n", + " <th>mbj_bandgap</th>\n", + " <th>hse_gap</th>\n", + " <th>reference</th>\n", + " <th>search</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>0</th>\n", + " <td>JVASP-90856</td>\n", + " <td>129</td>\n", + " <td>P4/nmm</td>\n", + " <td>TiCuSiAs</td>\n", + " <td>-0.42762</td>\n", + " <td>OptB88vdW</td>\n", + " <td>0.000</td>\n", + " <td>{'lattice_mat': [[3.566933224304235, 0.0, -0.0...</td>\n", + " <td>na</td>\n", + " <td>0.0</td>\n", + " <td>...</td>\n", + " <td>5.956</td>\n", + " <td>na</td>\n", + " <td>[]</td>\n", + " <td>8</td>\n", + " <td>na</td>\n", + " <td>na</td>\n", + " <td>na</td>\n", + " <td>na</td>\n", + " <td>mp-1080455</td>\n", + " <td>-As-Cu-Si-Ti</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1</th>\n", + " <td>JVASP-86097</td>\n", + " <td>221</td>\n", + " <td>Pm-3m</td>\n", + " <td>DyB6</td>\n", + " <td>-0.41596</td>\n", + " <td>OptB88vdW</td>\n", + " <td>0.000</td>\n", + " <td>{'lattice_mat': [[4.089078911208881, 0.0, 0.0]...</td>\n", + " <td>na</td>\n", + " <td>0.0</td>\n", + " <td>...</td>\n", + " <td>5.522</td>\n", + " <td>na</td>\n", + " <td>[OPT-LOPTICS,JVASP-86097.zip,https://ndownload...</td>\n", + " <td>7</td>\n", + " <td>na</td>\n", + " <td>na</td>\n", + " <td>na</td>\n", + " <td>na</td>\n", + " <td>mp-568319</td>\n", + " <td>-B-Dy</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2</th>\n", + " <td>JVASP-64906</td>\n", + " <td>119</td>\n", + " <td>I-4m2</td>\n", + " <td>Be2OsRu</td>\n", + " <td>0.04847</td>\n", + " <td>OptB88vdW</td>\n", + " <td>0.000</td>\n", + " <td>{'lattice_mat': [[-1.833590720595598, 1.833590...</td>\n", + " <td>na</td>\n", + " <td>0.0</td>\n", + " <td>...</td>\n", + " <td>10.960</td>\n", + " <td>na</td>\n", + " <td>[OPT-LOPTICS,JVASP-64906.zip,https://ndownload...</td>\n", + " <td>4</td>\n", + " <td>na</td>\n", + " <td>na</td>\n", + " <td>na</td>\n", + " <td>na</td>\n", + " <td>auid-3eaf68dd483bf4f4</td>\n", + " <td>-Be-Os-Ru</td>\n", + " </tr>\n", + " <tr>\n", + " <th>3</th>\n", + " <td>JVASP-98225</td>\n", + " <td>14</td>\n", + " <td>P2_1/c</td>\n", + " <td>KBi</td>\n", + " <td>-0.44140</td>\n", + " <td>OptB88vdW</td>\n", + " <td>0.472</td>\n", + " <td>{'lattice_mat': [[7.2963518353359165, 0.0, 0.0...</td>\n", + " <td>na</td>\n", + " <td>0.0</td>\n", + " <td>...</td>\n", + " <td>5.145</td>\n", + " <td>na</td>\n", + " <td>[]</td>\n", + " <td>32</td>\n", + " <td>na</td>\n", + " <td>na</td>\n", + " <td>na</td>\n", + " <td>na</td>\n", + " <td>mp-31104</td>\n", + " <td>-Bi-K</td>\n", + " </tr>\n", + " <tr>\n", + " <th>4</th>\n", + " <td>JVASP-10</td>\n", + " <td>164</td>\n", + " <td>P-3m1</td>\n", + " <td>VSe2</td>\n", + " <td>-0.71026</td>\n", + " <td>OptB88vdW</td>\n", + " <td>0.000</td>\n", + " <td>{'lattice_mat': [[1.6777483798834445, -2.90594...</td>\n", + " <td>na</td>\n", + " <td>0.0</td>\n", + " <td>...</td>\n", + " <td>5.718</td>\n", + " <td>0.23</td>\n", + " <td>[FD-ELAST,JVASP-10.zip,https://ndownloader.fig...</td>\n", + " <td>3</td>\n", + " <td>48.79</td>\n", + " <td>33.05</td>\n", + " <td>0.0</td>\n", + " <td>na</td>\n", + " <td>mp-694</td>\n", + " <td>-Se-V</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "<p>5 rows × 64 columns</p>\n", + "</div>" + ], + "text/plain": [ + " jid spg_number spg_symbol formula formation_energy_peratom \\\n", + "0 JVASP-90856 129 P4/nmm TiCuSiAs -0.42762 \n", + "1 JVASP-86097 221 Pm-3m DyB6 -0.41596 \n", + "2 JVASP-64906 119 I-4m2 Be2OsRu 0.04847 \n", + "3 JVASP-98225 14 P2_1/c KBi -0.44140 \n", + "4 JVASP-10 164 P-3m1 VSe2 -0.71026 \n", + "\n", + " func optb88vdw_bandgap \\\n", + "0 OptB88vdW 0.000 \n", + "1 OptB88vdW 0.000 \n", + "2 OptB88vdW 0.000 \n", + "3 OptB88vdW 0.472 \n", + "4 OptB88vdW 0.000 \n", + "\n", + " atoms slme magmom_oszicar ... \\\n", + "0 {'lattice_mat': [[3.566933224304235, 0.0, -0.0... na 0.0 ... \n", + "1 {'lattice_mat': [[4.089078911208881, 0.0, 0.0]... na 0.0 ... \n", + "2 {'lattice_mat': [[-1.833590720595598, 1.833590... na 0.0 ... \n", + "3 {'lattice_mat': [[7.2963518353359165, 0.0, 0.0... na 0.0 ... \n", + "4 {'lattice_mat': [[1.6777483798834445, -2.90594... na 0.0 ... \n", + "\n", + " density poisson raw_files nat \\\n", + "0 5.956 na [] 8 \n", + "1 5.522 na [OPT-LOPTICS,JVASP-86097.zip,https://ndownload... 7 \n", + "2 10.960 na [OPT-LOPTICS,JVASP-64906.zip,https://ndownload... 4 \n", + "3 5.145 na [] 32 \n", + "4 5.718 0.23 [FD-ELAST,JVASP-10.zip,https://ndownloader.fig... 3 \n", + "\n", + " bulk_modulus_kv shear_modulus_gv mbj_bandgap hse_gap \\\n", + "0 na na na na \n", + "1 na na na na \n", + "2 na na na na \n", + "3 na na na na \n", + "4 48.79 33.05 0.0 na \n", + "\n", + " reference search \n", + "0 mp-1080455 -As-Cu-Si-Ti \n", + "1 mp-568319 -B-Dy \n", + "2 auid-3eaf68dd483bf4f4 -Be-Os-Ru \n", + "3 mp-31104 -Bi-K \n", + "4 mp-694 -Se-V \n", + "\n", + "[5 rows x 64 columns]" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "jid 75993\n", + "spg_number 75993\n", + "spg_symbol 75993\n", + "formula 75993\n", + "formation_energy_peratom 75993\n", + "func 75993\n", + "optb88vdw_bandgap 75993\n", + "atoms 75993\n", + "slme 9770\n", + "magmom_oszicar 71320\n", + "spillage 11377\n", + "elastic_tensor 25513\n", + "effective_masses_300K 75993\n", + "kpoint_length_unit 75671\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/p4/gl5hwwk13vjb1pncdz82sq4h0000gq/T/ipykernel_73354/3773379141.py:3: FutureWarning: Downcasting behavior in `replace` is deprecated and will be removed in a future version. To retain the old behavior, explicitly call `result.infer_objects(copy=False)`. To opt-in to the future behavior, set `pd.set_option('future.no_silent_downcasting', True)`\n", + " val=df[i].replace('na',np.nan).dropna().values\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "maxdiff_mesh 5861\n", + "maxdiff_bz 5861\n", + "encut 75670\n", + "optb88vdw_total_energy 75993\n", + "epsx 52168\n", + "epsy 52168\n", + "epsz 52168\n", + "mepsx 18293\n", + "mepsy 18293\n", + "mepsz 18293\n", + "modes 13910\n", + "magmom_outcar 74261\n", + "max_efg 11871\n", + "avg_elec_mass 17645\n", + "avg_hole_mass 17645\n", + "icsd 75993\n", + "dfpt_piezo_max_eij 4799\n", + "dfpt_piezo_max_dij 3347\n", + "dfpt_piezo_max_dielectric 4706\n", + "dfpt_piezo_max_dielectric_electronic 4809\n", + "dfpt_piezo_max_dielectric_ionic 4809\n", + "max_ir_mode 4805\n", + "min_ir_mode 4809\n", + "n-Seebeck 23218\n", + "p-Seebeck 23218\n", + "n-powerfact 23218\n", + "p-powerfact 23218\n", + "ncond 23218\n", + "pcond 23218\n", + "nkappa 23218\n", + "pkappa 23218\n", + "ehull 75993\n", + "Tc_supercon 1058\n", + "dimensionality 75560\n", + "efg 75993\n", + "xml_data_link 75993\n", + "typ 75993\n", + "exfoliation_energy 813\n", + "spg 75993\n", + "crys 75993\n", + "density 75993\n", + "poisson 23597\n", + "raw_files 75993\n", + "nat 75993\n", + "bulk_modulus_kv 23824\n", + "shear_modulus_gv 23824\n", + "mbj_bandgap 19805\n", + "hse_gap 56\n", + "reference 75993\n", + "search 75993\n" + ] + } + ], + "source": [ + "## Count number of entries for each property\n", + "for i in df.columns.values:\n", + " val=df[i].replace('na',np.nan).dropna().values\n", + " print(i,len(val))" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "## Filter dataset based on desired property \n", + "## We will focus on elastic properties for today, i.e. Bulk modulus" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "#!pip install pymatgen" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "from jarvis.core.atoms import Atoms\n", + "bm=df[df.bulk_modulus_kv != 'na']\n", + "data = [(Atoms.from_dict(bm.iloc[i]['atoms']).pymatgen_converter(), bm.iloc[i].bulk_modulus_kv) for i in range(len(bm))]" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "import itertools\n", + "def get_stoichiometry(elements):\n", + " return [(g[0], len(list(g[1]))) for g in itertools.groupby(elements)]" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 23824/23824 [00:02<00:00, 8443.33it/s]\n" + ] + } + ], + "source": [ + "## Use all the material dataset for training the bulk modulus\n", + "from tqdm import tqdm\n", + "\n", + "stoichs=[] #stoichiometry\n", + "bulk=[] #bulk modulus\n", + "for i in tqdm(range(len(bm))):\n", + " stoichs.append(Atoms.from_dict(bm.iloc[i]['atoms']).pymatgen_converter())\n", + " bulk.append(bm.iloc[i]['bulk_modulus_kv'])\n", + "data_ran=list(zip(stoichs,bulk))\n", + "\n", + "import pickle\n", + "with open('data_ran.pickle', 'wb') as f:\n", + " pickle.dump(data_ran, f)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "### change environments from jarvis to matgl" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/sgrover/anaconda3/envs/matgl-megnet/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "from __future__ import annotations\n", + "\n", + "import os\n", + "import shutil\n", + "import warnings\n", + "import zipfile\n", + "import matgl\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import pandas as pd\n", + "import pytorch_lightning as pl\n", + "import torch\n", + "import pickle\n", + "import numpy as np\n", + "from dgl.data.utils import split_dataset\n", + "from pymatgen.core import Structure\n", + "from pytorch_lightning.loggers import CSVLogger\n", + "from tqdm import tqdm\n", + "\n", + "from matgl.ext.pymatgen import Structure2Graph, get_element_list\n", + "from matgl.graph.data import MGLDataset, MGLDataLoader, collate_fn\n", + "from matgl.layers import BondExpansion\n", + "from matgl.models import MEGNet\n", + "from matgl.utils.io import RemoteFile\n", + "from matgl.utils.training import ModelLightningModule\n", + "\n", + "# To suppress warnings for clearer output\n", + "warnings.simplefilter(\"ignore\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "data_ran=pd.read_pickle('./data_ran.pickle')" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "list" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "type(data_ran)" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [], + "source": [ + "import random\n", + "random.shuffle(data_ran)" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Full Formula (Al2 Cr2 Ge2)\n", + "Reduced Formula: AlCrGe\n", + "abc : 5.953030 4.915863 4.763705\n", + "angles: 75.887124 53.211418 50.901458\n", + "pbc : True True True\n", + "Sites (6)\n", + " # SP a b c\n", + "--- ---- -------- -------- --------\n", + " 0 Al 0.07707 0.42293 0.07707\n", + " 1 Al 0.42293 0.07707 0.42293\n", + " 2 Cr 0.75 0.75 0.75\n", + " 3 Cr 0 0 0\n", + " 4 Ge 0.669987 0.330013 0.669987\n", + " 5 Ge 0.330013 0.669987 0.330013 2.1178676265660163\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "\n", + "structures=[d[0] for d in data_ran[:15000]]\n", + "targets=np.log10([d[1] for d in data_ran])\n", + "\n", + "print(structures[0],targets[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": {}, + "outputs": [], + "source": [ + "# get element types in the dataset\n", + "elem_list = get_element_list(structures)\n", + "# setup a graph converter\n", + "converter = Structure2Graph(element_types=elem_list, cutoff=4.0)\n", + "# convert the raw dataset into MEGNetDataset\n", + "mp_dataset = MGLDataset(\n", + " structures=structures,\n", + " labels={\"bulk_modulus_kv\": targets},\n", + " converter=converter,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": {}, + "outputs": [], + "source": [ + "train_data, val_data, test_data = split_dataset(\n", + " mp_dataset,\n", + " frac_list=[0.8, 0.1, 0.1],\n", + " shuffle=True,\n", + " random_state=42,\n", + ")\n", + "train_loader, val_loader, test_loader = MGLDataLoader(\n", + " train_data=train_data,\n", + " val_data=val_data,\n", + " test_data=test_data,\n", + " collate_fn=collate_fn,\n", + " batch_size=64,\n", + " num_workers=0,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": {}, + "outputs": [], + "source": [ + "# setup the embedding layer for node attributes\n", + "node_embed = torch.nn.Embedding(len(elem_list), 16)\n", + "# define the bond expansion\n", + "bond_expansion = BondExpansion(rbf_type=\"Gaussian\", initial=0.0, final=5.0, num_centers=100, width=0.5)\n", + "\n", + "# setup the architecture of MEGNet model\n", + "model = MEGNet(\n", + " dim_node_embedding=16,\n", + " dim_edge_embedding=100,\n", + " dim_state_embedding=2,\n", + " nblocks=3,\n", + " hidden_layer_sizes_input=(64, 32),\n", + " hidden_layer_sizes_conv=(64, 64, 32),\n", + " nlayers_set2set=1,\n", + " niters_set2set=2,\n", + " hidden_layer_sizes_output=(32, 16),\n", + " is_classification=False,\n", + " activation_type=\"softplus2\",\n", + " bond_expansion=bond_expansion,\n", + " cutoff=4.0,\n", + " gauss_width=0.5,\n", + ")\n", + "\n", + "# setup the MEGNetTrainer\n", + "lit_module = ModelLightningModule(model=model)" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (mps), used: False\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + " | Name | Type | Params\n", + "--------------------------------------------\n", + "0 | model | MEGNet | 189 K \n", + "1 | mae | MeanAbsoluteError | 0 \n", + "2 | rmse | MeanSquaredError | 0 \n", + "--------------------------------------------\n", + "189 K Trainable params\n", + "100 Non-trainable params\n", + "189 K Total params\n", + "0.758 Total estimated model params size (MB)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 39: 100%|██████████| 298/298 [00:24<00:00, 11.99it/s, v_num=6, val_Total_Loss=nan.0, val_MAE=nan.0, val_RMSE=nan.0, train_Total_Loss=nan.0, train_MAE=nan.0, train_RMSE=nan.0]" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`Trainer.fit` stopped: `max_epochs=40` reached.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 39: 100%|██████████| 298/298 [00:24<00:00, 11.97it/s, v_num=6, val_Total_Loss=nan.0, val_MAE=nan.0, val_RMSE=nan.0, train_Total_Loss=nan.0, train_MAE=nan.0, train_RMSE=nan.0]\n" + ] + } + ], + "source": [ + "logger = CSVLogger(\"logs\", name=\"MEGNet_training\")\n", + "trainer = pl.Trainer(max_epochs=40, accelerator=\"cpu\", logger=logger)\n", + "trainer.fit(model=lit_module, train_dataloaders=train_loader, val_dataloaders=val_loader)\n", + "\n", + "warnings.simplefilter(\"ignore\")" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "<Figure size 640x480 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "metrics = pd.read_csv(\"logs/MEGNet_training/version_0/metrics.csv\")\n", + "metrics[\"train_MAE\"].dropna().plot()\n", + "metrics[\"val_MAE\"].dropna().plot()\n", + "\n", + "_ = plt.legend()\n", + "#plt.savefig(\"loss.jpg\")" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>epoch</th>\n", + " <th>step</th>\n", + " <th>train_MAE</th>\n", + " <th>train_RMSE</th>\n", + " <th>train_Total_Loss</th>\n", + " <th>val_MAE</th>\n", + " <th>val_RMSE</th>\n", + " <th>val_Total_Loss</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>0</th>\n", + " <td>0</td>\n", + " <td>297</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1</th>\n", + " <td>0</td>\n", + " <td>297</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2</th>\n", + " <td>1</td>\n", + " <td>595</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " </tr>\n", + " <tr>\n", + " <th>3</th>\n", + " <td>1</td>\n", + " <td>595</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " </tr>\n", + " <tr>\n", + " <th>4</th>\n", + " <td>2</td>\n", + " <td>893</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " </tr>\n", + " <tr>\n", + " <th>5</th>\n", + " <td>2</td>\n", + " <td>893</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " </tr>\n", + " <tr>\n", + " <th>6</th>\n", + " <td>3</td>\n", + " <td>1191</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " </tr>\n", + " <tr>\n", + " <th>7</th>\n", + " <td>3</td>\n", + " <td>1191</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " </tr>\n", + " <tr>\n", + " <th>8</th>\n", + " <td>4</td>\n", + " <td>1489</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " </tr>\n", + " <tr>\n", + " <th>9</th>\n", + " <td>4</td>\n", + " <td>1489</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "</div>" + ], + "text/plain": [ + " epoch step train_MAE train_RMSE train_Total_Loss val_MAE val_RMSE \\\n", + "0 0 297 NaN NaN NaN NaN NaN \n", + "1 0 297 NaN NaN NaN NaN NaN \n", + "2 1 595 NaN NaN NaN NaN NaN \n", + "3 1 595 NaN NaN NaN NaN NaN \n", + "4 2 893 NaN NaN NaN NaN NaN \n", + "5 2 893 NaN NaN NaN NaN NaN \n", + "6 3 1191 NaN NaN NaN NaN NaN \n", + "7 3 1191 NaN NaN NaN NaN NaN \n", + "8 4 1489 NaN NaN NaN NaN NaN \n", + "9 4 1489 NaN NaN NaN NaN NaN \n", + "\n", + " val_Total_Loss \n", + "0 NaN \n", + "1 NaN \n", + "2 NaN \n", + "3 NaN \n", + "4 NaN \n", + "5 NaN \n", + "6 NaN \n", + "7 NaN \n", + "8 NaN \n", + "9 NaN " + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "metrics" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "i=0\n", + "prediction=np.zeros(len(test_data))\n", + "for i in range(len(structures_test)):\n", + " prediction[i]=model.predict_structure(structures_test[i])" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "molcal", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} -- GitLab