Execution environment

The code below requires the dependencies used in the import block (torch, torchvision, matplotlib, numpy). A practical option is to run the notebook in Google Colab with a GPU runtime enabled.

Import the required packages

Insert all the packages you require.

# import required packages for a training/testing exdperiment with MNIST image classfication using only standard MLPs
 
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import torchvision
import numpy as np
 
print("PyTorch version:", torch.__version__)

Hyperparameters and Options

Set the hyperparameters.

# this will contain a list of all hyperparameters that will be used in our experiment like learning rate, momentum, batch size, etc.
hyperparameters = {
    'learning_rate': 0.01,
    'momentum': 0.5,
    'nesterov': True,
    'batch_size': 10,
    'label_smoothing': 0,
    'dropout_rate': 0.1,
    'use_batch_norm': False,
    'num_epochs': 30,
    'validation_split': 0.2,
    'use_lr_scheduler': True,
    'lr_scheduler_step_size': 15,
    'lr_scheduler_gamma': 0.5,
    'early_stopping_patience': 10,
    'use_data_augmentation': True,
    'augmentation_probability': 0.5,
    'num_workers': 12,
    'use_persistent_workers': True,
    'random_seed': 42
}
 
seed = hyperparameters['random_seed']
torch.manual_seed(seed)
np.random.seed(seed)
 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print('Random seed:', seed)

Model architecture

Define the network

  • init: initialization of the fully connected (fc) layers of the network
  • forward: phase where the input is propagated to the next levels, it provides the output of the network
class MLP(nn.Module):
    def __init__(self, dropout_rate=0.2, use_batch_norm=False):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(784, 800)
        self.bn1 = nn.BatchNorm1d(800) if use_batch_norm else nn.Identity()
        self.fc2 = nn.Linear(800, 400)
        self.bn2 = nn.BatchNorm1d(400) if use_batch_norm else nn.Identity()
        self.fc3 = nn.Linear(400, 10)
        self.act = nn.GELU()
        self.dropout = nn.Dropout(dropout_rate)
        self.apply(self._initialize_weights)
 
    def _initialize_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
 
    def forward(self, x):
        x = x.view(-1, 784)  # flatten the input
        x = self.fc1(x)
        x = self.bn1(x)
        x = self.act(x)
        x = self.dropout(x)
 
        x = self.fc2(x)
        x = self.bn2(x)
        x = self.act(x)
        x = self.dropout(x)
 
        x = self.fc3(x)
        return x

Building blocks for training

Instance of the network:

  • Loss Function: to guide learning, quantify how close the actual output is to the desired output
  • Optimizer: weight update method (stochastic gradient descent)
  • Learning Rate Scheduler: adjust the learning rate based on the number of epochs
# define objects like CrossEntropy loss function, SGD optimizer and no more things
model = MLP(
    dropout_rate=hyperparameters.get('dropout_rate', 0.2),
    use_batch_norm=hyperparameters.get('use_batch_norm', False)
).to(device)
criterion = nn.CrossEntropyLoss(label_smoothing=hyperparameters.get('label_smoothing', 0.1)).to(device)
optimizer = optim.SGD(model.parameters(), lr=hyperparameters['learning_rate'], momentum=hyperparameters['momentum'], nesterov=hyperparameters['nesterov']   )
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=hyperparameters['lr_scheduler_step_size'], gamma=hyperparameters['lr_scheduler_gamma']) if hyperparameters.get('use_lr_scheduler', False) else None

Model Summary and Trainable Hyperparameters

Use summary to inspect model structure and count trainable parameters.

try:
    from torchinfo import summary
except ImportError:
    !pip -q install torchinfo
    from torchinfo import summary
 
model_summary = summary(
    model,
    input_size=(1, 1, 28, 28),
    device=str(device),
    depth=3,
    col_names=["input_size", "output_size", "num_params", "trainable"],
    verbose=2,
)
 
print(f"Total parameters: {model_summary.total_params:,}")
print(f"Trainable parameters: {model_summary.trainable_params:,}")

Create datasets

Split training / validation set.

Strategy: split MNIST training set into 80% train and 20% validation (controlled by validation_split), and keep the full MNIST test set for testing.

# download MNIST training and test datasets separately
mnist_train = datasets.MNIST(root='./data', train=True, download=True)
mnist_test = datasets.MNIST(root='./data', train=False, download=True)
 
# split official MNIST training into train/validation
validation_split = hyperparameters.get('validation_split', 0.2)
train_size = int((1.0 - validation_split) * len(mnist_train))
val_size = len(mnist_train) - train_size
split_generator = torch.Generator().manual_seed(hyperparameters['random_seed'])
indices = torch.randperm(len(mnist_train), generator=split_generator).tolist()
train_indices = indices[:train_size]
val_indices = indices[train_size:]
 
# create separate dataset objects so train/val can use different transforms later
train_dataset = datasets.MNIST(root='./data', train=True, download=False)
val_dataset = datasets.MNIST(root='./data', train=True, download=False)
test_dataset = mnist_test
 
train_dataset = torch.utils.data.Subset(train_dataset, train_indices)
val_dataset = torch.utils.data.Subset(val_dataset, val_indices)
 
# calculate the mean and standard deviation on the train split only
to_tensor = transforms.ToTensor()
stats_num_workers = hyperparameters.get('num_workers', 2)
stats_dataset = datasets.MNIST(root='./data', train=True, download=False, transform=to_tensor)
stats_subset = torch.utils.data.Subset(stats_dataset, train_indices)
stats_loader = DataLoader(
    stats_subset,
    batch_size=hyperparameters['batch_size'],
    shuffle=False,
    num_workers=stats_num_workers
    )
mean = 0.0
std = 0.0
num_batches = 0
for images, _ in stats_loader:
    batch_mean = images.mean()
    batch_std = images.std()
    mean += batch_mean
    std += batch_std
    num_batches += 1
mean /= num_batches
std /= num_batches
print(f'Mean: {mean}, Std: {std}')
print(f'Train size: {train_size}, Validation size: {val_size}, Test size: {len(test_dataset)}')

Data Check

The most recurring problem in Machine Learning and Deep Learning is that the model is fed with wrong data, caused by incorrect data loading, processinge ecc.

A good practice is to calculate and print statistics.

# visual check of random samples from training set together with their dataset labels
def imshow(img):
    npimg = np.array(img)
    plt.imshow(npimg, cmap='gray')
    plt.show()
 
# check some samples directly from the dataset
for i in range(5):
    img, label = train_dataset[i]
    print(f'Sample {i+1}: Image type: {type(img)}, Size: {img.size}, Mode: {img.mode}, Label: {label}')
    imshow(img)

Data transforms

Data transforms are applied at batch generation time.

They serve to transform your data into what the neural network expects.

Data should be converted to tensors whose shape corresponds to the network input and possibly normalized so as to be 0-centered roughly in the range

# create transforms like PIL image to torch tensor and data normalization
# also create separa complex data augmentation for MNIST
augmentation_probability = hyperparameters.get('augmentation_probability', 0.5)
 
# fixed random permutation of pixel positions (same permutation for all samples)
permutation_generator = torch.Generator().manual_seed(hyperparameters['random_seed'])
fixed_pixel_permutation = torch.randperm(28 * 28, generator=permutation_generator)
 
def apply_fixed_pixel_permutation(img_tensor):
    channels, height, width = img_tensor.shape
    return img_tensor.reshape(-1)[fixed_pixel_permutation].reshape(channels, height, width)
 
base_train_transforms = [
    transforms.ToTensor(),
    #transforms.Lambda(apply_fixed_pixel_permutation),
    transforms.Normalize((mean,), (std,)),  # mean and std for MNIST
]
 
# MNIST-friendly augmentation: small geometric perturbations applied with configurable probability
augmentation_transforms = [
    transforms.RandomApply([
        transforms.RandomAffine(
            degrees=12,
            translate=(0.10, 0.10),
            scale=(0.95, 1.05),
            shear=8
        ),
    ], p=augmentation_probability),
    transforms.RandomApply([
        transforms.RandomRotation(10),
    ], p=augmentation_probability),
]
 
train_transform = transforms.Compose(
    (augmentation_transforms if hyperparameters.get('use_data_augmentation', False) else []) + base_train_transforms
)
 
transform = transforms.Compose([
    transforms.ToTensor(),
   # transforms.Lambda(apply_fixed_pixel_permutation),
    transforms.Normalize((mean,), (std,)),  # mean and std for MNIST
])
 
# apply the transforms to the datasets
train_dataset.dataset.transform = train_transform
val_dataset.dataset.transform = transform
test_dataset.transform = transform
train_dataset.dataset.target_transform = None
val_dataset.dataset.target_transform = None
test_dataset.target_transform = None

Data loaders

Dataloaders are in-built PyTorch objects that serve to sample batches from datasets.

# prepare data loaders for training, validation, and testing
loader_num_workers = hyperparameters.get('num_workers', 2)
use_persistent_workers = hyperparameters.get('use_persistent_workers', True) and loader_num_workers > 0
pin_memory = torch.cuda.is_available()
 
train_loader = DataLoader(
    train_dataset,
    batch_size=hyperparameters['batch_size'],
    shuffle=True,
    num_workers=loader_num_workers,
    pin_memory=pin_memory,
    persistent_workers=use_persistent_workers
 )
val_loader = DataLoader(
    val_dataset,
    batch_size=512,
    shuffle=False,
    num_workers=loader_num_workers,
    pin_memory=pin_memory,
    persistent_workers=use_persistent_workers
 )
test_loader = DataLoader(
    test_dataset,
    batch_size=512,
    shuffle=False,
    num_workers=loader_num_workers,
    pin_memory=pin_memory,
    persistent_workers=use_persistent_workers
 )

Data check after transforms

To check what the network will see at train/test time, you have to use dataloaders which will apply the data transforms previously defined.

# check some samples directly from the dataset
for i in range(5):
    img, label = train_dataset[i]
    #transform img (torch.tensor) to 2D numpy array
    img = img.squeeze().numpy()  # remove channel dimension and convert to numpy array
    #print(f'Sample {i+1}: Image type: {type(img)}, Size: {img.size}, Mode: {img.mode}, Label: {label}')
    imshow(img)

Train function

It is preferable (but not mandatory) to embed training (1 epoch) code into a function, and call that function later during the training phase, at each epoch.

# create training loop function that will train the model for one epoch and return average loss and accuracy
def train(model, train_loader, criterion, optimizer):
    model.train()  # set the model to training mode
    running_loss = 0.0
    correct = 0
    total = 0
    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()  # zero the parameter gradients
        outputs = model(images)  # forward pass
        loss = criterion(outputs, labels)  # calculate loss
        loss.backward()  # backward pass
        optimizer.step()  # update weights
 
        running_loss += loss.item() * images.size(0)  # accumulate loss
 
        # compute training accuracy
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
 
    average_loss = running_loss / len(train_loader.dataset)  # calculate average loss
    accuracy = correct / total  # calculate accuracy
    return average_loss, accuracy

Test function

It is preferable (but not mandatory) to embed the test code into a function, and call that function whenever needed.

For instance, during training for validation at each epoch, or after training for testing, or for deploying the model.

# test function that evaluates a given model on a given dataloader and returns the average loss and
def test(model, test_loader, criterion):
    model.eval()  # set the model to evaluation mode
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():  # disable gradient calculation
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)  # forward pass
            loss = criterion(outputs, labels)  # calculate loss
            running_loss += loss.item() * images.size(0)  # accumulate loss
            _, predicted = torch.max(outputs.data, 1)  # get the index of the max log-probability
            total += labels.size(0)  # update total count
            correct += (predicted == labels).sum().item()
    average_loss = running_loss / len(test_loader.dataset)  # calculate average loss
    accuracy = correct / total  # calculate accuracy
    return average_loss, accuracy

Train Model / Test Model

The code below also includes visual loss/accuracy monitoring during training, both on training and validation sets.

# run the training experiment using the previously defined training loop function, and after each epoch evaluate the model on the validation set, and update the plot accordingly, and also save the best model (which is the one that achieved the best accuracy on the validation set)
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []
best_val_accuracy = 0.0
best_val_accuracy_epoch = 0
best_val_loss = float('inf')
best_val_loss_epoch = 0
best_model_state = None
 
for epoch in range(hyperparameters['num_epochs']):
    epoch_start = time.time()
    train_loss, train_accuracy = train(model, train_loader, criterion, optimizer)  # train for one epoch
    val_loss, val_accuracy = test(model, val_loader, criterion)  # evaluate on validation set
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    if scheduler is not None:
        scheduler.step()
    current_lr = optimizer.param_groups[0]['lr']
    train_accuracies.append(train_accuracy)
    val_accuracies.append(val_accuracy)
    epoch_time = time.time() - epoch_start
 
    print(f"Epoch {epoch+1}/{hyperparameters['num_epochs']}, Time: {epoch_time:.2f}s, LR: {current_lr:.6f}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Train Acc: {train_accuracy:.4f}, Val Acc: {val_accuracy:.4f}")
 
    # track best validation metrics
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_val_loss_epoch = epoch + 1
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        best_val_accuracy_epoch = epoch + 1
        best_model_state = model.state_dict()
 
    # plot metrics per epoch (live update)
    plt.figure(figsize=(12, 4))
    epochs = list(range(1, len(train_losses) + 1))
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, label='Train Loss', marker='o')
    plt.plot(epochs, val_losses, label=f'Val Loss (best {best_val_loss:.4f} @ {best_val_loss_epoch})', marker='o')
    plt.title('Loss per epoch')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.yscale('log')
    plt.xlim(0, 30)
    plt.xticks(range(0, 31, 2))
    plt.legend()
    
    plt.subplot(1, 2, 2)
    train_accuracies_pct = [x * 100 for x in train_accuracies]
    val_accuracies_pct = [x * 100 for x in val_accuracies]
    plt.plot(epochs, train_accuracies_pct, label='Train Accuracy (%)', marker='o')
    plt.plot(epochs, val_accuracies_pct, label=f'Val Accuracy (%) (best {best_val_accuracy*100:.2f}% @ {best_val_accuracy_epoch})', marker='o')
    plt.title('Accuracy per epoch')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.xlim(0, 30)
    plt.xticks(range(0, 31, 2))
    plt.yticks(range(95, 101, 1))
    plt.ylim(95, 100)
    plt.grid(axis='y', linestyle='--', alpha=0.5)
    plt.legend()
    plt.tight_layout()
    plt.show()
 
# after training, load the best model and evaluate it on the test set
model.load_state_dict(best_model_state)  # load the best model state
test_loss, test_accuracy = test(model, test_loader, criterion)  # evaluate on test set
print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}')

Visual inspection

Visual inspection of the incorrect predictions

# Visual inspection of incorrect predictions (uses the trained model and test loader in this notebook)
model.eval()
error_images = []
error_pred_labels = []
error_true_labels = []
with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        
        # model outputs are 10-d vectors; targets are class indices
        preds = torch.argmax(outputs, dim=1)
        trues = labels
 
        mismatch = preds != trues
        if mismatch.any():
            error_images.append(images[mismatch].cpu())
            error_pred_labels.append(preds[mismatch].cpu())
            error_true_labels.append(trues[mismatch].cpu())
 
if len(error_images) == 0:
    print('No errors found on test set.')
else:
    error_images = torch.cat(error_images)
    error_pred_labels = torch.cat(error_pred_labels)
    error_true_labels = torch.cat(error_true_labels)
    n_errors = len(error_images)
    print(f'NUMBERS OF ERRORS: {n_errors}')
 
    # de-normalize for visualization: x_vis = x_norm * std + mean
    vis_mean = float(mean) if torch.is_tensor(mean) else mean
    vis_std = float(std) if torch.is_tensor(std) else std
    vis_images = (error_images * vis_std + vis_mean).clamp(0.0, 1.0)
    
    # display all misclassified samples; adapt grid rows to the number of errors
    n_cols = 10
    n_rows = int(np.ceil(n_errors / n_cols))
    image_grid = torchvision.utils.make_grid(vis_images, nrow=n_cols)
    npimg = image_grid.numpy()
 
    plt.figure(figsize=(12, max(4, n_rows * 1.2)))
    plt.imshow(npimg[0], cmap='gray', vmin=0.0, vmax=1.0)
    plt.title(f'Errors (all {n_errors} samples, {n_rows} rows x {n_cols} cols)')
    plt.axis('off')
    plt.show()
 
    # print all mismatch labels
    for i in range(n_errors):
        print(f'{i+1}: true={int(error_true_labels[i])}, pred={int(error_pred_labels[i])}')