LLM-specific Transformer Architecture

Overview

This section focuses on how Large Language Models (LLMs) build upon and extend the basic Transformer architecture. For the fundamental Transformer architecture, please refer to api/content/deep_learning/architectures/transformers.py.

LLMs have introduced several key innovations and modifications to the original Transformer architecture to handle the challenges of training and deploying models at unprecedented scales:

Core Concepts

  • Decoder-Only vs Encoder-Decoder

    While the original Transformer used both encoder and decoder components, many modern LLMs use simplified architectures:

    • Decoder-Only (GPT-style): Used for generative tasks and autoregressive language modeling
      • Examples: GPT series, LLaMA, Claude
      • Advantages: Simpler architecture, more focused on generation tasks
      • Typically uses causal masking to prevent looking at future tokens
    • Encoder-Only (BERT-style): Used for understanding and feature extraction
      • Examples: BERT, RoBERTa
      • Better for tasks requiring bidirectional context
      • Often used for classification, named entity recognition, etc.
    • Encoder-Decoder (T5-style): Used for sequence-to-sequence tasks
      • Examples: T5, BART
      • Most similar to original Transformer
      • Well-suited for translation, summarization, etc.
  • Scaling Techniques

    LLMs have introduced various techniques to handle the challenges of scale:

    • Parallel Training:
      • Pipeline Parallelism: Different layers on different devices
      • Tensor Parallelism: Single operations split across devices
      • Data Parallelism: Different batches on different devices
    • Memory Optimization:
      • Gradient Checkpointing: Trade computation for memory
      • Mixed Precision Training: Using FP16/BF16 with FP32
      • Parameter Sharing: Reducing total parameter count
  • Advanced Attention Mechanisms

    LLMs have developed several variations on the original attention mechanism:

    • Sparse Attention Patterns:
      • Sliding Window Attention: Local context windows
      • Longformer Attention: Global + sliding window
      • Big Bird: Random + global + sliding window
    • Memory-Efficient Attention:
      • Flash Attention: Optimized attention computation
      • Sparse Memory: Compressed key/value storage
      • Linear Attention: O(n) complexity variants

Implementation

  • Code Example

    
    # Flash Attention Implementation
    def flash_attention(q, k, v, sm_scale=None):
        # q, k, v: (batch_size, num_heads, seq_len, head_dim)
        # Assumes blocked matrix multiplication is available
        
        batch_size, num_heads, seq_len, head_dim = q.shape
        if sm_scale is None:
            sm_scale = 1.0 / math.sqrt(head_dim)
        
        # Initialize output and attention statistics
        out = torch.zeros_like(q)
        m = torch.zeros(batch_size, num_heads, seq_len, device=q.device)
        l = torch.zeros(batch_size, num_heads, seq_len, device=q.device)
        
        # Process in blocks to save memory
        block_size = min(256, seq_len)
        for block_start in range(0, seq_len, block_size):
            block_end = min(block_start + block_size, seq_len)
            
            # Load current block of keys and values
            k_block = k[:, :, block_start:block_end]
            v_block = v[:, :, block_start:block_end]
            
            # Compute attention scores for current block
            current_scores = torch.matmul(q, k_block.transpose(-2, -1)) * sm_scale
            
            # Update running statistics
            m_new = torch.maximum(m, current_scores.max(dim=-1)[0])
            exp_diff = torch.exp(m.unsqueeze(-1) - m_new.unsqueeze(-1))
            
            # Update output
            current_exp = torch.exp(current_scores - m_new.unsqueeze(-1))
            l_new = l * exp_diff + current_exp.sum(dim=-1)
            
            out = out * exp_diff.unsqueeze(-1) + torch.matmul(current_exp, v_block)
            
            m = m_new
            l = l_new
        
        return out / l.unsqueeze(-1)
    
    # Memory-Efficient Transformer with Gradient Checkpointing
    class MemoryEfficientTransformer(nn.Module):
        def __init__(self, config):
            super().__init__()
            self.layers = nn.ModuleList([
                TransformerLayer(config) for _ in range(config.num_layers)
            ])
            
        def forward(self, x, use_checkpointing=True):
            def create_custom_forward(module):
                def custom_forward(*inputs):
                    return module(*inputs)
                return custom_forward
            
            if use_checkpointing:
                for layer in self.layers:
                    x = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(layer),
                        x
                    )
            else:
                for layer in self.layers:
                    x = layer(x)
            
            return x
    

Practice Questions

1. How does positional encoding work in transformers? Medium

Hint: Consider how transformers need position information since they have no recurrence or convolution

2. Why is layer normalization important in transformer architectures? Medium

Hint: Think about training stability and convergence

3. Explain the multi-head attention mechanism in transformers Hard

Hint: Think about why multiple attention heads are better than just one
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$