Skip to content
Snippets Groups Projects
link.py 1.58 KiB
Newer Older
Ameer's avatar
Ameer committed
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import pandas as pd
Ameer's avatar
Ameer committed
from dataLoader import LinkPredDataset, default_collator
from models.transformers import LinkPred
Ameer's avatar
Ameer committed
from myparser import pargs
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
import torch
import torch.nn as nn
import torch.nn.functional as F
Ameer's avatar
Ameer committed
from torch.utils.data import DataLoader
Ameer's avatar
Ameer committed



Ameer's avatar
Ameer committed
ix2rel = ['HinderedBy', 'isBefore', 'isAfter', 'Causes', 'None']
rel2ix = {rel:ix for ix, rel in enumerate(ix2rel)}
Ameer's avatar
Ameer committed

if __name__ == "__main__":
    args = pargs()
    checkpoint_callback = ModelCheckpoint(monitor='val_loss', mode='min')
Ameer's avatar
Ameer committed
    trainer = Trainer(gpus=1, default_root_dir=args.save, callbacks=[checkpoint_callback], max_epochs = args.epochs)
Ameer's avatar
Ameer committed

    if not args.ckpt:
Ameer's avatar
Ameer committed
        train_ds = LinkPredDataset('train', ix2rel, args.base_model)
        val_ds = LinkPredDataset('dev', ix2rel, args.base_model)
Ameer's avatar
Ameer committed
        train_dl = DataLoader(train_ds, batch_size=args.bsz, shuffle=True, collate_fn=default_collator)
        val_dl = DataLoader(val_ds, batch_size=args.bsz, shuffle=False, collate_fn=default_collator)


        model = LinkPred(lr=args.lr, model=args.base_model, n_labels=len(ix2rel))
        trainer.fit(model, train_dl, val_dl)

    else:
        model = LinkPred.load_from_checkpoint(args.ckpt, lr=args.lr, model=args.base_model, n_labels=len(ix2rel))

Ameer's avatar
Ameer committed
    test_ds = LinkPredDataset('test', ix2rel, args.base_model)
Ameer's avatar
Ameer committed
    test_dl = DataLoader(test_ds, batch_size=args.bsz, shuffle=False, collate_fn=default_collator)
    trainer.test(model, test_dl)