diff --git a/2024/coursework/nlu-cw2/README.md b/2024/coursework/nlu-cw2/README.md new file mode 100644 index 0000000000000000000000000000000000000000..179ccb82203e7f9415aaf71a44b59442223ff36f --- /dev/null +++ b/2024/coursework/nlu-cw2/README.md @@ -0,0 +1,23 @@ +NLU+ Coursework 2 (2024 Spring) +--- +##### Yifu Qiu, Shay Cohen and Alexandra Birch-Mayne + +###### Prior versions from Yao Fu, Frank Keller, Tom Sherborne, Mirella Lapata, Denis Emelin, Jon Mallinson, Ida Szubert and Rico Sennrich + +--- +This repository contains the data and materials for the second coursework of NLU+ 2024 Spring. + +To get started, clone the repository: +```bash +git clone https://git.ecdf.ed.ac.uk/nlu_public/course_materials.git +``` + +Then follow the instruction from the Course Work 2 handout. + + +##### Acnowledgements +This code is based on [Binh Tang's NMT Tutorial](https://github.com/tangbinh/machine-translation) and the data is a sample +of [Europarl](http://www.statmt.org/europarl/). + + + diff --git a/2024/coursework/nlu-cw2/example.sh b/2024/coursework/nlu-cw2/example.sh new file mode 100755 index 0000000000000000000000000000000000000000..a84e70c491830b696327f05668eb41ee895cce42 --- /dev/null +++ b/2024/coursework/nlu-cw2/example.sh @@ -0,0 +1,41 @@ +#!/usr/bin/env bash +########### +# USAGE NOTE: +# The following script is designed to help you get started running models with the codebase. +# At a minimum, all you need to do is give a name do your experiment using the ${EXP_NAME} variable. +# Additional arguments can be added to the train.py command for different functionality. +# We recommend copying this script and modifying it for each experiment you try (multi-layer, lexical etc) +########### + +# Activate Conda Environment [assuming your Miniconda installation is in your root directory] +source ~/miniconda3/bin/activate nlu + +# Define a location for all your experiments to save +ROOT=$(git rev-parse --show-toplevel) +RESULTS_ROOT="${ROOT}/results" +mkdir -p ${RESULTS_ROOT} + +### NAME YOUR EXPERIMENT HERE ## +EXP_NAME="baseline" +################################ + +## Local variables for current experiment +EXP_ROOT="${RESULTS_ROOT}/${EXP_NAME}" +DATA_DIR="${ROOT}/europarl_prepared" +TEST_EN_GOLD="${ROOT}/europarl_raw/test.en" +TEST_EN_PRED="${EXP_ROOT}/model_translations.txt" +mkdir -p ${EXP_ROOT} + +# Train model. Defaults are used for any argument not specified here. Use "\" to add arguments over multiple lines. +python train.py --save-dir "${EXP_ROOT}" \ + --log-file "${EXP_ROOT}/log.out" \ + --data "${DATA_DIR}" \ + ### ADDITIONAL ARGUMENTS HERE ### + +## Prediction step +python translate.py \ + --checkpoint-path "${EXP_ROOT}/checkpoint_best.pt" \ + --output "${TEST_EN_PRED}" + +## Calculate BLEU score for model outputs +perl multi-bleu.perl -lc ${TEST_EN_GOLD} < ${TEST_EN_PRED} | tee "${EXP_ROOT}/bleu.txt" diff --git a/2024/coursework/nlu-cw2/multi-bleu.perl b/2024/coursework/nlu-cw2/multi-bleu.perl new file mode 100644 index 0000000000000000000000000000000000000000..61de10d4594268d657e4f08780794ff7f58122dc --- /dev/null +++ b/2024/coursework/nlu-cw2/multi-bleu.perl @@ -0,0 +1,162 @@ +#!/usr/bin/env perl +# +# This file is part of moses. Its use is licensed under the GNU Lesser General +# Public License version 2.1 or, at your option, any later version. + +# $Id$ +use warnings; +use strict; + +my $lowercase = 0; +if ($ARGV[0] eq "-lc") { + $lowercase = 1; + shift; +} + +my $stem = $ARGV[0]; +if (!defined $stem) { + print STDERR "usage: multi-bleu.pl [-lc] reference < hypothesis\n"; + print STDERR "Reads the references from reference or reference0, reference1, ...\n"; + exit(1); +} + +$stem .= ".ref" if !-e $stem && !-e $stem."0" && -e $stem.".ref0"; + +my @REF; +my $ref=0; +while(-e "$stem$ref") { + &add_to_ref("$stem$ref",\@REF); + $ref++; +} +&add_to_ref($stem,\@REF) if -e $stem; +die("ERROR: could not find reference file $stem") unless scalar @REF; + +sub add_to_ref { + my ($file,$REF) = @_; + my $s=0; + open(REF,$file) or die "Can't read $file"; + while(<REF>) { + chop; + push @{$$REF[$s++]}, $_; + } + close(REF); +} + +my(@CORRECT,@TOTAL,$length_translation,$length_reference); +my $s=0; +while(<STDIN>) { + chop; + $_ = lc if $lowercase; + my @WORD = split; + my %REF_NGRAM = (); + my $length_translation_this_sentence = scalar(@WORD); + my ($closest_diff,$closest_length) = (9999,9999); + foreach my $reference (@{$REF[$s]}) { +# print "$s $_ <=> $reference\n"; + $reference = lc($reference) if $lowercase; + my @WORD = split(' ',$reference); + my $length = scalar(@WORD); + my $diff = abs($length_translation_this_sentence-$length); + if ($diff < $closest_diff) { + $closest_diff = $diff; + $closest_length = $length; + # print STDERR "$s: closest diff ".abs($length_translation_this_sentence-$length)." = abs($length_translation_this_sentence-$length), setting len: $closest_length\n"; + } elsif ($diff == $closest_diff) { + $closest_length = $length if $length < $closest_length; + # from two references with the same closeness to me + # take the *shorter* into account, not the "first" one. + } + for(my $n=1;$n<=4;$n++) { + my %REF_NGRAM_N = (); + for(my $start=0;$start<=$#WORD-($n-1);$start++) { + my $ngram = "$n"; + for(my $w=0;$w<$n;$w++) { + $ngram .= " ".$WORD[$start+$w]; + } + $REF_NGRAM_N{$ngram}++; + } + foreach my $ngram (keys %REF_NGRAM_N) { + if (!defined($REF_NGRAM{$ngram}) || + $REF_NGRAM{$ngram} < $REF_NGRAM_N{$ngram}) { + $REF_NGRAM{$ngram} = $REF_NGRAM_N{$ngram}; +# print "$i: REF_NGRAM{$ngram} = $REF_NGRAM{$ngram}<BR>\n"; + } + } + } + } + $length_translation += $length_translation_this_sentence; + $length_reference += $closest_length; + for(my $n=1;$n<=4;$n++) { + my %T_NGRAM = (); + for(my $start=0;$start<=$#WORD-($n-1);$start++) { + my $ngram = "$n"; + for(my $w=0;$w<$n;$w++) { + $ngram .= " ".$WORD[$start+$w]; + } + $T_NGRAM{$ngram}++; + } + foreach my $ngram (keys %T_NGRAM) { + $ngram =~ /^(\d+) /; + my $n = $1; + # my $corr = 0; +# print "$i e $ngram $T_NGRAM{$ngram}<BR>\n"; + $TOTAL[$n] += $T_NGRAM{$ngram}; + if (defined($REF_NGRAM{$ngram})) { + if ($REF_NGRAM{$ngram} >= $T_NGRAM{$ngram}) { + $CORRECT[$n] += $T_NGRAM{$ngram}; + # $corr = $T_NGRAM{$ngram}; +# print "$i e correct1 $T_NGRAM{$ngram}<BR>\n"; + } + else { + $CORRECT[$n] += $REF_NGRAM{$ngram}; + # $corr = $REF_NGRAM{$ngram}; +# print "$i e correct2 $REF_NGRAM{$ngram}<BR>\n"; + } + } + # $REF_NGRAM{$ngram} = 0 if !defined $REF_NGRAM{$ngram}; + # print STDERR "$ngram: {$s, $REF_NGRAM{$ngram}, $T_NGRAM{$ngram}, $corr}\n" + } + } + $s++; +} +my $brevity_penalty = 1; +my $bleu = 0; + +my @bleu=(); + +for(my $n=1;$n<=4;$n++) { + if (defined ($TOTAL[$n])){ + $bleu[$n]=($TOTAL[$n])?$CORRECT[$n]/$TOTAL[$n]:0; + # print STDERR "CORRECT[$n]:$CORRECT[$n] TOTAL[$n]:$TOTAL[$n]\n"; + }else{ + $bleu[$n]=0; + } +} + +if ($length_reference==0){ + printf "BLEU = 0, 0/0/0/0 (BP=0, ratio=0, hyp_len=0, ref_len=0)\n"; + exit(1); +} + +if ($length_translation<$length_reference) { + $brevity_penalty = exp(1-$length_reference/$length_translation); +} +$bleu = $brevity_penalty * exp((my_log( $bleu[1] ) + + my_log( $bleu[2] ) + + my_log( $bleu[3] ) + + my_log( $bleu[4] ) ) / 4) ; +printf "BLEU = %.2f, %.1f/%.1f/%.1f/%.1f (BP=%.3f, ratio=%.3f, hyp_len=%d, ref_len=%d)\n", + 100*$bleu, + 100*$bleu[1], + 100*$bleu[2], + 100*$bleu[3], + 100*$bleu[4], + $brevity_penalty, + $length_translation / $length_reference, + $length_translation, + $length_reference; + +sub my_log { + return -9999999999 unless $_[0]; + return log($_[0]); +} diff --git a/2024/coursework/nlu-cw2/preprocess.py b/2024/coursework/nlu-cw2/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..cd58ec1fa350b694da4a8ba1a447593eb73ba751 --- /dev/null +++ b/2024/coursework/nlu-cw2/preprocess.py @@ -0,0 +1,108 @@ +import argparse +import collections +import logging +import os +import sys +import re +import pickle + +from seq2seq import utils +from seq2seq.data.dictionary import Dictionary + +SPACE_NORMALIZER = re.compile("\s+") + + +def word_tokenize(line): + line = SPACE_NORMALIZER.sub(" ", line) + line = line.strip() + return line.split() + + +def get_args(): + parser = argparse.ArgumentParser('Data pre-processing)') + parser.add_argument('--source-lang', default=None, metavar='SRC', help='source language') + parser.add_argument('--target-lang', default=None, metavar='TGT', help='target language') + + parser.add_argument('--train-prefix', default=None, metavar='FP', help='train file prefix') + parser.add_argument('--tiny-train-prefix', default=None, metavar='FP', help='tiny train file prefix') + parser.add_argument('--valid-prefix', default=None, metavar='FP', help='valid file prefix') + parser.add_argument('--test-prefix', default=None, metavar='FP', help='test file prefix') + parser.add_argument('--dest-dir', default='data-bin', metavar='DIR', help='destination dir') + + parser.add_argument('--threshold-src', default=2, type=int, + help='map words appearing less than threshold times to unknown') + parser.add_argument('--num-words-src', default=-1, type=int, help='number of source words to retain') + parser.add_argument('--threshold-tgt', default=2, type=int, + help='map words appearing less than threshold times to unknown') + parser.add_argument('--num-words-tgt', default=-1, type=int, help='number of target words to retain') + return parser.parse_args() + + +def main(args): + os.makedirs(args.dest_dir, exist_ok=True) + src_dict = build_dictionary([args.train_prefix + '.' + args.source_lang]) + tgt_dict = build_dictionary([args.train_prefix + '.' + args.target_lang]) + + src_dict.finalize(threshold=args.threshold_src, num_words=args.num_words_src) + src_dict.save(os.path.join(args.dest_dir, 'dict.' + args.source_lang)) + logging.info('Built a source dictionary ({}) with {} words'.format(args.source_lang, len(src_dict))) + + tgt_dict.finalize(threshold=args.threshold_tgt, num_words=args.num_words_tgt) + tgt_dict.save(os.path.join(args.dest_dir, 'dict.' + args.target_lang)) + logging.info('Built a target dictionary ({}) with {} words'.format(args.target_lang, len(tgt_dict))) + + def make_split_datasets(lang, dictionary): + if args.train_prefix is not None: + make_binary_dataset(args.train_prefix + '.' + lang, os.path.join(args.dest_dir, 'train.' + lang), + dictionary) + if args.tiny_train_prefix is not None: + make_binary_dataset(args.tiny_train_prefix + '.' + lang, os.path.join(args.dest_dir, 'tiny_train.' + lang), + dictionary) + if args.valid_prefix is not None: + make_binary_dataset(args.valid_prefix + '.' + lang, os.path.join(args.dest_dir, 'valid.' + lang), + dictionary) + if args.test_prefix is not None: + make_binary_dataset(args.test_prefix + '.' + lang, os.path.join(args.dest_dir, 'test.' + lang), dictionary) + + make_split_datasets(args.source_lang, src_dict) + make_split_datasets(args.target_lang, tgt_dict) + + +def build_dictionary(filenames, tokenize=word_tokenize): + dictionary = Dictionary() + for filename in filenames: + with open(filename, 'r') as file: + for line in file: + for symbol in word_tokenize(line.strip()): + dictionary.add_word(symbol) + dictionary.add_word(dictionary.eos_word) + return dictionary + + +def make_binary_dataset(input_file, output_file, dictionary, tokenize=word_tokenize, append_eos=True): + nsent, ntok = 0, 0 + unk_counter = collections.Counter() + + def unk_consumer(word, idx): + if idx == dictionary.unk_idx and word != dictionary.unk_word: + unk_counter.update([word]) + + tokens_list = [] + with open(input_file, 'r') as inf: + for line in inf: + tokens = dictionary.binarize(line.strip(), word_tokenize, append_eos, consumer=unk_consumer) + nsent, ntok = nsent + 1, ntok + len(tokens) + tokens_list.append(tokens.numpy()) + + with open(output_file, 'wb') as outf: + pickle.dump(tokens_list, outf, protocol=pickle.HIGHEST_PROTOCOL) + logging.info('Built a binary dataset for {}: {} sentences, {} tokens, {:.3f}% replaced by unknown token'.format( + input_file, nsent, ntok, 100.0 * sum(unk_counter.values()) / ntok, dictionary.unk_word)) + + +if __name__ == '__main__': + args = get_args() + utils.init_logging(args) + logging.info('COMMAND: %s' % ' '.join(sys.argv)) + logging.info('Arguments: {}'.format(vars(args))) + main(args) diff --git a/2024/coursework/nlu-cw2/train.py b/2024/coursework/nlu-cw2/train.py new file mode 100755 index 0000000000000000000000000000000000000000..2bd30d0a9626e7f7e78c2534fdadd16f8fb27569 --- /dev/null +++ b/2024/coursework/nlu-cw2/train.py @@ -0,0 +1,218 @@ +import os +import logging +import argparse +import numpy as np +from tqdm import tqdm +from collections import OrderedDict + +import torch +import torch.nn as nn + +from seq2seq import models, utils +from seq2seq.data.dictionary import Dictionary +from seq2seq.data.dataset import Seq2SeqDataset, BatchSampler +from seq2seq.models import ARCH_MODEL_REGISTRY, ARCH_CONFIG_REGISTRY + + +def get_args(): + """ Defines training-specific hyper-parameters. """ + parser = argparse.ArgumentParser('Sequence to Sequence Model') + + # Add data arguments + parser.add_argument('--data', default='europarl_prepared', help='path to data directory') + parser.add_argument('--source-lang', default='de', help='source language') + parser.add_argument('--target-lang', default='en', help='target language') + parser.add_argument('--max-tokens', default=None, type=int, help='maximum number of tokens in a batch') + parser.add_argument('--batch-size', default=10, type=int, help='maximum number of sentences in a batch') + parser.add_argument('--train-on-tiny', action='store_true', help='train model on a tiny dataset') + + # Add model arguments + parser.add_argument('--arch', default='lstm', choices=ARCH_MODEL_REGISTRY.keys(), help='model architecture') + + # Add optimization arguments + parser.add_argument('--max-epoch', default=100, type=int, help='force stop training at specified epoch') + parser.add_argument('--clip-norm', default=4.0, type=float, help='clip threshold of gradients') + parser.add_argument('--lr', default=0.0003, type=float, help='learning rate') + parser.add_argument('--patience', default=10, type=int, + help='number of epochs without improvement on validation set before early stopping') + + # Add checkpoint arguments + parser.add_argument('--log-file', default=None, help='path to save logs') + parser.add_argument('--save-dir', default='checkpoints', help='path to save checkpoints') + parser.add_argument('--restore-file', default='checkpoint_last.pt', help='filename to load checkpoint') + parser.add_argument('--save-interval', type=int, default=1, help='save a checkpoint every N epochs') + parser.add_argument('--no-save', action='store_true', help='don\'t save models or checkpoints') + parser.add_argument('--epoch-checkpoints', action='store_true', help='store all epoch checkpoints') + + # Parse twice as model arguments are not known the first time + args, _ = parser.parse_known_args() + model_parser = parser.add_argument_group(argument_default=argparse.SUPPRESS) + ARCH_MODEL_REGISTRY[args.arch].add_args(model_parser) + args = parser.parse_args() + ARCH_CONFIG_REGISTRY[args.arch](args) + return args + + +def main(args): + """ Main training function. Trains the translation model over the course of several epochs, including dynamic + learning rate adjustment and gradient clipping. """ + logging.info('Commencing training!') + torch.manual_seed(42) + np.random.seed(42) + + utils.init_logging(args) + + # Load dictionaries + src_dict = Dictionary.load(os.path.join(args.data, 'dict.{:s}'.format(args.source_lang))) + logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format(args.source_lang, len(src_dict))) + tgt_dict = Dictionary.load(os.path.join(args.data, 'dict.{:s}'.format(args.target_lang))) + logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format(args.target_lang, len(tgt_dict))) + + # Load datasets + def load_data(split): + return Seq2SeqDataset( + src_file=os.path.join(args.data, '{:s}.{:s}'.format(split, args.source_lang)), + tgt_file=os.path.join(args.data, '{:s}.{:s}'.format(split, args.target_lang)), + src_dict=src_dict, tgt_dict=tgt_dict) + + train_dataset = load_data(split='train') if not args.train_on_tiny else load_data(split='tiny_train') + valid_dataset = load_data(split='valid') + + # Build model and optimization criterion + model = models.build_model(args, src_dict, tgt_dict) + logging.info('Built a model with {:d} parameters'.format(sum(p.numel() for p in model.parameters()))) + criterion = nn.CrossEntropyLoss(ignore_index=src_dict.pad_idx, reduction='sum') + + # Instantiate optimizer and learning rate scheduler + optimizer = torch.optim.Adam(model.parameters(), args.lr) + + # Load last checkpoint if one exists + state_dict = utils.load_checkpoint(args, model, optimizer) # lr_scheduler + last_epoch = state_dict['last_epoch'] if state_dict is not None else -1 + + # Track validation performance for early stopping + bad_epochs = 0 + best_validate = float('inf') + + for epoch in range(last_epoch + 1, args.max_epoch): + train_loader = \ + torch.utils.data.DataLoader(train_dataset, num_workers=1, collate_fn=train_dataset.collater, + batch_sampler=BatchSampler(train_dataset, args.max_tokens, args.batch_size, 1, + 0, shuffle=True, seed=42)) + model.train() + stats = OrderedDict() + stats['loss'] = 0 + stats['lr'] = 0 + stats['num_tokens'] = 0 + stats['batch_size'] = 0 + stats['grad_norm'] = 0 + stats['clip'] = 0 + # Display progress + progress_bar = tqdm(train_loader, desc='| Epoch {:03d}'.format(epoch), leave=False, disable=False) + + # Iterate over the training set + for i, sample in enumerate(progress_bar): + + if len(sample) == 0: + continue + model.train() + + ''' + ___QUESTION-1-DESCRIBE-E-START___ + 1. Add tensor shape annotation to each of the output tensor + 2. Add line-by-line description about the following lines of code do. + ''' + output, _ = model(sample['src_tokens'], sample['src_lengths'], sample['tgt_inputs']) + + loss = \ + criterion(output.view(-1, output.size(-1)), sample['tgt_tokens'].view(-1)) / len(sample['src_lengths']) + loss.backward() + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_norm) + optimizer.step() + optimizer.zero_grad() + '''___QUESTION-1-DESCRIBE-E-END___''' + + # Update statistics for progress bar + total_loss, num_tokens, batch_size = loss.item(), sample['num_tokens'], len(sample['src_tokens']) + stats['loss'] += total_loss * len(sample['src_lengths']) / sample['num_tokens'] + stats['lr'] += optimizer.param_groups[0]['lr'] + stats['num_tokens'] += num_tokens / len(sample['src_tokens']) + stats['batch_size'] += batch_size + stats['grad_norm'] += grad_norm + stats['clip'] += 1 if grad_norm > args.clip_norm else 0 + progress_bar.set_postfix({key: '{:.4g}'.format(value / (i + 1)) for key, value in stats.items()}, + refresh=True) + + logging.info('Epoch {:03d}: {}'.format(epoch, ' | '.join(key + ' {:.4g}'.format( + value / len(progress_bar)) for key, value in stats.items()))) + + # Calculate validation loss + valid_perplexity = validate(args, model, criterion, valid_dataset, epoch) + model.train() + + # Save checkpoints + if epoch % args.save_interval == 0: + utils.save_checkpoint(args, model, optimizer, epoch, valid_perplexity) # lr_scheduler + + # Check whether to terminate training + if valid_perplexity < best_validate: + best_validate = valid_perplexity + bad_epochs = 0 + else: + bad_epochs += 1 + if bad_epochs >= args.patience: + logging.info('No validation set improvements observed for {:d} epochs. Early stop!'.format(args.patience)) + break + + +def validate(args, model, criterion, valid_dataset, epoch): + """ Validates model performance on a held-out development set. """ + valid_loader = \ + torch.utils.data.DataLoader(valid_dataset, num_workers=1, collate_fn=valid_dataset.collater, + batch_sampler=BatchSampler(valid_dataset, args.max_tokens, args.batch_size, 1, 0, + shuffle=False, seed=42)) + model.eval() + stats = OrderedDict() + stats['valid_loss'] = 0 + stats['num_tokens'] = 0 + stats['batch_size'] = 0 + + # Iterate over the validation set + for i, sample in enumerate(valid_loader): + if len(sample) == 0: + continue + with torch.no_grad(): + # Compute loss + output, attn_scores = model(sample['src_tokens'], sample['src_lengths'], sample['tgt_inputs']) + loss = criterion(output.view(-1, output.size(-1)), sample['tgt_tokens'].view(-1)) + # Update tracked statistics + stats['valid_loss'] += loss.item() + stats['num_tokens'] += sample['num_tokens'] + stats['batch_size'] += len(sample['src_tokens']) + + # Calculate validation perplexity + stats['valid_loss'] = stats['valid_loss'] / stats['num_tokens'] + perplexity = np.exp(stats['valid_loss']) + stats['num_tokens'] = stats['num_tokens'] / stats['batch_size'] + + logging.info( + 'Epoch {:03d}: {}'.format(epoch, ' | '.join(key + ' {:.3g}'.format(value) for key, value in stats.items())) + + ' | valid_perplexity {:.3g}'.format(perplexity)) + + return perplexity + + +if __name__ == '__main__': + args = get_args() + args.device_id = 0 + + # Set up logging to file + logging.basicConfig(filename=args.log_file, filemode='a', level=logging.INFO, + format='%(levelname)s: %(message)s') + if args.log_file is not None: + # Logging to console + console = logging.StreamHandler() + console.setLevel(logging.INFO) + logging.getLogger('').addHandler(console) + + main(args) diff --git a/2024/coursework/nlu-cw2/translate.py b/2024/coursework/nlu-cw2/translate.py new file mode 100644 index 0000000000000000000000000000000000000000..26549bb1c9e7f0703137ee0f3352b45e1d200073 --- /dev/null +++ b/2024/coursework/nlu-cw2/translate.py @@ -0,0 +1,121 @@ +import os +import logging +import argparse +import numpy as np +from tqdm import tqdm + +import torch +from torch.serialization import default_restore_location + +from seq2seq import models, utils +from seq2seq.data.dictionary import Dictionary +from seq2seq.data.dataset import Seq2SeqDataset, BatchSampler + + +def get_args(): + """ Defines generation-specific hyper-parameters. """ + parser = argparse.ArgumentParser('Sequence to Sequence Model') + parser.add_argument('--cuda', default=False, help='Use a GPU') + parser.add_argument('--seed', default=42, type=int, help='pseudo random number generator seed') + + # Add data arguments + parser.add_argument('--data', default='data-bin', help='path to data directory') + parser.add_argument('--checkpoint-path', default='checkpoints/checkpoint_best.pt', help='path to the model file') + parser.add_argument('--batch-size', default=None, type=int, help='maximum number of sentences in a batch') + parser.add_argument('--output', default='model_translations.txt', type=str, + help='path to the output file destination') + parser.add_argument('--max-len', default=25, type=int, help='maximum length of generated sequence') + + return parser.parse_args() + + +def main(args): + """ Main translation function' """ + # Load arguments from checkpoint + torch.manual_seed(args.seed) + state_dict = torch.load(args.checkpoint_path, map_location=lambda s, l: default_restore_location(s, 'cpu')) + args = argparse.Namespace(**{**vars(args), **vars(state_dict['args'])}) + utils.init_logging(args) + + # Load dictionaries + src_dict = Dictionary.load(os.path.join(args.data, 'dict.{:s}'.format(args.source_lang))) + logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format(args.source_lang, len(src_dict))) + tgt_dict = Dictionary.load(os.path.join(args.data, 'dict.{:s}'.format(args.target_lang))) + logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format(args.target_lang, len(tgt_dict))) + + # Load dataset + test_dataset = Seq2SeqDataset( + src_file=os.path.join(args.data, 'test.{:s}'.format(args.source_lang)), + tgt_file=os.path.join(args.data, 'test.{:s}'.format(args.target_lang)), + src_dict=src_dict, tgt_dict=tgt_dict) + + test_loader = torch.utils.data.DataLoader(test_dataset, num_workers=1, collate_fn=test_dataset.collater, + batch_sampler=BatchSampler(test_dataset, 9999999, + args.batch_size, 1, 0, shuffle=False, + seed=args.seed)) + # Build model and criterion + model = models.build_model(args, src_dict, tgt_dict) + if args.cuda: + model = model.cuda() + model.eval() + model.load_state_dict(state_dict['model']) + logging.info('Loaded a model from checkpoint {:s}'.format(args.checkpoint_path)) + progress_bar = tqdm(test_loader, desc='| Generation', leave=False) + + # Iterate over the test set + all_hyps = {} + for i, sample in enumerate(progress_bar): + with torch.no_grad(): + # Compute the encoder output + encoder_out = model.encoder(sample['src_tokens'], sample['src_lengths']) + go_slice = \ + torch.ones(sample['src_tokens'].shape[0], 1).fill_(tgt_dict.eos_idx).type_as(sample['src_tokens']) + prev_words = go_slice + next_words = None + + for _ in range(args.max_len): + with torch.no_grad(): + # Compute the decoder output by repeatedly feeding it the decoded sentence prefix + decoder_out, _ = model.decoder(prev_words, encoder_out) + # Suppress <UNK>s + _, next_candidates = torch.topk(decoder_out, 2, dim=-1) + best_candidates = next_candidates[:, :, 0] + backoff_candidates = next_candidates[:, :, 1] + next_words = torch.where(best_candidates == tgt_dict.unk_idx, backoff_candidates, best_candidates) + prev_words = torch.cat([go_slice, next_words], dim=1) + + # Segment into sentences + decoded_batch = next_words.numpy() + output_sentences = [decoded_batch[row, :] for row in range(decoded_batch.shape[0])] + assert(len(output_sentences) == len(sample['id'].data)) + + # Remove padding + temp = list() + for sent in output_sentences: + first_eos = np.where(sent == tgt_dict.eos_idx)[0] + if len(first_eos) > 0: + temp.append(sent[:first_eos[0]]) + else: + temp.append([]) + output_sentences = temp + + # Convert arrays of indices into strings of words + output_sentences = [tgt_dict.string(sent) for sent in output_sentences] + + # Save translations + assert(len(output_sentences) == len(sample['id'].data)) + for ii, sent in enumerate(output_sentences): + all_hyps[int(sample['id'].data[ii])] = sent + + # Write to file + if args.output is not None: + with open(args.output, 'w') as out_file: + for sent_id in range(len(all_hyps.keys())): + out_file.write(all_hyps[sent_id] + '\n') + + logging.info("Output {:d} translations to {:s}".format(len(all_hyps), args.output)) + + +if __name__ == '__main__': + args = get_args() + main(args)