Why GPU Memory Matters
GPU memory (VRAM) is the primary constraint for deep learning. Running out of VRAM causes CUDA out of memory errors and crashes your training run.
Profiling Memory Usage
import torch
# Check current allocation
print(torch.cuda.memory_allocated() / 1e9, "GB allocated")
print(torch.cuda.memory_reserved() / 1e9, "GB reserved")
# Detailed snapshot
torch.cuda.memory_summary(device=0, abbreviated=False)Technique 1 — Mixed Precision Training (AMP)
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
for batch in dataloader:
with autocast():
output = model(batch)
loss = criterion(output, targets)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()Reduces memory by ~50% with minimal accuracy loss.
Technique 2 — Gradient Checkpointing
from torch.utils.checkpoint import checkpoint_sequential
# Recompute activations during backward pass instead of storing them
output = checkpoint_sequential(model.layers, segments=4, input=x)Reduces activation memory at the cost of ~30% slower training.
Technique 3 — Gradient Accumulation
accumulation_steps = 4
optimizer.zero_grad()
for i, batch in enumerate(dataloader):
output = model(batch)
loss = criterion(output, targets) / accumulation_steps
loss.backward()
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()Simulates a larger batch size without increasing memory.
Technique 4 — 4-bit Quantisation (QLoRA)
from transformers import BitsAndBytesConfig
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-70b-hf",
quantization_config=bnb_config,
)Reduces a 70B model from ~140 GB to ~35 GB.
Technique 5 — Offload to CPU/NVMe
# DeepSpeed ZeRO-Infinity
ds_config = {
"zero_optimization": {
"stage": 3,
"offload_optimizer": {"device": "cpu"},
"offload_param": {"device": "nvme", "nvme_path": "/mnt/nvme"},
}
}Summary
| Technique | Memory Saving | Speed Impact |
|---|---|---|
| AMP (FP16) | ~50% | +20% faster |
| Gradient Checkpointing | ~40% | -30% slower |
| Gradient Accumulation | Constant | Neutral |
| 4-bit QLoRA | ~75% | -15% slower |
| CPU Offload | ~80% | -50% slower |
