Every deep-learning training run, regardless of architecture, dataset, or framework, reduces at the innermost level to a handful of operations performed in a specific order on each mini-batch. Understanding the pipeline as a whole, rather than as a collection of separate concerns, is what separates a practitioner who can debug and tune from one who can only follow a template.
This note synthesizes the entire training loop. Each section is brief because the mathematical and conceptual depth lives in the dedicated notes; the value added here is the integration, the order of operations, and the cross-cutting concerns (state, mode, memory) that no single component-level note can address.
The canonical training step
Every other concern in this note revolves around five operations executed per mini-batch, in this order:
optimizer.zero_grad() # 1. reset accumulated gradients
predictions = model(inputs) # 2. forward pass
loss = loss_fn(predictions, targets) # 3. loss computation
loss.backward() # 4. backward pass (autograd)
optimizer.step() # 5. parameter updateWrap this in a loop over mini-batches and you have one epoch; wrap that in a loop over epochs and you have an entire training run. Everything else (scheduler stepping, validation passes, checkpointing, logging) is scaffolding around this five-line atomic unit.
The order is not arbitrary
Each of the five calls depends on the state produced by the previous one. Reordering them silently produces incorrect training:
zero_grad()before forward: necessary, becauseloss.backward()accumulates into.grad(it does not overwrite). Without zeroing, every step uses the sum of all past gradients.loss.backward()after forward and loss: requires the autograd graph built during the forward pass and the loss as its scalar root.optimizer.step()after backward: consumes the gradients now sitting in each parameter’s.grad.scheduler.step()(when present) afteroptimizer.step(): the scheduler’s effect lands on the next iteration’s learning rate, not the current one.Most “training is broken” bugs in deep-learning code trace to a violation of one of these orderings.
1. Setup and initialization
Before the loop can start, several stateful objects must be created and placed on the right compute device:
- the model (parameters with their initial values);
- the loss function (typically stateless: cross-entropy, MSE, etc.);
- the optimizer, bound to the model’s parameters;
- the (optional) learning-rate scheduler, bound to the optimizer;
- the data loaders for training and validation.
The initial values of the model parameters are not random in the colloquial sense: they are drawn from a carefully chosen distribution whose variance is calibrated to keep activations and gradients in a workable range across all layers. The wrong initialization can make the network untrainable before the first gradient step is computed.
Initialization is part of the architecture
The right initialization depends on the activation functions used. The two standard choices are derived in Xavier and He initialization:
Activation Recommended init Variance Why / sigmoid / linear Xavier (Glorot) symmetric, near-linear around the origin ReLU / Leaky ReLU / PReLU He (Kaiming) compensates the variance halving from Specialized cells have specialized initializations: an LSTM forget-gate bias should be set to about rather than so the cell state survives the first iterations; an Adam-trained Transformer benefits from a small initial-LR warmup that compensates for the early moment-estimate noise.
model = MyModel().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)2. Feed-forward pass
The forward pass evaluates the model on one mini-batch, producing predictions:
predictions = model(inputs)Three things happen simultaneously inside that single line.
- The numerical predictions are computed layer by layer through the architecture.
- The intermediate activations of every layer are cached in memory, because backpropagation needs them.
- The framework records the autograd computation graph: a dynamic DAG that tracks every operation applied to a tensor with
requires_grad=True.
The third point is the conceptual lever PyTorch provides. The graph is not a static description of the model architecture; it is built on the fly by the actual operations executed during the forward pass. This is why control flow (conditionals, loops over time steps in an RNN, dynamic shapes) is natural in PyTorch: every iteration of training rebuilds the graph fresh from whatever code ran.
Activation memory is the dominant cost
For a network with layers and per-layer activation size , the forward pass holds floats in memory until the backward pass releases them. For large models this typically exceeds the parameter memory by an order of magnitude and is the dominant constraint on usable batch size. Techniques like gradient checkpointing recompute activations on the backward pass rather than storing them, trading compute for memory; mixed precision stores activations in
float16orbfloat16, halving the cost; activation offloading moves them to CPU between forward and backward.
The architectures that fill in the model(inputs) call are covered in their dedicated sections: MLPs, LSTMs, GRUs, CNNs, and so on. From the pipeline’s perspective, all of them are interchangeable black boxes that produce predictions.
3. Loss function and monitoring
The loss reduces the model’s predictions and the ground-truth targets to a single scalar quantifying the prediction error on the current mini-batch:
loss = loss_fn(predictions, targets)The choice of loss is dictated by the task and by the form of the network’s output layer. The canonical pairings are:
| Task | Output layer | Loss | Why this pairing |
|---|---|---|---|
| Binary classification | sigmoid ( unit) | BCELoss / BCEWithLogitsLoss | cancels the saturating sigmoid slope in the gradient |
| Multi-class classification | softmax ( units) | CrossEntropyLoss | analogous cancellation; standard in classification |
| Regression | linear ( unit) | MSELoss / L1Loss | direct prediction vs target comparison |
| Multi-label classification | sigmoid ( independent units) | BCEWithLogitsLoss | each label is an independent binary problem |
| Sequence modelling | softmax per position | CrossEntropyLoss per token, summed | teacher forcing during training |
The pairing of output activation and loss is not cosmetic: certain combinations cancel a saturating term in the gradient that would otherwise produce the output-layer learning slowdown. The cross-entropy + softmax canonical pair is the textbook example.
What to monitor, beyond the loss
The loss on the current mini-batch is noisy and not directly informative about generalization. A standard monitoring discipline maintains four series of values:
Series Where computed What it tells you Training loss per step every mini-batch optimizer stability (noisy but should trend down) Training loss per epoch average over the epoch smoothed training progress Validation loss per epoch on held-out validation set generalization; primary signal for overfitting Task metric per epoch (accuracy, F1, BLEU, …) on training and validation the actual quantity of interest, which loss only approximates The gap between training loss and validation loss is the canonical overfitting signal; the gap between loss and task metric is a reminder that loss is a differentiable proxy, not the goal itself.
4. Backpropagation and gradients
A single line triggers the most computationally demanding part of training:
loss.backward()This call executes the backpropagation algorithm (for feedforward networks) or backpropagation through time (for recurrent networks) on the computation graph built during the forward pass. The chain rule is applied automatically, layer by layer, from the loss back to every parameter that contributed to it. The result is stored in each parameter’s .grad attribute, ready to be consumed by the optimizer.
The mathematical content is treated in MLP backpropagation for feedforward networks and in Backpropagation through time for recurrent ones. From the pipeline’s perspective, loss.backward() is a single expensive call whose runtime is roughly comparable to one forward pass, and which doubles the memory footprint temporarily because it traverses the cached activations.
Understanding what
.backward()actually doesCalling
loss.backward()looks like magic the first time, because so much is hidden inside one method invocation: graph traversal, chain-rule application, gradient accumulation. The cleanest way to demystify it is to build the autograd machinery from scratch: see Micrograd, a walkthrough of Andrej Karpathy’s miniature autodiff engine. Working through Micrograd makes it concrete which data structure traverses (the computation graph), in which order the gradients are computed (reverse topological order), how they accumulate at each node (additively, which is whyzero_grad()is needed), and where exactly in the training loop the call belongs. After Micrograd,.backward()stops being magic and starts being mechanical.
loss.backward()accumulates, it does not assignThe most common source of silent training bugs:
loss.backward()adds the new gradient to whatever is already in.grad. Ifoptimizer.zero_grad()is not called before the backward pass, gradients from previous iterations remain summed in, and the optimizer effectively uses the sum of all gradients computed so far. This is sometimes done intentionally (gradient accumulation, to simulate a larger effective batch when memory is limited), but the intent must be explicit. Always either zero the gradients or document the accumulation pattern.
Gradient checkpoints, clipping, and other interventions
Several common operations sit between the backward pass and the optimizer step:
- Gradient clipping: limits the norm of the gradient to prevent exploding-gradient instabilities, particularly important in recurrent networks. Inserted as
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)afterloss.backward().- Gradient checkpointing: trades compute for memory by recomputing activations during the backward pass rather than storing them.
- Mixed-precision scaling: under
torch.cuda.amp, the loss is scaled before backward and unscaled before the optimizer step to avoid float16 underflow.
5. Parameters update and optimizers
The optimizer reads the gradients from .grad and updates each parameter according to its rule:
optimizer.step()The choice of optimizer determines the update rule, the per-parameter memory footprint, and the convergence behaviour. The major families:
| Optimizer | Update mechanism | Memory cost | Strongest in |
|---|---|---|---|
| SGD | small models, simple problems | ||
| SGD + Momentum | adds inertial velocity buffer | CNNs, late-stage training, fine-tuning | |
| NAG | look-ahead gradient | same as Momentum, slightly better trajectory | |
| AdaGrad | per-parameter scaling | sparse-feature problems | |
| RMSProp | per-parameter EMA of | RNNs, varying gradient scales | |
| Adam | EMAs of and , bias-corrected | random init, Transformers, modern default | |
| AdamW | Adam + decoupled weight decay | universal modern default with weight decay |
The selection criteria are developed in Choosing an optimizer; the mathematical derivations of each algorithm are in their respective notes. The pipeline-level point is that optimizer.step() is a single call whose internal cost is small (proportional to parameter count) but whose memory cost scales with the optimizer family: at parameters in float32, Adam consumes an extra just for its momentum and second-moment buffers.
After optimizer.step(), if a learning-rate scheduler is in use, it advances:
scheduler.step()The scheduler’s role is to modulate the global learning rate over the training run, complementary to the per-parameter adaptation built into the optimizer. The mechanics and trade-offs of scheduling are covered in Learning rate scheduling and its dedicated family notes.
6. Mode: train vs eval
PyTorch models carry a mode flag that determines the behaviour of layers whose dynamics depend on whether training is in progress:
model.train() # set the model to training mode
model.eval() # set the model to evaluation modeThe flag affects two layer families specifically. Everything else (linear layers, convolutions, activations, attention) behaves identically in both modes.
| Layer | behaviour | behaviour |
|---|---|---|
| Dropout | active: zeroes elements with probability , scales survivors by | inactive: passes activations through unchanged |
| Batch normalization | uses batch mean/variance for normalization; updates running stats | uses running mean/variance accumulated during training; does not update them |
The single most common deep-learning bug
Forgetting to call
model.eval()before validation or inference is the most common silent bug in PyTorch code. Symptoms:
- validation accuracy fluctuates randomly between runs even on the same data (dropout is still firing);
- validation loss is systematically higher than it should be;
- inference results are non-deterministic for inputs that should be deterministic;
- in production, predictions vary mysteriously between identical API calls.
The fix is one line, but the failure mode is invisible: the code runs, no exception is raised, and the model’s outputs are nonsense only in the statistical sense.
model.eval() is separate from disabling gradient computation. During validation, both should typically be active:
model.eval()
with torch.no_grad():
for inputs, targets in val_loader:
predictions = model(inputs)
val_loss = loss_fn(predictions, targets)The model.eval() call switches dropout and batchnorm into evaluation behaviour; the torch.no_grad() context disables autograd graph construction, which saves substantial memory because activations no longer need to be cached. Forgetting torch.no_grad() during validation does not produce wrong results, but it can cause out-of-memory crashes on large models because validation memory grows to match training memory.
7. Iterations and epochs
Two units of time structure the training loop, and confusing them is a frequent source of misconfiguration. The vocabulary and its library-dependent subtleties are treated in depth in Neural Networks vocabulary §5 (“Iteration, step, update, epoch”), which also catalogs the API-level traps (for example, scikit-learn’s max_iter often counts epochs, while in PyTorch and most deep-learning theory an “iteration” is one optimizer update). The compact version used in the rest of this note is the following.
| Unit | Definition | Granularity |
|---|---|---|
| Iteration (also: step, batch) | one execution of the five-line training step on one mini-batch | finest |
| Epoch | one complete pass through the training dataset | coarsest |
If the training set contains examples and the mini-batch size is , one epoch consists of iterations. A training run is typically specified in epochs (e.g., ” epochs”), but framework internals operate per-iteration.
Scheduler frequency: the silent-misconfiguration trap
The learning-rate scheduler advances each time
scheduler.step()is called. Two equally valid conventions coexist:
- Per-epoch stepping:
scheduler.step()is called once after the inner loop over batches. The scheduler’s hyperparameters (T_max,step_size,T_0, etc.) are then expressed in epochs.- Per-iteration stepping:
scheduler.step()is called inside the inner loop, after eachoptimizer.step(). The scheduler’s hyperparameters are then expressed in iterations.Both are mathematically correct. The bug is mixing them: stepping per iteration while choosing
T_max = num_epochscauses the schedule to complete in the first epoch and then stay at for the rest of training. The detailed implications are derived in Learning rate scheduling §3.
The full training loop, putting all seven sections together:
for epoch in range(num_epochs):
model.train()
for inputs, targets in train_loader:
optimizer.zero_grad()
predictions = model(inputs)
loss = loss_fn(predictions, targets)
loss.backward()
optimizer.step()
scheduler.step() # per-epoch convention
model.eval()
with torch.no_grad():
val_loss = sum(loss_fn(model(x), y) for x, y in val_loader) / len(val_loader)
log(epoch, train_loss=..., val_loss=val_loss)
if validation_has_stopped_improving():
break # early stoppingThe holistic view: state, mode, and stopping
Three concerns become visible only from the pipeline-level view, not from any individual section.
What carries state across iterations
A training run is a state machine. To checkpoint or resume it correctly, every stateful component must be captured.
| Component | Stateful across iterations? | Must be checkpointed? |
|---|---|---|
| Model parameters | yes | yes |
| Optimizer buffers (momentum, EMAs) | yes | yes (skipping this resets the optimizer’s history) |
| Scheduler state (step counter, current ) | yes | yes (skipping resets the schedule) |
Gradient buffers .grad | per iteration only (zeroed each step) | no |
| Activation memory | per iteration only (released after backward) | no |
| Data-loader iterator state | per epoch | optional but useful for exact reproducibility |
| RNG state (PyTorch, NumPy, Python) | continuous | yes for bit-exact reproducibility |
A standard torch.save({...}) checkpoint includes the model state_dict, the optimizer state_dict, the scheduler state_dict, and the current epoch. Omitting the optimizer state from the checkpoint is one of the most common reasons a “resumed” training run silently produces worse results than continuous training: the optimizer’s momentum and second-moment estimates are silently re-initialized to zero.
Memory peaks and where they happen
The pipeline has a characteristic memory profile that determines the maximum trainable batch size:
| Phase | Memory consumed | Notes |
|---|---|---|
| Setup | parameters + optimizer state | constant for the whole run |
| Forward pass | + cached activations | grows linearly with depth |
| Loss computation | tiny | scalar plus a few intermediates |
| Backward pass | + gradient buffers | activations released as they are consumed; gradient buffers grow |
| Optimizer step | gradients consumed | brief surge then released |
The peak typically lands at the start of the backward pass, when all forward-pass activations are still cached and the first gradient tensors begin to be allocated. This is why gradient checkpointing (recompute activations during backward) and mixed precision (halve activation memory) are the standard interventions for fitting larger models or larger batches on the same hardware.
Common mistakes catalog
The five-mistake list
Most “my model trained but produces garbage” bugs are one of these:
- Forgot
optimizer.zero_grad(): gradients accumulate; the optimizer uses the sum of all gradients computed so far, not the gradient of the current batch.- Forgot
model.eval()during validation: dropout still active, batchnorm uses batch statistics; validation metrics are noisy and overly optimistic about training-data overlap.- Forgot
torch.no_grad()during validation: results are correct but memory usage doubles, often causing OOM.- Called
scheduler.step()beforeoptimizer.step(): the first iteration uses the learning rate of step instead of step , and the schedule is off by one throughout training.- Mixed per-iteration and per-epoch scheduler conventions: the schedule completes in the first epoch (per-iteration stepping with epoch-sized
T_max) or essentially doesn’t move (per-epoch stepping with iteration-sizedT_max).None of these produce a runtime error. All of them silently degrade training results. The ordered five-line skeleton at the top of this note encodes the prevention of all five.
The pipeline is small, and its individual operations are simple. Most training-time difficulty in deep learning is not in the pipeline itself but in choosing the right architecture, the right loss, the right optimizer, the right learning rate schedule, the right regularization, and the right hyperparameters for the task at hand. The pipeline is the constant chassis; the components and their settings are the variables. Mastering the pipeline frees attention to focus on those variables, where the actual experimental decisions live.