Newer
Older
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from torch.utils.data import DataLoader
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger
from myparser import pargs
from dataLoader import fullDataset, injection_collator, default_collator
from transformers import AutoTokenizer
from sklearn.metrics import f1_score, confusion_matrix
import torch.nn.functional as F
import numpy as np
import os
def main(args):
# seed_everything(24032022)
args.save = os.path.join('results/', args.save)
# if not os.path.exists(args.save):
# print('overwriting save directory\n')
# os.mkdir(args.save)
# checkpoint = None
kg_token = None
collate_fn = default_collator
tokenizer = AutoTokenizer.from_pretrained(args.base_model)
if args.knowledge == 'comet' or args.knowledge== 'comet_link':
tokenizer.add_tokens(['[Arg1]', '[Arg2]'], special_tokens=True)
if args.model == 'crossencoder':
kg_token = "[KG]"
num_added_toks = tokenizer.add_tokens([kg_token], special_tokens=True) ##This line is updated
model = TransformerCrossEncoder(len(tokenizer), lr=args.lr, decay=args.lrdecay, model=args.base_model, n_labels=args.n_classes)
elif args.model == 'hybrid':
model = TransformerInjection(len(tokenizer), lr=args.lr, decay=args.lrdecay, model=args.base_model, n_labels=args.n_classes)
collate_fn = injection_collator
elif args.model == 'attn':
model = TransformerKnowledgeAttention(len(tokenizer), lr=args.lr, decay=args.lrdecay, model=args.base_model, n_labels=args.n_classes)
collate_fn = injection_collator
else:
model = TransformerCrossEncoder(len(tokenizer), lr=args.lr, decay=args.lrdecay, model=args.base_model, n_labels=args.n_classes)
train_dataset = fullDataset(args.datadir, tokenizer, model=args.model, kg=args.knowledge, kg_token=kg_token, n_labels=args.n_classes)
val_dataset = fullDataset(args.datadir, tokenizer, split='dev', model=args.model, kg=args.knowledge, kg_token=kg_token, n_labels=args.n_classes)
train_dl = DataLoader(train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collate_fn)
val_dl = DataLoader(val_dataset, batch_size=args.bsz, shuffle=False, collate_fn=collate_fn)
checkpoint_callback = ModelCheckpoint(monitor='val_f1', mode='max')
tb_logger = TensorBoardLogger(save_dir=os.path.join(args.save, 'logs/'))
csv_logger = CSVLogger(save_dir=os.path.join(args.save, 'logs/'))
trainer = Trainer(gpus=-1, default_root_dir=args.save, logger=[tb_logger, csv_logger], callbacks=[checkpoint_callback], max_epochs = args.epochs)
trainer.fit(model, train_dl, val_dl)
test_dataset = fullDataset(args.datadir, tokenizer, split='test', model=args.model, kg=args.knowledge, kg_token=kg_token, n_labels=args.n_classes)
test_dl = DataLoader(test_dataset, batch_size=args.bsz, shuffle=False, collate_fn=collate_fn)
trainer.test(dataloaders=test_dl)
test_dl = DataLoader(test_dataset, batch_size=args.bsz, shuffle=False, collate_fn=collate_fn)
results = trainer.predict(model, test_dl)
preds, _ = tuple(zip(*results))
preds = torch.cat(preds).cpu().detach().numpy()
np.save(os.path.join(args.save, 'logs/', f'test_results.npy'), preds)