Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import pickle
import networkx as nx
import pandas as pd
from link import LinkPred
from transformers import AutoTokenizer
from itertools import product
from tqdm import tqdm
import torch
import numpy as np
inverse = {'isBefore': 'isAfter',
'isAfter':'isBefore',
'Causes':'CausedBy',
'HinderedBy':'Hinders'
}
ix2rel = ['HinderedBy', 'isBefore', 'isAfter', 'Causes', 'None']
rel2ix = {rel:ix for ix, rel in enumerate(ix2rel)}
def flip(G):
g = nx.DiGraph()
for u,v in G.edges():
g.add_edge(v, u, label=inverse[G.get_edge_data(u, v)['label']])
return g
def edges_to_nodes(G, paths):
paths2 = []
for path in paths:
path2 = []
for i in range(len(path) - 1):
path2.append(path[i])
path2.append(G.get_edge_data(path[i], path[i+1])['label'])
path2.append(path[-1])
paths2.append(path2)
return paths2
if __name__ == '__main__':
datasets = ['student_essay', 'debate']
args = pargs()
model = LinkPred.load_from_checkpoint(args.link_model, model=args.base_model, n_labels=len(ix2rel))
tokenizer = AutoTokenizer.from_pretrained(args.base_model)
for dataset in datasets:
with open(f'trees_{dataset}.p', 'rb') as f:
tree_dict = {}
while True:
try:
arg, graph = pickle.load(f)
tree_dict[arg] = graph
except EOFError:
break
df = pd.read_csv(f'data/{dataset}.csv', index_col=0)
paths = []
paths_link = []
with torch.no_grad():
model.eval()
for _, row in tqdm(df.iterrows(), total=df.shape[0]):
g1 = tree_dict[row['Arg1']]
g2 = tree_dict[row['Arg2']]
g_comb = nx.compose(g1, flip(g2))
g_comb_link = g_comb
for a, _ in nx.bfs_successors(g1, source=row['Arg1']):
if a == row['Arg1']:
continue
for b, _ in nx.bfs_successors(g2, source=row['Arg2']):
if b == row['Arg2']:
continue
inputs = tokenizer(a, b, truncation=True, padding=True)
inputs = {key: torch.tensor(val).reshape(1, -1) for key, val in inputs.items()}
predictions = model(inputs).numpy()
predictions = np.argmax(predictions, axis=-1)[0]
if predictions == rel2ix['None']:
continue
label = ix2rel[predictions]
g_comb_link.add_edge(a, b, label=label)
g_comb = nx.relabel_nodes(g_comb, {row['Arg1']:'[Arg1]', row['Arg2']:'[Arg2]'})
g_comb_link = nx.relabel_nodes(g_comb, {row['Arg1']:'[Arg1]', row['Arg2']:'[Arg2]'})
p = list(nx.all_simple_paths(g_comb, '[Arg1]', '[Arg2]'))
p = edges_to_nodes(g_comb, p)
paths.append(p)
p = list(nx.all_simple_paths(g_comb_link, '[Arg1]', '[Arg2]'))
p = edges_to_nodes(g_comb_link, p)
paths_link.append(p)
df['comet_link'] = paths_link
df['comet'] = paths
df.to_csv(f'data/{dataset}.csv')