Documentation

LeanCert.ML.Softmax

Verified Softmax and LogSumExp #

This module implements tight interval bounds for Softmax using algebraic cancellation to handle the numerator-denominator dependency.

The Problem #

Naive Interval Softmax: [e^l, e^u] / [Σe^l, Σe^u]. If width is large, the lower bound of the result approaches 0 and upper > 1.

The Solution (Algebraic Substitution) #

We compute: y_i = 1 / (1 + Σ_{j≠i} exp(x_j - x_i))

This explicitly models the dependency, ensuring the result is always in (0, 1) and providing much tighter bounds for the dominant token.

Helper: Interval Exponentiation #

Vectorized exponential using Taylor refined intervals

Equations
  • One or more equations did not get rendered due to their size.
Instances For

    Verified LogSumExp #

    LogSumExp: log(Σ e^x_i) Computed as: max(x) + log(Σ e^(x_i - max(x))) This is the "Shift-Invariance" property used for numerical stability. It also keeps intervals small (inputs to exp are ≤ 0).

    Equations
    • One or more equations did not get rendered due to their size.
    Instances For

      Verified Softmax #

      Optimized Softmax for index k of vector x. Computes 1 / (1 + Σ_{j≠k} exp(x_j - x_k))

      Equations
      • One or more equations did not get rendered due to their size.
      Instances For

        Soundness Proofs #

        theorem LeanCert.ML.Softmax.mem_foldl_add {xs_real : List } {xs_interval : List Core.IntervalDyadic} {acc_real : } {acc_interval : Core.IntervalDyadic} (_hlen : xs_real.length = xs_interval.length) (hacc : acc_real acc_interval) (_hmem : ∀ (i : ) (hi : i < xs_real.length), xs_real[i] xs_interval[i]) :
        List.foldl (fun (x1 x2 : ) => x1 + x2) acc_real xs_real List.foldl Core.IntervalDyadic.add acc_interval xs_interval

        Helper: foldl add preserves interval membership. If acc_real ∈ acc_interval and for each i, xs_real[i] ∈ xs_interval[i], then (xs_real.foldl (+) acc_real) ∈ (xs_interval.foldl add acc_interval).

        Proof by induction on xs_real, using mem_add at each step.

        noncomputable def LeanCert.ML.Softmax.Real.softmax (x : List ) (k : ) :

        Real Softmax function

        Equations
        Instances For
          theorem LeanCert.ML.Softmax.softmax_algebraic_identity (x : List ) (k : ) (hk : k < x.length) (hsum : (List.map Real.exp x).sum 0) :
          Real.softmax x k = 1 / (List.map (fun (xj : ) => Real.exp (xj - x[k]!)) x).sum

          Algebraic Identity: e^xk / Σ e^xj = 1 / Σ e^(xj - xk)

          This is the key insight that allows tight interval bounds: by dividing both numerator and denominator by e^xk, we cancel the correlation and get an expression where all terms are differences.

          Proof sketch: e^xk / Σe^xj = (e^xk * e^{-xk}) / (Σe^xj * e^{-xk}) = 1 / Σ(e^xj * e^{-xk}) = 1 / Σe^{xj - xk}

          theorem LeanCert.ML.Softmax.mem_softmax {v : List } {I : IntervalVector} (hdim : v.length = List.length I) (hmem : ∀ (i : ) (hi : i < List.length I), v[i] I[i]) (prec : ) (_hprec : prec 0) (k : ) (hk : k < List.length I) :
          Real.softmax v k (softmax I prec)[k]

          Main Theorem: Softmax Soundness

          If inputs v are within I, then softmax(v) is within softmax(I).