Skip to content
Snippets Groups Projects
get_comet_paths.py 3.72 KiB
Newer Older
Ameer's avatar
Ameer committed
import pickle
import networkx as nx
import pandas as pd
Ameer's avatar
Ameer committed
# from link import LinkPred
Ameer's avatar
Ameer committed
from transformers import AutoTokenizer
from itertools import product
from tqdm import tqdm
import torch
import numpy as np
Ameer's avatar
Ameer committed
import os
from myparser import pargs
Ameer's avatar
Ameer committed

inverse = {'isBefore': 'isAfter',
            'isAfter':'isBefore',
            'Causes':'CausedBy',
            'HinderedBy':'Hinders'
            }

ix2rel = ['HinderedBy', 'isBefore', 'isAfter', 'Causes', 'None']
rel2ix = {rel:ix for ix, rel in enumerate(ix2rel)}


def flip(G):
    g = nx.DiGraph()
    for u,v in G.edges():
        g.add_edge(v, u, label=inverse[G.get_edge_data(u, v)['label']])

    return g

def edges_to_nodes(G, paths):
    paths2 = []
    for path in paths:
        path2 = []
        for i in range(len(path) - 1):
            path2.append(path[i])
            path2.append(G.get_edge_data(path[i], path[i+1])['label'])
        path2.append(path[-1])
        paths2.append(path2)
    return paths2


if __name__ == '__main__':
Ameer's avatar
Ameer committed
    datasets = ['presidential_final']
Ameer's avatar
Ameer committed
    args = pargs()
Ameer's avatar
Ameer committed
    # link_model = [os.path.join(root, file) for root, _, files in os.walk(args.link_model) for file in files if '.ckpt' in file]
    # if len(link_model) == 0:
    #     raise ValueError('link model not found')
Ameer's avatar
Ameer committed

Ameer's avatar
Ameer committed
    # link_model = sorted(link_model)[-1]
Ameer's avatar
Ameer committed

Ameer's avatar
Ameer committed
    # model = LinkPred.load_from_checkpoint(link_model, model=args.base_model, n_labels=len(ix2rel))
    # tokenizer = AutoTokenizer.from_pretrained(args.base_model)
Ameer's avatar
Ameer committed
    for dataset in datasets:
Ameer's avatar
Ameer committed
        with open(f'trees_presidential.p', 'rb') as f:
Ameer's avatar
Ameer committed
            tree_dict = {}
            while True:
                try:
                    arg, graph = pickle.load(f)
                    tree_dict[arg] = graph
                except EOFError:
                    break
Ameer's avatar
Ameer committed

        df = pd.read_csv(f'data/{dataset}.csv', index_col=0)

        paths = []
        paths_link = []
Ameer's avatar
Ameer committed
        # with torch.no_grad():
        #     model.eval()
        for _, row in tqdm(df.iterrows(), total=df.shape[0]):
            g1 = tree_dict[row['Arg1']]
            g2 = tree_dict[row['Arg2']]
            g_comb = nx.compose(g1, flip(g2))
                # g_comb_link = g_comb
                # for a, _ in nx.bfs_successors(g1, source=row['Arg1']):
                #     if a == row['Arg1']:
                #         continue
                #     for b, _ in nx.bfs_successors(g2, source=row['Arg2']):
                #         if b == row['Arg2']:
                #             continue
                        # inputs = tokenizer(a, b, truncation=True, padding=True)
                        # inputs = {key: torch.tensor(val).reshape(1, -1) for key, val in inputs.items()}
                        # predictions = model(inputs).numpy()
                        # predictions = np.argmax(predictions, axis=-1)[0]
                        # if predictions == rel2ix['None']:
                        #     continue
                        # label = ix2rel[predictions]
                        # g_comb_link.add_edge(a, b, label=label)

            g_comb = nx.relabel_nodes(g_comb, {row['Arg1']:'[Arg1]', row['Arg2']:'[Arg2]'})
                # g_comb_link = nx.relabel_nodes(g_comb, {row['Arg1']:'[Arg1]', row['Arg2']:'[Arg2]'})

            try:
Ameer's avatar
Ameer committed
                p = list(nx.all_simple_paths(g_comb, '[Arg1]', '[Arg2]'))
                p = edges_to_nodes(g_comb, p)
Ameer's avatar
Ameer committed
            except:
                p = []
            paths.append(p)

                # try:
                #     p = list(nx.all_simple_paths(g_comb_link, '[Arg1]', '[Arg2]'))
                #     p = edges_to_nodes(g_comb_link, p)
                # except:
                #     p = []
                # paths_link.append(p)

        # df['comet_link'] = paths_link
Ameer's avatar
Ameer committed
        df['comet'] = paths
Ameer's avatar
Ameer committed
        df.to_csv(f'data/{dataset}-2.csv')