Verified Softmax and LogSumExp #
This module implements tight interval bounds for Softmax using algebraic cancellation to handle the numerator-denominator dependency.
The Problem #
Naive Interval Softmax: [e^l, e^u] / [Σe^l, Σe^u]. If width is large, the lower bound of the result approaches 0 and upper > 1.
The Solution (Algebraic Substitution) #
We compute: y_i = 1 / (1 + Σ_{j≠i} exp(x_j - x_i))
This explicitly models the dependency, ensuring the result is always in (0, 1) and providing much tighter bounds for the dominant token.
Helper: Interval Exponentiation #
Vectorized exponential using Taylor refined intervals
Equations
- One or more equations did not get rendered due to their size.
Instances For
Verified LogSumExp #
LogSumExp: log(Σ e^x_i) Computed as: max(x) + log(Σ e^(x_i - max(x))) This is the "Shift-Invariance" property used for numerical stability. It also keeps intervals small (inputs to exp are ≤ 0).
Equations
- One or more equations did not get rendered due to their size.
Instances For
Verified Softmax #
Optimized Softmax for index k of vector x.
Computes 1 / (1 + Σ_{j≠k} exp(x_j - x_k))
Equations
- One or more equations did not get rendered due to their size.
Instances For
Full Softmax layer
Equations
- LeanCert.ML.Softmax.softmax x prec = List.map (fun (k : ℕ) => LeanCert.ML.Softmax.softmaxComponent x k prec) (List.range (List.length x))
Instances For
Soundness Proofs #
Helper: foldl add preserves interval membership. If acc_real ∈ acc_interval and for each i, xs_real[i] ∈ xs_interval[i], then (xs_real.foldl (+) acc_real) ∈ (xs_interval.foldl add acc_interval).
Proof by induction on xs_real, using mem_add at each step.
Algebraic Identity: e^xk / Σ e^xj = 1 / Σ e^(xj - xk)
This is the key insight that allows tight interval bounds: by dividing both numerator and denominator by e^xk, we cancel the correlation and get an expression where all terms are differences.
Proof sketch: e^xk / Σe^xj = (e^xk * e^{-xk}) / (Σe^xj * e^{-xk}) = 1 / Σ(e^xj * e^{-xk}) = 1 / Σe^{xj - xk}
Main Theorem: Softmax Soundness
If inputs v are within I, then softmax(v) is within softmax(I).