from pathlib import Path
from urllib.request import Request, urlopen
import json
import shutil
import zipfile
 
RESOURCE_MANIFESTS = [
    "https://assets.deeplearningnotes.com/code-support-resources/datasets/books/latest.json",
]
 
def download_file(url, path):
    request = Request(url, headers={"User-Agent": "Mozilla/5.0"})
    with urlopen(request) as response, open(path, "wb") as file:
        shutil.copyfileobj(response, file)
 
def extract_zip_safely(zip_path, target_dir="."):
    target_dir = Path(target_dir).resolve()
 
    with zipfile.ZipFile(zip_path, "r") as zip_ref:
        for member in zip_ref.infolist():
            target_path = (target_dir / member.filename).resolve()
            if not str(target_path).startswith(str(target_dir)):
                raise RuntimeError(f"Unsafe zip path: {member.filename}")
 
        zip_ref.extractall(target_dir)
 
for manifest_url in RESOURCE_MANIFESTS:
    name = manifest_url.rstrip("/").split("/")[-2]
    manifest_path = Path(f"{name}-latest.json")
    archive_path = Path(f"{name}.zip")
 
    download_file(manifest_url, manifest_path)
    manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
 
    expected_paths = [Path(path) for path in manifest.get("expected_paths", [])]
 
    if not all(path.exists() for path in expected_paths):
        download_file(manifest["archive_url"], archive_path)
        extract_zip_safely(archive_path, manifest.get("extract_to", "."))
 
    missing = [str(path) for path in expected_paths if not path.exists()]
    if missing:
        raise FileNotFoundError(f"Missing expected paths: {missing}")
 
    print(f"{manifest['name']} ready.")

Use local books

Load corpus files from the local ./datasets/books folder.

import os
 
books_dir = "./datasets/books"
if not os.path.isdir(books_dir):
    raise FileNotFoundError(f"Books folder not found: {books_dir}")
 
available_book_files = sorted([
    file_name for file_name in os.listdir(books_dir)
    if file_name.endswith(".txt") and os.path.isfile(os.path.join(books_dir, file_name))
])
 
if not available_book_files:
    raise FileNotFoundError(f"No .txt books found in {books_dir}")
 
print("Books found in ./datasets/books:")
for file_name in available_book_files:
    print(f" - {file_name}")

Imports

# import the required packages
import os
import time
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import clear_output
from tokenizers import ByteLevelBPETokenizer

Set hyperparameters and options

Set here your hyperparameters (to be used later in the code), so that you can run and compare different experiments operating on these values.

Note

A better alternative would be to use command-line arguments to set hyperparameters and other options (see argparse Python package).

# hyperparameters (Transformer-only, quality-first)
batch_size = 32          # number of independent token streams per optimization step
learning_rate = 3e-4     # initial AdamW learning rate
epochs = 200              # maximum number of training epochs
seq_length = 128         # context window used for next-token prediction during training
num_samples = 700        # default maximum number of generated tokens at inference
num_patience = 10        # early stopping patience on validation loss (epochs)
min_delta = 1e-4         # minimum val-loss improvement to reset early-stopping counter
weight_decay = 1e-2      # L2-style regularization strength in AdamW
lr_factor = 0.5          # multiplicative LR drop when ReduceLROnPlateau triggers
lr_patience = 3          # epochs with no val-loss improvement before LR is reduced
input_token_dropout = 0.05  # probability of replacing an input token with <unk> during training
 
# transformer architecture defaults (used when building TransformerLanguageModel)
transformer_d_model = 384          # token/position embedding size and hidden width
transformer_nhead = 8              # number of self-attention heads per layer
transformer_num_layers = 4         # number of stacked Transformer encoder blocks
transformer_ff_mult = 4            # feed-forward expansion factor (ff_dim = ff_mult * d_model)
transformer_dropout = 0.15         # dropout applied inside Transformer blocks
transformer_max_seq_len = max(512, seq_length * 4)  # maximum context length supported by positional embeddings
transformer_label_smoothing = 0.05 # label smoothing used in cross-entropy during training
 
# generation defaults (aligned with current best config)
generation_seed = 'Deep Learning is'  # default prompt seed used in generation cells
temperature = 0.80       # softmax temperature: lower = safer, higher = more random
top_k = 60               # sample only from the k most likely tokens
top_p = 0.9              # nucleus sampling threshold (smallest token set with cumulative prob >= p)
repetition_penalty = 1.0 # >1 discourages repeating seen tokens; 1.0 disables this penalty
presence_penalty = 0.0   # subtracts a fixed penalty for tokens already seen in the generated text
frequency_penalty = 0.0  # subtracts a count-scaled penalty for frequently repeated tokens
min_gen_tokens = 80      # minimum generated length before allowing EOS-based early stop
 
# data options: hardcoded WikiText train/test corpora
train_book_name = 'wikitext2-train.txt'
val_book_name = 'wikitext2-test.txt'
train_book_file = os.path.join(books_dir, train_book_name)
val_book_file = os.path.join(books_dir, val_book_name)
print('Training corpus:', train_book_name)
print('Validation corpus:', val_book_name)
 
# local checkpoint location
checkpoint_dir = os.path.abspath('transformer_results')
checkpoint_filename = 'transformer_best_model.pt'
 
device = "cuda" if torch.cuda.is_available() else "cpu"  # training/inference device
 
# print selected device (also GPU id if available)
print(f"Selected device: {device}")
if device == "cuda":
    print(f"GPU available: {torch.cuda.get_device_name(0)}")

Define the Transformer architecture

This notebook now uses only the Transformer language model (LSTM removed).

class TransformerLanguageModel(nn.Module):
    def __init__(self, vocab_size, d_model=256, nhead=8, num_layers=4, ff_mult=4, dropout=0.2, max_seq_len=1024):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_seq_len, d_model)
        self.drop = nn.Dropout(dropout)
 
        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=ff_mult * d_model,
            dropout=dropout,
            activation='gelu',
            batch_first=True,
            norm_first=True
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers)
        self.norm = nn.LayerNorm(d_model)
        self.fc = nn.Linear(d_model, vocab_size, bias=False)
        self.fc.weight = self.token_emb.weight  # weight tying
        self.vocab_size = vocab_size
        self.max_seq_len = max_seq_len
 
    def _causal_mask(self, seq_len, device):
        return torch.triu(torch.ones(seq_len, seq_len, device=device, dtype=torch.bool), diagonal=1)
 
    def forward(self, x):
        bsz, seq_len = x.shape
        if seq_len > self.max_seq_len:
            x = x[:, -self.max_seq_len:]
            seq_len = x.size(1)
 
        pos = torch.arange(seq_len, device=x.device).unsqueeze(0).expand(bsz, seq_len)
        h = self.token_emb(x) + self.pos_emb(pos)
        h = self.drop(h)
        h = self.encoder(h, mask=self._causal_mask(seq_len, x.device))
        h = self.norm(h)
        logits = self.fc(h)
        return logits.reshape(-1, self.vocab_size)

Create datasets

Build tokenizer and train/validation token tensors.

# use hardcoded WikiText train/test corpora
missing_corpus_files = [
    file_path for file_path in [train_book_file, val_book_file]
    if not os.path.exists(file_path)
]
if missing_corpus_files:
    raise FileNotFoundError(
        'Required corpus files not found in ./datasets/books: '
        f'{missing_corpus_files}. Add wikitext2-train.txt and wikitext2-test.txt.'
    )
 
print('Using training corpus:', os.path.basename(train_book_file))
print('Using validation corpus:', os.path.basename(val_book_file))
 
# Build tokenizer on the training corpus only
tokenizer = ByteLevelBPETokenizer()
tokenizer.train(
    files=[train_book_file],
    vocab_size=5000,
    min_frequency=2,
    special_tokens=['<pad>', '<eos>', '<unk>']
)
 
eos_token_id = tokenizer.token_to_id('<eos>')
unk_token_id = tokenizer.token_to_id('<unk>')
 
def _load_token_grid(file_path, split_name):
    with open(file_path, 'r', encoding='utf-8') as f:
        split_lines = [line.strip() for line in f if line.strip()]
 
    if not split_lines:
        raise ValueError(f'{split_name} corpus is empty: {file_path}')
 
    split_text = ' <eos> '.join(split_lines) + ' <eos>'
    split_token_ids = tokenizer.encode(split_text).ids
    split_ids = torch.tensor(split_token_ids, dtype=torch.long)
 
    print(f'{split_name} original ids are {split_ids.size(0)}')
    num_full_batches = int(split_ids.size(0) / batch_size)
    if num_full_batches == 0:
        raise ValueError(
            f'{split_name} corpus is too small for batch_size={batch_size}. '
            'Choose a smaller batch size or a larger corpus.'
        )
 
    split_ids = split_ids[:num_full_batches * batch_size]
    print(f'{split_name} trimmed  ids are {split_ids.size(0)} (now multiple of batch size {batch_size})')
    print(f'{split_name} reshape  ids from {split_ids.shape} to {split_ids.view(batch_size, -1).shape}')
    return split_ids.view(batch_size, -1)
 
train_ids = _load_token_grid(train_book_file, 'train')
val_ids = _load_token_grid(val_book_file, 'val')
 
vocab_size = tokenizer.get_vocab_size()
print('tokenizer vocab size is %d' % vocab_size)
 
num_train_batches = max(1, int((train_ids.size(1) - 1) / seq_length))
num_val_batches = max(1, int((val_ids.size(1) - 1) / seq_length))
print('train shape:', tuple(train_ids.shape), '| val shape:', tuple(val_ids.shape))
print('num_train_batches:', num_train_batches, '| num_val_batches:', num_val_batches)

Language-model utilities

Define model-agnostic decoding and generation helpers (usable beyond Transformers).

def _get_banned_tokens_ngram(generated_ids, ngram_size):
    if len(generated_ids) < ngram_size - 1 or ngram_size < 2:
        return set()
    prefix = tuple(generated_ids[-(ngram_size - 1):])
    banned = set()
    for i in range(len(generated_ids) - ngram_size + 1):
        if tuple(generated_ids[i:i + ngram_size - 1]) == prefix:
            banned.add(generated_ids[i + ngram_size - 1])
    return banned
 
 
def _sample_next_token_with_constraints(
    logits,
    generated_ids,
    temperature=0.9,
    top_k=50,
    top_p=0.9,
    repetition_penalty=1.18,
    presence_penalty=0.20,
    frequency_penalty=0.25,
    no_repeat_ngram_size=4,
    max_same_token_run=2
):
    if len(generated_ids) > 0 and (presence_penalty > 0 or frequency_penalty > 0):
        token_counts = {}
        for token_id in generated_ids:
            token_counts[token_id] = token_counts.get(token_id, 0) + 1
        penalty_ids = torch.tensor(list(token_counts.keys()), device=logits.device, dtype=torch.long)
        counts = torch.tensor([token_counts[t] for t in token_counts.keys()], device=logits.device, dtype=logits.dtype)
        logits[penalty_ids] = logits[penalty_ids] - presence_penalty - frequency_penalty * counts
 
    if repetition_penalty != 1.0 and len(generated_ids) > 0:
        unique_tokens = set(generated_ids)
        token_ids = torch.tensor(list(unique_tokens), device=logits.device, dtype=torch.long)
        penalized_logits = logits[token_ids]
        penalized_logits = torch.where(
            penalized_logits < 0,
            penalized_logits * repetition_penalty,
            penalized_logits / repetition_penalty
        )
        logits[token_ids] = penalized_logits
 
    if len(generated_ids) >= max_same_token_run:
        recent = generated_ids[-max_same_token_run:]
        if all(tok == recent[0] for tok in recent):
            logits[recent[0]] = -1e10
 
    for banned_id in _get_banned_tokens_ngram(generated_ids, no_repeat_ngram_size):
        logits[banned_id] = -1e10
 
    logits = logits / max(temperature, 1e-6)
 
    if top_k is not None and top_k > 0:
        topk_vals, _ = torch.topk(logits, min(top_k, logits.size(-1)))
        kth = topk_vals[-1]
        logits = torch.where(logits < kth, torch.full_like(logits, -1e10), logits)
 
    if top_p is not None and 0 < top_p < 1.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        sorted_probs = torch.softmax(sorted_logits, dim=-1)
        cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
        remove_mask = cumulative_probs > top_p
        remove_mask[0] = False
        sorted_logits = sorted_logits.masked_fill(remove_mask, -1e10)
        logits_filtered = torch.full_like(logits, -1e10)
        logits_filtered[sorted_indices] = sorted_logits
        logits = logits_filtered
 
    probs = torch.softmax(logits, dim=-1)
    return torch.multinomial(probs, num_samples=1).item()
 
 
def generate_transformer_text(
    model,
    seed='Winter is coming',
    max_new_tokens=None,
    temperature=None,
    top_k=None,
    top_p=None,
    repetition_penalty=None,
    presence_penalty=None,
    frequency_penalty=None,
    no_repeat_ngram_size=4,
    max_same_token_run=2,
    context_window=None,
    blocked_token_ids=None,
    block_special_tokens=True
):
    if max_new_tokens is None:
        max_new_tokens = num_samples
    if temperature is None:
        temperature = globals().get('temperature', 0.82)
    if top_k is None:
        top_k = globals().get('top_k', 60)
    if top_p is None:
        top_p = globals().get('top_p', 0.9)
    if repetition_penalty is None:
        repetition_penalty = globals().get('repetition_penalty', 1.0)
    if presence_penalty is None:
        presence_penalty = globals().get('presence_penalty', 0.0)
    if frequency_penalty is None:
        frequency_penalty = globals().get('frequency_penalty', 0.0)
 
    if context_window is None:
        context_window = min(seq_length, model.max_seq_len)
    else:
        context_window = min(max(1, int(context_window)), model.max_seq_len)
 
    blocked_token_ids = set(blocked_token_ids or [])
    if block_special_tokens:
        pad_token_id = tokenizer.token_to_id('<pad>')
        if pad_token_id is not None:
            blocked_token_ids.add(pad_token_id)
        if unk_token_id is not None:
            blocked_token_ids.add(unk_token_id)
 
    model.eval()
    with torch.no_grad():
        seed_ids = tokenizer.encode(seed).ids
        if len(seed_ids) == 0:
            seed_ids = [unk_token_id]
 
        generated_ids = list(seed_ids)
 
        for step in range(max_new_tokens):
            context_ids = generated_ids[-context_window:]
            x = torch.tensor(context_ids, device=device).unsqueeze(0)
            logits = model(x).view(1, -1, vocab_size)[0, -1].clone()
 
            for blocked_id in blocked_token_ids:
                logits[blocked_id] = -1e10
 
            next_id = _sample_next_token_with_constraints(
                logits,
                generated_ids,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
                repetition_penalty=repetition_penalty,
                presence_penalty=presence_penalty,
                frequency_penalty=frequency_penalty,
                no_repeat_ngram_size=no_repeat_ngram_size,
                max_same_token_run=max_same_token_run
            )
            generated_ids.append(next_id)
 
            if next_id == eos_token_id and step >= min_gen_tokens:
                break
 
        text = tokenizer.decode(generated_ids, skip_special_tokens=False)
        text = text.replace('<unk>', '').replace('<eos>', '\n').replace('<pad>', '').strip()
        return text

Train the Transformer

First run the next cell to define the training function, then run the following cell to launch the actual training experiment and save the best checkpoint in transformer_results.

def train_transformer_baseline(
    model,
    train_ids,
    val_ids,
    epochs_to_run=12,
    early_stop_patience=None,
    label_smoothing=0.05,
    checkpoint_path='transformer_best_model.pt'
 ):
    checkpoint_path = os.path.abspath(checkpoint_path)
    checkpoint_parent = os.path.dirname(checkpoint_path)
    if checkpoint_parent:
        os.makedirs(checkpoint_parent, exist_ok=True)
 
    criterion_train = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
    criterion_eval = nn.CrossEntropyLoss()
    optimizer_t = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler_t = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer_t, mode='min', factor=lr_factor, patience=lr_patience, min_lr=1e-5
    )
 
    train_objective_hist = []
    train_hist, val_hist = [], []
    train_ppl_hist, val_ppl_hist = [], []
    best_val = float('inf')
    best_val_perplexity = float('inf')
    best_epoch_t = 0
    patience_counter_t = 0
    best_state = None
    patience_limit = num_patience if early_stop_patience is None else early_stop_patience
 
    model.to(device)
 
    def _evaluate_chunk(token_grid):
        loss_sum = 0.0
        steps = 0
 
        model.eval()
        with torch.no_grad():
            for i in range(0, token_grid.size(1) - seq_length - 1, seq_length):
                inputs = token_grid[:, i:i + seq_length].to(device)
                targets = token_grid[:, (i + 1):(i + 1) + seq_length].to(device)
                logits = model(inputs)
                loss = criterion_eval(logits, targets.reshape(-1))
                loss_sum += loss.item()
                steps += 1
 
        avg_loss = loss_sum / max(1, steps)
        perplexity = float(np.exp(avg_loss)) if avg_loss < 20 else float('inf')
        return avg_loss, perplexity
 
    for epoch in range(1, epochs_to_run + 1):
        t0 = time.time()
 
        train_max_offset = max(0, train_ids.size(1) - (num_train_batches * seq_length + 1))
        train_offset = np.random.randint(0, train_max_offset + 1) if train_max_offset > 0 else 0
        train_end = train_offset + num_train_batches * seq_length + 1
        train_chunk = train_ids[:, train_offset:train_end]
 
        val_max_offset = max(0, val_ids.size(1) - (num_val_batches * seq_length + 1))
        val_offset = np.random.randint(0, val_max_offset + 1) if val_max_offset > 0 else 0
        val_end = val_offset + num_val_batches * seq_length + 1
        val_chunk = val_ids[:, val_offset:val_end]
 
        model.train()
        train_objective_sum = 0.0
        train_steps = 0
 
        for i in range(0, train_chunk.size(1) - seq_length - 1, seq_length):
            inputs = train_chunk[:, i:i + seq_length].to(device)
            targets = train_chunk[:, (i + 1):(i + 1) + seq_length].to(device)
 
            if input_token_dropout > 0:
                dropout_mask = torch.rand_like(inputs.float()) < input_token_dropout
                inputs = inputs.masked_fill(dropout_mask, unk_token_id)
 
            optimizer_t.zero_grad()
            logits = model(inputs)
            loss = criterion_train(logits, targets.reshape(-1))
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer_t.step()
 
            train_objective_sum += loss.item()
            train_steps += 1
 
        epoch_train_objective_loss = train_objective_sum / max(1, train_steps)
        epoch_train_loss, epoch_train_perplexity = _evaluate_chunk(train_chunk)
        epoch_val_loss, epoch_val_perplexity = _evaluate_chunk(val_chunk)
 
        scheduler_t.step(epoch_val_loss)
        current_lr = optimizer_t.param_groups[0]['lr']
 
        train_objective_hist.append(epoch_train_objective_loss)
        train_hist.append(epoch_train_loss)
        val_hist.append(epoch_val_loss)
        train_ppl_hist.append(epoch_train_perplexity)
        val_ppl_hist.append(epoch_val_perplexity)
 
        improved = epoch_val_loss < (best_val - min_delta)
        if improved:
            best_val = epoch_val_loss
            best_val_perplexity = epoch_val_perplexity
            best_epoch_t = epoch
            patience_counter_t = 0
            best_state = copy.deepcopy(model.state_dict())
 
            checkpoint_payload = {
                'model_state_dict': best_state,
                'best_val_loss': best_val,
                'best_val_perplexity': best_val_perplexity,
                'best_epoch': best_epoch_t,
                'vocab_size': vocab_size,
                'model_config': {
                    'd_model': transformer_d_model,
                    'nhead': transformer_nhead,
                    'num_layers': transformer_num_layers,
                    'ff_mult': transformer_ff_mult,
                    'dropout': transformer_dropout,
                    'max_seq_len': transformer_max_seq_len
                }
            }
            torch.save(checkpoint_payload, checkpoint_path)
            print(
                f"[Transformer] New best model saved to {checkpoint_path} "
                f"(val_nll={best_val:.4f}, val_ppl={best_val_perplexity:.2f}, epoch={best_epoch_t})."
            )
        else:
            patience_counter_t += 1
 
        print(
            f"[Transformer] Epoch {epoch:02d}/{epochs_to_run} | time: {time.time()-t0:.1f}s | "
            f"train_obj: {epoch_train_objective_loss:.4f} | train_nll: {epoch_train_loss:.4f} | "
            f"train_ppl: {epoch_train_perplexity:.2f} | val_nll: {epoch_val_loss:.4f} | "
            f"val_ppl: {epoch_val_perplexity:.2f} | best_val_ppl: {best_val_perplexity:.2f} | "
            f"lr: {current_lr:.6f} | patience: {patience_counter_t}/{patience_limit}"
        )
 
        if patience_counter_t >= patience_limit:
            print(f"[Transformer] Early stopping at epoch {epoch}.")
            break
 
    if best_state is not None:
        model.load_state_dict(best_state)
        print(
            f"[Transformer] Restored best model from epoch {best_epoch_t} "
            f"(val_nll={best_val:.4f}, val_ppl={best_val_perplexity:.2f})."
        )
 
    return {
        'model': model,
        'train_objective_losses': train_objective_hist,
        'train_losses': train_hist,
        'val_losses': val_hist,
        'train_perplexities': train_ppl_hist,
        'val_perplexities': val_ppl_hist,
        'best_val_loss': best_val,
        'best_val_perplexity': best_val_perplexity,
        'best_epoch': best_epoch_t,
        'checkpoint_path': checkpoint_path
    }
 
# training experiment block (no generation)
transformer_epochs = max(epochs, 60)
transformer_patience = max(num_patience, 16)
 
os.makedirs(checkpoint_dir, exist_ok=True)
transformer_checkpoint_path = os.path.join(checkpoint_dir, checkpoint_filename)
print(f'Current working directory: {os.getcwd()}')
print(f'Checkpoint target path: {transformer_checkpoint_path}')
 
transformer_model = TransformerLanguageModel(
    vocab_size=vocab_size,
    d_model=transformer_d_model,
    nhead=transformer_nhead,
    num_layers=transformer_num_layers,
    ff_mult=transformer_ff_mult,
    dropout=transformer_dropout,
    max_seq_len=transformer_max_seq_len
)
num_trainable_params = sum(
    parameter.numel() for parameter in transformer_model.parameters()
    if parameter.requires_grad
)
print(f'Trainable parameters: {num_trainable_params:,} ({num_trainable_params / 1e6:.2f}M)')
 
transformer_results = train_transformer_baseline(
    transformer_model,
    train_ids=train_ids,
    val_ids=val_ids,
    epochs_to_run=transformer_epochs,
    early_stop_patience=transformer_patience,
    label_smoothing=transformer_label_smoothing,
    checkpoint_path=transformer_checkpoint_path
)
 
fig, axes = plt.subplots(1, 2, figsize=(13, 4))
axes[0].plot(transformer_results['train_losses'], marker='o', linewidth=2, label='Train NLL')
axes[0].plot(transformer_results['val_losses'], marker='o', linewidth=2, label='Validation NLL')
axes[0].set_title('Transformer: Evaluation Loss')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('NLL')
axes[0].grid(alpha=0.3)
axes[0].legend()
 
axes[1].plot(transformer_results['train_perplexities'], marker='o', linewidth=2, label='Train PPL')
axes[1].plot(transformer_results['val_perplexities'], marker='o', linewidth=2, label='Validation PPL')
axes[1].set_title('Transformer: Perplexity')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Perplexity')
axes[1].grid(alpha=0.3)
axes[1].legend()
 
plt.tight_layout()
plt.show()
 
print(f"Best Transformer val NLL: {transformer_results['best_val_loss']:.4f} at epoch {transformer_results['best_epoch']}")
print(f"Best Transformer val perplexity: {transformer_results['best_val_perplexity']:.2f}")
print(f"Best checkpoint saved at: {transformer_results['checkpoint_path']}")

Generate text (no retraining)

Run this cell any time after training to generate new text from the saved Transformer model in memory.

# generation-only block (no training)
default_checkpoint_path = os.path.join(os.path.abspath(checkpoint_dir), checkpoint_filename)
 
transformer_checkpoint_path = globals().get('transformer_checkpoint_path', default_checkpoint_path)
transformer_checkpoint_path = os.path.abspath(transformer_checkpoint_path)
 
print(f'Current working directory: {os.getcwd()}')
print(f'Looking for checkpoint at: {transformer_checkpoint_path}')
 
if not os.path.exists(transformer_checkpoint_path):
    raise RuntimeError(
        f"Checkpoint not found at {transformer_checkpoint_path!r}. Run training first to save the best model."
    )
 
checkpoint = torch.load(transformer_checkpoint_path, map_location=device)
checkpoint_config = checkpoint.get('model_config', {})
 
model_for_generation = TransformerLanguageModel(
    vocab_size=checkpoint.get('vocab_size', vocab_size),
    d_model=checkpoint_config.get('d_model', transformer_d_model),
    nhead=checkpoint_config.get('nhead', transformer_nhead),
    num_layers=checkpoint_config.get('num_layers', transformer_num_layers),
    ff_mult=checkpoint_config.get('ff_mult', transformer_ff_mult),
    dropout=checkpoint_config.get('dropout', transformer_dropout),
    max_seq_len=checkpoint_config.get('max_seq_len', transformer_max_seq_len)
).to(device)
model_for_generation.load_state_dict(checkpoint['model_state_dict'])
model_for_generation.eval()
 
# change generation_seed in the options cell or override sample_seed here
sample_seed = 'The Mediterranean Sea'  # default seed prompt for generation
sample_tokens = min(num_samples, 220)
 
generation_context_window = min(seq_length, model_for_generation.max_seq_len)
generation_temperature = 0.65
generation_top_k = 20
generation_top_p = 0.85
generation_repetition_penalty = 1.08
generation_presence_penalty = 0.05
generation_frequency_penalty = 0.05
 
sample_text = generate_transformer_text(
    model_for_generation,
    seed=sample_seed,
    max_new_tokens=sample_tokens,
    temperature=generation_temperature,
    top_k=generation_top_k,
    top_p=generation_top_p,
    repetition_penalty=generation_repetition_penalty,
    presence_penalty=generation_presence_penalty,
    frequency_penalty=generation_frequency_penalty,
    context_window=generation_context_window
)
 
print(f"Loaded checkpoint: {transformer_checkpoint_path}")
if 'best_val_loss' in checkpoint and 'best_epoch' in checkpoint:
    print(f"Checkpoint metrics -> best_val_loss: {checkpoint['best_val_loss']:.4f}, best_epoch: {checkpoint['best_epoch']}")
if 'best_val_perplexity' in checkpoint:
    print(f"Checkpoint perplexity -> best_val_perplexity: {checkpoint['best_val_perplexity']:.2f}")
 
print(
    f"Generation config -> context_window: {generation_context_window}, "
    f"temperature: {generation_temperature}, top_k: {generation_top_k}, top_p: {generation_top_p}"
)
print(f'Seed: {sample_seed!r}')
print(f'Generated tokens (max): {sample_tokens}')
print('\n--- SAMPLE START ---\n')
print(sample_text)
print('\n--- SAMPLE END ---')