SGD, Adam, AdamW, learning rate schedules, warmup, and gradient clipping for training
60% of ML interviews
Powers systems at Every training and fine-tuning job
Right optimizer can 2-3x training speed query improvement
TL;DR
AdamW is the standard optimizer for transformers. Use warmup to prevent early instability, cosine decay for pre-training, linear decay for fine-tuning, and gradient clipping to prevent explosions. Fine-tuning needs 10-100x smaller learning rates than pre-training.
Visual Overview
SGD UPDATE
+-----------------------------------------------------------+
| |
| w = w - lr x gradient |
| |
| Where: |
| w = weights |
| lr = learning rate |
| gradient = dLoss/dw |
| |
+-----------------------------------------------------------+
SGD PROBLEMS
+-----------------------------------------------------------+
| |
| 1. OSCILLATION IN VALLEYS |
| Loss surface has steep sides, shallow floor |
| SGD bounces side-to-side, slow progress forward |
| |
| /\ /\ /\ |
| / \ / \ / \ |
| / \/ \/ \--> minimum |
| |<-bounce->|<-bounce->| |
| |
| 2. SAME LEARNING RATE FOR ALL |
| Some parameters need big updates, others small |
| Single LR can't satisfy both |
| |
+-----------------------------------------------------------+
Momentum
Add “velocity” to SGD. Accumulate gradient direction over time.
MOMENTUM UPDATE
+-----------------------------------------------------------+
| |
| velocity = beta x velocity + gradient |
| w = w - lr x velocity |
| |
| Where: |
| beta = momentum coefficient (typically 0.9) |
| velocity = accumulated gradient direction |
| |
+-----------------------------------------------------------+
MOMENTUM INTUITION
+-----------------------------------------------------------+
| |
| Ball rolling downhill. |
| |
| Without momentum: |
| Step 1: gradient = [1, 0] -> move [1, 0] |
| Step 2: gradient = [-1, 0.1] -> move [-1, 0.1] |
| (oscillating!) |
| |
| With momentum (beta=0.9): |
| Step 1: velocity = [1, 0] -> move [1, 0] |
| Step 2: velocity = 0.9x[1,0] + [-1, 0.1] |
| = [0.9, 0] + [-1, 0.1] |
| = [-0.1, 0.1] -> much smaller |
| oscillation! |
| |
| Consistent direction gets amplified. |
| Oscillating direction gets dampened. |
| |
+-----------------------------------------------------------+
Adam (Adaptive Moment Estimation)
Combines momentum with adaptive learning rates per parameter.
ADAM UPDATE
+-----------------------------------------------------------+
| |
| m = beta1 x m + (1-beta1) x gradient # 1st moment |
| v = beta2 x v + (1-beta2) x gradient^2 # 2nd moment |
| |
| m_hat = m / (1 - beta1^t) # Bias correct |
| v_hat = v / (1 - beta2^t) |
| |
| w = w - lr x m_hat / (sqrt(v_hat) + eps) |
| |
| Default hyperparameters: |
| beta1 = 0.9 (momentum decay) |
| beta2 = 0.999 (variance decay) |
| eps = 1e-8 (numerical stability) |
| |
+-----------------------------------------------------------+
WHY ADAM WORKS WELL
+-----------------------------------------------------------+
| |
| 1. MOMENTUM (m) |
| Same as SGD+momentum -- damps oscillation |
| |
| 2. ADAPTIVE LR (v) |
| Parameters with large gradients -> smaller LR |
| Parameters with small gradients -> larger LR |
| |
| High-gradient param: lr / sqrt(large_v) = small step |
| Low-gradient param: lr / sqrt(small_v) = large step |
| |
| 3. BIAS CORRECTION |
| Early steps: m and v are biased toward 0 |
| Correction compensates for this |
| |
+-----------------------------------------------------------+
Adam is the default choice for most deep learning.
AdamW
Adam with proper weight decay. The standard for transformers.
ADAMW VS ADAM + L2
+-----------------------------------------------------------+
| |
| Adam + L2 regularization (WRONG): |
| gradient = task_gradient + lambda x w |
| ... adam update with this gradient ... |
| |
| Problem: Weight decay mixed into adaptive LR |
| High-variance params get less regularization |
| |
| AdamW (CORRECT): |
| gradient = task_gradient # NO weight decay here |
| ... adam update with this gradient ... |
| w = w - lr x lambda x w # Decay applied AFTER |
| |
| Weight decay truly decoupled from optimization. |
| |
+-----------------------------------------------------------+
Always use AdamW for transformers, not Adam with L2.
Learning Rate Schedules
Learning rate should change during training. High initially, lower later.
Linear Decay
LINEAR DECAY
+-----------------------------------------------------------+
| |
| lr = initial_lr x (1 - step/total_steps) |
| |
| lr |
| | |
| 1 |* |
| | \ |
| | \ |
| 0.5 | \ |
| | \ |
| 0 | * |
| +---------- |
| 0 steps T |
| |
+-----------------------------------------------------------+
Cosine Decay
COSINE DECAY
+-----------------------------------------------------------+
| |
| lr = lr_min + 0.5 x (lr_max - lr_min) |
| x (1 + cos(pi x step/total_steps)) |
| |
| lr |
| | |
| 1 |*\ |
| | \ |
| | \ |
| 0.5 | \ |
| | \___ |
| 0 | * |
| +---------- |
| 0 steps T |
| |
| Smooth decay, lingers longer at mid-range LR. |
| Often works better than linear. |
| |
+-----------------------------------------------------------+
Warmup
WARMUP
+-----------------------------------------------------------+
| |
| Start with very low LR, ramp up, then decay. |
| |
| lr |
| | *---\ |
| | / \ |
| | / \ |
| 0.5 | / \ |
| | / \ |
| 0 |*-/ * |
| +---------------------- |
| 0 warmup peak T |
| |
| Warmup prevents early instability. |
| Gradients are noisy at start (random weights). |
| High LR + noisy gradients = explosion. |
| |
+-----------------------------------------------------------+
Typical warmup: 1-5% of total training steps.
Common Configurations
Pre-training LLM
Optimizer: AdamW
beta1 = 0.9, beta2 = 0.95
Weight decay: 0.1
LR: 1e-4 to 3e-4
Schedule: Cosine decay with warmup
Warmup: 2000 steps (or 1% of total)
Fine-tuning
Optimizer: AdamW
beta1 = 0.9, beta2 = 0.999
Weight decay: 0.01
LR: 1e-5 to 5e-5 (10-100x smaller than pre-training)
Schedule: Linear decay with warmup
Warmup: 3-5% of total steps
Quick Reference
| Scenario | LR | Schedule | Warmup |
|---|---|---|---|
| Pre-training LLM | 1e-4 - 3e-4 | Cosine | 1-2% |
| Fine-tuning LLM | 1e-5 - 5e-5 | Linear | 3-5% |
| Fine-tuning BERT | 2e-5 - 5e-5 | Linear | 10% |
| Training CNN | 1e-3 | Step decay | None |
Gradient Clipping
Limit gradient magnitude to prevent explosions.
GRADIENT CLIPPING
+-----------------------------------------------------------+
| |
| Clip by global norm (most common): |
| total_norm = sqrt(SUM(gradient^2)) |
| if total_norm > max_norm: |
| gradient = gradient x (max_norm / total_norm) |
| |
| Typical max_norm: 1.0 |
| |
| Prevents single bad batch from destroying model. |
| |
+-----------------------------------------------------------+
When to use:
- Always for transformers
- When training is unstable
- When loss spikes occasionally
Debugging Optimization
LOSS NOT DECREASING
+-----------------------------------------------------------+
| |
| Symptoms: |
| - Loss stays flat from start |
| - Or decreases very slowly |
| |
| Causes: |
| - Learning rate too low |
| - Warmup too long |
| - Wrong optimizer config |
| |
| Debug steps: |
| 1. Try 10x higher LR |
| 2. Reduce warmup steps |
| 3. Check gradient values (should be non-zero) |
| 4. Verify optimizer is updating weights |
| |
+-----------------------------------------------------------+
LOSS EXPLODES (NaN)
+-----------------------------------------------------------+
| |
| Symptoms: |
| - Loss suddenly becomes NaN |
| - Training crashes |
| |
| Causes: |
| - Learning rate too high |
| - Missing gradient clipping |
| - No warmup |
| |
| Debug steps: |
| 1. Add gradient clipping (max_norm=1.0) |
| 2. Reduce LR by 10x |
| 3. Add warmup (5% of steps) |
| 4. Check for numerical issues in data |
| |
+-----------------------------------------------------------+
FINE-TUNING DESTROYS MODEL
+-----------------------------------------------------------+
| |
| Symptoms: |
| - After fine-tuning, model is worse than base |
| - "Catastrophic forgetting" |
| |
| Causes: |
| - Learning rate too high |
| - No warmup |
| - Weight decay too high |
| |
| Debug steps: |
| 1. Use much smaller LR (1e-5 or lower) |
| 2. Add warmup |
| 3. Reduce weight decay |
| 4. Consider LoRA instead of full fine-tuning |
| |
+-----------------------------------------------------------+
When This Matters
| Situation | What to know |
|---|---|
| Training transformers | Use AdamW, not Adam |
| Fine-tuning | LR 10-100x smaller than pre-training |
| Training unstable | Add warmup, gradient clipping |
| Loss not decreasing | Try higher LR |
| Loss exploding | Lower LR, add gradient clipping |
| Understanding configs | beta1=momentum, beta2=variance averaging |
| Choosing schedule | Cosine for pre-training, linear for fine-tuning |