
Import the required packages
Import all the packages required for this notebook.
import copy
import time
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
try:
from sklearn.manifold import TSNE
except ImportError:
!pip -q install scikit-learn
from sklearn.manifold import TSNE
print("PyTorch version:", torch.__version__)Hyperparameters and Options
Set the hyperparameters
hyperparameters = {
'learning_rate': 1e-3,
'batch_size': 256,
'num_epochs': 15,
'validation_split': 0.2,
'latent_dim': 2,
'beta_values': [0.0, 1.0, 5.0],
'manifold_beta': 5.0,
'weight_decay': 1e-5,
'num_workers': 2,
'use_persistent_workers': False,
'use_lr_scheduler': True,
'lr_scheduler_patience': 2,
'lr_scheduler_factor': 0.5,
'decoder_upsampling_mode': 'nearest',
'tsne_samples': 2000,
'tsne_perplexity': 30,
'tsne_random_state': 42,
'manifold_grid_size': 20,
'manifold_quantile_min': 0.05,
'manifold_quantile_max': 0.95,
'random_seed': 42,
}
seed = hyperparameters['random_seed']
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print('Random seed:', seed)
print('Beta values:', hyperparameters['beta_values'])Model architecture
Define a convolutional beta-VAE.
encode: maps a 28x28 MNIST digit to the latent Gaussian parametersmuandlog_varreparameterize: samples a latent code using the reparameterization trickdecode: maps a latent code back to image logitsforward: returns reconstruction logits together with the latent statistics
class ConvBetaVAE(nn.Module):
def __init__(self, latent_dim=2, upsampling_mode='nearest'):
super().__init__()
self.latent_dim = latent_dim
self.encoder = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),
nn.GELU(),
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
nn.GELU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
nn.GELU(),
)
self.fc_mu = nn.Linear(64 * 7 * 7, latent_dim)
self.fc_log_var = nn.Linear(64 * 7 * 7, latent_dim)
self.fc_decode = nn.Linear(latent_dim, 64 * 7 * 7)
self.decoder = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.GELU(),
nn.Upsample(scale_factor=2, mode=upsampling_mode),
nn.Conv2d(64, 32, kernel_size=3, padding=1),
nn.GELU(),
nn.Upsample(scale_factor=2, mode=upsampling_mode),
nn.Conv2d(32, 16, kernel_size=3, padding=1),
nn.GELU(),
nn.Conv2d(16, 1, kernel_size=3, padding=1),
)
self.apply(self._initialize_weights)
def _initialize_weights(self, module):
if isinstance(module, (nn.Conv2d, nn.Linear)):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
def encode(self, x):
hidden = self.encoder(x)
hidden = torch.flatten(hidden, start_dim=1)
mu = self.fc_mu(hidden)
log_var = self.fc_log_var(hidden)
return mu, log_var
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, latent_code):
hidden = self.fc_decode(latent_code)
hidden = hidden.view(latent_code.size(0), 64, 7, 7)
return self.decoder(hidden)
def decode_to_probs(self, latent_code):
return torch.sigmoid(self.decode(latent_code))
def forward(self, x):
mu, log_var = self.encode(x)
latent_code = self.reparameterize(mu, log_var)
reconstruction_logits = self.decode(latent_code)
return reconstruction_logits, mu, log_var, latent_codeBuilding blocks for training
Create a beta-VAE factory together with the loss and optimization helpers used for each beta value in the comparison.
def build_model():
return ConvBetaVAE(
latent_dim=hyperparameters['latent_dim'],
upsampling_mode=hyperparameters.get('decoder_upsampling_mode', 'nearest'),
).to(device)
def build_optimizer(model):
optimizer = optim.AdamW(
model.parameters(),
lr=hyperparameters['learning_rate'],
weight_decay=hyperparameters['weight_decay'],
)
scheduler = (
optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode='min',
factor=hyperparameters['lr_scheduler_factor'],
patience=hyperparameters['lr_scheduler_patience'],
)
if hyperparameters.get('use_lr_scheduler', False)
else None
)
return optimizer, scheduler
def vae_loss(reconstruction_logits, targets, mu, log_var, beta):
reconstruction_loss = F.binary_cross_entropy_with_logits(
reconstruction_logits,
targets,
reduction='sum',
) / targets.size(0)
kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) / targets.size(0)
total_loss = reconstruction_loss + beta * kl_loss
return total_loss, reconstruction_loss, kl_loss
model = build_model()
reference_beta = hyperparameters['manifold_beta']Model Summary and Smoke Test
Inspect the beta-VAE structure and verify that one dummy forward pass yields valid reconstruction logits, latent statistics, and a finite loss.
total_params = sum(parameter.numel() for parameter in model.parameters())
trainable_params = sum(parameter.numel() for parameter in model.parameters() if parameter.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
with torch.no_grad():
dummy_batch = torch.rand(8, 1, 28, 28, device=device)
reconstruction_logits, mu, log_var, latent_codes = model(dummy_batch)
total_loss, reconstruction_loss, kl_loss = vae_loss(
reconstruction_logits,
dummy_batch,
mu,
log_var,
beta=reference_beta,
)
print('Reconstruction logits shape:', tuple(reconstruction_logits.shape))
print('Mu shape:', tuple(mu.shape))
print('Log var shape:', tuple(log_var.shape))
print('Latent batch shape:', tuple(latent_codes.shape))
print('Smoke-test losses:', float(total_loss), float(reconstruction_loss), float(kl_loss))
assert reconstruction_logits.shape == dummy_batch.shape
assert mu.shape == (8, hyperparameters['latent_dim'])
assert log_var.shape == (8, hyperparameters['latent_dim'])
assert latent_codes.shape == (8, hyperparameters['latent_dim'])
assert torch.isfinite(total_loss)
print('Smoke test passed.')Create datasets
Split MNIST into train, validation, and test sets.
Strategy: keep the same 80/20 split of the official MNIST training set, and use raw grayscale pixels in the [0, 1] range as reconstruction targets.
mnist_train = datasets.MNIST(root='./data', train=True, download=True)
mnist_test = datasets.MNIST(root='./data', train=False, download=True)
validation_split = hyperparameters.get('validation_split', 0.2)
train_size = int((1.0 - validation_split) * len(mnist_train))
val_size = len(mnist_train) - train_size
split_generator = torch.Generator().manual_seed(hyperparameters['random_seed'])
indices = torch.randperm(len(mnist_train), generator=split_generator).tolist()
train_indices = indices[:train_size]
val_indices = indices[train_size:]
train_dataset = torch.utils.data.Subset(
datasets.MNIST(root='./data', train=True, download=False),
train_indices,
)
val_dataset = torch.utils.data.Subset(
datasets.MNIST(root='./data', train=True, download=False),
val_indices,
)
test_dataset = datasets.MNIST(root='./data', train=False, download=False)
print(f'Train size: {train_size}, Validation size: {val_size}, Test size: {len(test_dataset)}')
print('Inputs and reconstruction targets are the same MNIST image; KL regularization shapes the latent space.')Data Check
A common problem in Machine Learning and Deep Learning is feeding incorrect data to the model, often due to mistakes in loading or preprocessing.
A good practice is to calculate and print statistics.
# visual check of random samples from training set together with their dataset labels
def imshow(img):
npimg = np.array(img)
plt.imshow(npimg, cmap='gray')
plt.show()
# check some samples directly from the dataset
for i in range(5):
img, label = train_dataset[i]
print(f'Sample {i+1}: Image type: {type(img)}, Size: {img.size}, Mode: {img.mode}, Label: {label}')
imshow(img)Data transforms
Convert MNIST images to tensors in the range so the VAE decoder can model Bernoulli-style pixel probabilities through BCE with logits.
transform = transforms.ToTensor()
train_dataset.dataset.transform = transform
val_dataset.dataset.transform = transform
test_dataset.transform = transform
train_dataset.dataset.target_transform = None
val_dataset.dataset.target_transform = None
test_dataset.target_transform = None
print('Transforms configured for VAE training with pixels in the [0, 1] range.')Data loaders
DataLoaders are built-in PyTorch objects that sample batches from datasets.
# prepare data loaders for training, validation, and testing
loader_num_workers = hyperparameters.get('num_workers', 2)
use_persistent_workers = hyperparameters.get('use_persistent_workers', True) and loader_num_workers > 0
pin_memory = torch.cuda.is_available()
train_loader = DataLoader(
train_dataset,
batch_size=hyperparameters['batch_size'],
shuffle=True,
num_workers=loader_num_workers,
pin_memory=pin_memory,
persistent_workers=use_persistent_workers
)
val_loader = DataLoader(
val_dataset,
batch_size=512,
shuffle=False,
num_workers=loader_num_workers,
pin_memory=pin_memory,
persistent_workers=use_persistent_workers
)
test_loader = DataLoader(
test_dataset,
batch_size=512,
shuffle=False,
num_workers=loader_num_workers,
pin_memory=pin_memory,
persistent_workers=use_persistent_workers
)Data check after transforms
To check what the network will see at train/test time, you have to use dataloaders which will apply the data transforms previously defined.
# check some samples directly from the dataset
for i in range(5):
img, label = train_dataset[i]
#transform img (torch.tensor) to 2D numpy array
img = img.squeeze().numpy() # remove channel dimension and convert to numpy array
#print(f'Sample {i+1}: Image type: {type(img)}, Size: {img.size}, Mode: {img.mode}, Label: {label}')
imshow(img)Train function
Train the beta-VAE for one epoch and return the total loss together with its reconstruction and KL components.
def train_epoch(model, train_loader, optimizer, beta):
model.train()
totals = {'total': 0.0, 'recon': 0.0, 'kl': 0.0}
for images, _ in train_loader:
images = images.to(device, non_blocking=True)
optimizer.zero_grad(set_to_none=True)
reconstruction_logits, mu, log_var, _ = model(images)
total_loss, reconstruction_loss, kl_loss = vae_loss(
reconstruction_logits,
images,
mu,
log_var,
beta=beta,
)
total_loss.backward()
optimizer.step()
batch_size = images.size(0)
totals['total'] += total_loss.item() * batch_size
totals['recon'] += reconstruction_loss.item() * batch_size
totals['kl'] += kl_loss.item() * batch_size
dataset_size = len(train_loader.dataset)
return {key: value / dataset_size for key, value in totals.items()}Validation / Latent Extraction
Evaluate the beta-VAE on validation or test data, and collect latent means for the t-SNE comparison plots.
def evaluate(model, data_loader, beta):
model.eval()
totals = {'total': 0.0, 'recon': 0.0, 'kl': 0.0}
with torch.no_grad():
for images, _ in data_loader:
images = images.to(device, non_blocking=True)
reconstruction_logits, mu, log_var, _ = model(images)
total_loss, reconstruction_loss, kl_loss = vae_loss(
reconstruction_logits,
images,
mu,
log_var,
beta=beta,
)
batch_size = images.size(0)
totals['total'] += total_loss.item() * batch_size
totals['recon'] += reconstruction_loss.item() * batch_size
totals['kl'] += kl_loss.item() * batch_size
dataset_size = len(data_loader.dataset)
return {key: value / dataset_size for key, value in totals.items()}
def collect_latent_codes(model, data_loader, max_samples=None):
model.eval()
latent_codes = []
labels = []
collected = 0
with torch.no_grad():
for images, batch_labels in data_loader:
images = images.to(device, non_blocking=True)
mu, _ = model.encode(images)
latent_codes.append(mu.cpu())
labels.append(batch_labels.cpu())
collected += mu.size(0)
if max_samples is not None and collected >= max_samples:
break
latent_codes = torch.cat(latent_codes, dim=0)
labels = torch.cat(labels, dim=0)
if max_samples is not None:
latent_codes = latent_codes[:max_samples]
labels = labels[:max_samples]
return latent_codes.numpy(), labels.numpy()Train Beta-VAE Models
Train one model for each beta value, keep the best checkpoint according to validation total loss, and compare how the reconstruction and KL terms evolve. The plots are generated once at the end to stay lighter on a remote server.
experiment_results = {}
for beta in hyperparameters['beta_values']:
print(f"\n=== Training beta={beta} ===")
model = build_model()
optimizer, scheduler = build_optimizer(model)
history = {
'train_total': [],
'train_recon': [],
'train_kl': [],
'val_total': [],
'val_recon': [],
'val_kl': [],
'learning_rates': [],
}
best_val_total = float('inf')
best_epoch = 0
best_model_state = None
for epoch in range(hyperparameters['num_epochs']):
epoch_start = time.time()
train_metrics = train_epoch(model, train_loader, optimizer, beta)
val_metrics = evaluate(model, val_loader, beta)
if scheduler is not None:
scheduler.step(val_metrics['total'])
current_lr = optimizer.param_groups[0]['lr']
epoch_time = time.time() - epoch_start
history['train_total'].append(train_metrics['total'])
history['train_recon'].append(train_metrics['recon'])
history['train_kl'].append(train_metrics['kl'])
history['val_total'].append(val_metrics['total'])
history['val_recon'].append(val_metrics['recon'])
history['val_kl'].append(val_metrics['kl'])
history['learning_rates'].append(current_lr)
if val_metrics['total'] < best_val_total:
best_val_total = val_metrics['total']
best_epoch = epoch + 1
best_model_state = {
key: value.detach().cpu().clone()
for key, value in model.state_dict().items()
}
print(
f"Epoch {epoch + 1:02d}/{hyperparameters['num_epochs']:02d} | "
f"time {epoch_time:.1f}s | lr {current_lr:.2e} | "
f"train total {train_metrics['total']:.3f} | val total {val_metrics['total']:.3f} | "
f"val recon {val_metrics['recon']:.3f} | val kl {val_metrics['kl']:.3f}"
)
model.load_state_dict(best_model_state)
model = model.to(device)
test_metrics = evaluate(model, test_loader, beta)
experiment_results[beta] = {
'best_epoch': best_epoch,
'best_val_total': best_val_total,
'test_metrics': test_metrics,
'history': history,
'state_dict': best_model_state,
}
print(
f"Best beta={beta} model at epoch {best_epoch}: "
f"val total {best_val_total:.3f} | "
f"test total {test_metrics['total']:.3f} | "
f"test recon {test_metrics['recon']:.3f} | test kl {test_metrics['kl']:.3f}"
)
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
for beta in hyperparameters['beta_values']:
history = experiment_results[beta]['history']
epochs = np.arange(1, len(history['val_total']) + 1)
axes[0].plot(epochs, history['val_total'], marker='o', label=f'beta={beta}')
axes[1].plot(epochs, history['val_recon'], marker='o', label=f'beta={beta}')
axes[2].plot(epochs, history['val_kl'], marker='o', label=f'beta={beta}')
axes[0].set_title('Validation total loss')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].grid(alpha=0.3)
axes[0].legend()
axes[1].set_title('Validation reconstruction loss')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('BCE')
axes[1].grid(alpha=0.3)
axes[1].legend()
axes[2].set_title('Validation KL loss')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('KL')
axes[2].grid(alpha=0.3)
axes[2].legend()
plt.tight_layout()
plt.show()t-SNE Comparison and VAE Manifolds
Compare the 2D t-SNE embeddings of the learned latent means for different beta values, then decode a grid of latent points for each beta-VAE to visualize how the manifold changes with KL regularization.
beta_values = hyperparameters['beta_values']
fig, axes = plt.subplots(1, len(beta_values), figsize=(7 * len(beta_values), 5), squeeze=False)
axes = axes.ravel()
for axis, beta in zip(axes, beta_values):
model = build_model()
model.load_state_dict(experiment_results[beta]['state_dict'])
model = model.to(device)
model.eval()
latent_codes, labels = collect_latent_codes(
model,
test_loader,
max_samples=hyperparameters['tsne_samples'],
)
tsne_perplexity = min(hyperparameters['tsne_perplexity'], len(latent_codes) - 1)
tsne = TSNE(
n_components=2,
perplexity=tsne_perplexity,
init='pca',
learning_rate=200.0,
random_state=hyperparameters['tsne_random_state'],
)
latent_2d = tsne.fit_transform(latent_codes)
scatter = axis.scatter(
latent_2d[:, 0],
latent_2d[:, 1],
c=labels,
cmap='tab10',
s=10,
alpha=0.75,
)
axis.set_title(f'MNIST latent t-SNE - beta {beta}')
axis.set_xlabel('t-SNE component 1')
axis.set_ylabel('t-SNE component 2')
axis.grid(alpha=0.15)
colorbar = plt.colorbar(scatter, ax=axis, ticks=list(range(10)))
colorbar.set_label('Digit label')
plt.tight_layout()
plt.show()
assert hyperparameters['latent_dim'] == 2, 'Manifold visualization assumes a 2D latent space.'
normal_distribution = torch.distributions.Normal(0, 1)
latent_axis = normal_distribution.icdf(
torch.linspace(
hyperparameters['manifold_quantile_min'],
hyperparameters['manifold_quantile_max'],
hyperparameters['manifold_grid_size'],
device=device,
)
)
mesh_y, mesh_x = torch.meshgrid(latent_axis, latent_axis, indexing='ij')
grid_latent_codes = torch.stack([mesh_x.reshape(-1), mesh_y.reshape(-1)], dim=1)
grid_size = hyperparameters['manifold_grid_size']
latent_min = float(latent_axis.min().cpu())
latent_max = float(latent_axis.max().cpu())
fig, axes = plt.subplots(1, len(beta_values), figsize=(5 * len(beta_values), 5), squeeze=False)
axes = axes.ravel()
for axis, beta in zip(axes, beta_values):
manifold_model = build_model()
manifold_model.load_state_dict(experiment_results[beta]['state_dict'])
manifold_model = manifold_model.to(device)
manifold_model.eval()
with torch.no_grad():
decoded_images = manifold_model.decode_to_probs(grid_latent_codes).cpu().squeeze(1)
canvas = np.zeros((28 * grid_size, 28 * grid_size), dtype=np.float32)
for row in range(grid_size):
for col in range(grid_size):
image_index = row * grid_size + col
canvas[row * 28:(row + 1) * 28, col * 28:(col + 1) * 28] = decoded_images[image_index].numpy()
axis.imshow(
canvas,
cmap='gray',
origin='lower',
extent=[latent_min, latent_max, latent_min, latent_max],
)
axis.set_title(f'VAE manifold - beta {beta}')
axis.set_xlabel('z1')
axis.set_ylabel('z2')
axis.grid(False)
plt.tight_layout()
plt.show()