from pathlib import Path
from urllib.request import Request, urlopen
import json
import shutil
import zipfile
 
RESOURCE_MANIFESTS = [
    "https://assets.deeplearningnotes.com/code-support-resources/datasets/books/latest.json",
]
 
def download_file(url, path):
    request = Request(url, headers={"User-Agent": "Mozilla/5.0"})
    with urlopen(request) as response, open(path, "wb") as file:
        shutil.copyfileobj(response, file)
 
def extract_zip_safely(zip_path, target_dir="."):
    target_dir = Path(target_dir).resolve()
 
    with zipfile.ZipFile(zip_path, "r") as zip_ref:
        for member in zip_ref.infolist():
            target_path = (target_dir / member.filename).resolve()
            if not str(target_path).startswith(str(target_dir)):
                raise RuntimeError(f"Unsafe zip path: {member.filename}")
 
        zip_ref.extractall(target_dir)
 
for manifest_url in RESOURCE_MANIFESTS:
    name = manifest_url.rstrip("/").split("/")[-2]
    manifest_path = Path(f"{name}-latest.json")
    archive_path = Path(f"{name}.zip")
 
    download_file(manifest_url, manifest_path)
    manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
 
    expected_paths = [Path(path) for path in manifest.get("expected_paths", [])]
 
    if not all(path.exists() for path in expected_paths):
        download_file(manifest["archive_url"], archive_path)
        extract_zip_safely(archive_path, manifest.get("extract_to", "."))
 
    missing = [str(path) for path in expected_paths if not path.exists()]
    if missing:
        raise FileNotFoundError(f"Missing expected paths: {missing}")
 
    print(f"{manifest['name']} ready.")

Download books

Choose your corpus

# use local books from ./datasets/books
import os
 
books_dir = "./datasets/books"
if not os.path.isdir(books_dir):
    raise FileNotFoundError(f"Books folder not found: {books_dir}")
 
book_files = sorted([
    f for f in os.listdir(books_dir)
    if f.endswith(".txt") and os.path.isfile(os.path.join(books_dir, f))
])
 
if not book_files:
    raise FileNotFoundError(f"No .txt books found in {books_dir}")
 
print("Books found in ./datasets/books:")
for f in book_files:
    print(f" - {f}")

Imports

# import the required packages
import os
import time
import torch
import torch.nn as nn
import numpy as np

Set hyperparameters and options

Set here your hyperparameters (to be used later in the code), so that you can run and compare different experiments operating on these values.

Note

A better alternative would be to use command-line arguments to set hyperparameters and other options (see argparse Python package).

# hyperparameters
batch_size = 32
learning_rate = 0.002
epochs = 20
hidden_neurons = 512    # hidden size per direction in the bidirectional LSTM
embed_size = 128        # embedding size
num_layers = 1          # number of LSTM layers
attention_neurons = 256 # hidden size of the attention scoring network
dropout_prob = 0.2
seq_length = 30         # length of the context window used to predict the next word
num_samples = 1000      # number of words to be sampled at testing phase
 
# options
training_set = os.path.join(books_dir, "grimms_tales.txt")
device = "cuda:0"   # force GPU with ID 0
 
# prefer GPU 0 when available, otherwise fall back to the best local device
requested_device = device
 
if requested_device == "cuda:0":
    if torch.cuda.is_available():
        torch.cuda.set_device(0)
        device = "cuda:0"
    elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        device = "mps"
        print("CUDA is not available. Falling back to Apple MPS.")
    else:
        device = "cpu"
        print("CUDA is not available. Falling back to CPU.")
elif requested_device == "cpu":
    device = "cpu"
else:
    raise ValueError(f"Unsupported requested device '{requested_device}'.")
 
# quick device sanity check
_device_test = torch.tensor([1.0], device=device)
_device_test = _device_test * 2.0
assert _device_test.item() == 2.0, f"Device sanity check failed on {device}."
 
if device == "cuda:0":
    print(f"Device check OK | device={device} | name={torch.cuda.get_device_name(0)}")
else:
    print(f"Device check OK | device={device}")

Define the model architecture

The network now uses a bidirectional LSTM to encode each context window of seq_length words. A gated soft-attention block learns one score for each hidden-state pair, normalizes the scores with a softmax, computes a weighted linear combination of all hidden states, concatenates this context vector with the last hidden state of the top bidirectional layer, and feeds the result to the next-word classifier.

class LanguageModel(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers, attention_size, dropout=0.0):
        super(LanguageModel, self).__init__()
 
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.dropout = nn.Dropout(dropout)
 
        # Bidirectional LSTM: each time step output is a pair of hidden states
        self.lstm = nn.LSTM(
            embed_size,
            hidden_size,
            num_layers,
            batch_first=True,
            bidirectional=True,
            dropout=dropout if num_layers > 1 else 0.0,
        )
 
        feature_size = hidden_size * 2
 
        # Gated soft attention over the sequence of hidden-state pairs
        self.attention_gate = nn.Linear(feature_size, feature_size)
        self.attention_hidden = nn.Linear(feature_size, attention_size)
        self.attention_score = nn.Linear(attention_size, 1)
 
        # Next-word classifier fed by [last_hidden ; attention_context]
        self.classifier = nn.Sequential(
            nn.Linear(feature_size * 2, feature_size),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(feature_size, vocab_size),
        )
 
    def forward(self, x):
        # embed word ids to float vectors
        x = self.dropout(self.embed(x))
 
        # out contains the hidden-state pair for every time step in the input window
        out, (h, c) = self.lstm(x)
 
        # last hidden state of the top bidirectional layer = [forward_last ; backward_last]
        last_hidden = torch.cat((h[-2], h[-1]), dim=1)
 
        # learn soft attention scores from gated hidden states
        gated_out = out * torch.sigmoid(self.attention_gate(out))
        attention_scores = self.attention_score(torch.tanh(self.attention_hidden(gated_out))).squeeze(-1)
        attention_weights = torch.softmax(attention_scores, dim=1)
 
        # weighted linear combination of all hidden states in the context window
        context = torch.sum(attention_weights.unsqueeze(-1) * out, dim=1)
 
        # combine global summary from attention with the last hidden state
        features = torch.cat((last_hidden, context), dim=1)
        logits = self.classifier(features)
 
        return logits, attention_weights

Create datasets

The corpus is still flattened into token ids, but training now uses each context window of seq_length tokens to predict only the following word. This matches the bidirectional LSTM plus attention architecture, which summarizes the whole window before making one next-word prediction.

# Dictionary class
# create dictionary on-the-fly by adding one word at the time
# every time a new word is added, it is added to the dictionary
# and an index is associated to it
class Dictionary(object):
    def __init__(self):
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 0
 
    def add_word(self, word):
        if not word in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1
 
    def __len__(self):
        return len(self.word2idx)
 
# Corpus class
# First pass scans a given text file and adds all words found to the Dictionary
# also the end of sentence <eos> is added as a 'word'
# Second pass scans the same text file and associates an id to each word found
# using the Dictionary previously built
class Corpus(object):
    def __init__(self):
        self.dictionary = Dictionary()
 
    def get_data(self, path, batch_size=32):
 
        # First scan: add words to the dictionary
        with open(path, 'r') as f:
            tokens = 0
            for line in f:
                words = line.split() + ['<eos>']
                tokens += len(words)
                for word in words:
                    self.dictionary.add_word(word)
 
        # Second scan: tokenize the file content
        ids = torch.LongTensor(tokens)
        token = 0
        with open(path, 'r') as f:
            for line in f:
                words = line.split() + ['<eos>']
                for word in words:
                    ids[token] = self.dictionary.word2idx[word]
                    token += 1
 
        # make ids length multiple of batch_size, and remove
        # a few ids at the end if needed
        print('original ids are %d' % ids.size(0))
        num_rows = int(ids.size(0) / batch_size)
        ids = ids[:num_rows*batch_size]
        print('trimmed  ids are %d (now multiple of batch size %d)' % (ids.size(0), batch_size))
 
        # reshape ids (1D tensor) to a 2D tensor of size batch_size x num_rows
        print('reshape  ids from %s to %s' % (ids.shape, ids.view(batch_size, -1).shape))
        return ids.view(batch_size, -1)
 
corpus = Corpus()
ids = corpus.get_data(training_set, batch_size)
vocab_size = len(corpus.dictionary)
print('dictionary size (#unique ids) is %d' % vocab_size)
 
# each batch sample is a context window of seq_length words used to predict one next word
window_starts = list(range(0, ids.size(1) - seq_length, seq_length))
num_batches = len(window_starts)
print('num_batches is %d since each batch sample is a context window of %d words used to predict the next word' % (num_batches, seq_length))

Create the building blocks for training

Create an instance of the network, the loss function, the optimizer, and learning rate scheduler.

# create the network
net = LanguageModel(vocab_size, embed_size, hidden_neurons, num_layers, attention_neurons, dropout_prob)
 
# create loss function
criterion = nn.CrossEntropyLoss()
 
# create Adam optimizer
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
 
# create learning rate scheduler
# ...not needed (Adam automatically updates the learning rate)

Sanity check

Before training, run one batch through the bidirectional LSTM and attention head. The expected output is one vocabulary logit vector per sequence window and one attention distribution per time step in the window.

net.to(device)
with torch.no_grad():
    smoke_inputs = ids[:, :seq_length].to(device)
    smoke_targets = ids[:, seq_length].to(device)
    smoke_logits, smoke_attention = net(smoke_inputs)
 
assert smoke_logits.shape == (batch_size, vocab_size)
assert smoke_attention.shape == (batch_size, seq_length)
smoke_loss = criterion(smoke_logits, smoke_targets)
 
print(f"smoke logits shape: {tuple(smoke_logits.shape)}")
print(f"smoke attention shape: {tuple(smoke_attention.shape)}")
print(f"smoke loss: {smoke_loss.item():.4f}")

Train

Each training sample is a context window of seq_length tokens, and the model predicts only the following word. This matches the bidirectional attention architecture, which summarizes the full input window before the classifier emits one next-word distribution.

# reset performance monitors
losses = []
ticks = []
 
# move net to device
net.to(device)
net.train()
 
# start training
for epoch in range(1, epochs+1):
 
    # reset performance measures
    loss_sum = 0.0
 
    # measure time elapsed
    t0 = time.time()
 
    # for each context window of size seq_length
    for i in window_starts:
 
        # get batch inputs and targets and send them to device
        # target is the next word immediately after the context window
        inputs = ids[:, i:i+seq_length].to(device)
        targets = ids[:, i+seq_length].to(device)
 
        # zero the parameter gradients
        optimizer.zero_grad()
 
        # forward pass
        outputs, attention_weights = net(inputs)
 
        # calculate loss
        # outputs is a 2D tensor (batch_size, vocab_size)
        # targets is a 1D tensor (batch_size)
        loss = criterion(outputs, targets)
 
        # loss gradient backpropagation
        loss.backward()
 
        # accumulate loss
        loss_sum += loss.item()
 
        # clip big gradients to avoid overshooting near steep cliffs in the loss hyperspace
        nn.utils.clip_grad_norm_(net.parameters(), 0.5)
 
        # net parameters update
        optimizer.step()
 
    # update performance history
    losses.append(loss_sum / max(num_batches, 1))
    ticks.append(epoch)
 
    # print per-epoch performances
    print (f"\nEpoch {epoch}\n"
            f"...TIME: {time.time()-t0:.1f} seconds\n"
            f"...loss: {losses[-1]} (best {min(losses)} at epoch {ticks[np.argmin(losses)]})\n")

Test

Autoregression with a rolling context window: at each step the bidirectional LSTM re-encodes the most recent seq_length words, attention summarizes the whole window, and the classifier predicts the next word.

seed = 'The'  # one or more words already present in the dictionary
 
net.eval()
with torch.no_grad():
    # convert seed words to indices
    seed_words = seed.split()
    seed_ids = []
    for word in seed_words:
        if word not in corpus.dictionary.word2idx:
            raise KeyError(f"Seed word '{word}' not found in dictionary.")
        seed_ids.append(corpus.dictionary.word2idx[word])
 
    # keep a rolling context window; if the seed is shorter than seq_length, pad on the left with <eos>
    history = [corpus.dictionary.word2idx['<eos>']] * max(seq_length - len(seed_ids), 0)
    history += seed_ids[-seq_length:]
 
    # start generated text with the seed
    text = seed + ' '
 
    for _ in range(num_samples):
        # forward pass on the current context window
        context = torch.tensor([history[-seq_length:]], device=device)
        output, attention_weights = net(context)  # output shape: (1, vocab_size)
 
        # sample next word index from the predicted distribution
        probs = torch.softmax(output.squeeze(0), dim=0)
        predicted_idx = torch.multinomial(probs, num_samples=1).item()
 
        # convert index back to word
        word = corpus.dictionary.idx2word[predicted_idx]
        text += '\n' if word == '<eos>' else word + ' '
 
        # extend the rolling context with the newly generated word
        history.append(predicted_idx)
 
    print(text)