import torch
from collections import Counter
import argparse
import os
import json
from tqdm import tqdm
from torch.nn.utils.rnn import pad_sequence
import pandas as pd
from itertools import chain
from transformers import AutoTokenizer

class fullDataset(torch.utils.data.Dataset):
    def __init__(self, dir, tokenizer, split='train', model='crossencoder', kg='none', kg_token="[KG]", n_labels=2):

        self.path = dir
        if kg !='none':
            df = pd.read_csv(self.path, index_col=0)[['Arg1', 'Arg2', 'rel', kg, 'split']]
        else:
            df = pd.read_csv(self.path, index_col=0)[['Arg1', 'Arg2', 'rel', 'split']]
        df = df[df['split'] == split]
        self.model = model
        comet_rels = {'isBefore':'is before', 'isAfter': 'is after', 'Causes': 'causes', 'CausedBy':'is caused by', 'HinderedBy': 'is hindered by', 'Hinders': 'hinders', 'xIntent': 'because they intend to', 'xNeed': 'because they need to'}

        if n_labels == 2:
            df = df[df['rel']<2]

        if kg != 'none':
            self.knowledge = df[kg].values.tolist()
            if kg == 'comet' or kg == 'comet_link':
                self.knowledge = [eval(x) for x in self.knowledge]
                self.knowledge = [chain.from_iterable(x) for x in self.knowledge]
                self.knowledge = [[comet_rels[x] if x in comet_rels.keys() else x for x in k] for k in self.knowledge]
                self.knowledge = [' '.join(x) for x in self.knowledge]
            if kg == 'conceptnet':
                self.knowledge = [". ".join(eval(x)) for x in self.knowledge]
            if self.model == 'hybrid' or self.model == 'attn':
                self.args = tokenizer(df['Arg1'].values.tolist(), df['Arg2'].values.tolist(), padding=True, truncation=True)
                self.knowledge = tokenizer(self.knowledge, padding=True, truncation=True)
            else:
                zipped_args = zip(self.knowledge, df['Arg2'].values.tolist())
                arg2 = [k + ' ' + kg_token + ' ' + a for k, a in zipped_args]

                self.args = tokenizer(df['Arg1'].astype(str).values.tolist(),
                    arg2,
                    padding=True, truncation=True)

        else:
            self.args = tokenizer(df['Arg1'].values.tolist(), df['Arg2'].values.tolist(), padding=True, truncation=True)

        self.labels = df['rel'].astype(int).values.tolist()

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, i):
        label = self.labels[i]
        item = {key: torch.tensor(val[i]) for key, val in self.args.items()}
        item['labels'] = torch.tensor(label)
        if self.model == 'hybrid' or self.model == 'attn':
            knowledge = {key: torch.tensor(val[i]) for key, val in self.knowledge.items()}
            return item, knowledge
        else:
            return item


def injection_collator(batch):
    txt = {}
    kg = {}
    txt_keys = batch[0][0].keys()
    kg_keys = batch[0][1].keys()
    for k in txt_keys:
        txt[k] = torch.stack([t[0][k] for t in batch])

    for k in kg_keys:
        kg[k] = torch.stack([t[1][k] for t in batch])

    return txt, kg

def default_collator(batch):
    txt = {}
    txt_keys = batch[0].keys()
    for k in txt_keys:
        txt[k] = torch.stack([t[k] for t in batch])

    return txt

class LinkPredDataset(torch.utils.data.Dataset):
    def __init__(self, split='train', ix2rel=None, model='bert-base-uncased'):
        print(f'loading {split} data')
        df = pd.read_csv(f'./data/atomic2020/{split}.tsv', delimiter='\t', names=['head', 'rel', 'tail'])
        n = int(df['rel'].value_counts().mean())
        print(df['rel'].value_counts())
        print(n)
        no_rel = None
        diff = n
        i = 1
        while diff > 0:

            if diff == n:
                heads = df[(df['head'].str.contains('none')) | (df['head'] != 'NaN')]['head'].dropna()
                tails = df[(df['tail'] != 'none') | (df['tail'] != 'NaN')]['tail'].dropna()

                heads = heads.sample(n=diff, random_state=i, ignore_index=True, replace=True)
                tails = tails.sample(n=diff, random_state=i, ignore_index=True, replace=True)
                temp = pd.concat([heads, tails], axis=1).rename(columns={0:'head', 1:'tail'})
                temp = pd.merge(temp, df[['head', 'tail']], on=['head', 'tail'],indicator='i', how='left')
                no_rel = temp.query('i == "left_only"').drop('i', 1).reset_index(drop=True)
                no_rel['rel'] = 'None'
                diff = max(0, n - no_rel.shape[0])
            else:
                heads = df[(df['head'].str.contains('none')) | (df['head'] != 'NaN')]['head'].dropna()
                tails = df[(df['tail'] != 'none') | (df['tail'] != 'NaN')]['tail'].dropna()

                heads = heads.sample(n=diff, random_state=i, ignore_index=True, replace=True)
                tails = tails.sample(n=diff, random_state=i, ignore_index=True, replace=True)
                temp = pd.concat([heads, tails], axis=1).rename(columns={0:'head', 1:'tail'})
                temp = pd.merge(temp, df[['head', 'tail']], on=['head', 'tail'], indicator='i', how='left')
                temp = temp.query('i == "left_only"').drop('i', 1).reset_index(drop=True)
                no_rel = pd.concat([no_rel, temp], ignore_index=True)
                no_rel['rel'] = 'None'
                no_rel = no_rel.drop_duplicates()
                diff = max(0, n - no_rel.shape[0])
            i += 1

        assert (no_rel.shape[0] == n)
        print(f'Adding {n} none relations to the existing {df.shape[0]} samples')
        df = pd.concat([df, no_rel], ignore_index=True)
        df = df[df['rel'].isin(ix2rel)]

        self.tokenizer = AutoTokenizer.from_pretrained(model)
        self.args = list(zip(df['head'].astype(str).tolist(), df['tail'].astype(str).tolist()))

        rel2ix = {rel:ix for ix, rel in enumerate(ix2rel)}
        self.labels = [rel2ix[rel] for rel in df['rel']]

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        item = self.tokenizer(*self.args[idx], truncation=True, padding='max_length')
        item = {key: torch.tensor(val) for key, val in item.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item