Skip to content
Snippets Groups Projects
link.py 1.45 KiB
Newer Older
Ameer's avatar
Ameer committed
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import pandas as pd
from dataLoader import LinkPredDataset
from model.Transformer import LinkPred
from myparser import pargs
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
import torch
import torch.nn as nn
import torch.nn.functional as F

ix2rel = ['HinderedBy', 'isBefore', 'isAfter', 'Causes', 'None']

rel2ix = {rel:ix for ix, rel in enumerate(ix2rel)}



if __name__ == "__main__":
    args = pargs()
    checkpoint_callback = ModelCheckpoint(monitor='val_loss', mode='min')
    trainer = Trainer(gpus=-1, default_root_dir=args.save, callbacks=[checkpoint_callback], max_epochs = args.epochs)

    if not args.ckpt:
        train_ds = LinkPredDataset('train')
        val_ds = LinkPredDataset('dev')
        train_dl = DataLoader(train_ds, batch_size=args.bsz, shuffle=True, collate_fn=default_collator)
        val_dl = DataLoader(val_ds, batch_size=args.bsz, shuffle=False, collate_fn=default_collator)


        model = LinkPred(lr=args.lr, model=args.base_model, n_labels=len(ix2rel))
        trainer.fit(model, train_dl, val_dl)

    else:
        model = LinkPred.load_from_checkpoint(args.ckpt, lr=args.lr, model=args.base_model, n_labels=len(ix2rel))

    test_ds = LinkPredDataset('test')
    test_dl = DataLoader(test_ds, batch_size=args.bsz, shuffle=False, collate_fn=default_collator)
    trainer.test(model, test_dl)