Skip to content

Optimization

12 min Intermediate gen-ai Interview: 60%

SGD, Adam, AdamW, learning rate schedules, warmup, and gradient clipping for training

💼 60% of ML interviews
Interview Relevance
60% of ML interviews
🏭 Every training and fine-tuning job
Production Impact
Powers systems at Every training and fine-tuning job
Right optimizer can 2-3x training speed
Performance
Right optimizer can 2-3x training speed query improvement

TL;DR

AdamW is the standard optimizer for transformers. Use warmup to prevent early instability, cosine decay for pre-training, linear decay for fine-tuning, and gradient clipping to prevent explosions. Fine-tuning needs 10-100x smaller learning rates than pre-training.

Visual Overview

SGD UPDATE
+-----------------------------------------------------------+
|                                                           |
|   w = w - lr x gradient                                   |
|                                                           |
|   Where:                                                  |
|     w = weights                                           |
|     lr = learning rate                                    |
|     gradient = dLoss/dw                                   |
|                                                           |
+-----------------------------------------------------------+

SGD PROBLEMS
+-----------------------------------------------------------+
|                                                           |
|   1. OSCILLATION IN VALLEYS                               |
|      Loss surface has steep sides, shallow floor          |
|      SGD bounces side-to-side, slow progress forward      |
|                                                           |
|         /\    /\    /\                                    |
|        /  \  /  \  /  \                                   |
|       /    \/    \/    \--> minimum                       |
|       |<-bounce->|<-bounce->|                             |
|                                                           |
|   2. SAME LEARNING RATE FOR ALL                           |
|      Some parameters need big updates, others small       |
|      Single LR can't satisfy both                         |
|                                                           |
+-----------------------------------------------------------+

Momentum

Add “velocity” to SGD. Accumulate gradient direction over time.

MOMENTUM UPDATE
+-----------------------------------------------------------+
|                                                           |
|   velocity = beta x velocity + gradient                   |
|   w = w - lr x velocity                                   |
|                                                           |
|   Where:                                                  |
|     beta = momentum coefficient (typically 0.9)           |
|     velocity = accumulated gradient direction             |
|                                                           |
+-----------------------------------------------------------+

MOMENTUM INTUITION
+-----------------------------------------------------------+
|                                                           |
|   Ball rolling downhill.                                  |
|                                                           |
|   Without momentum:                                       |
|     Step 1: gradient = [1, 0]  -> move [1, 0]             |
|     Step 2: gradient = [-1, 0.1] -> move [-1, 0.1]        |
|     (oscillating!)                                        |
|                                                           |
|   With momentum (beta=0.9):                               |
|     Step 1: velocity = [1, 0]        -> move [1, 0]       |
|     Step 2: velocity = 0.9x[1,0] + [-1, 0.1]              |
|                      = [0.9, 0] + [-1, 0.1]               |
|                      = [-0.1, 0.1]   -> much smaller      |
|                                          oscillation!     |
|                                                           |
|   Consistent direction gets amplified.                    |
|   Oscillating direction gets dampened.                    |
|                                                           |
+-----------------------------------------------------------+

Adam (Adaptive Moment Estimation)

Combines momentum with adaptive learning rates per parameter.

ADAM UPDATE
+-----------------------------------------------------------+
|                                                           |
|   m = beta1 x m + (1-beta1) x gradient     # 1st moment   |
|   v = beta2 x v + (1-beta2) x gradient^2   # 2nd moment   |
|                                                           |
|   m_hat = m / (1 - beta1^t)                # Bias correct |
|   v_hat = v / (1 - beta2^t)                               |
|                                                           |
|   w = w - lr x m_hat / (sqrt(v_hat) + eps)                |
|                                                           |
|   Default hyperparameters:                                |
|     beta1 = 0.9    (momentum decay)                       |
|     beta2 = 0.999  (variance decay)                       |
|     eps = 1e-8     (numerical stability)                  |
|                                                           |
+-----------------------------------------------------------+

WHY ADAM WORKS WELL
+-----------------------------------------------------------+
|                                                           |
|   1. MOMENTUM (m)                                         |
|      Same as SGD+momentum -- damps oscillation            |
|                                                           |
|   2. ADAPTIVE LR (v)                                      |
|      Parameters with large gradients -> smaller LR        |
|      Parameters with small gradients -> larger LR         |
|                                                           |
|      High-gradient param: lr / sqrt(large_v) = small step |
|      Low-gradient param:  lr / sqrt(small_v) = large step |
|                                                           |
|   3. BIAS CORRECTION                                      |
|      Early steps: m and v are biased toward 0             |
|      Correction compensates for this                      |
|                                                           |
+-----------------------------------------------------------+

Adam is the default choice for most deep learning.


AdamW

Adam with proper weight decay. The standard for transformers.

ADAMW VS ADAM + L2
+-----------------------------------------------------------+
|                                                           |
|   Adam + L2 regularization (WRONG):                       |
|     gradient = task_gradient + lambda x w                 |
|     ... adam update with this gradient ...                |
|                                                           |
|     Problem: Weight decay mixed into adaptive LR          |
|              High-variance params get less regularization |
|                                                           |
|   AdamW (CORRECT):                                        |
|     gradient = task_gradient       # NO weight decay here |
|     ... adam update with this gradient ...                |
|     w = w - lr x lambda x w        # Decay applied AFTER  |
|                                                           |
|     Weight decay truly decoupled from optimization.       |
|                                                           |
+-----------------------------------------------------------+

Always use AdamW for transformers, not Adam with L2.


Learning Rate Schedules

Learning rate should change during training. High initially, lower later.

Linear Decay

LINEAR DECAY
+-----------------------------------------------------------+
|                                                           |
|   lr = initial_lr x (1 - step/total_steps)                |
|                                                           |
|      lr                                                   |
|       |                                                   |
|     1 |*                                                  |
|       | \                                                 |
|       |  \                                                |
|   0.5 |   \                                               |
|       |    \                                              |
|     0 |     *                                             |
|       +----------                                         |
|       0   steps   T                                       |
|                                                           |
+-----------------------------------------------------------+

Cosine Decay

COSINE DECAY
+-----------------------------------------------------------+
|                                                           |
|   lr = lr_min + 0.5 x (lr_max - lr_min)                   |
|        x (1 + cos(pi x step/total_steps))                 |
|                                                           |
|      lr                                                   |
|       |                                                   |
|     1 |*\                                                 |
|       |  \                                                |
|       |   \                                               |
|   0.5 |    \                                              |
|       |     \___                                          |
|     0 |         *                                         |
|       +----------                                         |
|       0   steps   T                                       |
|                                                           |
|   Smooth decay, lingers longer at mid-range LR.           |
|   Often works better than linear.                         |
|                                                           |
+-----------------------------------------------------------+

Warmup

WARMUP
+-----------------------------------------------------------+
|                                                           |
|   Start with very low LR, ramp up, then decay.            |
|                                                           |
|      lr                                                   |
|       |       *---\                                       |
|       |      /     \                                      |
|       |     /       \                                     |
|   0.5 |    /         \                                    |
|       |   /           \                                   |
|     0 |*-/             *                                  |
|       +----------------------                             |
|       0  warmup  peak  T                                  |
|                                                           |
|   Warmup prevents early instability.                      |
|   Gradients are noisy at start (random weights).          |
|   High LR + noisy gradients = explosion.                  |
|                                                           |
+-----------------------------------------------------------+

Typical warmup: 1-5% of total training steps.


Common Configurations

Pre-training LLM

Optimizer: AdamW
beta1 = 0.9, beta2 = 0.95
Weight decay: 0.1
LR: 1e-4 to 3e-4
Schedule: Cosine decay with warmup
Warmup: 2000 steps (or 1% of total)

Fine-tuning

Optimizer: AdamW
beta1 = 0.9, beta2 = 0.999
Weight decay: 0.01
LR: 1e-5 to 5e-5 (10-100x smaller than pre-training)
Schedule: Linear decay with warmup
Warmup: 3-5% of total steps

Quick Reference

ScenarioLRScheduleWarmup
Pre-training LLM1e-4 - 3e-4Cosine1-2%
Fine-tuning LLM1e-5 - 5e-5Linear3-5%
Fine-tuning BERT2e-5 - 5e-5Linear10%
Training CNN1e-3Step decayNone

Gradient Clipping

Limit gradient magnitude to prevent explosions.

GRADIENT CLIPPING
+-----------------------------------------------------------+
|                                                           |
|   Clip by global norm (most common):                      |
|     total_norm = sqrt(SUM(gradient^2))                    |
|     if total_norm > max_norm:                             |
|         gradient = gradient x (max_norm / total_norm)     |
|                                                           |
|   Typical max_norm: 1.0                                   |
|                                                           |
|   Prevents single bad batch from destroying model.        |
|                                                           |
+-----------------------------------------------------------+

When to use:

  • Always for transformers
  • When training is unstable
  • When loss spikes occasionally

Debugging Optimization

LOSS NOT DECREASING
+-----------------------------------------------------------+
|                                                           |
|   Symptoms:                                               |
|     - Loss stays flat from start                          |
|     - Or decreases very slowly                            |
|                                                           |
|   Causes:                                                 |
|     - Learning rate too low                               |
|     - Warmup too long                                     |
|     - Wrong optimizer config                              |
|                                                           |
|   Debug steps:                                            |
|     1. Try 10x higher LR                                  |
|     2. Reduce warmup steps                                |
|     3. Check gradient values (should be non-zero)         |
|     4. Verify optimizer is updating weights               |
|                                                           |
+-----------------------------------------------------------+

LOSS EXPLODES (NaN)
+-----------------------------------------------------------+
|                                                           |
|   Symptoms:                                               |
|     - Loss suddenly becomes NaN                           |
|     - Training crashes                                    |
|                                                           |
|   Causes:                                                 |
|     - Learning rate too high                              |
|     - Missing gradient clipping                           |
|     - No warmup                                           |
|                                                           |
|   Debug steps:                                            |
|     1. Add gradient clipping (max_norm=1.0)               |
|     2. Reduce LR by 10x                                   |
|     3. Add warmup (5% of steps)                           |
|     4. Check for numerical issues in data                 |
|                                                           |
+-----------------------------------------------------------+

FINE-TUNING DESTROYS MODEL
+-----------------------------------------------------------+
|                                                           |
|   Symptoms:                                               |
|     - After fine-tuning, model is worse than base         |
|     - "Catastrophic forgetting"                           |
|                                                           |
|   Causes:                                                 |
|     - Learning rate too high                              |
|     - No warmup                                           |
|     - Weight decay too high                               |
|                                                           |
|   Debug steps:                                            |
|     1. Use much smaller LR (1e-5 or lower)                |
|     2. Add warmup                                         |
|     3. Reduce weight decay                                |
|     4. Consider LoRA instead of full fine-tuning          |
|                                                           |
+-----------------------------------------------------------+

When This Matters

SituationWhat to know
Training transformersUse AdamW, not Adam
Fine-tuningLR 10-100x smaller than pre-training
Training unstableAdd warmup, gradient clipping
Loss not decreasingTry higher LR
Loss explodingLower LR, add gradient clipping
Understanding configsbeta1=momentum, beta2=variance averaging
Choosing scheduleCosine for pre-training, linear for fine-tuning