Upgrade to Pro — share decks privately, control downloads, hide ads and more …

Jiaxin Shi (Google DeepMind, London, UK) Stein’s Method for Modern Machine Learning: From Gradient Estimation to Generative Modeling

Jia-Jie Zhu
March 25, 2024
390

Jiaxin Shi (Google DeepMind, London, UK) Stein’s Method for Modern Machine Learning: From Gradient Estimation to Generative Modeling

WORKSHOP ON OPTIMAL TRANSPORT
FROM THEORY TO APPLICATIONS
INTERFACING DYNAMICAL SYSTEMS, OPTIMIZATION, AND MACHINE LEARNING
Venue: Humboldt University of Berlin, Dorotheenstraße 24

Berlin, Germany. March 11th - 15th, 2024

Jia-Jie Zhu

March 25, 2024
Tweet

More Decks by Jia-Jie Zhu

Transcript

  1. Stein’s Method for Modern Machine Learning From Gradient Estimation to

    Generative Modeling Jiaxin Shi Google DeepMind 
 2024/3/14 @ OT-Berlin jiaxins.io
  2. Outline • Stein’s method: Foundations • Stein’s method and machine

    learning • Sampling • Gradient estimation • Score-based modeling
  3. Divergences between Probability Distributions GANs, Diffusion models Transformers • How

    well does my model fit the data? • Parameter estimation by minimizing divergences • Sampling as optimization https://openai.com/blog/generative-models/
  4. Integral Probability Metrics (IPM) • When is sufficient large, convergence

    in implies weakly converges to • Examples: Total variation distance, Wasserstein distance • Problem: Often is our model and integration under is intractable • Idea: Only consider functions with ℋ dℋ (qn , p) qn p p p 𝔼 p [h(Y)] = 0 dℋ (q, p) = sup h∈ℋ | 𝔼 q [h(X)] − 𝔼 p [h(Y)]|
  5. Stein’s Method • Identify an operator that generates mean-zero functions

    under 
 target distribution . for all • Define the Stein discrepancy: • Show that the Stein discrepancy is lower bounded by an IPM. For example, if for any , a solution exists for the equation , then . 𝒯 p 𝔼 p [( 𝒯 g)(X)] = 0 g ∈ 𝒢 𝒮 (q, 𝒯 , 𝒢 ) ≜ sup g∈ 𝒢 𝔼 q [( 𝒯 g)(X)] − 𝔼 p [( 𝒯 g)(X)] h ∈ ℋ g ∈ 𝒢 h(x) − 𝔼 p [h(Y)] = ( 𝒯 g)(x) dℋ (q, p) ≤ 𝒮 (q, 𝒯 , 𝒢 ) [Stein, 1972]
  6. Identifying a Stein Operator Stein’s Lemma If is a standard

    normal distribution, then for all The corresponding Stein operator: p 𝔼 p [g′  (X) − Xg(X)] = 0 g ∈ C1 b 𝒯 (g) = g′  (x) − xg(x)
  7. Identifying a Stein Operator Barbour’s generalization via stochastic processes •

    The (infinitesimal) generator of a stochastic process is defined as • The generator of a stochastic process with stationary distribution satisfies . A (Xt )t≥0 (Af )(x) = lim t→0 𝔼 [ f(Xt )|X0 = x] − f(x) t . p 𝔼 p [(Af )(X)] = 0 [Barbour, 1988 & 1990]
  8. Langevin Stein Operator • Langevin diffusion on : • Generator:

    • Convenient form with a vector-valued function : • Depends on only through , computable even for unnormalized ℝd dXt = ∇log p(Xt )dt + 2dWt (Af )(x) = ∇log p(x)⊤ ∇f(x) + ∇ ⋅ ∇f(x) g : ℝd → ℝd ( 𝒯 p g)(x) = ∇log p(x)⊤g(x) + ∇ ⋅ g(x) p ∇log p p [Gorham & Mackey, 2015]
  9. Stein Operators and Sampling Find the direction that most quickly

    decreases the KL divergence to p qt · θt = gt (θt ) p d dt KL(qt ∥p) = − 𝔼 qt [( 𝒯 p gt )(X)] [Liu & Wang, 2016]
  10. Wasserstein Gradient Flow and SVGD inf gt ∈ 𝒢 d

    dt KL(qt ∥p) = − sup gt ∈ 𝒢 𝔼 qt [( 𝒯 p gt )(X)] • : Wasserstein Gradient Flow Same density evolution as Langevin diffusion • RKHS of kernel : Stein Variational Gradient Descent 𝒢 = ℒ2(qt ) g* t ∝ ∇log p − ∇log qt , 𝒢 = K g* t ∝ 𝔼 qt [K( ⋅ , X)∇log p(X) + ∇X ⋅ K( ⋅ , X)] [Liu & Wang, 2016]
  11. Convergence Analysis of SVGD • Korba, A., Salim, A., Arbel,

    M., Luise, G., & Gretton, A. A non-asymptotic analysis for Stein variational gradient descent. NeurIPS (2020). • Chewi, S., Le Gouic, T., Lu, C., Maunu, T., & Rigollet, P. SVGD as a kernelized Wasserstein gradient flow of the chi-squared divergence. NeurIPS (2020). • Shi, J., & Mackey, L. A finite-particle convergence rate for Stein variational gradient descent. NeurIPS (2023). Convergence rate for discrete-time, fi nite-particle SVGD
  12. The Gradient Estimation Problem ∇η 𝔼 qη [ f(X)] Encoder

    Decoding error A common problem in training generative models and reinforcement learning [Lim et al., 2018]
  13. The Gradient Estimation Problem A common problem in training generative

    models and reinforcement learning ∇η 𝔼 qη [ f(X)] Policy Reward Model [Lim et al., 2018] https://aws.amazon.com/de/what-is/reinforcement-learning-from-human-feedback/
  14. • Computing exact gradients is often intractable ∇η 𝔼 qη

    [ f(X)] = ∇η ∑ x∈{0,1}d qη (x)f(x) • Discrete data, states, and actions Discrete Gradient Estimation d-dimensional binary vector Intractable sum over con fi gurations 2d Complex, nonlinear function [Wei et al., 2022] [Silver et al., 2017] [Alamdari et al., 2023]
  15. Gradient Estimation and Variance Reduction ̂ g2 = 1 K

    K ∑ k=1 [ f(xk )∇η log qη (xk ) + cv(xk )] − 𝔼 qη [cv(X)] ̂ g1 = 1 K K ∑ k=1 f(xk )∇η log qη (xk ) (REINFORCE) • Strong correlation is required for effective variance reduction • Fundamental tradeoff: needs to be very flexible but still have analytic expectation under . cv qη Control Variates High variance! ̂ g2 = 1 K K ∑ k=1 [ f(xk )∇η log qη (xk ) + (Ah)(xk )] − 𝔼 qη [(Ah)(X)] : Stein Operator A = 0
  16. Discrete Stein Operators How: Apply Barbour’s idea to discrete-state Markov

    chains. 𝔼 q [((K − I)h)(X)] = 0 𝔼 q [(Ah)(X)] = 0 Shi, Zhou, Hwang, Titsias & Mackey. Gradient estimation with discrete Stein operators, NeurIPS 2022 Outstanding Paper. cont. time : transfer operator K : generator A
  17. Experiments: Training Binary Latent VAEs Ours ~40% reduction Shi, Zhou,

    Hwang, Titsias & Mackey. Gradient estimation with discrete Stein operators, NeurIPS 2022 Outstanding Paper. Prior SOTA
  18. min θ | 𝔼 q [h(x)⊤ ∇x log pθ (x)+∇

    ⋅ h(x)]| Data distribution Model distribution ? Model fi tting: Stein Discrepancy as a Learning Rule
  19. → min θ 𝔼 qdata [∥∇log pθ (x) − ∇log

    qdata (x)∥2] min θ sup h∈L2(q) | 𝔼 q [h(x)⊤ ∇x log pθ (x)+∇ ⋅ h(x)]| Data distribution Model distribution Score Matching [Hyvärinen, 2005] Model fi tting: −
  20. Training Energy-Based Models Key insight: The score does not depend

    on normalizing constant Zθ ∇x log pθ (x) = − ∇Eθ (x) + ∇x log Zθ pθ (x) = e−Eθ (x) Zθ • Score Matching is more suitable for training such models than maximum likelihood! x Eθ (x) Song*, Garg*, Shi & Ermon. Sliced score matching: A scalable approach to density and score estimation. UAI 2019 Time Performance Sliced Score Matching [Song*, Garg*, S & Ermon, UAI’19]
  21. Score-Based Modeling Idea: Model the score instead of the density

    Advantages: 1. less computation than energy-based modeling 2. enable more flexible models s := ∇log p min s∈ℋ 𝔼 qdata ∥s(x) − ∇log qdata (x)∥2 + λ 2 ∥s∥2 ℋ Nonparametric Score Model The spectral estimator (Shi et al., 18) is a special case. Zhou, Shi & Zhu. Nonparametric score estimators. ICML 2020
  22. ∇x log q(x) = − ∑ j≥1 𝔼 q [∇ψj

    (x)]ψj (x) Shi, Sun & Zhu. A spectral approach to gradient estimation for implicit distributions. ICML 2018 density gradients (score) eigenfunction 𝔼 x′  ∼q [k(x, x′  )ψj (x′  )] = λj ψj (x) A Spectral Method for Score Estimation ⟨∇log q, ψj ⟩L2(q) = − 𝔼 q [∇ψj (x)]
  23. A Spectral Method for Score Estimation Shi, Sun & Zhu.

    A spectral approach to gradient estimation for implicit distributions. ICML 2018 Zhou, Shi & Zhu. Nonparametric score estimators. ICML 2020 ∇x log q(x) {xj}M j=1 i.i.d. ∼ q (unknown) q(x) Score function
  24. Idea: Model the score instead of the density Advantages: 1.

    less computation than energy-based modeling 2. enable more flexible models s := ∇log p Score-Based Modeling min s∈ℋ 𝔼 qdata ∥s(x) − ∇log qdata (x)∥2 + λ 2 ∥s∥2 ℋ Nonparametric Score Model The spectral estimator (Shi et al., 18) is a special case. x Score Network min θ 𝔼 qdata ∥sθ (x) − ∇log qdata (x)∥2 sθ (x) ≈ ∇log qdata (x) Use neural networks to model score, trained by sliced score matching Song*, Garg*, Shi & Ermon. Sliced score matching: A scalable approach to density and score estimation. UAI 2019 Zhou, Shi & Zhu. Nonparametric score estimators. ICML 2020
  25. From Score Networks to Diffusion Models [Song et al., ICLR’20]

    Images created by OpenAI’s DALLE-2. DALLE-2 is based on diffusion models. Updates produced by score networks transform 
 noise to data
  26. Open Problems • Improving finite-particle rates of SVGD • Approximately

    solving the Stein equation for improved gradient estimation • Lower bounding the discrete Stein discrepancy • Learning the features in nonparametric score models • Finding the “right” discrete correspondence of the score matching objective
  27. Main References Joint work with Lester Mackey, Yuhao Zhou, Jessica

    Hwang, Michalis K. Titsias, Shengyang Sun, Jun Zhu, Yang Song, Sahaj Garg, Stefano Ermon • Shi & Mackey. A fi nite-particle convergence rate for Stein variational gradient descent. NeurIPS 2023. • Shi, Zhou, Hwang, Titsias, & Mackey. Gradient estimation with discrete Stein operators. NeurIPS 2022 • Titsias & Shi. Double control variates for gradient estimation in discrete latent-variable models. AISTATS 2022 • Shi, Sun, & Zhu. A spectral approach to gradient estimation for implicit distributions. ICML 2018 • Song, Garg, Shi, & Ermon. Sliced score matching: A scalable approach to density and score estimation. UAI 2019 • Zhou, Shi, Zhu. Nonparametric score estimators. ICML 2020
  28. • Stein, C. (1972). A bound for the error in

    the normal approximation to the distribution of a sum of dependent random variables. • Barbour, A. D. (1988). Stein's method and Poisson process convergence. Journal of Applied Probability, 25(A), 175-184. • Barbour, A. D. (1990). Stein's method for diffusion approximations. Probability theory and related fi elds, 84(3), 297-322. • Gorham, J., & Mackey, L. (2015). Measuring sample quality with Stein's method. Advances in Neural Information Processing Systems, 28. • Liu, Q., & Wang, D. (2016). Stein variational gradient descent: A general purpose bayesian inference algorithm. Advances in Neural Information Processing Systems, 29. References
  29. • Lim, J., Ryu, S., Kim, J. W., & Kim,

    W. Y. (2018). Molecular generative model based on conditional variational autoencoder for de novo molecular design. Journal of Cheminformatics, 10, 1-9. • Wei, J., Wang, X., Schuurmans, D., Bosma, M., Xia, F., Chi, E., ... & Zhou, D. (2022). Chain-of-thought prompting elicits reasoning in large language models. Advances in Neural Information Processing Systems, 35, 24824-24837. • Alamdari, S., Thakkar, N., van den Berg, R., Lu, A. X., Fusi, N., Amini, A. P., & Yang, K. K. (2023). Protein generation with evolutionary diffusion: sequence is all you need. bioRxiv, 2023-09. • Silver, D., Schrittwieser, J., Simonyan, K., Antonoglou, I., Huang, A., Guez, A., ... & Hassabis, D. (2017). Mastering the game of go without human knowledge. Nature, 550(7676), 354-359. • Song, Y., Sohl-Dickstein, J., Kingma, D. P., Kumar, A., Ermon, S., & Poole, B. (2020). Score-based generative modeling through stochastic differential equations. arXiv preprint arXiv:2011.13456. References