Overview
This note develops a complete image-classification experiment on MNIST using a standard multilayer perceptron. The goal is to turn the abstract components of neural-network training into an executable pipeline: loading handwritten digit images, preparing batches, defining a fully connected architecture, selecting a loss function and optimizer, training the model, monitoring validation performance, and evaluating how well the network assigns each image to one of the ten digit classes. Although MNIST is intentionally simple, it provides a clean setting for understanding the mechanics of supervised learning before moving to convolutional architectures and larger vision datasets.
Execution environment
The code below requires the dependencies used in the import block (
torch,torchvision,matplotlib,numpy). A practical option is to run the notebook in Google Colab with a GPU runtime enabled.

Import the required packages
Insert all the packages you require.
# import required packages for a training/testing exdperiment with MNIST image classfication using only standard MLPs
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import torchvision
import numpy as npHyperparameters and Options
Experiment with hyperparameters
The values below should be treated as an experimental starting point, not as a fixed recipe. To understand the effect of each choice, it is usually best to change one hyperparameter at a time, rerun the experiment, and compare the loss curves, validation accuracy, and final test performance. Once the role of the main parameters is clear, combinations of changes can be explored more deliberately.
# this will contain a list of all hyperparameters that will be used in our experiment like learning rate, momentum, batch size, etc.
hyperparameters = {
'learning_rate': 2,
'momentum': 0.2,
'nesterov': True,
'batch_size': 10,
'num_epochs': 30,
'use_lr_scheduler': True,
'lr_scheduler_step_size': 10,
'lr_scheduler_gamma': 0.1
}
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)Model architecture
Define the network
- init: initialization of the fully connected (fc) layers of the network
- forward: phase where the input is propagated to the next levels, it provides the output of the network
Architecture as a design choice
The number of hidden layers and the number of neurons in each layer are not fixed rules. The architecture below is only a reasonable starting point for the experiment. Different depths and widths should be tested and compared through training behavior, validation performance, and overfitting tendencies.
# define here a simple MLP architecture with sigmoid activation
class MLP(nn.Module):
def __init__(self):
super(MLP, self).__init__()
self.fc1 = nn.Linear(784, 1000)
self.fc2 = nn.Linear(1000, 1000)
self.fc3 = nn.Linear(1000, 200)
self.fc4 = nn.Linear(200, 10)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = x.view(-1, 784) # flatten the input
x = self.sigmoid(self.fc1(x))
x = self.sigmoid(self.fc2(x))
x = self.sigmoid(self.fc3(x))
x = self.sigmoid(self.fc4(x))
return xBuilding blocks for training
Instance of the network:
- Loss Function: to guide learning, quantify how close the actual output is to the desired output
- Optimizer: weight update method (stochastic gradient descent)
- Learning Rate Scheduler: adjust the learning rate based on the number of epochs
# define objects like MSE loss function, SGD optimizer and no more things
model = MLP().to(device)
criterion = nn.MSELoss().to(device)
optimizer = optim.SGD(model.parameters(), lr=hyperparameters['learning_rate'], momentum=hyperparameters['momentum'], nesterov=hyperparameters['nesterov'])
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=hyperparameters['lr_scheduler_step_size'], gamma=hyperparameters['lr_scheduler_gamma']) if hyperparameters.get('use_lr_scheduler', False) else NoneCreate datasets
Split training / validation set.
Note
MNIST Dataset does not have a validation set, so the test set is used as validation set.
# download MNIST training and test datasets separately
mnist_train = datasets.MNIST(root='./data', train=True, download=True)
mnist_test = datasets.MNIST(root='./data', train=False, download=True)
train_dataset = mnist_train
# we use full MNIST training as train set; split official MNIST test set into validation/test
val_size = int(0.5 * len(mnist_test))
test_size = len(mnist_test) - val_size
val_dataset, test_dataset = torch.utils.data.random_split(mnist_test, [val_size, test_size])
# calculate the mean and standard deviation of the training dataset for normalization
to_tensor = transforms.ToTensor()
train_loader = DataLoader(mnist_train, batch_size=hyperparameters['batch_size'], shuffle=True, num_workers=2, collate_fn=lambda x: (to_tensor(x[0][0]), x[0][1]) if isinstance(x[0], tuple) else (to_tensor(x[0]), x[1]))
mean = 0.0
std = 0.0
num_batches = 0
for images, _ in train_loader:
batch_mean = images.mean()
batch_std = images.std()
mean += batch_mean
std += batch_std
num_batches += 1
mean /= num_batches
std /= num_batches
print(f'Mean: {mean}, Std: {std}')Data Check
The most recurring problem in Machine Learning and Deep Learning is that the model is fed with wrong data, caused by incorrect data loading, processinge etc.
A good practice is to calculate and print statistics.
# visual check of random samples from training set together with their dataset labels
def imshow(img):
npimg = np.array(img)
plt.imshow(npimg, cmap='gray')
plt.show()
# check some samples directly from the dataset
for i in range(5):
img, label = train_dataset[i]
print(f'Sample {i+1}: Image type: {type(img)}, Size: {img.size}, Mode: {img.mode}, Label: {label}')
imshow(img)Data transforms
Data transforms are applied at batch generation time.
They serve to transform your data into what the neural network expects.
Data should be converted to tensors whose shape corresponds to the network input and possibly normalized so as to be 0-centered roughly in the range
# create transform for one-hot encoding of the labels
class OneHotEncode(object):
def __call__(self, label):
one_hot_label = torch.zeros(10) # 10 classes for MNIST
one_hot_label[label] = 1
return one_hot_label
# create transforms like PIL image to torch tensor, data normalization, and one-hot encoding for the labels
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((mean,), (std,)), # mean and std for MNIST
])
# target_transform for labels should take label only (not (image, label))
target_transform = OneHotEncode()
# apply the transforms to the datasets
train_dataset.transform = transform
val_dataset.dataset.transform = transform
test_dataset.dataset.transform = transform
train_dataset.target_transform = target_transform
val_dataset.dataset.target_transform = target_transform
test_dataset.dataset.target_transform = target_transformData loaders
Dataloaders are in-built PyTorch objects that serve to sample batches from datasets.
# prepare data loaders for training, validation, and testing
train_loader = DataLoader(train_dataset, batch_size=hyperparameters['batch_size'], shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=512, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=512, shuffle=False, num_workers=2)Data check after transforms
To check what the network will see at train/test time, you have to use dataloaders which will apply the data transforms previously defined.
# inspect one transformed batch exactly as it will be passed to the model
def inspect_transformed_batch(loader, split_name, mean, std, num_images=8):
images, labels = next(iter(loader))
num_images = min(num_images, images.size(0))
label_indices = labels.argmax(dim=1) if labels.ndim == 2 else labels
image_min = images.min().item()
image_max = images.max().item()
batch_mean = images.mean().item()
batch_std = images.std().item()
print(f'{split_name} batch')
print(f' images shape: {tuple(images.shape)}')
print(f' labels shape: {tuple(labels.shape)}')
print(f' images dtype: {images.dtype}')
print(f' labels dtype: {labels.dtype}')
print(f' image value range after normalization: [{image_min:.4f}, {image_max:.4f}]')
print(f' batch mean after normalization: {batch_mean:.4f}')
print(f' batch std after normalization: {batch_std:.4f}')
print(f' first labels: {label_indices[:num_images].tolist()}')
assert images.ndim == 4, 'Images must have shape [batch_size, channels, height, width].'
assert images.shape[1:] == (1, 28, 28), 'MNIST images should have shape [1, 28, 28] after transforms.'
assert labels.ndim == 2 and labels.shape[1] == 10, 'Labels should be one-hot encoded with 10 classes.'
assert torch.isfinite(images).all(), 'Images contain NaN or infinite values.'
assert torch.isfinite(labels).all(), 'Labels contain NaN or infinite values.'
assert torch.allclose(labels.sum(dim=1), torch.ones(labels.size(0)), atol=1e-6), 'Each label should contain exactly one active class.'
# undo normalization only for visualization
images_to_show = images[:num_images].detach().cpu() * std + mean
images_to_show = images_to_show.clamp(0, 1)
grid = torchvision.utils.make_grid(images_to_show, nrow=num_images, padding=2)
grid = grid.permute(1, 2, 0).numpy()
if grid.shape[-1] == 1:
grid = grid.squeeze(-1)
plt.figure(figsize=(1.4 * num_images, 1.8))
plt.imshow(grid, cmap='gray')
plt.title(f'{split_name} samples after transforms')
plt.axis('off')
plt.show()
inspect_transformed_batch(train_loader, 'Training', mean, std)
inspect_transformed_batch(val_loader, 'Validation', mean, std)
inspect_transformed_batch(test_loader, 'Test', mean, std)Train function
It is preferable (but not mandatory) to embed training ( epoch) code into a function, and call that function later during the training phase, at each epoch.
# create training loop function that will train the model for one epoch and return average loss and accuracy
def train(model, train_loader, criterion, optimizer):
model.train() # set the model to training mode
running_loss = 0.0
correct = 0
total = 0
for images, labels in train_loader:
images = images.to(device)
labels = labels.to(device)
optimizer.zero_grad() # zero the parameter gradients
outputs = model(images) # forward pass
loss = criterion(outputs, labels) # calculate loss
loss.backward() # backward pass
optimizer.step() # update weights
running_loss += loss.item() * images.size(0) # accumulate loss
# compute training accuracy
_, predicted = torch.max(outputs.data, 1)
_, labels_max = torch.max(labels.data, 1)
total += labels.size(0)
correct += (predicted == labels_max).sum().item()
average_loss = running_loss / len(train_loader.dataset) # calculate average loss
accuracy = correct / total # calculate accuracy
return average_loss, accuracyTest function
It is preferable (but not mandatory) to embed the test code into a function, and call that function whenever needed.
For instance, during training for validation at each epoch, or after training for testing, or for deploying the model.
# test function that evaluates a given model on a given dataloader and returns the average loss and
def test(model, test_loader, criterion):
model.eval() # set the model to evaluation mode
running_loss = 0.0
correct = 0
total = 0
with torch.no_grad(): # disable gradient calculation
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images) # forward pass
loss = criterion(outputs, labels) # calculate loss
running_loss += loss.item() * images.size(0) # accumulate loss
_, predicted = torch.max(outputs.data, 1) # get the index of the max log-probability
_, labels_max = torch.max(labels.data, 1) # get the index of the max log-probability for labels
total += labels.size(0) # update total count
correct += (predicted == labels_max).sum().item()
average_loss = running_loss / len(test_loader.dataset) # calculate average loss
accuracy = correct / total # calculate accuracy
return average_loss, accuracyTrain Model / Test Model
The code below also includes visual loss/accuracy monitoring during training, both on training and validation sets.
# run the training experiment using the previously defined training loop function, and after each epoch evaluate the model on the validation set, and update the plot accordingly, and also save the best model (which is the one that achieved the best accuracy on the validation set)
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []
best_val_accuracy = 0.0
best_val_accuracy_epoch = 0
best_val_loss = float('inf')
best_val_loss_epoch = 0
best_model_state = None
for epoch in range(hyperparameters['num_epochs']):
train_loss, train_accuracy = train(model, train_loader, criterion, optimizer) # train for one epoch
val_loss, val_accuracy = test(model, val_loader, criterion) # evaluate on validation set
train_losses.append(train_loss)
val_losses.append(val_loss)
if scheduler is not None:
scheduler.step()
current_lr = optimizer.param_groups[0]['lr']
train_accuracies.append(train_accuracy)
val_accuracies.append(val_accuracy)
print(f"Epoch {epoch+1}/{hyperparameters['num_epochs']}, LR: {current_lr:.6f}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Train Acc: {train_accuracy:.4f}, Val Acc: {val_accuracy:.4f}")
# track best validation metrics
if val_loss < best_val_loss:
best_val_loss = val_loss
best_val_loss_epoch = epoch + 1
if val_accuracy > best_val_accuracy:
best_val_accuracy = val_accuracy
best_val_accuracy_epoch = epoch + 1
best_model_state = model.state_dict()
# plot metrics per epoch (live update)
plt.figure(figsize=(12, 4))
epochs = list(range(1, len(train_losses) + 1))
plt.subplot(1, 2, 1)
plt.plot(epochs, train_losses, label='Train Loss', marker='o')
plt.plot(epochs, val_losses, label=f'Val Loss (best {best_val_loss:.4f} @ {best_val_loss_epoch})', marker='o')
plt.title('Loss per epoch')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.xlim(0, 30)
plt.xticks(range(0, 31, 2))
plt.legend()
plt.subplot(1, 2, 2)
train_accuracies_pct = [x * 100 for x in train_accuracies]
val_accuracies_pct = [x * 100 for x in val_accuracies]
plt.plot(epochs, train_accuracies_pct, label='Train Accuracy (%)', marker='o')
plt.plot(epochs, val_accuracies_pct, label=f'Val Accuracy (%) (best {best_val_accuracy*100:.2f}% @ {best_val_accuracy_epoch})', marker='o')
plt.title('Accuracy per epoch')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.xlim(0, 30)
plt.xticks(range(0, 31, 2))
plt.yticks(range(80, 101, 5))
plt.ylim(80, 100)
plt.grid(axis='y', linestyle='--', alpha=0.5)
plt.legend()
plt.tight_layout()
plt.show()
# after training, load the best model and evaluate it on the test set
model.load_state_dict(best_model_state) # load the best model state
test_loss, test_accuracy = test(model, test_loader, criterion) # evaluate on test set
print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}')Visual inspection
Visual inspection of the incorrect predictions
# Visual inspection of incorrect predictions (uses the trained model and test loader in this notebook)
model.eval()
error_images = []
error_pred_labels = []
error_true_labels = []
with torch.no_grad():
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
# model outputs are 10-d vectors; targets are one-hot encoded vectors
preds = torch.argmax(outputs, dim=1)
trues = torch.argmax(labels, dim=1)
mismatch = preds != trues
if mismatch.any():
error_images.append(images[mismatch].cpu())
error_pred_labels.append(preds[mismatch].cpu())
error_true_labels.append(trues[mismatch].cpu())
if len(error_images) == 0:
print('No errors found on test set.')
else:
error_images = torch.cat(error_images)
error_pred_labels = torch.cat(error_pred_labels)
error_true_labels = torch.cat(error_true_labels)
print(f'NUMBERS OF ERRORS: {len(error_images)}')
# display first 100 misclassified samples (or fewer if fewer errors exist)
n_show = min(100, len(error_images))
image_grid = torchvision.utils.make_grid(error_images[:n_show], nrow=10, normalize=True, scale_each=True)
npimg = image_grid.numpy()
plt.figure(figsize=(12, 12))
plt.imshow(np.transpose(npimg, (1, 2, 0)), cmap='gray')
plt.title(f'Errors (first {n_show} samples)')
plt.axis('off')
plt.show()
# print first 20 mismatch labels
for i in range(min(20, n_show)):
print(f'{i+1}: true={int(error_true_labels[i])}, pred={int(error_pred_labels[i])}')