latentbrief
← Back to concepts

Concept

Gradient Checkpointing

A memory-efficiency technique that trades extra computation for reduced memory usage during training by discarding intermediate activations and recomputing them during the backward pass rather than storing them throughout.

Added May 18, 2026

During training, the forward pass through a neural network generates intermediate activations at each layer - values that must be stored because they are needed during the backward pass to compute gradients. For large models trained on long sequences, these stored activations can consume more GPU memory than the model parameters themselves.

Consider a transformer trained on 4096-token sequences: the attention matrices alone at each layer are quadratic in sequence length, and a model with 80 layers stores 80 sets of these across the full batch. The memory required can easily reach hundreds of gigabytes, far exceeding what any single GPU can hold.

Gradient checkpointing (also called activation checkpointing or rematerialisation) addresses this by selectively discarding intermediate activations during the forward pass and recomputing them during the backward pass. Only a subset of activations - the "checkpoints" - are retained. When the backward pass reaches a segment between checkpoints, it recomputes the forward pass for that segment from the nearest checkpoint, regenerating the needed activations on demand.

The tradeoff is explicit: gradient checkpointing reduces memory usage at the cost of additional compute. A naive implementation that checkpoints at every layer computes each layer's forward pass twice, increasing training time by roughly 30-40%. Selective checkpointing strategies checkpoint only at certain boundaries (every N layers, or at the boundaries of model components), reducing the compute penalty.

In practice, gradient checkpointing enables training models that would otherwise be impossible to fit in GPU memory. A model that requires 80GB of activation memory can be trained on a 40GB GPU with checkpointing, at the cost of ~30% longer training time. For research teams without access to the largest GPUs, this tradeoff is often worth it.

Modern training frameworks (PyTorch, JAX, Hugging Face Accelerate) provide gradient checkpointing as a high-level option that can be enabled without rewriting the model code. In PyTorch, `torch.utils.checkpoint.checkpoint()` wraps any module to enable checkpointing for that module's activations.

Gradient checkpointing is often combined with other memory-efficiency techniques: mixed precision training (storing activations in lower precision), ZeRO sharding (distributing activations across devices), and CPU offloading (storing some activations in CPU RAM rather than GPU VRAM) to achieve the best balance of memory usage and training speed.

Analogy

A student taking an exam using scratch paper with limited space. Rather than keeping every intermediate calculation on the scratch paper at all times, they do the work and erase intermediate steps, keeping only key results. If they need to review how they reached a conclusion, they redo the intermediate steps from the last saved result. They use more time but far less space - exactly the tradeoff gradient checkpointing makes.

Real-world example

A team training a language model finds that a 7B parameter model with 8K context length requires 120GB of GPU VRAM for activations alone when using a batch size needed for efficient training. With gradient checkpointing enabled at layer boundaries (every 8 layers), memory usage drops to 35GB - fitting on a single A100 80GB GPU - at the cost of 28% longer training time. The compute cost is acceptable; fitting the training run on available hardware is not.

Why it matters

Gradient checkpointing is one of the core techniques that made training large models feasible on the GPU hardware available during the early scaling era. Understanding it explains a key tradeoff in ML infrastructure: memory versus compute are substitutable, and the right balance depends on what constrains you. It is also relevant for practitioners fine-tuning large models, where activation memory is often the binding constraint.

In the news

No recent coverage - check back later.

Related concepts