Transformer Attention PatternsΒΆ

This notebook provides interactive visualizations of the self-attention mechanism in Transformers, demonstrating how attention works, multi-head attention, and positional encodings.

InΒ [1]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.patches import Rectangle
from matplotlib.gridspec import GridSpec
import warnings
warnings.filterwarnings('ignore')

# Set style for better visualizations
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")
plt.rcParams['figure.figsize'] = (15, 10)
plt.rcParams['font.size'] = 11

1. Understanding Self-AttentionΒΆ

Self-attention is the core mechanism that allows Transformers to process sequences without recurrence. Each position in the sequence can directly attend to every other position, computing relevance scores through dot products of learned query and key vectors.

What to observe:

  • Notice how each word (row) distributes its attention across all words (columns)
  • The diagonal often has high values (self-attention), but words also attend to semantically related words
  • The attention weights sum to 1 for each row (softmax normalization)

Significance: This mechanism allows the model to capture both local and long-range dependencies in a single operation, unlike RNNs which must process sequences step by step.

InΒ [2]:
def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Compute scaled dot-product attention.
    
    Args:
        Q: Query matrix (seq_len, d_k)
        K: Key matrix (seq_len, d_k)
        V: Value matrix (seq_len, d_v)
        mask: Optional mask (seq_len, seq_len)
    
    Returns:
        output: Attention output (seq_len, d_v)
        attention_weights: Attention weights (seq_len, seq_len)
    """
    d_k = Q.shape[-1]
    
    # Compute attention scores
    scores = np.matmul(Q, K.T) / np.sqrt(d_k)
    
    # Apply mask if provided
    if mask is not None:
        scores = scores + (mask * -1e9)
    
    # Apply softmax
    attention_weights = np.exp(scores - np.max(scores, axis=-1, keepdims=True))
    attention_weights = attention_weights / np.sum(attention_weights, axis=-1, keepdims=True)
    
    # Apply attention to values
    output = np.matmul(attention_weights, V)
    
    return output, attention_weights

# Example: Simple attention on a sentence
sentence = "The cat sat on the mat"
words = sentence.split()
seq_len = len(words)
d_model = 64

# Create random embeddings for demonstration
np.random.seed(42)
embeddings = np.random.randn(seq_len, d_model)

# Compute Q, K, V (in practice, these would be learned projections)
Q = embeddings @ np.random.randn(d_model, d_model)
K = embeddings @ np.random.randn(d_model, d_model)
V = embeddings @ np.random.randn(d_model, d_model)

# Compute attention
output, attention_weights = scaled_dot_product_attention(Q, K, V)

# Visualize attention weights
plt.figure(figsize=(8, 6))
sns.heatmap(attention_weights, annot=True, fmt='.2f', cmap='Blues',
            xticklabels=words, yticklabels=words, cbar_kws={'label': 'Attention Weight'})
plt.title('Self-Attention Weights: "' + sentence + '"', fontsize=14, fontweight='bold')
plt.xlabel('Keys (Attended to)', fontsize=12)
plt.ylabel('Queries (Attending from)', fontsize=12)
plt.tight_layout()
plt.show()
No description has been provided for this image

2. Visualizing Attention PatternsΒΆ

Different attention patterns emerge in transformers for different purposes. Understanding these patterns helps explain how transformers process information and what inductive biases can be built into the architecture.

What to observe:

  • Uniform: Equal attention to all positions - useful for aggregating global information
  • Self-Focus: Strong diagonal pattern - preserves position-specific information
  • Causal: Lower triangular - ensures autoregressive property for generation
  • Local: Window-based attention - captures local dependencies efficiently
  • Strided: Sparse attention pattern - reduces computational complexity
  • Global+Local: Hybrid pattern - combines local context with global anchors

Significance: These patterns represent different strategies for information flow. Modern efficient transformers often use combinations of these patterns to balance expressiveness with computational efficiency.

InΒ [3]:
def create_attention_patterns():
    """Create different types of attention patterns."""
    seq_len = 10
    patterns = {}
    
    # 1. Uniform attention
    patterns['Uniform'] = np.ones((seq_len, seq_len)) / seq_len
    
    # 2. Diagonal (self) attention
    patterns['Self-Focus'] = np.eye(seq_len) * 0.8 + np.ones((seq_len, seq_len)) * 0.02
    
    # 3. Lower triangular (causal) attention
    causal = np.tril(np.ones((seq_len, seq_len)))
    patterns['Causal'] = causal / causal.sum(axis=1, keepdims=True)
    
    # 4. Local attention (attending to nearby positions)
    local = np.zeros((seq_len, seq_len))
    window_size = 3
    for i in range(seq_len):
        start = max(0, i - window_size // 2)
        end = min(seq_len, i + window_size // 2 + 1)
        local[i, start:end] = 1
    patterns['Local'] = local / local.sum(axis=1, keepdims=True)
    
    # 5. Strided attention
    strided = np.zeros((seq_len, seq_len))
    for i in range(seq_len):
        strided[i, ::2] = 1  # Attend to every other position
    patterns['Strided'] = strided / strided.sum(axis=1, keepdims=True)
    
    # 6. Global + Local attention
    global_local = local.copy()
    global_local[:, 0] = 1  # All positions attend to first position
    global_local[0, :] = 1  # First position attends to all
    patterns['Global+Local'] = global_local / global_local.sum(axis=1, keepdims=True)
    
    return patterns

# Create and visualize patterns
patterns = create_attention_patterns()

fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

for idx, (name, pattern) in enumerate(patterns.items()):
    ax = axes[idx]
    im = ax.imshow(pattern, cmap='Blues', aspect='auto', vmin=0, vmax=0.5)
    ax.set_title(f'{name} Attention', fontsize=12, fontweight='bold')
    ax.set_xlabel('Key Position', fontsize=10)
    ax.set_ylabel('Query Position', fontsize=10)
    ax.set_xticks(range(10))
    ax.set_yticks(range(10))
    plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

plt.suptitle('Common Attention Patterns in Transformers', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()
No description has been provided for this image

3. Multi-Head AttentionΒΆ

Multi-head attention allows the model to jointly attend to information from different representation subspaces. Instead of having a single attention function, the model uses multiple "heads" that can learn different types of relationships.

What to observe:

  • Each head shows a different attention pattern - some might be more local, others more distributed
  • The patterns are learned during training and often specialize for different linguistic phenomena
  • Notice how different heads might focus on different aspects of the sequence

Significance: Multi-head attention is crucial for transformer expressiveness. Different heads can capture different types of relationships (syntactic, semantic, positional) simultaneously, which single attention cannot achieve effectively.

InΒ [4]:
class MultiHeadAttention:
    def __init__(self, d_model, n_heads):
        """
        Multi-Head Attention mechanism.
        
        Args:
            d_model: Model dimension
            n_heads: Number of attention heads
        """
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        # Initialize weight matrices
        self.W_q = np.random.randn(n_heads, d_model, self.d_k) * 0.1
        self.W_k = np.random.randn(n_heads, d_model, self.d_k) * 0.1
        self.W_v = np.random.randn(n_heads, d_model, self.d_k) * 0.1
        self.W_o = np.random.randn(n_heads * self.d_k, d_model) * 0.1
    
    def forward(self, x):
        """
        Forward pass of multi-head attention.
        
        Args:
            x: Input tensor (seq_len, d_model)
        
        Returns:
            output: Output tensor (seq_len, d_model)
            attention_weights: Attention weights for each head
        """
        seq_len = x.shape[0]
        attention_weights = []
        head_outputs = []
        
        for h in range(self.n_heads):
            # Project to Q, K, V
            Q = x @ self.W_q[h]
            K = x @ self.W_k[h]
            V = x @ self.W_v[h]
            
            # Compute attention for this head
            head_out, attn_weights = scaled_dot_product_attention(Q, K, V)
            head_outputs.append(head_out)
            attention_weights.append(attn_weights)
        
        # Concatenate heads
        concat_output = np.concatenate(head_outputs, axis=-1)
        
        # Final projection
        output = concat_output @ self.W_o
        
        return output, attention_weights

# Example: Multi-head attention on a sequence
seq_len = 8
d_model = 64
n_heads = 4

# Create sample input
x = np.random.randn(seq_len, d_model)

# Apply multi-head attention
mha = MultiHeadAttention(d_model, n_heads)
output, attention_weights_heads = mha.forward(x)

# Visualize attention weights for each head
fig, axes = plt.subplots(1, n_heads, figsize=(16, 4))

for h in range(n_heads):
    axes[h].imshow(attention_weights_heads[h], cmap='Blues', aspect='auto')
    axes[h].set_title(f'Head {h+1}', fontsize=12)
    axes[h].set_xlabel('Key Position', fontsize=10)
    if h == 0:
        axes[h].set_ylabel('Query Position', fontsize=10)
    axes[h].set_xticks(range(seq_len))
    axes[h].set_yticks(range(seq_len))

plt.suptitle('Multi-Head Attention: Different Heads Learn Different Patterns', 
             fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()
No description has been provided for this image

4. Positional EncodingΒΆ

Since attention mechanisms are permutation-invariant (order doesn't matter), transformers need a way to incorporate position information. Positional encodings add position-dependent signals to embeddings using sinusoidal functions.

What to observe:

  • The alternating patterns in the encoding matrix - different frequencies for different dimensions
  • The similarity matrix shows how positional encodings relate across positions
  • Notice the periodic nature of the sinusoidal functions at different scales
  • The similarity decreases smoothly with distance between positions

Significance: Positional encodings allow the model to use position information without learning position-specific parameters. The sinusoidal pattern enables the model to extrapolate to sequence lengths not seen during training.

InΒ [5]:
def get_positional_encoding(seq_len, d_model):
    """
    Generate positional encoding for a sequence.
    
    Args:
        seq_len: Sequence length
        d_model: Model dimension
    
    Returns:
        PE: Positional encoding matrix (seq_len, d_model)
    """
    PE = np.zeros((seq_len, d_model))
    position = np.arange(seq_len).reshape(-1, 1)
    
    # Create div_term for the sinusoidal pattern
    div_term = np.exp(np.arange(0, d_model, 2) * -(np.log(10000.0) / d_model))
    
    # Apply sin to even indices
    PE[:, 0::2] = np.sin(position * div_term)
    
    # Apply cos to odd indices
    if d_model % 2 == 0:
        PE[:, 1::2] = np.cos(position * div_term)
    else:
        PE[:, 1::2] = np.cos(position * div_term[:-1])
    
    return PE

# Generate positional encodings
seq_len = 50
d_model = 128
PE = get_positional_encoding(seq_len, d_model)

# Visualize positional encodings
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# 1. Full positional encoding matrix
axes[0, 0].imshow(PE, cmap='RdBu_r', aspect='auto')
axes[0, 0].set_title('Positional Encoding Matrix', fontsize=12, fontweight='bold')
axes[0, 0].set_xlabel('Dimension', fontsize=10)
axes[0, 0].set_ylabel('Position', fontsize=10)
axes[0, 0].set_colorbar = plt.colorbar(axes[0, 0].images[0], ax=axes[0, 0])

# 2. First few dimensions
for i in range(8):
    axes[0, 1].plot(PE[:, i], label=f'Dim {i}', alpha=0.8)
axes[0, 1].set_title('First 8 Dimensions of Positional Encoding', fontsize=12, fontweight='bold')
axes[0, 1].set_xlabel('Position', fontsize=10)
axes[0, 1].set_ylabel('Encoding Value', fontsize=10)
axes[0, 1].legend(loc='upper right', fontsize=8)
axes[0, 1].grid(True, alpha=0.3)

# 3. Similarity matrix (dot product between positions)
similarity = PE @ PE.T
axes[1, 0].imshow(similarity, cmap='viridis', aspect='auto')
axes[1, 0].set_title('Position Similarity Matrix', fontsize=12, fontweight='bold')
axes[1, 0].set_xlabel('Position', fontsize=10)
axes[1, 0].set_ylabel('Position', fontsize=10)

# 4. Distance-based similarity
positions = [0, 5, 10, 20, 30, 40]
for pos in positions:
    similarity_to_pos = PE @ PE[pos]
    axes[1, 1].plot(similarity_to_pos, label=f'Pos {pos}', alpha=0.7)

axes[1, 1].set_title('Similarity to Different Positions', fontsize=12, fontweight='bold')
axes[1, 1].set_xlabel('Position', fontsize=10)
axes[1, 1].set_ylabel('Similarity', fontsize=10)
axes[1, 1].legend(loc='upper right', fontsize=8)
axes[1, 1].grid(True, alpha=0.3)

plt.suptitle('Positional Encoding Visualization', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()
No description has been provided for this image

5. Attention with Positional EncodingΒΆ

This section compares attention patterns with and without positional encoding to understand how position information affects attention distribution.

What to observe:

  • Without positional encoding: attention is based purely on content similarity
  • With positional encoding: nearby positions tend to have slightly higher attention
  • The difference plot shows where positional information changes attention patterns
  • Notice how position encoding creates more structured, position-aware patterns

Significance: Positional encoding helps the model understand word order and distance relationships. This is critical for tasks like understanding grammar, where word order matters significantly.

InΒ [6]:
# Demonstrate how positional encoding affects attention
sentence = "The quick brown fox jumps over the lazy dog"
words = sentence.split()
seq_len = len(words)
d_model = 64

# Create word embeddings (random for demonstration)
np.random.seed(42)
word_embeddings = np.random.randn(seq_len, d_model)

# Get positional encodings
pos_encodings = get_positional_encoding(seq_len, d_model)

# Compare attention with and without positional encoding
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# 1. Without positional encoding
Q1 = word_embeddings @ np.random.randn(d_model, d_model) * 0.1
K1 = word_embeddings @ np.random.randn(d_model, d_model) * 0.1
V1 = word_embeddings @ np.random.randn(d_model, d_model) * 0.1
_, attn_no_pos = scaled_dot_product_attention(Q1, K1, V1)

im1 = axes[0].imshow(attn_no_pos, cmap='Blues', aspect='auto')
axes[0].set_title('Attention WITHOUT Positional Encoding', fontsize=12, fontweight='bold')
axes[0].set_xticks(range(seq_len))
axes[0].set_yticks(range(seq_len))
axes[0].set_xticklabels(words, rotation=45, ha='right', fontsize=9)
axes[0].set_yticklabels(words, fontsize=9)
plt.colorbar(im1, ax=axes[0], fraction=0.046, pad=0.04)

# 2. With positional encoding
embeddings_with_pos = word_embeddings + pos_encodings * 0.1
Q2 = embeddings_with_pos @ np.random.randn(d_model, d_model) * 0.1
K2 = embeddings_with_pos @ np.random.randn(d_model, d_model) * 0.1
V2 = embeddings_with_pos @ np.random.randn(d_model, d_model) * 0.1
_, attn_with_pos = scaled_dot_product_attention(Q2, K2, V2)

im2 = axes[1].imshow(attn_with_pos, cmap='Blues', aspect='auto')
axes[1].set_title('Attention WITH Positional Encoding', fontsize=12, fontweight='bold')
axes[1].set_xticks(range(seq_len))
axes[1].set_yticks(range(seq_len))
axes[1].set_xticklabels(words, rotation=45, ha='right', fontsize=9)
axes[1].set_yticklabels(words, fontsize=9)
plt.colorbar(im2, ax=axes[1], fraction=0.046, pad=0.04)

# 3. Difference
diff = attn_with_pos - attn_no_pos
im3 = axes[2].imshow(diff, cmap='RdBu_r', aspect='auto', vmin=-0.1, vmax=0.1)
axes[2].set_title('Difference (Effect of Positional Encoding)', fontsize=12, fontweight='bold')
axes[2].set_xticks(range(seq_len))
axes[2].set_yticks(range(seq_len))
axes[2].set_xticklabels(words, rotation=45, ha='right', fontsize=9)
axes[2].set_yticklabels(words, fontsize=9)
plt.colorbar(im3, ax=axes[2], fraction=0.046, pad=0.04)

plt.suptitle('Impact of Positional Encoding on Attention', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()
No description has been provided for this image

6. Masked Self-Attention (Causal Attention)ΒΆ

Causal attention is used in autoregressive models like GPT, where each position can only attend to previous positions. This ensures the model can't "cheat" by looking at future tokens during training.

What to observe:

  • The mask prevents attention to future positions (upper triangle is blocked)
  • Each position can only see itself and previous positions
  • The attention pattern is strictly lower triangular
  • Later tokens have access to more context than earlier ones

Significance: Causal masking is essential for language generation tasks. It ensures the model generates text left-to-right without accessing future information, maintaining the autoregressive property needed for coherent generation.

InΒ [7]:
def create_causal_mask(seq_len):
    """
    Create a causal mask for autoregressive attention.
    
    Args:
        seq_len: Sequence length
    
    Returns:
        mask: Upper triangular mask (seq_len, seq_len)
    """
    mask = np.triu(np.ones((seq_len, seq_len)), k=1)
    return mask

# Example sequence for generation
generation_sequence = "Once upon a time in a land"
gen_words = generation_sequence.split()
gen_seq_len = len(gen_words)

# Create embeddings
gen_embeddings = np.random.randn(gen_seq_len, d_model)

# Create causal mask
causal_mask = create_causal_mask(gen_seq_len)

# Compute masked attention
Q = gen_embeddings @ np.random.randn(d_model, d_model) * 0.1
K = gen_embeddings @ np.random.randn(d_model, d_model) * 0.1
V = gen_embeddings @ np.random.randn(d_model, d_model) * 0.1

_, masked_attention = scaled_dot_product_attention(Q, K, V, mask=causal_mask)

# Visualize masked attention
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# 1. Causal mask
axes[0].imshow(1 - causal_mask, cmap='Greys', aspect='auto')
axes[0].set_title('Causal Mask (White = Allowed, Black = Masked)', fontsize=12, fontweight='bold')
axes[0].set_xticks(range(gen_seq_len))
axes[0].set_yticks(range(gen_seq_len))
axes[0].set_xticklabels(gen_words, rotation=45, ha='right')
axes[0].set_yticklabels(gen_words)
axes[0].set_xlabel('Can Attend To', fontsize=10)
axes[0].set_ylabel('Position', fontsize=10)

# 2. Masked attention weights
im = axes[1].imshow(masked_attention, cmap='Blues', aspect='auto')
axes[1].set_title('Causal (Masked) Self-Attention', fontsize=12, fontweight='bold')
axes[1].set_xticks(range(gen_seq_len))
axes[1].set_yticks(range(gen_seq_len))
axes[1].set_xticklabels(gen_words, rotation=45, ha='right')
axes[1].set_yticklabels(gen_words)
axes[1].set_xlabel('Key Position', fontsize=10)
axes[1].set_ylabel('Query Position', fontsize=10)
plt.colorbar(im, ax=axes[1], fraction=0.046, pad=0.04)

plt.suptitle('Causal Attention for Autoregressive Generation', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()
No description has been provided for this image

7. Cross-Attention (Encoder-Decoder Attention)ΒΆ

Cross-attention connects encoder and decoder in sequence-to-sequence models. The decoder queries attend to encoder keys and values, allowing the decoder to focus on relevant parts of the input.

What to observe:

  • Queries come from the target (decoder) sequence
  • Keys and values come from the source (encoder) sequence
  • Each target word distributes attention over source words
  • The pattern often shows rough word alignment in translation

Significance: Cross-attention is the key mechanism for tasks like translation, where the model needs to align output with relevant parts of the input. It allows the decoder to dynamically focus on different parts of the source sequence.

InΒ [8]:
def cross_attention_example():
    """
    Demonstrate cross-attention between encoder and decoder.
    """
    # Source sequence (encoder)
    source = "Hello world"
    source_words = source.split()
    source_len = len(source_words)
    
    # Target sequence (decoder)
    target = "Bonjour le monde"
    target_words = target.split()
    target_len = len(target_words)
    
    # Create embeddings
    source_embeddings = np.random.randn(source_len, d_model)
    target_embeddings = np.random.randn(target_len, d_model)
    
    # Cross-attention: Q from decoder, K and V from encoder
    Q = target_embeddings @ np.random.randn(d_model, d_model) * 0.1
    K = source_embeddings @ np.random.randn(d_model, d_model) * 0.1
    V = source_embeddings @ np.random.randn(d_model, d_model) * 0.1
    
    # Compute cross-attention
    d_k = Q.shape[-1]
    scores = np.matmul(Q, K.T) / np.sqrt(d_k)
    attention_weights = np.exp(scores - np.max(scores, axis=-1, keepdims=True))
    attention_weights = attention_weights / np.sum(attention_weights, axis=-1, keepdims=True)
    
    return attention_weights, source_words, target_words

# Generate cross-attention
cross_attn, source_words, target_words = cross_attention_example()

# Visualize cross-attention
plt.figure(figsize=(8, 6))
im = plt.imshow(cross_attn, cmap='Reds', aspect='auto')
plt.colorbar(im, fraction=0.046, pad=0.04, label='Attention Weight')

# Set labels
plt.xticks(range(len(source_words)), source_words, fontsize=11)
plt.yticks(range(len(target_words)), target_words, fontsize=11)
plt.xlabel('Source (Encoder) Tokens', fontsize=12)
plt.ylabel('Target (Decoder) Tokens', fontsize=12)
plt.title('Cross-Attention: Translation Example', fontsize=14, fontweight='bold')

# Add annotations
for i in range(len(target_words)):
    for j in range(len(source_words)):
        plt.text(j, i, f'{cross_attn[i, j]:.2f}',
                ha='center', va='center', color='white' if cross_attn[i, j] > 0.5 else 'black',
                fontsize=10)

plt.tight_layout()
plt.show()
No description has been provided for this image

8. Attention Head SpecializationΒΆ

Different attention heads learn to focus on different types of relationships. Research has shown that heads often specialize for specific linguistic tasks without explicit supervision.

What to observe:

  • Previous Word: Captures sequential/local dependencies
  • Determiners: Focuses on articles and determiners (linguistic structure)
  • Nouns: Attends to content words (semantic information)
  • Broad: Maintains global context

Significance: Head specialization emerges naturally during training. Some heads learn syntactic patterns (e.g., subject-verb agreement), while others capture semantic relationships or positional patterns. This specialization allows transformers to handle multiple aspects of language simultaneously.

InΒ [9]:
def create_specialized_heads():
    """
    Simulate how different attention heads might specialize.
    """
    sentence = "The cat sat on the mat near the door"
    words = sentence.split()
    seq_len = len(words)
    
    # Define specialized attention patterns
    heads = {}
    
    # Head 1: Attending to previous word (local)
    head1 = np.zeros((seq_len, seq_len))
    for i in range(seq_len):
        if i > 0:
            head1[i, i-1] = 0.7
        head1[i, i] = 0.3
    heads['Previous Word'] = head1
    
    # Head 2: Attending to determiners
    head2 = np.ones((seq_len, seq_len)) * 0.05
    det_indices = [0, 4, 7]  # "The" positions
    for idx in det_indices:
        head2[:, idx] = 0.3
    heads['Determiners'] = head2 / head2.sum(axis=1, keepdims=True)
    
    # Head 3: Attending to nouns
    head3 = np.ones((seq_len, seq_len)) * 0.05
    noun_indices = [1, 5, 8]  # "cat", "mat", "door"
    for idx in noun_indices:
        head3[:, idx] = 0.3
    heads['Nouns'] = head3 / head3.sum(axis=1, keepdims=True)
    
    # Head 4: Broad attention
    head4 = np.ones((seq_len, seq_len))
    heads['Broad'] = head4 / head4.sum(axis=1, keepdims=True)
    
    return heads, words

# Create specialized heads
specialized_heads, words = create_specialized_heads()

# Visualize
fig, axes = plt.subplots(2, 2, figsize=(14, 12))
axes = axes.flatten()

for idx, (head_name, attention) in enumerate(specialized_heads.items()):
    im = axes[idx].imshow(attention, cmap='Purples', aspect='auto', vmin=0, vmax=0.4)
    axes[idx].set_title(f'Head Specialization: {head_name}', fontsize=12, fontweight='bold')
    axes[idx].set_xticks(range(len(words)))
    axes[idx].set_yticks(range(len(words)))
    axes[idx].set_xticklabels(words, rotation=45, ha='right', fontsize=9)
    axes[idx].set_yticklabels(words, fontsize=9)
    axes[idx].set_xlabel('Attending To', fontsize=10)
    axes[idx].set_ylabel('Query Position', fontsize=10)
    plt.colorbar(im, ax=axes[idx], fraction=0.046, pad=0.04)

plt.suptitle('Attention Head Specialization Patterns', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()
No description has been provided for this image

9. Attention Distance AnalysisΒΆ

This analysis explores how attention weights vary with the distance between positions, revealing different types of locality biases that can emerge or be designed into attention mechanisms.

What to observe:

  • Linear Decay: Attention decreases linearly with distance
  • Exponential Decay: Rapid falloff creates strong local focus
  • Gaussian: Smooth, bell-curve distribution around each position
  • The graph shows how average attention changes with distance

Significance: Understanding distance-based attention patterns helps in designing efficient transformers. Many efficient variants (Linformer, Longformer) explicitly model attention as a function of distance to reduce computational complexity while maintaining performance.

InΒ [10]:
def analyze_attention_distance():
    """
    Analyze how attention varies with distance between positions.
    """
    seq_len = 20
    
    # Create different attention patterns based on distance
    patterns = {}
    
    # Linear decay
    linear = np.zeros((seq_len, seq_len))
    for i in range(seq_len):
        for j in range(seq_len):
            dist = abs(i - j)
            linear[i, j] = max(0, 1 - dist / 10)
    patterns['Linear Decay'] = linear / linear.sum(axis=1, keepdims=True)
    
    # Exponential decay
    exponential = np.zeros((seq_len, seq_len))
    for i in range(seq_len):
        for j in range(seq_len):
            dist = abs(i - j)
            exponential[i, j] = np.exp(-dist / 3)
    patterns['Exponential Decay'] = exponential / exponential.sum(axis=1, keepdims=True)
    
    # Gaussian
    gaussian = np.zeros((seq_len, seq_len))
    for i in range(seq_len):
        for j in range(seq_len):
            dist = abs(i - j)
            gaussian[i, j] = np.exp(-(dist ** 2) / (2 * 3 ** 2))
    patterns['Gaussian'] = gaussian / gaussian.sum(axis=1, keepdims=True)
    
    return patterns

# Generate patterns
distance_patterns = analyze_attention_distance()

# Visualize
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

for idx, (name, pattern) in enumerate(distance_patterns.items()):
    im = axes[idx].imshow(pattern, cmap='YlOrRd', aspect='auto')
    axes[idx].set_title(f'{name}', fontsize=12, fontweight='bold')
    axes[idx].set_xlabel('Position', fontsize=10)
    axes[idx].set_ylabel('Position', fontsize=10)
    plt.colorbar(im, ax=axes[idx], fraction=0.046, pad=0.04)

plt.suptitle('Attention Patterns Based on Position Distance', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

# Plot distance vs attention weight
plt.figure(figsize=(10, 6))
colors = ['blue', 'red', 'green']

for (name, pattern), color in zip(distance_patterns.items(), colors):
    # Get attention weights by distance
    distances = []
    weights = []
    
    for i in range(len(pattern)):
        for j in range(len(pattern)):
            distances.append(abs(i - j))
            weights.append(pattern[i, j])
    
    # Average weights by distance
    unique_distances = sorted(set(distances))
    avg_weights = []
    for d in unique_distances:
        d_weights = [w for dist, w in zip(distances, weights) if dist == d]
        avg_weights.append(np.mean(d_weights))
    
    plt.plot(unique_distances, avg_weights, marker='o', label=name, color=color, linewidth=2)

plt.xlabel('Distance Between Positions', fontsize=12)
plt.ylabel('Average Attention Weight', fontsize=12)
plt.title('Attention Weight vs Position Distance', fontsize=14, fontweight='bold')
plt.legend(fontsize=10)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
No description has been provided for this image
No description has been provided for this image

10. Layer-wise Attention EvolutionΒΆ

Attention patterns change systematically across transformer layers, with different layers capturing different types of information. This evolution from local to global patterns is key to the model's hierarchical processing.

What to observe:

  • Early layers (1-2): More local, position-based attention
  • Middle layers (3-4): Syntactic patterns emerge (grammatical relationships)
  • Later layers (5-6): Semantic and task-specific patterns dominate
  • Notice how patterns become more specialized and task-oriented in deeper layers

Significance: The layer-wise evolution shows how transformers build up representations hierarchically: from local patterns to syntax to semantics. This mirrors findings in neuroscience about hierarchical processing in the brain.

InΒ [11]:
def simulate_layer_attention():
    """
    Simulate how attention patterns might evolve across transformer layers.
    """
    sentence = "The quick brown fox jumps"
    words = sentence.split()
    seq_len = len(words)
    n_layers = 6
    
    # Initialize random embeddings
    embeddings = np.random.randn(seq_len, d_model)
    
    layer_attentions = []
    
    for layer in range(n_layers):
        # Simulate different behavior in different layers
        if layer < 2:
            # Early layers: more local attention
            noise_level = 0.3
            attention = np.eye(seq_len) * 0.5
            for i in range(seq_len):
                for j in range(max(0, i-1), min(seq_len, i+2)):
                    attention[i, j] += 0.2
        elif layer < 4:
            # Middle layers: syntactic patterns
            noise_level = 0.2
            attention = np.ones((seq_len, seq_len)) * 0.1
            # Simulate grammatical dependencies
            attention[2, 3] += 0.3  # "brown" -> "fox"
            attention[3, 4] += 0.3  # "fox" -> "jumps"
            attention[1, 3] += 0.2  # "quick" -> "fox"
        else:
            # Late layers: semantic/global patterns
            noise_level = 0.1
            attention = np.ones((seq_len, seq_len)) * 0.15
            attention[4, 3] += 0.3  # "jumps" -> "fox"
            attention[:, 0] += 0.1  # Global attention to "The"
        
        # Add noise and normalize
        attention += np.random.rand(seq_len, seq_len) * noise_level
        attention = attention / attention.sum(axis=1, keepdims=True)
        layer_attentions.append(attention)
    
    return layer_attentions, words

# Generate layer-wise attention
layer_attentions, words = simulate_layer_attention()

# Visualize
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

for layer_idx, attention in enumerate(layer_attentions):
    im = axes[layer_idx].imshow(attention, cmap='Blues', aspect='auto', vmin=0, vmax=0.5)
    axes[layer_idx].set_title(f'Layer {layer_idx + 1}', fontsize=12, fontweight='bold')
    axes[layer_idx].set_xticks(range(len(words)))
    axes[layer_idx].set_yticks(range(len(words)))
    axes[layer_idx].set_xticklabels(words, rotation=45, ha='right')
    axes[layer_idx].set_yticklabels(words)
    
    if layer_idx >= 3:
        axes[layer_idx].set_xlabel('Key', fontsize=10)
    if layer_idx % 3 == 0:
        axes[layer_idx].set_ylabel('Query', fontsize=10)

plt.suptitle('Attention Pattern Evolution Across Transformer Layers', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()
No description has been provided for this image

11. Attention Entropy AnalysisΒΆ

Entropy measures the "focus" of attention distributions. Low entropy means concentrated attention (focusing on few positions), while high entropy means distributed attention (attending broadly).

What to observe:

  • Highly Focused: Low entropy - attention concentrated on single positions
  • Moderately Focused: Medium entropy - attention on a few key positions
  • Distributed: Maximum entropy - uniform attention across all positions
  • Mixed: Variable entropy - different focusing strategies per position

Significance: Entropy analysis helps understand model behavior and can be used for model interpretation, pruning decisions, and understanding which parts of the input are most important for predictions.

InΒ [12]:
def compute_attention_entropy(attention_weights):
    """
    Compute entropy of attention distribution.
    Higher entropy = more distributed attention
    Lower entropy = more focused attention
    """
    # Add small epsilon to avoid log(0)
    eps = 1e-10
    entropy = -np.sum(attention_weights * np.log(attention_weights + eps), axis=-1)
    return entropy

# Create different attention patterns
seq_len = 10
patterns = {
    'Highly Focused': np.eye(seq_len),  # Attention only on diagonal
    'Moderately Focused': np.eye(seq_len) * 0.7 + np.ones((seq_len, seq_len)) * 0.03,
    'Distributed': np.ones((seq_len, seq_len)) / seq_len,  # Uniform attention
    'Mixed': np.random.dirichlet(np.ones(seq_len), size=seq_len)  # Random distribution
}

# Normalize patterns
for name in patterns:
    patterns[name] = patterns[name] / patterns[name].sum(axis=1, keepdims=True)

# Compute entropy for each pattern
fig, axes = plt.subplots(2, 4, figsize=(18, 8))

for idx, (name, pattern) in enumerate(patterns.items()):
    # Compute entropy
    entropy = compute_attention_entropy(pattern)
    
    # Plot attention pattern
    axes[0, idx].imshow(pattern, cmap='Blues', aspect='auto')
    axes[0, idx].set_title(f'{name}', fontsize=11, fontweight='bold')
    axes[0, idx].set_xlabel('Key Position', fontsize=9)
    axes[0, idx].set_ylabel('Query Position', fontsize=9)
    
    # Plot entropy
    axes[1, idx].bar(range(seq_len), entropy, color='coral')
    axes[1, idx].set_title(f'Entropy (avg: {np.mean(entropy):.2f})', fontsize=11)
    axes[1, idx].set_xlabel('Position', fontsize=9)
    axes[1, idx].set_ylabel('Entropy', fontsize=9)
    axes[1, idx].set_ylim([0, np.log(seq_len)])
    axes[1, idx].grid(True, alpha=0.3)

plt.suptitle('Attention Entropy Analysis: Measuring Focus vs Distribution', 
             fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()
No description has been provided for this image

12. Attention Flow VisualizationΒΆ

This visualization shows how information flows through the attention mechanism, making the abstract concept of attention more intuitive by showing it as directed connections between words.

What to observe:

  • The heatmap shows traditional attention weights
  • The flow diagram visualizes the same information as arrows
  • Arrow thickness represents attention strength
  • Only strong connections (above threshold) are shown for clarity
  • Notice how grammatically related words have stronger connections

Significance: Understanding attention as information flow helps explain how transformers process sequences. Strong attention connections indicate information pathways - which words influence the representation of other words. This visualization is particularly useful for debugging and interpreting model behavior.

InΒ [13]:
def visualize_attention_flow():
    """
    Visualize how information flows through attention.
    """
    # Simple sentence for visualization
    sentence = "Dogs love to play fetch"
    words = sentence.split()
    seq_len = len(words)
    
    # Create attention pattern with clear dependencies
    attention = np.ones((seq_len, seq_len)) * 0.05
    
    # Add specific attention patterns
    attention[0, 0] = 0.4  # "Dogs" to self
    attention[1, 0] = 0.5  # "love" to "Dogs"
    attention[2, 1] = 0.3  # "to" to "love"
    attention[3, 1] = 0.4  # "play" to "love"
    attention[3, 2] = 0.3  # "play" to "to"
    attention[4, 3] = 0.6  # "fetch" to "play"
    attention[4, 0] = 0.2  # "fetch" to "Dogs"
    
    # Normalize
    attention = attention / attention.sum(axis=1, keepdims=True)
    
    # Create flow visualization
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
    
    # 1. Traditional heatmap
    im = ax1.imshow(attention, cmap='Purples', aspect='auto')
    ax1.set_xticks(range(seq_len))
    ax1.set_yticks(range(seq_len))
    ax1.set_xticklabels(words)
    ax1.set_yticklabels(words)
    ax1.set_xlabel('Attending To', fontsize=11)
    ax1.set_ylabel('Query', fontsize=11)
    ax1.set_title('Attention Weights Matrix', fontsize=12, fontweight='bold')
    plt.colorbar(im, ax=ax1, fraction=0.046, pad=0.04)
    
    # 2. Flow diagram
    ax2.set_xlim(-1, seq_len)
    ax2.set_ylim(-0.5, 1.5)
    
    # Draw words
    y_positions = [0, 1] * (seq_len // 2 + 1)
    for i, word in enumerate(words):
        y = y_positions[i]
        ax2.text(i, y, word, ha='center', va='center', 
                bbox=dict(boxstyle='round,pad=0.3', facecolor='lightblue', alpha=0.7),
                fontsize=11, fontweight='bold')
    
    # Draw attention connections
    threshold = 0.15  # Only show strong connections
    for i in range(seq_len):
        for j in range(seq_len):
            if attention[i, j] > threshold and i != j:
                y_from = y_positions[j]
                y_to = y_positions[i]
                
                # Draw arrow with clamped alpha
                alpha_val = min(1.0, attention[i, j] * 2)  # Clamp alpha to [0, 1]
                ax2.annotate('', xy=(i, y_to), xytext=(j, y_from),
                           arrowprops=dict(arrowstyle='->', lw=attention[i, j] * 5,
                                         color='purple', alpha=alpha_val))
    
    ax2.set_title('Attention Flow Diagram', fontsize=12, fontweight='bold')
    ax2.axis('off')
    
    plt.suptitle('Information Flow Through Attention', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

visualize_attention_flow()
No description has been provided for this image

SummaryΒΆ

This notebook demonstrated key concepts in Transformer attention mechanisms:

  1. Self-Attention: How each position attends to all other positions
  2. Attention Patterns: Various types (uniform, causal, local, etc.)
  3. Multi-Head Attention: Multiple attention mechanisms learning different patterns
  4. Positional Encoding: How position information is incorporated
  5. Masked Attention: Causal masking for autoregressive generation
  6. Cross-Attention: Attention between encoder and decoder
  7. Head Specialization: How different heads learn different functions
  8. Distance Analysis: How attention varies with position distance
  9. Layer Evolution: How attention patterns change across layers
  10. Entropy Analysis: Measuring attention focus vs distribution
  11. Attention Flow: Visualizing information flow through the network

These visualizations help understand how Transformers process and relate information across sequences, which is fundamental to their success in NLP and other domains.