Newer
Older
import pickle
import networkx as nx
import pandas as pd
from link import LinkPred
from transformers import AutoTokenizer
from itertools import product
from tqdm import tqdm
import torch
import numpy as np
inverse = {'isBefore': 'isAfter',
'isAfter':'isBefore',
'Causes':'CausedBy',
'HinderedBy':'Hinders'
}
ix2rel = ['HinderedBy', 'isBefore', 'isAfter', 'Causes', 'None']
rel2ix = {rel:ix for ix, rel in enumerate(ix2rel)}
def flip(G):
g = nx.DiGraph()
for u,v in G.edges():
g.add_edge(v, u, label=inverse[G.get_edge_data(u, v)['label']])
return g
def edges_to_nodes(G, paths):
paths2 = []
for path in paths:
path2 = []
for i in range(len(path) - 1):
path2.append(path[i])
path2.append(G.get_edge_data(path[i], path[i+1])['label'])
path2.append(path[-1])
paths2.append(path2)
return paths2
if __name__ == '__main__':
datasets = ['student_essay', 'debate']
args = pargs()
link_model = [os.path.join(root, file) for root, _, files in os.walk(args.link_model) for file in files if '.ckpt' in file]
if len(link_model) == 0:
raise ValueError('link model not found')
link_model = sorted(link_model)[-1]
model = LinkPred.load_from_checkpoint(link_model, model=args.base_model, n_labels=len(ix2rel))
tokenizer = AutoTokenizer.from_pretrained(args.base_model)
for dataset in datasets:
with open(f'trees_{dataset}.p', 'rb') as f:
tree_dict = {}
while True:
try:
arg, graph = pickle.load(f)
tree_dict[arg] = graph
except EOFError:
break
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
df = pd.read_csv(f'data/{dataset}.csv', index_col=0)
paths = []
paths_link = []
with torch.no_grad():
model.eval()
for _, row in tqdm(df.iterrows(), total=df.shape[0]):
g1 = tree_dict[row['Arg1']]
g2 = tree_dict[row['Arg2']]
g_comb = nx.compose(g1, flip(g2))
g_comb_link = g_comb
for a, _ in nx.bfs_successors(g1, source=row['Arg1']):
if a == row['Arg1']:
continue
for b, _ in nx.bfs_successors(g2, source=row['Arg2']):
if b == row['Arg2']:
continue
inputs = tokenizer(a, b, truncation=True, padding=True)
inputs = {key: torch.tensor(val).reshape(1, -1) for key, val in inputs.items()}
predictions = model(inputs).numpy()
predictions = np.argmax(predictions, axis=-1)[0]
if predictions == rel2ix['None']:
continue
label = ix2rel[predictions]
g_comb_link.add_edge(a, b, label=label)
g_comb = nx.relabel_nodes(g_comb, {row['Arg1']:'[Arg1]', row['Arg2']:'[Arg2]'})
g_comb_link = nx.relabel_nodes(g_comb, {row['Arg1']:'[Arg1]', row['Arg2']:'[Arg2]'})
p = list(nx.all_simple_paths(g_comb, '[Arg1]', '[Arg2]'))
p = edges_to_nodes(g_comb, p)
paths.append(p)
p = list(nx.all_simple_paths(g_comb_link, '[Arg1]', '[Arg2]'))
p = edges_to_nodes(g_comb_link, p)
paths_link.append(p)
df['comet_link'] = paths_link
df['comet'] = paths
df.to_csv(f'data/{dataset}.csv')