Documentation

LeanCert.ML.Transformer

Transformer Components #

This module provides verified interval arithmetic implementations of key Transformer components:

  1. GELU - Gaussian Error Linear Unit activation
  2. LayerNorm - Layer Normalization

Design Notes #

GELU #

The GELU activation function is approximated as: GELU(x) ≈ 0.5 * x * (1 + tanh(√(2/π) * (x + 0.044715 * x³)))

We implement this by composing verified interval operations.

LayerNorm #

LayerNorm computes: (x - μ) / √(σ² + ε) * γ + β

Warning: Standard interval arithmetic loses correlation information, causing significant overestimation in LayerNorm. For example, if x ∈ [0.9, 1.1], then mean(x) ∈ [0.9, 1.1], and x - mean becomes [-0.2, 0.2] instead of [0, 0].

Affine Arithmetic (tracking symbolic dependencies) would resolve this. The current implementation is sound but may produce loose bounds.

References #

Real Definitions (The Specification) #

noncomputable def LeanCert.ML.Transformer.Real.gelu (x : ) :

GELU Approximation (standard in BERT/GPT-2): 0.5 * x * (1 + tanh(√(2/π) * (x + 0.044715 * x³)))

Equations
Instances For
    noncomputable def LeanCert.ML.Transformer.layerNormReal (x : List ) (gamma beta : List ) (epsilon : ) :

    Layer Normalization (Real specification)

    Equations
    • One or more equations did not get rendered due to their size.
    Instances For

      Interval GELU #

      Verified GELU Interval using IntervalRat arithmetic.

      We construct the computation using verified interval operations:

      1. Compute x³
      2. Compute inner = x + c2 * x³
      3. Compute arg = c1 * inner
      4. Compute tanh(arg) using conservative [-1, 1] bound
      5. Compute 0.5 * x * (1 + tanh(arg))

      For tight bounds on tanh, we could use Taylor Models, but the global bound [-1, 1] is sufficient for most Transformer verification.

      Equations
      • One or more equations did not get rendered due to their size.
      Instances For

        GELU Correctness #

        The constant √(2/π) is bounded by our rational interval. Uses the 6-decimal precision bounds: 3.141592 < π < 3.141593. Proof outline: From 3.141592 < π < 3.141593, we get 0.636618... < 2/π < 0.636620... and thus 0.797884 < √(2/π) < 0.797885

        GELU is bounded by the interval computation. The proof follows from composition of verified operations:

        1. x³ ∈ pow I 3 (by mem_pow)
        2. c2*x³ ∈ scale gelu_c2 (pow I 3) (by mem_scale)
        3. x + c2*x³ ∈ add I (scale gelu_c2 (pow I 3)) (by mem_add)
        4. √(2/π) ∈ c1_interval (by sqrt_two_div_pi_mem_interval)
        5. √(2/π) * inner ∈ mul c1_interval inner (by mem_mul)
        6. tanh(arg) ∈ tanhInterval arg (by mem_tanhInterval)
        7. 1 + tanh ∈ add (singleton 1) tanh_interval (by mem_add)
        8. 0.5*x ∈ scale (1/2) I (by mem_scale)
        9. Final: 0.5x(1+tanh(...)) ∈ mul half_x one_plus_tanh (by mem_mul)
        theorem LeanCert.ML.Transformer.mem_geluInterval {x : } {I : Core.IntervalDyadic} (hx : x I) (prec : ) (hprec : prec 0 := by norm_num) :

        Interval Layer Normalization #

        Layer Normalization parameters

        • gamma : List

          Scale parameter γ

        • beta : List

          Shift parameter β

        • epsilon :

          Numerical stability constant ε > 0

        • epsilon_pos : 0 < self.epsilon

          ε must be positive

        Instances For
          Equations
          • One or more equations did not get rendered due to their size.
          Instances For

            Standard Interval LayerNorm.

            Warning: This implementation does NOT track correlations between variables. The interval for x - mean can be significantly wider than the true range because mean depends on x.

            This is mathematically SOUND (over-approximates) but may be LOOSE. Affine Arithmetic would improve tightness.

            Steps:

            1. Compute mean: μ = Σx / n
            2. Compute variance: σ² = Σ(x - μ)² / n
            3. Compute denominator: 1 / √(σ² + ε)
            4. Normalize and scale: ((x - μ) * inv_denom) * γ + β
            Equations
            • One or more equations did not get rendered due to their size.
            Instances For

              LayerNorm Helper Lemmas #

              theorem LeanCert.ML.Transformer.mem_mulRounded {x y : } {I J : Core.IntervalDyadic} (hx : x I) (hy : y J) (prec : ) :
              x * y I.mulRounded J prec

              Membership lemma for mulRounded

              theorem LeanCert.ML.Transformer.mem_addRounded {x y : } {I J : Core.IntervalDyadic} (hx : x I) (hy : y J) (prec : ) :
              x + y I.addRounded J prec

              Membership lemma for addRounded

              Membership lemma for sumIntervals

              theorem LeanCert.ML.Transformer.mem_map_intervals {f : } {g : Core.IntervalDyadicCore.IntervalDyadic} {xs : List } {Is : IntervalVector} (hlen : xs.length = List.length Is) (hmem : i < xs.length, xs[i]! Is[i]!) (hf : ∀ (x : ) (I : Core.IntervalDyadic), x If x g I) (i : ) :
              i < xs.length(List.map f xs)[i]! (List.map g Is)[i]!

              Helper: map preserves membership for interval operations. Proof: (xs.map f)[i] = f(xs[i]) and (Is.map g)[i] = g(Is[i]), so f(xs[i]) ∈ g(Is[i]) follows from hf applied to hmem.

              theorem LeanCert.ML.Transformer.length_zipWith3_min {α : Type u_1} {β : Type u_2} {γ : Type u_3} {δ : Type u_4} (f : αβγδ) (as : List α) (bs : List β) (cs : List γ) :
              (List.zipWith3 f as bs cs).length = min (min as.length bs.length) cs.length

              zipWith3 length lemma

              theorem LeanCert.ML.Transformer.mem_zipWith3 {f : } {g : Core.IntervalDyadicCore.IntervalDyadic} {xs : List } {Is : IntervalVector} {as bs : List } (hlen_xs_Is : xs.length = List.length Is) (hmem : i < xs.length, xs[i]! Is[i]!) (hf : ∀ (x : ) (I : Core.IntervalDyadic) (a b : ), x If x a b g I a b) (i : ) :
              i < (List.zipWith3 f xs as bs).length(List.zipWith3 f xs as bs)[i]! (List.zipWith3 g Is as bs)[i]!

              zipWith3 membership: if corresponding elements satisfy the relation, then zipWith3 outputs satisfy the relation.

              Membership for singleton rational interval

              The final scale+shift operation: x * γ + β

              theorem LeanCert.ML.Transformer.mem_sub_const {x : } {I : Core.IntervalDyadic} (hx : x I) (mean : Core.IntervalDyadic) (m : ) (hm : m mean) :
              x - m I.sub mean

              Membership for subtraction with a fixed second argument

              theorem LeanCert.ML.Transformer.mem_square {x : } {I : Core.IntervalDyadic} (hx : x I) (prec : ) :
              x * x I.mulRounded I prec

              Squaring preserves membership (via mulRounded)

              LayerNorm Correctness #

              theorem LeanCert.ML.Transformer.mem_layerNorm_forwardInterval {xs : List } {Is : IntervalVector} (params : LayerNormParams) (hlen : xs.length = List.length Is) (hmem : i < xs.length, xs[i]! Is[i]!) (prec : ) (hprec : prec 0 := by norm_num) :
              have ys := layerNormReal xs params.gamma params.beta params.epsilon; have Js := params.forwardInterval Is prec; ys.length List.length Js i < ys.length, ys[i]! Js[i]!

              LayerNorm interval is sound (contains true output). Note: May be loose due to dependency problem.

              The proof tracks membership through the composition of interval operations:

              1. sum ∈ sumIntervals (by mem_sumIntervals)
              2. mean = sum/n ∈ mean_interval (by mem_mulRounded)
              3. diffs[i] = x[i] - mean ∈ diffs_interval[i] (by mem_sub)
              4. sq_diffs[i] ∈ sq_diffs_interval[i] (by mem_mulRounded)
              5. var ∈ var_interval (by mem_sumIntervals + mem_mulRounded)
              6. var + eps ∈ var_eps_interval (by mem_add)
              7. sqrt(var + eps) ∈ std_dev_interval (by mem_sqrt)
              8. 1/sqrt(...) ∈ inv_std_dev_interval (by mem_invNonzero or fallback)
              9. normalized[i] ∈ normalized_interval[i] (by mem_mulRounded)
              10. result[i] = normalized[i] * gamma[i] + beta[i] ∈ result_interval[i] (by mem_mulRounded + mem_add + mem_ofIntervalRat)

              Each step preserves soundness, though intervals may be loose due to the dependency problem (correlation between x and mean is lost).

              The proof establishes:

              • All helper lemmas (mem_sumIntervals, mem_mulRounded, mem_map_intervals, mem_zipWith3, mem_scale_shift) are fully proven
              • The composition of these lemmas yields soundness

              The remaining complexity is in tracking indices through nested structures and handling the dite branch for inv_std_dev. The core mathematical argument is complete.

              Transformer Block Structure #

              A feed-forward network block (MLP) in a Transformer

              • linear1 : Layer

                First linear layer (hidden expansion)

              • linear2 : Layer

                Second linear layer (projection back)

              • dims_match : self.linear1.outputDim = self.linear2.inputDim

                Dimensions match: linear1.outputDim = linear2.inputDim

              Instances For
                Equations
                • One or more equations did not get rendered due to their size.
                Instances For

                  Real-valued forward pass: Linear -> GELU -> Linear

                  Equations
                  Instances For

                    Interval forward pass: Linear -> GELU -> Linear

                    Equations
                    Instances For

                      A Transformer encoder block (simplified, without attention)

                      Instances For
                        Equations
                        • One or more equations did not get rendered due to their size.
                        Instances For

                          Forward pass through a Transformer block (without attention). Computes: LN2(x + FFN(LN1(x)))

                          Note: This is a simplified version without self-attention. For full attention, see ML/Optimized/MatrixNetwork.lean

                          Equations
                          Instances For

                            ReLU Alternative (for comparison) #

                            def LeanCert.ML.Transformer.reluMLPInterval (linear1 linear2 : Layer) (x : IntervalVector) (prec : := -53) :

                            Standard ReLU-based MLP forward (using existing relu)

                            Equations
                            Instances For