Execution environment
The code below requires the dependencies used in the import block (
torch,torchvision,matplotlib,numpy, etc). A practical option is to run the notebook in Google Colab with a GPU runtime enabled.

Import the required packages
Import all the packages required for this notebook.
import copy
import time
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
try:
from sklearn.manifold import TSNE
except ImportError:
!pip -q install scikit-learn
from sklearn.manifold import TSNE
print("PyTorch version:", torch.__version__)Hyperparameters and Options
Set the hyperparameters
hyperparameters = {
'learning_rate': 1e-3,
'batch_size': 128,
'num_epochs': 20,
'validation_split': 0.2,
'latent_dim': 32,
'noise_std': 0.20,
'weight_decay': 1e-5,
'num_workers': 2,
'use_persistent_workers': False,
'use_lr_scheduler': True,
'lr_scheduler_patience': 3,
'lr_scheduler_factor': 0.5,
'decoder_upsampling_mode': 'nearest',
'tsne_samples': 3000,
'tsne_perplexity': 30,
'tsne_random_state': 42,
'random_seed': 42,
}
seed = hyperparameters['random_seed']
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print('Random seed:', seed)Model architecture
Define a convolutional denoising autoencoder.
encode: maps a 28x28 MNIST digit to a compact latent codedecode: reconstructs the clean digit using upsampling + Conv2d layersforward: returns both the reconstruction and the latent code
class ConvDenoisingAutoencoder(nn.Module):
def __init__(self, latent_dim=32, upsampling_mode='nearest'):
super().__init__()
self.encoder = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),
nn.GELU(),
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
nn.GELU(),
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
nn.GELU(),
)
self.to_latent = nn.Linear(128 * 4 * 4, latent_dim)
self.from_latent = nn.Linear(latent_dim, 128 * 4 * 4)
self.decoder = nn.Sequential(
nn.Upsample(size=(7, 7), mode=upsampling_mode),
nn.Conv2d(128, 64, kernel_size=3, padding=1),
nn.GELU(),
nn.Upsample(scale_factor=2, mode=upsampling_mode),
nn.Conv2d(64, 32, kernel_size=3, padding=1),
nn.GELU(),
nn.Upsample(scale_factor=2, mode=upsampling_mode),
nn.Conv2d(32, 16, kernel_size=3, padding=1),
nn.GELU(),
nn.Conv2d(16, 1, kernel_size=3, padding=1),
nn.Sigmoid(),
)
self.apply(self._initialize_weights)
def _initialize_weights(self, module):
if isinstance(module, (nn.Conv2d, nn.Linear)):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
def encode(self, x):
x = self.encoder(x)
x = torch.flatten(x, start_dim=1)
return self.to_latent(x)
def decode(self, latent_code):
x = self.from_latent(latent_code)
x = x.view(latent_code.size(0), 128, 4, 4)
return self.decoder(x)
def forward(self, x):
latent_code = self.encode(x)
reconstruction = self.decode(latent_code)
return reconstruction, latent_codeBuilding blocks for training
Create the objects required for reconstruction training:
- the denoising autoencoder
- the reconstruction loss
- the optimizer
- an optional learning-rate scheduler
model = ConvDenoisingAutoencoder(
latent_dim=hyperparameters['latent_dim'],
upsampling_mode=hyperparameters.get('decoder_upsampling_mode', 'nearest'),
).to(device)
criterion = nn.MSELoss().to(device)
optimizer = optim.AdamW(
model.parameters(),
lr=hyperparameters['learning_rate'],
weight_decay=hyperparameters['weight_decay'],
)
scheduler = (
optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode='min',
factor=hyperparameters['lr_scheduler_factor'],
patience=hyperparameters['lr_scheduler_patience'],
)
if hyperparameters.get('use_lr_scheduler', False)
else None
)Model Summary and Smoke Test
Inspect the model structure and verify that it returns both a reconstruction with shape 1x28x28 and a latent vector with the requested dimensionality.
try:
from torchinfo import summary
except ImportError:
!pip -q install torchinfo
from torchinfo import summary
model_summary = summary(
model,
input_size=(1, 1, 28, 28),
device=str(device),
depth=3,
col_names=["input_size", "output_size", "num_params", "trainable"],
verbose=0,
)
print(f"Total parameters: {model_summary.total_params:,}")
print(f"Trainable parameters: {model_summary.trainable_params:,}")
with torch.no_grad():
dummy_batch = torch.rand(8, 1, 28, 28, device=device)
reconstructions, latent_codes = model(dummy_batch)
print('Reconstruction batch shape:', tuple(reconstructions.shape))
print('Latent batch shape:', tuple(latent_codes.shape))
assert reconstructions.shape == dummy_batch.shape
assert latent_codes.shape == (8, hyperparameters['latent_dim'])
print('Smoke test passed.')Create datasets
Split training / validation set.
Strategy: split MNIST training set into 80% train and 20% validation (controlled by validation_split), and keep the full MNIST test set for testing.
mnist_train = datasets.MNIST(root='./data', train=True, download=True)
mnist_test = datasets.MNIST(root='./data', train=False, download=True)
validation_split = hyperparameters.get('validation_split', 0.2)
train_size = int((1.0 - validation_split) * len(mnist_train))
val_size = len(mnist_train) - train_size
split_generator = torch.Generator().manual_seed(hyperparameters['random_seed'])
indices = torch.randperm(len(mnist_train), generator=split_generator).tolist()
train_indices = indices[:train_size]
val_indices = indices[train_size:]
train_dataset = torch.utils.data.Subset(
datasets.MNIST(root='./data', train=True, download=False),
train_indices,
)
val_dataset = torch.utils.data.Subset(
datasets.MNIST(root='./data', train=True, download=False),
val_indices,
)
test_dataset = datasets.MNIST(root='./data', train=False, download=False)
print(f'Train size: {train_size}, Validation size: {val_size}, Test size: {len(test_dataset)}')
print('Targets remain the clean MNIST digits; Gaussian noise is injected only into the input batch during training/evaluation.')Data Check
A common problem in Machine Learning and Deep Learning is feeding incorrect data to the model, often due to mistakes in loading or preprocessing.
A good practice is to calculate and print statistics.
# visual check of random samples from training set together with their dataset labels
def imshow(img):
npimg = np.array(img)
plt.imshow(npimg, cmap='gray')
plt.show()
# check some samples directly from the dataset
for i in range(5):
img, label = train_dataset[i]
print(f'Sample {i+1}: Image type: {type(img)}, Size: {img.size}, Mode: {img.mode}, Label: {label}')
imshow(img)Data transforms
Data transforms are applied at batch generation time.
They serve to transform your data into what the neural network expects.
Data should be converted to tensors whose shape corresponds to the network input and possibly normalized so as to be 0-centered roughly in the range
transform = transforms.ToTensor()
train_dataset.dataset.transform = transform
val_dataset.dataset.transform = transform
test_dataset.transform = transform
train_dataset.dataset.target_transform = None
val_dataset.dataset.target_transform = None
test_dataset.target_transform = None
def add_gaussian_noise(images, noise_std):
if noise_std <= 0:
return images
noisy_images = images + torch.randn_like(images) * noise_std
return noisy_images.clamp(0.0, 1.0)
print('Transforms configured for reconstruction in the [0, 1] pixel range.')Data loaders
DataLoaders are built-in PyTorch objects that sample batches from datasets.
# prepare data loaders for training, validation, and testing
loader_num_workers = hyperparameters.get('num_workers', 2)
use_persistent_workers = hyperparameters.get('use_persistent_workers', True) and loader_num_workers > 0
pin_memory = torch.cuda.is_available()
train_loader = DataLoader(
train_dataset,
batch_size=hyperparameters['batch_size'],
shuffle=True,
num_workers=loader_num_workers,
pin_memory=pin_memory,
persistent_workers=use_persistent_workers
)
val_loader = DataLoader(
val_dataset,
batch_size=512,
shuffle=False,
num_workers=loader_num_workers,
pin_memory=pin_memory,
persistent_workers=use_persistent_workers
)
test_loader = DataLoader(
test_dataset,
batch_size=512,
shuffle=False,
num_workers=loader_num_workers,
pin_memory=pin_memory,
persistent_workers=use_persistent_workers
)Data check after transforms
To check what the network will see at train/test time, you have to use dataloaders which will apply the data transforms previously defined.
# check some samples directly from the dataset
for i in range(5):
img, label = train_dataset[i]
#transform img (torch.tensor) to 2D numpy array
img = img.squeeze().numpy() # remove channel dimension and convert to numpy array
#print(f'Sample {i+1}: Image type: {type(img)}, Size: {img.size}, Mode: {img.mode}, Label: {label}')
imshow(img)Train function
Train the autoencoder for one epoch by feeding noisy digits to the encoder and comparing the reconstruction against the clean image.
def train(model, train_loader, criterion, optimizer, noise_std):
model.train()
running_loss = 0.0
for images, _ in train_loader:
clean_images = images.to(device, non_blocking=True)
noisy_images = add_gaussian_noise(clean_images, noise_std)
optimizer.zero_grad(set_to_none=True)
reconstructed_images, _ = model(noisy_images)
loss = criterion(reconstructed_images, clean_images)
loss.backward()
optimizer.step()
running_loss += loss.item() * clean_images.size(0)
average_loss = running_loss / len(train_loader.dataset)
return average_lossValidation / Latent Extraction
Use one function to evaluate reconstruction loss and another to collect latent codes for the final t-SNE visualization.
def evaluate(model, data_loader, criterion, noise_std):
model.eval()
running_loss = 0.0
with torch.no_grad():
for images, _ in data_loader:
clean_images = images.to(device, non_blocking=True)
noisy_images = add_gaussian_noise(clean_images, noise_std)
reconstructed_images, _ = model(noisy_images)
loss = criterion(reconstructed_images, clean_images)
running_loss += loss.item() * clean_images.size(0)
average_loss = running_loss / len(data_loader.dataset)
return average_loss
def collect_latent_codes(model, data_loader, noise_std=0.0, max_samples=None):
model.eval()
latent_codes = []
labels = []
collected = 0
with torch.no_grad():
for images, batch_labels in data_loader:
input_images = images.to(device, non_blocking=True)
input_images = add_gaussian_noise(input_images, noise_std)
_, batch_latent_codes = model(input_images)
latent_codes.append(batch_latent_codes.cpu())
labels.append(batch_labels.cpu())
collected += batch_latent_codes.size(0)
if max_samples is not None and collected >= max_samples:
break
latent_codes = torch.cat(latent_codes, dim=0)
labels = torch.cat(labels, dim=0)
if max_samples is not None:
latent_codes = latent_codes[:max_samples]
labels = labels[:max_samples]
return latent_codes.numpy(), labels.numpy()Train Model / Evaluate Reconstruction
Train on noisy MNIST digits, monitor reconstruction loss on the validation split, and keep the best model according to validation loss. The plots are generated once at the end to stay lighter on a remote server.
train_losses = []
val_losses = []
learning_rates = []
best_val_loss = float('inf')
best_val_loss_epoch = 0
best_model_state = copy.deepcopy(model.state_dict())
noise_std = hyperparameters.get('noise_std', 0.0)
for epoch in range(hyperparameters['num_epochs']):
epoch_start = time.time()
train_loss = train(model, train_loader, criterion, optimizer, noise_std)
val_loss = evaluate(model, val_loader, criterion, noise_std)
if scheduler is not None:
scheduler.step(val_loss)
current_lr = optimizer.param_groups[0]['lr']
epoch_time = time.time() - epoch_start
train_losses.append(train_loss)
val_losses.append(val_loss)
learning_rates.append(current_lr)
if val_loss < best_val_loss:
best_val_loss = val_loss
best_val_loss_epoch = epoch + 1
best_model_state = copy.deepcopy(model.state_dict())
print(
f"Epoch {epoch + 1:02d}/{hyperparameters['num_epochs']:02d} | "
f"time {epoch_time:.1f}s | lr {current_lr:.2e} | "
f"train loss {train_loss:.6f} | val loss {val_loss:.6f}"
)
model.load_state_dict(best_model_state)
test_loss = evaluate(model, test_loader, criterion, noise_std)
print(f'Best validation loss: {best_val_loss:.6f} at epoch {best_val_loss_epoch}')
print(f'Test reconstruction loss: {test_loss:.6f}')
epochs = np.arange(1, len(train_losses) + 1)
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
axes[0].plot(epochs, train_losses, marker='o', label='Train')
axes[0].plot(epochs, val_losses, marker='o', label='Validation')
axes[0].set_title('Reconstruction loss')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('MSE loss')
axes[0].grid(alpha=0.3)
axes[0].legend()
axes[1].plot(epochs, learning_rates, marker='o', color='tab:orange')
axes[1].set_title('Learning rate')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('LR')
axes[1].set_yscale('log')
axes[1].grid(alpha=0.3)
plt.tight_layout()
plt.show()Reconstructions and Latent Space
Visualize clean inputs, noisy inputs, reconstructions, and a 2D t-SNE projection of the latent codes.
model.eval()
num_examples = 10
example_images, example_labels = next(iter(test_loader))
clean_images = example_images[:num_examples].to(device)
example_labels = example_labels[:num_examples]
noisy_images = add_gaussian_noise(clean_images, hyperparameters['noise_std'])
with torch.no_grad():
reconstructed_images, _ = model(noisy_images)
clean_images = clean_images.cpu()
noisy_images = noisy_images.cpu()
reconstructed_images = reconstructed_images.cpu()
fig, axes = plt.subplots(3, num_examples, figsize=(1.5 * num_examples, 4.5))
for idx in range(num_examples):
axes[0, idx].imshow(clean_images[idx, 0], cmap='gray')
axes[0, idx].set_title(f'label {int(example_labels[idx])}')
axes[1, idx].imshow(noisy_images[idx, 0], cmap='gray')
axes[2, idx].imshow(reconstructed_images[idx, 0], cmap='gray')
for row in range(3):
axes[row, idx].axis('off')
axes[0, 0].set_ylabel('Clean', fontsize=12)
axes[1, 0].set_ylabel('Noisy', fontsize=12)
axes[2, 0].set_ylabel('Recon', fontsize=12)
plt.suptitle('MNIST denoising autoencoder examples')
plt.tight_layout()
plt.show()
latent_codes, labels = collect_latent_codes(
model,
test_loader,
noise_std=0.0,
max_samples=hyperparameters['tsne_samples'],
)
tsne_perplexity = min(hyperparameters['tsne_perplexity'], len(latent_codes) - 1)
tsne = TSNE(
n_components=2,
perplexity=tsne_perplexity,
init='pca',
learning_rate=200.0,
random_state=hyperparameters['tsne_random_state'],
)
latent_2d = tsne.fit_transform(latent_codes)
plt.figure(figsize=(10, 8))
scatter = plt.scatter(
latent_2d[:, 0],
latent_2d[:, 1],
c=labels,
cmap='tab10',
s=8,
alpha=0.75,
)
plt.title('t-SNE of MNIST latent codes')
plt.xlabel('t-SNE component 1')
plt.ylabel('t-SNE component 2')
colorbar = plt.colorbar(scatter, ticks=list(range(10)))
colorbar.set_label('Digit label')
plt.grid(alpha=0.15)
plt.show()