from transformers import AutoModelForSequenceClassification, AutoTokenizer import pandas as pd from dataLoader import LinkPredDataset, default_collator from models.transformers import LinkPred from myparser import pargs from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader ix2rel = ['HinderedBy', 'isBefore', 'isAfter', 'Causes', 'None'] rel2ix = {rel:ix for ix, rel in enumerate(ix2rel)} if __name__ == "__main__": args = pargs() checkpoint_callback = ModelCheckpoint(monitor='val_loss', mode='min') trainer = Trainer(gpus=1, default_root_dir=args.save, callbacks=[checkpoint_callback], max_epochs = args.epochs) if not args.ckpt: train_ds = LinkPredDataset('train', ix2rel, args.base_model) val_ds = LinkPredDataset('dev', ix2rel, args.base_model) train_dl = DataLoader(train_ds, batch_size=args.bsz, shuffle=True, collate_fn=default_collator) val_dl = DataLoader(val_ds, batch_size=args.bsz, shuffle=False, collate_fn=default_collator) model = LinkPred(lr=args.lr, model=args.base_model, n_labels=len(ix2rel)) trainer.fit(model, train_dl, val_dl) else: model = LinkPred.load_from_checkpoint(args.ckpt, lr=args.lr, model=args.base_model, n_labels=len(ix2rel)) test_ds = LinkPredDataset('test', ix2rel, args.base_model) test_dl = DataLoader(test_ds, batch_size=args.bsz, shuffle=False, collate_fn=default_collator) trainer.test(model, test_dl)