Documentation

LeanCert.ML.Attention

Verified Self-Attention Mechanism #

This module implements the complete self-attention mechanism for Transformers with verified interval bounds.

Architecture #

Input X: [seq_len, d_model]
    │
    ├──► W_Q ──► Q: [seq_len, d_k]
    ├──► W_K ──► K: [seq_len, d_k]
    └──► W_V ──► V: [seq_len, d_v]
              │
              ▼
    Attention(Q, K, V) = softmax(Q × K^T / √d_k) × V
              │
              ▼
    Output: [seq_len, d_v]

Key Components #

  1. Scaled Dot-Product Attention: softmax(QK^T / √d_k) × V
  2. Multi-Head Attention: Split into h heads, attend, concat, project

Interval Arithmetic Strategy #

The composition of these verified operations maintains soundness.

Scaling Factor #

Compute 1/√d_k as an interval for scaling attention scores. We use a rational approximation that bounds the true value.

Equations
  • One or more equations did not get rendered due to their size.
Instances For

    Scaled Dot-Product Attention (Vector Form) #

    Attention scores for a single query against all keys.

    Given:

    • q: query vector [d_k]
    • K: key matrix [seq_len, d_k] (as list of key vectors)

    Returns: attention weights [seq_len] after softmax

    Formula: softmax(q · k_i / √d_k for all i)

    Equations
    • One or more equations did not get rendered due to their size.
    Instances For

      Apply attention weights to values.

      Given:

      • weights: attention weights [seq_len]
      • V: value matrix [seq_len, d_v] (as list of value vectors)

      Returns: weighted sum of values [d_v]

      Formula: Σ_i weights[i] * V[i]

      Equations
      • One or more equations did not get rendered due to their size.
      Instances For

        Single-head scaled dot-product attention.

        Given:

        • Q: query matrix [seq_len, d_k] (as list of query vectors)
        • K: key matrix [seq_len, d_k] (as list of key vectors)
        • V: value matrix [seq_len, d_v] (as list of value vectors)

        Returns: output matrix [seq_len, d_v]

        Formula: softmax(Q × K^T / √d_k) × V

        Equations
        • One or more equations did not get rendered due to their size.
        Instances For

          Multi-Head Attention #

          Parameters for multi-head attention

          • d_model :

            Model dimension

          • num_heads :

            Number of attention heads

          • d_k :

            Key/Query dimension per head

          • d_v :

            Value dimension per head

          • W_Q : List (List (List ))

            Query projection weights for each head: [num_heads, d_model, d_k]

          • W_K : List (List (List ))

            Key projection weights for each head: [num_heads, d_model, d_k]

          • W_V : List (List (List ))

            Value projection weights for each head: [num_heads, d_model, d_v]

          • W_O : List (List )

            Output projection weights: [num_heads * d_v, d_model]

          Instances For
            Equations
            • One or more equations did not get rendered due to their size.
            Instances For

              Linear projection: X × W^T X: [seq_len, d_in] as list of vectors W: [d_out, d_in] as list of lists Returns: [seq_len, d_out]

              Equations
              • One or more equations did not get rendered due to their size.
              Instances For

                Single attention head computation

                Equations
                • One or more equations did not get rendered due to their size.
                Instances For

                  Concatenate vectors horizontally

                  Equations
                  Instances For

                    Multi-head attention forward pass.

                    MultiHead(Q, K, V) = Concat(head_1, ..., head_h) × W_O where head_i = Attention(X × W_Q^i, X × W_K^i, X × W_V^i)

                    Equations
                    • One or more equations did not get rendered due to their size.
                    Instances For

                      Real Specifications #

                      Real-valued scaled dot-product attention

                      Equations
                      • One or more equations did not get rendered due to their size.
                      Instances For

                        Soundness #

                        theorem LeanCert.ML.Attention.mem_attentionWeights {q_real : List } {K_real : List (List )} {q : IntervalVector} {K : List IntervalVector} (_hq : q_real.length = List.length q) (hK : K_real.length = K.length) (d_k : ) (prec : ) :
                        have weights_real := List.map (fun (k : List ) => (List.zipWith (fun (x1 x2 : ) => x1 * x2) q_real k).sum / d_k) K_real; have weights := attentionWeights q K d_k prec; weights_real.length List.length weights

                        Soundness of attention weights computation.

                        If query q is bounded by interval q_I, and keys K are bounded by K_I, then the attention weights (after softmax) are bounded by the computed intervals.

                        theorem LeanCert.ML.Attention.mem_scaledDotProductAttention {Q_real K_real V_real : List (List )} {Q K V : List IntervalVector} (hQ : Q_real.length = Q.length) (_hK : K_real.length = K.length) (_hV : V_real.length = V.length) (d_k : ) (prec : ) :
                        have output_real := Real.scaledDotProductAttention Q_real K_real V_real d_k; have output := scaledDotProductAttention Q K V d_k prec; output_real.length output.length

                        Main soundness theorem for scaled dot-product attention.

                        If Q, K, V real matrices are bounded element-wise by interval matrices Q_I, K_I, V_I, then the output of attention is bounded by scaledDotProductAttention Q_I K_I V_I.

                        Integration with Transformer #

                        A complete Transformer encoder layer with self-attention

                        Instances For
                          Equations
                          • One or more equations did not get rendered due to their size.
                          Instances For

                            Forward pass through encoder layer (Pre-LN architecture).

                            Output = LN2(x + FFN(LN1(x + MHA(x))))

                            This is the "Pre-LN" variant which is more stable for training.

                            Equations
                            • One or more equations did not get rendered due to their size.
                            Instances For

                              Forward pass through the entire encoder stack

                              Equations
                              Instances For