Skip to content
Snippets Groups Projects
get_comet_paths.py 3.82 KiB
Newer Older
Ameer's avatar
Ameer committed
import pickle
import networkx as nx
import pandas as pd
s1707343's avatar
s1707343 committed
from link import LinkPred
Ameer's avatar
Ameer committed
from transformers import AutoTokenizer
from itertools import product
from tqdm import tqdm
import torch
import numpy as np
Ameer's avatar
Ameer committed
import os
from myparser import pargs
Ameer's avatar
Ameer committed

inverse = {'isBefore': 'isAfter',
            'isAfter':'isBefore',
            'Causes':'CausedBy',
            'HinderedBy':'Hinders'
            }

ix2rel = ['HinderedBy', 'isBefore', 'isAfter', 'Causes', 'None']
rel2ix = {rel:ix for ix, rel in enumerate(ix2rel)}


def flip(G):
    g = nx.DiGraph()
    for u,v in G.edges():
        g.add_edge(v, u, label=inverse[G.get_edge_data(u, v)['label']])

    return g

def edges_to_nodes(G, paths):
    paths2 = []
    for path in paths:
        path2 = []
        for i in range(len(path) - 1):
            path2.append(path[i])
            path2.append(G.get_edge_data(path[i], path[i+1])['label'])
        path2.append(path[-1])
        paths2.append(path2)
    return paths2


if __name__ == '__main__':
s1707343's avatar
s1707343 committed
    datasets = ['student_essay']
Ameer's avatar
Ameer committed
    args = pargs()
s1707343's avatar
s1707343 committed
    link_model = [os.path.join(root, file) for root, _, files in os.walk(args.link_model) for file in files if '.ckpt' in file]
    if len(link_model) == 0:
        raise ValueError('link model not found')
Ameer's avatar
Ameer committed

s1707343's avatar
s1707343 committed
    link_model = sorted(link_model)[-1]
Ameer's avatar
Ameer committed

s1707343's avatar
s1707343 committed
    model = LinkPred.load_from_checkpoint(link_model, model=args.base_model, n_labels=len(ix2rel))
    tokenizer = AutoTokenizer.from_pretrained(args.base_model)
Ameer's avatar
Ameer committed
    for dataset in datasets:
s1707343's avatar
s1707343 committed
        with open(f'trees_student_essay.p', 'rb') as f:
Ameer's avatar
Ameer committed
            tree_dict = {}
            while True:
                try:
s1707343's avatar
s1707343 committed
                    arg, goal, graph, inv = pickle.load(f)
                    tree_dict[(arg, goal)] = graph
                    tree_dict[(goal, arg)] = inv
Ameer's avatar
Ameer committed
                except EOFError:
                    break
Ameer's avatar
Ameer committed

s1707343's avatar
s1707343 committed
        df = pd.read_csv(f'datasets/{dataset}.csv', index_col=0)
Ameer's avatar
Ameer committed

        paths = []
        paths_link = []
s1707343's avatar
s1707343 committed
        with torch.no_grad():
            model.eval()
            for _, row in tqdm(df.iterrows(), total=df.shape[0]):
                g1 = tree_dict[(row['Arg1'], row['Arg2'])]
                g2 = tree_dict[(row['Arg2'], row['Arg1'])]
                g_comb = nx.compose(g1, flip(g2))
                g_comb_link = g_comb
                for a, _ in nx.bfs_successors(g1, source=row['Arg1']):
                    if a == row['Arg1']:
                        continue
                    for b, _ in nx.bfs_successors(g2, source=row['Arg2']):
                        if b == row['Arg2']:
                            continue
                        inputs = tokenizer(a, b, truncation=True, padding=True)
                        inputs = {key: torch.tensor(val).reshape(1, -1) for key, val in inputs.items()}
                        predictions = model(inputs).detach().cpu().numpy()
                        predictions = np.argmax(predictions, axis=-1)[0]
                        if predictions == rel2ix['None']:
                            continue
                        label = ix2rel[predictions]
                        g_comb_link.add_edge(a, b, label=label)

                g_comb = nx.relabel_nodes(g_comb, {row['Arg1']:'[Arg1]', row['Arg2']:'[Arg2]'})
                g_comb_link = nx.relabel_nodes(g_comb, {row['Arg1']:'[Arg1]', row['Arg2']:'[Arg2]'})
                try:
                    p = list(nx.all_simple_paths(g_comb, '[Arg1]', '[Arg2]'))
                    p = edges_to_nodes(g_comb, p)
                except:
                    p = []
                paths.append(p)
    
                try:
                    p = list(nx.all_simple_paths(g_comb_link, '[Arg1]', '[Arg2]'))
                    p = edges_to_nodes(g_comb_link, p)
                except:
                    p = []
                paths_link.append(p)
Ameer's avatar
Ameer committed

s1707343's avatar
s1707343 committed
        df['comet_link'] = paths_link
Ameer's avatar
Ameer committed
        df['comet'] = paths
s1707343's avatar
s1707343 committed
        df.to_csv(f'datasets/{dataset}-2.csv')