Upgrade to Pro — share decks privately, control downloads, hide ads and more …

Symbolic Distillation of Neural Networks

Symbolic Distillation of Neural Networks

I describe a general framework for distilling symbolic models from neural networks.

Miles Cranmer

March 01, 2023
Tweet

Other Decks in Science

Transcript

  1. Symbolic Distillation

    of Neural Networks
    Miles Cranmer
    Flatiron Institute


    University of Cambridge


    Princeton University


    View full-size slide

  2. P2 ∝ a3
    Kepler’s third law
    Kepler’s third law
    Empirical
    fi
    t:
    Problem:

    View full-size slide

  3. P2 ∝ a3
    Kepler’s third law
    Newton’s law of
    gravitation, 

    to explain it
    Kepler’s third law
    Empirical
    fi
    t:
    Problem:

    View full-size slide

  4. P2 ∝ a3
    Kepler’s third law
    Newton’s law of
    gravitation, 

    to explain it
    Kepler’s third law Planck’s law
    B =
    2hν3
    c2 (
    exp (

    kB
    T) − 1
    )
    −1
    Empirical
    fi
    t:
    Problem:

    View full-size slide

  5. P2 ∝ a3
    Kepler’s third law
    Newton’s law of
    gravitation, 

    to explain it
    Kepler’s third law Planck’s law
    B =
    2hν3
    c2 (
    exp (

    kB
    T) − 1
    )
    −1
    Empirical
    fi
    t:
    Quantum
    mechanics, 

    to explain it
    (Partially)
    Problem:

    View full-size slide

  6. P2 ∝ a3
    Kepler’s third law
    Newton’s law of
    gravitation, 

    to explain it
    Kepler’s third law Planck’s law
    B =
    2hν3
    c2 (
    exp (

    kB
    T) − 1
    )
    −1
    Empirical
    fi
    t: Neural 

    Network

    Weights
    Quantum
    mechanics, 

    to explain it
    (Partially)
    Problem:

    View full-size slide

  7. P2 ∝ a3
    Kepler’s third law
    Newton’s law of
    gravitation, 

    to explain it
    Kepler’s third law Planck’s law
    B =
    2hν3
    c2 (
    exp (

    kB
    T) − 1
    )
    −1
    Empirical
    fi
    t: Neural 

    Network

    Weights
    ???
    Quantum
    mechanics, 

    to explain it
    (Partially)
    Problem:

    View full-size slide

  8. What I want:
    I want ML to create models in a language* I can understand**


    → Insights into existing models


    → Understand biases, learned shortcuts


    → Can place stronger priors on learned functions

    View full-size slide

  9. Industry version of interpretability
    • All revolves around saliency or feature
    importance


    • Consider the “saliency map”
    Omeiza et al., 2019

    View full-size slide

  10. Industry version of interpretability
    • All revolves around saliency or feature
    importance


    • Consider the “saliency map”
    Omeiza et al., 2019

    View full-size slide

  11. (- Feynman’s blackboard)

    View full-size slide

  12. Science already has a modeling language

    View full-size slide

  13. Science already has a modeling language
    Computer Vision

    View full-size slide

  14. Science already has a modeling language
    Computer Vision
    ???

    View full-size slide

  15. Science already has a modeling language
    Computer Vision Science
    ???

    View full-size slide

  16. Science already has a modeling language
    Computer Vision Science
    ???

    View full-size slide

  17. We should build interpretations in this existing
    language: mathematical expressions!

    View full-size slide

  18. Symbolic regression
    • Symbolic regression
    fi
    nds analytic expressions to
    fi
    t a dataset.


    • (~another name for "program synthesis”)


    • Pioneering work by Langley et al., 1980s; Koza et al., 1990s; Lipson et al., 2000s

    View full-size slide

  19. How can I try this?

    View full-size slide

  20. How can I try this?
    • Open-source

    View full-size slide

  21. How can I try this?
    • Open-source
    • Extensible Python API compatible with scikit-learn

    View full-size slide

  22. How can I try this?
    • Open-source
    • Extensible Python API compatible with scikit-learn
    • Can be distributed over 1000s of cores

    (w/ slurm, PBS, LSF, or Kubernetes)

    View full-size slide

  23. How can I try this?
    • Open-source
    • Extensible Python API compatible with scikit-learn
    • Can be distributed over 1000s of cores

    (w/ slurm, PBS, LSF, or Kubernetes)
    • Custom operators, losses, constraints

    View full-size slide

  24. Great!


    But, there’s a problem:

    View full-size slide

  25. Great!


    But, there’s a problem:
    • Genetic algorithms, like PySR, scale terribly with expression complexity.

    View full-size slide

  26. Great!


    But, there’s a problem:
    • Genetic algorithms, like PySR, scale terribly with expression complexity.
    • One must search over:

    View full-size slide

  27. Great!


    But, there’s a problem:
    • Genetic algorithms, like PySR, scale terribly with expression complexity.
    • One must search over:
    • (permutations of operators) x (permutations of variables + possible constants)

    View full-size slide

  28. Great!


    But, there’s a problem:
    • Genetic algorithms, like PySR, scale terribly with expression complexity.
    • One must search over:
    • (permutations of operators) x (permutations of variables + possible constants)

    View full-size slide

  29. Great!


    But, there’s a problem:
    • Genetic algorithms, like PySR, scale terribly with expression complexity.
    • One must search over:
    • (permutations of operators) x (permutations of variables + possible constants)
    • But, we know that neural networks can ef
    fi
    ciently
    fi
    nd very complex functions!

    View full-size slide

  30. Great!


    But, there’s a problem:
    • Genetic algorithms, like PySR, scale terribly with expression complexity.
    • One must search over:
    • (permutations of operators) x (permutations of variables + possible constants)
    • But, we know that neural networks can ef
    fi
    ciently
    fi
    nd very complex functions!
    • Can we exploit this?

    View full-size slide

  31. Symbolic Distillation
    Neural network

    View full-size slide

  32. Symbolic Distillation
    Neural network
    Approximation in my
    domain-speci
    fi
    c language
    Miles Cranmer, Rui Xu, Peter Battaglia and
    Shirley Ho,

    ML4Physics Workshop @ NeurIPS 2019
    Miles Cranmer, Alvaro Sanchez-Gonzalez,
    Peter Battaglia, Rui Xu, Kyle Cranmer, David
    Spergel and Shirley Ho,

    NeurIPS, 2020

    View full-size slide

  33. How this works:
    1. Train NN normally, 

    and freeze parameters.

    View full-size slide

  34. How this works:
    1. Train NN normally, 

    and freeze parameters.
    2. Record input/outputs of

    network over training set.

    View full-size slide

  35. How this works:
    1. Train NN normally, 

    and freeze parameters.
    2. Record input/outputs of

    network over training set.
    PySR
    3. Fit the input/outputs of the
    neural network with PySR

    View full-size slide

  36. Analogy
    “Taylor expanding the Neural Network”

    View full-size slide

  37. Analogy
    “Taylor expanding the Neural Network”

    View full-size slide

  38. Full Symbolic Distillation

    View full-size slide

  39. Full Symbolic Distillation
    Learns features?
    Uses features

    for calculation?

    View full-size slide

  40. Full Symbolic Distillation
    Learns features?
    Uses features

    for calculation?

    View full-size slide

  41. Full Symbolic Distillation

    View full-size slide

  42. Full Symbolic Distillation
    Re-train , to pick up any errors

    in the approximation of
    g
    f
    🔄

    View full-size slide

  43. Full Symbolic Distillation

    View full-size slide

  44. Full Symbolic Distillation

    View full-size slide

  45. Full Symbolic Distillation

    View full-size slide

  46. Full Symbolic Distillation

    View full-size slide

  47. Full Symbolic Distillation
    (g ∘ f)(x1
    , x2
    , x3
    , x4
    ) =

    View full-size slide

  48. Full Symbolic Distillation
    (g ∘ f)(x1
    , x2
    , x3
    , x4
    ) =
    Fully-interpretable approximation of the original neural network!

    View full-size slide

  49. Full Symbolic Distillation
    (g ∘ f)(x1
    , x2
    , x3
    , x4
    ) =
    Fully-interpretable approximation of the original neural network!

    View full-size slide

  50. Full Symbolic Distillation
    (g ∘ f)(x1
    , x2
    , x3
    , x4
    ) =
    Fully-interpretable approximation of the original neural network!
    • Easier to interpret and compare with existing models in the domain speci
    fi
    c
    language

    View full-size slide

  51. Full Symbolic Distillation
    (g ∘ f)(x1
    , x2
    , x3
    , x4
    ) =
    Fully-interpretable approximation of the original neural network!
    • Easier to interpret and compare with existing models in the domain speci
    fi
    c
    language
    • Easier to impose symbolic priors (can potentially get better generalization!)

    View full-size slide

  52. vs
    Instead of having to
    fi
    nd this complex expression,

    I have reduced it to
    fi
    nding multiple, simple expressions.

    View full-size slide

  53. Searching over expressions Searching over expressions
    n2 → 2n
    vs
    Instead of having to
    fi
    nd this complex expression,

    I have reduced it to
    fi
    nding multiple, simple expressions.

    View full-size slide

  54. What about the functional degeneracy?
    Any over-complicated functional form that learns,
    could invert!
    f
    g

    View full-size slide

  55. Xi
    y
    Inductive bias
    • Introducing some form of inductive bias is needed to eliminate the
    functional degeneracy. For example:

    View full-size slide

  56. Xi
    y
    Inductive bias
    • Introducing some form of inductive bias is needed to eliminate the
    functional degeneracy. For example:
    • the latent space between and could have some aggregation over a set!
    f g

    i

    View full-size slide

  57. Inductive bias

    View full-size slide

  58. Inductive bias
    • Other inductive biases to eliminate the degeneracy:

    View full-size slide

  59. Inductive bias
    • Other inductive biases to eliminate the degeneracy:
    • Sparsity on latent space ( fewer equations, fewer variables)

    View full-size slide

  60. Inductive bias
    • Other inductive biases to eliminate the degeneracy:
    • Sparsity on latent space ( fewer equations, fewer variables)

    • (Also see related work of Sebastian Wetzel & Roger Melko; and Steve Brunton & Nathan Kutz!)

    View full-size slide

  61. Inductive bias
    • Other inductive biases to eliminate the degeneracy:
    • Sparsity on latent space ( fewer equations, fewer variables)

    • (Also see related work of Sebastian Wetzel & Roger Melko; and Steve Brunton & Nathan Kutz!)
    • Smoothness penalty (try to encourage expression-like behavior)

    View full-size slide

  62. Inductive bias
    • Other inductive biases to eliminate the degeneracy:
    • Sparsity on latent space ( fewer equations, fewer variables)

    • (Also see related work of Sebastian Wetzel & Roger Melko; and Steve Brunton & Nathan Kutz!)
    • Smoothness penalty (try to encourage expression-like behavior)
    • “Disentangled sparsity”

    View full-size slide

  63. Miles Cranmer, Can Cui, et al. “Disentangled Sparsity Networks for Explainable AI”

    Workshop on Sparse Neural Networks, 2021, pp. 7

    https://astroautomata.com/data/sjnn_paper.pdf

    View full-size slide

  64. • Disentangled Sparsity:
    Miles Cranmer, Can Cui, et al. “Disentangled Sparsity Networks for Explainable AI”

    Workshop on Sparse Neural Networks, 2021, pp. 7

    https://astroautomata.com/data/sjnn_paper.pdf

    View full-size slide

  65. • Disentangled Sparsity:
    • Want few latent features AND want each latent feature to have few
    dependencies
    Miles Cranmer, Can Cui, et al. “Disentangled Sparsity Networks for Explainable AI”

    Workshop on Sparse Neural Networks, 2021, pp. 7

    https://astroautomata.com/data/sjnn_paper.pdf

    View full-size slide

  66. • Disentangled Sparsity:
    • Want few latent features AND want each latent feature to have few
    dependencies
    • This makes things much easier for the genetic algorithm!
    Miles Cranmer, Can Cui, et al. “Disentangled Sparsity Networks for Explainable AI”

    Workshop on Sparse Neural Networks, 2021, pp. 7

    https://astroautomata.com/data/sjnn_paper.pdf

    View full-size slide

  67. • Disentangled Sparsity:
    • Want few latent features AND want each latent feature to have few
    dependencies
    • This makes things much easier for the genetic algorithm!
    Miles Cranmer, Can Cui, et al. “Disentangled Sparsity Networks for Explainable AI”

    Workshop on Sparse Neural Networks, 2021, pp. 7

    https://astroautomata.com/data/sjnn_paper.pdf

    View full-size slide

  68. Example: Graph neural network activations =
    forces, under a sparsity regularization
    Miles Cranmer, Alvaro Sanchez-Gonzalez, Peter Battaglia, Rui Xu, Kyle Cranmer, David Spergel and Shirley Ho,

    NeurIPS, 2020

    View full-size slide

  69. Example:

    Discovering Orbital Mechanics

    View full-size slide

  70. Example:

    Discovering Orbital Mechanics
    Can we learn Newton’s law of gravity simply by
    observing the solar system?


    Unknown masses, and unknown dynamical model.

    View full-size slide

  71. “Rediscovering orbital mechanics with machine learning” (2022)

    Pablo Lemos, Niall Jeffrey, Miles Cranmer, Shirley Ho, Peter Battaglia
    Example:

    Discovering Orbital Mechanics
    Can we learn Newton’s law of gravity simply by
    observing the solar system?


    Unknown masses, and unknown dynamical model.

    View full-size slide

  72. Simpli
    fi
    cation:

    View full-size slide

  73. Simpli
    fi
    cation:
    • At some time :


    • Known position for each planet


    • Known acceleration for each planet


    • Unknown parameter for each planet


    • Unknown force
    t
    xi
    ∈ ℝ3
    ··
    xi
    ∈ ℝ3
    vi
    ∈ ℝ
    f(xi
    − xj
    , vi
    , vj
    )

    View full-size slide

  74. Simpli
    fi
    cation:

    View full-size slide

  75. Simpli
    fi
    cation:
    • Optimize:



    ··
    xi

    1
    vi

    j≠i
    f(xi
    − xj
    , vi
    , vj
    )

    View full-size slide

  76. Simpli
    fi
    cation:
    • Optimize:



    ··
    xi

    1
    vi

    j≠i
    f(xi
    − xj
    , vi
    , vj
    )
    Known
    acceleration

    View full-size slide

  77. Simpli
    fi
    cation:
    • Optimize:



    ··
    xi

    1
    vi

    j≠i
    f(xi
    − xj
    , vi
    , vj
    )
    Known
    acceleration
    Newton’s laws

    of motion (assumed)

    View full-size slide

  78. Simpli
    fi
    cation:
    • Optimize:



    ··
    xi

    1
    vi

    j≠i
    f(xi
    − xj
    , vi
    , vj
    )
    Known
    acceleration Learned 

    force law
    Newton’s laws

    of motion (assumed)

    View full-size slide

  79. Simpli
    fi
    cation:
    • Optimize:



    ··
    xi

    1
    vi

    j≠i
    f(xi
    − xj
    , vi
    , vj
    )
    Known
    acceleration Learned 

    force law
    Learned parameters
    for planets i, j
    Newton’s laws

    of motion (assumed)

    View full-size slide

  80. Simpli
    fi
    cation:
    • Optimize:



    ··
    xi

    1
    vi

    j≠i
    f(xi
    − xj
    , vi
    , vj
    )
    Known
    acceleration Learned 

    force law
    Learned parameters
    for planets i, j
    Newton’s laws

    of motion (assumed)
    Learn via gradient descent.


    This allows us to
    fi
    nd

    both and simultaneously
    f
    f vi

    View full-size slide

  81. Training:
    • NASA’s HORIZONS ephemeris data


    • 31 bodies:


    • Sun


    • Planets


    • Moons with mass > 1e18 kg


    • (Therefore: 465 connections)


    • 30 years, 1980-2010 for training


    • 2010-2013 for validation

    View full-size slide

  82. Next: interpretation
    ··
    xi

    1
    vi

    j≠i
    f(xi
    − xj
    , vi
    , vj
    )
    Approximate input/output of with symbolic regression.
    f

    View full-size slide

  83. Interpretation Results for f
    Complexity
    Accuracy/Complexity
    Tradeo
    ff
    *
    *from Cranmer+2020; similar to
    Schmidt & Lipson, 2009

    View full-size slide

  84. Interpretation Results for f
    Complexity
    Accuracy/Complexity
    Tradeo
    ff
    *
    *from Cranmer+2020; similar to
    Schmidt & Lipson, 2009

    View full-size slide

  85. Interpretation Results for f
    Complexity
    Accuracy/Complexity
    Tradeo
    ff
    *
    *from Cranmer+2020; similar to
    Schmidt & Lipson, 2009

    View full-size slide

  86. Interpretation Results for f
    Complexity
    Accuracy/Complexity
    Tradeo
    ff
    *
    *from Cranmer+2020; similar to
    Schmidt & Lipson, 2009

    View full-size slide

  87. Interpretation Results for f
    Complexity
    Accuracy/Complexity
    Tradeo
    ff
    *
    *from Cranmer+2020; similar to
    Schmidt & Lipson, 2009

    View full-size slide

  88. Interpretation Results for f
    Complexity
    Accuracy/Complexity
    Tradeo
    ff
    *
    *from Cranmer+2020; similar to
    Schmidt & Lipson, 2009

    View full-size slide

  89. Interpretation Results for f
    Complexity
    Accuracy/Complexity
    Tradeo
    ff
    *
    *from Cranmer+2020; similar to
    Schmidt & Lipson, 2009

    View full-size slide

  90. Interpretation Results for f
    Complexity
    Accuracy/Complexity
    Tradeo
    ff
    *
    *from Cranmer+2020; similar to
    Schmidt & Lipson, 2009
    = −
    d(log(error))
    d(complexity)

    View full-size slide

  91. Test the symbolic model:

    View full-size slide

  92. Why isn’t this working well?
    • Let’s look at the mass values in comparison with the true masses:

    View full-size slide

  93. Why isn’t this working well?
    • Let’s look at the mass values in comparison with the true masses:

    View full-size slide

  94. Solution: re-optimize !
    vi

    View full-size slide

  95. Solution: re-optimize !
    vi
    • The were optimized for the neural network.
    vi

    View full-size slide

  96. Solution: re-optimize !
    vi
    • The were optimized for the neural network.
    vi
    • The symbolic formula is not a *perfect* approximation of the network.

    View full-size slide

  97. Solution: re-optimize !
    vi
    • The were optimized for the neural network.
    vi
    • The symbolic formula is not a *perfect* approximation of the network.
    • Thus: we need to re-optimize for the symbolic function !
    vi
    f

    View full-size slide

  98. V. Ongoing Work: Turbulence
    Work includes: Dmitrii Kochkov, Keaton Burns, Drummond Fielding,
    and others

    View full-size slide

  99. Learned Coarse Models for Ef
    fi
    cient Turbulence Simulation

    (ICLR 2022)


    Kimberly Stachenfeld, Drummond B. Fielding, Dmitrii Kochkov, Miles Cranmer,

    Tobias Pfaff, Jonathan Godwin, Can Cui, Shirley Ho, Peter Battaglia, Alvaro Sanchez-
    Gonzalez
    Example:

    View full-size slide

  100. Learned Coarse Models for Ef
    fi
    cient Turbulence Simulation

    (ICLR 2022)


    Kimberly Stachenfeld, Drummond B. Fielding, Dmitrii Kochkov, Miles Cranmer,

    Tobias Pfaff, Jonathan Godwin, Can Cui, Shirley Ho, Peter Battaglia, Alvaro Sanchez-
    Gonzalez
    Example:
    Trained to reproduce turbulence simulations at lower resolution:

    View full-size slide

  101. Learned Coarse Models for Ef
    fi
    cient Turbulence Simulation

    (ICLR 2022)


    Kimberly Stachenfeld, Drummond B. Fielding, Dmitrii Kochkov, Miles Cranmer,

    Tobias Pfaff, Jonathan Godwin, Can Cui, Shirley Ho, Peter Battaglia, Alvaro Sanchez-
    Gonzalez
    1000x speedup:
    Example:
    Trained to reproduce turbulence simulations at lower resolution:

    View full-size slide

  102. Learned Coarse Models for Ef
    fi
    cient Turbulence Simulation

    (ICLR 2022)


    Kimberly Stachenfeld, Drummond B. Fielding, Dmitrii Kochkov, Miles Cranmer,

    Tobias Pfaff, Jonathan Godwin, Can Cui, Shirley Ho, Peter Battaglia, Alvaro Sanchez-
    Gonzalez
    1000x speedup:
    How did the model actually achieve this?
    Example:
    Trained to reproduce turbulence simulations at lower resolution:

    View full-size slide

  103. (Preliminary results)
    τxx
    τxy
    τyx
    τyy
    (Non-symmetric, since o
    ff
    set!)

    View full-size slide

  104. Summary
    • Symbolic distillation is a technique for translating ML models into a
    domain speci
    fi
    c language


    • Can do this for expressions/programs using PySR


    • Exciting future applications in understanding turbulence, and other
    physical systems

    View full-size slide