LayerNorm, BatchNorm, RMSNorm: what they do, when to use them, and Pre-Norm vs Post-Norm
55% of architecture interviews
Powers systems at Understanding transformer architecture
RMSNorm ~7% faster than LayerNorm query improvement
TL;DR
Normalization keeps activations in a reasonable range during training. LayerNorm is standard for transformers, RMSNorm is faster and used in modern LLMs like Llama. Pre-norm (normalize before sublayer) is more stable than Post-norm for deep networks.
Visual Overview
THE PROBLEM: INTERNAL COVARIATE SHIFT
+-----------------------------------------------------------+
| |
| During training, each layer's input distribution |
| changes as previous layers update. |
| |
| Epoch 1: Layer 3 receives inputs with mean=0, std=1 |
| Epoch 2: Layer 2 weights changed -> Layer 3 now sees |
| mean=0.5, std=2 |
| Epoch 3: More drift -> Layer 3 sees mean=1.2, std=3.5 |
| |
| Layer 3 keeps having to re-adapt to shifting inputs. |
| Training is slower and less stable. |
| |
| SOLUTION: Normalize activations to have consistent |
| statistics. |
| |
+-----------------------------------------------------------+
Batch Normalization
Normalizes across the batch dimension. Each feature is normalized using statistics from the current batch.
BATCH NORMALIZATION
+-----------------------------------------------------------+
| |
| Input: x with shape (batch_size, features) |
| |
| For each feature f: |
| mu_f = mean(x[:, f]) # mean across batch |
| sigma_f = std(x[:, f]) # std across batch |
| |
| x_norm[:, f] = (x[:, f] - mu_f) / (sigma_f + eps) |
| |
| Then apply learnable scale and shift: |
| output = gamma x x_norm + beta |
| |
| gamma and beta are learned per feature. |
| |
+-----------------------------------------------------------+
VISUAL: NORMALIZE ACROSS BATCH
+-----------------------------------------------------------+
| |
| Feature 1 Feature 2 Feature 3 |
| +----------+----------+----------+ |
| Batch 1| 2.1 | 0.5 | -1.2 | |
| +----------+----------+----------+ |
| Batch 2| 1.8 | 0.7 | -0.9 | <- Normalize |
| +----------+----------+----------+ down each |
| Batch 3| 2.3 | 0.4 | -1.1 | column |
| +----------+----------+----------+ |
| v v v |
| mu=2.07 mu=0.53 mu=-1.07 |
| |
+-----------------------------------------------------------+
When it works well:
- CNNs (computer vision)
- Large batch sizes (stable statistics)
- Training (has batch to compute stats)
Problems:
- Needs batch statistics at inference (use running average)
- Small batches -> noisy statistics -> unstable
- Batch size 1 -> undefined (no batch to normalize over)
Layer Normalization
Normalizes across the feature dimension. Each sample is normalized independently.
LAYER NORMALIZATION
+-----------------------------------------------------------+
| |
| Input: x with shape (batch_size, features) |
| |
| For each sample i: |
| mu_i = mean(x[i, :]) # mean across features |
| sigma_i = std(x[i, :]) # std across features |
| |
| x_norm[i, :] = (x[i, :] - mu_i) / (sigma_i + eps) |
| |
| Then apply learnable scale and shift: |
| output = gamma x x_norm + beta |
| |
+-----------------------------------------------------------+
VISUAL: NORMALIZE ACROSS FEATURES
+-----------------------------------------------------------+
| |
| Feature 1 Feature 2 Feature 3 |
| +----------+----------+----------+ |
| Batch 1| 2.1 | 0.5 | -1.2 | -> Normalize |
| +----------+----------+----------+ this row |
| Batch 2| 1.8 | 0.7 | -0.9 | -> Normalize |
| +----------+----------+----------+ this row |
| Batch 3| 2.3 | 0.4 | -1.1 | -> Normalize |
| +----------+----------+----------+ this row |
| |
+-----------------------------------------------------------+
When it works well:
- Transformers (the standard)
- RNNs, LSTMs
- Any batch size (including 1)
- Inference (no batch dependency)
Why transformers use LayerNorm:
- Sequence length varies -> batch statistics meaningless
- Inference often batch_size=1
- Each token normalized independently
RMSNorm (Root Mean Square Normalization)
Simplified LayerNorm: only variance normalization, no mean centering.
RMSNORM
+-----------------------------------------------------------+
| |
| Standard LayerNorm: |
| x_norm = (x - mean(x)) / std(x) |
| |
| RMSNorm: |
| x_norm = x / RMS(x) |
| |
| where RMS(x) = sqrt(mean(x^2)) |
| |
| No mean subtraction. Just scale by root-mean-square. |
| |
+-----------------------------------------------------------+
Why it works:
- Mean centering turns out to be less important than variance scaling
- Removing mean computation saves ~7% training time
- Quality is equivalent or better in practice
Used in: Llama, Llama 2, Mistral, most modern LLMs
# RMSNorm implementation
def rmsnorm(x, weight, eps=1e-6):
rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + eps)
return weight * (x / rms)
Pre-Norm vs Post-Norm
Where you place normalization matters for training stability.
POST-NORM (Original Transformer)
+-----------------------------------------------------------+
| |
| x = x + Attention(x) |
| x = LayerNorm(x) <- Norm AFTER residual |
| x = x + FFN(x) |
| x = LayerNorm(x) |
| |
| Problem: Gradients must flow through LayerNorm |
| Can cause instability in deep networks |
| |
+-----------------------------------------------------------+
PRE-NORM (Modern Transformers)
+-----------------------------------------------------------+
| |
| x = x + Attention(LayerNorm(x)) <- Norm BEFORE |
| x = x + FFN(LayerNorm(x)) |
| |
| Advantages: |
| - Residual stream is "clean" (just additions) |
| - Gradients flow directly through residual path |
| - More stable for deep networks |
| - Easier to train without careful LR tuning |
| |
+-----------------------------------------------------------+
Which to use: Pre-norm for new models. Post-norm only if replicating original BERT/GPT-2.
Comparison Table
| Aspect | BatchNorm | LayerNorm | RMSNorm |
|---|---|---|---|
| Normalizes across | Batch | Features | Features |
| Works with batch=1 | No | Yes | Yes |
| Needs running stats | Yes | No | No |
| Mean centering | Yes | Yes | No |
| Speed | Baseline | Baseline | ~7% faster |
| Used in | CNNs | Transformers | Modern LLMs |
Debugging Normalization Issues
TRAINING INSTABILITY (LOSS SPIKES)
+-----------------------------------------------------------+
| |
| Symptoms: |
| - Loss suddenly spikes during training |
| - Gradients explode intermittently |
| |
| Causes: |
| - Post-norm architecture with deep network |
| - Missing normalization somewhere |
| - Norm placed incorrectly |
| |
| Debug steps: |
| 1. Switch to pre-norm if using post-norm |
| 2. Check every sublayer has normalization |
| 3. Verify norm is before attention/FFN, not after |
| 4. Reduce learning rate |
| |
+-----------------------------------------------------------+
ACTIVATIONS GROWING UNBOUNDED
+-----------------------------------------------------------+
| |
| Symptoms: |
| - Activation magnitudes grow over layers |
| - Eventually overflow to NaN |
| |
| Causes: |
| - Missing normalization layer |
| - Residual accumulation without norm |
| - Wrong norm dimension |
| |
| Debug steps: |
| 1. Print activation statistics per layer |
| 2. Verify norm is applied (gamma, beta params exist) |
| 3. Check norm dimension matches input shape |
| |
+-----------------------------------------------------------+
When This Matters
| Situation | What to know |
|---|---|
| Reading transformer code | LayerNorm before attention/FFN (pre-norm) |
| Understanding Llama/Mistral | RMSNorm, not LayerNorm |
| Training instability | Switch to pre-norm, check norm placement |
| Batch size constraints | LayerNorm works with any batch size |
| Optimizing inference speed | RMSNorm is slightly faster |
| Porting CNN techniques | BatchNorm doesn’t work for transformers |
| Understanding model configs | ”norm_eps” is the epsilon in denominator |