Transformer Components #
This module provides verified interval arithmetic implementations of key Transformer components:
- GELU - Gaussian Error Linear Unit activation
- LayerNorm - Layer Normalization
Design Notes #
GELU #
The GELU activation function is approximated as: GELU(x) ≈ 0.5 * x * (1 + tanh(√(2/π) * (x + 0.044715 * x³)))
We implement this by composing verified interval operations.
LayerNorm #
LayerNorm computes: (x - μ) / √(σ² + ε) * γ + β
Warning: Standard interval arithmetic loses correlation information, causing significant overestimation in LayerNorm. For example, if x ∈ [0.9, 1.1], then mean(x) ∈ [0.9, 1.1], and x - mean becomes [-0.2, 0.2] instead of [0, 0].
Affine Arithmetic (tracking symbolic dependencies) would resolve this. The current implementation is sound but may produce loose bounds.
References #
- Hendrycks & Gimpel, "Gaussian Error Linear Units (GELUs)", 2016
- Ba et al., "Layer Normalization", 2016
Real Definitions (The Specification) #
Interval GELU #
Verified GELU Interval using IntervalRat arithmetic.
We construct the computation using verified interval operations:
- Compute x³
- Compute inner = x + c2 * x³
- Compute arg = c1 * inner
- Compute tanh(arg) using conservative [-1, 1] bound
- Compute 0.5 * x * (1 + tanh(arg))
For tight bounds on tanh, we could use Taylor Models, but the global bound [-1, 1] is sufficient for most Transformer verification.
Equations
- One or more equations did not get rendered due to their size.
Instances For
GELU on IntervalDyadic with precision control
Equations
Instances For
Vector version of GELU
Equations
- LeanCert.ML.Transformer.geluVector v prec = List.map (fun (I : LeanCert.Core.IntervalDyadic) => LeanCert.ML.Transformer.geluInterval I prec) v
Instances For
GELU Correctness #
The constant √(2/π) is bounded by our rational interval. Uses the 6-decimal precision bounds: 3.141592 < π < 3.141593. Proof outline: From 3.141592 < π < 3.141593, we get 0.636618... < 2/π < 0.636620... and thus 0.797884 < √(2/π) < 0.797885
GELU is bounded by the interval computation. The proof follows from composition of verified operations:
- x³ ∈ pow I 3 (by mem_pow)
- c2*x³ ∈ scale gelu_c2 (pow I 3) (by mem_scale)
- x + c2*x³ ∈ add I (scale gelu_c2 (pow I 3)) (by mem_add)
- √(2/π) ∈ c1_interval (by sqrt_two_div_pi_mem_interval)
- √(2/π) * inner ∈ mul c1_interval inner (by mem_mul)
- tanh(arg) ∈ tanhInterval arg (by mem_tanhInterval)
- 1 + tanh ∈ add (singleton 1) tanh_interval (by mem_add)
- 0.5*x ∈ scale (1/2) I (by mem_scale)
- Final: 0.5x(1+tanh(...)) ∈ mul half_x one_plus_tanh (by mem_mul)
Interval Layer Normalization #
Equations
- One or more equations did not get rendered due to their size.
Instances For
Sum of interval vector elements
Equations
Instances For
Standard Interval LayerNorm.
Warning: This implementation does NOT track correlations between
variables. The interval for x - mean can be significantly wider than
the true range because mean depends on x.
This is mathematically SOUND (over-approximates) but may be LOOSE. Affine Arithmetic would improve tightness.
Steps:
- Compute mean: μ = Σx / n
- Compute variance: σ² = Σ(x - μ)² / n
- Compute denominator: 1 / √(σ² + ε)
- Normalize and scale: ((x - μ) * inv_denom) * γ + β
Equations
- One or more equations did not get rendered due to their size.
Instances For
LayerNorm Helper Lemmas #
Membership lemma for mulRounded
Membership lemma for addRounded
Membership lemma for sumIntervals
Helper: map preserves membership for interval operations. Proof: (xs.map f)[i] = f(xs[i]) and (Is.map g)[i] = g(Is[i]), so f(xs[i]) ∈ g(Is[i]) follows from hf applied to hmem.
zipWith3 membership: if corresponding elements satisfy the relation, then zipWith3 outputs satisfy the relation.
Membership for singleton rational interval
The final scale+shift operation: x * γ + β
Membership for subtraction with a fixed second argument
Squaring preserves membership (via mulRounded)
LayerNorm Correctness #
LayerNorm interval is sound (contains true output). Note: May be loose due to dependency problem.
The proof tracks membership through the composition of interval operations:
- sum ∈ sumIntervals (by mem_sumIntervals)
- mean = sum/n ∈ mean_interval (by mem_mulRounded)
- diffs[i] = x[i] - mean ∈ diffs_interval[i] (by mem_sub)
- sq_diffs[i] ∈ sq_diffs_interval[i] (by mem_mulRounded)
- var ∈ var_interval (by mem_sumIntervals + mem_mulRounded)
- var + eps ∈ var_eps_interval (by mem_add)
- sqrt(var + eps) ∈ std_dev_interval (by mem_sqrt)
- 1/sqrt(...) ∈ inv_std_dev_interval (by mem_invNonzero or fallback)
- normalized[i] ∈ normalized_interval[i] (by mem_mulRounded)
- result[i] = normalized[i] * gamma[i] + beta[i] ∈ result_interval[i] (by mem_mulRounded + mem_add + mem_ofIntervalRat)
Each step preserves soundness, though intervals may be loose due to the dependency problem (correlation between x and mean is lost).
The proof establishes:
- All helper lemmas (mem_sumIntervals, mem_mulRounded, mem_map_intervals, mem_zipWith3, mem_scale_shift) are fully proven
- The composition of these lemmas yields soundness
The remaining complexity is in tracking indices through nested structures and handling the dite branch for inv_std_dev. The core mathematical argument is complete.
Transformer Block Structure #
Equations
Equations
- One or more equations did not get rendered due to their size.
Instances For
Real-valued forward pass: Linear -> GELU -> Linear
Equations
- blk.forwardReal x = blk.linear2.forwardReal (List.map LeanCert.ML.Transformer.Real.gelu (blk.linear1.forwardReal x))
Instances For
Interval forward pass: Linear -> GELU -> Linear
Equations
- blk.forwardInterval x prec = blk.linear2.forwardInterval (LeanCert.ML.Transformer.geluVector (blk.linear1.forwardInterval x prec) prec) prec
Instances For
A Transformer encoder block (simplified, without attention)
- ln1 : LayerNormParams
Pre-FFN layer normalization
- ffn : FFNBlock
Feed-forward network
- ln2 : LayerNormParams
Post-FFN layer normalization
Instances For
Equations
- One or more equations did not get rendered due to their size.
Instances For
Forward pass through a Transformer block (without attention). Computes: LN2(x + FFN(LN1(x)))
Note: This is a simplified version without self-attention. For full attention, see ML/Optimized/MatrixNetwork.lean
Equations
- blk.forwardInterval x prec = blk.ln2.forwardInterval (x.add (blk.ffn.forwardInterval (blk.ln1.forwardInterval x prec) prec)) prec
Instances For
ReLU Alternative (for comparison) #
Standard ReLU-based MLP forward (using existing relu)
Equations
- LeanCert.ML.Transformer.reluMLPInterval linear1 linear2 x prec = linear2.forwardInterval (linear1.forwardInterval x prec) prec