Upgrade to PRO for Only $50/Yearโ€”Limited-Time Offer! ๐Ÿ”ฅ

Taiji Suzuki (University of Tokyo, Japan) Conve...

Avatar for Jia-Jie Zhu Jia-Jie Zhu
March 27, 2024
500

Taiji Suzuki (University of Tokyo, Japan) Convergence of mean field Langevin dynamics and its application to neural network featureย learning

WORKSHOP ON OPTIMAL TRANSPORT
FROM THEORY TO APPLICATIONS
INTERFACING DYNAMICAL SYSTEMS, OPTIMIZATION, AND MACHINE LEARNING
Venue: Humboldt University of Berlin, DorotheenstraรŸe 24

Berlin, Germany. March 11th - 15th, 2024

Avatar for Jia-Jie Zhu

Jia-Jie Zhu

March 27, 2024
Tweet

More Decks by Jia-Jie Zhu

Transcript

  1. Convergence of mean field Langevin dynamics and its application to

    neural network feature learning 1 Taiji Suzuki The University of Tokyo / AIP-RIKEN 15th/Mar/2024 Workshop on Optimal Transport, Berlin (Deep learning theory team)
  2. Outline of this talk 2 Convex โ€ข ๐น is convex:

    โ€ข Entropy regularization: Application: Training 2-layer NN in mean field regime. [Convergence] โ€ข We introduce mean field Langevin dynamics (MFLD) to minimize โ„’. โ€ข We show its linear convergence under a log-Sobolev inequality condition. [Generalization error analysis] โ€ข A generalization error analysis of 2-layer NN trained by MFLD is given. โ€ข Separation from kernel methods is shown.
  3. Feature learning of NN 3 Benefit of feature learning with

    optimization guarantee. โ€ข [Computation] Suzuki, Wu, Nitanda: โ€œConvergence of mean-field Langevin dynamics: Time and space discretization, stochastic gradient, and variance reduction.โ€ NeurIPS2023. โ€ข [Generalization] โžข Suzuki, Wu, Oko, Nitanda: โ€œFeature learning via mean-field Langevin dynamics: Classifying sparse parities and beyond.โ€ NeurIPS2023. โžข Nitanda, Oko, Suzuki, Wu: โ€œAnisotropy helps: improved statistical and computational complexity of the mean-field Langevin dynamics under structured data.โ€ ICLR2024. Especially, we compare the generalization error between neural networks and kernel methods. Trade-off: Statistical complexity vs computational complexity Feature learning Optimization Neural network โœ“ Non-convex Kernel method ร— Convex
  4. Noisy gradient descent 5 Noise Gradient descent Optimization of neural

    network is basically non-convex. โžข Noisy gradient descent (e.g., SGD) is effective for non- convex optimization. Noisy perturbation is helpful to escape a local minimum. โžข Likely converges to a flat global minimum.
  5. Gradient Langevin Dynamics (GLD)6 (Gradient Langevin dynamics) (Non-convex) (Euler-Maruyama scheme)

    Discretization [Gelfand and Mitter (1991); Borkar and Mitter (1999); Welling and Teh (2011)] Stationary distribution: Can stay around the global minimum of ๐น(๐‘ฅ). Regularized loss:
  6. GLD as a Wasserstein gradient flow7 : Distribution of ๐‘‹๐‘ก

    (we can assume it has a density) PDE that describes ๐œ‡๐‘กโ€™s dynamics [Fokker-Planck equation]: [linear w.r.t. ๐] = Stationary distribution This is the Wasserstein gradient flow to minimize the following objective: c.f., Donsker-Varadan duality formula โ„’
  7. 2-layer NN in mean-field scaling 8 โ€ข 2-layer neural network:

    Non-linear with respect to parameters ๐‘Ÿ ๐‘— , ๐‘ค๐‘— ๐‘—=1 ๐‘€ . where ๐‘‹(๐‘—) = ๐‘Ÿ๐‘— , ๐‘ค๐‘— and Regularized empirical risk: Non-convex Loss L2 regularization
  8. Noisy gradient descent 9 Noisy gradient descent update (GLD): Does

    it converge? Naรฏve application of existing theory in gradient Langevin dynamics yields iteration complexity to achieve ๐œ– error. โ†’ Cannot be applied to wide neural network. [Raginsky, Rakhlin and Telgarsky, 2017; Xu, Chen, Zou, and Gu, 2018; Erdogdu, Mackey and Shamir, 2018; Vempala and Wibisono, 2019] โ‡”
  9. Mean field limit 10 Loss function (empirical risk + regularization):

    ๐‘€ โ†’ โˆž โ€ฆ โ˜…Mean field limit: Non-linear with respect to the parameters ๐‘Ÿ๐‘— , ๐‘ค๐‘— ๐‘—=1 ๐‘€ . Convex w.r.t. ๐œ‡ if the loss โ„“๐‘– is convex (e.g., squared / logistic loss). [Nitanda&Suzuki, 2017][Chizat&Bach, 2018][Mei, Montanari&Nguyen, 2018][Rotskoff&Vanden-Eijnden, 2018] Linear with respect to ๐œ‡.
  10. General form of mean field LD 11 โžข SDE the

    Fokker-Planck equation of which corresponds to the Wasserstein GF: ๐น Gradient convex strictly convex = strictly convex + Mean field Langevin dynamics: The first variation ๐›ฟ๐น ๐›ฟ๐œ‡ : ๐’ซ ร— โ„๐‘‘ โ†’ โ„ is defined as a continuous functional such as Definition (first variation) GLD: , โžข โžข
  11. MF-LD to optimize mean field NN 12 Loss function: (distribution

    of ๐‘‹๐‘˜ ) Neuron โ„Ž๐‘ฅ (โ‹…) ๐‘ฅ Discrete time MFLD:
  12. Proximal Gibbs measure 13 ๐น Gradient Minimizer Proximal Gibbs measure

    โžขThe proximal Gibbs measure is a kind of โ€œtentativeโ€ target. โžขIt plays important role in the convergence analysis. Linearized objective at ๐:
  13. Convergence rate 14 Proximal Gibbs measure: Theorem (Linear convergence) [Nitanda,

    Wu, Suzuki (AISTATS2022)][Chizat (2022)] Assumption (Log-Sobolev inequality) KL-div Fisher-div There exists ๐›ผ > 0 such that for any probability measure ๐œˆ (abs. cont. w.r.t. ๐‘๐œ‡), If ๐‘๐œ‡๐‘ก satisfies the LSI condition for any ๐‘ก โ‰ฅ 0, then This is a non-linear extension of well known GLD convergence analysis. c.f., Polyak-Lojasiewicz condition ๐‘“ ๐‘ฅ โˆ’ ๐‘“ ๐‘ฅโˆ— โ‰ค ๐ถ ๐›ป๐‘“ ๐‘ฅ 2 The rate of convergence is characterized by LSI constant
  14. Log-Sobolev inequality 15 L2-regularized loss function for mean field 2-layer

    NN: Proximal Gibbs: If sup ๐‘ง โ„“๐‘– โ€ฒ ๐‘“๐œ‡ (โ‹…) โ„Ž๐‘ฅ (โ‹…) โ‰ค ๐ต, the proximal Gibbs measure ๐‘๐œ‡ satisfies the LSI with a constant ๐›ผ with โˆต Bakry-Emery criterion (1985) and Holley-Strook bounded perturbation lemma (1987) Bounded (โ‰ค ๐ต) Strongly convex where Gaussian Bounded perturbation
  15. We have obtained a convergence of infinite width and continuous

    time dynamics. Question: Can we evaluate a finite particles & discrete time approximation errors? 17 (distribution of ๐‘‹๐‘ก) Neuron ๐‘ฅ (vector field) (Finite particle approximation)
  16. Difficulty โ€ข SDE of interacting particles (McKean, Kac,โ€ฆ, 60โ€™) 18

    ๐‘ก = 1 ๐‘ก = 2 ๐‘ก = 3 ๐‘ก = 4 Finite particle approximation error can be amplified through time. โ†’ It is difficult to bound the perturbation uniformly over time. The particles behave as if they are independent as the number of particles increases to infinity. Propagation of chaos [Sznitman, 1991; Lacker, 2021]: โ€ข A naรฏve evaluation gives exponential growth on time: โžข Weak interaction/Strong regularization in existing work exp ๐‘ก /๐‘€ [Mei et al. (2018, Theorem 3)]
  17. Practical algorithm 19 โ€ข Time discretization: ๐‘ก โ†’ ๐‘˜๐œ‚ (๐œ‚:

    step size, ๐‘˜: # of steps) โ€ข Space discretization: ๐œ‡๐‘ก is approximatd by ๐‘€ particles โ€ข Stochastic gradient: ๐›ป ๐›ฟ๐น ๐œ‡ ๐›ฟ๐œ‡ โ†’ ๐‘ฃ๐‘˜ ๐‘– ๐œ‡๐‘ก โ†’ เทœ ๐œ‡๐‘˜ = 1 ๐‘€ โˆ‘๐›ฟ ๐‘‹ ๐‘˜ (๐‘–) where and (stochastic gradient) (space discretization) (time discretization) โžข Noisy gradient descent on 2-layer NN with finite width. ๐‘€ particles ๐‘‹ ๐‘˜ ๐‘– ๐‘–=1 ๐‘€
  18. Convergence analysis 20 Time discr. Space discr. Stochastic approx. Under

    smoothness and boundedness of the loss function, it holds that Suppose that ๐‘๐œ‡ satisfies log-Sobolev inequality with a constant ๐›ผ. Theorem (One-step update) [Suzuki, Wu, Nitanda (2023)] : proximal Gibbs measure 1. ๐น: ๐’ซ โ†’ โ„ is convex and has a form of ๐‘ญ ๐ = ๐‘ณ ๐ + ๐€๐Ÿ ๐”ผ๐ ๐’™ ๐Ÿ . 2. (smoothness) ๐›ป๐›ฟ๐ฟ ๐œ‡ ๐›ฟ๐œ‡ ๐‘ฅ โˆ’๐›ป๐›ฟ๐ฟ ๐œˆ ๐›ฟ๐œ‡ ๐‘ฆ โ‰ค ๐ถ(๐‘Š2 ๐œ‡, ๐œˆ + ๐‘ฅ โˆ’ ๐‘ฆ ) and (boundedness) ๐›ป๐›ฟ๐ฟ ๐œ‡ ๐›ฟ๐œ‡ ๐‘ฅ โ‰ค ๐‘…. Assumption: (+ second order differentiability) Naรฏve bound: [Suzuki, Wu, Nitanda: Convergence of mean-field Langevin dynamics: Time and space discretization, stochastic gradient, and variance reduction. arXiv:2306.07221] ๐Ž(๐Ÿ/๐‘ด)
  19. Uniform log-Sobolev inequality 21 ๐‘‹ ๐‘˜ (1) ๐‘‹ ๐‘˜ (2)

    ๐‘‹ ๐‘˜ (๐‘) ๐’ณ๐‘˜ = ๐‘‹ ๐‘˜ ๐‘– ๐‘–=1 ๐‘€ โˆผ ๐œ‡ ๐‘˜ ๐‘€ : Joint distribution of ๐‘€ particles. Potential of the joint distribution ๐ ๐’Œ (๐‘ด) on โ„๐’…ร—๐‘ด : where (Fisher divergence) where โžข The finite particle dynamics is the Wasserstein gradient flow that minimizes . (Approximate) Uniform log-Sobolev inequality [Chen et al. 2022] Recall [Chen, Ren, Wang. Uniform-in-time propagation of chaos for mean field Langevin dynamics. arXiv:2212.03050, 2022.] For any ๐‘ด, Reference
  20. Computational complexity 22 Time discr. Space discr. Stochastic approx. SG-MFLD

    Iteration complexity: to achieve ๐œ– + ๐‘‚(1/(๐œ†2 ๐›ผ๐‘)) accuracy. By setting , the iteration complexity becomes โžข ๐ต = 1/(๐œ†2 ๐›ผ๐œ–) is the optimal mini-batch size. โ†’ ๐‘˜ = ๐‘‚ ฮค log ๐œ–โˆ’1 ๐œ– . (finite sum), (stochastic gradient) (Mini-batch size = ๐ต) โžขApproximation errors are uniform in time. โžขNo exponential dependency on ๐‘ด (number of neurons).
  21. Numerical experiment 23 Test error v.s. Number of steps (regularization

    term: ๐‘Ÿ ๐‘ฅ = ๐‘ฅ 2) ๐‘€ โ†’ โˆž
  22. Generalization error analysis So far, we have obtained convergence of

    MFLD. โ‡’ How effective is the feature learning of MFLD in terms of generalization error? 25 โ€ข Benefit of feature learning? Neural network vs Kernel method (NTK vs mean field)
  23. Classification task 26 Problem setting (classification): โžขLogistic loss: โ„“ ๐‘ฆ๐‘“

    = log(1 + exp(โˆ’๐‘ฆ๐‘“)) โžขtanh activation: โ„Ž๐‘ฅ ๐‘ง = เดค ๐‘… โ‹… [tanh ๐‘ฅ1 , ๐‘ง + ๐‘ฅ2 + 2 โ‹… tanh ๐‘ฅ3 ]/3 Loss function and model: (+1) (-1)
  24. Assumptions There exists ๐œ‡โˆ— such that 1. KL ๐œˆ, ๐œ‡โˆ—

    โ‰ค ๐‘…, 2. ๐‘Œ๐‘“๐œ‡โˆ— ๐‘ โ‰ฅ ๐‘0 for some constants ๐‘…, ๐‘0 > 0. 27 Assumption where ๐œˆ = ๐‘(0, ๐œ†/(2๐œ†1 )). Objective of MFLD: The Bayes classifier is attained by ๐โˆ— with a bounded KL-div from ๐‚. (a. s. ), KL-regularization ๐‘0 ๐‘Œ = 1 ๐‘Œ = โˆ’1 ๐‘“๐œ‡โˆ—(๐‘ง) ๐‘ง supp(๐‘ƒ๐‘ ) supp(๐‘ƒ๐‘ ) (+ classification calibration condition)
  25. Main theorem 28 Suppose that ๐œ† = ฮ˜( ฮค 1

    ๐‘…), then it holds that with probability 1 โˆ’ exp โˆ’๐‘ก . Class. error Theorem 1 O เดค ๐‘…2๐‘… ๐‘› โ€ข โ„Ž๐‘ฅ ๐‘ง = เดค ๐‘… โ‹… [tanh ๐‘ฅ1 , ๐‘ง + ๐‘ฅ2 + 2 โ‹… tanh ๐‘ฅ3 ]/3 โ€ข ๐œ‡โˆ—: KL ๐œˆ, ๐œ‡โˆ— โ‰ค ๐‘…, ๐‘Œ๐‘“๐œ‡โˆ— ๐‘ โ‰ฅ ๐‘0 Existing bound: Chen et al. (2020); Nitanda, Wu, Suzuki (2021) Class. Error โ‰ค O 1 ๐‘› . (Rademacher complexity bound) โ€ข Our bound provides fast learning rate (faster than 1/ ๐‘›). O ฮค ๐‘… ๐‘› โ‰ช O ฮค 1 ๐‘›
  26. Main theorem 2 29 Theorem 2 ๐”ผ[Class. Error] โ‰ค O

    ๐‘… ๐‘› . then it holds that with probability Theorem 1: ๐”ผ[Class. Error] โ‰ค O exp(โˆ’O(๐‘›/๐‘…2)) if ๐‘› โ‰ฅ ๐‘…2. Theorem 2: Suppose that ๐œ† = ฮ˜(1/๐‘…) and If we have sufficiently large training data, we have exponential convergence of test error. We only need to evaluate ๐‘… to obtain a test error bound.
  27. Example: k-sparse parity problem 30 โ€ข ๐‘˜-sparse parity problem on

    high dimensional data โžข ๐‘ โˆผ Unif( โˆ’1,1 ๐‘‘) (up to freedom of rotation) โžข ๐‘Œ = ฯ‚๐‘—=1 ๐‘˜ ๐‘๐‘— Table 1 of [Telgarsky: Feature selection and low test error in shallow low-rotation ReLu networks, ICLR2023]. Q: Can we learn sparse ๐’Œ-parity with GD? Is there any benefit of neural network? โ€ป Suppose that we donโ€™t know which coordinate ๐‘๐‘— is aligned to. ๐‘˜ = 2: XOR problem ๐‘‘ = 3, ๐‘˜ = 2 Complexity to learn XOR function (๐‘˜ = 2) Only the first ๐‘˜-coordinates are informative.
  28. Generalization bound 31 ๐”ผ[Class. Error] โ‰ค O ๐‘… ๐‘› .

    Theorem 1: ๐”ผ[Class. Error] โ‰ค O exp(โˆ’O(๐‘›/๐‘…2)) Theorem 2: if ๐‘› โ‰ฅ ๐‘…2. ๐œ‡โˆ—: KL ๐œˆ, ๐œ‡โˆ— โ‰ค ๐‘…, ๐‘Œ๐‘“๐œ‡โˆ— ๐‘ โ‰ฅ ๐‘0 (perfect classifier with margin ๐‘0) Suppose that there exists ๐œ‡โˆ— such that For the ๐‘˜-parity problem, we may take Then, Lemma We can evaluate ๐‘… required for the ๐‘˜-sparse parity problem: Reminder
  29. Generalization error bound โ€ข Setting 2: ๐‘› > ๐‘‘2 32

    โ€ข Setting 1: ๐‘› > ๐‘‘ โžข Test error (classification error) = ๐Ž(exp(โˆ’๐’/๐’…๐Ÿ)) โžข Test error (classification error) = ๐Ž( ฮค ๐’… ๐’) These are better than NTK (kernel method); Sample complexity of NTK ๐’ = ๐›€ ๐’…๐’Œ vs NN ๐’ = ๐Ž(๐’…) Trade-off between computational complexity and sample complexity. Our analysis provides โ€ข better sample complexity โ€ข discrete-time/finite-width analysis โ€ข ๐‘‘ and ๐‘˜ are โ€œdecoupled.โ€ Corollary (Test accuracy of MFLD) (Computational complexity is exp O ๐‘‘ (But, can be relaxed to O(1) if X is anisotropic))
  30. Anisotropic data structure โ€ข Isotropic data: โžขTest error bound: O(

    ฮค ๐‘‘ ๐‘›). โžขComputational complexity is O(exp ๐‘‘ ). 33 If data has anisotropic covariance, sample / computational complexities can be much improved. True signal True signal Isotropic data distribution Anisotropic data distribution # of iterations: (๐‘… = ๐‘‘ yields exp(๐‘‘)) The data structure affects the complexities.
  31. Anisotropic k-parity setting 34 Input Label (k-sparse parity) +1 +1

    -1 -1 +1 +1 -1 -1 Example: coordinate wise scaling where ๐‘ ๐‘— is the scaling factor s.t. โˆ‘ ๐‘—=1 ๐‘‘ ๐‘ ๐‘— 2 = 1. โ€ข Power low decay: ๐‘ ๐‘— 2 โ‰ ๐‘—โˆ’๐›ผ/๐‘‘1โˆ’๐›ผ โ€ข Spiked covariance: ๐‘ ๐‘— 2 โ‰ ๐‘‘๐›ผโˆ’1 ๐‘— โˆˆ ๐‘˜ ๐‘ ๐‘— 2 โ‰ ๐‘‘โˆ’1 ๐‘— โˆˆ ๐‘˜ + 1, ๐‘› (๐›ผ โˆˆ [0,1]) (๐‘๐‘— = ยฑ๐‘ ๐‘— )
  32. Generalization error bound 35 Anisotropic ๐’Œ-sparse parity with coordinate wise

    scaling (we assume โˆ‘ ๐‘—=1 ๐‘‘ ๐‘ ๐‘— 2 = 1) Input: Label: โ€ข Setting 2: ๐‘› > ๐‘†๐‘˜ 2 โ€ข Setting 1: ๐‘› > ๐‘†๐‘˜ โžข Test error (classification error) = โžข Test error (classification error) = Corollary 1. ๐”ผ[Class. Error] โ‰ค O ๐‘… ๐‘› . 2. ๐”ผ[Class. Error] โ‰ค O exp(โˆ’O(๐‘›/๐‘…2)) if ๐‘› โ‰ฅ ๐‘…2. Lemma
  33. Example 36 (with โˆ‘ ๐‘—=1 ๐‘‘ ๐‘ ๐‘— 2 = 1)

    Input: โ€ข Setting 2: ๐‘› > ๐‘†๐‘˜ 2 โ€ข Setting 1: ๐‘› > ๐‘†๐‘˜ โžข Test error (classification error) = โžข Test error (classification error) = Test error = โ€ข Isotropic setting (๐’”๐’‹ ๐Ÿ = ๐Ÿ/๐’…): Test error = 1. Power low decay (๐’”๐’‹ ๐Ÿ โ‰ ๐’‹โˆ’๐œถ): 2. Spiked covariance Test error = Test error โ€ข Anisotropic setting:
  34. 37 Anisotropic Isotropic k-signal components Noisy components k-signal components Noisy

    components Large signal Small noise Small signal Large noise ๐’‹โˆ’๐œถ
  35. Computational complexity When ๐œ† = ๐‘‚( ฮค 1 ๐‘…), the

    number of iterations can be bounded by 38 By substituting, then the # of iterations can be summarized as โ€ข Isotropic setting: โ€ข Anisotropic setting: Anisotropic structure mitigate the computational complexity. Especially, there is no exponential dependency on ๐’… when ๐œถ = ๐Ÿ. (assuming ๐‘˜ = ๐‘‚(1))
  36. Kernel lower bound 39 When ๐‘˜ = ๐‘˜โˆ—, Mean field:

    Kernel: Mean field NN can โ€œdecoupleโ€ ๐’Œ and ๐’…, while kernel has exponential relation between them. Setting: Thm For arbitrary ๐›ฟ > 0, the sample complexity of kernel methods is lower bounded as (kernel)
  37. Coordinate transform 40 โ€ข When the input ๐‘ is isotropic,

    we may estimate the โ€œinformative directionโ€ by the gradients at the initialization. โ€ข Then, ๐บ estimates the informative direction. By the following coordinate transformation, we may take ๐‘… independent of ๐‘‘ (exp ๐‘‘ โ†’ exp(๐‘˜)): True signal True signal Isotropic input (avoiding curse of dim)
  38. Sample complexity to compute G 41 By using training data

    with size Then, ๐‘… can be modified as we can compute ๐บ as Without ๐บ estimation: ๐‘… = ฮฉ(๐‘‘). Isotropic setting:
  39. Summary of the result 42 Upper bound for NN (our

    result) Lower bound for kernel method (our result) Improve sample complexity
  40. Discussion โ€ข The CSQ lower bound states that O ๐‘‘๐‘˜โˆ’1

    sample complexity is optimal for methods with polynomial order computational complexity. [Abbe et al. (2023); Refinetti et al. (2021); Ben Arous et al. (2022); Damian et al. (2022)] โ€ข On the other hand, our analysis is about full-batch GD. 43 Minibatch size # of iterations Sample complexity Our analysis ๐’ ๐’†๐’… ๐’… SGD (CSQ-lower bound) 1 ๐‘‘๐‘˜โˆ’1 ๐‘‘๐‘˜โˆ’1 We obtain a better sample complexity than O(๐‘‘๐‘˜โˆ’1) with higher computational complexity. โ†’ We can obtain a polynomial order method with MFLD for anisotropic input.
  41. Conclusion โ€ข Mean field Langevin dynamics โžข Mean field representation

    of 2-layer NNs โžข Optimizing convex functional โžข Convergence guarantee (Wasserstein gradient flow, Uniform-in-time propagation of chaos) โ€ข Generalization error of mean field 2-layer NN โžข Fast learning rate โžข Sparse ๐‘˜-parity problem โžขBetter sample complexity than kernel methods โžขStructure of data (anisotropic covariance) can improve the complexities. 44 Kernel Mean field Mean field Kernel lower bound Kernel lower bound