Skip to content
Snippets Groups Projects
get_comet_paths.py 3.25 KiB
Newer Older
Ameer's avatar
Ameer committed
import pickle
import networkx as nx
import pandas as pd
from link import LinkPred
from transformers import AutoTokenizer
from itertools import product
from tqdm import tqdm
import torch
import numpy as np


inverse = {'isBefore': 'isAfter',
            'isAfter':'isBefore',
            'Causes':'CausedBy',
            'HinderedBy':'Hinders'
            }

ix2rel = ['HinderedBy', 'isBefore', 'isAfter', 'Causes', 'None']
rel2ix = {rel:ix for ix, rel in enumerate(ix2rel)}


def flip(G):
    g = nx.DiGraph()
    for u,v in G.edges():
        g.add_edge(v, u, label=inverse[G.get_edge_data(u, v)['label']])

    return g

def edges_to_nodes(G, paths):
    paths2 = []
    for path in paths:
        path2 = []
        for i in range(len(path) - 1):
            path2.append(path[i])
            path2.append(G.get_edge_data(path[i], path[i+1])['label'])
        path2.append(path[-1])
        paths2.append(path2)
    return paths2


if __name__ == '__main__':
    datasets = ['student_essay', 'debate']
    args = pargs()
    model = LinkPred.load_from_checkpoint(args.link_model, model=args.base_model, n_labels=len(ix2rel))
    tokenizer = AutoTokenizer.from_pretrained(args.base_model)
    for dataset in datasets:
        with open(f'trees_{dataset}.p', 'rb') as f:
        tree_dict = {}
        while True:
            try:
                arg, graph = pickle.load(f)
                tree_dict[arg] = graph
            except EOFError:
                break

        df = pd.read_csv(f'data/{dataset}.csv', index_col=0)

        paths = []
        paths_link = []
        with torch.no_grad():
            model.eval()
            for _, row in tqdm(df.iterrows(), total=df.shape[0]):
                g1 = tree_dict[row['Arg1']]
                g2 = tree_dict[row['Arg2']]
                g_comb = nx.compose(g1, flip(g2))
                g_comb_link = g_comb
                for a, _ in nx.bfs_successors(g1, source=row['Arg1']):
                    if a == row['Arg1']:
                        continue
                    for b, _ in nx.bfs_successors(g2, source=row['Arg2']):
                        if b == row['Arg2']:
                            continue
                        inputs = tokenizer(a, b, truncation=True, padding=True)
                        inputs = {key: torch.tensor(val).reshape(1, -1) for key, val in inputs.items()}
                        predictions = model(inputs).numpy()
                        predictions = np.argmax(predictions, axis=-1)[0]
                        if predictions == rel2ix['None']:
                            continue
                        label = ix2rel[predictions]
                        g_comb_link.add_edge(a, b, label=label)

                g_comb = nx.relabel_nodes(g_comb, {row['Arg1']:'[Arg1]', row['Arg2']:'[Arg2]'})
                g_comb_link = nx.relabel_nodes(g_comb, {row['Arg1']:'[Arg1]', row['Arg2']:'[Arg2]'})

                p = list(nx.all_simple_paths(g_comb, '[Arg1]', '[Arg2]'))
                p = edges_to_nodes(g_comb, p)
                paths.append(p)

                p = list(nx.all_simple_paths(g_comb_link, '[Arg1]', '[Arg2]'))
                p = edges_to_nodes(g_comb_link, p)
                paths_link.append(p)

        df['comet_link'] = paths_link
        df['comet'] = paths
        df.to_csv(f'data/{dataset}.csv')