Affine Arithmetic: Vector Operations #
This file provides vectorized affine operations for efficient neural network verification. The key insight is that shared noise symbols track correlations across vector elements.
The Dependency Problem in Vectors #
Consider LayerNorm: y_i = (x_i - μ) / σ where μ = mean(x).
With standard interval arithmetic:
x_i ∈ [0.9, 1.1]for all iμ ∈ [0.9, 1.1]x_i - μ ∈ [-0.2, 0.2](WRONG! Should be nearly 0)
With Affine Arithmetic:
x_i = 1.0 + 0.1·ε_i(each element has its own noise)μ = 1.0 + 0.1·(ε_1 + ... + ε_n)/nx_i - μ = 0.1·ε_i - 0.1·(ε_1 + ... + ε_n)/n(TIGHT!)
Main definitions #
AffineVector- Vector of affine forms with shared noise symbolsAffineVector.mean- Compute mean preserving correlationsAffineVector.sub- Element-wise subtractionAffineVector.layerNorm- LayerNorm with proper dependency tracking
Affine Vector #
A vector of affine forms sharing the same noise symbol space.
All elements should have the same coeffs.length (number of noise symbols).
This ensures correlations are properly tracked.
Instances For
Create an affine vector from intervals, assigning each a unique noise symbol.
If intervals = [I₀, I₁, I₂], creates:
- x₀ = mid(I₀) + rad(I₀)·ε₀
- x₁ = mid(I₁) + rad(I₁)·ε₁
- x₂ = mid(I₂) + rad(I₂)·ε₂
Equations
- One or more equations did not get rendered due to their size.
Instances For
Convert back to interval bounds
Equations
Instances For
Linear Operations #
Element-wise addition
Equations
Instances For
Element-wise subtraction
Equations
Instances For
Element-wise negation
Equations
Instances For
Scalar multiplication
Equations
Instances For
Aggregation Operations #
Sum of all elements (preserves correlations!)
Equations
Instances For
Mean of all elements
Equations
Instances For
Dot product of two vectors
Equations
Instances For
LayerNorm Components #
Compute (x - μ) for each element, where μ = mean(x).
This is where Affine Arithmetic shines: the subtraction properly cancels the correlated parts, giving tight bounds.
Instances For
Compute variance: mean((x - μ)²)
Equations
Instances For
Layer Normalization: (x - μ) / √(σ² + ε) * γ + β
Parameters:
- v: input vector (affine)
- gamma: scale parameters
- beta: shift parameters
- eps: numerical stability constant
Equations
- One or more equations did not get rendered due to their size.
Instances For
Softmax Components #
Compute exp(x_i - max(x)) for numerical stability
Equations
- One or more equations did not get rendered due to their size.
Instances For
Softmax using algebraic cancellation.
softmax(x)_i = exp(x_i) / Σ exp(x_j) = 1 / Σ exp(x_j - x_i)
By computing differences first, we track correlations better.
Equations
- One or more equations did not get rendered due to their size.
Instances For
Attention Components #
Scaled dot-product attention scores for a single query.
Computes softmax(q · K^T / √d_k) where K is a list of key vectors.
Equations
- One or more equations did not get rendered due to their size.
Instances For
Apply attention weights to values
Equations
- One or more equations did not get rendered due to their size.
Instances For
GELU #
GELU activation: x · Φ(x) ≈ 0.5 · x · (1 + tanh(√(2/π) · (x + 0.044715 · x³)))
Using affine arithmetic preserves correlations in x · tanh(f(x)).
Equations
- One or more equations did not get rendered due to their size.
Instances For
Soundness #
Membership for affine vectors
Equations
- LeanCert.Engine.Affine.AffineVector.mem v_real v eps = (v_real.length = List.length v ∧ ∀ (i : ℕ) (hi_v : i < List.length v) (hi_r : i < v_real.length), v[i].mem_affine eps v_real[i])
Instances For
Sum is sound
Mean is sound
Centered is sound: x - mean(x)
Variance is sound
LayerNorm is sound for the complete operation.
Note: This theorem requires additional hypotheses that would typically be verified at the call site:
- The variance + epsilon must be positive (guaranteed by epsilon > 0)
- Length constraints for gamma and beta
- The affine form for var + eps must have positive lower bound for inv
The proof composition uses:
- mem_centered for centering
- mem_variance for variance
- mem_add for adding epsilon
- mem_sqrt for standard deviation
- mem_inv for reciprocal
- mem_layerNorm_elem for each output element