import pandas as pd
import json
import os
import numpy as np
from tqdm import tqdm

from time import time

from functools import lru_cache
import pickle

import nltk
from nltk.corpus import brown
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize

from sklearn.metrics.pairwise import cosine_distances
from sentence_transformers import SentenceTransformer




class Atomic2020():
    def __init__(self, dir="data/atomic2020/"):
        # Compute p(word) over the brown corpus
        # print('Building probability dict...\n')
        #
        # data = pd.read_csv(data, index_col=0)
        # data = data[data['split'] == 'TRAIN']
        # text = []
        # for i, row in tqdm(data.iterrows()):
        #
        #     arg1 = word_tokenize(row['Arg1'])
        #     arg2 = word_tokenize(row['Arg2'])
        #     text = text + arg1 + arg2

        # fdist = nltk.FreqDist(w.lower() for w in text)
        # self.pdist = nltk.DictionaryProbDist(fdist, normalize=True)
        #
        # print('Loading fasttext model...\n')
        # self.model = fasttext.load_facebook_vectors(vecs)
        self.model = SentenceTransformer("all-mpnet-base-v2")


        # model = FastText(vector_size=300, window=3, min_count=1)
        # model.build_vocab(brown.sents())
        # model.train(brown.sents()[:10], total_examples=len(brown.sents()[:10]), epochs=1)
        # self.model = model.wv

        # Preprocess atomic to make future code easier


        self.df = pd.concat([pd.read_csv(os.path.join(dir, file), delimiter='\t', header=None) for file in os.listdir(dir) if '.tsv' in file],
                            axis=0,
                            ignore_index=True
                            )
        self.df = self.df[~(self.df[2] == 'none')]
        self.df = self.df[~self.df[2].isna()]
        self.df = self.df[~(self.df[1] == 'isFilledBy')]
        self.stopwords = stopwords.words('english')
        self.rel_dict = {'oEffect': ', as a result others',
                          'oReact': ', as a result others feel',
                          'oWant': 'because others wanted',
                          'xAttr': ', as a result PersonX is seen as',
                          'xEffect': ', as a result PersonX',
                          'xIntent': 'because PersonX sought',
                          'xReact': ', as a result PersonX feels',
                          'xWant': 'because PersonX wanted',
                          'xNeed': 'because PersonX needed',
                          'HinderedBy': 'but is hindered by the fact that',
                          'ObjectUse': 'is used for',
                          'isBefore': 'before',
                          'HasProperty': 'has the property',
                          'AtLocation': 'is located at',
                          'HasSubEvent': 'involves',
                          'MadeUpOf': 'is made of',
                          'isAfter': 'after',
                          'CapableOf': 'is capable of',
                          'Desires': 'desires',
                          'NotDesires': 'does not desire',
                          'Causes': 'causes',
                          'xReason': 'because PersonX wanted'
                          }

        if not os.path.exists('embeddings/2020'):
            os.mkdir('embeddings/2020')

        # Load up precomputed atomic embeddings if exists otherwise compute and store them
        try:
            with open('embeddings/2020/heads-bert.p', 'rb') as f:
                    self.cached_heads = pickle.load(f)
            with open('embeddings/2020/tails-bert.p', 'rb') as f:
                    self.cached_tails = pickle.load(f)

        except:
            print('Computing embeddings... \n')
            self.cached_heads, self.cached_tails = self.cache_embeds()
            with open('embeddings/2020/heads-bert.p', 'wb') as f:
                pickle.dump(self.cached_heads, f)
            with open('embeddings/2020/tails-bert.p', 'wb') as f:
                pickle.dump(self.cached_tails, f)

        self.head_array = np.array(list(self.cached_heads.keys()))
        self.head_embeds = np.array(list(self.cached_heads.values()))

    # Speed up embedding lookup with memoization
    @lru_cache(maxsize=256)
    def get_embed(self, word):
        return self.model[word]

    def cache_embeds(self):
        heads = {}
        for head in tqdm(self.df[0].unique()):
            heads[head] = self.get_mean_embed(head)
        tails = {}
        for tail in tqdm(self.df[2].unique()):
            tails[tail] = self.get_mean_embed(tail)

        return heads, tails



    def get_mean_embed(self, sent):
        # sent = [word.lower() for word in word_tokenize(sent) if word.lower() not in self.stopwords]
        # if len(sent) == 0:
        #     return np.zeros(300)
        # # Truncate anything over 50 words
        # if len(sent) > 50:
        #     sent = sent[:50]
        #
        # # Compute average of word embeddings weighted by rarity=(1-p(word))
        # sent = [self.get_embed(word) * (1 - self.pdist.prob(word) + 1e-5) for word in sent]
        # emb = np.mean(sent, axis=0)

        emb = self.model.encode(sent)
        return emb


    # Get n closest atomic relation heads
    def get_topics(self, arg, n=1):
        dists = cosine_distances(arg.reshape(1, -1), self.head_embeds)[0]
        closest = np.argsort(dists)
        return self.head_array[closest[:n]], dists[closest[:n]]

    # Get all triples that start with the given head
    # Get all triples that start with the given head
    def get_paths(self, topic):
        relevant_rows = self.df[self.df[0] == topic]
        paths = []
        for _, row in relevant_rows.iterrows():
            paths.append([row[0], row[1], row[2]])
        return paths

    # Given list of paths, get closest path to arg
    def filter_paths(self, arg, paths):
        path_embeds = [self.cached_tails[path[2]] for path in paths]
        path_embeds = np.array(path_embeds)
        dists = cosine_distances(arg.reshape(1, -1), path_embeds)[0]
        closest = np.argsort(dists)
        return paths[closest[0]], dists[closest[0]]

    # Find n triples that best connect arg1 and arg2
    def join_args(self, arg1, arg2, n=1):

        arg1 = self.get_mean_embed(arg1)
        arg2 = self.get_mean_embed(arg2)

        topic_1 = list(zip(*self.get_topics(arg1, n)))
        topic_2 = list(zip(*self.get_topics(arg2, n)))
        path_1 = []
        path_2 = []
        for i in range(n):
            topic = topic_1[i]
            path = self.filter_paths(arg2, self.get_paths(topic[0]))
            path_1.append((topic, path))
        for i in range(n):
            topic = topic_2[i]
            path = self.filter_paths(arg1, self.get_paths(topic[0]))
            path_2.append((topic, path))

        paths = [(path, (path_dist + topic_dist) / 2) for ((topic, topic_dist), (path, path_dist)) in path_1 + path_2]
        paths = list(sorted(paths, key=lambda x: x[1]))
        return paths

    def process_args(self, arg1, arg2):
        sent = self.join_args(arg1, arg2, 3)[0]

        x = 'the person'

        sent[0][1] = self.rel_dict[sent[0][1]]
        rel = ' '.join(sent[0])
        rel = rel.replace('PersonX', x)
        return (rel, sent[1])