Skip to content
Snippets Groups Projects
dataLoader.py 6.16 KiB
Newer Older
Ameer's avatar
Ameer committed
import torch
from collections import Counter
import argparse
import os
import json
from tqdm import tqdm
from torch.nn.utils.rnn import pad_sequence
import pandas as pd
from itertools import chain
Ameer's avatar
Ameer committed
from transformers import AutoTokenizer
Ameer's avatar
Ameer committed

class fullDataset(torch.utils.data.Dataset):
    def __init__(self, dir, tokenizer, split='train', model='crossencoder', kg='none', kg_token="[KG]", n_labels=2):

        self.path = dir
        if kg !='none':
            df = pd.read_csv(self.path, index_col=0)[['Arg1', 'Arg2', 'rel', kg, 'split']]
        else:
            df = pd.read_csv(self.path, index_col=0)[['Arg1', 'Arg2', 'rel', 'split']]
        df = df[df['split'] == split]
        self.model = model
        comet_rels = {'isBefore':'is before', 'isAfter': 'is after', 'Causes': 'causes', 'CausedBy':'is caused by', 'HinderedBy': 'is hindered by', 'Hinders': 'hinders', 'xIntent': 'because they intend to', 'xNeed': 'because they need to'}

        if n_labels == 2:
            df = df[df['rel']<2]
Ameer's avatar
Ameer committed

Ameer's avatar
Ameer committed
        if kg != 'none':
            self.knowledge = df[kg].values.tolist()
            if kg == 'comet' or kg == 'comet_link':
                self.knowledge = [eval(x) for x in self.knowledge]
                self.knowledge = [chain.from_iterable(x) for x in self.knowledge]
                self.knowledge = [[comet_rels[x] if x in comet_rels.keys() else x for x in k] for k in self.knowledge]
                self.knowledge = [' '.join(x) for x in self.knowledge]
            if kg == 'conceptnet':
                self.knowledge = [". ".join(eval(x)) for x in self.knowledge]
            if self.model == 'hybrid' or self.model == 'attn':
                self.args = tokenizer(df['Arg1'].values.tolist(), df['Arg2'].values.tolist(), padding=True, truncation=True)
                self.knowledge = tokenizer(self.knowledge, padding=True, truncation=True)
            else:
                zipped_args = zip(self.knowledge, df['Arg2'].values.tolist())
                arg2 = [k + ' ' + kg_token + ' ' + a for k, a in zipped_args]

                self.args = tokenizer(df['Arg1'].astype(str).values.tolist(),
                    arg2,
                    padding=True, truncation=True)

        else:
            self.args = tokenizer(df['Arg1'].values.tolist(), df['Arg2'].values.tolist(), padding=True, truncation=True)

        self.labels = df['rel'].astype(int).values.tolist()

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, i):
        label = self.labels[i]
        item = {key: torch.tensor(val[i]) for key, val in self.args.items()}
        item['labels'] = torch.tensor(label)
        if self.model == 'hybrid' or self.model == 'attn':
            knowledge = {key: torch.tensor(val[i]) for key, val in self.knowledge.items()}
            return item, knowledge
        else:
            return item


def injection_collator(batch):
    txt = {}
    kg = {}
    txt_keys = batch[0][0].keys()
    kg_keys = batch[0][1].keys()
    for k in txt_keys:
        txt[k] = torch.stack([t[0][k] for t in batch])

    for k in kg_keys:
        kg[k] = torch.stack([t[1][k] for t in batch])

    return txt, kg

def default_collator(batch):
    txt = {}
    txt_keys = batch[0].keys()
    for k in txt_keys:
        txt[k] = torch.stack([t[k] for t in batch])

    return txt

Ameer's avatar
Ameer committed
class LinkPredDataset(torch.utils.data.Dataset):
    def __init__(self, split='train', ix2rel=None, model='bert-base-uncased'):
Ameer's avatar
Ameer committed
        print(f'loading {split} data')
        df = pd.read_csv(f'./data/atomic2020/{split}.tsv', delimiter='\t', names=['head', 'rel', 'tail'])
        n = int(df['rel'].value_counts().mean())
        print(df['rel'].value_counts())
        print(n)
        no_rel = None
        diff = n
        i = 1
        while diff > 0:

            if diff == n:
                heads = df[(df['head'].str.contains('none')) | (df['head'] != 'NaN')]['head'].dropna()
                tails = df[(df['tail'] != 'none') | (df['tail'] != 'NaN')]['tail'].dropna()

                heads = heads.sample(n=diff, random_state=i, ignore_index=True, replace=True)
                tails = tails.sample(n=diff, random_state=i, ignore_index=True, replace=True)
                temp = pd.concat([heads, tails], axis=1).rename(columns={0:'head', 1:'tail'})
                temp = pd.merge(temp, df[['head', 'tail']], on=['head', 'tail'],indicator='i', how='left')
                no_rel = temp.query('i == "left_only"').drop('i', 1).reset_index(drop=True)
                no_rel['rel'] = 'None'
                diff = max(0, n - no_rel.shape[0])
            else:
                heads = df[(df['head'].str.contains('none')) | (df['head'] != 'NaN')]['head'].dropna()
                tails = df[(df['tail'] != 'none') | (df['tail'] != 'NaN')]['tail'].dropna()

                heads = heads.sample(n=diff, random_state=i, ignore_index=True, replace=True)
                tails = tails.sample(n=diff, random_state=i, ignore_index=True, replace=True)
                temp = pd.concat([heads, tails], axis=1).rename(columns={0:'head', 1:'tail'})
                temp = pd.merge(temp, df[['head', 'tail']], on=['head', 'tail'], indicator='i', how='left')
                temp = temp.query('i == "left_only"').drop('i', 1).reset_index(drop=True)
                no_rel = pd.concat([no_rel, temp], ignore_index=True)
                no_rel['rel'] = 'None'
                no_rel = no_rel.drop_duplicates()
                diff = max(0, n - no_rel.shape[0])
            i += 1

        assert (no_rel.shape[0] == n)
        print(f'Adding {n} none relations to the existing {df.shape[0]} samples')
        df = pd.concat([df, no_rel], ignore_index=True)
        df = df[df['rel'].isin(ix2rel)]

Ameer's avatar
Ameer committed
        self.tokenizer = AutoTokenizer.from_pretrained(model)
Ameer's avatar
Ameer committed
        self.args = list(zip(df['head'].astype(str).tolist(), df['tail'].astype(str).tolist()))
Ameer's avatar
Ameer committed

        rel2ix = {rel:ix for ix, rel in enumerate(ix2rel)}
Ameer's avatar
Ameer committed
        self.labels = [rel2ix[rel] for rel in df['rel']]

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        item = self.tokenizer(*self.args[idx], truncation=True, padding='max_length')
        item = {key: torch.tensor(val) for key, val in item.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item