Verified Model Distillation / Equivalence Checking #
This module implements the logic to verify that two neural networks
(e.g., a massive Teacher and a compressed Student) produce outputs
that are within a specific tolerance ε of each other for ALL inputs
in a given domain.
Main Definitions #
SequentialNet- A network with N layerscheckEquivalence- Computable certificate for network equivalenceverify_equivalence- The main theorem proving |T(x) - S(x)| ≤ ε
The Math #
We define the difference graph: Diff(x) = Teacher(x) - Student(x)
We compute interval bounds for Diff(InputBox). If the resulting interval is contained in [-ε, ε], the networks are ε-equivalent.
Interval Subtraction #
Negate an interval: -I = [-hi, -lo]
Equations
Instances For
Membership in negated interval
Subtract two intervals: I - J = I + (-J)
Equations
Instances For
Membership in subtracted interval
Sequential Network Infrastructure #
A standard feedforward network is just a list of layers
Instances For
Equations
- One or more equations did not get rendered due to their size.
Instances For
Real-valued forward pass
Equations
- net.forwardReal x = List.foldl (fun (acc : List ℝ) (l : LeanCert.ML.Layer) => l.forwardReal acc) x net.layers
Instances For
Interval forward pass
Equations
- net.forwardInterval x prec = List.foldl (fun (acc : LeanCert.Engine.IntervalVector) (l : LeanCert.ML.Layer) => l.forwardInterval acc prec) x net.layers
Instances For
Check if all layers in the sequence are well-formed and dimensions align
Equations
- net.WellFormed inputDim = LeanCert.ML.Distillation.SequentialNet.LayersWellFormed net.layers inputDim
Instances For
Helper: Lengths of foldl outputs match
Soundness of SequentialNet forward pass
Public soundness theorem
Interval Vector Subtraction #
Subtract two interval vectors (pointwise)
Equations
Instances For
Length of subtracted vectors
Membership in subtracted vectors
The Equivalence Checker #
Check if an interval is contained within [-eps, eps]
Equations
- LeanCert.ML.Distillation.intervalBoundedBy I eps = decide (-eps ≤ I.toIntervalRat.lo ∧ I.toIntervalRat.hi ≤ eps)
Instances For
Check if an interval vector is contained within [-eps, eps]
Equations
- LeanCert.ML.Distillation.isBoundedBy v eps = List.all v fun (x : LeanCert.Core.IntervalDyadic) => LeanCert.ML.Distillation.intervalBoundedBy x eps
Instances For
Soundness of intervalBoundedBy
The Distillation Certifier
Returns true if the student network is proven to be within eps
of the teacher network for all inputs in the domain.
Equations
- One or more equations did not get rendered due to their size.
Instances For
Correctness Proofs #
Helper: extract dimension equality from checkEquivalence
Golden Theorem: Verified Model Distillation
If checkEquivalence returns true, then for ALL real inputs x in the domain,
the output difference |Teacher(x) - Student(x)| is at most eps for every output neuron.