Taiji Suzuki (University of Tokyo, Japan) Convergence of mean field Langevin dynamics and its application to neural network featureย learning
WORKSHOP ON OPTIMAL TRANSPORT
FROM THEORY TO APPLICATIONS
INTERFACING DYNAMICAL SYSTEMS, OPTIMIZATION, AND MACHINE LEARNING
Venue: Humboldt University of Berlin, Dorotheenstraรe 24
neural network feature learning 1 Taiji Suzuki The University of Tokyo / AIP-RIKEN 15th/Mar/2024 Workshop on Optimal Transport, Berlin (Deep learning theory team)
โข Entropy regularization: Application: Training 2-layer NN in mean field regime. [Convergence] โข We introduce mean field Langevin dynamics (MFLD) to minimize โ. โข We show its linear convergence under a log-Sobolev inequality condition. [Generalization error analysis] โข A generalization error analysis of 2-layer NN trained by MFLD is given. โข Separation from kernel methods is shown.
network is basically non-convex. โข Noisy gradient descent (e.g., SGD) is effective for non- convex optimization. Noisy perturbation is helpful to escape a local minimum. โข Likely converges to a flat global minimum.
Discretization [Gelfand and Mitter (1991); Borkar and Mitter (1999); Welling and Teh (2011)] Stationary distribution: Can stay around the global minimum of ๐น(๐ฅ). Regularized loss:
(we can assume it has a density) PDE that describes ๐๐กโs dynamics [Fokker-Planck equation]: [linear w.r.t. ๐] = Stationary distribution This is the Wasserstein gradient flow to minimize the following objective: c.f., Donsker-Varadan duality formula โ
it converge? Naรฏve application of existing theory in gradient Langevin dynamics yields iteration complexity to achieve ๐ error. โ Cannot be applied to wide neural network. [Raginsky, Rakhlin and Telgarsky, 2017; Xu, Chen, Zou, and Gu, 2018; Erdogdu, Mackey and Shamir, 2018; Vempala and Wibisono, 2019] โ
๐ โ โ โฆ โ Mean field limit: Non-linear with respect to the parameters ๐๐ , ๐ค๐ ๐=1 ๐ . Convex w.r.t. ๐ if the loss โ๐ is convex (e.g., squared / logistic loss). [Nitanda&Suzuki, 2017][Chizat&Bach, 2018][Mei, Montanari&Nguyen, 2018][Rotskoff&Vanden-Eijnden, 2018] Linear with respect to ๐.
Fokker-Planck equation of which corresponds to the Wasserstein GF: ๐น Gradient convex strictly convex = strictly convex + Mean field Langevin dynamics: The first variation ๐ฟ๐น ๐ฟ๐ : ๐ซ ร โ๐ โ โ is defined as a continuous functional such as Definition (first variation) GLD: , โข โข
โขThe proximal Gibbs measure is a kind of โtentativeโ target. โขIt plays important role in the convergence analysis. Linearized objective at ๐:
Wu, Suzuki (AISTATS2022)][Chizat (2022)] Assumption (Log-Sobolev inequality) KL-div Fisher-div There exists ๐ผ > 0 such that for any probability measure ๐ (abs. cont. w.r.t. ๐๐), If ๐๐๐ก satisfies the LSI condition for any ๐ก โฅ 0, then This is a non-linear extension of well known GLD convergence analysis. c.f., Polyak-Lojasiewicz condition ๐ ๐ฅ โ ๐ ๐ฅโ โค ๐ถ ๐ป๐ ๐ฅ 2 The rate of convergence is characterized by LSI constant
time dynamics. Question: Can we evaluate a finite particles & discrete time approximation errors? 17 (distribution of ๐๐ก) Neuron ๐ฅ (vector field) (Finite particle approximation)
๐ก = 1 ๐ก = 2 ๐ก = 3 ๐ก = 4 Finite particle approximation error can be amplified through time. โ It is difficult to bound the perturbation uniformly over time. The particles behave as if they are independent as the number of particles increases to infinity. Propagation of chaos [Sznitman, 1991; Lacker, 2021]: โข A naรฏve evaluation gives exponential growth on time: โข Weak interaction/Strong regularization in existing work exp ๐ก /๐ [Mei et al. (2018, Theorem 3)]
smoothness and boundedness of the loss function, it holds that Suppose that ๐๐ satisfies log-Sobolev inequality with a constant ๐ผ. Theorem (One-step update) [Suzuki, Wu, Nitanda (2023)] : proximal Gibbs measure 1. ๐น: ๐ซ โ โ is convex and has a form of ๐ญ ๐ = ๐ณ ๐ + ๐๐ ๐ผ๐ ๐ ๐ . 2. (smoothness) ๐ป๐ฟ๐ฟ ๐ ๐ฟ๐ ๐ฅ โ๐ป๐ฟ๐ฟ ๐ ๐ฟ๐ ๐ฆ โค ๐ถ(๐2 ๐, ๐ + ๐ฅ โ ๐ฆ ) and (boundedness) ๐ป๐ฟ๐ฟ ๐ ๐ฟ๐ ๐ฅ โค ๐ . Assumption: (+ second order differentiability) Naรฏve bound: [Suzuki, Wu, Nitanda: Convergence of mean-field Langevin dynamics: Time and space discretization, stochastic gradient, and variance reduction. arXiv:2306.07221] ๐(๐/๐ด)
๐ ๐ (๐) ๐ณ๐ = ๐ ๐ ๐ ๐=1 ๐ โผ ๐ ๐ ๐ : Joint distribution of ๐ particles. Potential of the joint distribution ๐ ๐ (๐ด) on โ๐ ร๐ด : where (Fisher divergence) where โข The finite particle dynamics is the Wasserstein gradient flow that minimizes . (Approximate) Uniform log-Sobolev inequality [Chen et al. 2022] Recall [Chen, Ren, Wang. Uniform-in-time propagation of chaos for mean field Langevin dynamics. arXiv:2212.03050, 2022.] For any ๐ด, Reference
MFLD. โ How effective is the feature learning of MFLD in terms of generalization error? 25 โข Benefit of feature learning? Neural network vs Kernel method (NTK vs mean field)
๐ ๐ . then it holds that with probability Theorem 1: ๐ผ[Class. Error] โค O exp(โO(๐/๐ 2)) if ๐ โฅ ๐ 2. Theorem 2: Suppose that ๐ = ฮ(1/๐ ) and If we have sufficiently large training data, we have exponential convergence of test error. We only need to evaluate ๐ to obtain a test error bound.
high dimensional data โข ๐ โผ Unif( โ1,1 ๐) (up to freedom of rotation) โข ๐ = ฯ๐=1 ๐ ๐๐ Table 1 of [Telgarsky: Feature selection and low test error in shallow low-rotation ReLu networks, ICLR2023]. Q: Can we learn sparse ๐-parity with GD? Is there any benefit of neural network? โป Suppose that we donโt know which coordinate ๐๐ is aligned to. ๐ = 2: XOR problem ๐ = 3, ๐ = 2 Complexity to learn XOR function (๐ = 2) Only the first ๐-coordinates are informative.
Theorem 1: ๐ผ[Class. Error] โค O exp(โO(๐/๐ 2)) Theorem 2: if ๐ โฅ ๐ 2. ๐โ: KL ๐, ๐โ โค ๐ , ๐๐๐โ ๐ โฅ ๐0 (perfect classifier with margin ๐0) Suppose that there exists ๐โ such that For the ๐-parity problem, we may take Then, Lemma We can evaluate ๐ required for the ๐-sparse parity problem: Reminder
โข Setting 1: ๐ > ๐ โข Test error (classification error) = ๐(exp(โ๐/๐ ๐)) โข Test error (classification error) = ๐( ฮค ๐ ๐) These are better than NTK (kernel method); Sample complexity of NTK ๐ = ๐ ๐ ๐ vs NN ๐ = ๐(๐ ) Trade-off between computational complexity and sample complexity. Our analysis provides โข better sample complexity โข discrete-time/finite-width analysis โข ๐ and ๐ are โdecoupled.โ Corollary (Test accuracy of MFLD) (Computational complexity is exp O ๐ (But, can be relaxed to O(1) if X is anisotropic))
ฮค ๐ ๐). โขComputational complexity is O(exp ๐ ). 33 If data has anisotropic covariance, sample / computational complexities can be much improved. True signal True signal Isotropic data distribution Anisotropic data distribution # of iterations: (๐ = ๐ yields exp(๐)) The data structure affects the complexities.
number of iterations can be bounded by 38 By substituting, then the # of iterations can be summarized as โข Isotropic setting: โข Anisotropic setting: Anisotropic structure mitigate the computational complexity. Especially, there is no exponential dependency on ๐ when ๐ถ = ๐. (assuming ๐ = ๐(1))
Kernel: Mean field NN can โdecoupleโ ๐ and ๐ , while kernel has exponential relation between them. Setting: Thm For arbitrary ๐ฟ > 0, the sample complexity of kernel methods is lower bounded as (kernel)
we may estimate the โinformative directionโ by the gradients at the initialization. โข Then, ๐บ estimates the informative direction. By the following coordinate transformation, we may take ๐ independent of ๐ (exp ๐ โ exp(๐)): True signal True signal Isotropic input (avoiding curse of dim)
sample complexity is optimal for methods with polynomial order computational complexity. [Abbe et al. (2023); Refinetti et al. (2021); Ben Arous et al. (2022); Damian et al. (2022)] โข On the other hand, our analysis is about full-batch GD. 43 Minibatch size # of iterations Sample complexity Our analysis ๐ ๐๐ ๐ SGD (CSQ-lower bound) 1 ๐๐โ1 ๐๐โ1 We obtain a better sample complexity than O(๐๐โ1) with higher computational complexity. โ We can obtain a polynomial order method with MFLD for anisotropic input.
of 2-layer NNs โข Optimizing convex functional โข Convergence guarantee (Wasserstein gradient flow, Uniform-in-time propagation of chaos) โข Generalization error of mean field 2-layer NN โข Fast learning rate โข Sparse ๐-parity problem โขBetter sample complexity than kernel methods โขStructure of data (anisotropic covariance) can improve the complexities. 44 Kernel Mean field Mean field Kernel lower bound Kernel lower bound