Slide 1

Slide 1 text

Stein’s Method for Modern Machine Learning From Gradient Estimation to Generative Modeling Jiaxin Shi Google DeepMind 
 2024/3/14 @ OT-Berlin jiaxins.io

Slide 2

Slide 2 text

Outline • Stein’s method: Foundations • Stein’s method and machine learning • Sampling • Gradient estimation • Score-based modeling

Slide 3

Slide 3 text

Divergences between Probability Distributions GANs, Diffusion models Transformers • How well does my model fit the data? • Parameter estimation by minimizing divergences • Sampling as optimization https://openai.com/blog/generative-models/

Slide 4

Slide 4 text

Integral Probability Metrics (IPM) • When is sufficient large, convergence in implies weakly converges to • Examples: Total variation distance, Wasserstein distance • Problem: Often is our model and integration under is intractable • Idea: Only consider functions with ℋ dℋ (qn , p) qn p p p 𝔼 p [h(Y)] = 0 dℋ (q, p) = sup h∈ℋ | 𝔼 q [h(X)] − 𝔼 p [h(Y)]|

Slide 5

Slide 5 text

Stein’s Method • Identify an operator that generates mean-zero functions under 
 target distribution . for all • Define the Stein discrepancy: • Show that the Stein discrepancy is lower bounded by an IPM. For example, if for any , a solution exists for the equation , then . 𝒯 p 𝔼 p [( 𝒯 g)(X)] = 0 g ∈ 𝒢 𝒮 (q, 𝒯 , 𝒢 ) ≜ sup g∈ 𝒢 𝔼 q [( 𝒯 g)(X)] − 𝔼 p [( 𝒯 g)(X)] h ∈ ℋ g ∈ 𝒢 h(x) − 𝔼 p [h(Y)] = ( 𝒯 g)(x) dℋ (q, p) ≤ 𝒮 (q, 𝒯 , 𝒢 ) [Stein, 1972]

Slide 6

Slide 6 text

Identifying a Stein Operator Stein’s Lemma If is a standard normal distribution, then for all The corresponding Stein operator: p 𝔼 p [g′  (X) − Xg(X)] = 0 g ∈ C1 b 𝒯 (g) = g′  (x) − xg(x)

Slide 7

Slide 7 text

Identifying a Stein Operator Barbour’s generalization via stochastic processes • The (infinitesimal) generator of a stochastic process is defined as • The generator of a stochastic process with stationary distribution satisfies . A (Xt )t≥0 (Af )(x) = lim t→0 𝔼 [ f(Xt )|X0 = x] − f(x) t . p 𝔼 p [(Af )(X)] = 0 [Barbour, 1988 & 1990]

Slide 8

Slide 8 text

Langevin Stein Operator • Langevin diffusion on : • Generator: • Convenient form with a vector-valued function : • Depends on only through , computable even for unnormalized ℝd dXt = ∇log p(Xt )dt + 2dWt (Af )(x) = ∇log p(x)⊤ ∇f(x) + ∇ ⋅ ∇f(x) g : ℝd → ℝd ( 𝒯 p g)(x) = ∇log p(x)⊤g(x) + ∇ ⋅ g(x) p ∇log p p [Gorham & Mackey, 2015]

Slide 9

Slide 9 text

Stein Operators and Sampling Find the direction that most quickly decreases the KL divergence to p qt · θt = gt (θt ) p d dt KL(qt ∥p) = − 𝔼 qt [( 𝒯 p gt )(X)] [Liu & Wang, 2016]

Slide 10

Slide 10 text

Wasserstein Gradient Flow and SVGD inf gt ∈ 𝒢 d dt KL(qt ∥p) = − sup gt ∈ 𝒢 𝔼 qt [( 𝒯 p gt )(X)] • : Wasserstein Gradient Flow Same density evolution as Langevin diffusion • RKHS of kernel : Stein Variational Gradient Descent 𝒢 = ℒ2(qt ) g* t ∝ ∇log p − ∇log qt , 𝒢 = K g* t ∝ 𝔼 qt [K( ⋅ , X)∇log p(X) + ∇X ⋅ K( ⋅ , X)] [Liu & Wang, 2016]

Slide 11

Slide 11 text

Convergence Analysis of SVGD • Korba, A., Salim, A., Arbel, M., Luise, G., & Gretton, A. A non-asymptotic analysis for Stein variational gradient descent. NeurIPS (2020). • Chewi, S., Le Gouic, T., Lu, C., Maunu, T., & Rigollet, P. SVGD as a kernelized Wasserstein gradient flow of the chi-squared divergence. NeurIPS (2020). • Shi, J., & Mackey, L. A finite-particle convergence rate for Stein variational gradient descent. NeurIPS (2023). Convergence rate for discrete-time, fi nite-particle SVGD

Slide 12

Slide 12 text

Stein’s Method and Gradient Estimation

Slide 13

Slide 13 text

The Gradient Estimation Problem ∇η 𝔼 qη [ f(X)] Encoder Decoding error A common problem in training generative models and reinforcement learning [Lim et al., 2018]

Slide 14

Slide 14 text

The Gradient Estimation Problem A common problem in training generative models and reinforcement learning ∇η 𝔼 qη [ f(X)] Policy Reward Model [Lim et al., 2018] https://aws.amazon.com/de/what-is/reinforcement-learning-from-human-feedback/

Slide 15

Slide 15 text

● Computing exact gradients is often intractable ∇η 𝔼 qη [ f(X)] = ∇η ∑ x∈{0,1}d qη (x)f(x) ● Discrete data, states, and actions Discrete Gradient Estimation d-dimensional binary vector Intractable sum over con fi gurations 2d Complex, nonlinear function [Wei et al., 2022] [Silver et al., 2017] [Alamdari et al., 2023]

Slide 16

Slide 16 text

Gradient Estimation and Variance Reduction ̂ g2 = 1 K K ∑ k=1 [ f(xk )∇η log qη (xk ) + cv(xk )] − 𝔼 qη [cv(X)] ̂ g1 = 1 K K ∑ k=1 f(xk )∇η log qη (xk ) (REINFORCE) • Strong correlation is required for effective variance reduction • Fundamental tradeoff: needs to be very flexible but still have analytic expectation under . cv qη Control Variates High variance! ̂ g2 = 1 K K ∑ k=1 [ f(xk )∇η log qη (xk ) + (Ah)(xk )] − 𝔼 qη [(Ah)(X)] : Stein Operator A = 0

Slide 17

Slide 17 text

Discrete Stein Operators How: Apply Barbour’s idea to discrete-state Markov chains. 𝔼 q [((K − I)h)(X)] = 0 𝔼 q [(Ah)(X)] = 0 Shi, Zhou, Hwang, Titsias & Mackey. Gradient estimation with discrete Stein operators, NeurIPS 2022 Outstanding Paper. cont. time : transfer operator K : generator A

Slide 18

Slide 18 text

Experiments: Training Binary Latent VAEs Ours ~40% reduction Shi, Zhou, Hwang, Titsias & Mackey. Gradient estimation with discrete Stein operators, NeurIPS 2022 Outstanding Paper. Prior SOTA

Slide 19

Slide 19 text

Stein’s Method and Score-Based Modeling

Slide 20

Slide 20 text

min θ | 𝔼 q [h(x)⊤ ∇x log pθ (x)+∇ ⋅ h(x)]| Data distribution Model distribution ? Model fi tting: Stein Discrepancy as a Learning Rule

Slide 21

Slide 21 text

→ min θ 𝔼 qdata [∥∇log pθ (x) − ∇log qdata (x)∥2] min θ sup h∈L2(q) | 𝔼 q [h(x)⊤ ∇x log pθ (x)+∇ ⋅ h(x)]| Data distribution Model distribution Score Matching [Hyvärinen, 2005] Model fi tting: −

Slide 22

Slide 22 text

Training Energy-Based Models Key insight: The score does not depend on normalizing constant Zθ ∇x log pθ (x) = − ∇Eθ (x) + ∇x log Zθ pθ (x) = e−Eθ (x) Zθ • Score Matching is more suitable for training such models than maximum likelihood! x Eθ (x) Song*, Garg*, Shi & Ermon. Sliced score matching: A scalable approach to density and score estimation. UAI 2019 Time Performance Sliced Score Matching [Song*, Garg*, S & Ermon, UAI’19]

Slide 23

Slide 23 text

Score-Based Modeling Idea: Model the score instead of the density Advantages: 1. less computation than energy-based modeling 2. enable more flexible models s := ∇log p min s∈ℋ 𝔼 qdata ∥s(x) − ∇log qdata (x)∥2 + λ 2 ∥s∥2 ℋ Nonparametric Score Model The spectral estimator (Shi et al., 18) is a special case. Zhou, Shi & Zhu. Nonparametric score estimators. ICML 2020

Slide 24

Slide 24 text

∇x log q(x) = − ∑ j≥1 𝔼 q [∇ψj (x)]ψj (x) Shi, Sun & Zhu. A spectral approach to gradient estimation for implicit distributions. ICML 2018 density gradients (score) eigenfunction 𝔼 x′  ∼q [k(x, x′  )ψj (x′  )] = λj ψj (x) A Spectral Method for Score Estimation ⟨∇log q, ψj ⟩L2(q) = − 𝔼 q [∇ψj (x)]

Slide 25

Slide 25 text

A Spectral Method for Score Estimation Shi, Sun & Zhu. A spectral approach to gradient estimation for implicit distributions. ICML 2018 Zhou, Shi & Zhu. Nonparametric score estimators. ICML 2020 ∇x log q(x) {xj}M j=1 i.i.d. ∼ q (unknown) q(x) Score function

Slide 26

Slide 26 text

Idea: Model the score instead of the density Advantages: 1. less computation than energy-based modeling 2. enable more flexible models s := ∇log p Score-Based Modeling min s∈ℋ 𝔼 qdata ∥s(x) − ∇log qdata (x)∥2 + λ 2 ∥s∥2 ℋ Nonparametric Score Model The spectral estimator (Shi et al., 18) is a special case. x Score Network min θ 𝔼 qdata ∥sθ (x) − ∇log qdata (x)∥2 sθ (x) ≈ ∇log qdata (x) Use neural networks to model score, trained by sliced score matching Song*, Garg*, Shi & Ermon. Sliced score matching: A scalable approach to density and score estimation. UAI 2019 Zhou, Shi & Zhu. Nonparametric score estimators. ICML 2020

Slide 27

Slide 27 text

From Score Networks to Diffusion Models [Song et al., ICLR’20] Images created by OpenAI’s DALLE-2. DALLE-2 is based on diffusion models. Updates produced by score networks transform 
 noise to data

Slide 28

Slide 28 text

Open Problems ● Improving finite-particle rates of SVGD ● Approximately solving the Stein equation for improved gradient estimation ● Lower bounding the discrete Stein discrepancy ● Learning the features in nonparametric score models ● Finding the “right” discrete correspondence of the score matching objective

Slide 29

Slide 29 text

Main References Joint work with Lester Mackey, Yuhao Zhou, Jessica Hwang, Michalis K. Titsias, Shengyang Sun, Jun Zhu, Yang Song, Sahaj Garg, Stefano Ermon • Shi & Mackey. A fi nite-particle convergence rate for Stein variational gradient descent. NeurIPS 2023. • Shi, Zhou, Hwang, Titsias, & Mackey. Gradient estimation with discrete Stein operators. NeurIPS 2022 • Titsias & Shi. Double control variates for gradient estimation in discrete latent-variable models. AISTATS 2022 • Shi, Sun, & Zhu. A spectral approach to gradient estimation for implicit distributions. ICML 2018 • Song, Garg, Shi, & Ermon. Sliced score matching: A scalable approach to density and score estimation. UAI 2019 • Zhou, Shi, Zhu. Nonparametric score estimators. ICML 2020

Slide 30

Slide 30 text

• Stein, C. (1972). A bound for the error in the normal approximation to the distribution of a sum of dependent random variables. • Barbour, A. D. (1988). Stein's method and Poisson process convergence. Journal of Applied Probability, 25(A), 175-184. • Barbour, A. D. (1990). Stein's method for diffusion approximations. Probability theory and related fi elds, 84(3), 297-322. • Gorham, J., & Mackey, L. (2015). Measuring sample quality with Stein's method. Advances in Neural Information Processing Systems, 28. • Liu, Q., & Wang, D. (2016). Stein variational gradient descent: A general purpose bayesian inference algorithm. Advances in Neural Information Processing Systems, 29. References

Slide 31

Slide 31 text

• Lim, J., Ryu, S., Kim, J. W., & Kim, W. Y. (2018). Molecular generative model based on conditional variational autoencoder for de novo molecular design. Journal of Cheminformatics, 10, 1-9. • Wei, J., Wang, X., Schuurmans, D., Bosma, M., Xia, F., Chi, E., ... & Zhou, D. (2022). Chain-of-thought prompting elicits reasoning in large language models. Advances in Neural Information Processing Systems, 35, 24824-24837. • Alamdari, S., Thakkar, N., van den Berg, R., Lu, A. X., Fusi, N., Amini, A. P., & Yang, K. K. (2023). Protein generation with evolutionary diffusion: sequence is all you need. bioRxiv, 2023-09. • Silver, D., Schrittwieser, J., Simonyan, K., Antonoglou, I., Huang, A., Guez, A., ... & Hassabis, D. (2017). Mastering the game of go without human knowledge. Nature, 550(7676), 354-359. • Song, Y., Sohl-Dickstein, J., Kingma, D. P., Kumar, A., Ermon, S., & Poole, B. (2020). Score-based generative modeling through stochastic differential equations. arXiv preprint arXiv:2011.13456. References