import pickle
import networkx as nx
import pandas as pd
# from link import LinkPred
from transformers import AutoTokenizer
from itertools import product
from tqdm import tqdm
import torch
import numpy as np
import os
from myparser import pargs

inverse = {'isBefore': 'isAfter',
            'isAfter':'isBefore',
            'Causes':'CausedBy',
            'HinderedBy':'Hinders'
            }

ix2rel = ['HinderedBy', 'isBefore', 'isAfter', 'Causes', 'None']
rel2ix = {rel:ix for ix, rel in enumerate(ix2rel)}


def flip(G):
    g = nx.DiGraph()
    for u,v in G.edges():
        g.add_edge(v, u, label=inverse[G.get_edge_data(u, v)['label']])

    return g

def edges_to_nodes(G, paths):
    paths2 = []
    for path in paths:
        path2 = []
        for i in range(len(path) - 1):
            path2.append(path[i])
            path2.append(G.get_edge_data(path[i], path[i+1])['label'])
        path2.append(path[-1])
        paths2.append(path2)
    return paths2


if __name__ == '__main__':
    datasets = ['presidential_final']
    args = pargs()
    # link_model = [os.path.join(root, file) for root, _, files in os.walk(args.link_model) for file in files if '.ckpt' in file]
    # if len(link_model) == 0:
    #     raise ValueError('link model not found')

    # link_model = sorted(link_model)[-1]

    # model = LinkPred.load_from_checkpoint(link_model, model=args.base_model, n_labels=len(ix2rel))
    # tokenizer = AutoTokenizer.from_pretrained(args.base_model)
    for dataset in datasets:
        with open(f'trees_presidential.p', 'rb') as f:
            tree_dict = {}
            while True:
                try:
                    arg, graph = pickle.load(f)
                    tree_dict[arg] = graph
                except EOFError:
                    break

        df = pd.read_csv(f'data/{dataset}.csv', index_col=0)

        paths = []
        paths_link = []
        # with torch.no_grad():
        #     model.eval()
        for _, row in tqdm(df.iterrows(), total=df.shape[0]):
            g1 = tree_dict[row['Arg1']]
            g2 = tree_dict[row['Arg2']]
            g_comb = nx.compose(g1, flip(g2))
                # g_comb_link = g_comb
                # for a, _ in nx.bfs_successors(g1, source=row['Arg1']):
                #     if a == row['Arg1']:
                #         continue
                #     for b, _ in nx.bfs_successors(g2, source=row['Arg2']):
                #         if b == row['Arg2']:
                #             continue
                        # inputs = tokenizer(a, b, truncation=True, padding=True)
                        # inputs = {key: torch.tensor(val).reshape(1, -1) for key, val in inputs.items()}
                        # predictions = model(inputs).numpy()
                        # predictions = np.argmax(predictions, axis=-1)[0]
                        # if predictions == rel2ix['None']:
                        #     continue
                        # label = ix2rel[predictions]
                        # g_comb_link.add_edge(a, b, label=label)

            g_comb = nx.relabel_nodes(g_comb, {row['Arg1']:'[Arg1]', row['Arg2']:'[Arg2]'})
                # g_comb_link = nx.relabel_nodes(g_comb, {row['Arg1']:'[Arg1]', row['Arg2']:'[Arg2]'})

            try:
                p = list(nx.all_simple_paths(g_comb, '[Arg1]', '[Arg2]'))
                p = edges_to_nodes(g_comb, p)
            except:
                p = []
            paths.append(p)

                # try:
                #     p = list(nx.all_simple_paths(g_comb_link, '[Arg1]', '[Arg2]'))
                #     p = edges_to_nodes(g_comb_link, p)
                # except:
                #     p = []
                # paths_link.append(p)

        # df['comet_link'] = paths_link
        df['comet'] = paths
        df.to_csv(f'data/{dataset}-2.csv')