Skip to content
Snippets Groups Projects
dataLoader.py 6.08 KiB
Newer Older
Ameer's avatar
Ameer committed
import torch
from collections import Counter
import argparse
import os
import json
from tqdm import tqdm
from torch.nn.utils.rnn import pad_sequence
import pandas as pd
from model.Transformer import BaseTransformer
from itertools import chain

class fullDataset(torch.utils.data.Dataset):
    def __init__(self, dir, tokenizer, split='train', model='crossencoder', kg='none', kg_token="[KG]", n_labels=2):

        self.path = dir
        if kg !='none':
            df = pd.read_csv(self.path, index_col=0)[['Arg1', 'Arg2', 'rel', kg, 'split']]
        else:
            df = pd.read_csv(self.path, index_col=0)[['Arg1', 'Arg2', 'rel', 'split']]
        df = df[df['split'] == split]
        self.model = model
        comet_rels = {'isBefore':'is before', 'isAfter': 'is after', 'Causes': 'causes', 'CausedBy':'is caused by', 'HinderedBy': 'is hindered by', 'Hinders': 'hinders', 'xIntent': 'because they intend to', 'xNeed': 'because they need to'}

        if n_labels == 2:
            df = df[df['rel']<2]
            
        if kg != 'none':
            self.knowledge = df[kg].values.tolist()
            if kg == 'comet' or kg == 'comet_link':
                self.knowledge = [eval(x) for x in self.knowledge]
                self.knowledge = [chain.from_iterable(x) for x in self.knowledge]
                self.knowledge = [[comet_rels[x] if x in comet_rels.keys() else x for x in k] for k in self.knowledge]
                self.knowledge = [' '.join(x) for x in self.knowledge]
            if kg == 'conceptnet':
                self.knowledge = [". ".join(eval(x)) for x in self.knowledge]
            if self.model == 'hybrid' or self.model == 'attn':
                self.args = tokenizer(df['Arg1'].values.tolist(), df['Arg2'].values.tolist(), padding=True, truncation=True)
                self.knowledge = tokenizer(self.knowledge, padding=True, truncation=True)
            else:
                zipped_args = zip(self.knowledge, df['Arg2'].values.tolist())
                arg2 = [k + ' ' + kg_token + ' ' + a for k, a in zipped_args]

                self.args = tokenizer(df['Arg1'].astype(str).values.tolist(),
                    arg2,
                    padding=True, truncation=True)

        else:
            self.args = tokenizer(df['Arg1'].values.tolist(), df['Arg2'].values.tolist(), padding=True, truncation=True)

        self.labels = df['rel'].astype(int).values.tolist()

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, i):
        label = self.labels[i]
        item = {key: torch.tensor(val[i]) for key, val in self.args.items()}
        item['labels'] = torch.tensor(label)
        if self.model == 'hybrid' or self.model == 'attn':
            knowledge = {key: torch.tensor(val[i]) for key, val in self.knowledge.items()}
            return item, knowledge
        else:
            return item


def injection_collator(batch):
    txt = {}
    kg = {}
    txt_keys = batch[0][0].keys()
    kg_keys = batch[0][1].keys()
    for k in txt_keys:
        txt[k] = torch.stack([t[0][k] for t in batch])

    for k in kg_keys:
        kg[k] = torch.stack([t[1][k] for t in batch])

    return txt, kg

def default_collator(batch):
    txt = {}
    txt_keys = batch[0].keys()
    for k in txt_keys:
        txt[k] = torch.stack([t[k] for t in batch])

    return txt

class LinkPredDataset(Dataset):
    def __init__(self, split='train'):
        print(f'loading {split} data')
        df = pd.read_csv(f'./data/atomic2020/{split}.tsv', delimiter='\t', names=['head', 'rel', 'tail'])
        n = int(df['rel'].value_counts().mean())
        print(df['rel'].value_counts())
        print(n)
        no_rel = None
        diff = n
        i = 1
        while diff > 0:

            if diff == n:
                heads = df[(df['head'].str.contains('none')) | (df['head'] != 'NaN')]['head'].dropna()
                tails = df[(df['tail'] != 'none') | (df['tail'] != 'NaN')]['tail'].dropna()

                heads = heads.sample(n=diff, random_state=i, ignore_index=True, replace=True)
                tails = tails.sample(n=diff, random_state=i, ignore_index=True, replace=True)
                temp = pd.concat([heads, tails], axis=1).rename(columns={0:'head', 1:'tail'})
                temp = pd.merge(temp, df[['head', 'tail']], on=['head', 'tail'],indicator='i', how='left')
                no_rel = temp.query('i == "left_only"').drop('i', 1).reset_index(drop=True)
                no_rel['rel'] = 'None'
                diff = max(0, n - no_rel.shape[0])
            else:
                heads = df[(df['head'].str.contains('none')) | (df['head'] != 'NaN')]['head'].dropna()
                tails = df[(df['tail'] != 'none') | (df['tail'] != 'NaN')]['tail'].dropna()

                heads = heads.sample(n=diff, random_state=i, ignore_index=True, replace=True)
                tails = tails.sample(n=diff, random_state=i, ignore_index=True, replace=True)
                temp = pd.concat([heads, tails], axis=1).rename(columns={0:'head', 1:'tail'})
                temp = pd.merge(temp, df[['head', 'tail']], on=['head', 'tail'], indicator='i', how='left')
                temp = temp.query('i == "left_only"').drop('i', 1).reset_index(drop=True)
                no_rel = pd.concat([no_rel, temp], ignore_index=True)
                no_rel['rel'] = 'None'
                no_rel = no_rel.drop_duplicates()
                diff = max(0, n - no_rel.shape[0])
            i += 1

        assert (no_rel.shape[0] == n)
        print(f'Adding {n} none relations to the existing {df.shape[0]} samples')
        df = pd.concat([df, no_rel], ignore_index=True)
        df = df[df['rel'].isin(ix2rel)]

        self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
        self.args = list(zip(df['head'].astype(str).tolist(), df['tail'].astype(str).tolist()))
        self.labels = [rel2ix[rel] for rel in df['rel']]

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        item = self.tokenizer(*self.args[idx], truncation=True, padding='max_length')
        item = {key: torch.tensor(val) for key, val in item.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item