Overview

This note develops a complete image-classification experiment on MNIST using a standard multilayer perceptron. The goal is to turn the abstract components of neural-network training into an executable pipeline: loading handwritten digit images, preparing batches, defining a fully connected architecture, selecting a loss function and optimizer, training the model, monitoring validation performance, and evaluating how well the network assigns each image to one of the ten digit classes. Although MNIST is intentionally simple, it provides a clean setting for understanding the mechanics of supervised learning before moving to convolutional architectures and larger vision datasets.

Execution environment

The code below requires the dependencies used in the import block (torch, torchvision, matplotlib, numpy). A practical option is to run the notebook in Google Colab with a GPU runtime enabled.

Import the required packages

Insert all the packages you require.

# import required packages for a training/testing exdperiment with MNIST image classfication using only standard MLPs  
import torch  
import torch.nn as nn  
import torch.optim as optim  
from torch.utils.data import DataLoader  
from torchvision import datasets, transforms  
import matplotlib.pyplot as plt  
import torchvision  
import numpy as np

Hyperparameters and Options

Experiment with hyperparameters

The values below should be treated as an experimental starting point, not as a fixed recipe. To understand the effect of each choice, it is usually best to change one hyperparameter at a time, rerun the experiment, and compare the loss curves, validation accuracy, and final test performance. Once the role of the main parameters is clear, combinations of changes can be explored more deliberately.

# this will contain a list of all hyperparameters that will be used in our experiment like learning rate, momentum, batch size, etc.
hyperparameters = {
    'learning_rate': 2,
    'momentum': 0.2,
    'nesterov': True,
    'batch_size': 10,
    'num_epochs': 30,
    'use_lr_scheduler': True,
    'lr_scheduler_step_size': 10,
    'lr_scheduler_gamma': 0.1
}
 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Model architecture

Define the network

  • init: initialization of the fully connected (fc) layers of the network
  • forward: phase where the input is propagated to the next levels, it provides the output of the network

Architecture as a design choice

The number of hidden layers and the number of neurons in each layer are not fixed rules. The architecture below is only a reasonable starting point for the experiment. Different depths and widths should be tested and compared through training behavior, validation performance, and overfitting tendencies.

# define here a simple MLP architecture with sigmoid activation
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(784, 1000)
        self.fc2 = nn.Linear(1000, 1000)
        self.fc3 = nn.Linear(1000, 200)
        self.fc4 = nn.Linear(200, 10)
        self.sigmoid = nn.Sigmoid()
 
    def forward(self, x):
        x = x.view(-1, 784)  # flatten the input
        x = self.sigmoid(self.fc1(x))
        x = self.sigmoid(self.fc2(x))
        x = self.sigmoid(self.fc3(x))
        x = self.sigmoid(self.fc4(x))
        return x

Building blocks for training

Instance of the network:

  • Loss Function: to guide learning, quantify how close the actual output is to the desired output
  • Optimizer: weight update method (stochastic gradient descent)
  • Learning Rate Scheduler: adjust the learning rate based on the number of epochs
# define objects like MSE loss function, SGD optimizer and no more things
model = MLP().to(device)
criterion = nn.MSELoss().to(device)
optimizer = optim.SGD(model.parameters(), lr=hyperparameters['learning_rate'], momentum=hyperparameters['momentum'], nesterov=hyperparameters['nesterov'])
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=hyperparameters['lr_scheduler_step_size'], gamma=hyperparameters['lr_scheduler_gamma']) if hyperparameters.get('use_lr_scheduler', False) else None

Create datasets

Split training / validation set.

Note

MNIST Dataset does not have a validation set, so the test set is used as validation set.

# download MNIST training and test datasets separately
mnist_train = datasets.MNIST(root='./data', train=True, download=True)
mnist_test = datasets.MNIST(root='./data', train=False, download=True)
train_dataset = mnist_train
 
# we use full MNIST training as train set; split official MNIST test set into validation/test
val_size = int(0.5 * len(mnist_test))
test_size = len(mnist_test) - val_size
val_dataset, test_dataset = torch.utils.data.random_split(mnist_test, [val_size, test_size])
 
# calculate the mean and standard deviation of the training dataset for normalization
to_tensor = transforms.ToTensor()
train_loader = DataLoader(mnist_train, batch_size=hyperparameters['batch_size'], shuffle=True, num_workers=2, collate_fn=lambda x: (to_tensor(x[0][0]), x[0][1]) if isinstance(x[0], tuple) else (to_tensor(x[0]), x[1]))
mean = 0.0
std = 0.0
num_batches = 0
for images, _ in train_loader:
    batch_mean = images.mean()
    batch_std = images.std()
    mean += batch_mean
    std += batch_std
    num_batches += 1
mean /= num_batches
std /= num_batches
print(f'Mean: {mean}, Std: {std}')

Data Check

The most recurring problem in Machine Learning and Deep Learning is that the model is fed with wrong data, caused by incorrect data loading, processinge etc.

A good practice is to calculate and print statistics.

# visual check of random samples from training set together with their dataset labels
def imshow(img):
    npimg = np.array(img)
    plt.imshow(npimg, cmap='gray')
    plt.show()
 
# check some samples directly from the dataset
for i in range(5):
    img, label = train_dataset[i]
    print(f'Sample {i+1}: Image type: {type(img)}, Size: {img.size}, Mode: {img.mode}, Label: {label}')
    imshow(img)

Data transforms

Data transforms are applied at batch generation time.

They serve to transform your data into what the neural network expects.

Data should be converted to tensors whose shape corresponds to the network input and possibly normalized so as to be 0-centered roughly in the range

# create transform for one-hot encoding of the labels
class OneHotEncode(object):
    def __call__(self, label):
        one_hot_label = torch.zeros(10)  # 10 classes for MNIST
        one_hot_label[label] = 1
        return one_hot_label
 
# create transforms like PIL image to torch tensor, data normalization, and one-hot encoding for the labels
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((mean,), (std,)),  # mean and std for MNIST
])
 
# target_transform for labels should take label only (not (image, label))
target_transform = OneHotEncode()
 
# apply the transforms to the datasets
train_dataset.transform = transform
val_dataset.dataset.transform = transform
test_dataset.dataset.transform = transform
train_dataset.target_transform = target_transform
val_dataset.dataset.target_transform = target_transform
test_dataset.dataset.target_transform = target_transform

Data loaders

Dataloaders are in-built PyTorch objects that serve to sample batches from datasets.

# prepare data loaders for training, validation, and testing
train_loader = DataLoader(train_dataset, batch_size=hyperparameters['batch_size'], shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=512, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=512, shuffle=False, num_workers=2)

Data check after transforms

To check what the network will see at train/test time, you have to use dataloaders which will apply the data transforms previously defined.

# inspect one transformed batch exactly as it will be passed to the model
def inspect_transformed_batch(loader, split_name, mean, std, num_images=8):
    images, labels = next(iter(loader))
    num_images = min(num_images, images.size(0))
 
    label_indices = labels.argmax(dim=1) if labels.ndim == 2 else labels
    image_min = images.min().item()
    image_max = images.max().item()
    batch_mean = images.mean().item()
    batch_std = images.std().item()
 
    print(f'{split_name} batch')
    print(f'  images shape: {tuple(images.shape)}')
    print(f'  labels shape: {tuple(labels.shape)}')
    print(f'  images dtype: {images.dtype}')
    print(f'  labels dtype: {labels.dtype}')
    print(f'  image value range after normalization: [{image_min:.4f}, {image_max:.4f}]')
    print(f'  batch mean after normalization: {batch_mean:.4f}')
    print(f'  batch std after normalization: {batch_std:.4f}')
    print(f'  first labels: {label_indices[:num_images].tolist()}')
 
    assert images.ndim == 4, 'Images must have shape [batch_size, channels, height, width].'
    assert images.shape[1:] == (1, 28, 28), 'MNIST images should have shape [1, 28, 28] after transforms.'
    assert labels.ndim == 2 and labels.shape[1] == 10, 'Labels should be one-hot encoded with 10 classes.'
    assert torch.isfinite(images).all(), 'Images contain NaN or infinite values.'
    assert torch.isfinite(labels).all(), 'Labels contain NaN or infinite values.'
    assert torch.allclose(labels.sum(dim=1), torch.ones(labels.size(0)), atol=1e-6), 'Each label should contain exactly one active class.'
 
    # undo normalization only for visualization
    images_to_show = images[:num_images].detach().cpu() * std + mean
    images_to_show = images_to_show.clamp(0, 1)
 
    grid = torchvision.utils.make_grid(images_to_show, nrow=num_images, padding=2)
    grid = grid.permute(1, 2, 0).numpy()
    if grid.shape[-1] == 1:
        grid = grid.squeeze(-1)
 
    plt.figure(figsize=(1.4 * num_images, 1.8))
    plt.imshow(grid, cmap='gray')
    plt.title(f'{split_name} samples after transforms')
    plt.axis('off')
    plt.show()
 
 
inspect_transformed_batch(train_loader, 'Training', mean, std)
inspect_transformed_batch(val_loader, 'Validation', mean, std)
inspect_transformed_batch(test_loader, 'Test', mean, std)

Train function

It is preferable (but not mandatory) to embed training ( epoch) code into a function, and call that function later during the training phase, at each epoch.

# create training loop function that will train the model for one epoch and return average loss and accuracy
def train(model, train_loader, criterion, optimizer):
    model.train()  # set the model to training mode
    running_loss = 0.0
    correct = 0
    total = 0
    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()    # zero the parameter gradients
        outputs = model(images)  # forward pass
        loss = criterion(outputs, labels)  # calculate loss
        loss.backward()   # backward pass
        optimizer.step()  # update weights
 
        running_loss += loss.item() * images.size(0)  # accumulate loss
 
  
        # compute training accuracy
        _, predicted = torch.max(outputs.data, 1)
        _, labels_max = torch.max(labels.data, 1)
        total += labels.size(0)
        correct += (predicted == labels_max).sum().item()
 
    average_loss = running_loss / len(train_loader.dataset)  # calculate average loss
    accuracy = correct / total  # calculate accuracy
    return average_loss, accuracy

Test function

It is preferable (but not mandatory) to embed the test code into a function, and call that function whenever needed.

For instance, during training for validation at each epoch, or after training for testing, or for deploying the model.

# test function that evaluates a given model on a given dataloader and returns the average loss and
def test(model, test_loader, criterion):
    model.eval()  # set the model to evaluation mode
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():  # disable gradient calculation
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)  # forward pass
            loss = criterion(outputs, labels)  # calculate loss
            running_loss += loss.item() * images.size(0)  # accumulate loss
            _, predicted = torch.max(outputs.data, 1)  # get the index of the max log-probability
            _, labels_max = torch.max(labels.data, 1)  # get the index of the max log-probability for labels
            total += labels.size(0)  # update total count
            correct += (predicted == labels_max).sum().item()
    average_loss = running_loss / len(test_loader.dataset)  # calculate average loss
    accuracy = correct / total  # calculate accuracy
 
    return average_loss, accuracy

Train Model / Test Model

The code below also includes visual loss/accuracy monitoring during training, both on training and validation sets.

# run the training experiment using the previously defined training loop function, and after each epoch evaluate the model on the validation set, and update the plot accordingly, and also save the best model (which is the one that achieved the best accuracy on the validation set)
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []
best_val_accuracy = 0.0
best_val_accuracy_epoch = 0
best_val_loss = float('inf')
best_val_loss_epoch = 0
best_model_state = None
 
  
for epoch in range(hyperparameters['num_epochs']):
    train_loss, train_accuracy = train(model, train_loader, criterion, optimizer)  # train for one epoch
    val_loss, val_accuracy = test(model, val_loader, criterion)  # evaluate on validation set
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    if scheduler is not None:
        scheduler.step()
    current_lr = optimizer.param_groups[0]['lr']
    train_accuracies.append(train_accuracy)
    val_accuracies.append(val_accuracy)
    print(f"Epoch {epoch+1}/{hyperparameters['num_epochs']}, LR: {current_lr:.6f}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Train Acc: {train_accuracy:.4f}, Val Acc: {val_accuracy:.4f}")
 
 
    # track best validation metrics
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_val_loss_epoch = epoch + 1
 
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        best_val_accuracy_epoch = epoch + 1
        best_model_state = model.state_dict()
 
    # plot metrics per epoch (live update)
    plt.figure(figsize=(12, 4))
    epochs = list(range(1, len(train_losses) + 1))
 
 
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, label='Train Loss', marker='o')
    plt.plot(epochs, val_losses, label=f'Val Loss (best {best_val_loss:.4f} @ {best_val_loss_epoch})', marker='o')
    plt.title('Loss per epoch')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.xlim(0, 30)
    plt.xticks(range(0, 31, 2))
    plt.legend()
 
  
    plt.subplot(1, 2, 2)
    train_accuracies_pct = [x * 100 for x in train_accuracies]
    val_accuracies_pct = [x * 100 for x in val_accuracies]
    plt.plot(epochs, train_accuracies_pct, label='Train Accuracy (%)', marker='o')
    plt.plot(epochs, val_accuracies_pct, label=f'Val Accuracy (%) (best {best_val_accuracy*100:.2f}% @ {best_val_accuracy_epoch})', marker='o')
    plt.title('Accuracy per epoch')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.xlim(0, 30)
    plt.xticks(range(0, 31, 2))
    plt.yticks(range(80, 101, 5))
    plt.ylim(80, 100)
    plt.grid(axis='y', linestyle='--', alpha=0.5)
    plt.legend()
    plt.tight_layout()
    plt.show()
 
# after training, load the best model and evaluate it on the test set
model.load_state_dict(best_model_state)  # load the best model state
test_loss, test_accuracy = test(model, test_loader, criterion)  # evaluate on test set
print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}')

Visual inspection

Visual inspection of the incorrect predictions

# Visual inspection of incorrect predictions (uses the trained model and test loader in this notebook)
model.eval()
error_images = []
error_pred_labels = []
error_true_labels = []
with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
 
        # model outputs are 10-d vectors; targets are one-hot encoded vectors
        preds = torch.argmax(outputs, dim=1)
        trues = torch.argmax(labels, dim=1)
 
        mismatch = preds != trues
        if mismatch.any():
            error_images.append(images[mismatch].cpu())
            error_pred_labels.append(preds[mismatch].cpu())
            error_true_labels.append(trues[mismatch].cpu())
 
 
if len(error_images) == 0:
    print('No errors found on test set.')
else:
    error_images = torch.cat(error_images)
    error_pred_labels = torch.cat(error_pred_labels)
    error_true_labels = torch.cat(error_true_labels)
    
    print(f'NUMBERS OF ERRORS: {len(error_images)}')
 
  
    # display first 100 misclassified samples (or fewer if fewer errors exist)
    n_show = min(100, len(error_images))
    image_grid = torchvision.utils.make_grid(error_images[:n_show], nrow=10, normalize=True, scale_each=True)
    npimg = image_grid.numpy()
    plt.figure(figsize=(12, 12))
    plt.imshow(np.transpose(npimg, (1, 2, 0)), cmap='gray')
    plt.title(f'Errors (first {n_show} samples)')
    plt.axis('off')
    plt.show()
 
    # print first 20 mismatch labels
    for i in range(min(20, n_show)):
        print(f'{i+1}: true={int(error_true_labels[i])}, pred={int(error_pred_labels[i])}')