From b9baf1d29b96f123f3d88df666aef2ab2434ed68 Mon Sep 17 00:00:00 2001
From: Yifu Qiu <yifu@MacBook-Pro-557.local>
Date: Fri, 1 Mar 2024 11:11:14 +0000
Subject: [PATCH] add cw2

---
 2024/coursework/nlu-cw2/README.md       |  23 +++
 2024/coursework/nlu-cw2/example.sh      |  41 +++++
 2024/coursework/nlu-cw2/multi-bleu.perl | 162 ++++++++++++++++++
 2024/coursework/nlu-cw2/preprocess.py   | 108 ++++++++++++
 2024/coursework/nlu-cw2/train.py        | 218 ++++++++++++++++++++++++
 2024/coursework/nlu-cw2/translate.py    | 121 +++++++++++++
 6 files changed, 673 insertions(+)
 create mode 100644 2024/coursework/nlu-cw2/README.md
 create mode 100755 2024/coursework/nlu-cw2/example.sh
 create mode 100644 2024/coursework/nlu-cw2/multi-bleu.perl
 create mode 100644 2024/coursework/nlu-cw2/preprocess.py
 create mode 100755 2024/coursework/nlu-cw2/train.py
 create mode 100644 2024/coursework/nlu-cw2/translate.py

diff --git a/2024/coursework/nlu-cw2/README.md b/2024/coursework/nlu-cw2/README.md
new file mode 100644
index 0000000..179ccb8
--- /dev/null
+++ b/2024/coursework/nlu-cw2/README.md
@@ -0,0 +1,23 @@
+NLU+ Coursework 2 (2024 Spring)
+---
+##### Yifu Qiu, Shay Cohen and Alexandra Birch-Mayne
+
+###### Prior versions from Yao Fu, Frank Keller, Tom Sherborne, Mirella Lapata, Denis Emelin, Jon Mallinson, Ida Szubert and Rico Sennrich 
+
+---
+This repository contains the data and materials for the second coursework of NLU+ 2024 Spring.
+
+To get started, clone the repository:
+```bash
+git clone https://git.ecdf.ed.ac.uk/nlu_public/course_materials.git
+```
+
+Then follow the instruction from the Course Work 2 handout.
+
+
+##### Acnowledgements
+This code is based on [Binh Tang's NMT Tutorial](https://github.com/tangbinh/machine-translation) and the data is a sample
+of [Europarl](http://www.statmt.org/europarl/).
+
+
+ 
diff --git a/2024/coursework/nlu-cw2/example.sh b/2024/coursework/nlu-cw2/example.sh
new file mode 100755
index 0000000..a84e70c
--- /dev/null
+++ b/2024/coursework/nlu-cw2/example.sh
@@ -0,0 +1,41 @@
+#!/usr/bin/env bash
+###########
+# USAGE NOTE:
+# The following script is designed to help you get started running models with the codebase.
+# At a minimum, all you need to do is give a name do your experiment using the ${EXP_NAME} variable.
+# Additional arguments can be added to the train.py command for different functionality.
+# We recommend copying this script and modifying it for each experiment you try (multi-layer, lexical etc)
+###########
+
+# Activate Conda Environment [assuming your Miniconda installation is in your root directory]
+source ~/miniconda3/bin/activate nlu
+
+# Define a location for all your experiments to save
+ROOT=$(git rev-parse --show-toplevel)
+RESULTS_ROOT="${ROOT}/results"
+mkdir -p ${RESULTS_ROOT}
+
+### NAME YOUR EXPERIMENT HERE ##
+EXP_NAME="baseline"
+################################
+
+## Local variables for current experiment
+EXP_ROOT="${RESULTS_ROOT}/${EXP_NAME}"
+DATA_DIR="${ROOT}/europarl_prepared"
+TEST_EN_GOLD="${ROOT}/europarl_raw/test.en"
+TEST_EN_PRED="${EXP_ROOT}/model_translations.txt"
+mkdir -p ${EXP_ROOT}
+
+# Train model. Defaults are used for any argument not specified here. Use "\" to add arguments over multiple lines.
+python train.py --save-dir "${EXP_ROOT}" \
+                --log-file "${EXP_ROOT}/log.out"  \
+                --data "${DATA_DIR}" \
+                ### ADDITIONAL ARGUMENTS HERE ###
+
+## Prediction step
+python translate.py \
+    --checkpoint-path "${EXP_ROOT}/checkpoint_best.pt" \
+    --output "${TEST_EN_PRED}"
+
+## Calculate BLEU score for model outputs
+perl multi-bleu.perl -lc ${TEST_EN_GOLD} < ${TEST_EN_PRED} | tee "${EXP_ROOT}/bleu.txt"
diff --git a/2024/coursework/nlu-cw2/multi-bleu.perl b/2024/coursework/nlu-cw2/multi-bleu.perl
new file mode 100644
index 0000000..61de10d
--- /dev/null
+++ b/2024/coursework/nlu-cw2/multi-bleu.perl
@@ -0,0 +1,162 @@
+#!/usr/bin/env perl
+#
+# This file is part of moses.  Its use is licensed under the GNU Lesser General
+# Public License version 2.1 or, at your option, any later version.
+
+# $Id$
+use warnings;
+use strict;
+
+my $lowercase = 0;
+if ($ARGV[0] eq "-lc") {
+  $lowercase = 1;
+  shift;
+}
+
+my $stem = $ARGV[0];
+if (!defined $stem) {
+  print STDERR "usage: multi-bleu.pl [-lc] reference < hypothesis\n";
+  print STDERR "Reads the references from reference or reference0, reference1, ...\n";
+  exit(1);
+}
+
+$stem .= ".ref" if !-e $stem && !-e $stem."0" && -e $stem.".ref0";
+
+my @REF;
+my $ref=0;
+while(-e "$stem$ref") {
+    &add_to_ref("$stem$ref",\@REF);
+    $ref++;
+}
+&add_to_ref($stem,\@REF) if -e $stem;
+die("ERROR: could not find reference file $stem") unless scalar @REF;
+
+sub add_to_ref {
+    my ($file,$REF) = @_;
+    my $s=0;
+    open(REF,$file) or die "Can't read $file";
+    while(<REF>) {
+	chop;
+	push @{$$REF[$s++]}, $_;
+    }
+    close(REF);
+}
+
+my(@CORRECT,@TOTAL,$length_translation,$length_reference);
+my $s=0;
+while(<STDIN>) {
+    chop;
+    $_ = lc if $lowercase;
+    my @WORD = split;
+    my %REF_NGRAM = ();
+    my $length_translation_this_sentence = scalar(@WORD);
+    my ($closest_diff,$closest_length) = (9999,9999);
+    foreach my $reference (@{$REF[$s]}) {
+#      print "$s $_ <=> $reference\n";
+  $reference = lc($reference) if $lowercase;
+	my @WORD = split(' ',$reference);
+	my $length = scalar(@WORD);
+        my $diff = abs($length_translation_this_sentence-$length);
+	if ($diff < $closest_diff) {
+	    $closest_diff = $diff;
+	    $closest_length = $length;
+	    # print STDERR "$s: closest diff ".abs($length_translation_this_sentence-$length)." = abs($length_translation_this_sentence-$length), setting len: $closest_length\n";
+	} elsif ($diff == $closest_diff) {
+            $closest_length = $length if $length < $closest_length;
+            # from two references with the same closeness to me
+            # take the *shorter* into account, not the "first" one.
+        }
+	for(my $n=1;$n<=4;$n++) {
+	    my %REF_NGRAM_N = ();
+	    for(my $start=0;$start<=$#WORD-($n-1);$start++) {
+		my $ngram = "$n";
+		for(my $w=0;$w<$n;$w++) {
+		    $ngram .= " ".$WORD[$start+$w];
+		}
+		$REF_NGRAM_N{$ngram}++;
+	    }
+	    foreach my $ngram (keys %REF_NGRAM_N) {
+		if (!defined($REF_NGRAM{$ngram}) ||
+		    $REF_NGRAM{$ngram} < $REF_NGRAM_N{$ngram}) {
+		    $REF_NGRAM{$ngram} = $REF_NGRAM_N{$ngram};
+#	    print "$i: REF_NGRAM{$ngram} = $REF_NGRAM{$ngram}<BR>\n";
+		}
+	    }
+	}
+    }
+    $length_translation += $length_translation_this_sentence;
+    $length_reference += $closest_length;
+    for(my $n=1;$n<=4;$n++) {
+	my %T_NGRAM = ();
+	for(my $start=0;$start<=$#WORD-($n-1);$start++) {
+	    my $ngram = "$n";
+	    for(my $w=0;$w<$n;$w++) {
+		$ngram .= " ".$WORD[$start+$w];
+	    }
+	    $T_NGRAM{$ngram}++;
+	}
+	foreach my $ngram (keys %T_NGRAM) {
+	    $ngram =~ /^(\d+) /;
+	    my $n = $1;
+            # my $corr = 0;
+#	print "$i e $ngram $T_NGRAM{$ngram}<BR>\n";
+	    $TOTAL[$n] += $T_NGRAM{$ngram};
+	    if (defined($REF_NGRAM{$ngram})) {
+		if ($REF_NGRAM{$ngram} >= $T_NGRAM{$ngram}) {
+		    $CORRECT[$n] += $T_NGRAM{$ngram};
+                    # $corr =  $T_NGRAM{$ngram};
+#	    print "$i e correct1 $T_NGRAM{$ngram}<BR>\n";
+		}
+		else {
+		    $CORRECT[$n] += $REF_NGRAM{$ngram};
+                    # $corr =  $REF_NGRAM{$ngram};
+#	    print "$i e correct2 $REF_NGRAM{$ngram}<BR>\n";
+		}
+	    }
+            # $REF_NGRAM{$ngram} = 0 if !defined $REF_NGRAM{$ngram};
+            # print STDERR "$ngram: {$s, $REF_NGRAM{$ngram}, $T_NGRAM{$ngram}, $corr}\n"
+	}
+    }
+    $s++;
+}
+my $brevity_penalty = 1;
+my $bleu = 0;
+
+my @bleu=();
+
+for(my $n=1;$n<=4;$n++) {
+  if (defined ($TOTAL[$n])){
+    $bleu[$n]=($TOTAL[$n])?$CORRECT[$n]/$TOTAL[$n]:0;
+    # print STDERR "CORRECT[$n]:$CORRECT[$n] TOTAL[$n]:$TOTAL[$n]\n";
+  }else{
+    $bleu[$n]=0;
+  }
+}
+
+if ($length_reference==0){
+  printf "BLEU = 0, 0/0/0/0 (BP=0, ratio=0, hyp_len=0, ref_len=0)\n";
+  exit(1);
+}
+
+if ($length_translation<$length_reference) {
+  $brevity_penalty = exp(1-$length_reference/$length_translation);
+}
+$bleu = $brevity_penalty * exp((my_log( $bleu[1] ) +
+				my_log( $bleu[2] ) +
+				my_log( $bleu[3] ) +
+				my_log( $bleu[4] ) ) / 4) ;
+printf "BLEU = %.2f, %.1f/%.1f/%.1f/%.1f (BP=%.3f, ratio=%.3f, hyp_len=%d, ref_len=%d)\n",
+    100*$bleu,
+    100*$bleu[1],
+    100*$bleu[2],
+    100*$bleu[3],
+    100*$bleu[4],
+    $brevity_penalty,
+    $length_translation / $length_reference,
+    $length_translation,
+    $length_reference;
+
+sub my_log {
+  return -9999999999 unless $_[0];
+  return log($_[0]);
+}
diff --git a/2024/coursework/nlu-cw2/preprocess.py b/2024/coursework/nlu-cw2/preprocess.py
new file mode 100644
index 0000000..cd58ec1
--- /dev/null
+++ b/2024/coursework/nlu-cw2/preprocess.py
@@ -0,0 +1,108 @@
+import argparse
+import collections
+import logging
+import os
+import sys
+import re
+import pickle
+
+from seq2seq import utils
+from seq2seq.data.dictionary import Dictionary
+
+SPACE_NORMALIZER = re.compile("\s+")
+
+
+def word_tokenize(line):
+    line = SPACE_NORMALIZER.sub(" ", line)
+    line = line.strip()
+    return line.split()
+
+
+def get_args():
+    parser = argparse.ArgumentParser('Data pre-processing)')
+    parser.add_argument('--source-lang', default=None, metavar='SRC', help='source language')
+    parser.add_argument('--target-lang', default=None, metavar='TGT', help='target language')
+
+    parser.add_argument('--train-prefix', default=None, metavar='FP', help='train file prefix')
+    parser.add_argument('--tiny-train-prefix', default=None, metavar='FP', help='tiny train file prefix')
+    parser.add_argument('--valid-prefix', default=None, metavar='FP', help='valid file prefix')
+    parser.add_argument('--test-prefix', default=None, metavar='FP', help='test file prefix')
+    parser.add_argument('--dest-dir', default='data-bin', metavar='DIR', help='destination dir')
+
+    parser.add_argument('--threshold-src', default=2, type=int,
+                        help='map words appearing less than threshold times to unknown')
+    parser.add_argument('--num-words-src', default=-1, type=int, help='number of source words to retain')
+    parser.add_argument('--threshold-tgt', default=2, type=int,
+                        help='map words appearing less than threshold times to unknown')
+    parser.add_argument('--num-words-tgt', default=-1, type=int, help='number of target words to retain')
+    return parser.parse_args()
+
+
+def main(args):
+    os.makedirs(args.dest_dir, exist_ok=True)
+    src_dict = build_dictionary([args.train_prefix + '.' + args.source_lang])
+    tgt_dict = build_dictionary([args.train_prefix + '.' + args.target_lang])
+
+    src_dict.finalize(threshold=args.threshold_src, num_words=args.num_words_src)
+    src_dict.save(os.path.join(args.dest_dir, 'dict.' + args.source_lang))
+    logging.info('Built a source dictionary ({}) with {} words'.format(args.source_lang, len(src_dict)))
+
+    tgt_dict.finalize(threshold=args.threshold_tgt, num_words=args.num_words_tgt)
+    tgt_dict.save(os.path.join(args.dest_dir, 'dict.' + args.target_lang))
+    logging.info('Built a target dictionary ({}) with {} words'.format(args.target_lang, len(tgt_dict)))
+
+    def make_split_datasets(lang, dictionary):
+        if args.train_prefix is not None:
+            make_binary_dataset(args.train_prefix + '.' + lang, os.path.join(args.dest_dir, 'train.' + lang),
+                                dictionary)
+        if args.tiny_train_prefix is not None:
+            make_binary_dataset(args.tiny_train_prefix + '.' + lang, os.path.join(args.dest_dir, 'tiny_train.' + lang),
+                                dictionary)
+        if args.valid_prefix is not None:
+            make_binary_dataset(args.valid_prefix + '.' + lang, os.path.join(args.dest_dir, 'valid.' + lang),
+                                dictionary)
+        if args.test_prefix is not None:
+            make_binary_dataset(args.test_prefix + '.' + lang, os.path.join(args.dest_dir, 'test.' + lang), dictionary)
+
+    make_split_datasets(args.source_lang, src_dict)
+    make_split_datasets(args.target_lang, tgt_dict)
+
+
+def build_dictionary(filenames, tokenize=word_tokenize):
+    dictionary = Dictionary()
+    for filename in filenames:
+        with open(filename, 'r') as file:
+            for line in file:
+                for symbol in word_tokenize(line.strip()):
+                    dictionary.add_word(symbol)
+                dictionary.add_word(dictionary.eos_word)
+    return dictionary
+
+
+def make_binary_dataset(input_file, output_file, dictionary, tokenize=word_tokenize, append_eos=True):
+    nsent, ntok = 0, 0
+    unk_counter = collections.Counter()
+
+    def unk_consumer(word, idx):
+        if idx == dictionary.unk_idx and word != dictionary.unk_word:
+            unk_counter.update([word])
+
+    tokens_list = []
+    with open(input_file, 'r') as inf:
+        for line in inf:
+            tokens = dictionary.binarize(line.strip(), word_tokenize, append_eos, consumer=unk_consumer)
+            nsent, ntok = nsent + 1, ntok + len(tokens)
+            tokens_list.append(tokens.numpy())
+
+    with open(output_file, 'wb') as outf:
+        pickle.dump(tokens_list, outf, protocol=pickle.HIGHEST_PROTOCOL)
+        logging.info('Built a binary dataset for {}: {} sentences, {} tokens, {:.3f}% replaced by unknown token'.format(
+            input_file, nsent, ntok, 100.0 * sum(unk_counter.values()) / ntok, dictionary.unk_word))
+
+
+if __name__ == '__main__':
+    args = get_args()
+    utils.init_logging(args)
+    logging.info('COMMAND: %s' % ' '.join(sys.argv))
+    logging.info('Arguments: {}'.format(vars(args)))
+    main(args)
diff --git a/2024/coursework/nlu-cw2/train.py b/2024/coursework/nlu-cw2/train.py
new file mode 100755
index 0000000..2bd30d0
--- /dev/null
+++ b/2024/coursework/nlu-cw2/train.py
@@ -0,0 +1,218 @@
+import os
+import logging
+import argparse
+import numpy as np
+from tqdm import tqdm
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+
+from seq2seq import models, utils
+from seq2seq.data.dictionary import Dictionary
+from seq2seq.data.dataset import Seq2SeqDataset, BatchSampler
+from seq2seq.models import ARCH_MODEL_REGISTRY, ARCH_CONFIG_REGISTRY
+
+
+def get_args():
+    """ Defines training-specific hyper-parameters. """
+    parser = argparse.ArgumentParser('Sequence to Sequence Model')
+
+    # Add data arguments
+    parser.add_argument('--data', default='europarl_prepared', help='path to data directory')
+    parser.add_argument('--source-lang', default='de', help='source language')
+    parser.add_argument('--target-lang', default='en', help='target language')
+    parser.add_argument('--max-tokens', default=None, type=int, help='maximum number of tokens in a batch')
+    parser.add_argument('--batch-size', default=10, type=int, help='maximum number of sentences in a batch')
+    parser.add_argument('--train-on-tiny', action='store_true', help='train model on a tiny dataset')
+
+    # Add model arguments
+    parser.add_argument('--arch', default='lstm', choices=ARCH_MODEL_REGISTRY.keys(), help='model architecture')
+
+    # Add optimization arguments
+    parser.add_argument('--max-epoch', default=100, type=int, help='force stop training at specified epoch')
+    parser.add_argument('--clip-norm', default=4.0, type=float, help='clip threshold of gradients')
+    parser.add_argument('--lr', default=0.0003, type=float, help='learning rate')
+    parser.add_argument('--patience', default=10, type=int,
+                        help='number of epochs without improvement on validation set before early stopping')
+
+    # Add checkpoint arguments
+    parser.add_argument('--log-file', default=None, help='path to save logs')
+    parser.add_argument('--save-dir', default='checkpoints', help='path to save checkpoints')
+    parser.add_argument('--restore-file', default='checkpoint_last.pt', help='filename to load checkpoint')
+    parser.add_argument('--save-interval', type=int, default=1, help='save a checkpoint every N epochs')
+    parser.add_argument('--no-save', action='store_true', help='don\'t save models or checkpoints')
+    parser.add_argument('--epoch-checkpoints', action='store_true', help='store all epoch checkpoints')
+
+    # Parse twice as model arguments are not known the first time
+    args, _ = parser.parse_known_args()
+    model_parser = parser.add_argument_group(argument_default=argparse.SUPPRESS)
+    ARCH_MODEL_REGISTRY[args.arch].add_args(model_parser)
+    args = parser.parse_args()
+    ARCH_CONFIG_REGISTRY[args.arch](args)
+    return args
+
+
+def main(args):
+    """ Main training function. Trains the translation model over the course of several epochs, including dynamic
+    learning rate adjustment and gradient clipping. """
+    logging.info('Commencing training!')
+    torch.manual_seed(42)
+    np.random.seed(42)
+
+    utils.init_logging(args)
+
+    # Load dictionaries
+    src_dict = Dictionary.load(os.path.join(args.data, 'dict.{:s}'.format(args.source_lang)))
+    logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format(args.source_lang, len(src_dict)))
+    tgt_dict = Dictionary.load(os.path.join(args.data, 'dict.{:s}'.format(args.target_lang)))
+    logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format(args.target_lang, len(tgt_dict)))
+
+    # Load datasets
+    def load_data(split):
+        return Seq2SeqDataset(
+            src_file=os.path.join(args.data, '{:s}.{:s}'.format(split, args.source_lang)),
+            tgt_file=os.path.join(args.data, '{:s}.{:s}'.format(split, args.target_lang)),
+            src_dict=src_dict, tgt_dict=tgt_dict)
+
+    train_dataset = load_data(split='train') if not args.train_on_tiny else load_data(split='tiny_train')
+    valid_dataset = load_data(split='valid')
+
+    # Build model and optimization criterion
+    model = models.build_model(args, src_dict, tgt_dict)
+    logging.info('Built a model with {:d} parameters'.format(sum(p.numel() for p in model.parameters())))
+    criterion = nn.CrossEntropyLoss(ignore_index=src_dict.pad_idx, reduction='sum')
+
+    # Instantiate optimizer and learning rate scheduler
+    optimizer = torch.optim.Adam(model.parameters(), args.lr)
+
+    # Load last checkpoint if one exists
+    state_dict = utils.load_checkpoint(args, model, optimizer)  # lr_scheduler
+    last_epoch = state_dict['last_epoch'] if state_dict is not None else -1
+
+    # Track validation performance for early stopping
+    bad_epochs = 0
+    best_validate = float('inf')
+
+    for epoch in range(last_epoch + 1, args.max_epoch):
+        train_loader = \
+            torch.utils.data.DataLoader(train_dataset, num_workers=1, collate_fn=train_dataset.collater,
+                                        batch_sampler=BatchSampler(train_dataset, args.max_tokens, args.batch_size, 1,
+                                                                   0, shuffle=True, seed=42))
+        model.train()
+        stats = OrderedDict()
+        stats['loss'] = 0
+        stats['lr'] = 0
+        stats['num_tokens'] = 0
+        stats['batch_size'] = 0
+        stats['grad_norm'] = 0
+        stats['clip'] = 0
+        # Display progress
+        progress_bar = tqdm(train_loader, desc='| Epoch {:03d}'.format(epoch), leave=False, disable=False)
+
+        # Iterate over the training set
+        for i, sample in enumerate(progress_bar):
+
+            if len(sample) == 0:
+                continue
+            model.train()
+
+            '''
+            ___QUESTION-1-DESCRIBE-E-START___
+            1.  Add tensor shape annotation to each of the output tensor
+            2.  Add line-by-line description about the following lines of code do.
+            '''
+            output, _ = model(sample['src_tokens'], sample['src_lengths'], sample['tgt_inputs'])
+
+            loss = \
+                criterion(output.view(-1, output.size(-1)), sample['tgt_tokens'].view(-1)) / len(sample['src_lengths'])
+            loss.backward()
+            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_norm)
+            optimizer.step()
+            optimizer.zero_grad()
+            '''___QUESTION-1-DESCRIBE-E-END___'''
+
+            # Update statistics for progress bar
+            total_loss, num_tokens, batch_size = loss.item(), sample['num_tokens'], len(sample['src_tokens'])
+            stats['loss'] += total_loss * len(sample['src_lengths']) / sample['num_tokens']
+            stats['lr'] += optimizer.param_groups[0]['lr']
+            stats['num_tokens'] += num_tokens / len(sample['src_tokens'])
+            stats['batch_size'] += batch_size
+            stats['grad_norm'] += grad_norm
+            stats['clip'] += 1 if grad_norm > args.clip_norm else 0
+            progress_bar.set_postfix({key: '{:.4g}'.format(value / (i + 1)) for key, value in stats.items()},
+                                     refresh=True)
+
+        logging.info('Epoch {:03d}: {}'.format(epoch, ' | '.join(key + ' {:.4g}'.format(
+            value / len(progress_bar)) for key, value in stats.items())))
+
+        # Calculate validation loss
+        valid_perplexity = validate(args, model, criterion, valid_dataset, epoch)
+        model.train()
+
+        # Save checkpoints
+        if epoch % args.save_interval == 0:
+            utils.save_checkpoint(args, model, optimizer, epoch, valid_perplexity)  # lr_scheduler
+
+        # Check whether to terminate training
+        if valid_perplexity < best_validate:
+            best_validate = valid_perplexity
+            bad_epochs = 0
+        else:
+            bad_epochs += 1
+        if bad_epochs >= args.patience:
+            logging.info('No validation set improvements observed for {:d} epochs. Early stop!'.format(args.patience))
+            break
+
+
+def validate(args, model, criterion, valid_dataset, epoch):
+    """ Validates model performance on a held-out development set. """
+    valid_loader = \
+        torch.utils.data.DataLoader(valid_dataset, num_workers=1, collate_fn=valid_dataset.collater,
+                                    batch_sampler=BatchSampler(valid_dataset, args.max_tokens, args.batch_size, 1, 0,
+                                                               shuffle=False, seed=42))
+    model.eval()
+    stats = OrderedDict()
+    stats['valid_loss'] = 0
+    stats['num_tokens'] = 0
+    stats['batch_size'] = 0
+
+    # Iterate over the validation set
+    for i, sample in enumerate(valid_loader):
+        if len(sample) == 0:
+            continue
+        with torch.no_grad():
+            # Compute loss
+            output, attn_scores = model(sample['src_tokens'], sample['src_lengths'], sample['tgt_inputs'])
+            loss = criterion(output.view(-1, output.size(-1)), sample['tgt_tokens'].view(-1))
+        # Update tracked statistics
+        stats['valid_loss'] += loss.item()
+        stats['num_tokens'] += sample['num_tokens']
+        stats['batch_size'] += len(sample['src_tokens'])
+
+    # Calculate validation perplexity
+    stats['valid_loss'] = stats['valid_loss'] / stats['num_tokens']
+    perplexity = np.exp(stats['valid_loss'])
+    stats['num_tokens'] = stats['num_tokens'] / stats['batch_size']
+
+    logging.info(
+        'Epoch {:03d}: {}'.format(epoch, ' | '.join(key + ' {:.3g}'.format(value) for key, value in stats.items())) +
+        ' | valid_perplexity {:.3g}'.format(perplexity))
+
+    return perplexity
+
+
+if __name__ == '__main__':
+    args = get_args()
+    args.device_id = 0
+
+    # Set up logging to file
+    logging.basicConfig(filename=args.log_file, filemode='a', level=logging.INFO,
+                        format='%(levelname)s: %(message)s')
+    if args.log_file is not None:
+        # Logging to console
+        console = logging.StreamHandler()
+        console.setLevel(logging.INFO)
+        logging.getLogger('').addHandler(console)
+
+    main(args)
diff --git a/2024/coursework/nlu-cw2/translate.py b/2024/coursework/nlu-cw2/translate.py
new file mode 100644
index 0000000..26549bb
--- /dev/null
+++ b/2024/coursework/nlu-cw2/translate.py
@@ -0,0 +1,121 @@
+import os
+import logging
+import argparse
+import numpy as np
+from tqdm import tqdm
+
+import torch
+from torch.serialization import default_restore_location
+
+from seq2seq import models, utils
+from seq2seq.data.dictionary import Dictionary
+from seq2seq.data.dataset import Seq2SeqDataset, BatchSampler
+
+
+def get_args():
+    """ Defines generation-specific hyper-parameters. """
+    parser = argparse.ArgumentParser('Sequence to Sequence Model')
+    parser.add_argument('--cuda', default=False, help='Use a GPU')
+    parser.add_argument('--seed', default=42, type=int, help='pseudo random number generator seed')
+
+    # Add data arguments
+    parser.add_argument('--data', default='data-bin', help='path to data directory')
+    parser.add_argument('--checkpoint-path', default='checkpoints/checkpoint_best.pt', help='path to the model file')
+    parser.add_argument('--batch-size', default=None, type=int, help='maximum number of sentences in a batch')
+    parser.add_argument('--output', default='model_translations.txt', type=str,
+                        help='path to the output file destination')
+    parser.add_argument('--max-len', default=25, type=int, help='maximum length of generated sequence')
+
+    return parser.parse_args()
+
+
+def main(args):
+    """ Main translation function' """
+    # Load arguments from checkpoint
+    torch.manual_seed(args.seed)
+    state_dict = torch.load(args.checkpoint_path, map_location=lambda s, l: default_restore_location(s, 'cpu'))
+    args = argparse.Namespace(**{**vars(args), **vars(state_dict['args'])})
+    utils.init_logging(args)
+
+    # Load dictionaries
+    src_dict = Dictionary.load(os.path.join(args.data, 'dict.{:s}'.format(args.source_lang)))
+    logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format(args.source_lang, len(src_dict)))
+    tgt_dict = Dictionary.load(os.path.join(args.data, 'dict.{:s}'.format(args.target_lang)))
+    logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format(args.target_lang, len(tgt_dict)))
+
+    # Load dataset
+    test_dataset = Seq2SeqDataset(
+        src_file=os.path.join(args.data, 'test.{:s}'.format(args.source_lang)),
+        tgt_file=os.path.join(args.data, 'test.{:s}'.format(args.target_lang)),
+        src_dict=src_dict, tgt_dict=tgt_dict)
+
+    test_loader = torch.utils.data.DataLoader(test_dataset, num_workers=1, collate_fn=test_dataset.collater,
+                                              batch_sampler=BatchSampler(test_dataset, 9999999,
+                                                                         args.batch_size, 1, 0, shuffle=False,
+                                                                         seed=args.seed))
+    # Build model and criterion
+    model = models.build_model(args, src_dict, tgt_dict)
+    if args.cuda:
+        model = model.cuda()
+    model.eval()
+    model.load_state_dict(state_dict['model'])
+    logging.info('Loaded a model from checkpoint {:s}'.format(args.checkpoint_path))
+    progress_bar = tqdm(test_loader, desc='| Generation', leave=False)
+
+    # Iterate over the test set
+    all_hyps = {}
+    for i, sample in enumerate(progress_bar):
+        with torch.no_grad():
+            # Compute the encoder output
+            encoder_out = model.encoder(sample['src_tokens'], sample['src_lengths'])
+            go_slice = \
+                torch.ones(sample['src_tokens'].shape[0], 1).fill_(tgt_dict.eos_idx).type_as(sample['src_tokens'])
+            prev_words = go_slice
+            next_words = None
+
+        for _ in range(args.max_len):
+            with torch.no_grad():
+                # Compute the decoder output by repeatedly feeding it the decoded sentence prefix
+                decoder_out, _ = model.decoder(prev_words, encoder_out)
+            # Suppress <UNK>s
+            _, next_candidates = torch.topk(decoder_out, 2, dim=-1)
+            best_candidates = next_candidates[:, :, 0]
+            backoff_candidates = next_candidates[:, :, 1]
+            next_words = torch.where(best_candidates == tgt_dict.unk_idx, backoff_candidates, best_candidates)
+            prev_words = torch.cat([go_slice, next_words], dim=1)
+
+        # Segment into sentences
+        decoded_batch = next_words.numpy()
+        output_sentences = [decoded_batch[row, :] for row in range(decoded_batch.shape[0])]
+        assert(len(output_sentences) == len(sample['id'].data))
+
+        # Remove padding
+        temp = list()
+        for sent in output_sentences:
+            first_eos = np.where(sent == tgt_dict.eos_idx)[0]
+            if len(first_eos) > 0:
+                temp.append(sent[:first_eos[0]])
+            else:
+                temp.append([])
+        output_sentences = temp
+
+        # Convert arrays of indices into strings of words
+        output_sentences = [tgt_dict.string(sent) for sent in output_sentences]
+
+        # Save translations
+        assert(len(output_sentences) == len(sample['id'].data))
+        for ii, sent in enumerate(output_sentences):
+            all_hyps[int(sample['id'].data[ii])] = sent
+
+    # Write to file
+    if args.output is not None:
+        with open(args.output, 'w') as out_file:
+            for sent_id in range(len(all_hyps.keys())):
+                out_file.write(all_hyps[sent_id] + '\n')
+
+        logging.info("Output {:d} translations to {:s}".format(len(all_hyps), args.output))
+
+
+if __name__ == '__main__':
+    args = get_args()
+    main(args)
-- 
GitLab