
Import the required packages
Import all the packages required for this notebook.
# import required packages for a training/testing experiment with MNIST image classification using a CNN
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import torchvision
import numpy as np
print("PyTorch version:", torch.__version__)Hyperparameters and Options
Set the hyperparameters
# this will contain a list of all hyperparameters that will be used in our experiment like learning rate, momentum, batch size, etc.
hyperparameters = {
'learning_rate': 0.01,
'momentum': 0.5,
'nesterov': True,
'batch_size': 10,
'label_smoothing': 0.1,
'dropout_rate': 0.1,
'use_batch_norm': True,
'num_epochs': 30,
'validation_split': 0.2,
'use_lr_scheduler': True,
'lr_scheduler_step_size': 15,
'lr_scheduler_gamma': 0.5,
'early_stopping_patience': 10,
'use_data_augmentation': True,
'augmentation_probability': 0.5,
'num_workers': 12,
'use_persistent_workers': True,
'random_seed': 42
}
seed = hyperparameters['random_seed']
torch.manual_seed(seed)
np.random.seed(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print('Random seed:', seed)Model architecture
Define the network
- init: initialization of convolutional and linear layers of the network
- forward: phase where input is propagated through convolutional and classification blocks to produce logits
class CNN(nn.Module):
def __init__(self, dropout_rate=0.2, use_batch_norm=False):
super(CNN, self).__init__()
self.use_batch_norm = use_batch_norm
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding='same')
self.conv2 = nn.Conv2d(32, 32, kernel_size=3, padding='same')
self.pool = nn.MaxPool2d(2, 2)
self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding='same')
self.conv4 = nn.Conv2d(64, 64, kernel_size=3, padding='same')
if self.use_batch_norm:
self.bn1 = nn.BatchNorm2d(32)
self.bn2 = nn.BatchNorm2d(32)
self.bn3 = nn.BatchNorm2d(64)
self.bn4 = nn.BatchNorm2d(64)
self.bn_fc1 = nn.BatchNorm1d(512)
#self.pool = nn.MaxPool2d(2, 2) already defined above
self.fc1 = nn.Linear(64 * 7 * 7, 512)
self.fc2 = nn.Linear(512, 10)
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout_rate)
self.apply(self._initialize_weights)
def _initialize_weights(self, module):
if isinstance(module, (nn.Conv2d, nn.Linear)):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
def forward(self, x):
x = self.conv1(x)
if self.use_batch_norm:
x = self.bn1(x)
x = self.act(x)
x = self.conv2(x)
if self.use_batch_norm:
x = self.bn2(x)
x = self.act(x)
x = self.pool(x)
x = self.conv3(x)
if self.use_batch_norm:
x = self.bn3(x)
x = self.act(x)
x = self.conv4(x)
if self.use_batch_norm:
x = self.bn4(x)
x = self.act(x)
x = self.pool(x)
x = x.view(x.size(0), -1) # flatten the input
x = self.fc1(x)
if self.use_batch_norm:
x = self.bn_fc1(x)
x = self.act(x)
x = self.dropout(x)
x = self.fc2(x)
return xBuilding blocks for training
Instance of the network:
- Loss Function: to guide learning, quantify how close the actual output is to the desired output
- Optimizer: weight update method (stochastic gradient descent)
- Learning Rate Scheduler: adjust the learning rate based on the number of epochs
# define the model, loss function, optimizer, and optional LR scheduler
model = CNN(
dropout_rate=hyperparameters.get('dropout_rate', 0.2),
use_batch_norm=hyperparameters.get('use_batch_norm', False)
).to(device)
criterion = nn.CrossEntropyLoss(label_smoothing=hyperparameters.get('label_smoothing', 0.1)).to(device)
optimizer = optim.SGD(model.parameters(), lr=hyperparameters['learning_rate'], momentum=hyperparameters['momentum'], nesterov=hyperparameters['nesterov'] )
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=hyperparameters['lr_scheduler_step_size'], gamma=hyperparameters['lr_scheduler_gamma']) if hyperparameters.get('use_lr_scheduler', False) else NoneModel Summary and Trainable Hyperparameters
Use summary to inspect model structure and count trainable parameters.
try:
from torchinfo import summary
except ImportError:
!pip -q install torchinfo
from torchinfo import summary
model_summary = summary(
model,
input_size=(1, 1, 28, 28),
device=str(device),
depth=3,
col_names=["input_size", "output_size", "num_params", "trainable"],
verbose=2,
)
print(f"Total parameters: {model_summary.total_params:,}")
print(f"Trainable parameters: {model_summary.trainable_params:,}")Create datasets
Split training / validation set.
Strategy
Split MNIST training set into 80% train and 20% validation (controlled by
validation_split), and keep the full MNIST test set for testing.
# download MNIST training and test datasets separately
mnist_train = datasets.MNIST(root='./data', train=True, download=True)
mnist_test = datasets.MNIST(root='./data', train=False, download=True)
# split official MNIST training into train/validation
validation_split = hyperparameters.get('validation_split', 0.2)
train_size = int((1.0 - validation_split) * len(mnist_train))
val_size = len(mnist_train) - train_size
split_generator = torch.Generator().manual_seed(hyperparameters['random_seed'])
indices = torch.randperm(len(mnist_train), generator=split_generator).tolist()
train_indices = indices[:train_size]
val_indices = indices[train_size:]
# create separate dataset objects so train/val can use different transforms later
train_dataset = datasets.MNIST(root='./data', train=True, download=False)
val_dataset = datasets.MNIST(root='./data', train=True, download=False)
test_dataset = mnist_test
train_dataset = torch.utils.data.Subset(train_dataset, train_indices)
val_dataset = torch.utils.data.Subset(val_dataset, val_indices)
# calculate the mean and standard deviation on the train split only
to_tensor = transforms.ToTensor()
stats_num_workers = hyperparameters.get('num_workers', 2)
stats_dataset = datasets.MNIST(root='./data', train=True, download=False, transform=to_tensor)
stats_subset = torch.utils.data.Subset(stats_dataset, train_indices)
stats_loader = DataLoader(
stats_subset,
batch_size=hyperparameters['batch_size'],
shuffle=False,
num_workers=stats_num_workers
)
mean = 0.0
std = 0.0
num_batches = 0
for images, _ in stats_loader:
batch_mean = images.mean()
batch_std = images.std()
mean += batch_mean
std += batch_std
num_batches += 1
mean /= num_batches
std /= num_batches
print(f'Mean: {mean}, Std: {std}')
print(f'Train size: {train_size}, Validation size: {val_size}, Test size: {len(test_dataset)}')Data Check
A common problem in Machine Learning and Deep Learning is feeding incorrect data to the model, often due to mistakes in loading or preprocessing.
A good practice is to calculate and print statistics.
# visual check of random samples from training set together with their dataset labels
def imshow(img):
npimg = np.array(img)
plt.imshow(npimg, cmap='gray')
plt.show()
# check some samples directly from the dataset
for i in range(5):
img, label = train_dataset[i]
print(f'Sample {i+1}: Image type: {type(img)}, Size: {img.size}, Mode: {img.mode}, Label: {label}')
imshow(img)Data transforms
Data transforms are applied at batch generation time.
They serve to transform your data into what the neural network expects.
Data should be converted to tensors whose shape corresponds to the network input and possibly normalized so as to be 0-centered roughly in the range
# create transforms like PIL image to torch tensor and data normalization
# also create a more complex data augmentation pipeline for MNIST
augmentation_probability = hyperparameters.get('augmentation_probability', 0.5)
# fixed random permutation of pixel positions (same permutation for all samples)
permutation_generator = torch.Generator().manual_seed(hyperparameters['random_seed'])
fixed_pixel_permutation = torch.randperm(28 * 28, generator=permutation_generator)
def apply_fixed_pixel_permutation(img_tensor):
channels, height, width = img_tensor.shape
return img_tensor.reshape(-1)[fixed_pixel_permutation].reshape(channels, height, width)
base_train_transforms = [
transforms.ToTensor(),
#transforms.Lambda(apply_fixed_pixel_permutation),
transforms.Normalize((mean,), (std,)), # mean and std for MNIST
]
# MNIST-friendly augmentation: small geometric perturbations applied with configurable probability
augmentation_transforms = [
transforms.RandomApply([
transforms.RandomAffine(
degrees=12,
translate=(0.10, 0.10),
scale=(0.95, 1.05),
shear=8
),
], p=augmentation_probability),
transforms.RandomApply([
transforms.RandomRotation(10),
], p=augmentation_probability),
]
train_transform = transforms.Compose(
(augmentation_transforms if hyperparameters.get('use_data_augmentation', False) else []) + base_train_transforms
)
transform = transforms.Compose([
transforms.ToTensor(),
# transforms.Lambda(apply_fixed_pixel_permutation),
transforms.Normalize((mean,), (std,)), # mean and std for MNIST
])
# apply the transforms to the datasets
train_dataset.dataset.transform = train_transform
val_dataset.dataset.transform = transform
test_dataset.transform = transform
train_dataset.dataset.target_transform = None
val_dataset.dataset.target_transform = None
test_dataset.target_transform = NoneData loaders
DataLoaders are built-in PyTorch objects that sample batches from datasets.
# prepare data loaders for training, validation, and testing
loader_num_workers = hyperparameters.get('num_workers', 2)
use_persistent_workers = hyperparameters.get('use_persistent_workers', True) and loader_num_workers > 0
pin_memory = torch.cuda.is_available()
train_loader = DataLoader(
train_dataset,
batch_size=hyperparameters['batch_size'],
shuffle=True,
num_workers=loader_num_workers,
pin_memory=pin_memory,
persistent_workers=use_persistent_workers
)
val_loader = DataLoader(
val_dataset,
batch_size=512,
shuffle=False,
num_workers=loader_num_workers,
pin_memory=pin_memory,
persistent_workers=use_persistent_workers
)
test_loader = DataLoader(
test_dataset,
batch_size=512,
shuffle=False,
num_workers=loader_num_workers,
pin_memory=pin_memory,
persistent_workers=use_persistent_workers
)Data check after transforms
To check what the network will see at train/test time, you have to use dataloaders which will apply the data transforms previously defined.
# check some samples directly from the dataset
for i in range(5):
img, label = train_dataset[i]
#transform img (torch.tensor) to 2D numpy array
img = img.squeeze().numpy() # remove channel dimension and convert to numpy array
#print(f'Sample {i+1}: Image type: {type(img)}, Size: {img.size}, Mode: {img.mode}, Label: {label}')
imshow(img)Train function
It is preferable (but not mandatory) to embed training (1 epoch) code into a function, and call that function later during the training phase, at each epoch.
# create training loop function that will train the model for one epoch and return average loss and accuracy
def train(model, train_loader, criterion, optimizer):
model.train() # set the model to training mode
running_loss = 0.0
correct = 0
total = 0
for images, labels in train_loader:
images = images.to(device)
labels = labels.to(device)
optimizer.zero_grad() # zero the parameter gradients
outputs = model(images) # forward pass
loss = criterion(outputs, labels) # calculate loss
loss.backward() # backward pass
optimizer.step() # update weights
running_loss += loss.item() * images.size(0) # accumulate loss
# compute training accuracy
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
average_loss = running_loss / len(train_loader.dataset) # calculate average loss
accuracy = correct / total # calculate accuracy
return average_loss, accuracyTest function
It is preferable (but not mandatory) to place test code in a function and call it when needed.
For instance, during training for validation at each epoch, or after training for testing, or for deploying the model.
# test function that evaluates a given model on a given dataloader and returns the average loss and
def test(model, test_loader, criterion):
model.eval() # set the model to evaluation mode
running_loss = 0.0
correct = 0
total = 0
with torch.no_grad(): # disable gradient calculation
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images) # forward pass
loss = criterion(outputs, labels) # calculate loss
running_loss += loss.item() * images.size(0) # accumulate loss
_, predicted = torch.max(outputs.data, 1) # get the index of the max log-probability
total += labels.size(0) # update total count
correct += (predicted == labels).sum().item()
average_loss = running_loss / len(test_loader.dataset) # calculate average loss
accuracy = correct / total # calculate accuracy
return average_loss, accuracyTrain Model / Test Model
The code below also includes visual loss/accuracy monitoring during training, both on training and validation sets.
# run the training experiment using the previously defined training loop function, and after each epoch evaluate the model on the validation set, and update the plot accordingly, and also save the best model (which is the one that achieved the best accuracy on the validation set)
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []
best_val_accuracy = 0.0
best_val_accuracy_epoch = 0
best_val_loss = float('inf')
best_val_loss_epoch = 0
best_model_state = None
for epoch in range(hyperparameters['num_epochs']):
epoch_start = time.time()
train_loss, train_accuracy = train(model, train_loader, criterion, optimizer) # train for one epoch
val_loss, val_accuracy = test(model, val_loader, criterion) # evaluate on validation set
train_losses.append(train_loss)
val_losses.append(val_loss)
if scheduler is not None:
scheduler.step()
current_lr = optimizer.param_groups[0]['lr']
train_accuracies.append(train_accuracy)
val_accuracies.append(val_accuracy)
epoch_time = time.time() - epoch_start
print(f"Epoch {epoch+1}/{hyperparameters['num_epochs']}, Time: {epoch_time:.2f}s, LR: {current_lr:.6f}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Train Acc: {train_accuracy:.4f}, Val Acc: {val_accuracy:.4f}")
# track best validation metrics
if val_loss < best_val_loss:
best_val_loss = val_loss
best_val_loss_epoch = epoch + 1
if val_accuracy > best_val_accuracy:
best_val_accuracy = val_accuracy
best_val_accuracy_epoch = epoch + 1
best_model_state = model.state_dict()
# plot metrics per epoch (live update)
plt.figure(figsize=(12, 4))
epochs = list(range(1, len(train_losses) + 1))
plt.subplot(1, 2, 1)
plt.plot(epochs, train_losses, label='Train Loss', marker='o')
plt.plot(epochs, val_losses, label=f'Val Loss (best {best_val_loss:.4f} @ {best_val_loss_epoch})', marker='o')
plt.title('Loss per epoch')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.yscale('log')
plt.xlim(0, 30)
plt.xticks(range(0, 31, 2))
plt.legend()
plt.subplot(1, 2, 2)
train_accuracies_pct = [x * 100 for x in train_accuracies]
val_accuracies_pct = [x * 100 for x in val_accuracies]
plt.plot(epochs, train_accuracies_pct, label='Train Accuracy (%)', marker='o')
plt.plot(epochs, val_accuracies_pct, label=f'Val Accuracy (%) (best {best_val_accuracy*100:.2f}% @ {best_val_accuracy_epoch})', marker='o')
plt.title('Accuracy per epoch')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.xlim(0, 30)
plt.xticks(range(0, 31, 2))
plt.yticks(range(95, 101, 1))
plt.ylim(95, 100)
plt.grid(axis='y', linestyle='--', alpha=0.5)
plt.legend()
plt.tight_layout()
plt.show()
# after training, load the best model and evaluate it on the test set
model.load_state_dict(best_model_state) # load the best model state
test_loss, test_accuracy = test(model, test_loader, criterion) # evaluate on test set
print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}')Visual inspection
Visual inspection of the incorrect predictions
# Visual inspection of incorrect predictions (uses the trained model and test loader in this notebook)
model.eval()
error_images = []
error_pred_labels = []
error_true_labels = []
with torch.no_grad():
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
# model outputs are 10-d vectors; targets are class indices
preds = torch.argmax(outputs, dim=1)
trues = labels
mismatch = preds != trues
if mismatch.any():
error_images.append(images[mismatch].cpu())
error_pred_labels.append(preds[mismatch].cpu())
error_true_labels.append(trues[mismatch].cpu())
if len(error_images) == 0:
print('No errors found on test set.')
else:
error_images = torch.cat(error_images)
error_pred_labels = torch.cat(error_pred_labels)
error_true_labels = torch.cat(error_true_labels)
n_errors = len(error_images)
print(f'NUMBERS OF ERRORS: {n_errors}')
# de-normalize for visualization: x_vis = x_norm * std + mean
vis_mean = float(mean) if torch.is_tensor(mean) else mean
vis_std = float(std) if torch.is_tensor(std) else std
vis_images = (error_images * vis_std + vis_mean).clamp(0.0, 1.0)
# display all misclassified samples; adapt grid rows to the number of errors
n_cols = 10
n_rows = int(np.ceil(n_errors / n_cols))
image_grid = torchvision.utils.make_grid(vis_images, nrow=n_cols)
npimg = image_grid.numpy()
plt.figure(figsize=(12, max(4, n_rows * 1.2)))
plt.imshow(npimg[0], cmap='gray', vmin=0.0, vmax=1.0)
plt.title(f'Errors (all {n_errors} samples, {n_rows} rows x {n_cols} cols)')
plt.axis('off')
plt.show()
# print all mismatch labels
for i in range(n_errors):
print(f'{i+1}: true={int(error_true_labels[i])}, pred={int(error_pred_labels[i])}')