Skip to content

Ai-engineering Series

Embeddings to Attention - Relating Tokens to Each Other

Deep dive into attention mechanisms: why transformers replaced RNNs, scaled dot-product attention, multi-head attention, and how context length affects performance

Concepts Covered in This Article

Building On Previous Knowledge

In the previous progression, you learned how tokens become embeddings—vectors that capture meaning. Each token has its own embedding vector.

But there’s a problem: embeddings are independent. The vector for “bank” doesn’t know whether it’s about finance or rivers until it sees the surrounding words.

Attention solves this by letting each token’s representation incorporate information from other tokens. After attention, “bank” in “river bank” has a different representation than “bank” in “savings bank”—because it attended to different context.

What Goes Wrong Without This:

Symptom: Your model truncates long documents and misses important information.
Cause: You treated context as infinite. Attention is O(n²) in memory.
       128K context doesn't mean you can use 128K without consequences.

Symptom: Model gives inconsistent answers to the same question.
Cause: In long contexts, attention can miss relevant information.
       "Lost in the middle" - models attend more to beginning and end.

Symptom: Reasoning fails on complex multi-step problems.
Cause: Attention struggles to carry information across many hops.
       Each hop through attention layers is lossy.

Why Attention Matters

Before attention, sequence models used recurrence (RNNs, LSTMs):

Process sequentially:
  token_1  state_1  token_2  state_2  ...  token_n  state_n

Problems:
  1. Can't parallelize (each step depends on previous)
  2. Long-range dependencies are hard (gradient vanishing)
  3. Information bottleneck (fixed-size state)

A 1000-word document must compress through a single state vector.

Attention allows direct connections:

Every token can directly access every other token:

  token_1  token_2  token_3  ...  token_n
                                         
    └───────────┴───────────┴───────────────┘
            All pairwise connections

Benefits:
  1. Fully parallelizable (all attention computed at once)
  2. Direct long-range access (no bottleneck)
  3. Dynamic weighting (attend more to relevant tokens)

This is why Transformers replaced RNNs everywhere.


The Core Idea: Weighted Mixing

Attention is surprisingly simple at its core:

Input: Sequence of token embeddings [v1, v2, v3, v4]

For each token, compute a new representation by
MIXING all tokens weighted by relevance:

  new_v2 = 0.1*v1 + 0.6*v2 + 0.2*v3 + 0.1*v4
                                   
        weights sum to 1.0 (softmax)

The weights (attention scores) determine how much
each token contributes to the new representation.

For the sentence “The cat sat on the mat”:

When processing "sat":
  - High attention to "cat" (subject of sat)
  - Medium attention to "mat" (related to sitting)
  - Low attention to "the" (less informative)

Result: "sat" embedding now contains information
about WHAT sat (cat) and WHERE (mat).

Query, Key, Value

The Q, K, V framework formalizes how attention scores are computed:

+------------------------------------------------------------------+
|  INTUITION: Library Metaphor                                      |
+------------------------------------------------------------------+
|                                                                   |
|  Query (Q): What am I looking for?                                |
|    "I need books about machine learning"                          |
|                                                                   |
|  Key (K):   What does each item contain?                          |
|    Book 1: "Introduction to AI"                                   |
|    Book 2: "Cooking recipes"                                      |
|    Book 3: "Deep Learning fundamentals"                           |
|                                                                   |
|  Value (V): The actual content to retrieve                        |
|    The book's actual contents                                     |
|                                                                   |
|  Match Query against Keys  Weight Values by match quality        |
|                                                                   |
+------------------------------------------------------------------+

In practice, Q, K, V are linear projections of the input embeddings:

Input embedding: x (dimension d_model)

Q = x @ W_Q    # project to query space
K = x @ W_K    # project to key space
V = x @ W_V    # project to value space

Where W_Q, W_K, W_V are learned weight matrices.

Each token gets its own Q, K, V vectors.
Token i's query asks: "What should I attend to?"
Token j's key advertises: "Here's what I contain"
Token j's value provides: "Here's my information if you want it"

Scaled Dot-Product Attention

The standard attention formula:

Attention(Q, K, V) = softmax(QK^T / √d_k) V

Let's break this down:

Step 1: Compute Attention Scores

scores = Q @ K^T

For a sequence of n tokens, each with d_k dimensional Q and K:
  Q: (n, d_k)
  K: (n, d_k)
  K^T: (d_k, n)
  Q @ K^T: (n, n)   attention scores matrix

scores[i][j] = how much token i should attend to token j

Step 2: Scale

scaled_scores = scores / √d_k

Why scale?
  Dot products grow with dimension size.
  Large dot products  softmax becomes very peaked
   gradients vanish (all weight on one token)

  √d_k keeps variance stable regardless of dimension.

Step 3: Softmax

attention_weights = softmax(scaled_scores)

Softmax converts scores to probabilities:
  - All values between 0 and 1
  - Each row sums to 1.0
  - High scores  high weights, low scores  near zero

Example row: [2.1, 0.5, -1.0, 0.8]
After softmax: [0.65, 0.13, 0.03, 0.19]
                
        Token with score 2.1 gets most attention

Step 4: Weighted Sum

output = attention_weights @ V

Each output vector is a weighted combination of all value vectors:
  output_i = Σ (attention_weight[i][j] * V[j])

This is where information actually flows between tokens.

Complete Picture

+------------------------------------------------------------------+
|               SCALED DOT-PRODUCT ATTENTION                        |
+------------------------------------------------------------------+
|                                                                   |
|    Q (n×d_k)        K (n×d_k)                                    |
|        │                │                                         |
|        │      ┌─────────┘                                         |
|        │      │ (transpose)                                       |
|                                                                 |
|     ┌────────────┐                                                |
|     │   MatMul   │  Q @ K^T = (n×n) attention scores              |
|     └─────┬──────┘                                                |
|           │                                                       |
|                                                                  |
|     ┌────────────┐                                                |
|     │   Scale    │  divide by √d_k                                |
|     └─────┬──────┘                                                |
|           │                                                       |
|                                                                  |
|     ┌────────────┐                                                |
|     │  Softmax   │  convert to probabilities (each row)           |
|     └─────┬──────┘                                                |
|           │                                                       |
|           │      V (n×d_v)                                        |
|           │         │                                             |
|                                                                 |
|     ┌────────────────────┐                                        |
|     │      MatMul        │  weights @ V = output (n×d_v)          |
|     └─────────┬──────────┘                                        |
|               │                                                   |
|                                                                  |
|         Output (n×d_v)                                            |
|                                                                   |
+------------------------------------------------------------------+

Multi-Head Attention

One attention pattern isn’t enough. Different relationships need different attention:

"The animal didn't cross the street because it was too tired."

Different questions need different attention patterns:
  - Q: What is "it"?   attend "it" to "animal" (coreference)
  - Q: What action?    attend verbs to subjects
  - Q: What's the reason?  attend "tired" to "didn't cross"

Solution: Multiple attention "heads", each learning different patterns.

Multi-head attention runs h parallel attention operations:

+------------------------------------------------------------------+
|               MULTI-HEAD ATTENTION                                |
+------------------------------------------------------------------+
|                                                                   |
|   Input X                                                         |
|      │                                                            |
|      ├───────────────┬───────────────┬─────────────────┐          |
|      │               │               │                 │          |
|   ┌────┐         ┌────┐         ┌────┐          ┌────┐       |
|   │Head1│         │Head2│         │Head3│   ...    │Head_h│       |
|   │ QKV │         │ QKV │         │ QKV │          │ QKV │        |
|   └──┬──┘         └──┬──┘         └──┬──┘          └──┬──┘        |
|      │               │               │                 │          |
|   (n,d_v/h)       (n,d_v/h)       (n,d_v/h)        (n,d_v/h)      |
|      │               │               │                 │          |
|      └───────────────┴───────────────┴─────────────────┘          |
|                          │                                        |
|                    ┌──────────┐                                  |
|                    │  Concat   │  Combine all heads               |
|                    └─────┬─────┘                                  |
|                          │                                        |
|                    ┌──────────┐                                  |
|                    │   W_O     │  Project back to d_model         |
|                    └─────┬─────┘                                  |
|                          │                                        |
|                     Output (n, d_model)                           |
|                                                                   |
+------------------------------------------------------------------+

Typical configurations:

+------------------+---------------+---------------+
|  Model           |  d_model      |  Heads (h)    |
+------------------+---------------+---------------+
|  BERT-base       |  768          |  12           |
|  GPT-2           |  768          |  12           |
|  GPT-3 (175B)    |  12288        |  96           |
|  LLaMA 7B        |  4096         |  32           |
+------------------+---------------+---------------+

Each head has d_k = d_model / h dimensions.
More heads = more diverse attention patterns.

Context Window and Attention

The context window limit exists because attention is O(n²):

For sequence length n:
  - Attention matrix: n × n
  - Memory: O(n²)
  - Compute: O(n²)

+------------------+---------------+---------------+
|  Context Length  |  Attention    |  Memory       |
+------------------+---------------+---------------+
|  1K tokens       |  1M entries   |  ~4 MB        |
|  4K tokens       |  16M entries  |  ~64 MB       |
|  32K tokens      |  1B entries   |  ~4 GB        |
|  128K tokens     |  16B entries  |  ~64 GB       |
+------------------+---------------+---------------+

This is why long-context models are expensive.
128K context doesn't mean free 128K—it means 128K² cost.

Techniques for Longer Context

1. Sparse Attention

Instead of n² full attention, attend to subset:
  - Local attention: only nearby tokens
  - Strided attention: every k-th token
  - Random attention: sample positions

BigBird, Longformer use O(n) attention patterns.
Trade: some information paths are blocked.

2. Flash Attention

Not mathematically different—same result.
But implements attention in a memory-efficient way:
  - Never materializes full n×n matrix
  - Computes in tiles that fit in GPU SRAM
  - 2-4x faster, same memory as single forward pass

This is why modern context windows keep growing.

3. Sliding Window / RoPE

Combine:
  - Rotary Position Embeddings (RoPE) for relative positions
  - Sliding window for bounded attention
  - Global tokens that always attend everywhere

LLaMA, Mistral use these patterns.

The “Lost in the Middle” Problem

Even with long context, attention has limitations:

+------------------------------------------------------------------+
|  Position in context vs attention received                        |
|                                                                   |
|  Attention │                                                      |
|    Score   │  ████                                                |
|            │  ████                          ████                  |
|            │  ████████                ████████████                |
|            │  ████████████████  ████████████████████████          |
|            └──────────────────────────────────────────────        |
|            Beginning        Middle              End               |
|                                                                   |
|  Beginning and end get more attention.                            |
|  Middle content can be "lost."                                    |
+------------------------------------------------------------------+

Practical impact:
  - Put critical information at beginning or end of prompts
  - Don't bury important context in the middle of long documents
  - Test your application with information at different positions

Attention Visualization

What attention patterns look like:

"The cat sat on the mat"

Attention weights (simplified, one head):

           The   cat   sat   on    the   mat
    The   [0.3   0.2   0.1   0.1   0.2   0.1]
    cat   [0.2   0.4   0.2   0.0   0.1   0.1]
    sat   [0.1   0.5   0.2   0.1   0.0   0.1]   "sat" attends heavily to "cat"
    on    [0.1   0.1   0.3   0.2   0.1   0.2]
    the   [0.1   0.1   0.1   0.2   0.3   0.2]
    mat   [0.1   0.1   0.2   0.2   0.2   0.2]

Different heads learn different patterns:
  - Head 1: Subject-verb relationships
  - Head 2: Positional (nearby tokens)
  - Head 3: Syntactic structure

Code Example

Minimal implementation of scaled dot-product attention:

import torch
import torch.nn.functional as F

def scaled_dot_product_attention(
    Q: torch.Tensor,  # (batch, n, d_k)
    K: torch.Tensor,  # (batch, n, d_k)
    V: torch.Tensor,  # (batch, n, d_v)
    mask: torch.Tensor = None,  # optional mask
) -> torch.Tensor:
    """
    Compute scaled dot-product attention.

    Returns:
        Output tensor of shape (batch, n, d_v)
    """
    d_k = Q.size(-1)

    # Step 1: Compute attention scores
    scores = torch.matmul(Q, K.transpose(-2, -1))  # (batch, n, n)

    # Step 2: Scale
    scores = scores / (d_k ** 0.5)

    # Optional: Apply mask (for causal/padding)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))

    # Step 3: Softmax to get attention weights
    attention_weights = F.softmax(scores, dim=-1)  # (batch, n, n)

    # Step 4: Weighted sum of values
    output = torch.matmul(attention_weights, V)  # (batch, n, d_v)

    return output

# Example usage
batch_size, seq_len, d_model = 2, 10, 64

# Random Q, K, V (in practice, these come from linear projections)
Q = torch.randn(batch_size, seq_len, d_model)
K = torch.randn(batch_size, seq_len, d_model)
V = torch.randn(batch_size, seq_len, d_model)

output = scaled_dot_product_attention(Q, K, V)
print(f"Output shape: {output.shape}")  # (2, 10, 64)

Key Takeaways

1. Attention lets tokens incorporate information from all other tokens

2. Q, K, V are projections that define what to attend to and what to retrieve

3. Scaled dot-product attention: softmax(QK^T / √d_k) @ V

4. Multi-head attention runs h parallel attention operations
   - Each head can learn different relationship patterns

5. Context window limits exist because attention is O(n²)
   - 128K context = 128K² computation

6. "Lost in the middle" is real
   - Critical information should be at beginning or end

Verify Your Understanding

Before proceeding, you should be able to:

Draw the attention formula and explain each component — What does the softmax do? Why scale by √d_k? What does multiplying by V accomplish?

Explain why multi-head attention is better than single-head — Give a concrete example of different “types” of relationships different heads might learn.

Your LLM has 128K context but struggles to answer questions about content in the middle. What’s happening? How would you restructure your prompt?

Calculate the memory required for full attention with 32K tokens at float16 precision. How does this change with 64K tokens?


What’s Next

After this, you can:

  • Continue → Attention → Generation — how models produce text token by token
  • Go deeper → Explore transformer architectures, pre-training objectives