from transformers import AutoModelForSequenceClassification, AutoTokenizer
import pandas as pd
from dataLoader import LinkPredDataset, default_collator
from models.transformers import LinkPred
from myparser import pargs
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader



ix2rel = ['HinderedBy', 'isBefore', 'isAfter', 'Causes', 'None']
rel2ix = {rel:ix for ix, rel in enumerate(ix2rel)}

if __name__ == "__main__":
    args = pargs()
    checkpoint_callback = ModelCheckpoint(monitor='val_loss', mode='min')
    trainer = Trainer(gpus=1, default_root_dir=args.save, callbacks=[checkpoint_callback], max_epochs = args.epochs)

    if not args.ckpt:
        train_ds = LinkPredDataset('train', ix2rel, args.base_model)
        val_ds = LinkPredDataset('dev', ix2rel, args.base_model)
        train_dl = DataLoader(train_ds, batch_size=args.bsz, shuffle=True, collate_fn=default_collator)
        val_dl = DataLoader(val_ds, batch_size=args.bsz, shuffle=False, collate_fn=default_collator)


        model = LinkPred(lr=args.lr, model=args.base_model, n_labels=len(ix2rel))
        trainer.fit(model, train_dl, val_dl)

    else:
        model = LinkPred.load_from_checkpoint(args.ckpt, lr=args.lr, model=args.base_model, n_labels=len(ix2rel))

    test_ds = LinkPredDataset('test', ix2rel, args.base_model)
    test_dl = DataLoader(test_ds, batch_size=args.bsz, shuffle=False, collate_fn=default_collator)
    trainer.test(model, test_dl)