Documentation

LeanCert.ML.Distillation

Verified Model Distillation / Equivalence Checking #

This module implements the logic to verify that two neural networks (e.g., a massive Teacher and a compressed Student) produce outputs that are within a specific tolerance ε of each other for ALL inputs in a given domain.

Main Definitions #

The Math #

We define the difference graph: Diff(x) = Teacher(x) - Student(x)

We compute interval bounds for Diff(InputBox). If the resulting interval is contained in [-ε, ε], the networks are ε-equivalent.

Interval Subtraction #

Negate an interval: -I = [-hi, -lo]

Equations
Instances For

    Membership in negated interval

    theorem LeanCert.ML.Distillation.IntervalDyadic.mem_sub {x y : } {I J : Core.IntervalDyadic} (hx : x I) (hy : y J) :
    x - y I.sub J

    Membership in subtracted interval

    Sequential Network Infrastructure #

    A standard feedforward network is just a list of layers

    Instances For
      Equations
      • One or more equations did not get rendered due to their size.
      Instances For

        Real-valued forward pass

        Equations
        Instances For

          Interval forward pass

          Equations
          Instances For

            A layer chain is well-formed if dimensions align

            Equations
            Instances For

              Check if all layers in the sequence are well-formed and dimensions align

              Equations
              Instances For
                theorem LeanCert.ML.Distillation.SequentialNet.forwardLength_aux (layers : List Layer) (xs : List ) (Is : IntervalVector) (hdim : xs.length = List.length Is) (hwf : LayersWellFormed layers xs.length) (prec : ) :
                (List.foldl (fun (acc : List ) (l : Layer) => l.forwardReal acc) xs layers).length = List.length (List.foldl (fun (acc : IntervalVector) (l : Layer) => l.forwardInterval acc prec) Is layers)

                Helper: Lengths of foldl outputs match

                theorem LeanCert.ML.Distillation.SequentialNet.mem_forwardInterval_aux (layers : List Layer) (xs : List ) (Is : IntervalVector) (hdim : xs.length = List.length Is) (hwf : LayersWellFormed layers xs.length) (hmem : ∀ (i : ) (hi : i < List.length Is), xs[i] Is[i]) (prec : ) (hprec : prec 0) :
                let outReal := List.foldl (fun (acc : List ) (l : Layer) => l.forwardReal acc) xs layers; let outInt := List.foldl (fun (acc : IntervalVector) (l : Layer) => l.forwardInterval acc prec) Is layers; ∀ (i : ) (hi : i < List.length outInt), outReal[i] outInt[i]

                Soundness of SequentialNet forward pass

                theorem LeanCert.ML.Distillation.SequentialNet.mem_forwardInterval {net : SequentialNet} {xs : List } {Is : IntervalVector} (hdim : xs.length = List.length Is) (hwf : net.WellFormed xs.length) (hmem : ∀ (i : ) (hi : i < List.length Is), xs[i] Is[i]) (prec : ) (hprec : prec 0 := by norm_num) :
                let outReal := net.forwardReal xs; let outInt := net.forwardInterval Is prec; outReal.length = List.length outInt ∀ (i : ) (hi : i < List.length outInt), outReal[i] outInt[i]

                Public soundness theorem

                Interval Vector Subtraction #

                theorem LeanCert.ML.Distillation.mem_subVectors {ra rb : List } {ia ib : IntervalVector} (halen : ra.length = List.length ia) (hblen : rb.length = List.length ib) (ha : ∀ (i : ) (hi : i < List.length ia), ra[i] ia[i]) (hb : ∀ (i : ) (hi : i < List.length ib), rb[i] ib[i]) (i : ) (hi : i < List.length (subVectors ia ib)) :
                ra[i] - rb[i] (subVectors ia ib)[i]

                Membership in subtracted vectors

                The Equivalence Checker #

                Check if an interval is contained within [-eps, eps]

                Equations
                Instances For

                  Check if an interval vector is contained within [-eps, eps]

                  Equations
                  Instances For
                    theorem LeanCert.ML.Distillation.intervalBoundedBy_spec {x : } {I : Core.IntervalDyadic} {eps : } (hx : x I) (hcheck : intervalBoundedBy I eps = true) :
                    |x| eps

                    Soundness of intervalBoundedBy

                    def LeanCert.ML.Distillation.checkEquivalence (teacher student : SequentialNet) (domain : IntervalVector) (eps : ) (prec : := -53) :

                    The Distillation Certifier

                    Returns true if the student network is proven to be within eps of the teacher network for all inputs in the domain.

                    Equations
                    • One or more equations did not get rendered due to their size.
                    Instances For

                      Correctness Proofs #

                      theorem LeanCert.ML.Distillation.checkEquivalence_dims {teacher student : SequentialNet} {domain : IntervalVector} {eps : } {prec : } (h : checkEquivalence teacher student domain eps prec = true) :
                      List.length (teacher.forwardInterval domain prec) = List.length (student.forwardInterval domain prec)

                      Helper: extract dimension equality from checkEquivalence

                      theorem LeanCert.ML.Distillation.verify_equivalence (teacher student : SequentialNet) (domain : IntervalVector) (eps : ) (prec : ) (x : List ) (hprec : prec 0) (hwf_t : teacher.WellFormed x.length) (hwf_s : student.WellFormed x.length) (h_dom_len : x.length = List.length domain) (h_mem : ∀ (i : ) (hi : i < List.length domain), x[i] domain[i]) (h_cert : checkEquivalence teacher student domain eps prec = true) :
                      let t_out := teacher.forwardReal x; let s_out := student.forwardReal x; let t_int := teacher.forwardInterval domain prec; let s_int := student.forwardInterval domain prec; List.length t_int = List.length s_int ∀ (i : ) (hi_t : i < List.length t_int) (hi_s : i < List.length s_int), |t_out[i] - s_out[i]| eps

                      Golden Theorem: Verified Model Distillation

                      If checkEquivalence returns true, then for ALL real inputs x in the domain, the output difference |Teacher(x) - Student(x)| is at most eps for every output neuron.