Skip to content
Snippets Groups Projects
get_comet_paths.py 3.55 KiB
Newer Older
Ameer's avatar
Ameer committed
import pickle
import networkx as nx
import pandas as pd
from link import LinkPred
from transformers import AutoTokenizer
from itertools import product
from tqdm import tqdm
import torch
import numpy as np
Ameer's avatar
Ameer committed
import os
from myparser import pargs
Ameer's avatar
Ameer committed

inverse = {'isBefore': 'isAfter',
            'isAfter':'isBefore',
            'Causes':'CausedBy',
            'HinderedBy':'Hinders'
            }

ix2rel = ['HinderedBy', 'isBefore', 'isAfter', 'Causes', 'None']
rel2ix = {rel:ix for ix, rel in enumerate(ix2rel)}


def flip(G):
    g = nx.DiGraph()
    for u,v in G.edges():
        g.add_edge(v, u, label=inverse[G.get_edge_data(u, v)['label']])

    return g

def edges_to_nodes(G, paths):
    paths2 = []
    for path in paths:
        path2 = []
        for i in range(len(path) - 1):
            path2.append(path[i])
            path2.append(G.get_edge_data(path[i], path[i+1])['label'])
        path2.append(path[-1])
        paths2.append(path2)
    return paths2


if __name__ == '__main__':
    datasets = ['student_essay', 'debate']
    args = pargs()
Ameer's avatar
Ameer committed
    link_model = [os.path.join(root, file) for root, _, files in os.walk(args.link_model) for file in files if '.ckpt' in file]
Ameer's avatar
Ameer committed
    if len(link_model) == 0:
        raise ValueError('link model not found')

    link_model = sorted(link_model)[-1]

    model = LinkPred.load_from_checkpoint(link_model, model=args.base_model, n_labels=len(ix2rel))
Ameer's avatar
Ameer committed
    tokenizer = AutoTokenizer.from_pretrained(args.base_model)
    for dataset in datasets:
        with open(f'trees_{dataset}.p', 'rb') as f:
Ameer's avatar
Ameer committed
            tree_dict = {}
            while True:
                try:
                    arg, graph = pickle.load(f)
                    tree_dict[arg] = graph
                except EOFError:
                    break
Ameer's avatar
Ameer committed

        df = pd.read_csv(f'data/{dataset}.csv', index_col=0)

        paths = []
        paths_link = []
        with torch.no_grad():
            model.eval()
            for _, row in tqdm(df.iterrows(), total=df.shape[0]):
                g1 = tree_dict[row['Arg1']]
                g2 = tree_dict[row['Arg2']]
                g_comb = nx.compose(g1, flip(g2))
                g_comb_link = g_comb
                for a, _ in nx.bfs_successors(g1, source=row['Arg1']):
                    if a == row['Arg1']:
                        continue
                    for b, _ in nx.bfs_successors(g2, source=row['Arg2']):
                        if b == row['Arg2']:
                            continue
                        inputs = tokenizer(a, b, truncation=True, padding=True)
                        inputs = {key: torch.tensor(val).reshape(1, -1) for key, val in inputs.items()}
                        predictions = model(inputs).numpy()
                        predictions = np.argmax(predictions, axis=-1)[0]
                        if predictions == rel2ix['None']:
                            continue
                        label = ix2rel[predictions]
                        g_comb_link.add_edge(a, b, label=label)

                g_comb = nx.relabel_nodes(g_comb, {row['Arg1']:'[Arg1]', row['Arg2']:'[Arg2]'})
                g_comb_link = nx.relabel_nodes(g_comb, {row['Arg1']:'[Arg1]', row['Arg2']:'[Arg2]'})

                p = list(nx.all_simple_paths(g_comb, '[Arg1]', '[Arg2]'))
                p = edges_to_nodes(g_comb, p)
                paths.append(p)

                p = list(nx.all_simple_paths(g_comb_link, '[Arg1]', '[Arg2]'))
                p = edges_to_nodes(g_comb_link, p)
                paths_link.append(p)

        df['comet_link'] = paths_link
        df['comet'] = paths
        df.to_csv(f'data/{dataset}.csv')