Newer
Older
from transformers import AutoTokenizer
from itertools import product
from tqdm import tqdm
import torch
import numpy as np
inverse = {'isBefore': 'isAfter',
'isAfter':'isBefore',
'Causes':'CausedBy',
'HinderedBy':'Hinders'
}
ix2rel = ['HinderedBy', 'isBefore', 'isAfter', 'Causes', 'None']
rel2ix = {rel:ix for ix, rel in enumerate(ix2rel)}
def flip(G):
g = nx.DiGraph()
for u,v in G.edges():
g.add_edge(v, u, label=inverse[G.get_edge_data(u, v)['label']])
return g
def edges_to_nodes(G, paths):
paths2 = []
for path in paths:
path2 = []
for i in range(len(path) - 1):
path2.append(path[i])
path2.append(G.get_edge_data(path[i], path[i+1])['label'])
path2.append(path[-1])
paths2.append(path2)
return paths2
if __name__ == '__main__':
# link_model = [os.path.join(root, file) for root, _, files in os.walk(args.link_model) for file in files if '.ckpt' in file]
# if len(link_model) == 0:
# raise ValueError('link model not found')
# model = LinkPred.load_from_checkpoint(link_model, model=args.base_model, n_labels=len(ix2rel))
# tokenizer = AutoTokenizer.from_pretrained(args.base_model)
tree_dict = {}
while True:
try:
arg, graph = pickle.load(f)
tree_dict[arg] = graph
except EOFError:
break
df = pd.read_csv(f'data/{dataset}.csv', index_col=0)
paths = []
paths_link = []
# with torch.no_grad():
# model.eval()
for _, row in tqdm(df.iterrows(), total=df.shape[0]):
g1 = tree_dict[row['Arg1']]
g2 = tree_dict[row['Arg2']]
g_comb = nx.compose(g1, flip(g2))
# g_comb_link = g_comb
# for a, _ in nx.bfs_successors(g1, source=row['Arg1']):
# if a == row['Arg1']:
# continue
# for b, _ in nx.bfs_successors(g2, source=row['Arg2']):
# if b == row['Arg2']:
# continue
# inputs = tokenizer(a, b, truncation=True, padding=True)
# inputs = {key: torch.tensor(val).reshape(1, -1) for key, val in inputs.items()}
# predictions = model(inputs).numpy()
# predictions = np.argmax(predictions, axis=-1)[0]
# if predictions == rel2ix['None']:
# continue
# label = ix2rel[predictions]
# g_comb_link.add_edge(a, b, label=label)
g_comb = nx.relabel_nodes(g_comb, {row['Arg1']:'[Arg1]', row['Arg2']:'[Arg2]'})
# g_comb_link = nx.relabel_nodes(g_comb, {row['Arg1']:'[Arg1]', row['Arg2']:'[Arg2]'})
try:
p = list(nx.all_simple_paths(g_comb, '[Arg1]', '[Arg2]'))
p = edges_to_nodes(g_comb, p)
except:
p = []
paths.append(p)
# try:
# p = list(nx.all_simple_paths(g_comb_link, '[Arg1]', '[Arg2]'))
# p = edges_to_nodes(g_comb_link, p)
# except:
# p = []
# paths_link.append(p)
# df['comet_link'] = paths_link