Jia-Jie Zhu
March 25, 2024
390

# Jiaxin Shi (Google DeepMind, London, UK) Stein’s Method for Modern Machine Learning: From Gradient Estimation to Generative Modeling

WORKSHOP ON OPTIMAL TRANSPORT
FROM THEORY TO APPLICATIONS
INTERFACING DYNAMICAL SYSTEMS, OPTIMIZATION, AND MACHINE LEARNING
Venue: Humboldt University of Berlin, Dorotheenstraße 24

Berlin, Germany. March 11th - 15th, 2024

March 25, 2024

## Transcript

1. ### Stein’s Method for Modern Machine Learning From Gradient Estimation to

Generative Modeling Jiaxin Shi Google DeepMind   2024/3/14 @ OT-Berlin jiaxins.io
2. ### Outline • Stein’s method: Foundations • Stein’s method and machine

learning • Sampling • Gradient estimation • Score-based modeling
3. ### Divergences between Probability Distributions GANs, Diffusion models Transformers • How

well does my model fit the data? • Parameter estimation by minimizing divergences • Sampling as optimization https://openai.com/blog/generative-models/
4. ### Integral Probability Metrics (IPM) • When is sufficient large, convergence

in implies weakly converges to • Examples: Total variation distance, Wasserstein distance • Problem: Often is our model and integration under is intractable • Idea: Only consider functions with ℋ dℋ (qn , p) qn p p p 𝔼 p [h(Y)] = 0 dℋ (q, p) = sup h∈ℋ | 𝔼 q [h(X)] − 𝔼 p [h(Y)]|
5. ### Stein’s Method • Identify an operator that generates mean-zero functions

under   target distribution . for all • Define the Stein discrepancy: • Show that the Stein discrepancy is lower bounded by an IPM. For example, if for any , a solution exists for the equation , then . 𝒯 p 𝔼 p [( 𝒯 g)(X)] = 0 g ∈ 𝒢 𝒮 (q, 𝒯 , 𝒢 ) ≜ sup g∈ 𝒢 𝔼 q [( 𝒯 g)(X)] − 𝔼 p [( 𝒯 g)(X)] h ∈ ℋ g ∈ 𝒢 h(x) − 𝔼 p [h(Y)] = ( 𝒯 g)(x) dℋ (q, p) ≤ 𝒮 (q, 𝒯 , 𝒢 ) [Stein, 1972]
6. ### Identifying a Stein Operator Stein’s Lemma If is a standard

normal distribution, then for all The corresponding Stein operator: p 𝔼 p [g′ ￼ (X) − Xg(X)] = 0 g ∈ C1 b 𝒯 (g) = g′ ￼ (x) − xg(x)
7. ### Identifying a Stein Operator Barbour’s generalization via stochastic processes •

The (infinitesimal) generator of a stochastic process is defined as • The generator of a stochastic process with stationary distribution satisfies . A (Xt )t≥0 (Af )(x) = lim t→0 𝔼 [ f(Xt )|X0 = x] − f(x) t . p 𝔼 p [(Af )(X)] = 0 [Barbour, 1988 & 1990]
8. ### Langevin Stein Operator • Langevin diffusion on : • Generator:

• Convenient form with a vector-valued function : • Depends on only through , computable even for unnormalized ℝd dXt = ∇log p(Xt )dt + 2dWt (Af )(x) = ∇log p(x)⊤ ∇f(x) + ∇ ⋅ ∇f(x) g : ℝd → ℝd ( 𝒯 p g)(x) = ∇log p(x)⊤g(x) + ∇ ⋅ g(x) p ∇log p p [Gorham & Mackey, 2015]
9. ### Stein Operators and Sampling Find the direction that most quickly

decreases the KL divergence to p qt · θt = gt (θt ) p d dt KL(qt ∥p) = − 𝔼 qt [( 𝒯 p gt )(X)] [Liu & Wang, 2016]
10. ### Wasserstein Gradient Flow and SVGD inf gt ∈ 𝒢 d

dt KL(qt ∥p) = − sup gt ∈ 𝒢 𝔼 qt [( 𝒯 p gt )(X)] • : Wasserstein Gradient Flow Same density evolution as Langevin diffusion • RKHS of kernel : Stein Variational Gradient Descent 𝒢 = ℒ2(qt ) g* t ∝ ∇log p − ∇log qt , 𝒢 = K g* t ∝ 𝔼 qt [K( ⋅ , X)∇log p(X) + ∇X ⋅ K( ⋅ , X)] [Liu & Wang, 2016]
11. ### Convergence Analysis of SVGD • Korba, A., Salim, A., Arbel,

M., Luise, G., & Gretton, A. A non-asymptotic analysis for Stein variational gradient descent. NeurIPS (2020). • Chewi, S., Le Gouic, T., Lu, C., Maunu, T., & Rigollet, P. SVGD as a kernelized Wasserstein gradient flow of the chi-squared divergence. NeurIPS (2020). • Shi, J., & Mackey, L. A finite-particle convergence rate for Stein variational gradient descent. NeurIPS (2023). Convergence rate for discrete-time, fi nite-particle SVGD

13. ### The Gradient Estimation Problem ∇η 𝔼 qη [ f(X)] Encoder

Decoding error A common problem in training generative models and reinforcement learning [Lim et al., 2018]
14. ### The Gradient Estimation Problem A common problem in training generative

models and reinforcement learning ∇η 𝔼 qη [ f(X)] Policy Reward Model [Lim et al., 2018] https://aws.amazon.com/de/what-is/reinforcement-learning-from-human-feedback/
15. ### • Computing exact gradients is often intractable ∇η 𝔼 qη

[ f(X)] = ∇η ∑ x∈{0,1}d qη (x)f(x) • Discrete data, states, and actions Discrete Gradient Estimation d-dimensional binary vector Intractable sum over con fi gurations 2d Complex, nonlinear function [Wei et al., 2022] [Silver et al., 2017] [Alamdari et al., 2023]
16. ### Gradient Estimation and Variance Reduction ̂ g2 = 1 K

K ∑ k=1 [ f(xk )∇η log qη (xk ) + cv(xk )] − 𝔼 qη [cv(X)] ̂ g1 = 1 K K ∑ k=1 f(xk )∇η log qη (xk ) (REINFORCE) • Strong correlation is required for effective variance reduction • Fundamental tradeoff: needs to be very flexible but still have analytic expectation under . cv qη Control Variates High variance! ̂ g2 = 1 K K ∑ k=1 [ f(xk )∇η log qη (xk ) + (Ah)(xk )] − 𝔼 qη [(Ah)(X)] : Stein Operator A = 0
17. ### Discrete Stein Operators How: Apply Barbour’s idea to discrete-state Markov

chains. 𝔼 q [((K − I)h)(X)] = 0 𝔼 q [(Ah)(X)] = 0 Shi, Zhou, Hwang, Titsias & Mackey. Gradient estimation with discrete Stein operators, NeurIPS 2022 Outstanding Paper. cont. time : transfer operator K : generator A
18. ### Experiments: Training Binary Latent VAEs Ours ~40% reduction Shi, Zhou,

Hwang, Titsias & Mackey. Gradient estimation with discrete Stein operators, NeurIPS 2022 Outstanding Paper. Prior SOTA

20. ### min θ | 𝔼 q [h(x)⊤ ∇x log pθ (x)+∇

⋅ h(x)]| Data distribution Model distribution ? Model fi tting: Stein Discrepancy as a Learning Rule
21. ### → min θ 𝔼 qdata [∥∇log pθ (x) − ∇log

qdata (x)∥2] min θ sup h∈L2(q) | 𝔼 q [h(x)⊤ ∇x log pθ (x)+∇ ⋅ h(x)]| Data distribution Model distribution Score Matching [Hyvärinen, 2005] Model fi tting: −
22. ### Training Energy-Based Models Key insight: The score does not depend

on normalizing constant Zθ ∇x log pθ (x) = − ∇Eθ (x) + ∇x log Zθ pθ (x) = e−Eθ (x) Zθ • Score Matching is more suitable for training such models than maximum likelihood! x Eθ (x) Song*, Garg*, Shi & Ermon. Sliced score matching: A scalable approach to density and score estimation. UAI 2019 Time Performance Sliced Score Matching [Song*, Garg*, S & Ermon, UAI’19]
23. ### Score-Based Modeling Idea: Model the score instead of the density

Advantages: 1. less computation than energy-based modeling 2. enable more flexible models s := ∇log p min s∈ℋ 𝔼 qdata ∥s(x) − ∇log qdata (x)∥2 + λ 2 ∥s∥2 ℋ Nonparametric Score Model The spectral estimator (Shi et al., 18) is a special case. Zhou, Shi & Zhu. Nonparametric score estimators. ICML 2020
24. ### ∇x log q(x) = − ∑ j≥1 𝔼 q [∇ψj

(x)]ψj (x) Shi, Sun & Zhu. A spectral approach to gradient estimation for implicit distributions. ICML 2018 density gradients (score) eigenfunction 𝔼 x′ ￼ ∼q [k(x, x′ ￼ )ψj (x′ ￼ )] = λj ψj (x) A Spectral Method for Score Estimation ⟨∇log q, ψj ⟩L2(q) = − 𝔼 q [∇ψj (x)]
25. ### A Spectral Method for Score Estimation Shi, Sun & Zhu.

A spectral approach to gradient estimation for implicit distributions. ICML 2018 Zhou, Shi & Zhu. Nonparametric score estimators. ICML 2020 ∇x log q(x) {xj}M j=1 i.i.d. ∼ q (unknown) q(x) Score function
26. ### Idea: Model the score instead of the density Advantages: 1.

less computation than energy-based modeling 2. enable more flexible models s := ∇log p Score-Based Modeling min s∈ℋ 𝔼 qdata ∥s(x) − ∇log qdata (x)∥2 + λ 2 ∥s∥2 ℋ Nonparametric Score Model The spectral estimator (Shi et al., 18) is a special case. x Score Network min θ 𝔼 qdata ∥sθ (x) − ∇log qdata (x)∥2 sθ (x) ≈ ∇log qdata (x) Use neural networks to model score， trained by sliced score matching Song*, Garg*, Shi & Ermon. Sliced score matching: A scalable approach to density and score estimation. UAI 2019 Zhou, Shi & Zhu. Nonparametric score estimators. ICML 2020
27. ### From Score Networks to Diffusion Models [Song et al., ICLR’20]

Images created by OpenAI’s DALLE-2. DALLE-2 is based on diffusion models. Updates produced by score networks transform   noise to data
28. ### Open Problems • Improving finite-particle rates of SVGD • Approximately

solving the Stein equation for improved gradient estimation • Lower bounding the discrete Stein discrepancy • Learning the features in nonparametric score models • Finding the “right” discrete correspondence of the score matching objective
29. ### Main References Joint work with Lester Mackey, Yuhao Zhou, Jessica

Hwang, Michalis K. Titsias, Shengyang Sun, Jun Zhu, Yang Song, Sahaj Garg, Stefano Ermon • Shi & Mackey. A fi nite-particle convergence rate for Stein variational gradient descent. NeurIPS 2023. • Shi, Zhou, Hwang, Titsias, & Mackey. Gradient estimation with discrete Stein operators. NeurIPS 2022 • Titsias & Shi. Double control variates for gradient estimation in discrete latent-variable models. AISTATS 2022 • Shi, Sun, & Zhu. A spectral approach to gradient estimation for implicit distributions. ICML 2018 • Song, Garg, Shi, & Ermon. Sliced score matching: A scalable approach to density and score estimation. UAI 2019 • Zhou, Shi, Zhu. Nonparametric score estimators. ICML 2020
30. ### • Stein, C. (1972). A bound for the error in

the normal approximation to the distribution of a sum of dependent random variables. • Barbour, A. D. (1988). Stein's method and Poisson process convergence. Journal of Applied Probability, 25(A), 175-184. • Barbour, A. D. (1990). Stein's method for diffusion approximations. Probability theory and related fi elds, 84(3), 297-322. • Gorham, J., & Mackey, L. (2015). Measuring sample quality with Stein's method. Advances in Neural Information Processing Systems, 28. • Liu, Q., & Wang, D. (2016). Stein variational gradient descent: A general purpose bayesian inference algorithm. Advances in Neural Information Processing Systems, 29. References
31. ### • Lim, J., Ryu, S., Kim, J. W., & Kim,

W. Y. (2018). Molecular generative model based on conditional variational autoencoder for de novo molecular design. Journal of Cheminformatics, 10, 1-9. • Wei, J., Wang, X., Schuurmans, D., Bosma, M., Xia, F., Chi, E., ... & Zhou, D. (2022). Chain-of-thought prompting elicits reasoning in large language models. Advances in Neural Information Processing Systems, 35, 24824-24837. • Alamdari, S., Thakkar, N., van den Berg, R., Lu, A. X., Fusi, N., Amini, A. P., & Yang, K. K. (2023). Protein generation with evolutionary diffusion: sequence is all you need. bioRxiv, 2023-09. • Silver, D., Schrittwieser, J., Simonyan, K., Antonoglou, I., Huang, A., Guez, A., ... & Hassabis, D. (2017). Mastering the game of go without human knowledge. Nature, 550(7676), 354-359. • Song, Y., Sohl-Dickstein, J., Kingma, D. P., Kumar, A., Ermon, S., & Poole, B. (2020). Score-based generative modeling through stochastic differential equations. arXiv preprint arXiv:2011.13456. References