Attention Mechanisms

Overview

In the context of neural networks, particularly for sequence processing tasks, an attention mechanism allows a model to dynamically focus on different parts of an input sequence when producing an output. Instead of compressing an entire input sequence into a single fixed-length context vector (as in earlier encoder-decoder architectures), attention allows the model to selectively look at relevant parts of the source sequence at each step of generating the target sequence.

The core idea is to compute a set of attention weights for the input sequence elements. These weights determine how much importance or "attention" each element should receive. The output is then typically a weighted sum of the input elements based on these attention weights.

This mechanism was a significant breakthrough, especially for tasks like machine translation, as it helped models handle long sequences more effectively and align relevant parts of the source and target sentences.

Note: This content provides a deep dive into attention mechanisms. For practical implementations:

  • Base Transformer Architecture and Implementation: api/content/deep_learning/architectures/transformers.py
  • LLM-specific Transformer Details: api/content/modern_ai/llms/transformer_architecture.py

The content here focuses on the theoretical foundations and variations of attention mechanisms, while the transformer-specific files provide concrete implementations and architectural details.

Core Concepts

  • General Attention Framework (Query, Key, Value)

    Most attention mechanisms can be described in terms of three components:

    • Queries (Q): A set of vectors representing what information the current output step is looking for or is interested in. In a sequence-to-sequence model, a query might be the hidden state of the decoder at the current time step.
    • Keys (K): A set of vectors associated with the input sequence elements. Each key corresponds to a piece of information that might be relevant to a query. In an encoder-decoder model, keys are often derived from the encoder hidden states.
    • Values (V): A set of vectors also associated with the input sequence elements. Each value contains the actual content or representation that should be retrieved if its corresponding key is deemed relevant by a query. Values are also often derived from encoder hidden states.

    The attention process involves:

    1. Calculating Compatibility Scores: For each query, a compatibility score is computed with every key. This score indicates how well the query matches each key. Common scoring functions include dot product, scaled dot product, or a small feed-forward network (additive/Bahdanau attention).
    2. Converting Scores to Weights: The compatibility scores are typically passed through a softmax function to convert them into attention weights. These weights are positive and sum to 1, representing a probability distribution over the input elements.
    3. Computing the Context Vector: The final output of the attention mechanism (the context vector) is a weighted sum of the value vectors, using the attention weights.

    $$\text{Attention}(Q, K, V) = \sum_{i} \text{softmax}(\text{score}(Q, K_i)) V_i$$

    The Scaled Dot-Product Attention, prominent in Transformers, uses: $$\text{Attention}(Q, K, V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})V$$

  • Additive Attention (Bahdanau Attention)

    Introduced by Bahdanau et al. (2014) for machine translation, this type of attention uses a feed-forward network to compute the compatibility score between the query (decoder hidden state \(h_t\)) and keys (encoder hidden states \(s_j\)).

    The alignment score \(e_{tj}\) is calculated as:

    $$e_{tj} = v_a^T \tanh(W_a h_t + U_a s_j)$$

    Where \(v_a\), \(W_a\), and \(U_a\) are learnable weight matrices. The attention weights \(\alpha_{tj}\) are then computed by applying softmax to these scores. This mechanism allows the decoder to look at different parts of the source sentence at each step of translation.

    It's called 'additive' because the query and key are combined additively within the tanh function before the final projection by \(v_a\).

  • Dot-Product Attention (Luong Attention)

    Proposed by Luong et al. (2015), this is a simpler and often more computationally efficient attention mechanism. The compatibility score is calculated using the dot product between the query (decoder hidden state \(h_t\)) and keys (encoder hidden states \(s_j\)).

    Several variations of dot-product scoring exist:

    • General: $$\text{score}(h_t, s_j) = h_t^T W_a s_j$$ (with a learnable weight matrix \(W_a\))
    • Dot (Multiplicative): $$\text{score}(h_t, s_j) = h_t^T s_j$$ (if query and key have the same dimensionality)

    Luong attention is typically faster and uses less memory than Bahdanau attention due to the simpler scoring function.

  • Scaled Dot-Product Attention (Transformer)

    This is the attention mechanism used in the Transformer architecture. It's a specific form of dot-product attention where the dot products between queries (Q) and keys (K) are scaled by the square root of the dimension of the keys (\(d_k\)).

    $$\text{Attention}(Q, K, V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})V$$

    Reason for Scaling: For large values of \(d_k\), the dot products can grow large in magnitude, pushing the softmax function into regions where it has extremely small gradients. Scaling helps to counteract this effect, leading to more stable training.

    This mechanism is highly parallelizable as it involves matrix multiplications, making it well-suited for modern hardware (GPUs/TPUs).

  • Self-Attention (Intra-Attention)

    Self-attention is an attention mechanism that relates different positions of a single sequence to compute a representation of the same sequence. In self-attention, the queries, keys, and values all come from the same input sequence (or its representations in previous layers).

    For each token in the input sequence, self-attention computes a weighted sum of all tokens in the sequence (including itself), where the weights indicate the relevance of other tokens to the current token. This allows the model to capture dependencies and context within the input sequence itself.

    For example, it can help resolve pronoun coreferences (e.g., knowing what "it" refers to in a sentence) or understand syntactic relationships. It's a fundamental component of the Transformer architecture.

  • Multi-Head Attention

    Multi-Head Attention, also integral to the Transformer, runs multiple attention mechanisms ("heads") in parallel and then combines their outputs. Instead of performing a single attention function over the queries, keys, and values, the Q, K, and V are linearly projected \(h\) times with different, learned projections. Attention is then applied to each of these projected versions in parallel.

    The outputs of the \(h\) heads are concatenated and once again projected, resulting in the final values.

    $$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O$$

    where $$\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)$$

    This allows the model to jointly attend to information from different representation subspaces at different positions. For instance, different heads might learn to capture different types of relationships (e.g., syntactic, semantic, positional).

  • Other Variants (Briefly)

    • Local Attention: Attends only to a small window of context tokens around the current position, useful for very long sequences where global attention might be too computationally expensive.
    • Sparse Attention: Modifies the attention mechanism to attend only to a sparse subset of positions, rather than all positions, to improve efficiency for long sequences (e.g., Longformer, BigBird).
    • Hierarchical Attention: Applies attention at multiple levels of granularity (e.g., word-level then sentence-level attention for document classification).
  • Advantages

    • Improved Handling of Long Sequences: Attention helps mitigate the bottleneck problem of fixed-length context vectors in older sequence-to-sequence models, allowing models to effectively process and remember information from long input sequences.
    • Better Contextual Understanding: By selectively focusing on relevant parts of the input, attention mechanisms enable models to create more nuanced and contextually aware representations.
    • Interpretability (to some extent): Attention weights can sometimes be visualized to understand which parts of the input sequence the model is focusing on when producing a particular output. This can offer insights into the model's decision-making process, although interpreting attention weights directly as explanations can be misleading.
    • Alignment in Sequence-to-Sequence Tasks: In tasks like machine translation, attention provides a natural way to align words or phrases in the source and target languages.
    • Parallelization (Self-Attention): Self-attention mechanisms, as used in Transformers, allow for highly parallelizable computation over the sequence, leading to faster training and inference on suitable hardware compared to recurrent models.

Implementation

  • Scaled Dot-Product Attention (Conceptual Python/PyTorch)

    
    import torch
    import torch.nn.functional as F
    import math
    
    def scaled_dot_product_attention(query, key, value, mask=None):
        """Compute scaled dot product attention.
        Args:
            query: Tensor of shape (batch_size, num_heads, seq_len_q, d_k)
            key: Tensor of shape (batch_size, num_heads, seq_len_k, d_k)
            value: Tensor of shape (batch_size, num_heads, seq_len_v, d_v) (seq_len_k == seq_len_v)
            mask: Optional tensor of shape (batch_size, 1, seq_len_q, seq_len_k) or broadcastable.
                  Masked positions are typically 0 or True.
        Returns:
            output: Tensor of shape (batch_size, num_heads, seq_len_q, d_v)
            attn_weights: Tensor of shape (batch_size, num_heads, seq_len_q, seq_len_k)
        """
        d_k = query.size(-1) # Dimension of keys/queries
        
        # MatMul(Q, K^T)
        # (..., seq_len_q, d_k) @ (..., d_k, seq_len_k) -> (..., seq_len_q, seq_len_k)
        scores = torch.matmul(query, key.transpose(-2, -1))
        
        # Scale scores
        scores = scores / math.sqrt(d_k)
        
        # Apply mask (if provided)
        # Masked positions are filled with a very small number (-inf) before softmax
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9) # Assuming mask uses 0 for masked positions
            # If mask uses True for masked positions: scores = scores.masked_fill(mask, -1e9)
            
        # Apply softmax to get attention weights
        # Softmax is applied on the last dimension (seq_len_k)
        attn_weights = F.softmax(scores, dim=-1)
        
        # MatMul(attn_weights, V)
        # (..., seq_len_q, seq_len_k) @ (..., seq_len_k, d_v) -> (..., seq_len_q, d_v)
        output = torch.matmul(attn_weights, value)
        
        return output, attn_weights
    
    # Example Usage (Conceptual):
    # batch_size, num_heads, seq_len_q, seq_len_k, d_k, d_v = 2, 8, 10, 12, 64, 64
    # query = torch.randn(batch_size, num_heads, seq_len_q, d_k)
    # key = torch.randn(batch_size, num_heads, seq_len_k, d_k)
    # value = torch.randn(batch_size, num_heads, seq_len_k, d_v) # seq_len_k often same as seq_len_v
    # output, weights = scaled_dot_product_attention(query, key, value)
    # print("Output shape:", output.shape)
    # print("Attention weights shape:", weights.shape)
                            

Interview Examples

What is the core problem that attention mechanisms solve in sequence-to-sequence models?

Explain the limitations of traditional encoder-decoder models that attention was designed to address.

Explain the Query, Key, and Value concepts in attention.

Describe the roles of Queries, Keys, and Values in the general attention framework.

Scaled Dot-Product Attention vs. Additive Attention: Differences?

Compare Scaled Dot-Product Attention (used in Transformers) with Additive (Bahdanau) Attention.

Practice Questions

1. Explain the multi-head attention mechanism in transformers Hard

Hint: Think about why multiple attention heads are better than just one
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

2. How does positional encoding work in transformers? Medium

Hint: Consider how transformers need position information since they have no recurrence or convolution

3. Why is layer normalization important in transformer architectures? Medium

Hint: Think about training stability and convergence