Transformer Attention PatternsΒΆ
This notebook provides interactive visualizations of the self-attention mechanism in Transformers, demonstrating how attention works, multi-head attention, and positional encodings.
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.patches import Rectangle
from matplotlib.gridspec import GridSpec
import warnings
warnings.filterwarnings('ignore')
# Set style for better visualizations
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")
plt.rcParams['figure.figsize'] = (15, 10)
plt.rcParams['font.size'] = 11
1. Understanding Self-AttentionΒΆ
Self-attention is the core mechanism that allows Transformers to process sequences without recurrence. Each position in the sequence can directly attend to every other position, computing relevance scores through dot products of learned query and key vectors.
What to observe:
- Notice how each word (row) distributes its attention across all words (columns)
- The diagonal often has high values (self-attention), but words also attend to semantically related words
- The attention weights sum to 1 for each row (softmax normalization)
Significance: This mechanism allows the model to capture both local and long-range dependencies in a single operation, unlike RNNs which must process sequences step by step.
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
Compute scaled dot-product attention.
Args:
Q: Query matrix (seq_len, d_k)
K: Key matrix (seq_len, d_k)
V: Value matrix (seq_len, d_v)
mask: Optional mask (seq_len, seq_len)
Returns:
output: Attention output (seq_len, d_v)
attention_weights: Attention weights (seq_len, seq_len)
"""
d_k = Q.shape[-1]
# Compute attention scores
scores = np.matmul(Q, K.T) / np.sqrt(d_k)
# Apply mask if provided
if mask is not None:
scores = scores + (mask * -1e9)
# Apply softmax
attention_weights = np.exp(scores - np.max(scores, axis=-1, keepdims=True))
attention_weights = attention_weights / np.sum(attention_weights, axis=-1, keepdims=True)
# Apply attention to values
output = np.matmul(attention_weights, V)
return output, attention_weights
# Example: Simple attention on a sentence
sentence = "The cat sat on the mat"
words = sentence.split()
seq_len = len(words)
d_model = 64
# Create random embeddings for demonstration
np.random.seed(42)
embeddings = np.random.randn(seq_len, d_model)
# Compute Q, K, V (in practice, these would be learned projections)
Q = embeddings @ np.random.randn(d_model, d_model)
K = embeddings @ np.random.randn(d_model, d_model)
V = embeddings @ np.random.randn(d_model, d_model)
# Compute attention
output, attention_weights = scaled_dot_product_attention(Q, K, V)
# Visualize attention weights
plt.figure(figsize=(8, 6))
sns.heatmap(attention_weights, annot=True, fmt='.2f', cmap='Blues',
xticklabels=words, yticklabels=words, cbar_kws={'label': 'Attention Weight'})
plt.title('Self-Attention Weights: "' + sentence + '"', fontsize=14, fontweight='bold')
plt.xlabel('Keys (Attended to)', fontsize=12)
plt.ylabel('Queries (Attending from)', fontsize=12)
plt.tight_layout()
plt.show()
2. Visualizing Attention PatternsΒΆ
Different attention patterns emerge in transformers for different purposes. Understanding these patterns helps explain how transformers process information and what inductive biases can be built into the architecture.
What to observe:
- Uniform: Equal attention to all positions - useful for aggregating global information
- Self-Focus: Strong diagonal pattern - preserves position-specific information
- Causal: Lower triangular - ensures autoregressive property for generation
- Local: Window-based attention - captures local dependencies efficiently
- Strided: Sparse attention pattern - reduces computational complexity
- Global+Local: Hybrid pattern - combines local context with global anchors
Significance: These patterns represent different strategies for information flow. Modern efficient transformers often use combinations of these patterns to balance expressiveness with computational efficiency.
def create_attention_patterns():
"""Create different types of attention patterns."""
seq_len = 10
patterns = {}
# 1. Uniform attention
patterns['Uniform'] = np.ones((seq_len, seq_len)) / seq_len
# 2. Diagonal (self) attention
patterns['Self-Focus'] = np.eye(seq_len) * 0.8 + np.ones((seq_len, seq_len)) * 0.02
# 3. Lower triangular (causal) attention
causal = np.tril(np.ones((seq_len, seq_len)))
patterns['Causal'] = causal / causal.sum(axis=1, keepdims=True)
# 4. Local attention (attending to nearby positions)
local = np.zeros((seq_len, seq_len))
window_size = 3
for i in range(seq_len):
start = max(0, i - window_size // 2)
end = min(seq_len, i + window_size // 2 + 1)
local[i, start:end] = 1
patterns['Local'] = local / local.sum(axis=1, keepdims=True)
# 5. Strided attention
strided = np.zeros((seq_len, seq_len))
for i in range(seq_len):
strided[i, ::2] = 1 # Attend to every other position
patterns['Strided'] = strided / strided.sum(axis=1, keepdims=True)
# 6. Global + Local attention
global_local = local.copy()
global_local[:, 0] = 1 # All positions attend to first position
global_local[0, :] = 1 # First position attends to all
patterns['Global+Local'] = global_local / global_local.sum(axis=1, keepdims=True)
return patterns
# Create and visualize patterns
patterns = create_attention_patterns()
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()
for idx, (name, pattern) in enumerate(patterns.items()):
ax = axes[idx]
im = ax.imshow(pattern, cmap='Blues', aspect='auto', vmin=0, vmax=0.5)
ax.set_title(f'{name} Attention', fontsize=12, fontweight='bold')
ax.set_xlabel('Key Position', fontsize=10)
ax.set_ylabel('Query Position', fontsize=10)
ax.set_xticks(range(10))
ax.set_yticks(range(10))
plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
plt.suptitle('Common Attention Patterns in Transformers', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()
3. Multi-Head AttentionΒΆ
Multi-head attention allows the model to jointly attend to information from different representation subspaces. Instead of having a single attention function, the model uses multiple "heads" that can learn different types of relationships.
What to observe:
- Each head shows a different attention pattern - some might be more local, others more distributed
- The patterns are learned during training and often specialize for different linguistic phenomena
- Notice how different heads might focus on different aspects of the sequence
Significance: Multi-head attention is crucial for transformer expressiveness. Different heads can capture different types of relationships (syntactic, semantic, positional) simultaneously, which single attention cannot achieve effectively.
class MultiHeadAttention:
def __init__(self, d_model, n_heads):
"""
Multi-Head Attention mechanism.
Args:
d_model: Model dimension
n_heads: Number of attention heads
"""
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
# Initialize weight matrices
self.W_q = np.random.randn(n_heads, d_model, self.d_k) * 0.1
self.W_k = np.random.randn(n_heads, d_model, self.d_k) * 0.1
self.W_v = np.random.randn(n_heads, d_model, self.d_k) * 0.1
self.W_o = np.random.randn(n_heads * self.d_k, d_model) * 0.1
def forward(self, x):
"""
Forward pass of multi-head attention.
Args:
x: Input tensor (seq_len, d_model)
Returns:
output: Output tensor (seq_len, d_model)
attention_weights: Attention weights for each head
"""
seq_len = x.shape[0]
attention_weights = []
head_outputs = []
for h in range(self.n_heads):
# Project to Q, K, V
Q = x @ self.W_q[h]
K = x @ self.W_k[h]
V = x @ self.W_v[h]
# Compute attention for this head
head_out, attn_weights = scaled_dot_product_attention(Q, K, V)
head_outputs.append(head_out)
attention_weights.append(attn_weights)
# Concatenate heads
concat_output = np.concatenate(head_outputs, axis=-1)
# Final projection
output = concat_output @ self.W_o
return output, attention_weights
# Example: Multi-head attention on a sequence
seq_len = 8
d_model = 64
n_heads = 4
# Create sample input
x = np.random.randn(seq_len, d_model)
# Apply multi-head attention
mha = MultiHeadAttention(d_model, n_heads)
output, attention_weights_heads = mha.forward(x)
# Visualize attention weights for each head
fig, axes = plt.subplots(1, n_heads, figsize=(16, 4))
for h in range(n_heads):
axes[h].imshow(attention_weights_heads[h], cmap='Blues', aspect='auto')
axes[h].set_title(f'Head {h+1}', fontsize=12)
axes[h].set_xlabel('Key Position', fontsize=10)
if h == 0:
axes[h].set_ylabel('Query Position', fontsize=10)
axes[h].set_xticks(range(seq_len))
axes[h].set_yticks(range(seq_len))
plt.suptitle('Multi-Head Attention: Different Heads Learn Different Patterns',
fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()
4. Positional EncodingΒΆ
Since attention mechanisms are permutation-invariant (order doesn't matter), transformers need a way to incorporate position information. Positional encodings add position-dependent signals to embeddings using sinusoidal functions.
What to observe:
- The alternating patterns in the encoding matrix - different frequencies for different dimensions
- The similarity matrix shows how positional encodings relate across positions
- Notice the periodic nature of the sinusoidal functions at different scales
- The similarity decreases smoothly with distance between positions
Significance: Positional encodings allow the model to use position information without learning position-specific parameters. The sinusoidal pattern enables the model to extrapolate to sequence lengths not seen during training.
def get_positional_encoding(seq_len, d_model):
"""
Generate positional encoding for a sequence.
Args:
seq_len: Sequence length
d_model: Model dimension
Returns:
PE: Positional encoding matrix (seq_len, d_model)
"""
PE = np.zeros((seq_len, d_model))
position = np.arange(seq_len).reshape(-1, 1)
# Create div_term for the sinusoidal pattern
div_term = np.exp(np.arange(0, d_model, 2) * -(np.log(10000.0) / d_model))
# Apply sin to even indices
PE[:, 0::2] = np.sin(position * div_term)
# Apply cos to odd indices
if d_model % 2 == 0:
PE[:, 1::2] = np.cos(position * div_term)
else:
PE[:, 1::2] = np.cos(position * div_term[:-1])
return PE
# Generate positional encodings
seq_len = 50
d_model = 128
PE = get_positional_encoding(seq_len, d_model)
# Visualize positional encodings
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
# 1. Full positional encoding matrix
axes[0, 0].imshow(PE, cmap='RdBu_r', aspect='auto')
axes[0, 0].set_title('Positional Encoding Matrix', fontsize=12, fontweight='bold')
axes[0, 0].set_xlabel('Dimension', fontsize=10)
axes[0, 0].set_ylabel('Position', fontsize=10)
axes[0, 0].set_colorbar = plt.colorbar(axes[0, 0].images[0], ax=axes[0, 0])
# 2. First few dimensions
for i in range(8):
axes[0, 1].plot(PE[:, i], label=f'Dim {i}', alpha=0.8)
axes[0, 1].set_title('First 8 Dimensions of Positional Encoding', fontsize=12, fontweight='bold')
axes[0, 1].set_xlabel('Position', fontsize=10)
axes[0, 1].set_ylabel('Encoding Value', fontsize=10)
axes[0, 1].legend(loc='upper right', fontsize=8)
axes[0, 1].grid(True, alpha=0.3)
# 3. Similarity matrix (dot product between positions)
similarity = PE @ PE.T
axes[1, 0].imshow(similarity, cmap='viridis', aspect='auto')
axes[1, 0].set_title('Position Similarity Matrix', fontsize=12, fontweight='bold')
axes[1, 0].set_xlabel('Position', fontsize=10)
axes[1, 0].set_ylabel('Position', fontsize=10)
# 4. Distance-based similarity
positions = [0, 5, 10, 20, 30, 40]
for pos in positions:
similarity_to_pos = PE @ PE[pos]
axes[1, 1].plot(similarity_to_pos, label=f'Pos {pos}', alpha=0.7)
axes[1, 1].set_title('Similarity to Different Positions', fontsize=12, fontweight='bold')
axes[1, 1].set_xlabel('Position', fontsize=10)
axes[1, 1].set_ylabel('Similarity', fontsize=10)
axes[1, 1].legend(loc='upper right', fontsize=8)
axes[1, 1].grid(True, alpha=0.3)
plt.suptitle('Positional Encoding Visualization', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()
5. Attention with Positional EncodingΒΆ
This section compares attention patterns with and without positional encoding to understand how position information affects attention distribution.
What to observe:
- Without positional encoding: attention is based purely on content similarity
- With positional encoding: nearby positions tend to have slightly higher attention
- The difference plot shows where positional information changes attention patterns
- Notice how position encoding creates more structured, position-aware patterns
Significance: Positional encoding helps the model understand word order and distance relationships. This is critical for tasks like understanding grammar, where word order matters significantly.
# Demonstrate how positional encoding affects attention
sentence = "The quick brown fox jumps over the lazy dog"
words = sentence.split()
seq_len = len(words)
d_model = 64
# Create word embeddings (random for demonstration)
np.random.seed(42)
word_embeddings = np.random.randn(seq_len, d_model)
# Get positional encodings
pos_encodings = get_positional_encoding(seq_len, d_model)
# Compare attention with and without positional encoding
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
# 1. Without positional encoding
Q1 = word_embeddings @ np.random.randn(d_model, d_model) * 0.1
K1 = word_embeddings @ np.random.randn(d_model, d_model) * 0.1
V1 = word_embeddings @ np.random.randn(d_model, d_model) * 0.1
_, attn_no_pos = scaled_dot_product_attention(Q1, K1, V1)
im1 = axes[0].imshow(attn_no_pos, cmap='Blues', aspect='auto')
axes[0].set_title('Attention WITHOUT Positional Encoding', fontsize=12, fontweight='bold')
axes[0].set_xticks(range(seq_len))
axes[0].set_yticks(range(seq_len))
axes[0].set_xticklabels(words, rotation=45, ha='right', fontsize=9)
axes[0].set_yticklabels(words, fontsize=9)
plt.colorbar(im1, ax=axes[0], fraction=0.046, pad=0.04)
# 2. With positional encoding
embeddings_with_pos = word_embeddings + pos_encodings * 0.1
Q2 = embeddings_with_pos @ np.random.randn(d_model, d_model) * 0.1
K2 = embeddings_with_pos @ np.random.randn(d_model, d_model) * 0.1
V2 = embeddings_with_pos @ np.random.randn(d_model, d_model) * 0.1
_, attn_with_pos = scaled_dot_product_attention(Q2, K2, V2)
im2 = axes[1].imshow(attn_with_pos, cmap='Blues', aspect='auto')
axes[1].set_title('Attention WITH Positional Encoding', fontsize=12, fontweight='bold')
axes[1].set_xticks(range(seq_len))
axes[1].set_yticks(range(seq_len))
axes[1].set_xticklabels(words, rotation=45, ha='right', fontsize=9)
axes[1].set_yticklabels(words, fontsize=9)
plt.colorbar(im2, ax=axes[1], fraction=0.046, pad=0.04)
# 3. Difference
diff = attn_with_pos - attn_no_pos
im3 = axes[2].imshow(diff, cmap='RdBu_r', aspect='auto', vmin=-0.1, vmax=0.1)
axes[2].set_title('Difference (Effect of Positional Encoding)', fontsize=12, fontweight='bold')
axes[2].set_xticks(range(seq_len))
axes[2].set_yticks(range(seq_len))
axes[2].set_xticklabels(words, rotation=45, ha='right', fontsize=9)
axes[2].set_yticklabels(words, fontsize=9)
plt.colorbar(im3, ax=axes[2], fraction=0.046, pad=0.04)
plt.suptitle('Impact of Positional Encoding on Attention', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()
6. Masked Self-Attention (Causal Attention)ΒΆ
Causal attention is used in autoregressive models like GPT, where each position can only attend to previous positions. This ensures the model can't "cheat" by looking at future tokens during training.
What to observe:
- The mask prevents attention to future positions (upper triangle is blocked)
- Each position can only see itself and previous positions
- The attention pattern is strictly lower triangular
- Later tokens have access to more context than earlier ones
Significance: Causal masking is essential for language generation tasks. It ensures the model generates text left-to-right without accessing future information, maintaining the autoregressive property needed for coherent generation.
def create_causal_mask(seq_len):
"""
Create a causal mask for autoregressive attention.
Args:
seq_len: Sequence length
Returns:
mask: Upper triangular mask (seq_len, seq_len)
"""
mask = np.triu(np.ones((seq_len, seq_len)), k=1)
return mask
# Example sequence for generation
generation_sequence = "Once upon a time in a land"
gen_words = generation_sequence.split()
gen_seq_len = len(gen_words)
# Create embeddings
gen_embeddings = np.random.randn(gen_seq_len, d_model)
# Create causal mask
causal_mask = create_causal_mask(gen_seq_len)
# Compute masked attention
Q = gen_embeddings @ np.random.randn(d_model, d_model) * 0.1
K = gen_embeddings @ np.random.randn(d_model, d_model) * 0.1
V = gen_embeddings @ np.random.randn(d_model, d_model) * 0.1
_, masked_attention = scaled_dot_product_attention(Q, K, V, mask=causal_mask)
# Visualize masked attention
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
# 1. Causal mask
axes[0].imshow(1 - causal_mask, cmap='Greys', aspect='auto')
axes[0].set_title('Causal Mask (White = Allowed, Black = Masked)', fontsize=12, fontweight='bold')
axes[0].set_xticks(range(gen_seq_len))
axes[0].set_yticks(range(gen_seq_len))
axes[0].set_xticklabels(gen_words, rotation=45, ha='right')
axes[0].set_yticklabels(gen_words)
axes[0].set_xlabel('Can Attend To', fontsize=10)
axes[0].set_ylabel('Position', fontsize=10)
# 2. Masked attention weights
im = axes[1].imshow(masked_attention, cmap='Blues', aspect='auto')
axes[1].set_title('Causal (Masked) Self-Attention', fontsize=12, fontweight='bold')
axes[1].set_xticks(range(gen_seq_len))
axes[1].set_yticks(range(gen_seq_len))
axes[1].set_xticklabels(gen_words, rotation=45, ha='right')
axes[1].set_yticklabels(gen_words)
axes[1].set_xlabel('Key Position', fontsize=10)
axes[1].set_ylabel('Query Position', fontsize=10)
plt.colorbar(im, ax=axes[1], fraction=0.046, pad=0.04)
plt.suptitle('Causal Attention for Autoregressive Generation', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()
7. Cross-Attention (Encoder-Decoder Attention)ΒΆ
Cross-attention connects encoder and decoder in sequence-to-sequence models. The decoder queries attend to encoder keys and values, allowing the decoder to focus on relevant parts of the input.
What to observe:
- Queries come from the target (decoder) sequence
- Keys and values come from the source (encoder) sequence
- Each target word distributes attention over source words
- The pattern often shows rough word alignment in translation
Significance: Cross-attention is the key mechanism for tasks like translation, where the model needs to align output with relevant parts of the input. It allows the decoder to dynamically focus on different parts of the source sequence.
def cross_attention_example():
"""
Demonstrate cross-attention between encoder and decoder.
"""
# Source sequence (encoder)
source = "Hello world"
source_words = source.split()
source_len = len(source_words)
# Target sequence (decoder)
target = "Bonjour le monde"
target_words = target.split()
target_len = len(target_words)
# Create embeddings
source_embeddings = np.random.randn(source_len, d_model)
target_embeddings = np.random.randn(target_len, d_model)
# Cross-attention: Q from decoder, K and V from encoder
Q = target_embeddings @ np.random.randn(d_model, d_model) * 0.1
K = source_embeddings @ np.random.randn(d_model, d_model) * 0.1
V = source_embeddings @ np.random.randn(d_model, d_model) * 0.1
# Compute cross-attention
d_k = Q.shape[-1]
scores = np.matmul(Q, K.T) / np.sqrt(d_k)
attention_weights = np.exp(scores - np.max(scores, axis=-1, keepdims=True))
attention_weights = attention_weights / np.sum(attention_weights, axis=-1, keepdims=True)
return attention_weights, source_words, target_words
# Generate cross-attention
cross_attn, source_words, target_words = cross_attention_example()
# Visualize cross-attention
plt.figure(figsize=(8, 6))
im = plt.imshow(cross_attn, cmap='Reds', aspect='auto')
plt.colorbar(im, fraction=0.046, pad=0.04, label='Attention Weight')
# Set labels
plt.xticks(range(len(source_words)), source_words, fontsize=11)
plt.yticks(range(len(target_words)), target_words, fontsize=11)
plt.xlabel('Source (Encoder) Tokens', fontsize=12)
plt.ylabel('Target (Decoder) Tokens', fontsize=12)
plt.title('Cross-Attention: Translation Example', fontsize=14, fontweight='bold')
# Add annotations
for i in range(len(target_words)):
for j in range(len(source_words)):
plt.text(j, i, f'{cross_attn[i, j]:.2f}',
ha='center', va='center', color='white' if cross_attn[i, j] > 0.5 else 'black',
fontsize=10)
plt.tight_layout()
plt.show()
8. Attention Head SpecializationΒΆ
Different attention heads learn to focus on different types of relationships. Research has shown that heads often specialize for specific linguistic tasks without explicit supervision.
What to observe:
- Previous Word: Captures sequential/local dependencies
- Determiners: Focuses on articles and determiners (linguistic structure)
- Nouns: Attends to content words (semantic information)
- Broad: Maintains global context
Significance: Head specialization emerges naturally during training. Some heads learn syntactic patterns (e.g., subject-verb agreement), while others capture semantic relationships or positional patterns. This specialization allows transformers to handle multiple aspects of language simultaneously.
def create_specialized_heads():
"""
Simulate how different attention heads might specialize.
"""
sentence = "The cat sat on the mat near the door"
words = sentence.split()
seq_len = len(words)
# Define specialized attention patterns
heads = {}
# Head 1: Attending to previous word (local)
head1 = np.zeros((seq_len, seq_len))
for i in range(seq_len):
if i > 0:
head1[i, i-1] = 0.7
head1[i, i] = 0.3
heads['Previous Word'] = head1
# Head 2: Attending to determiners
head2 = np.ones((seq_len, seq_len)) * 0.05
det_indices = [0, 4, 7] # "The" positions
for idx in det_indices:
head2[:, idx] = 0.3
heads['Determiners'] = head2 / head2.sum(axis=1, keepdims=True)
# Head 3: Attending to nouns
head3 = np.ones((seq_len, seq_len)) * 0.05
noun_indices = [1, 5, 8] # "cat", "mat", "door"
for idx in noun_indices:
head3[:, idx] = 0.3
heads['Nouns'] = head3 / head3.sum(axis=1, keepdims=True)
# Head 4: Broad attention
head4 = np.ones((seq_len, seq_len))
heads['Broad'] = head4 / head4.sum(axis=1, keepdims=True)
return heads, words
# Create specialized heads
specialized_heads, words = create_specialized_heads()
# Visualize
fig, axes = plt.subplots(2, 2, figsize=(14, 12))
axes = axes.flatten()
for idx, (head_name, attention) in enumerate(specialized_heads.items()):
im = axes[idx].imshow(attention, cmap='Purples', aspect='auto', vmin=0, vmax=0.4)
axes[idx].set_title(f'Head Specialization: {head_name}', fontsize=12, fontweight='bold')
axes[idx].set_xticks(range(len(words)))
axes[idx].set_yticks(range(len(words)))
axes[idx].set_xticklabels(words, rotation=45, ha='right', fontsize=9)
axes[idx].set_yticklabels(words, fontsize=9)
axes[idx].set_xlabel('Attending To', fontsize=10)
axes[idx].set_ylabel('Query Position', fontsize=10)
plt.colorbar(im, ax=axes[idx], fraction=0.046, pad=0.04)
plt.suptitle('Attention Head Specialization Patterns', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()
9. Attention Distance AnalysisΒΆ
This analysis explores how attention weights vary with the distance between positions, revealing different types of locality biases that can emerge or be designed into attention mechanisms.
What to observe:
- Linear Decay: Attention decreases linearly with distance
- Exponential Decay: Rapid falloff creates strong local focus
- Gaussian: Smooth, bell-curve distribution around each position
- The graph shows how average attention changes with distance
Significance: Understanding distance-based attention patterns helps in designing efficient transformers. Many efficient variants (Linformer, Longformer) explicitly model attention as a function of distance to reduce computational complexity while maintaining performance.
def analyze_attention_distance():
"""
Analyze how attention varies with distance between positions.
"""
seq_len = 20
# Create different attention patterns based on distance
patterns = {}
# Linear decay
linear = np.zeros((seq_len, seq_len))
for i in range(seq_len):
for j in range(seq_len):
dist = abs(i - j)
linear[i, j] = max(0, 1 - dist / 10)
patterns['Linear Decay'] = linear / linear.sum(axis=1, keepdims=True)
# Exponential decay
exponential = np.zeros((seq_len, seq_len))
for i in range(seq_len):
for j in range(seq_len):
dist = abs(i - j)
exponential[i, j] = np.exp(-dist / 3)
patterns['Exponential Decay'] = exponential / exponential.sum(axis=1, keepdims=True)
# Gaussian
gaussian = np.zeros((seq_len, seq_len))
for i in range(seq_len):
for j in range(seq_len):
dist = abs(i - j)
gaussian[i, j] = np.exp(-(dist ** 2) / (2 * 3 ** 2))
patterns['Gaussian'] = gaussian / gaussian.sum(axis=1, keepdims=True)
return patterns
# Generate patterns
distance_patterns = analyze_attention_distance()
# Visualize
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
for idx, (name, pattern) in enumerate(distance_patterns.items()):
im = axes[idx].imshow(pattern, cmap='YlOrRd', aspect='auto')
axes[idx].set_title(f'{name}', fontsize=12, fontweight='bold')
axes[idx].set_xlabel('Position', fontsize=10)
axes[idx].set_ylabel('Position', fontsize=10)
plt.colorbar(im, ax=axes[idx], fraction=0.046, pad=0.04)
plt.suptitle('Attention Patterns Based on Position Distance', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()
# Plot distance vs attention weight
plt.figure(figsize=(10, 6))
colors = ['blue', 'red', 'green']
for (name, pattern), color in zip(distance_patterns.items(), colors):
# Get attention weights by distance
distances = []
weights = []
for i in range(len(pattern)):
for j in range(len(pattern)):
distances.append(abs(i - j))
weights.append(pattern[i, j])
# Average weights by distance
unique_distances = sorted(set(distances))
avg_weights = []
for d in unique_distances:
d_weights = [w for dist, w in zip(distances, weights) if dist == d]
avg_weights.append(np.mean(d_weights))
plt.plot(unique_distances, avg_weights, marker='o', label=name, color=color, linewidth=2)
plt.xlabel('Distance Between Positions', fontsize=12)
plt.ylabel('Average Attention Weight', fontsize=12)
plt.title('Attention Weight vs Position Distance', fontsize=14, fontweight='bold')
plt.legend(fontsize=10)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
10. Layer-wise Attention EvolutionΒΆ
Attention patterns change systematically across transformer layers, with different layers capturing different types of information. This evolution from local to global patterns is key to the model's hierarchical processing.
What to observe:
- Early layers (1-2): More local, position-based attention
- Middle layers (3-4): Syntactic patterns emerge (grammatical relationships)
- Later layers (5-6): Semantic and task-specific patterns dominate
- Notice how patterns become more specialized and task-oriented in deeper layers
Significance: The layer-wise evolution shows how transformers build up representations hierarchically: from local patterns to syntax to semantics. This mirrors findings in neuroscience about hierarchical processing in the brain.
def simulate_layer_attention():
"""
Simulate how attention patterns might evolve across transformer layers.
"""
sentence = "The quick brown fox jumps"
words = sentence.split()
seq_len = len(words)
n_layers = 6
# Initialize random embeddings
embeddings = np.random.randn(seq_len, d_model)
layer_attentions = []
for layer in range(n_layers):
# Simulate different behavior in different layers
if layer < 2:
# Early layers: more local attention
noise_level = 0.3
attention = np.eye(seq_len) * 0.5
for i in range(seq_len):
for j in range(max(0, i-1), min(seq_len, i+2)):
attention[i, j] += 0.2
elif layer < 4:
# Middle layers: syntactic patterns
noise_level = 0.2
attention = np.ones((seq_len, seq_len)) * 0.1
# Simulate grammatical dependencies
attention[2, 3] += 0.3 # "brown" -> "fox"
attention[3, 4] += 0.3 # "fox" -> "jumps"
attention[1, 3] += 0.2 # "quick" -> "fox"
else:
# Late layers: semantic/global patterns
noise_level = 0.1
attention = np.ones((seq_len, seq_len)) * 0.15
attention[4, 3] += 0.3 # "jumps" -> "fox"
attention[:, 0] += 0.1 # Global attention to "The"
# Add noise and normalize
attention += np.random.rand(seq_len, seq_len) * noise_level
attention = attention / attention.sum(axis=1, keepdims=True)
layer_attentions.append(attention)
return layer_attentions, words
# Generate layer-wise attention
layer_attentions, words = simulate_layer_attention()
# Visualize
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()
for layer_idx, attention in enumerate(layer_attentions):
im = axes[layer_idx].imshow(attention, cmap='Blues', aspect='auto', vmin=0, vmax=0.5)
axes[layer_idx].set_title(f'Layer {layer_idx + 1}', fontsize=12, fontweight='bold')
axes[layer_idx].set_xticks(range(len(words)))
axes[layer_idx].set_yticks(range(len(words)))
axes[layer_idx].set_xticklabels(words, rotation=45, ha='right')
axes[layer_idx].set_yticklabels(words)
if layer_idx >= 3:
axes[layer_idx].set_xlabel('Key', fontsize=10)
if layer_idx % 3 == 0:
axes[layer_idx].set_ylabel('Query', fontsize=10)
plt.suptitle('Attention Pattern Evolution Across Transformer Layers', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()
11. Attention Entropy AnalysisΒΆ
Entropy measures the "focus" of attention distributions. Low entropy means concentrated attention (focusing on few positions), while high entropy means distributed attention (attending broadly).
What to observe:
- Highly Focused: Low entropy - attention concentrated on single positions
- Moderately Focused: Medium entropy - attention on a few key positions
- Distributed: Maximum entropy - uniform attention across all positions
- Mixed: Variable entropy - different focusing strategies per position
Significance: Entropy analysis helps understand model behavior and can be used for model interpretation, pruning decisions, and understanding which parts of the input are most important for predictions.
def compute_attention_entropy(attention_weights):
"""
Compute entropy of attention distribution.
Higher entropy = more distributed attention
Lower entropy = more focused attention
"""
# Add small epsilon to avoid log(0)
eps = 1e-10
entropy = -np.sum(attention_weights * np.log(attention_weights + eps), axis=-1)
return entropy
# Create different attention patterns
seq_len = 10
patterns = {
'Highly Focused': np.eye(seq_len), # Attention only on diagonal
'Moderately Focused': np.eye(seq_len) * 0.7 + np.ones((seq_len, seq_len)) * 0.03,
'Distributed': np.ones((seq_len, seq_len)) / seq_len, # Uniform attention
'Mixed': np.random.dirichlet(np.ones(seq_len), size=seq_len) # Random distribution
}
# Normalize patterns
for name in patterns:
patterns[name] = patterns[name] / patterns[name].sum(axis=1, keepdims=True)
# Compute entropy for each pattern
fig, axes = plt.subplots(2, 4, figsize=(18, 8))
for idx, (name, pattern) in enumerate(patterns.items()):
# Compute entropy
entropy = compute_attention_entropy(pattern)
# Plot attention pattern
axes[0, idx].imshow(pattern, cmap='Blues', aspect='auto')
axes[0, idx].set_title(f'{name}', fontsize=11, fontweight='bold')
axes[0, idx].set_xlabel('Key Position', fontsize=9)
axes[0, idx].set_ylabel('Query Position', fontsize=9)
# Plot entropy
axes[1, idx].bar(range(seq_len), entropy, color='coral')
axes[1, idx].set_title(f'Entropy (avg: {np.mean(entropy):.2f})', fontsize=11)
axes[1, idx].set_xlabel('Position', fontsize=9)
axes[1, idx].set_ylabel('Entropy', fontsize=9)
axes[1, idx].set_ylim([0, np.log(seq_len)])
axes[1, idx].grid(True, alpha=0.3)
plt.suptitle('Attention Entropy Analysis: Measuring Focus vs Distribution',
fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()
12. Attention Flow VisualizationΒΆ
This visualization shows how information flows through the attention mechanism, making the abstract concept of attention more intuitive by showing it as directed connections between words.
What to observe:
- The heatmap shows traditional attention weights
- The flow diagram visualizes the same information as arrows
- Arrow thickness represents attention strength
- Only strong connections (above threshold) are shown for clarity
- Notice how grammatically related words have stronger connections
Significance: Understanding attention as information flow helps explain how transformers process sequences. Strong attention connections indicate information pathways - which words influence the representation of other words. This visualization is particularly useful for debugging and interpreting model behavior.
def visualize_attention_flow():
"""
Visualize how information flows through attention.
"""
# Simple sentence for visualization
sentence = "Dogs love to play fetch"
words = sentence.split()
seq_len = len(words)
# Create attention pattern with clear dependencies
attention = np.ones((seq_len, seq_len)) * 0.05
# Add specific attention patterns
attention[0, 0] = 0.4 # "Dogs" to self
attention[1, 0] = 0.5 # "love" to "Dogs"
attention[2, 1] = 0.3 # "to" to "love"
attention[3, 1] = 0.4 # "play" to "love"
attention[3, 2] = 0.3 # "play" to "to"
attention[4, 3] = 0.6 # "fetch" to "play"
attention[4, 0] = 0.2 # "fetch" to "Dogs"
# Normalize
attention = attention / attention.sum(axis=1, keepdims=True)
# Create flow visualization
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
# 1. Traditional heatmap
im = ax1.imshow(attention, cmap='Purples', aspect='auto')
ax1.set_xticks(range(seq_len))
ax1.set_yticks(range(seq_len))
ax1.set_xticklabels(words)
ax1.set_yticklabels(words)
ax1.set_xlabel('Attending To', fontsize=11)
ax1.set_ylabel('Query', fontsize=11)
ax1.set_title('Attention Weights Matrix', fontsize=12, fontweight='bold')
plt.colorbar(im, ax=ax1, fraction=0.046, pad=0.04)
# 2. Flow diagram
ax2.set_xlim(-1, seq_len)
ax2.set_ylim(-0.5, 1.5)
# Draw words
y_positions = [0, 1] * (seq_len // 2 + 1)
for i, word in enumerate(words):
y = y_positions[i]
ax2.text(i, y, word, ha='center', va='center',
bbox=dict(boxstyle='round,pad=0.3', facecolor='lightblue', alpha=0.7),
fontsize=11, fontweight='bold')
# Draw attention connections
threshold = 0.15 # Only show strong connections
for i in range(seq_len):
for j in range(seq_len):
if attention[i, j] > threshold and i != j:
y_from = y_positions[j]
y_to = y_positions[i]
# Draw arrow with clamped alpha
alpha_val = min(1.0, attention[i, j] * 2) # Clamp alpha to [0, 1]
ax2.annotate('', xy=(i, y_to), xytext=(j, y_from),
arrowprops=dict(arrowstyle='->', lw=attention[i, j] * 5,
color='purple', alpha=alpha_val))
ax2.set_title('Attention Flow Diagram', fontsize=12, fontweight='bold')
ax2.axis('off')
plt.suptitle('Information Flow Through Attention', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()
visualize_attention_flow()
SummaryΒΆ
This notebook demonstrated key concepts in Transformer attention mechanisms:
- Self-Attention: How each position attends to all other positions
- Attention Patterns: Various types (uniform, causal, local, etc.)
- Multi-Head Attention: Multiple attention mechanisms learning different patterns
- Positional Encoding: How position information is incorporated
- Masked Attention: Causal masking for autoregressive generation
- Cross-Attention: Attention between encoder and decoder
- Head Specialization: How different heads learn different functions
- Distance Analysis: How attention varies with position distance
- Layer Evolution: How attention patterns change across layers
- Entropy Analysis: Measuring attention focus vs distribution
- Attention Flow: Visualizing information flow through the network
These visualizations help understand how Transformers process and relate information across sequences, which is fundamental to their success in NLP and other domains.