Skip to content
Snippets Groups Projects
train.py 3.38 KiB
Newer Older
s1707343's avatar
s1707343 committed
from models.transformers import *
Ameer's avatar
Ameer committed
from torch.utils.data import DataLoader
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger
from myparser import pargs
from dataLoader import fullDataset, injection_collator, default_collator
from transformers import AutoTokenizer
from sklearn.metrics import f1_score, confusion_matrix
import torch.nn.functional as F
import numpy as np
import os





def main(args):
    # seed_everything(24032022)
    args.save = os.path.join('results/', args.save)
    # if not os.path.exists(args.save):
    #     print('overwriting save directory\n')
    #     os.mkdir(args.save)
    #     checkpoint = None

    kg_token = None
    collate_fn = default_collator
    tokenizer = AutoTokenizer.from_pretrained(args.base_model)

    if args.knowledge == 'comet' or args.knowledge== 'comet_link':
        tokenizer.add_tokens(['[Arg1]', '[Arg2]'], special_tokens=True)


    if args.model == 'crossencoder':
        kg_token = "[KG]"
        num_added_toks = tokenizer.add_tokens([kg_token], special_tokens=True) ##This line is updated
        model = TransformerCrossEncoder(len(tokenizer), lr=args.lr, decay=args.lrdecay, model=args.base_model, n_labels=args.n_classes)

    elif args.model == 'hybrid':
        model = TransformerInjection(len(tokenizer), lr=args.lr, decay=args.lrdecay, model=args.base_model, n_labels=args.n_classes)
        collate_fn = injection_collator

    elif args.model == 'attn':
        model = TransformerKnowledgeAttention(len(tokenizer), lr=args.lr, decay=args.lrdecay, model=args.base_model, n_labels=args.n_classes)
        collate_fn = injection_collator

    else:
        model = TransformerCrossEncoder(len(tokenizer), lr=args.lr, decay=args.lrdecay, model=args.base_model, n_labels=args.n_classes)




    train_dataset = fullDataset(args.datadir, tokenizer, model=args.model, kg=args.knowledge, kg_token=kg_token, n_labels=args.n_classes)
    val_dataset = fullDataset(args.datadir, tokenizer, split='dev', model=args.model, kg=args.knowledge, kg_token=kg_token, n_labels=args.n_classes)

    train_dl = DataLoader(train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collate_fn)
    val_dl = DataLoader(val_dataset, batch_size=args.bsz, shuffle=False, collate_fn=collate_fn)


    checkpoint_callback = ModelCheckpoint(monitor='val_f1', mode='max')

    tb_logger = TensorBoardLogger(save_dir=os.path.join(args.save, 'logs/'))
    csv_logger = CSVLogger(save_dir=os.path.join(args.save, 'logs/'))

    trainer = Trainer(gpus=-1, default_root_dir=args.save, logger=[tb_logger, csv_logger], callbacks=[checkpoint_callback], max_epochs = args.epochs)
    trainer.fit(model, train_dl, val_dl)

    test_dataset = fullDataset(args.datadir, tokenizer, split='test', model=args.model, kg=args.knowledge, kg_token=kg_token, n_labels=args.n_classes)
    test_dl = DataLoader(test_dataset, batch_size=args.bsz, shuffle=False, collate_fn=collate_fn)

    trainer.test(dataloaders=test_dl)

    test_dl = DataLoader(test_dataset, batch_size=args.bsz, shuffle=False, collate_fn=collate_fn)
    results = trainer.predict(model, test_dl)
    preds, _ = tuple(zip(*results))
    preds = torch.cat(preds).cpu().detach().numpy()


s1707343's avatar
s1707343 committed
    np.save(os.path.join(args.save, 'logs/', f'test_results.npy'), preds)
Ameer's avatar
Ameer committed


if __name__ == '__main__':
    args = pargs()
    main(args)