Execution environment

The code below requires the dependencies used in the import block (torch, torchvision, matplotlib, numpy, etc). A practical option is to run the notebook in Google Colab with a GPU runtime enabled.

Import the required packages

Import all the packages required for this notebook.

import copy
import time
 
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
 
try:
    from sklearn.manifold import TSNE
except ImportError:
    !pip -q install scikit-learn
    from sklearn.manifold import TSNE
    
print("PyTorch version:", torch.__version__)

Hyperparameters and Options

Set the hyperparameters

hyperparameters = {
    'learning_rate': 1e-3,
    'batch_size': 128,
    'num_epochs': 20,
    'validation_split': 0.2,
    'latent_dim': 32,
    'noise_std': 0.20,
    'weight_decay': 1e-5,
    'num_workers': 2,
    'use_persistent_workers': False,
    'use_lr_scheduler': True,
    'lr_scheduler_patience': 3,
    'lr_scheduler_factor': 0.5,
    'decoder_upsampling_mode': 'nearest',
    'tsne_samples': 3000,
    'tsne_perplexity': 30,
    'tsne_random_state': 42,
    'random_seed': 42,
}
 
seed = hyperparameters['random_seed']
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print('Random seed:', seed)

Model architecture

Define a convolutional denoising autoencoder.

  • encode: maps a 28x28 MNIST digit to a compact latent code
  • decode: reconstructs the clean digit using upsampling + Conv2d layers
  • forward: returns both the reconstruction and the latent code
class ConvDenoisingAutoencoder(nn.Module):
    def __init__(self, latent_dim=32, upsampling_mode='nearest'):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),
            nn.GELU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.GELU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.GELU(),
        )
        self.to_latent = nn.Linear(128 * 4 * 4, latent_dim)
        self.from_latent = nn.Linear(latent_dim, 128 * 4 * 4)
        self.decoder = nn.Sequential(
            nn.Upsample(size=(7, 7), mode=upsampling_mode),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Upsample(scale_factor=2, mode=upsampling_mode),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Upsample(scale_factor=2, mode=upsampling_mode),
            nn.Conv2d(32, 16, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv2d(16, 1, kernel_size=3, padding=1),
            nn.Sigmoid(),
        )
        self.apply(self._initialize_weights)
        
    def _initialize_weights(self, module):
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
 
    def encode(self, x):
        x = self.encoder(x)
        x = torch.flatten(x, start_dim=1)
        return self.to_latent(x)
 
    def decode(self, latent_code):
        x = self.from_latent(latent_code)
        x = x.view(latent_code.size(0), 128, 4, 4)
        return self.decoder(x)
 
    def forward(self, x):
        latent_code = self.encode(x)
        reconstruction = self.decode(latent_code)
        return reconstruction, latent_code

Building blocks for training

Create the objects required for reconstruction training:

  • the denoising autoencoder
  • the reconstruction loss
  • the optimizer
  • an optional learning-rate scheduler
model = ConvDenoisingAutoencoder(
    latent_dim=hyperparameters['latent_dim'],
    upsampling_mode=hyperparameters.get('decoder_upsampling_mode', 'nearest'),
).to(device)
criterion = nn.MSELoss().to(device)
optimizer = optim.AdamW(
    model.parameters(),
    lr=hyperparameters['learning_rate'],
    weight_decay=hyperparameters['weight_decay'],
)
scheduler = (
    optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=hyperparameters['lr_scheduler_factor'],
        patience=hyperparameters['lr_scheduler_patience'],
    )
    if hyperparameters.get('use_lr_scheduler', False)
    else None
)

Model Summary and Smoke Test

Inspect the model structure and verify that it returns both a reconstruction with shape 1x28x28 and a latent vector with the requested dimensionality.

try:
    from torchinfo import summary
except ImportError:
    !pip -q install torchinfo
    from torchinfo import summary
 
model_summary = summary(
    model,
    input_size=(1, 1, 28, 28),
    device=str(device),
    depth=3,
    col_names=["input_size", "output_size", "num_params", "trainable"],
    verbose=0,
)
 
print(f"Total parameters: {model_summary.total_params:,}")
print(f"Trainable parameters: {model_summary.trainable_params:,}")
 
with torch.no_grad():
    dummy_batch = torch.rand(8, 1, 28, 28, device=device)
    reconstructions, latent_codes = model(dummy_batch)
 
print('Reconstruction batch shape:', tuple(reconstructions.shape))
print('Latent batch shape:', tuple(latent_codes.shape))
assert reconstructions.shape == dummy_batch.shape
assert latent_codes.shape == (8, hyperparameters['latent_dim'])
print('Smoke test passed.')

Create datasets

Split training / validation set.

Strategy: split MNIST training set into 80% train and 20% validation (controlled by validation_split), and keep the full MNIST test set for testing.

mnist_train = datasets.MNIST(root='./data', train=True, download=True)
mnist_test = datasets.MNIST(root='./data', train=False, download=True)
 
validation_split = hyperparameters.get('validation_split', 0.2)
train_size = int((1.0 - validation_split) * len(mnist_train))
val_size = len(mnist_train) - train_size
split_generator = torch.Generator().manual_seed(hyperparameters['random_seed'])
indices = torch.randperm(len(mnist_train), generator=split_generator).tolist()
train_indices = indices[:train_size]
val_indices = indices[train_size:]
 
train_dataset = torch.utils.data.Subset(
    datasets.MNIST(root='./data', train=True, download=False),
    train_indices,
)
val_dataset = torch.utils.data.Subset(
    datasets.MNIST(root='./data', train=True, download=False),
    val_indices,
)
test_dataset = datasets.MNIST(root='./data', train=False, download=False)
 
print(f'Train size: {train_size}, Validation size: {val_size}, Test size: {len(test_dataset)}')
print('Targets remain the clean MNIST digits; Gaussian noise is injected only into the input batch during training/evaluation.')

Data Check

A common problem in Machine Learning and Deep Learning is feeding incorrect data to the model, often due to mistakes in loading or preprocessing.

A good practice is to calculate and print statistics.

# visual check of random samples from training set together with their dataset labels
def imshow(img):
    npimg = np.array(img)
    plt.imshow(npimg, cmap='gray')
    plt.show()
 
# check some samples directly from the dataset
for i in range(5):
    img, label = train_dataset[i]
    print(f'Sample {i+1}: Image type: {type(img)}, Size: {img.size}, Mode: {img.mode}, Label: {label}')
    imshow(img)

Data transforms

Data transforms are applied at batch generation time.

They serve to transform your data into what the neural network expects.

Data should be converted to tensors whose shape corresponds to the network input and possibly normalized so as to be 0-centered roughly in the range

transform = transforms.ToTensor()
 
train_dataset.dataset.transform = transform
val_dataset.dataset.transform = transform
test_dataset.transform = transform
train_dataset.dataset.target_transform = None
val_dataset.dataset.target_transform = None
test_dataset.target_transform = None
 
def add_gaussian_noise(images, noise_std):
    if noise_std <= 0:
        return images
    noisy_images = images + torch.randn_like(images) * noise_std
    return noisy_images.clamp(0.0, 1.0)
 
 
print('Transforms configured for reconstruction in the [0, 1] pixel range.')

Data loaders

DataLoaders are built-in PyTorch objects that sample batches from datasets.

# prepare data loaders for training, validation, and testing
loader_num_workers = hyperparameters.get('num_workers', 2)
use_persistent_workers = hyperparameters.get('use_persistent_workers', True) and loader_num_workers > 0
pin_memory = torch.cuda.is_available()
 
train_loader = DataLoader(
    train_dataset,
    batch_size=hyperparameters['batch_size'],
    shuffle=True,
    num_workers=loader_num_workers,
    pin_memory=pin_memory,
    persistent_workers=use_persistent_workers
 )
val_loader = DataLoader(
    val_dataset,
    batch_size=512,
    shuffle=False,
    num_workers=loader_num_workers,
    pin_memory=pin_memory,
    persistent_workers=use_persistent_workers
 )
test_loader = DataLoader(
    test_dataset,
    batch_size=512,
    shuffle=False,
    num_workers=loader_num_workers,
    pin_memory=pin_memory,
    persistent_workers=use_persistent_workers
 )

Data check after transforms

To check what the network will see at train/test time, you have to use dataloaders which will apply the data transforms previously defined.

# check some samples directly from the dataset
for i in range(5):
    img, label = train_dataset[i]
    #transform img (torch.tensor) to 2D numpy array
    img = img.squeeze().numpy()  # remove channel dimension and convert to numpy array
    #print(f'Sample {i+1}: Image type: {type(img)}, Size: {img.size}, Mode: {img.mode}, Label: {label}')
    imshow(img)

Train function

Train the autoencoder for one epoch by feeding noisy digits to the encoder and comparing the reconstruction against the clean image.

def train(model, train_loader, criterion, optimizer, noise_std):
    model.train()
    running_loss = 0.0
 
    for images, _ in train_loader:
        clean_images = images.to(device, non_blocking=True)
        noisy_images = add_gaussian_noise(clean_images, noise_std)
 
        optimizer.zero_grad(set_to_none=True)
        reconstructed_images, _ = model(noisy_images)
        loss = criterion(reconstructed_images, clean_images)
        loss.backward()
        optimizer.step()
 
        running_loss += loss.item() * clean_images.size(0)
        
    average_loss = running_loss / len(train_loader.dataset)
    return average_loss

Validation / Latent Extraction

Use one function to evaluate reconstruction loss and another to collect latent codes for the final t-SNE visualization.

def evaluate(model, data_loader, criterion, noise_std):
    model.eval()
    running_loss = 0.0
 
    with torch.no_grad():
        for images, _ in data_loader:
            clean_images = images.to(device, non_blocking=True)
            noisy_images = add_gaussian_noise(clean_images, noise_std)
            reconstructed_images, _ = model(noisy_images)
            loss = criterion(reconstructed_images, clean_images)
            running_loss += loss.item() * clean_images.size(0)
            
    average_loss = running_loss / len(data_loader.dataset)
    return average_loss
 
 
def collect_latent_codes(model, data_loader, noise_std=0.0, max_samples=None):
    model.eval()
    latent_codes = []
    labels = []
    collected = 0
 
    with torch.no_grad():
        for images, batch_labels in data_loader:
            input_images = images.to(device, non_blocking=True)
            input_images = add_gaussian_noise(input_images, noise_std)
            _, batch_latent_codes = model(input_images)
            latent_codes.append(batch_latent_codes.cpu())
            labels.append(batch_labels.cpu())
            collected += batch_latent_codes.size(0)
 
            if max_samples is not None and collected >= max_samples:
                break
 
    latent_codes = torch.cat(latent_codes, dim=0)
    labels = torch.cat(labels, dim=0)
 
    if max_samples is not None:
        latent_codes = latent_codes[:max_samples]
        labels = labels[:max_samples]
 
    return latent_codes.numpy(), labels.numpy()

Train Model / Evaluate Reconstruction

Train on noisy MNIST digits, monitor reconstruction loss on the validation split, and keep the best model according to validation loss. The plots are generated once at the end to stay lighter on a remote server.

train_losses = []
val_losses = []
learning_rates = []
best_val_loss = float('inf')
best_val_loss_epoch = 0
best_model_state = copy.deepcopy(model.state_dict())
noise_std = hyperparameters.get('noise_std', 0.0)
 
for epoch in range(hyperparameters['num_epochs']):
    epoch_start = time.time()
    train_loss = train(model, train_loader, criterion, optimizer, noise_std)
    val_loss = evaluate(model, val_loader, criterion, noise_std)
 
    if scheduler is not None:
        scheduler.step(val_loss)
 
    current_lr = optimizer.param_groups[0]['lr']
    epoch_time = time.time() - epoch_start
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    learning_rates.append(current_lr)
 
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_val_loss_epoch = epoch + 1
        best_model_state = copy.deepcopy(model.state_dict())
 
    print(
        f"Epoch {epoch + 1:02d}/{hyperparameters['num_epochs']:02d} | "
        f"time {epoch_time:.1f}s | lr {current_lr:.2e} | "
        f"train loss {train_loss:.6f} | val loss {val_loss:.6f}"
    )
 
model.load_state_dict(best_model_state)
test_loss = evaluate(model, test_loader, criterion, noise_std)
print(f'Best validation loss: {best_val_loss:.6f} at epoch {best_val_loss_epoch}')
print(f'Test reconstruction loss: {test_loss:.6f}')
 
epochs = np.arange(1, len(train_losses) + 1)
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
 
axes[0].plot(epochs, train_losses, marker='o', label='Train')
axes[0].plot(epochs, val_losses, marker='o', label='Validation')
axes[0].set_title('Reconstruction loss')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('MSE loss')
axes[0].grid(alpha=0.3)
axes[0].legend()
 
axes[1].plot(epochs, learning_rates, marker='o', color='tab:orange')
axes[1].set_title('Learning rate')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('LR')
axes[1].set_yscale('log')
axes[1].grid(alpha=0.3)
 
plt.tight_layout()
plt.show()

Reconstructions and Latent Space

Visualize clean inputs, noisy inputs, reconstructions, and a 2D t-SNE projection of the latent codes.

model.eval()
num_examples = 10
example_images, example_labels = next(iter(test_loader))
clean_images = example_images[:num_examples].to(device)
example_labels = example_labels[:num_examples]
noisy_images = add_gaussian_noise(clean_images, hyperparameters['noise_std'])
 
with torch.no_grad():
    reconstructed_images, _ = model(noisy_images)
 
clean_images = clean_images.cpu()
noisy_images = noisy_images.cpu()
reconstructed_images = reconstructed_images.cpu()
 
fig, axes = plt.subplots(3, num_examples, figsize=(1.5 * num_examples, 4.5))
for idx in range(num_examples):
    axes[0, idx].imshow(clean_images[idx, 0], cmap='gray')
    axes[0, idx].set_title(f'label {int(example_labels[idx])}')
    axes[1, idx].imshow(noisy_images[idx, 0], cmap='gray')
    axes[2, idx].imshow(reconstructed_images[idx, 0], cmap='gray')
    for row in range(3):
        axes[row, idx].axis('off')
 
axes[0, 0].set_ylabel('Clean', fontsize=12)
axes[1, 0].set_ylabel('Noisy', fontsize=12)
axes[2, 0].set_ylabel('Recon', fontsize=12)
plt.suptitle('MNIST denoising autoencoder examples')
plt.tight_layout()
plt.show()
 
latent_codes, labels = collect_latent_codes(
    model,
    test_loader,
    noise_std=0.0,
    max_samples=hyperparameters['tsne_samples'],
)
tsne_perplexity = min(hyperparameters['tsne_perplexity'], len(latent_codes) - 1)
tsne = TSNE(
    n_components=2,
    perplexity=tsne_perplexity,
    init='pca',
    learning_rate=200.0,
    random_state=hyperparameters['tsne_random_state'],
)
latent_2d = tsne.fit_transform(latent_codes)
 
plt.figure(figsize=(10, 8))
scatter = plt.scatter(
    latent_2d[:, 0],
    latent_2d[:, 1],
    c=labels,
    cmap='tab10',
    s=8,
    alpha=0.75,
)
plt.title('t-SNE of MNIST latent codes')
plt.xlabel('t-SNE component 1')
plt.ylabel('t-SNE component 2')
colorbar = plt.colorbar(scatter, ticks=list(range(10)))
colorbar.set_label('Digit label')
plt.grid(alpha=0.15)
plt.show()