Upgrade to Pro — share decks privately, control downloads, hide ads and more …

Convergence theory and application of distribut...

Avatar for Jia-Jie Zhu Jia-Jie Zhu
August 18, 2025
14

Convergence theory and application of distribution optimization: Non-convexity, particle approximation, and diffusion models

Taiji Suzuki

ICSP 2025 invited session

Avatar for Jia-Jie Zhu

Jia-Jie Zhu

August 18, 2025
Tweet

More Decks by Jia-Jie Zhu

Transcript

  1. Convergence theory and application of distribution optimization: Non-convexity, particle approximation,

    and diffusion models 1 Taiji Suzuki The University of Tokyo / AIP-RIKEN 28th/July/2025 ICSP2025@Paris (Deep learning theory team)
  2. Probability measure optimization 2 Convex / Non-convex Application: - Training

    a neural network in the mean field regime - Training a transformer for in-context learning - Finetuning a generative model Mean field NN • Part 1: Convex • Part 2: Non-convex
  3. Presentation overview 3 • [Propagation of chaos] Nitanda, Lee, Kai,

    Sakaguchi, Suzuki: Propagation of Chaos for Mean-Field Langevin Dynamics and its Application to Model Ensemble. ICML2025. • [Optimization of a probability measure on a strict saddle objecive] - Kim, Suzuki: Transformers Learn Nonlinear Features In Context: Nonconvex Mean-field Dynamics on the Attention Landscape. ICML2024, oral. - Yamamto, Kim, Suzuki: Hessian-guided Perturbed Wasserstein Gradient Flows for Escaping Saddle Points. 2025. Part 2: Non-convex (strict-saddle objective) 1. Gaussian process perturbation: A polynomial time method to avoid a saddle point by a “random perturbation” of a probability measure. 2. In-context learning: The objective to train a two-layer transformer is strict-saddle. Mean-field Langevin dynamics (𝐹 𝜇 + 𝜆Ent(𝜇)) • Optimization of probability measures by WG-flow • Particle approximation • Defective log-Sobolev inequality Part 1: Convex (propagation of chaos)
  4. Mean field Langevin 4 Linear convergence of mean field Langevin

    dynamics ➢ Nitanda, Wu, Suzuki (AISTATS2022); Chizat (TMLR2022) Uniform-in-time propagation of chaos: ➢ Super log-Sobolev inequality: Suzuki, Nitanda, Wu (ICLR2023) ➢ Chen, Ren, Wang. Uniform-in-time propagation of chaos for mean field Langevin dynamics. 2022. (arXiv:2212.03050) ➢ Convergence analysis with finite particle/discrete time alg. : Suzuki, Nitanda, Wu (NeurIPS2023) (particle approximation) • Mean field Langevin dynamics: • Nitanda, Lee, Kai, Sakaguchi, Suzuki: Propagation of Chaos for Mean-Field Langevin Dynamics and its Application to Model Ensemble. ICML2025.
  5. Distribution optimization 5 Nonlinear convex functional Objective: Convex optimization on

    the probability measure space convex strictly convex = strictly convex + • • [Nitanda&Suzuki, 2017][Chizat&Bach, 2018][Mei, Montanari&Nguyen, 2018][Rotskoff&Vanden-Eijnden, 2018] 𝑀 → ∞ Linear with respect to 𝜇. Mean field neural network : Convex w.r.t. 𝝁! Mean field limit
  6. Mean field Langevin 6 Mean field Langevin dynamics Def (first

    variation) Objective Distribution dependent convex strictly convex = strictly convex + [Hu et al. 2019][Nitanda, Wu, Suzuki, 2022][Chizat, 2022] [Nitanda&Suzuki, 2017][Chizat&Bach, 2018][Mei, Montanari&Nguyen, 2018][Rotskoff&Vanden-Eijnden, 2018] 𝐹 Gradient
  7. Entropy sandwich 9 Proximal Gibbs measure: Theorem (Entropy sandwich) [Nitanda,

    Wu, Suzuki (AISTATS2022)][Chizat (2022)] LSI of 𝑝𝜇𝑡 LSI of 𝑝𝜇𝑡 Assumption: Log-Sobolev inequality of 𝑝𝜇 [Nitanda, Wu, Suzuki (AISTATS2022)][Chizat (2022)]
  8. Log-Sobolev inequality 10 𝐿(𝑥) is 𝜇-strongly convex Theorem (Bakry-Emery criterion)

    [Bakry and Émery, 1985] 𝑝 satisfies LSI with Theorem (Holley-Stroock bounded perturbation lemma) [Holley and Stroock, 1987] ⇒ • 𝑞 satisfies LSI with 𝛼′ • ℎ ∞ ≤ 𝐵 𝑝 satisfies LSI with ⇒ Issue: 𝛼 can be easily exp(−O(𝑑)) (e.g., Gaussian mixture)
  9. Practical algorithm 11 • Space discretization: 𝜇𝑡 is approximatd by

    𝑀 particles 𝜇𝑡 ≃ Ƹ 𝜇𝑡 = 1 𝑀 ∑𝛿 𝑋𝑡 (𝑖) where (space discretization) 𝑀 particles 𝑋𝑡 𝑖 𝑖=1 𝑀 Naïve application of Gronwal’s inequality yields Error = Ω(exp 𝑡 /𝑀) (not uniform-in-time)
  10. Propagation of chaos 12 Space discr. Under smoothness and boundedness

    of the loss function, it holds that Suppose that 𝑝𝜇 satisfies log-Sobolev inequality with a constant 𝛼. Prop [Chen, Ren Wang, 22][Suzuki, Wu, Nitanda, 23] (Existing result) (smoothness) 𝛻𝛿𝐿 𝜇 𝛿𝜇 𝑥 −𝛻𝛿𝐿 𝜈 𝛿𝜇 𝑦 ≤ 𝐶(𝑊2 𝜇, 𝜈 + 𝑥 − 𝑦 ) and (boundedness) 𝛻𝛿𝐿 𝜇 𝛿𝜇 𝑥 ≤ 𝑅. Assumption: [Suzuki, Wu, Nitanda: Convergence of mean-field Langevin dynamics: Time and space discretization, stochastic gradient, and variance reduction. arXiv:2306.07221] [Chen, Ren, Wang. Uniform-in-time propagation of chaos for mean field Langevin dynamics. arXiv:2212.03050, 2022.] ➢However, 𝜶 can be like exp −𝑶 𝒅 . (𝜇(𝑀): a joint distribution of 𝑀 particles) Uniform in time !
  11. Defective LSI and entropy sandwich 13 Theorem (Defective entropy sandwich)

    [Nitanda, Lee, Kai, Sakaguchi, Suzuki (ICML2025)] : LSI of conditional distribution. = 𝑂(1/𝑀2) : Bregman-divergence No LSI const.
  12. Defective LSI 14 (Fisher divergence) For any 𝑴, Under some

    smoothness condition, for any 𝑴, Theorem (our result) New bound [Nitanda, Lee, Kai, Sakaguchi, Suzuki, ICML2025]: The number of particles 𝑴 is independent of 𝜶 ≃ exp(−𝚶(𝒅)). Existing bound: 1 𝜆2𝛼𝑀 [Chen, Ren, Wang, 2022]
  13. Escaping from saddle point 15 [Yamamto, Kim, Suzuki: Hessian-guided Perturbed

    Wasserstein Gradient Flows for Escaping Saddle Points. 2025]
  14. Non-convex objective 16 (Discrete time dynamics) (Continuous time dynamics) Convergence?

    • Wasserstein GF converges a critical point. [Second order optimality] • It can be stacked at a saddle point. • How to escape the saddle point? Wasserstein gradient flow • • Non-convex (no entropy)
  15. 2nd order stationary point/ Saddle point 17 Second order derivative:

    for Def ((𝜀, 𝛿)-second-order stationary point) Assumption: ෩ 𝐻𝜇 ∞ ≤ 𝑅 ∇ 𝛿𝐹 𝛿𝜇 𝐿2(𝜇) Def ((𝜀, 𝛿)-saddler point) (note that, when ∇ 𝛿𝐹 𝛿𝜇 = 0, then ෩ 𝐻𝜇 = 0)
  16. Escape from saddle • W-GF convergences to a critical point:

    18 ➢ If this is an (𝜖, 𝛿)-stationary point, we may finish. ➢ If not, how to escape the saddle point? Finite dimensional case: - Move to the min-eigenvalue direction of Hessian. [Agarwal et al. 2016; Carmon et al. 2016; Nesterov&Polyak 2006] - Random perturbation. [Jin et al. 2017; Li 2019] [Jin et al. 2017] How to perturb probability measures? (infinite dimensional objective) Random perturbation
  17. Gaussian process perturbation 19 Let the “kernel” function be Generate

    a Gaussian process vector field: Then, perturb the distribution as This induces a small random perturbation of the distribution. (𝜂𝑝 > 0 is a small step size) (Hessian)
  18. Escape from saddle point 20 Proposition Suppose that 𝜇+ is

    a (𝜖, 𝛿)-saddle point where 𝜆0 ≔ 𝜆min 𝐻𝜇+ < −𝛿 and 𝜖 = 𝑂 𝛿2 . Then, for the GP-perturbation 𝜉 ∼ 𝐺(0, 𝐾𝜇+), let 𝜇0 = Id + 𝜂𝑝 𝜉 # 𝜇+ as the initial point of W-GF, then it holds that with probability 1 − 𝜁: where [Proof overview] • The GP-perturbation 𝜉 has a direction to the negative eigenvalue direction with a positive probability. • The negative curvature direction is exponentially amplified. ⇒ Escape from the saddle. Random perturbation
  19. Algorithm 21 At a 1st order stationary point, apply the

    Gaussian process perturbation. Check if it was a saddle point. If not, halt the algorithm.
  20. Proof outline 22 Let the Hessian at 𝜇 be Lemma

    The Wasserstein GF 𝜇𝑡 around a critical point 𝜇+ can be written as id + 𝜖𝑣𝑡 #𝜇+ where the velocity field 𝑣𝑡 follows Negative curvature direction exponentially grows up, if the initial point 𝑣0 contains a component toward the minimal eigenvalue direction. The Gaussian process perturbation ensures the negative curvature component. (c.f., Otto calculus)
  21. 23 KL-decomposition: (ONS in 𝐿2(𝜇)) (𝑍𝑗 ∼ 𝑁(0,1)) Then, it

    holds that Let a “shifted” initial-point be Then, they diverge and one of them should be out of the neighbor:
  22. Global optimality for strict saddle 25 𝐹: 𝒫2 ℝ𝑑 →

    ℝ is (𝜖, 𝛿, 𝛼)-strict saddle if one of the following conditions hold for any 𝜇 ∈ 𝒫2 ℝ𝑑 : Def ((𝜀, 𝛿, 𝛼)-strict saddle) (2) (1) (3) Example: • In-context learning (Kim&Suzuki, 2024): • Matrix decomposition (recommendation system):
  23. Global optimality for strict saddle 26 𝐹: 𝒫2 ℝ𝑑 →

    ℝ is (𝜖, 𝛿, 𝛼)-strict saddle if one of the following conditions hold for any 𝜇 ∈ 𝒫2 ℝ𝑑 : Def ((𝜀, 𝛿, 𝛼)-strict saddle) (2) (1) (3) Theorem Suppose that 𝐹: 𝒫2 ℝ𝑑 → ℝ is (𝜖, 𝛿, 𝛼)-strict saddle. Then, after 𝑇 = ෩ O 1 𝜖2 + 1 𝛿4 time, the solution achieves 𝑊2 𝜇 𝑇 , 𝜇∗ ≤ 𝛼, for a global optima 𝜇∗.
  24. Numerical experiment (in-context learning) 29 We compare 3 models with

    𝑑 = 20, 𝑘 = 5, and 500 neurons with sigmoid act. All models are pre-trained using SGD on 10K prompts of 1K token pairs. 1. attention: jointly optimizes ℒ(𝜇, Γ). 2. static: directly minimizes ℒ(𝜇). 3. modified: static model implementing birth-death & GP → verify global convergence as well as improvement for misaligned model (𝑘true = 7) and nonlinear test tasks 𝑔 𝑥 = max 𝑗≤𝑘 ℎ𝜇∘ 𝑥 𝑗 or 𝑔 𝑥 = ℎ𝜇∘ 𝑥 2 .
  25. Presentation overview 30 • [Propagation of chaos] Nitanda, Lee, Kai,

    Sakaguchi, Suzuki: Propagation of Chaos for Mean-Field Langevin Dynamics and its Application to Model Ensemble. ICML2025. • [Optimization of a probability measure on a strict saddle objecive] - Kim, Suzuki: Transformers Learn Nonlinear Features In Context: Nonconvex Mean-field Dynamics on the Attention Landscape. ICML2024, oral. - Yamamto, Kim, Suzuki: Hessian-guided Perturbed Wasserstein Gradient Flows for Escaping Saddle Points. 2025. Part 2: Non-convex (strict-saddle objective) 1. Gaussian process perturbation: A polynomial time method to avoid a saddle point by a “random perturbation” of a probability measure. 2. In-context learning: The objective to train a two-layer transformer is strict-saddle. Mean-field Langevin dynamics (𝐹 𝜇 + 𝜆Ent(𝜇)) • Optimization of probability measures by WG-flow • Particle approximation • Defective log-Sobolev inequality Part 1: Convex (propagation of chaos)