import pickle import networkx as nx import pandas as pd # from link import LinkPred from transformers import AutoTokenizer from itertools import product from tqdm import tqdm import torch import numpy as np import os from myparser import pargs inverse = {'isBefore': 'isAfter', 'isAfter':'isBefore', 'Causes':'CausedBy', 'HinderedBy':'Hinders' } ix2rel = ['HinderedBy', 'isBefore', 'isAfter', 'Causes', 'None'] rel2ix = {rel:ix for ix, rel in enumerate(ix2rel)} def flip(G): g = nx.DiGraph() for u,v in G.edges(): g.add_edge(v, u, label=inverse[G.get_edge_data(u, v)['label']]) return g def edges_to_nodes(G, paths): paths2 = [] for path in paths: path2 = [] for i in range(len(path) - 1): path2.append(path[i]) path2.append(G.get_edge_data(path[i], path[i+1])['label']) path2.append(path[-1]) paths2.append(path2) return paths2 if __name__ == '__main__': datasets = ['presidential_final'] args = pargs() # link_model = [os.path.join(root, file) for root, _, files in os.walk(args.link_model) for file in files if '.ckpt' in file] # if len(link_model) == 0: # raise ValueError('link model not found') # link_model = sorted(link_model)[-1] # model = LinkPred.load_from_checkpoint(link_model, model=args.base_model, n_labels=len(ix2rel)) # tokenizer = AutoTokenizer.from_pretrained(args.base_model) for dataset in datasets: with open(f'trees_presidential.p', 'rb') as f: tree_dict = {} while True: try: arg, graph = pickle.load(f) tree_dict[arg] = graph except EOFError: break df = pd.read_csv(f'data/{dataset}.csv', index_col=0) paths = [] paths_link = [] # with torch.no_grad(): # model.eval() for _, row in tqdm(df.iterrows(), total=df.shape[0]): g1 = tree_dict[row['Arg1']] g2 = tree_dict[row['Arg2']] g_comb = nx.compose(g1, flip(g2)) # g_comb_link = g_comb # for a, _ in nx.bfs_successors(g1, source=row['Arg1']): # if a == row['Arg1']: # continue # for b, _ in nx.bfs_successors(g2, source=row['Arg2']): # if b == row['Arg2']: # continue # inputs = tokenizer(a, b, truncation=True, padding=True) # inputs = {key: torch.tensor(val).reshape(1, -1) for key, val in inputs.items()} # predictions = model(inputs).numpy() # predictions = np.argmax(predictions, axis=-1)[0] # if predictions == rel2ix['None']: # continue # label = ix2rel[predictions] # g_comb_link.add_edge(a, b, label=label) g_comb = nx.relabel_nodes(g_comb, {row['Arg1']:'[Arg1]', row['Arg2']:'[Arg2]'}) # g_comb_link = nx.relabel_nodes(g_comb, {row['Arg1']:'[Arg1]', row['Arg2']:'[Arg2]'}) try: p = list(nx.all_simple_paths(g_comb, '[Arg1]', '[Arg2]')) p = edges_to_nodes(g_comb, p) except: p = [] paths.append(p) # try: # p = list(nx.all_simple_paths(g_comb_link, '[Arg1]', '[Arg2]')) # p = edges_to_nodes(g_comb_link, p) # except: # p = [] # paths_link.append(p) # df['comet_link'] = paths_link df['comet'] = paths df.to_csv(f'data/{dataset}-2.csv')