
Italian-English Neural Machine Translation
This notebook adapts the original recurrent model into a classic neural machine translation pipeline based on:
- trainable token embeddings learned end-to-end
- a bidirectional LSTM encoder
- an autoregressive LSTM decoder
- additive cross-attention over encoder states
- configurable teacher forcing during training
- random train/validation/test splits
- BLEU evaluation for validation and final inference
Embedding choice
For this application the cleanest choice is a trainable embedding layer (
nn.Embedding). It fits an LSTM encoder-decoder naturally and keeps the role of cross-attention visible. BERT would add a contextual pretrained encoder that hides the mechanics of the seq2seq model, while Word2Vec can be used only as static initialization and is usually not worth the extra complexity here.
from pathlib import Path
from urllib.request import Request, urlopen
import json
import shutil
import zipfile
RESOURCE_MANIFESTS = [
"https://assets.deeplearningnotes.com/code-support-resources/datasets/machine-translation/latest.json",
]
def download_file(url, path):
request = Request(url, headers={"User-Agent": "Mozilla/5.0"})
with urlopen(request) as response, open(path, "wb") as file:
shutil.copyfileobj(response, file)
def extract_zip_safely(zip_path, target_dir="."):
target_dir = Path(target_dir).resolve()
with zipfile.ZipFile(zip_path, "r") as zip_ref:
for member in zip_ref.infolist():
target_path = (target_dir / member.filename).resolve()
if not str(target_path).startswith(str(target_dir)):
raise RuntimeError(f"Unsafe zip path: {member.filename}")
zip_ref.extractall(target_dir)
for manifest_url in RESOURCE_MANIFESTS:
name = manifest_url.rstrip("/").split("/")[-2]
manifest_path = Path(f"{name}-latest.json")
archive_path = Path(f"{name}.zip")
download_file(manifest_url, manifest_path)
manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
expected_paths = [Path(path) for path in manifest.get("expected_paths", [])]
if not all(path.exists() for path in expected_paths):
download_file(manifest["archive_url"], archive_path)
extract_zip_safely(archive_path, manifest.get("extract_to", "."))
missing = [str(path) for path in expected_paths if not path.exists()]
if missing:
raise FileNotFoundError(f"Missing expected paths: {missing}")
print(f"{manifest['name']} ready.")from pathlib import Path
from dataclasses import dataclass, asdict
from collections import Counter
from typing import Dict, List, Optional, Sequence, Tuple
import copy
import math
import random
import re
import time
import unicodedata
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence
from torch.utils.data import DataLoader, Dataset
try:
import matplotlib.pyplot as plt
HAS_MATPLOTLIB = True
except Exception:
HAS_MATPLOTLIB = False
def set_seed(seed: int) -> None:
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
set_seed(42)
print(f"PyTorch version: {torch.__version__}")Configuration
The dataset path is resolved relative to the home directory of the machine attached to the notebook kernel, exactly as requested. All training settings are kept in one dataclass so that teacher forcing, model size, and split ratios are easy to change.
@dataclass
class Config:
dataset_path: Path = Path("datasets/machine-translation/ita-eng.txt")
random_seed: int = 42
train_ratio: float = 0.80
val_ratio: float = 0.10
test_ratio: float = 0.10
max_examples: Optional[int] = None
min_freq: int = 2
max_source_tokens: Optional[int] = None
max_target_tokens: Optional[int] = None
batch_size: int = 64
embedding_dim: int = 256
encoder_hidden_dim: int = 256
decoder_hidden_dim: int = 512
attention_dim: int = 256
dropout: float = 0.20
learning_rate: float = 1e-3
epochs: int = 12
teacher_forcing_ratio: float = 0.50
grad_clip: float = 1.0
max_generation_tokens: int = 40
config = Config()
assert abs(config.train_ratio + config.val_ratio + config.test_ratio - 1.0) < 1e-8
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
set_seed(config.random_seed)
print(f"Dataset path: {config.dataset_path}")
print(f"Device: {device}")
print(asdict(config))Load Sentence Pairs
The dataset file is parsed as a sentence-pair file. Since the rows are ordered by increasing difficulty, the split is performed only after a random shuffle. The loader also auto-detects which column is Italian and which is English so the notebook remains robust even if the file order differs from the folder name.
PAD_TOKEN = "<pad>"
BOS_TOKEN = "<bos>"
EOS_TOKEN = "<eos>"
UNK_TOKEN = "<unk>"
SPECIAL_TOKENS = [PAD_TOKEN, BOS_TOKEN, EOS_TOKEN, UNK_TOKEN]
TOKEN_PATTERN = re.compile(r"\w+|[^\w\s]", re.UNICODE)
ITALIAN_HINTS = {"il", "la", "lo", "gli", "le", "un", "una", "che", "non", "sono", "sei", "grazie", "ciao", "come", "per", "con"}
ENGLISH_HINTS = {"the", "a", "an", "and", "i", "you", "we", "they", "is", "are", "thank", "thanks", "hello", "how", "with", "for"}
def normalize_text(text: str) -> str:
text = unicodedata.normalize("NFKC", text.strip().lower())
text = text.replace("’", "'").replace("`", "'")
text = re.sub(r"\s+", " ", text)
return text
def tokenize(text: str) -> List[str]:
normalized = normalize_text(text)
return TOKEN_PATTERN.findall(normalized)
def language_hint_score(text: str, hint_words: set) -> int:
tokens = tokenize(text)
return sum(token in hint_words for token in tokens)
def detect_translation_direction(raw_pairs: Sequence[Tuple[str, str]]) -> Tuple[int, int, str]:
sample = list(raw_pairs[: min(200, len(raw_pairs))])
first_is_ita = sum(language_hint_score(src, ITALIAN_HINTS) + language_hint_score(tgt, ENGLISH_HINTS) for src, tgt in sample)
first_is_eng = sum(language_hint_score(src, ENGLISH_HINTS) + language_hint_score(tgt, ITALIAN_HINTS) for src, tgt in sample)
if first_is_ita >= first_is_eng:
return 0, 1, "Italian -> English (column 0 -> column 1)"
return 1, 0, "Italian -> English (column 1 -> column 0)"
def split_parallel_line(line: str) -> List[str]:
if "\t" in line:
parts = [part.strip() for part in line.rstrip("\n").split("\t")]
elif "|||" in line:
parts = [part.strip() for part in line.rstrip("\n").split("|||")]
else:
parts = []
return [part for part in parts if part]
def load_sentence_pairs(path: Path, cfg: Config) -> Tuple[List[Dict[str, object]], str]:
if not path.exists():
raise FileNotFoundError(f"Dataset not found: {path}")
raw_pairs: List[Tuple[str, str]] = []
with path.open("r", encoding="utf-8") as handle:
for line in handle:
parts = split_parallel_line(line)
if len(parts) >= 2:
raw_pairs.append((parts[0], parts[1]))
if not raw_pairs:
raise ValueError("No sentence pairs were found in the dataset file.")
source_column, target_column, mapping_description = detect_translation_direction(raw_pairs)
examples: List[Dict[str, object]] = []
for first, second in raw_pairs:
columns = [first, second]
source_text = normalize_text(columns[source_column])
target_text = normalize_text(columns[target_column])
source_tokens = tokenize(source_text)
target_tokens = tokenize(target_text)
if not source_tokens or not target_tokens:
continue
if cfg.max_source_tokens is not None and len(source_tokens) > cfg.max_source_tokens:
continue
if cfg.max_target_tokens is not None and len(target_tokens) > cfg.max_target_tokens:
continue
examples.append(
{
"source_text": source_text,
"target_text": target_text,
"source_tokens": source_tokens,
"target_tokens": target_tokens,
}
)
if cfg.max_examples is not None:
rng = random.Random(cfg.random_seed)
rng.shuffle(examples)
examples = examples[: cfg.max_examples]
return examples, mapping_description
def split_examples(
examples: Sequence[Dict[str, object]],
cfg: Config,
) -> Tuple[List[Dict[str, object]], List[Dict[str, object]], List[Dict[str, object]]]:
shuffled = list(examples)
rng = random.Random(cfg.random_seed)
rng.shuffle(shuffled)
total_size = len(shuffled)
train_size = int(total_size * cfg.train_ratio)
val_size = int(total_size * cfg.val_ratio)
test_size = total_size - train_size - val_size
if min(train_size, val_size, test_size) <= 0:
raise ValueError("Dataset is too small for the requested train/val/test split.")
train_examples = shuffled[:train_size]
val_examples = shuffled[train_size : train_size + val_size]
test_examples = shuffled[train_size + val_size :]
return train_examples, val_examples, test_examples
examples, detected_mapping = load_sentence_pairs(config.dataset_path, config)
train_examples, val_examples, test_examples = split_examples(examples, config)
print(f"Detected direction: {detected_mapping}")
print(f"Total examples: {len(examples)}")
print(f"Train / val / test: {len(train_examples)} / {len(val_examples)} / {len(test_examples)}")
for preview_idx in range(min(3, len(train_examples))):
sample = train_examples[preview_idx]
print()
print(f"Example {preview_idx + 1}")
print(" source:", sample["source_text"])
print(" target:", sample["target_text"])
class Vocab:
def __init__(self, token_sequences: Sequence[Sequence[str]], min_freq: int = 1):
counter = Counter(token for sequence in token_sequences for token in sequence)
self.itos = list(SPECIAL_TOKENS)
for token, frequency in counter.most_common():
if frequency >= min_freq and token not in self.itos:
self.itos.append(token)
self.stoi = {token: idx for idx, token in enumerate(self.itos)}
def __len__(self) -> int:
return len(self.itos)
@property
def pad_idx(self) -> int:
return self.stoi[PAD_TOKEN]
@property
def bos_idx(self) -> int:
return self.stoi[BOS_TOKEN]
@property
def eos_idx(self) -> int:
return self.stoi[EOS_TOKEN]
@property
def unk_idx(self) -> int:
return self.stoi[UNK_TOKEN]
def encode(self, tokens: Sequence[str], add_bos: bool = False, add_eos: bool = False) -> List[int]:
ids = [self.stoi.get(token, self.unk_idx) for token in tokens]
if add_bos:
ids = [self.bos_idx] + ids
if add_eos:
ids = ids + [self.eos_idx]
return ids
def decode(self, ids: Sequence[int], stop_at_eos: bool = True, skip_special: bool = True) -> List[str]:
tokens: List[str] = []
for idx in ids:
token = self.itos[int(idx)]
if stop_at_eos and token == EOS_TOKEN:
break
if skip_special and token in SPECIAL_TOKENS:
continue
tokens.append(token)
return tokens
class TranslationDataset(Dataset):
def __init__(self, examples: Sequence[Dict[str, object]], source_vocab: Vocab, target_vocab: Vocab):
self.examples = list(examples)
self.source_vocab = source_vocab
self.target_vocab = target_vocab
def __len__(self) -> int:
return len(self.examples)
def __getitem__(self, index: int) -> Dict[str, object]:
example = self.examples[index]
source_ids = self.source_vocab.encode(example["source_tokens"], add_eos=True)
target_ids = self.target_vocab.encode(example["target_tokens"], add_bos=True, add_eos=True)
return {
"source_ids": torch.tensor(source_ids, dtype=torch.long),
"target_ids": torch.tensor(target_ids, dtype=torch.long),
"source_text": example["source_text"],
"target_text": example["target_text"],
}
source_vocab = Vocab((example["source_tokens"] for example in train_examples), min_freq=config.min_freq)
target_vocab = Vocab((example["target_tokens"] for example in train_examples), min_freq=config.min_freq)
def collate_translation_batch(batch: Sequence[Dict[str, object]]) -> Dict[str, object]:
source_ids = [item["source_ids"] for item in batch]
target_ids = [item["target_ids"] for item in batch]
source_lengths = torch.tensor([len(item) for item in source_ids], dtype=torch.long)
source_tokens = pad_sequence(source_ids, batch_first=True, padding_value=source_vocab.pad_idx)
target_tokens = pad_sequence(target_ids, batch_first=True, padding_value=target_vocab.pad_idx)
return {
"source_tokens": source_tokens,
"source_lengths": source_lengths,
"target_tokens": target_tokens,
"source_text": [item["source_text"] for item in batch],
"target_text": [item["target_text"] for item in batch],
}
train_dataset = TranslationDataset(train_examples, source_vocab, target_vocab)
val_dataset = TranslationDataset(val_examples, source_vocab, target_vocab)
test_dataset = TranslationDataset(test_examples, source_vocab, target_vocab)
train_loader = DataLoader(
train_dataset,
batch_size=config.batch_size,
shuffle=True,
collate_fn=collate_translation_batch,
)
val_loader = DataLoader(
val_dataset,
batch_size=config.batch_size,
shuffle=False,
collate_fn=collate_translation_batch,
)
test_loader = DataLoader(
test_dataset,
batch_size=config.batch_size,
shuffle=False,
collate_fn=collate_translation_batch,
)
print(f"Source vocab size: {len(source_vocab)}")
print(f"Target vocab size: {len(target_vocab)}")
print(f"Training batches per epoch: {len(train_loader)}")Encoder-Decoder with Additive Cross-Attention
The model is intentionally written step by step instead of hiding the logic inside a high-level seq2seq wrapper:
- the encoder turns the source sentence into a sequence of hidden states
- additive cross-attention scores every encoder state against the current decoder hidden state
- the attention-weighted context vector is concatenated with the current target embedding
- the decoder LSTMCell predicts the next target token autoregressively
This keeps the role of cross-attention explicit at every decoding step.
class Encoder(nn.Module):
def __init__(
self,
source_vocab_size: int,
embedding_dim: int,
hidden_dim: int,
decoder_hidden_dim: int,
pad_idx: int,
dropout: float,
):
super().__init__()
self.pad_idx = pad_idx
self.embedding = nn.Embedding(source_vocab_size, embedding_dim, padding_idx=pad_idx)
self.dropout = nn.Dropout(dropout)
self.lstm = nn.LSTM(
input_size=embedding_dim,
hidden_size=hidden_dim,
batch_first=True,
bidirectional=True,
)
self.hidden_bridge = nn.Linear(hidden_dim * 2, decoder_hidden_dim)
self.cell_bridge = nn.Linear(hidden_dim * 2, decoder_hidden_dim)
def forward(
self,
source_tokens: torch.Tensor,
source_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
embedded = self.dropout(self.embedding(source_tokens))
packed = pack_padded_sequence(embedded, source_lengths.cpu(), batch_first=True, enforce_sorted=False)
packed_outputs, (hidden, cell) = self.lstm(packed)
encoder_outputs, _ = pad_packed_sequence(
packed_outputs,
batch_first=True,
total_length=source_tokens.size(1),
)
final_hidden = torch.cat([hidden[-2], hidden[-1]], dim=-1)
final_cell = torch.cat([cell[-2], cell[-1]], dim=-1)
decoder_hidden = torch.tanh(self.hidden_bridge(final_hidden))
decoder_cell = torch.tanh(self.cell_bridge(final_cell))
source_mask = source_tokens.ne(self.pad_idx)
return encoder_outputs, source_mask, decoder_hidden, decoder_cell
class AdditiveCrossAttention(nn.Module):
def __init__(self, encoder_dim: int, decoder_dim: int, attention_dim: int):
super().__init__()
self.encoder_projection = nn.Linear(encoder_dim, attention_dim, bias=False)
self.decoder_projection = nn.Linear(decoder_dim, attention_dim, bias=False)
self.energy_projection = nn.Linear(attention_dim, 1, bias=False)
def forward(
self,
decoder_hidden: torch.Tensor,
encoder_outputs: torch.Tensor,
source_mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
projected_encoder = self.encoder_projection(encoder_outputs)
projected_decoder = self.decoder_projection(decoder_hidden).unsqueeze(1)
energy = torch.tanh(projected_encoder + projected_decoder)
scores = self.energy_projection(energy).squeeze(-1)
scores = scores.masked_fill(~source_mask, -1e9)
attention_weights = torch.softmax(scores, dim=-1)
context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs).squeeze(1)
return context, attention_weights
class Decoder(nn.Module):
def __init__(
self,
target_vocab_size: int,
embedding_dim: int,
decoder_hidden_dim: int,
encoder_output_dim: int,
attention_dim: int,
pad_idx: int,
dropout: float,
):
super().__init__()
self.embedding = nn.Embedding(target_vocab_size, embedding_dim, padding_idx=pad_idx)
self.dropout = nn.Dropout(dropout)
self.attention = AdditiveCrossAttention(
encoder_dim=encoder_output_dim,
decoder_dim=decoder_hidden_dim,
attention_dim=attention_dim,
)
self.lstm_cell = nn.LSTMCell(embedding_dim + encoder_output_dim, decoder_hidden_dim)
self.output_projection = nn.Linear(decoder_hidden_dim + encoder_output_dim, target_vocab_size)
def forward_step(
self,
input_tokens: torch.Tensor,
hidden: torch.Tensor,
cell: torch.Tensor,
encoder_outputs: torch.Tensor,
source_mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
embedded = self.dropout(self.embedding(input_tokens))
context, attention_weights = self.attention(
decoder_hidden=hidden,
encoder_outputs=encoder_outputs,
source_mask=source_mask,
)
decoder_input = torch.cat([embedded, context], dim=-1)
hidden, cell = self.lstm_cell(decoder_input, (hidden, cell))
logits = self.output_projection(torch.cat([hidden, context], dim=-1))
return logits, hidden, cell, attention_weights
class Seq2SeqTranslator(nn.Module):
def __init__(self, encoder: Encoder, decoder: Decoder, bos_idx: int, eos_idx: int):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.bos_idx = bos_idx
self.eos_idx = eos_idx
def forward(
self,
source_tokens: torch.Tensor,
source_lengths: torch.Tensor,
target_tokens: torch.Tensor,
teacher_forcing_ratio: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
encoder_outputs, source_mask, hidden, cell = self.encoder(source_tokens, source_lengths)
input_tokens = target_tokens[:, 0]
step_logits = []
attention_history = []
for step in range(1, target_tokens.size(1)):
logits, hidden, cell, attention_weights = self.decoder.forward_step(
input_tokens=input_tokens,
hidden=hidden,
cell=cell,
encoder_outputs=encoder_outputs,
source_mask=source_mask,
)
step_logits.append(logits.unsqueeze(1))
attention_history.append(attention_weights.unsqueeze(1))
use_teacher_forcing = self.training and random.random() < teacher_forcing_ratio
input_tokens = target_tokens[:, step] if use_teacher_forcing else logits.argmax(dim=-1)
return torch.cat(step_logits, dim=1), torch.cat(attention_history, dim=1)
@torch.no_grad()
def greedy_decode(
self,
source_tokens: torch.Tensor,
source_lengths: torch.Tensor,
max_steps: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
encoder_outputs, source_mask, hidden, cell = self.encoder(source_tokens, source_lengths)
batch_size = source_tokens.size(0)
input_tokens = torch.full(
(batch_size,),
fill_value=self.bos_idx,
dtype=torch.long,
device=source_tokens.device,
)
generated_tokens = []
attention_history = []
finished = torch.zeros(batch_size, dtype=torch.bool, device=source_tokens.device)
for _ in range(max_steps):
logits, hidden, cell, attention_weights = self.decoder.forward_step(
input_tokens=input_tokens,
hidden=hidden,
cell=cell,
encoder_outputs=encoder_outputs,
source_mask=source_mask,
)
next_tokens = logits.argmax(dim=-1)
next_tokens = torch.where(
finished,
torch.full_like(next_tokens, self.eos_idx),
next_tokens,
)
generated_tokens.append(next_tokens.unsqueeze(1))
attention_history.append(attention_weights.unsqueeze(1))
finished = finished | next_tokens.eq(self.eos_idx)
input_tokens = next_tokens
if finished.all():
break
return torch.cat(generated_tokens, dim=1), torch.cat(attention_history, dim=1)BLEU and Training Utilities
Validation is done with greedy autoregressive decoding and corpus BLEU. The BLEU implementation below is explicit and dependency-free so the notebook does not rely on external NLP packages.
def move_batch_to_device(batch: Dict[str, object], device: torch.device) -> Dict[str, object]:
moved_batch: Dict[str, object] = {}
for key, value in batch.items():
moved_batch[key] = value.to(device) if torch.is_tensor(value) else value
return moved_batch
def detokenize(tokens: Sequence[str]) -> str:
text = " ".join(tokens)
text = re.sub(r"\s+([?.!,;:])", r"\1", text)
text = re.sub(r"\(\s+", "(", text)
text = re.sub(r"\s+\)", ")", text)
text = text.replace(" n't", "n't")
return text.strip()
def ngram_counts(tokens: Sequence[str], n: int) -> Counter:
if len(tokens) < n:
return Counter()
return Counter(tuple(tokens[i : i + n]) for i in range(len(tokens) - n + 1))
def corpus_bleu_score(
references: Sequence[Sequence[str]],
hypotheses: Sequence[Sequence[str]],
max_n: int = 4,
smooth: float = 1.0,
) -> float:
if len(references) != len(hypotheses):
raise ValueError("References and hypotheses must have the same length.")
if not references:
return 0.0
clipped_matches = [0] * max_n
total_matches = [0] * max_n
reference_length = 0
hypothesis_length = 0
for reference, hypothesis in zip(references, hypotheses):
reference_length += len(reference)
hypothesis_length += len(hypothesis)
for n in range(1, max_n + 1):
reference_ngrams = ngram_counts(reference, n)
hypothesis_ngrams = ngram_counts(hypothesis, n)
clipped_matches[n - 1] += sum((reference_ngrams & hypothesis_ngrams).values())
total_matches[n - 1] += max(len(hypothesis) - n + 1, 0)
valid_orders = [order for order, total in enumerate(total_matches, start=1) if total > 0]
if not valid_orders or hypothesis_length == 0:
return 0.0
log_precision_sum = 0.0
weight = 1.0 / len(valid_orders)
for order in valid_orders:
matches = clipped_matches[order - 1]
total = total_matches[order - 1]
precision = (matches + smooth) / (total + smooth)
log_precision_sum += weight * math.log(precision)
brevity_penalty = 1.0
if hypothesis_length < reference_length:
brevity_penalty = math.exp(1.0 - reference_length / max(hypothesis_length, 1))
bleu = brevity_penalty * math.exp(log_precision_sum)
return bleu * 100.0
def decode_target_ids(token_ids: Sequence[int], vocab: Vocab) -> List[str]:
return vocab.decode(token_ids, stop_at_eos=True, skip_special=True)Sanity Check
Before launching the full training loop, run one mini-batch through the model. This verifies that the encoder, additive cross-attention, and autoregressive decoder agree on tensor shapes.
smoke_batch = next(iter(train_loader))
smoke_batch = {
key: (value[:8] if torch.is_tensor(value) else value[:8])
for key, value in smoke_batch.items()
}
smoke_batch = move_batch_to_device(smoke_batch, device)
smoke_encoder = Encoder(
source_vocab_size=len(source_vocab),
embedding_dim=config.embedding_dim,
hidden_dim=config.encoder_hidden_dim,
decoder_hidden_dim=config.decoder_hidden_dim,
pad_idx=source_vocab.pad_idx,
dropout=config.dropout,
).to(device)
smoke_decoder = Decoder(
target_vocab_size=len(target_vocab),
embedding_dim=config.embedding_dim,
decoder_hidden_dim=config.decoder_hidden_dim,
encoder_output_dim=config.encoder_hidden_dim * 2,
attention_dim=config.attention_dim,
pad_idx=target_vocab.pad_idx,
dropout=config.dropout,
).to(device)
smoke_model = Seq2SeqTranslator(
encoder=smoke_encoder,
decoder=smoke_decoder,
bos_idx=target_vocab.bos_idx,
eos_idx=target_vocab.eos_idx,
).to(device)
smoke_logits, smoke_attention = smoke_model(
source_tokens=smoke_batch["source_tokens"],
source_lengths=smoke_batch["source_lengths"],
target_tokens=smoke_batch["target_tokens"],
teacher_forcing_ratio=config.teacher_forcing_ratio,
)
expected_target_shape = smoke_batch["target_tokens"][:, 1:].shape
assert smoke_logits.shape[:2] == expected_target_shape
assert smoke_attention.shape[:2] == expected_target_shape
assert smoke_attention.shape[-1] == smoke_batch["source_tokens"].shape[1]
print("Logits shape:", tuple(smoke_logits.shape))
print("Attention shape:", tuple(smoke_attention.shape))
print("Expected target shape:", tuple(expected_target_shape))Training Loop
Each epoch optimizes cross-entropy on the target sequence. Validation uses autoregressive decoding with teacher forcing disabled, so BLEU reflects actual inference behaviour. The best checkpoint is the one with highest validation BLEU.
def train_one_epoch(
model: Seq2SeqTranslator,
dataloader: DataLoader,
optimizer: torch.optim.Optimizer,
criterion: nn.Module,
teacher_forcing_ratio: float,
grad_clip: float,
device: torch.device,
) -> float:
model.train()
total_loss = 0.0
for batch in dataloader:
batch = move_batch_to_device(batch, device)
optimizer.zero_grad()
logits, _ = model(
source_tokens=batch["source_tokens"],
source_lengths=batch["source_lengths"],
target_tokens=batch["target_tokens"],
teacher_forcing_ratio=teacher_forcing_ratio,
)
target_output = batch["target_tokens"][:, 1:]
loss = criterion(logits.reshape(-1, logits.size(-1)), target_output.reshape(-1))
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
optimizer.step()
total_loss += loss.item()
return total_loss / max(len(dataloader), 1)
@torch.no_grad()
def evaluate_bleu(
model: Seq2SeqTranslator,
dataloader: DataLoader,
target_vocab: Vocab,
device: torch.device,
max_steps: int,
) -> float:
model.eval()
references: List[List[str]] = []
hypotheses: List[List[str]] = []
for batch in dataloader:
batch = move_batch_to_device(batch, device)
generated_ids, _ = model.greedy_decode(
source_tokens=batch["source_tokens"],
source_lengths=batch["source_lengths"],
max_steps=max_steps,
)
for reference_ids, hypothesis_ids in zip(batch["target_tokens"], generated_ids):
references.append(decode_target_ids(reference_ids.tolist(), target_vocab))
hypotheses.append(decode_target_ids(hypothesis_ids.tolist(), target_vocab))
return corpus_bleu_score(references, hypotheses)
@torch.no_grad()
def translate_sentence(
model: Seq2SeqTranslator,
sentence: str,
source_vocab: Vocab,
target_vocab: Vocab,
device: torch.device,
max_steps: int,
) -> Tuple[str, List[str], torch.Tensor]:
model.eval()
source_tokens = tokenize(sentence)
source_ids = source_vocab.encode(source_tokens, add_eos=True)
source_tensor = torch.tensor(source_ids, dtype=torch.long, device=device).unsqueeze(0)
source_lengths = torch.tensor([len(source_ids)], dtype=torch.long, device=device)
generated_ids, attention_history = model.greedy_decode(
source_tokens=source_tensor,
source_lengths=source_lengths,
max_steps=max_steps,
)
predicted_tokens = decode_target_ids(generated_ids[0].tolist(), target_vocab)
return detokenize(predicted_tokens), predicted_tokens, attention_history[0].cpu()
def plot_attention(source_tokens: Sequence[str], predicted_tokens: Sequence[str], attention_matrix: torch.Tensor) -> None:
if not HAS_MATPLOTLIB:
print("Matplotlib is not available in this environment, skipping attention plot.")
return
if len(predicted_tokens) == 0:
print("No generated tokens available for attention plotting.")
return
trimmed_attention = attention_matrix[: len(predicted_tokens), : len(source_tokens)]
figure, axis = plt.subplots(figsize=(max(6, 0.6 * len(source_tokens)), max(4, 0.6 * len(predicted_tokens))))
image = axis.imshow(trimmed_attention, aspect="auto", cmap="viridis")
axis.set_xticks(range(len(source_tokens)))
axis.set_xticklabels(source_tokens, rotation=45, ha="right")
axis.set_yticks(range(len(predicted_tokens)))
axis.set_yticklabels(predicted_tokens)
axis.set_xlabel("Source tokens")
axis.set_ylabel("Generated target tokens")
axis.set_title("Additive cross-attention weights")
figure.colorbar(image, ax=axis)
plt.show()Train the Translator
The training loss is optimized with teacher forcing, while validation BLEU is computed without teacher forcing. This separates optimization behaviour from actual autoregressive inference quality.
encoder = Encoder(
source_vocab_size=len(source_vocab),
embedding_dim=config.embedding_dim,
hidden_dim=config.encoder_hidden_dim,
decoder_hidden_dim=config.decoder_hidden_dim,
pad_idx=source_vocab.pad_idx,
dropout=config.dropout,
)
decoder = Decoder(
target_vocab_size=len(target_vocab),
embedding_dim=config.embedding_dim,
decoder_hidden_dim=config.decoder_hidden_dim,
encoder_output_dim=config.encoder_hidden_dim * 2,
attention_dim=config.attention_dim,
pad_idx=target_vocab.pad_idx,
dropout=config.dropout,
)
model = Seq2SeqTranslator(
encoder=encoder,
decoder=decoder,
bos_idx=target_vocab.bos_idx,
eos_idx=target_vocab.eos_idx,
).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=target_vocab.pad_idx)
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
history = {"train_loss": [], "val_bleu": []}
best_val_bleu = -1.0
best_state_dict = None
for epoch in range(1, config.epochs + 1):
start_time = time.time()
train_loss = train_one_epoch(
model=model,
dataloader=train_loader,
optimizer=optimizer,
criterion=criterion,
teacher_forcing_ratio=config.teacher_forcing_ratio,
grad_clip=config.grad_clip,
device=device,
)
val_bleu = evaluate_bleu(
model=model,
dataloader=val_loader,
target_vocab=target_vocab,
device=device,
max_steps=config.max_generation_tokens,
)
history["train_loss"].append(train_loss)
history["val_bleu"].append(val_bleu)
if val_bleu > best_val_bleu:
best_val_bleu = val_bleu
best_state_dict = copy.deepcopy(model.state_dict())
elapsed = time.time() - start_time
print(
f"Epoch {epoch:02d}/{config.epochs} | "
f"train loss = {train_loss:.4f} | "
f"val BLEU = {val_bleu:.2f} | "
f"time = {elapsed:.1f}s"
)
if best_state_dict is not None:
model.load_state_dict(best_state_dict)
print(f"Best validation BLEU: {best_val_bleu:.2f}")
print(history)Test-Time Inference
Final inference is evaluated on the held-out test split with BLEU, then a few greedy translations are printed. The first sample also displays its cross-attention map so the alignment behaviour is visible.
test_bleu = evaluate_bleu(
model=model,
dataloader=test_loader,
target_vocab=target_vocab,
device=device,
max_steps=config.max_generation_tokens,
)
print(f"Test BLEU: {test_bleu:.2f}")
rng = random.Random(config.random_seed)
sample_indices = rng.sample(range(len(test_examples)), k=min(5, len(test_examples)))
for rank, sample_index in enumerate(sample_indices, start=1):
sample = test_examples[sample_index]
predicted_text, predicted_tokens, attention_matrix = translate_sentence(
model=model,
sentence=sample["source_text"],
source_vocab=source_vocab,
target_vocab=target_vocab,
device=device,
max_steps=config.max_generation_tokens,
)
print()
print(f"Sample {rank}")
print(" source :", sample["source_text"])
print(" reference:", sample["target_text"])
print(" predicted:", predicted_text)
if rank == 1:
plot_attention(sample["source_tokens"], predicted_tokens, attention_matrix)Translate Your Own Sentence
Edit the sentence below with any Italian input you want to translate. This section reuses the trained encoder-decoder and also shows the attention map for the generated English translation.
user_sentence = "Mi piace molto studiare all'università le materie di informatica e intelligenza artificiale."
required_objects = [
"model",
"source_vocab",
"target_vocab",
"config",
"device",
"translate_sentence",
]
missing_objects = [name for name in required_objects if name not in globals()]
if missing_objects:
print(
"This translation cell needs the setup, utility, and training cells above to be run first. "
"Run those cells, then run this cell again.\n"
f"Missing objects: {missing_objects}"
)
else:
translated_text, predicted_tokens, attention_matrix = translate_sentence(
model=model,
sentence=user_sentence,
source_vocab=source_vocab,
target_vocab=target_vocab,
device=device,
max_steps=config.max_generation_tokens,
)
print("Input sentence :", user_sentence)
print("Translation :", translated_text)
plot_attention(tokenize(user_sentence), predicted_tokens, attention_matrix)