Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • s1707343/commonsense-argmining
1 result
Show changes
Commits on Source (2)
...@@ -16,3 +16,6 @@ link/* ...@@ -16,3 +16,6 @@ link/*
*.p *.p
*.ckpt *.ckpt
*.out
*.sh
comet-atomic_2020_BART/
...@@ -22,7 +22,8 @@ class fullDataset(torch.utils.data.Dataset): ...@@ -22,7 +22,8 @@ class fullDataset(torch.utils.data.Dataset):
comet_rels = {'isBefore':'is before', 'isAfter': 'is after', 'Causes': 'causes', 'CausedBy':'is caused by', 'HinderedBy': 'is hindered by', 'Hinders': 'hinders', 'xIntent': 'because they intend to', 'xNeed': 'because they need to'} comet_rels = {'isBefore':'is before', 'isAfter': 'is after', 'Causes': 'causes', 'CausedBy':'is caused by', 'HinderedBy': 'is hindered by', 'Hinders': 'hinders', 'xIntent': 'because they intend to', 'xNeed': 'because they need to'}
if n_labels == 2: if n_labels == 2:
df = df[df['rel']<2] df = df[df['rel']>0]
df['rel'] = df['rel'] - 1
if kg != 'none': if kg != 'none':
self.knowledge = df[kg].values.tolist() self.knowledge = df[kg].values.tolist()
......
...@@ -90,7 +90,7 @@ def build_tree(head, rels, model, max_depth=2): ...@@ -90,7 +90,7 @@ def build_tree(head, rels, model, max_depth=2):
if __name__ == "__main__": if __name__ == "__main__":
args = pargs() args = pargs()
dataset = args.datadir dataset = args.datadir
df = pd.read_csv(f'data/{dataset}') df = pd.read_csv(f'data/{dataset}.csv')
print("loading comet ...") print("loading comet ...")
comet = Comet("./comet-atomic_2020_BART", progress_bar=False) comet = Comet("./comet-atomic_2020_BART", progress_bar=False)
comet.model.zero_grad() comet.model.zero_grad()
......
...@@ -40,7 +40,7 @@ def edges_to_nodes(G, paths): ...@@ -40,7 +40,7 @@ def edges_to_nodes(G, paths):
if __name__ == '__main__': if __name__ == '__main__':
datasets = ['student_essay', 'debate'] datasets = ['presidential']
args = pargs() args = pargs()
link_model = [os.path.join(root, file) for root, _, files in os.walk(args.link_model) for file in files if '.ckpt' in file] link_model = [os.path.join(root, file) for root, _, files in os.walk(args.link_model) for file in files if '.ckpt' in file]
if len(link_model) == 0: if len(link_model) == 0:
...@@ -88,13 +88,19 @@ if __name__ == '__main__': ...@@ -88,13 +88,19 @@ if __name__ == '__main__':
g_comb = nx.relabel_nodes(g_comb, {row['Arg1']:'[Arg1]', row['Arg2']:'[Arg2]'}) g_comb = nx.relabel_nodes(g_comb, {row['Arg1']:'[Arg1]', row['Arg2']:'[Arg2]'})
g_comb_link = nx.relabel_nodes(g_comb, {row['Arg1']:'[Arg1]', row['Arg2']:'[Arg2]'}) g_comb_link = nx.relabel_nodes(g_comb, {row['Arg1']:'[Arg1]', row['Arg2']:'[Arg2]'})
p = list(nx.all_simple_paths(g_comb, '[Arg1]', '[Arg2]')) try:
p = edges_to_nodes(g_comb, p) p = list(nx.all_simple_paths(g_comb, '[Arg1]', '[Arg2]'))
p = edges_to_nodes(g_comb, p)
except:
p = []
paths.append(p) paths.append(p)
p = list(nx.all_simple_paths(g_comb_link, '[Arg1]', '[Arg2]')) try:
p = edges_to_nodes(g_comb_link, p) p = list(nx.all_simple_paths(g_comb_link, '[Arg1]', '[Arg2]'))
p = edges_to_nodes(g_comb_link, p)
except:
p = []
paths_link.append(p) paths_link.append(p)
df['comet_link'] = paths_link df['comet_link'] = paths_link
......
from model.transformers import * from models.transformers import *
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from pytorch_lightning import Trainer, seed_everything from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
...@@ -76,7 +76,7 @@ def main(args): ...@@ -76,7 +76,7 @@ def main(args):
preds = torch.cat(preds).cpu().detach().numpy() preds = torch.cat(preds).cpu().detach().numpy()
np.save(os.path.join('results/', args.save, 'logs/', f'test_results.npy'), preds) np.save(os.path.join(args.save, 'test_results.npy'), preds)
if __name__ == '__main__': if __name__ == '__main__':
......