import pickle import networkx as nx import pandas as pd from link import LinkPred from transformers import AutoTokenizer from itertools import product from tqdm import tqdm import torch import numpy as np import os from myparser import pargs inverse = {'isBefore': 'isAfter', 'isAfter':'isBefore', 'Causes':'CausedBy', 'HinderedBy':'Hinders' } ix2rel = ['HinderedBy', 'isBefore', 'isAfter', 'Causes', 'None'] rel2ix = {rel:ix for ix, rel in enumerate(ix2rel)} def flip(G): g = nx.DiGraph() for u,v in G.edges(): g.add_edge(v, u, label=inverse[G.get_edge_data(u, v)['label']]) return g def edges_to_nodes(G, paths): paths2 = [] for path in paths: path2 = [] for i in range(len(path) - 1): path2.append(path[i]) path2.append(G.get_edge_data(path[i], path[i+1])['label']) path2.append(path[-1]) paths2.append(path2) return paths2 if __name__ == '__main__': datasets = ['student_essay', 'debate'] args = pargs() link_model = [os.path.join(root, file) for root, _, files in os.walk(args.link_model) for file in files if '.ckpt' in file] if len(link_model) == 0: raise ValueError('link model not found') link_model = sorted(link_model)[-1] model = LinkPred.load_from_checkpoint(link_model, model=args.base_model, n_labels=len(ix2rel)) tokenizer = AutoTokenizer.from_pretrained(args.base_model) for dataset in datasets: with open(f'trees_{dataset}.p', 'rb') as f: tree_dict = {} while True: try: arg, graph = pickle.load(f) tree_dict[arg] = graph except EOFError: break df = pd.read_csv(f'data/{dataset}.csv', index_col=0) paths = [] paths_link = [] with torch.no_grad(): model.eval() for _, row in tqdm(df.iterrows(), total=df.shape[0]): g1 = tree_dict[row['Arg1']] g2 = tree_dict[row['Arg2']] g_comb = nx.compose(g1, flip(g2)) g_comb_link = g_comb for a, _ in nx.bfs_successors(g1, source=row['Arg1']): if a == row['Arg1']: continue for b, _ in nx.bfs_successors(g2, source=row['Arg2']): if b == row['Arg2']: continue inputs = tokenizer(a, b, truncation=True, padding=True) inputs = {key: torch.tensor(val).reshape(1, -1) for key, val in inputs.items()} predictions = model(inputs).numpy() predictions = np.argmax(predictions, axis=-1)[0] if predictions == rel2ix['None']: continue label = ix2rel[predictions] g_comb_link.add_edge(a, b, label=label) g_comb = nx.relabel_nodes(g_comb, {row['Arg1']:'[Arg1]', row['Arg2']:'[Arg2]'}) g_comb_link = nx.relabel_nodes(g_comb, {row['Arg1']:'[Arg1]', row['Arg2']:'[Arg2]'}) p = list(nx.all_simple_paths(g_comb, '[Arg1]', '[Arg2]')) p = edges_to_nodes(g_comb, p) paths.append(p) p = list(nx.all_simple_paths(g_comb_link, '[Arg1]', '[Arg2]')) p = edges_to_nodes(g_comb_link, p) paths_link.append(p) df['comet_link'] = paths_link df['comet'] = paths df.to_csv(f'data/{dataset}.csv')