Newer
Older
from transformers import AutoTokenizer
from itertools import product
from tqdm import tqdm
import torch
import numpy as np
inverse = {'isBefore': 'isAfter',
'isAfter':'isBefore',
'Causes':'CausedBy',
'HinderedBy':'Hinders'
}
ix2rel = ['HinderedBy', 'isBefore', 'isAfter', 'Causes', 'None']
rel2ix = {rel:ix for ix, rel in enumerate(ix2rel)}
def flip(G):
g = nx.DiGraph()
for u,v in G.edges():
g.add_edge(v, u, label=inverse[G.get_edge_data(u, v)['label']])
return g
def edges_to_nodes(G, paths):
paths2 = []
for path in paths:
path2 = []
for i in range(len(path) - 1):
path2.append(path[i])
path2.append(G.get_edge_data(path[i], path[i+1])['label'])
path2.append(path[-1])
paths2.append(path2)
return paths2
if __name__ == '__main__':
link_model = [os.path.join(root, file) for root, _, files in os.walk(args.link_model) for file in files if '.ckpt' in file]
if len(link_model) == 0:
raise ValueError('link model not found')
model = LinkPred.load_from_checkpoint(link_model, model=args.base_model, n_labels=len(ix2rel))
tokenizer = AutoTokenizer.from_pretrained(args.base_model)
arg, goal, graph, inv = pickle.load(f)
tree_dict[(arg, goal)] = graph
tree_dict[(goal, arg)] = inv
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
with torch.no_grad():
model.eval()
for _, row in tqdm(df.iterrows(), total=df.shape[0]):
g1 = tree_dict[(row['Arg1'], row['Arg2'])]
g2 = tree_dict[(row['Arg2'], row['Arg1'])]
g_comb = nx.compose(g1, flip(g2))
g_comb_link = g_comb
for a, _ in nx.bfs_successors(g1, source=row['Arg1']):
if a == row['Arg1']:
continue
for b, _ in nx.bfs_successors(g2, source=row['Arg2']):
if b == row['Arg2']:
continue
inputs = tokenizer(a, b, truncation=True, padding=True)
inputs = {key: torch.tensor(val).reshape(1, -1) for key, val in inputs.items()}
predictions = model(inputs).detach().cpu().numpy()
predictions = np.argmax(predictions, axis=-1)[0]
if predictions == rel2ix['None']:
continue
label = ix2rel[predictions]
g_comb_link.add_edge(a, b, label=label)
g_comb = nx.relabel_nodes(g_comb, {row['Arg1']:'[Arg1]', row['Arg2']:'[Arg2]'})
g_comb_link = nx.relabel_nodes(g_comb, {row['Arg1']:'[Arg1]', row['Arg2']:'[Arg2]'})
try:
p = list(nx.all_simple_paths(g_comb, '[Arg1]', '[Arg2]'))
p = edges_to_nodes(g_comb, p)
except:
p = []
paths.append(p)
try:
p = list(nx.all_simple_paths(g_comb_link, '[Arg1]', '[Arg2]'))
p = edges_to_nodes(g_comb_link, p)
except:
p = []
paths_link.append(p)