Newer
Older
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import pandas as pd
from dataLoader import LinkPredDataset, default_collator
from models.transformers import LinkPred
from myparser import pargs
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
import torch
import torch.nn as nn
import torch.nn.functional as F
ix2rel = ['HinderedBy', 'isBefore', 'isAfter', 'Causes', 'None']
rel2ix = {rel:ix for ix, rel in enumerate(ix2rel)}
if __name__ == "__main__":
args = pargs()
checkpoint_callback = ModelCheckpoint(monitor='val_loss', mode='min')
trainer = Trainer(gpus=1, default_root_dir=args.save, callbacks=[checkpoint_callback], max_epochs = args.epochs)
train_ds = LinkPredDataset('train', ix2rel, args.base_model)
val_ds = LinkPredDataset('dev', ix2rel, args.base_model)
train_dl = DataLoader(train_ds, batch_size=args.bsz, shuffle=True, collate_fn=default_collator)
val_dl = DataLoader(val_ds, batch_size=args.bsz, shuffle=False, collate_fn=default_collator)
model = LinkPred(lr=args.lr, model=args.base_model, n_labels=len(ix2rel))
trainer.fit(model, train_dl, val_dl)
else:
model = LinkPred.load_from_checkpoint(args.ckpt, lr=args.lr, model=args.base_model, n_labels=len(ix2rel))
test_dl = DataLoader(test_ds, batch_size=args.bsz, shuffle=False, collate_fn=default_collator)
trainer.test(model, test_dl)