Every deep-learning training run, regardless of architecture, dataset, or framework, reduces at the innermost level to a handful of operations performed in a specific order on each mini-batch. Understanding the pipeline as a whole, rather than as a collection of separate concerns, is what separates a practitioner who can debug and tune from one who can only follow a template.

This note synthesizes the entire training loop. Each section is brief because the mathematical and conceptual depth lives in the dedicated notes; the value added here is the integration, the order of operations, and the cross-cutting concerns (state, mode, memory) that no single component-level note can address.

The canonical training step

Every other concern in this note revolves around five operations executed per mini-batch, in this order:

optimizer.zero_grad()              # 1. reset accumulated gradients
predictions = model(inputs)        # 2. forward pass
loss = loss_fn(predictions, targets)  # 3. loss computation
loss.backward()                    # 4. backward pass (autograd)
optimizer.step()                   # 5. parameter update

Wrap this in a loop over mini-batches and you have one epoch; wrap that in a loop over epochs and you have an entire training run. Everything else (scheduler stepping, validation passes, checkpointing, logging) is scaffolding around this five-line atomic unit.

The order is not arbitrary

Each of the five calls depends on the state produced by the previous one. Reordering them silently produces incorrect training:

  • zero_grad() before forward: necessary, because loss.backward() accumulates into .grad (it does not overwrite). Without zeroing, every step uses the sum of all past gradients.
  • loss.backward() after forward and loss: requires the autograd graph built during the forward pass and the loss as its scalar root.
  • optimizer.step() after backward: consumes the gradients now sitting in each parameter’s .grad.
  • scheduler.step() (when present) after optimizer.step(): the scheduler’s effect lands on the next iteration’s learning rate, not the current one.

Most “training is broken” bugs in deep-learning code trace to a violation of one of these orderings.

1. Setup and initialization

Before the loop can start, several stateful objects must be created and placed on the right compute device:

  • the model (parameters with their initial values);
  • the loss function (typically stateless: cross-entropy, MSE, etc.);
  • the optimizer, bound to the model’s parameters;
  • the (optional) learning-rate scheduler, bound to the optimizer;
  • the data loaders for training and validation.

The initial values of the model parameters are not random in the colloquial sense: they are drawn from a carefully chosen distribution whose variance is calibrated to keep activations and gradients in a workable range across all layers. The wrong initialization can make the network untrainable before the first gradient step is computed.

Initialization is part of the architecture

The right initialization depends on the activation functions used. The two standard choices are derived in Xavier and He initialization:

ActivationRecommended initVarianceWhy
/ sigmoid / linearXavier (Glorot)symmetric, near-linear around the origin
ReLU / Leaky ReLU / PReLUHe (Kaiming)compensates the variance halving from

Specialized cells have specialized initializations: an LSTM forget-gate bias should be set to about rather than so the cell state survives the first iterations; an Adam-trained Transformer benefits from a small initial-LR warmup that compensates for the early moment-estimate noise.

model = MyModel().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader   = DataLoader(val_dataset,   batch_size=64, shuffle=False)

2. Feed-forward pass

The forward pass evaluates the model on one mini-batch, producing predictions:

predictions = model(inputs)

Three things happen simultaneously inside that single line.

  1. The numerical predictions are computed layer by layer through the architecture.
  2. The intermediate activations of every layer are cached in memory, because backpropagation needs them.
  3. The framework records the autograd computation graph: a dynamic DAG that tracks every operation applied to a tensor with requires_grad=True.

The third point is the conceptual lever PyTorch provides. The graph is not a static description of the model architecture; it is built on the fly by the actual operations executed during the forward pass. This is why control flow (conditionals, loops over time steps in an RNN, dynamic shapes) is natural in PyTorch: every iteration of training rebuilds the graph fresh from whatever code ran.

Activation memory is the dominant cost

For a network with layers and per-layer activation size , the forward pass holds floats in memory until the backward pass releases them. For large models this typically exceeds the parameter memory by an order of magnitude and is the dominant constraint on usable batch size. Techniques like gradient checkpointing recompute activations on the backward pass rather than storing them, trading compute for memory; mixed precision stores activations in float16 or bfloat16, halving the cost; activation offloading moves them to CPU between forward and backward.

The architectures that fill in the model(inputs) call are covered in their dedicated sections: MLPs, LSTMs, GRUs, CNNs, and so on. From the pipeline’s perspective, all of them are interchangeable black boxes that produce predictions.

3. Loss function and monitoring

The loss reduces the model’s predictions and the ground-truth targets to a single scalar quantifying the prediction error on the current mini-batch:

loss = loss_fn(predictions, targets)

The choice of loss is dictated by the task and by the form of the network’s output layer. The canonical pairings are:

TaskOutput layerLossWhy this pairing
Binary classificationsigmoid ( unit)BCELoss / BCEWithLogitsLosscancels the saturating sigmoid slope in the gradient
Multi-class classificationsoftmax ( units)CrossEntropyLossanalogous cancellation; standard in classification
Regressionlinear ( unit)MSELoss / L1Lossdirect prediction vs target comparison
Multi-label classificationsigmoid ( independent units)BCEWithLogitsLosseach label is an independent binary problem
Sequence modellingsoftmax per positionCrossEntropyLoss per token, summedteacher forcing during training

The pairing of output activation and loss is not cosmetic: certain combinations cancel a saturating term in the gradient that would otherwise produce the output-layer learning slowdown. The cross-entropy + softmax canonical pair is the textbook example.

What to monitor, beyond the loss

The loss on the current mini-batch is noisy and not directly informative about generalization. A standard monitoring discipline maintains four series of values:

SeriesWhere computedWhat it tells you
Training loss per stepevery mini-batchoptimizer stability (noisy but should trend down)
Training loss per epochaverage over the epochsmoothed training progress
Validation loss per epochon held-out validation setgeneralization; primary signal for overfitting
Task metric per epoch (accuracy, F1, BLEU, …)on training and validationthe actual quantity of interest, which loss only approximates

The gap between training loss and validation loss is the canonical overfitting signal; the gap between loss and task metric is a reminder that loss is a differentiable proxy, not the goal itself.

4. Backpropagation and gradients

A single line triggers the most computationally demanding part of training:

loss.backward()

This call executes the backpropagation algorithm (for feedforward networks) or backpropagation through time (for recurrent networks) on the computation graph built during the forward pass. The chain rule is applied automatically, layer by layer, from the loss back to every parameter that contributed to it. The result is stored in each parameter’s .grad attribute, ready to be consumed by the optimizer.

The mathematical content is treated in MLP backpropagation for feedforward networks and in Backpropagation through time for recurrent ones. From the pipeline’s perspective, loss.backward() is a single expensive call whose runtime is roughly comparable to one forward pass, and which doubles the memory footprint temporarily because it traverses the cached activations.

Understanding what .backward() actually does

Calling loss.backward() looks like magic the first time, because so much is hidden inside one method invocation: graph traversal, chain-rule application, gradient accumulation. The cleanest way to demystify it is to build the autograd machinery from scratch: see Micrograd, a walkthrough of Andrej Karpathy’s miniature autodiff engine. Working through Micrograd makes it concrete which data structure traverses (the computation graph), in which order the gradients are computed (reverse topological order), how they accumulate at each node (additively, which is why zero_grad() is needed), and where exactly in the training loop the call belongs. After Micrograd, .backward() stops being magic and starts being mechanical.

loss.backward() accumulates, it does not assign

The most common source of silent training bugs: loss.backward() adds the new gradient to whatever is already in .grad. If optimizer.zero_grad() is not called before the backward pass, gradients from previous iterations remain summed in, and the optimizer effectively uses the sum of all gradients computed so far. This is sometimes done intentionally (gradient accumulation, to simulate a larger effective batch when memory is limited), but the intent must be explicit. Always either zero the gradients or document the accumulation pattern.

Gradient checkpoints, clipping, and other interventions

Several common operations sit between the backward pass and the optimizer step:

  • Gradient clipping: limits the norm of the gradient to prevent exploding-gradient instabilities, particularly important in recurrent networks. Inserted as torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) after loss.backward().
  • Gradient checkpointing: trades compute for memory by recomputing activations during the backward pass rather than storing them.
  • Mixed-precision scaling: under torch.cuda.amp, the loss is scaled before backward and unscaled before the optimizer step to avoid float16 underflow.

5. Parameters update and optimizers

The optimizer reads the gradients from .grad and updates each parameter according to its rule:

optimizer.step()

The choice of optimizer determines the update rule, the per-parameter memory footprint, and the convergence behaviour. The major families:

OptimizerUpdate mechanismMemory costStrongest in
SGDsmall models, simple problems
SGD + Momentumadds inertial velocity bufferCNNs, late-stage training, fine-tuning
NAGlook-ahead gradientsame as Momentum, slightly better trajectory
AdaGradper-parameter scalingsparse-feature problems
RMSPropper-parameter EMA of RNNs, varying gradient scales
AdamEMAs of and , bias-correctedrandom init, Transformers, modern default
AdamWAdam + decoupled weight decayuniversal modern default with weight decay

The selection criteria are developed in Choosing an optimizer; the mathematical derivations of each algorithm are in their respective notes. The pipeline-level point is that optimizer.step() is a single call whose internal cost is small (proportional to parameter count) but whose memory cost scales with the optimizer family: at parameters in float32, Adam consumes an extra just for its momentum and second-moment buffers.

After optimizer.step(), if a learning-rate scheduler is in use, it advances:

scheduler.step()

The scheduler’s role is to modulate the global learning rate over the training run, complementary to the per-parameter adaptation built into the optimizer. The mechanics and trade-offs of scheduling are covered in Learning rate scheduling and its dedicated family notes.

6. Mode: train vs eval

PyTorch models carry a mode flag that determines the behaviour of layers whose dynamics depend on whether training is in progress:

model.train()    # set the model to training mode
model.eval()     # set the model to evaluation mode

The flag affects two layer families specifically. Everything else (linear layers, convolutions, activations, attention) behaves identically in both modes.

Layer behaviour behaviour
Dropoutactive: zeroes elements with probability , scales survivors by inactive: passes activations through unchanged
Batch normalizationuses batch mean/variance for normalization; updates running statsuses running mean/variance accumulated during training; does not update them

The single most common deep-learning bug

Forgetting to call model.eval() before validation or inference is the most common silent bug in PyTorch code. Symptoms:

  • validation accuracy fluctuates randomly between runs even on the same data (dropout is still firing);
  • validation loss is systematically higher than it should be;
  • inference results are non-deterministic for inputs that should be deterministic;
  • in production, predictions vary mysteriously between identical API calls.

The fix is one line, but the failure mode is invisible: the code runs, no exception is raised, and the model’s outputs are nonsense only in the statistical sense.

model.eval() is separate from disabling gradient computation. During validation, both should typically be active:

model.eval()
with torch.no_grad():
    for inputs, targets in val_loader:
        predictions = model(inputs)
        val_loss = loss_fn(predictions, targets)

The model.eval() call switches dropout and batchnorm into evaluation behaviour; the torch.no_grad() context disables autograd graph construction, which saves substantial memory because activations no longer need to be cached. Forgetting torch.no_grad() during validation does not produce wrong results, but it can cause out-of-memory crashes on large models because validation memory grows to match training memory.

7. Iterations and epochs

Two units of time structure the training loop, and confusing them is a frequent source of misconfiguration. The vocabulary and its library-dependent subtleties are treated in depth in Neural Networks vocabulary §5 (“Iteration, step, update, epoch”), which also catalogs the API-level traps (for example, scikit-learn’s max_iter often counts epochs, while in PyTorch and most deep-learning theory an “iteration” is one optimizer update). The compact version used in the rest of this note is the following.

UnitDefinitionGranularity
Iteration (also: step, batch)one execution of the five-line training step on one mini-batchfinest
Epochone complete pass through the training datasetcoarsest

If the training set contains examples and the mini-batch size is , one epoch consists of iterations. A training run is typically specified in epochs (e.g., ” epochs”), but framework internals operate per-iteration.

Scheduler frequency: the silent-misconfiguration trap

The learning-rate scheduler advances each time scheduler.step() is called. Two equally valid conventions coexist:

  • Per-epoch stepping: scheduler.step() is called once after the inner loop over batches. The scheduler’s hyperparameters (T_max, step_size, T_0, etc.) are then expressed in epochs.
  • Per-iteration stepping: scheduler.step() is called inside the inner loop, after each optimizer.step(). The scheduler’s hyperparameters are then expressed in iterations.

Both are mathematically correct. The bug is mixing them: stepping per iteration while choosing T_max = num_epochs causes the schedule to complete in the first epoch and then stay at for the rest of training. The detailed implications are derived in Learning rate scheduling §3.

The full training loop, putting all seven sections together:

for epoch in range(num_epochs):
    model.train()
    for inputs, targets in train_loader:
        optimizer.zero_grad()
        predictions = model(inputs)
        loss = loss_fn(predictions, targets)
        loss.backward()
        optimizer.step()
    scheduler.step()                      # per-epoch convention
 
    model.eval()
    with torch.no_grad():
        val_loss = sum(loss_fn(model(x), y) for x, y in val_loader) / len(val_loader)
 
    log(epoch, train_loss=..., val_loss=val_loss)
    if validation_has_stopped_improving():
        break                              # early stopping

The holistic view: state, mode, and stopping

Three concerns become visible only from the pipeline-level view, not from any individual section.

What carries state across iterations

A training run is a state machine. To checkpoint or resume it correctly, every stateful component must be captured.

ComponentStateful across iterations?Must be checkpointed?
Model parameters yesyes
Optimizer buffers (momentum, EMAs)yesyes (skipping this resets the optimizer’s history)
Scheduler state (step counter, current )yesyes (skipping resets the schedule)
Gradient buffers .gradper iteration only (zeroed each step)no
Activation memoryper iteration only (released after backward)no
Data-loader iterator stateper epochoptional but useful for exact reproducibility
RNG state (PyTorch, NumPy, Python)continuousyes for bit-exact reproducibility

A standard torch.save({...}) checkpoint includes the model state_dict, the optimizer state_dict, the scheduler state_dict, and the current epoch. Omitting the optimizer state from the checkpoint is one of the most common reasons a “resumed” training run silently produces worse results than continuous training: the optimizer’s momentum and second-moment estimates are silently re-initialized to zero.

Memory peaks and where they happen

The pipeline has a characteristic memory profile that determines the maximum trainable batch size:

PhaseMemory consumedNotes
Setupparameters + optimizer stateconstant for the whole run
Forward pass+ cached activationsgrows linearly with depth
Loss computationtinyscalar plus a few intermediates
Backward pass+ gradient buffersactivations released as they are consumed; gradient buffers grow
Optimizer stepgradients consumedbrief surge then released

The peak typically lands at the start of the backward pass, when all forward-pass activations are still cached and the first gradient tensors begin to be allocated. This is why gradient checkpointing (recompute activations during backward) and mixed precision (halve activation memory) are the standard interventions for fitting larger models or larger batches on the same hardware.

Common mistakes catalog

The five-mistake list

Most “my model trained but produces garbage” bugs are one of these:

  1. Forgot optimizer.zero_grad(): gradients accumulate; the optimizer uses the sum of all gradients computed so far, not the gradient of the current batch.
  2. Forgot model.eval() during validation: dropout still active, batchnorm uses batch statistics; validation metrics are noisy and overly optimistic about training-data overlap.
  3. Forgot torch.no_grad() during validation: results are correct but memory usage doubles, often causing OOM.
  4. Called scheduler.step() before optimizer.step(): the first iteration uses the learning rate of step instead of step , and the schedule is off by one throughout training.
  5. Mixed per-iteration and per-epoch scheduler conventions: the schedule completes in the first epoch (per-iteration stepping with epoch-sized T_max) or essentially doesn’t move (per-epoch stepping with iteration-sized T_max).

None of these produce a runtime error. All of them silently degrade training results. The ordered five-line skeleton at the top of this note encodes the prevention of all five.

The pipeline is small, and its individual operations are simple. Most training-time difficulty in deep learning is not in the pipeline itself but in choosing the right architecture, the right loss, the right optimizer, the right learning rate schedule, the right regularization, and the right hyperparameters for the task at hand. The pipeline is the constant chassis; the components and their settings are the variables. Mastering the pipeline frees attention to focus on those variables, where the actual experimental decisions live.