Upgrade to Pro — share decks privately, control downloads, hide ads and more …

Symbolic Distillation of Neural Networks

Symbolic Distillation of Neural Networks

I describe a general framework for distilling symbolic models from neural networks.

Miles Cranmer

March 01, 2023
Tweet

Other Decks in Science

Transcript

  1. Symbolic Distillation

    of Neural Networks
    Miles Cranmer
    Flatiron Institute


    University of Cambridge


    Princeton University


    View Slide

  2. P2 ∝ a3
    Kepler’s third law
    Kepler’s third law
    Empirical
    fi
    t:
    Problem:

    View Slide

  3. P2 ∝ a3
    Kepler’s third law
    Newton’s law of
    gravitation, 

    to explain it
    Kepler’s third law
    Empirical
    fi
    t:
    Problem:

    View Slide

  4. P2 ∝ a3
    Kepler’s third law
    Newton’s law of
    gravitation, 

    to explain it
    Kepler’s third law Planck’s law
    B =
    2hν3
    c2 (
    exp (

    kB
    T) − 1
    )
    −1
    Empirical
    fi
    t:
    Problem:

    View Slide

  5. P2 ∝ a3
    Kepler’s third law
    Newton’s law of
    gravitation, 

    to explain it
    Kepler’s third law Planck’s law
    B =
    2hν3
    c2 (
    exp (

    kB
    T) − 1
    )
    −1
    Empirical
    fi
    t:
    Quantum
    mechanics, 

    to explain it
    (Partially)
    Problem:

    View Slide

  6. P2 ∝ a3
    Kepler’s third law
    Newton’s law of
    gravitation, 

    to explain it
    Kepler’s third law Planck’s law
    B =
    2hν3
    c2 (
    exp (

    kB
    T) − 1
    )
    −1
    Empirical
    fi
    t: Neural 

    Network

    Weights
    Quantum
    mechanics, 

    to explain it
    (Partially)
    Problem:

    View Slide

  7. P2 ∝ a3
    Kepler’s third law
    Newton’s law of
    gravitation, 

    to explain it
    Kepler’s third law Planck’s law
    B =
    2hν3
    c2 (
    exp (

    kB
    T) − 1
    )
    −1
    Empirical
    fi
    t: Neural 

    Network

    Weights
    ???
    Quantum
    mechanics, 

    to explain it
    (Partially)
    Problem:

    View Slide

  8. What I want:
    I want ML to create models in a language* I can understand**


    → Insights into existing models


    → Understand biases, learned shortcuts


    → Can place stronger priors on learned functions

    View Slide

  9. Industry version of interpretability
    • All revolves around saliency or feature
    importance


    • Consider the “saliency map”
    Omeiza et al., 2019

    View Slide

  10. Industry version of interpretability
    • All revolves around saliency or feature
    importance


    • Consider the “saliency map”
    Omeiza et al., 2019

    View Slide

  11. (- Feynman’s blackboard)

    View Slide

  12. Science already has a modeling language

    View Slide

  13. Science already has a modeling language
    Computer Vision

    View Slide

  14. Science already has a modeling language
    Computer Vision
    ???

    View Slide

  15. Science already has a modeling language
    Computer Vision Science
    ???

    View Slide

  16. Science already has a modeling language
    Computer Vision Science
    ???

    View Slide

  17. View Slide

  18. We should build interpretations in this existing
    language: mathematical expressions!

    View Slide

  19. Symbolic regression
    • Symbolic regression
    fi
    nds analytic expressions to
    fi
    t a dataset.


    • (~another name for "program synthesis”)


    • Pioneering work by Langley et al., 1980s; Koza et al., 1990s; Lipson et al., 2000s

    View Slide

  20. How can I try this?

    View Slide

  21. How can I try this?
    • Open-source

    View Slide

  22. How can I try this?
    • Open-source
    • Extensible Python API compatible with scikit-learn

    View Slide

  23. How can I try this?
    • Open-source
    • Extensible Python API compatible with scikit-learn
    • Can be distributed over 1000s of cores

    (w/ slurm, PBS, LSF, or Kubernetes)

    View Slide

  24. How can I try this?
    • Open-source
    • Extensible Python API compatible with scikit-learn
    • Can be distributed over 1000s of cores

    (w/ slurm, PBS, LSF, or Kubernetes)
    • Custom operators, losses, constraints

    View Slide

  25. View Slide

  26. View Slide

  27. Great!


    But, there’s a problem:

    View Slide

  28. Great!


    But, there’s a problem:
    • Genetic algorithms, like PySR, scale terribly with expression complexity.

    View Slide

  29. Great!


    But, there’s a problem:
    • Genetic algorithms, like PySR, scale terribly with expression complexity.
    • One must search over:

    View Slide

  30. Great!


    But, there’s a problem:
    • Genetic algorithms, like PySR, scale terribly with expression complexity.
    • One must search over:
    • (permutations of operators) x (permutations of variables + possible constants)

    View Slide

  31. Great!


    But, there’s a problem:
    • Genetic algorithms, like PySR, scale terribly with expression complexity.
    • One must search over:
    • (permutations of operators) x (permutations of variables + possible constants)

    View Slide

  32. Great!


    But, there’s a problem:
    • Genetic algorithms, like PySR, scale terribly with expression complexity.
    • One must search over:
    • (permutations of operators) x (permutations of variables + possible constants)
    • But, we know that neural networks can ef
    fi
    ciently
    fi
    nd very complex functions!

    View Slide

  33. Great!


    But, there’s a problem:
    • Genetic algorithms, like PySR, scale terribly with expression complexity.
    • One must search over:
    • (permutations of operators) x (permutations of variables + possible constants)
    • But, we know that neural networks can ef
    fi
    ciently
    fi
    nd very complex functions!
    • Can we exploit this?

    View Slide

  34. Symbolic Distillation
    Neural network

    View Slide

  35. Symbolic Distillation
    Neural network
    Approximation in my
    domain-speci
    fi
    c language
    Miles Cranmer, Rui Xu, Peter Battaglia and
    Shirley Ho,

    ML4Physics Workshop @ NeurIPS 2019
    Miles Cranmer, Alvaro Sanchez-Gonzalez,
    Peter Battaglia, Rui Xu, Kyle Cranmer, David
    Spergel and Shirley Ho,

    NeurIPS, 2020

    View Slide

  36. How this works:
    1. Train NN normally, 

    and freeze parameters.

    View Slide

  37. How this works:
    1. Train NN normally, 

    and freeze parameters.
    2. Record input/outputs of

    network over training set.

    View Slide

  38. How this works:
    1. Train NN normally, 

    and freeze parameters.
    2. Record input/outputs of

    network over training set.
    PySR
    3. Fit the input/outputs of the
    neural network with PySR

    View Slide

  39. Analogy
    “Taylor expanding the Neural Network”

    View Slide

  40. Analogy
    “Taylor expanding the Neural Network”

    View Slide

  41. Full Symbolic Distillation

    View Slide

  42. Full Symbolic Distillation
    Learns features?
    Uses features

    for calculation?

    View Slide

  43. Full Symbolic Distillation
    Learns features?
    Uses features

    for calculation?

    View Slide

  44. Full Symbolic Distillation

    View Slide

  45. Full Symbolic Distillation
    Re-train , to pick up any errors

    in the approximation of
    g
    f
    🔄

    View Slide

  46. Full Symbolic Distillation

    View Slide

  47. Full Symbolic Distillation

    View Slide

  48. Full Symbolic Distillation

    View Slide

  49. Full Symbolic Distillation

    View Slide

  50. Full Symbolic Distillation
    (g ∘ f)(x1
    , x2
    , x3
    , x4
    ) =

    View Slide

  51. Full Symbolic Distillation
    (g ∘ f)(x1
    , x2
    , x3
    , x4
    ) =
    Fully-interpretable approximation of the original neural network!

    View Slide

  52. Full Symbolic Distillation
    (g ∘ f)(x1
    , x2
    , x3
    , x4
    ) =
    Fully-interpretable approximation of the original neural network!

    View Slide

  53. Full Symbolic Distillation
    (g ∘ f)(x1
    , x2
    , x3
    , x4
    ) =
    Fully-interpretable approximation of the original neural network!
    • Easier to interpret and compare with existing models in the domain speci
    fi
    c
    language

    View Slide

  54. Full Symbolic Distillation
    (g ∘ f)(x1
    , x2
    , x3
    , x4
    ) =
    Fully-interpretable approximation of the original neural network!
    • Easier to interpret and compare with existing models in the domain speci
    fi
    c
    language
    • Easier to impose symbolic priors (can potentially get better generalization!)

    View Slide

  55. vs
    Instead of having to
    fi
    nd this complex expression,

    I have reduced it to
    fi
    nding multiple, simple expressions.

    View Slide

  56. Searching over expressions Searching over expressions
    n2 → 2n
    vs
    Instead of having to
    fi
    nd this complex expression,

    I have reduced it to
    fi
    nding multiple, simple expressions.

    View Slide

  57. What about the functional degeneracy?
    Any over-complicated functional form that learns,
    could invert!
    f
    g

    View Slide

  58. Xi
    y
    Inductive bias
    • Introducing some form of inductive bias is needed to eliminate the
    functional degeneracy. For example:

    View Slide

  59. Xi
    y
    Inductive bias
    • Introducing some form of inductive bias is needed to eliminate the
    functional degeneracy. For example:
    • the latent space between and could have some aggregation over a set!
    f g

    i

    View Slide

  60. Inductive bias

    View Slide

  61. Inductive bias
    • Other inductive biases to eliminate the degeneracy:

    View Slide

  62. Inductive bias
    • Other inductive biases to eliminate the degeneracy:
    • Sparsity on latent space ( fewer equations, fewer variables)

    View Slide

  63. Inductive bias
    • Other inductive biases to eliminate the degeneracy:
    • Sparsity on latent space ( fewer equations, fewer variables)

    • (Also see related work of Sebastian Wetzel & Roger Melko; and Steve Brunton & Nathan Kutz!)

    View Slide

  64. Inductive bias
    • Other inductive biases to eliminate the degeneracy:
    • Sparsity on latent space ( fewer equations, fewer variables)

    • (Also see related work of Sebastian Wetzel & Roger Melko; and Steve Brunton & Nathan Kutz!)
    • Smoothness penalty (try to encourage expression-like behavior)

    View Slide

  65. Inductive bias
    • Other inductive biases to eliminate the degeneracy:
    • Sparsity on latent space ( fewer equations, fewer variables)

    • (Also see related work of Sebastian Wetzel & Roger Melko; and Steve Brunton & Nathan Kutz!)
    • Smoothness penalty (try to encourage expression-like behavior)
    • “Disentangled sparsity”

    View Slide

  66. Miles Cranmer, Can Cui, et al. “Disentangled Sparsity Networks for Explainable AI”

    Workshop on Sparse Neural Networks, 2021, pp. 7

    https://astroautomata.com/data/sjnn_paper.pdf

    View Slide

  67. • Disentangled Sparsity:
    Miles Cranmer, Can Cui, et al. “Disentangled Sparsity Networks for Explainable AI”

    Workshop on Sparse Neural Networks, 2021, pp. 7

    https://astroautomata.com/data/sjnn_paper.pdf

    View Slide

  68. • Disentangled Sparsity:
    • Want few latent features AND want each latent feature to have few
    dependencies
    Miles Cranmer, Can Cui, et al. “Disentangled Sparsity Networks for Explainable AI”

    Workshop on Sparse Neural Networks, 2021, pp. 7

    https://astroautomata.com/data/sjnn_paper.pdf

    View Slide

  69. • Disentangled Sparsity:
    • Want few latent features AND want each latent feature to have few
    dependencies
    • This makes things much easier for the genetic algorithm!
    Miles Cranmer, Can Cui, et al. “Disentangled Sparsity Networks for Explainable AI”

    Workshop on Sparse Neural Networks, 2021, pp. 7

    https://astroautomata.com/data/sjnn_paper.pdf

    View Slide

  70. • Disentangled Sparsity:
    • Want few latent features AND want each latent feature to have few
    dependencies
    • This makes things much easier for the genetic algorithm!
    Miles Cranmer, Can Cui, et al. “Disentangled Sparsity Networks for Explainable AI”

    Workshop on Sparse Neural Networks, 2021, pp. 7

    https://astroautomata.com/data/sjnn_paper.pdf

    View Slide

  71. Example: Graph neural network activations =
    forces, under a sparsity regularization
    Miles Cranmer, Alvaro Sanchez-Gonzalez, Peter Battaglia, Rui Xu, Kyle Cranmer, David Spergel and Shirley Ho,

    NeurIPS, 2020

    View Slide

  72. Example:

    Discovering Orbital Mechanics

    View Slide

  73. Example:

    Discovering Orbital Mechanics
    Can we learn Newton’s law of gravity simply by
    observing the solar system?


    Unknown masses, and unknown dynamical model.

    View Slide

  74. “Rediscovering orbital mechanics with machine learning” (2022)

    Pablo Lemos, Niall Jeffrey, Miles Cranmer, Shirley Ho, Peter Battaglia
    Example:

    Discovering Orbital Mechanics
    Can we learn Newton’s law of gravity simply by
    observing the solar system?


    Unknown masses, and unknown dynamical model.

    View Slide

  75. Simpli
    fi
    cation:

    View Slide

  76. Simpli
    fi
    cation:
    • At some time :


    • Known position for each planet


    • Known acceleration for each planet


    • Unknown parameter for each planet


    • Unknown force
    t
    xi
    ∈ ℝ3
    ··
    xi
    ∈ ℝ3
    vi
    ∈ ℝ
    f(xi
    − xj
    , vi
    , vj
    )

    View Slide

  77. Simpli
    fi
    cation:

    View Slide

  78. Simpli
    fi
    cation:
    • Optimize:



    ··
    xi

    1
    vi

    j≠i
    f(xi
    − xj
    , vi
    , vj
    )

    View Slide

  79. Simpli
    fi
    cation:
    • Optimize:



    ··
    xi

    1
    vi

    j≠i
    f(xi
    − xj
    , vi
    , vj
    )
    Known
    acceleration

    View Slide

  80. Simpli
    fi
    cation:
    • Optimize:



    ··
    xi

    1
    vi

    j≠i
    f(xi
    − xj
    , vi
    , vj
    )
    Known
    acceleration
    Newton’s laws

    of motion (assumed)

    View Slide

  81. Simpli
    fi
    cation:
    • Optimize:



    ··
    xi

    1
    vi

    j≠i
    f(xi
    − xj
    , vi
    , vj
    )
    Known
    acceleration Learned 

    force law
    Newton’s laws

    of motion (assumed)

    View Slide

  82. Simpli
    fi
    cation:
    • Optimize:



    ··
    xi

    1
    vi

    j≠i
    f(xi
    − xj
    , vi
    , vj
    )
    Known
    acceleration Learned 

    force law
    Learned parameters
    for planets i, j
    Newton’s laws

    of motion (assumed)

    View Slide

  83. Simpli
    fi
    cation:
    • Optimize:



    ··
    xi

    1
    vi

    j≠i
    f(xi
    − xj
    , vi
    , vj
    )
    Known
    acceleration Learned 

    force law
    Learned parameters
    for planets i, j
    Newton’s laws

    of motion (assumed)
    Learn via gradient descent.


    This allows us to
    fi
    nd

    both and simultaneously
    f
    f vi

    View Slide

  84. Training:
    • NASA’s HORIZONS ephemeris data


    • 31 bodies:


    • Sun


    • Planets


    • Moons with mass > 1e18 kg


    • (Therefore: 465 connections)


    • 30 years, 1980-2010 for training


    • 2010-2013 for validation

    View Slide

  85. View Slide

  86. View Slide

  87. Next: interpretation
    ··
    xi

    1
    vi

    j≠i
    f(xi
    − xj
    , vi
    , vj
    )
    Approximate input/output of with symbolic regression.
    f

    View Slide

  88. Interpretation Results for f
    Complexity
    Accuracy/Complexity
    Tradeo
    ff
    *
    *from Cranmer+2020; similar to
    Schmidt & Lipson, 2009

    View Slide

  89. Interpretation Results for f
    Complexity
    Accuracy/Complexity
    Tradeo
    ff
    *
    *from Cranmer+2020; similar to
    Schmidt & Lipson, 2009

    View Slide

  90. Interpretation Results for f
    Complexity
    Accuracy/Complexity
    Tradeo
    ff
    *
    *from Cranmer+2020; similar to
    Schmidt & Lipson, 2009

    View Slide

  91. Interpretation Results for f
    Complexity
    Accuracy/Complexity
    Tradeo
    ff
    *
    *from Cranmer+2020; similar to
    Schmidt & Lipson, 2009

    View Slide

  92. Interpretation Results for f
    Complexity
    Accuracy/Complexity
    Tradeo
    ff
    *
    *from Cranmer+2020; similar to
    Schmidt & Lipson, 2009

    View Slide

  93. Interpretation Results for f
    Complexity
    Accuracy/Complexity
    Tradeo
    ff
    *
    *from Cranmer+2020; similar to
    Schmidt & Lipson, 2009

    View Slide

  94. Interpretation Results for f
    Complexity
    Accuracy/Complexity
    Tradeo
    ff
    *
    *from Cranmer+2020; similar to
    Schmidt & Lipson, 2009

    View Slide

  95. Interpretation Results for f
    Complexity
    Accuracy/Complexity
    Tradeo
    ff
    *
    *from Cranmer+2020; similar to
    Schmidt & Lipson, 2009
    = −
    d(log(error))
    d(complexity)

    View Slide

  96. Test the symbolic model:

    View Slide

  97. View Slide

  98. View Slide

  99. Why isn’t this working well?
    • Let’s look at the mass values in comparison with the true masses:

    View Slide

  100. Why isn’t this working well?
    • Let’s look at the mass values in comparison with the true masses:

    View Slide

  101. Solution: re-optimize !
    vi

    View Slide

  102. Solution: re-optimize !
    vi
    • The were optimized for the neural network.
    vi

    View Slide

  103. Solution: re-optimize !
    vi
    • The were optimized for the neural network.
    vi
    • The symbolic formula is not a *perfect* approximation of the network.

    View Slide

  104. Solution: re-optimize !
    vi
    • The were optimized for the neural network.
    vi
    • The symbolic formula is not a *perfect* approximation of the network.
    • Thus: we need to re-optimize for the symbolic function !
    vi
    f

    View Slide

  105. View Slide

  106. View Slide

  107. View Slide

  108. V. Ongoing Work: Turbulence
    Work includes: Dmitrii Kochkov, Keaton Burns, Drummond Fielding,
    and others

    View Slide

  109. Learned Coarse Models for Ef
    fi
    cient Turbulence Simulation

    (ICLR 2022)


    Kimberly Stachenfeld, Drummond B. Fielding, Dmitrii Kochkov, Miles Cranmer,

    Tobias Pfaff, Jonathan Godwin, Can Cui, Shirley Ho, Peter Battaglia, Alvaro Sanchez-
    Gonzalez
    Example:

    View Slide

  110. Learned Coarse Models for Ef
    fi
    cient Turbulence Simulation

    (ICLR 2022)


    Kimberly Stachenfeld, Drummond B. Fielding, Dmitrii Kochkov, Miles Cranmer,

    Tobias Pfaff, Jonathan Godwin, Can Cui, Shirley Ho, Peter Battaglia, Alvaro Sanchez-
    Gonzalez
    Example:
    Trained to reproduce turbulence simulations at lower resolution:

    View Slide

  111. Learned Coarse Models for Ef
    fi
    cient Turbulence Simulation

    (ICLR 2022)


    Kimberly Stachenfeld, Drummond B. Fielding, Dmitrii Kochkov, Miles Cranmer,

    Tobias Pfaff, Jonathan Godwin, Can Cui, Shirley Ho, Peter Battaglia, Alvaro Sanchez-
    Gonzalez
    1000x speedup:
    Example:
    Trained to reproduce turbulence simulations at lower resolution:

    View Slide

  112. Learned Coarse Models for Ef
    fi
    cient Turbulence Simulation

    (ICLR 2022)


    Kimberly Stachenfeld, Drummond B. Fielding, Dmitrii Kochkov, Miles Cranmer,

    Tobias Pfaff, Jonathan Godwin, Can Cui, Shirley Ho, Peter Battaglia, Alvaro Sanchez-
    Gonzalez
    1000x speedup:
    How did the model actually achieve this?
    Example:
    Trained to reproduce turbulence simulations at lower resolution:

    View Slide

  113. (Preliminary results)
    τxx
    τxy
    τyx
    τyy
    (Non-symmetric, since o
    ff
    set!)

    View Slide

  114. Summary
    • Symbolic distillation is a technique for translating ML models into a
    domain speci
    fi
    c language


    • Can do this for expressions/programs using PySR


    • Exciting future applications in understanding turbulence, and other
    physical systems

    View Slide