from model.transformers import * from torch.utils.data import DataLoader from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger from myparser import pargs from dataLoader import fullDataset, injection_collator, default_collator from transformers import AutoTokenizer from sklearn.metrics import f1_score, confusion_matrix import torch.nn.functional as F import numpy as np import os def main(args): # seed_everything(24032022) args.save = os.path.join('results/', args.save) # if not os.path.exists(args.save): # print('overwriting save directory\n') # os.mkdir(args.save) # checkpoint = None kg_token = None collate_fn = default_collator tokenizer = AutoTokenizer.from_pretrained(args.base_model) if args.knowledge == 'comet' or args.knowledge== 'comet_link': tokenizer.add_tokens(['[Arg1]', '[Arg2]'], special_tokens=True) if args.model == 'crossencoder': kg_token = "[KG]" num_added_toks = tokenizer.add_tokens([kg_token], special_tokens=True) ##This line is updated model = TransformerCrossEncoder(len(tokenizer), lr=args.lr, decay=args.lrdecay, model=args.base_model, n_labels=args.n_classes) elif args.model == 'hybrid': model = TransformerInjection(len(tokenizer), lr=args.lr, decay=args.lrdecay, model=args.base_model, n_labels=args.n_classes) collate_fn = injection_collator elif args.model == 'attn': model = TransformerKnowledgeAttention(len(tokenizer), lr=args.lr, decay=args.lrdecay, model=args.base_model, n_labels=args.n_classes) collate_fn = injection_collator else: model = TransformerCrossEncoder(len(tokenizer), lr=args.lr, decay=args.lrdecay, model=args.base_model, n_labels=args.n_classes) train_dataset = fullDataset(args.datadir, tokenizer, model=args.model, kg=args.knowledge, kg_token=kg_token, n_labels=args.n_classes) val_dataset = fullDataset(args.datadir, tokenizer, split='dev', model=args.model, kg=args.knowledge, kg_token=kg_token, n_labels=args.n_classes) train_dl = DataLoader(train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collate_fn) val_dl = DataLoader(val_dataset, batch_size=args.bsz, shuffle=False, collate_fn=collate_fn) checkpoint_callback = ModelCheckpoint(monitor='val_f1', mode='max') tb_logger = TensorBoardLogger(save_dir=os.path.join(args.save, 'logs/')) csv_logger = CSVLogger(save_dir=os.path.join(args.save, 'logs/')) trainer = Trainer(gpus=-1, default_root_dir=args.save, logger=[tb_logger, csv_logger], callbacks=[checkpoint_callback], max_epochs = args.epochs) trainer.fit(model, train_dl, val_dl) test_dataset = fullDataset(args.datadir, tokenizer, split='test', model=args.model, kg=args.knowledge, kg_token=kg_token, n_labels=args.n_classes) test_dl = DataLoader(test_dataset, batch_size=args.bsz, shuffle=False, collate_fn=collate_fn) trainer.test(dataloaders=test_dl) test_dl = DataLoader(test_dataset, batch_size=args.bsz, shuffle=False, collate_fn=collate_fn) results = trainer.predict(model, test_dl) preds, _ = tuple(zip(*results)) preds = torch.cat(preds).cpu().detach().numpy() np.save(os.path.join('results/', args.save, 'logs/', f'test_results.npy'), preds) if __name__ == '__main__': args = pargs() main(args)