Import the required packages

Import all the packages required for this notebook.

import copy
import time
 
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
 
try:
    from sklearn.manifold import TSNE
except ImportError:
    !pip -q install scikit-learn
    from sklearn.manifold import TSNE
 
print("PyTorch version:", torch.__version__)

Hyperparameters and Options

Set the hyperparameters

hyperparameters = {
    'learning_rate': 1e-3,
    'batch_size': 256,
    'num_epochs': 15,
    'validation_split': 0.2,
    'latent_dim': 2,
    'beta_values': [0.0, 1.0, 5.0],
    'manifold_beta': 5.0,
    'weight_decay': 1e-5,
    'num_workers': 2,
    'use_persistent_workers': False,
    'use_lr_scheduler': True,
    'lr_scheduler_patience': 2,
    'lr_scheduler_factor': 0.5,
    'decoder_upsampling_mode': 'nearest',
    'tsne_samples': 2000,
    'tsne_perplexity': 30,
    'tsne_random_state': 42,
    'manifold_grid_size': 20,
    'manifold_quantile_min': 0.05,
    'manifold_quantile_max': 0.95,
    'random_seed': 42,
}
 
seed = hyperparameters['random_seed']
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print('Random seed:', seed)
print('Beta values:', hyperparameters['beta_values'])

Model architecture

Define a convolutional beta-VAE.

  • encode: maps a 28x28 MNIST digit to the latent Gaussian parameters mu and log_var
  • reparameterize: samples a latent code using the reparameterization trick
  • decode: maps a latent code back to image logits
  • forward: returns reconstruction logits together with the latent statistics
class ConvBetaVAE(nn.Module):
    def __init__(self, latent_dim=2, upsampling_mode='nearest'):
        super().__init__()
        self.latent_dim = latent_dim
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),
            nn.GELU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.GELU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.GELU(),
        )
        self.fc_mu = nn.Linear(64 * 7 * 7, latent_dim)
        self.fc_log_var = nn.Linear(64 * 7 * 7, latent_dim)
        self.fc_decode = nn.Linear(latent_dim, 64 * 7 * 7)
        self.decoder = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Upsample(scale_factor=2, mode=upsampling_mode),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Upsample(scale_factor=2, mode=upsampling_mode),
            nn.Conv2d(32, 16, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv2d(16, 1, kernel_size=3, padding=1),
        )
        self.apply(self._initialize_weights)
 
    def _initialize_weights(self, module):
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
 
    def encode(self, x):
        hidden = self.encoder(x)
        hidden = torch.flatten(hidden, start_dim=1)
        mu = self.fc_mu(hidden)
        log_var = self.fc_log_var(hidden)
        return mu, log_var
 
    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std
 
    def decode(self, latent_code):
        hidden = self.fc_decode(latent_code)
        hidden = hidden.view(latent_code.size(0), 64, 7, 7)
        return self.decoder(hidden)
 
    def decode_to_probs(self, latent_code):
        return torch.sigmoid(self.decode(latent_code))
 
    def forward(self, x):
        mu, log_var = self.encode(x)
        latent_code = self.reparameterize(mu, log_var)
        reconstruction_logits = self.decode(latent_code)
        return reconstruction_logits, mu, log_var, latent_code

Building blocks for training

Create a beta-VAE factory together with the loss and optimization helpers used for each beta value in the comparison.

def build_model():
    return ConvBetaVAE(
        latent_dim=hyperparameters['latent_dim'],
        upsampling_mode=hyperparameters.get('decoder_upsampling_mode', 'nearest'),
    ).to(device)
 
 
def build_optimizer(model):
    optimizer = optim.AdamW(
        model.parameters(),
        lr=hyperparameters['learning_rate'],
        weight_decay=hyperparameters['weight_decay'],
    )
    scheduler = (
        optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=hyperparameters['lr_scheduler_factor'],
            patience=hyperparameters['lr_scheduler_patience'],
        )
        if hyperparameters.get('use_lr_scheduler', False)
        else None
    )
    return optimizer, scheduler
 
 
def vae_loss(reconstruction_logits, targets, mu, log_var, beta):
    reconstruction_loss = F.binary_cross_entropy_with_logits(
        reconstruction_logits,
        targets,
        reduction='sum',
    ) / targets.size(0)
    kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) / targets.size(0)
    total_loss = reconstruction_loss + beta * kl_loss
    return total_loss, reconstruction_loss, kl_loss
 
 
model = build_model()
reference_beta = hyperparameters['manifold_beta']

Model Summary and Smoke Test

Inspect the beta-VAE structure and verify that one dummy forward pass yields valid reconstruction logits, latent statistics, and a finite loss.

total_params = sum(parameter.numel() for parameter in model.parameters())
trainable_params = sum(parameter.numel() for parameter in model.parameters() if parameter.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
 
with torch.no_grad():
    dummy_batch = torch.rand(8, 1, 28, 28, device=device)
    reconstruction_logits, mu, log_var, latent_codes = model(dummy_batch)
    total_loss, reconstruction_loss, kl_loss = vae_loss(
        reconstruction_logits,
        dummy_batch,
        mu,
        log_var,
        beta=reference_beta,
    )
 
print('Reconstruction logits shape:', tuple(reconstruction_logits.shape))
print('Mu shape:', tuple(mu.shape))
print('Log var shape:', tuple(log_var.shape))
print('Latent batch shape:', tuple(latent_codes.shape))
print('Smoke-test losses:', float(total_loss), float(reconstruction_loss), float(kl_loss))
assert reconstruction_logits.shape == dummy_batch.shape
assert mu.shape == (8, hyperparameters['latent_dim'])
assert log_var.shape == (8, hyperparameters['latent_dim'])
assert latent_codes.shape == (8, hyperparameters['latent_dim'])
assert torch.isfinite(total_loss)
print('Smoke test passed.')

Create datasets

Split MNIST into train, validation, and test sets.

Strategy: keep the same 80/20 split of the official MNIST training set, and use raw grayscale pixels in the [0, 1] range as reconstruction targets.

mnist_train = datasets.MNIST(root='./data', train=True, download=True)
mnist_test = datasets.MNIST(root='./data', train=False, download=True)
 
validation_split = hyperparameters.get('validation_split', 0.2)
train_size = int((1.0 - validation_split) * len(mnist_train))
val_size = len(mnist_train) - train_size
split_generator = torch.Generator().manual_seed(hyperparameters['random_seed'])
indices = torch.randperm(len(mnist_train), generator=split_generator).tolist()
train_indices = indices[:train_size]
val_indices = indices[train_size:]
 
train_dataset = torch.utils.data.Subset(
    datasets.MNIST(root='./data', train=True, download=False),
    train_indices,
)
val_dataset = torch.utils.data.Subset(
    datasets.MNIST(root='./data', train=True, download=False),
    val_indices,
)
test_dataset = datasets.MNIST(root='./data', train=False, download=False)
 
print(f'Train size: {train_size}, Validation size: {val_size}, Test size: {len(test_dataset)}')
print('Inputs and reconstruction targets are the same MNIST image; KL regularization shapes the latent space.')

Data Check

A common problem in Machine Learning and Deep Learning is feeding incorrect data to the model, often due to mistakes in loading or preprocessing.

A good practice is to calculate and print statistics.

# visual check of random samples from training set together with their dataset labels
def imshow(img):
    npimg = np.array(img)
    plt.imshow(npimg, cmap='gray')
    plt.show()
 
# check some samples directly from the dataset
for i in range(5):
    img, label = train_dataset[i]
    print(f'Sample {i+1}: Image type: {type(img)}, Size: {img.size}, Mode: {img.mode}, Label: {label}')
    imshow(img)

Data transforms

Convert MNIST images to tensors in the range so the VAE decoder can model Bernoulli-style pixel probabilities through BCE with logits.

transform = transforms.ToTensor()
 
train_dataset.dataset.transform = transform
val_dataset.dataset.transform = transform
test_dataset.transform = transform
train_dataset.dataset.target_transform = None
val_dataset.dataset.target_transform = None
test_dataset.target_transform = None
 
print('Transforms configured for VAE training with pixels in the [0, 1] range.')

Data loaders

DataLoaders are built-in PyTorch objects that sample batches from datasets.

# prepare data loaders for training, validation, and testing
loader_num_workers = hyperparameters.get('num_workers', 2)
use_persistent_workers = hyperparameters.get('use_persistent_workers', True) and loader_num_workers > 0
pin_memory = torch.cuda.is_available()
 
train_loader = DataLoader(
    train_dataset,
    batch_size=hyperparameters['batch_size'],
    shuffle=True,
    num_workers=loader_num_workers,
    pin_memory=pin_memory,
    persistent_workers=use_persistent_workers
 )
val_loader = DataLoader(
    val_dataset,
    batch_size=512,
    shuffle=False,
    num_workers=loader_num_workers,
    pin_memory=pin_memory,
    persistent_workers=use_persistent_workers
 )
test_loader = DataLoader(
    test_dataset,
    batch_size=512,
    shuffle=False,
    num_workers=loader_num_workers,
    pin_memory=pin_memory,
    persistent_workers=use_persistent_workers
 )

Data check after transforms

To check what the network will see at train/test time, you have to use dataloaders which will apply the data transforms previously defined.

# check some samples directly from the dataset
for i in range(5):
    img, label = train_dataset[i]
    #transform img (torch.tensor) to 2D numpy array
    img = img.squeeze().numpy()  # remove channel dimension and convert to numpy array
    #print(f'Sample {i+1}: Image type: {type(img)}, Size: {img.size}, Mode: {img.mode}, Label: {label}')
    imshow(img)

Train function

Train the beta-VAE for one epoch and return the total loss together with its reconstruction and KL components.

def train_epoch(model, train_loader, optimizer, beta):
    model.train()
    totals = {'total': 0.0, 'recon': 0.0, 'kl': 0.0}
 
    for images, _ in train_loader:
        images = images.to(device, non_blocking=True)
 
        optimizer.zero_grad(set_to_none=True)
        reconstruction_logits, mu, log_var, _ = model(images)
        total_loss, reconstruction_loss, kl_loss = vae_loss(
            reconstruction_logits,
            images,
            mu,
            log_var,
            beta=beta,
        )
        total_loss.backward()
        optimizer.step()
 
        batch_size = images.size(0)
        totals['total'] += total_loss.item() * batch_size
        totals['recon'] += reconstruction_loss.item() * batch_size
        totals['kl'] += kl_loss.item() * batch_size
 
    dataset_size = len(train_loader.dataset)
    return {key: value / dataset_size for key, value in totals.items()}

Validation / Latent Extraction

Evaluate the beta-VAE on validation or test data, and collect latent means for the t-SNE comparison plots.

def evaluate(model, data_loader, beta):
    model.eval()
    totals = {'total': 0.0, 'recon': 0.0, 'kl': 0.0}
 
    with torch.no_grad():
        for images, _ in data_loader:
            images = images.to(device, non_blocking=True)
            reconstruction_logits, mu, log_var, _ = model(images)
            total_loss, reconstruction_loss, kl_loss = vae_loss(
                reconstruction_logits,
                images,
                mu,
                log_var,
                beta=beta,
            )
            
            batch_size = images.size(0)
            totals['total'] += total_loss.item() * batch_size
            totals['recon'] += reconstruction_loss.item() * batch_size
            totals['kl'] += kl_loss.item() * batch_size
 
    dataset_size = len(data_loader.dataset)
    return {key: value / dataset_size for key, value in totals.items()}
 
 
def collect_latent_codes(model, data_loader, max_samples=None):
    model.eval()
    latent_codes = []
    labels = []
    collected = 0
 
    with torch.no_grad():
        for images, batch_labels in data_loader:
            images = images.to(device, non_blocking=True)
            mu, _ = model.encode(images)
            latent_codes.append(mu.cpu())
            labels.append(batch_labels.cpu())
            collected += mu.size(0)
            
            if max_samples is not None and collected >= max_samples:
                break
 
    latent_codes = torch.cat(latent_codes, dim=0)
    labels = torch.cat(labels, dim=0)
 
    if max_samples is not None:
        latent_codes = latent_codes[:max_samples]
        labels = labels[:max_samples]
 
    return latent_codes.numpy(), labels.numpy()

Train Beta-VAE Models

Train one model for each beta value, keep the best checkpoint according to validation total loss, and compare how the reconstruction and KL terms evolve. The plots are generated once at the end to stay lighter on a remote server.

experiment_results = {}
 
for beta in hyperparameters['beta_values']:
    print(f"\n=== Training beta={beta} ===")
    model = build_model()
    optimizer, scheduler = build_optimizer(model)
    history = {
        'train_total': [],
        'train_recon': [],
        'train_kl': [],
        'val_total': [],
        'val_recon': [],
        'val_kl': [],
        'learning_rates': [],
    }
    best_val_total = float('inf')
    best_epoch = 0
    best_model_state = None
 
    for epoch in range(hyperparameters['num_epochs']):
        epoch_start = time.time()
        train_metrics = train_epoch(model, train_loader, optimizer, beta)
        val_metrics = evaluate(model, val_loader, beta)
 
        if scheduler is not None:
            scheduler.step(val_metrics['total'])
 
        current_lr = optimizer.param_groups[0]['lr']
        epoch_time = time.time() - epoch_start
 
        history['train_total'].append(train_metrics['total'])
        history['train_recon'].append(train_metrics['recon'])
        history['train_kl'].append(train_metrics['kl'])
        history['val_total'].append(val_metrics['total'])
        history['val_recon'].append(val_metrics['recon'])
        history['val_kl'].append(val_metrics['kl'])
        history['learning_rates'].append(current_lr)
 
        if val_metrics['total'] < best_val_total:
            best_val_total = val_metrics['total']
            best_epoch = epoch + 1
            best_model_state = {
                key: value.detach().cpu().clone()
                for key, value in model.state_dict().items()
            }
 
        print(
            f"Epoch {epoch + 1:02d}/{hyperparameters['num_epochs']:02d} | "
            f"time {epoch_time:.1f}s | lr {current_lr:.2e} | "
            f"train total {train_metrics['total']:.3f} | val total {val_metrics['total']:.3f} | "
            f"val recon {val_metrics['recon']:.3f} | val kl {val_metrics['kl']:.3f}"
        )
 
    model.load_state_dict(best_model_state)
    model = model.to(device)
    test_metrics = evaluate(model, test_loader, beta)
 
    experiment_results[beta] = {
        'best_epoch': best_epoch,
        'best_val_total': best_val_total,
        'test_metrics': test_metrics,
        'history': history,
        'state_dict': best_model_state,
    }
    
    print(
        f"Best beta={beta} model at epoch {best_epoch}: "
        f"val total {best_val_total:.3f} | "
        f"test total {test_metrics['total']:.3f} | "
        f"test recon {test_metrics['recon']:.3f} | test kl {test_metrics['kl']:.3f}"
    )
 
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
for beta in hyperparameters['beta_values']:
    history = experiment_results[beta]['history']
    epochs = np.arange(1, len(history['val_total']) + 1)
    axes[0].plot(epochs, history['val_total'], marker='o', label=f'beta={beta}')
    axes[1].plot(epochs, history['val_recon'], marker='o', label=f'beta={beta}')
    axes[2].plot(epochs, history['val_kl'], marker='o', label=f'beta={beta}')
 
axes[0].set_title('Validation total loss')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].grid(alpha=0.3)
axes[0].legend()
 
axes[1].set_title('Validation reconstruction loss')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('BCE')
axes[1].grid(alpha=0.3)
axes[1].legend()
 
axes[2].set_title('Validation KL loss')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('KL')
axes[2].grid(alpha=0.3)
axes[2].legend()
 
plt.tight_layout()
plt.show()

t-SNE Comparison and VAE Manifolds

Compare the 2D t-SNE embeddings of the learned latent means for different beta values, then decode a grid of latent points for each beta-VAE to visualize how the manifold changes with KL regularization.

beta_values = hyperparameters['beta_values']
fig, axes = plt.subplots(1, len(beta_values), figsize=(7 * len(beta_values), 5), squeeze=False)
axes = axes.ravel()
 
for axis, beta in zip(axes, beta_values):
    model = build_model()
    model.load_state_dict(experiment_results[beta]['state_dict'])
    model = model.to(device)
    model.eval()
 
    latent_codes, labels = collect_latent_codes(
        model,
        test_loader,
        max_samples=hyperparameters['tsne_samples'],
    )
    tsne_perplexity = min(hyperparameters['tsne_perplexity'], len(latent_codes) - 1)
    tsne = TSNE(
        n_components=2,
        perplexity=tsne_perplexity,
        init='pca',
        learning_rate=200.0,
        random_state=hyperparameters['tsne_random_state'],
    )
    latent_2d = tsne.fit_transform(latent_codes)
 
    scatter = axis.scatter(
        latent_2d[:, 0],
        latent_2d[:, 1],
        c=labels,
        cmap='tab10',
        s=10,
        alpha=0.75,
    )
    axis.set_title(f'MNIST latent t-SNE - beta {beta}')
    axis.set_xlabel('t-SNE component 1')
    axis.set_ylabel('t-SNE component 2')
    axis.grid(alpha=0.15)
    colorbar = plt.colorbar(scatter, ax=axis, ticks=list(range(10)))
    colorbar.set_label('Digit label')
 
plt.tight_layout()
plt.show()
 
assert hyperparameters['latent_dim'] == 2, 'Manifold visualization assumes a 2D latent space.'
normal_distribution = torch.distributions.Normal(0, 1)
latent_axis = normal_distribution.icdf(
    torch.linspace(
        hyperparameters['manifold_quantile_min'],
        hyperparameters['manifold_quantile_max'],
        hyperparameters['manifold_grid_size'],
        device=device,
    )
)
mesh_y, mesh_x = torch.meshgrid(latent_axis, latent_axis, indexing='ij')
grid_latent_codes = torch.stack([mesh_x.reshape(-1), mesh_y.reshape(-1)], dim=1)
 
grid_size = hyperparameters['manifold_grid_size']
latent_min = float(latent_axis.min().cpu())
latent_max = float(latent_axis.max().cpu())
 
fig, axes = plt.subplots(1, len(beta_values), figsize=(5 * len(beta_values), 5), squeeze=False)
axes = axes.ravel()
 
for axis, beta in zip(axes, beta_values):
    manifold_model = build_model()
    manifold_model.load_state_dict(experiment_results[beta]['state_dict'])
    manifold_model = manifold_model.to(device)
    manifold_model.eval()
 
    with torch.no_grad():
        decoded_images = manifold_model.decode_to_probs(grid_latent_codes).cpu().squeeze(1)
 
    canvas = np.zeros((28 * grid_size, 28 * grid_size), dtype=np.float32)
    for row in range(grid_size):
        for col in range(grid_size):
            image_index = row * grid_size + col
            canvas[row * 28:(row + 1) * 28, col * 28:(col + 1) * 28] = decoded_images[image_index].numpy()
 
    axis.imshow(
        canvas,
        cmap='gray',
        origin='lower',
        extent=[latent_min, latent_max, latent_min, latent_max],
    )
    axis.set_title(f'VAE manifold - beta {beta}')
    axis.set_xlabel('z1')
    axis.set_ylabel('z2')
    axis.grid(False)
 
plt.tight_layout()
plt.show()