Affine Arithmetic LayerNorm for ML #
This file bridges the ML Transformer module with the Affine Arithmetic module, providing tight bounds for LayerNorm by preserving correlations.
The Problem with Standard Interval Arithmetic #
Standard interval arithmetic loses correlations between variables. In LayerNorm:
μ = mean(x)depends on allx_ix_i - μshould be small when all inputs are similar- But interval arithmetic treats
x_iandμas independent!
Example: If x_i ∈ [0.9, 1.1] for all i:
- Standard:
μ ∈ [0.9, 1.1], sox_i - μ ∈ [-0.2, 0.2](4x overestimate!) - Affine: Tracks that
μcomes from the samex_i, givingx_i - μ ≈ 0
Solution: Affine Arithmetic #
Affine forms track symbolic dependencies via noise symbols:
x_i = c₀ + c₁·ε₁ + c₂·ε₂ + ... + [-r, r]- Shared
ε_iacross variables preserve correlations - When we compute
x - mean(x), the common terms cancel!
Main Definitions #
LayerNormParams.forwardAffine- LayerNorm using affine arithmeticLayerNormParams.forwardIntervalTight- Convert affine result back to intervals
References #
- de Figueiredo & Stolfi, "Affine Arithmetic: Concepts and Applications", 2004
Affine LayerNorm Forward Pass #
Forward pass using affine arithmetic for tight bounds.
This converts the input intervals to affine forms, computes LayerNorm while preserving correlations, then extracts the resulting intervals.
Key advantage: The centering step x - μ preserves correlations,
giving much tighter bounds than standard interval arithmetic.
Equations
- params.forwardAffine Is = (LeanCert.Engine.Affine.AffineVector.ofIntervals (List.map LeanCert.Core.IntervalDyadic.toIntervalRat Is)).layerNorm params.gamma params.beta params.epsilon
Instances For
Convert affine output back to intervals.
This extracts conservative interval bounds from the affine forms. The bounds are tight because correlations were preserved during computation.
Equations
- params.forwardIntervalTight Is prec = List.map (fun (af : LeanCert.Engine.Affine.AffineForm) => LeanCert.Core.IntervalDyadic.ofIntervalRat af.toInterval prec) (params.forwardAffine Is)
Instances For
Comparison: Interval vs Affine Bounds #
Compute both interval and affine bounds for comparison.
Returns (interval_bounds, affine_bounds) for the same input. The affine bounds should be tighter, especially for centering.
Equations
- params.compareBounds Is prec = (params.forwardInterval Is prec, params.forwardIntervalTight Is prec)
Instances For
Measure the tightness improvement from affine arithmetic.
Returns the ratio of interval width to affine width for each output dimension. Values > 1 indicate affine is tighter.
Equations
- One or more equations did not get rendered due to their size.
Instances For
Soundness Theorem #
The affine LayerNorm is sound: if inputs are in the affine forms, then outputs are in the resulting affine forms.
This follows from composition of:
AffineVector.mem_centered- centering preserves membershipAffineVector.mem_variance- variance is soundAffineForm.mem_add- addition is soundAffineForm.mem_sqrt- square root is soundAffineForm.mem_inv- inversion is soundAffineForm.mem_mul,AffineForm.mem_scale,AffineForm.mem_add- final combination
The proof requires additional hypotheses about positivity of variance + epsilon and compatibility of lengths, which are handled in the implementation.
FFNBlock and TransformerBlock with Affine Arithmetic #
FFNBlock forward pass using affine arithmetic where beneficial.
Uses affine arithmetic for LayerNorm (where correlation matters), standard intervals for linear layers (where it doesn't).
Equations
- blk.forwardIntervalTight x prec = blk.linear2.forwardInterval (LeanCert.ML.Transformer.geluVector (blk.linear1.forwardInterval x prec) prec) prec
Instances For
TransformerBlock forward pass with tight LayerNorm bounds.
Uses affine arithmetic for LayerNorm to avoid the dependency problem.
Equations
- blk.forwardIntervalTight x prec = blk.ln2.forwardIntervalTight (x.add (blk.ffn.forwardInterval (blk.ln1.forwardIntervalTight x prec) prec)) prec