Upgrade to Pro — share decks privately, control downloads, hide ads and more …

Numerical Analysis on Neural Network Projected ...

Wuchen Li
March 23, 2024
140

Numerical Analysis on Neural Network Projected Schemes for Approximating One Dimensional Wasserstein Gradient Flows

We provide a numerical analysis and computation of neural network projected schemes for approximating one dimensional Wasserstein gradient flows. We approximate the Lagrangian mapping functions of gradient flows by the class of two-layer neural network functions with ReLU (rectified linear unit) activation functions. The numerical scheme is based on a projected gradient method, namely the Wasserstein natural gradient, where the projection is constructed from the mapping spaces onto the neural network parameterized mapping space. We establish theoretical guarantees for the performance of the neural projected dynamics. We derive a closed-form update for the scheme with well-posedness and explicit consistency guarantee for a particular choice of network structure. General truncation error analysis is also established on the basis of the projective nature of the dynamics. Numerical examples, including gradient drift Fokker-Planck equations, porous medium equations, and Keller-Segel models, verify the accuracy and effectiveness of the proposed neural projected algorithm.

Wuchen Li

March 23, 2024
Tweet

Transcript

  1. Numerical analysis on neural network projected schemes for approximating one

    dimensional Wasserstein gradient flows Wuchen Li University of South Carolina AMS sectional meeting, Florida State University March 24, 2024. Supported by AFOSR MURI and YIP, and NSF RTG and FRG.
  2. Samples from image distributions Figure: Results from Wasserstein Generative adversary

    networks. They are samples from image distributions π. 4
  3. Generative probability models Consider a class of invertible push-forward maps

    {fθ }θ∈Θ indexed by parameter θ ∈ Θ ⊂ RD fθ : Rd → Rd. We obtain a family of parametric distributions PΘ = ρθ = fθ# p | θ ∈ Θ ⊂ (P, gW ). 5
  4. Main question Neural network projected dynamics and its numerical analysis

    (i) What are neural network projected dynamics for Wasserstein gradient flows in Lagrangian coordinates? (ii) In 1D, how well does the neural network projected dynamic approxi- mate the Wasserstein gradient flow? 7
  5. Background: Natural gradient methods Natural gradient methods compute the gradient

    flow of certain functional on the probability manifold P by projecting it onto a parameter space Θ, which is equipped with the metric tensor obtained by pulling back the canonical metric defined in P. • First introduced in [Amari, 1998] for using Fisher-Rao manifold on P; • Natural gradient method on Wasserstein manifold was introduced in [Li, Guido, 2019], [Chen, Li 2020]; • Studies on the Wasserstein information matrix (i.e., the pull-backed metric tensor on Θ) were conducted in [Li, Zhao 2020], [Liu, et al. 2022], [Li, Zhao 2023], etc. 8
  6. Wasserstein natural gradient flows • Survey paper from Carrillo et.al.

    Lagrangian schemes for Wasserstein gradient flows, 2021. • Lin et.al. Wasserstein Proximal of GANs, 2021. • Liu et.al. Neural Parametric Fokker-Planck Equations, 2022. • Lee et.al. Deep JKO: time-implicit particle methods for general nonlinear gradient flows, 2023. It is a hot/good area in scientific computing and data science communities. See [Chen Xu, et.al. 2022], [Mokrov, et al. 2021], [Fan, et al. 2022], [Bonet, et al. 2022], [Hertrich et.al. 2023]. In general evolutionary dynamics, see [Du et al. 2021], [Anderson et al. 2021], [Bruna et al. 2022], [Gaby et al. 2023], etc. 9
  7. “Least square” in fluid dynamics Optimal transport has an optimal

    control formulation as follows: inf ρt∈P(Ω) 1 0 gW (∂t ρt , ∂t ρt )dt = 1 0 Ω (∇Φt , ∇Φt )ρt dxdt, under the dynamical constraint, i.e. continuity equation: ∂t ρt + ∇ · (ρt ∇Φt ) = 0, ρ0 = ρ0, ρ1 = ρ1. Here, (P(Ω), gW ) forms an infinite-dimensional Riemannian manifold1 with the L2-Wassersterin metric: gW (p) = (−∇ · (p∇))−1. 1John D. Lafferty: the density manifold and configuration space quantization, 1988. 10
  8. Wasserstein gradient flow (Eulerian formulation) Denote gradW the gradient operator

    on (P, gW ). Consider a smooth functional F: P → R. The Wasserstein gradient flow of F(p) is the time-evolution PDE: ∂t p = −gradW F(p) ≜ −gW (p)−1 δ δp F(p) = ∇ · (p∇ δ δp F(p)). (1) 11
  9. Famous example from Jordan-Kinderleherer-Otto Consider a free energy: F(p) =

    V (x)p(x)dx + γ p(x) log p(x)dx. In this case, the optimal transport gradient flow satisfies: ∂p ∂t = − gW (p)−1 δ δp F(p) =∇ · (p∇V ) + γ∇ · (p∇ log p) =∇ · (p∇V ) + γ∆p, where p∇ log p = ∇p. In this case, the gradient flow is the Fokker-Planck equation. 12
  10. Wasserstein gradient flow (Lagrangian formulation) Define the pushforward of pr

    by T as p(x) = T# pr (x) ≜ pr |det(Dz T)| ◦ T−1(x). Then, F(p) can be viewed as a functional of T, F#(T) ≜ F(T# pr ). Wasserstein gradient flow of F#(T) yields the dynamic: ∂t T(t, ·) = −gradW F#(T(t, ·)) ≜ − 1 pr (·) δF#(T(t, ·)) δT (·). (2) 13
  11. Examples in Eulerian and Lagrangian coordinates Fact [Carrillo, et al.

    2021] Suppose p(t, ·) and T(t, ·) solve the equations (1), (2) respectively. Assume that the initial data are consistent, i.e., p(0, ·) = T(0, ·)# pr . Then p(t, ·) = T(t, ·)# pr for t > 0. Examples 14
  12. Neural network functions Consider a mapping function f : Z

    × Θ → Ω, where Z ⊂ Rl is the latent space, Ω ⊂ Rd is the sample space and Θ ⊂ RD is the parameter space. f(θ, z) = 1 N N i=1 ai σ z − bi , where θ = (ai , bi ) ∈ RD, D = (l + 1)N. Here N is the number of hidden units (neurons). ai ∈ R is the weight of unit i. bi ∈ Rl is an offset (location variable). σ: R → R is an activation function, which satisfies σ(0) = 0, 1 ∈ ∂σ(0). z σ x Example (ReLU) Denote σ(x) = max{x, 0}. Suppose N = d = 1, then f(θ, z) = θ max{z, 0}, θ ∈ R+ . 16
  13. Models Definition (Neural mapping models) Let us define a fixed

    input reference probability density pr ∈ P(Z) = p(z) ∈ C∞(Z): Z pr (z)dz = 1, p(z) ≥ 0 . Denote a probability density generated by a neural network mapping function by the pushforward operator: p = fθ# pr ∈ P(Ω), In other words, p satisfies the following Monge-Amp` ere equation by p(f(θ, z))det(Dz f(θ, z)) = pr (z) , where Dz f(θ, z) is the Jacobian of the mapping function f(θ, z) w.r.t. variable z. 17
  14. Energies Definition (Neural mapping energies) Given an energy functional F:

    P(Ω) → R, we can construct a neural mapping energy F : Θ → R by F(θ) = F(fθ# pr ). 18
  15. Neural information distance Definition (Neural mapping distance) Define a distance

    function DistW : Θ × Θ → R as DistW (fθ0 # pr , fθ1 # pr )2 = Z ∥f(θ0, z) − f(θ1, z)∥2pr (z)dz = d m=1 Ez∼pr ∥fm (θ0, z) − fm (θ1, z)∥2 , where θ0, θ1 ∈ Θ are two sets of neural network parameters and ∥ · ∥ is the Euclidean norm in Rd. 19
  16. Neural information matrix Consider the Taylor expansion of the distance

    function. Let ∆θ ∈ RD, DistW (fθ+∆θ# pr , fθ# pr )2 = d m=1 Ez∼pr ∥fm (θ + ∆θ, z) − fm (θ, z)∥2 = d m=1 D i=1 D j=1 Ez∼pr ∂θi fm (θ, z)∂θj fm (θ, z) ∆θi ∆θj + o(∥∆θ∥2) =∆θT GW (θ)∆θ + o(∥∆θ∥2). Here GW is a Gram-type matrix function: GW (θ) = Z ∇θ f(θ, z)∇θ f(θ, z)T pr (z)dz. 20
  17. Neural network projected gradient flows Consider an energy functional F:

    P(Ω) → R. Then the gradient flow of function F(θ) = F(fθ# pr ) in (Θ, GW ) is given by dθ dt = −gradW F(θ) = −GW (θ)−1∇θ F(θ). In particular, dθi dt = − D j=1 d m=1 Ez∼pr ∇θ f(θ, z)∇θ f(θ, z)T −1 ij · E˜ z∼pr ∇xm δ δp F(p)(f(θ, ˜ z)) · ∂θj fm (θ, ˜ z) , where δ δp(x) is the L2–first variation w.r.t. variable p(x), x = f(θ, z). 21
  18. Example I: Linear energy For the potential energy: F(p) =

    Ω V (x)p(x)dx. The Wasserstein gradient flow satisfies ∂t p(t, x) = ∇x · p(t, x)∇x V (x) . The neural projected dynamic satisfies dθi dt = − D j=1 Ez∼pr ∇θ f(θ, z)∇θ f(θ, z)T −1 ij · E˜ z∼pr ∇x V (f(θ, ˜ z)) · ∂θj f(θ, ˜ z) . 22
  19. Example II: Interaction energy For the interaction energy: F(p) =

    1 2 Ω Ω W(x, y)p(x)p(y)dxdy. The Wasserstein gradient flow satisfies ∂t p(t, x) = ∇x · p(t, x) Ω ∇x W(x, y)p(t, y)dy . The neural projected dynamic satisfies dθi dt = − D j=1 Ez∼pr ∇θ f(θ, z)∇θ f(θ, z)T −1 ij · E(z1,z2)∼pr×pr ∇x1 W(f(θ, z1 ), f(θ, z2 )) · ∂θj f(θ, z1 ) . 23
  20. Example III: Internal energy For the internal energy: F(p) =

    Ω U(p(x))dx. The Wasserstein gradient flow satisfies ∂t p(t, x) = ∇x · p(t, x)∇x U′(p(t, x)) . The neural projected dynamic satisfies dθi dt = − D j=1 Ez∼pr ∇θ f(θ, z)∇θ f(θ, z)T −1 ij · Ez∼pr − tr Dz f(θ, z)−1 : ∂θj Dz f(θ, z) ˆ U′( pr (z) det(Dz f(θ, z)) ) pr (z) det(Dz f(θ, z)) , where ˆ U(z) = 1 z U(z). 24
  21. Algorithm We summarize the above explicitly update formulas below. Algorithm

    Projected Wasserstein gradient flows Input: Initial parameters θ ∈ RD; stepsize h > 0, total number of steps L, samples {zi }M i=1 ∼ pr for estimating ˜ GW (θ) and ∇θ ˜ F(θ). for k = 1, 2, . . . , L do θk+1 = θk − h ˜ GW (θk)−1∇θ ˜ F(θk); end for 25
  22. Numerical analysis What is the order of consistency of the

    proposed method, starting from 1D sample space? 26
  23. Analytic formula for NIM For a special case of 2-layer

    ReLU neural network with f(θ, z) = 1 N N i=1 ai σ(z − bi ), where only bi , i = 1, · · · , N can vary, we prove that the inverse matrix has the closed form 1 N2 G−1 W (b) ij =              − 1 ai ai−1 1 F0 (bi ) − F0 (bi−1 ) , j = i − 1, − 1 ai ai+1 1 F0 (bi+1 ) − F0 (bi ) , j = i + 1, 0, o.w. where F0 (·) is the cumulative distribution function of pr (·), i.e., F0 (x) = x −∞ pr (y)dy. 27
  24. Consistency via analytic formula Consider ∂t p(t, x) = ∇

    · (p(t, x)∇V (x)). Its neural projected dynamics satisfies dθ dt = −G−1 W (θ) · E˜ z∼pr ∇θ V (f(θ, ˜ z)) . The above ODE has a closed-form update: ˙ bi = N ai Ez∼pr [V ′(f(b, z))1[bi,bi+1] ] F0 (bi+1 ) − F0 (bi ) − Ez∼pr [V ′(f(b, z))1[bi−1,bi] ] F0 (bi ) − F0 (bi−1 ) , We need to show that the numerical scheme in the third line is consistent w.r.t. the original PDE in the first line. Proposition (Consistency of the projected potential flow) Assume potential functional satisfies ∥V ′′∥ ∞ < ∞. The spatial discretization is of first-order accuracy both in the mapping and the density coordinates. 28
  25. Well-posedness for projected heat flow The project gradient flow of

    the entropy functional, i.e. the heat flow is given by ˙ bi = 1 F0 (bi ) − F0 (bi−1 )    log i−1 j=1 aj i j=1 aj pr (bi ) a2 i − log i−2 j=1 aj i−1 j=1 aj pr (bi−1 ) ai ai−1    + 1 F0 (bi+1 ) − F0 (bi )    log i−1 j=1 aj i j=1 aj pr (bi ) a2 i − log i j=1 aj i+1 j=1 aj pr (bi+1 ) ai ai+1    . This is a nonlinear ODE although the original heat equation is a linear PDE, but we can prove the well-posedness of the projected dynamics by analyzing the behavior of the adjacent modes. Proposition The neural projected dynamics of the heat flow is well-posed, e.g. the solution extends to arbitrary time. 29
  26. Consistency analysis for general cases Previous analysis is based on

    either the analytic formula or specific gradient flow equation, we can also establish the consistency by viewing our method as a projection-based reduced order model, i.e. the projected gradient is chosen to minimize the Wasserstein inner product as illustrated below ∇θ H(θ) = argminv∈TθΘ (∇θ H(θ)(x) − v(x))2fθ# pr (x)dx, The L2-error between the projected gradient and original gradient is bounded by v(x) − ∇θ H(θ) 2 L2(fθ#pr) = 1 4 N j=1 aj N 2 ∥v′′∥ ∞ O(∆b4). Proposition The numerical scheme based on ReLU network mapping is consistent with order 2 using both a, b parameters and of order 1 with either a or b parameters. 30
  27. Neural network structure We focus on the two-layer neural network

    with ReLU as activation functions. f(θ, z) = N i=1 ai · σ(z − bi ) + 2N i=N+1 ai · σ(bi − z) . θ ∈ R4N represents the collection of weights {ai }2N i=1 and bias {bi }2N i=1 . At initialization, we set ai = 1/N for i ∈ {1, . . . , N} and ai = −1/N for i ∈ {N + 1, . . . , 2N}. To choose the bi ’s, we first set b = linspace(−B, B, N) for some positive constant B (e.g. B = 4 or B = 10). We then set bi = b[i] for i = 1, . . . , N and bj = b[j − N] + ε for j = N + 1, . . . , 2N. Here ε = 5 × 10−6 is a small offset. Our initialization is chosen such that f(θ, ·) approximates the identity map at initialization. 31
  28. Linear transport equation Consider V (x) = (x − 1)4/4

    − (x − 1)2/2. 0.6 0.8 1.0 1.2 1.4 3.0 2.5 2.0 1.5 1.0 (a) Error 6 4 2 0 2 4 6 1 0 1 2 3 4 analytic computed (b) Mapping comparison Figure: Left: log-log plot of linear transport PDE with quartic polynomial potential. The y-axis represents log10 error. x-axis represents log10 (N). Right: Mapping comparison between analytic solution and our computed solution. 32
  29. Linear transport equation Consider V (x) = (x − 4)6/6.

    0.6 0.8 1.0 1.2 1.4 3.0 2.5 2.0 1.5 1.0 (a) Error 6 4 2 0 2 4 6 0 1 2 3 4 5 6 analytic computed (b) Mapping comparison Figure: Left: log-log plot of linear transport PDE with sixth order polynomial potential. The y-axis represents log10 error. x-axis represents log10 (N). Right: Mapping comparison between analytic solution and our computed solution. 33
  30. Fokker-Planck equation Consider V (x) = x2/2. 4 2 0

    2 4 0.0 0.2 0.4 t=0.00 0 5 10 15 0.0 0.1 0.2 t=0.20 0 5 10 15 20 0.0 0.1 t=0.40 5 10 15 20 25 0.0 0.1 t=0.60 5 10 15 20 25 30 0.0 0.1 t=0.80 10 15 20 25 30 0.0 0.1 t=1.00 (a) Moving, widening Gaussian 4 2 0 2 4 0.0 0.2 0.4 t=0.00 2 0 2 4 6 0.0 0.2 0.4 t=0.20 0 2 4 6 0.00 0.25 0.50 t=0.40 2 4 6 8 0.0 0.5 t=0.60 4 6 8 0.0 0.5 t=0.80 5 6 7 8 9 0.0 0.5 1.0 t=1.00 (b) Moving, shrinking Gaussian Figure: Density evolution of Fokker-Planck equqation with equadratic potential. Orange curve represents the analytic solution. Blue rectangles represent the histogram using 106 particles in 100 bins from t = 0 to t = 1. Left panel: a Gaussian distribution shifting to the right with increasing variance. Right panel: a Gaussian distribution shifting to the right with decreasing variance. 34
  31. Fokker-Planck equation Consider V (x) = (x − 1)4/4 −

    (x − 1)2/2. 0.6 0.8 1.0 1.2 1.4 3.0 2.5 2.0 1.5 1.0 (a) Error 6 4 2 0 2 4 6 2 1 0 1 2 3 4 analytic computed (b) Mapping comparison 5.0 2.5 0.0 2.5 5.0 0.0 0.2 0.4 t=0.00 2 0 2 4 0.0 0.2 0.4 t=0.04 2 0 2 4 0.0 0.2 0.4 t=0.08 2 0 2 4 0.0 0.2 0.4 t=0.12 2 0 2 4 0.0 0.2 0.4 t=0.16 2 0 2 4 0.00 0.25 0.50 t=0.20 (c) Density evolution Figure: Left: log-log plot of Fokker-Planck equation with a quartic polynomial potential. The y-axis represents log10 error. x-axis represents log10 (N). Red lines represent results when only the weights terms are updated. Black lines represent results when both weights and bias are updated. Middle: mapping comparison between T(t, z) (using an accurate numerical solver) and our computed solution f(θt , z). Right: density evolution of the Fokker-Planck equation with a quartic polynomial potential. Orange curve represents the density p(t, x) computed by a numerical solver. Blue rectangles represent the histogram of 106 particles in 100 bins from t = 0 to t = 0.2. 35
  32. Fokker-Planck equation Consider V (x) = (x − 1)6/6. 0.6

    0.8 1.0 1.2 1.4 3.0 2.5 2.0 1.5 1.0 (a) Error 6 4 2 0 2 4 6 0 1 2 3 4 5 6 analytic computed (b) Mapping comparison 4 2 0 2 4 0.0 0.2 0.4 t=0.0000 2 0 2 4 0.0 0.2 0.4 t=0.0002 0 2 4 0.00 0.25 0.50 t=0.0004 0 2 4 0.0 0.5 t=0.0006 0 2 4 0.0 0.5 1.0 t=0.0008 0 2 4 0.0 0.5 1.0 t=0.0010 (c) Density evolution Figure: Left: log-log plot of Fokker-Planck equation with a sixth order polynomial potential. The y-axis represents log10 error. x-axis represents log10 (N). The bias terms bi are initialized with B = 10 for the dashed line and B = 4 for the solid line. Red lines represent results when only the weights terms are updated. Black lines represent results when both weights and bias are updated. Middle: mapping comparison between analytic solution (using an accurate numerical solver) and our computed solution f(θt , z). Right: density evolution. Orange curve represents the density p(t, x) computed by a numerical solver. Blue rectangles represents the histogram of 106 particles in 100 bins from t = 0 to t = 10−3. 36
  33. Porous medium equation Consider the functional U(p(x)) = 1 m−1

    p(x)m, m = 2. This choice of U yields the porous medium equation ∂t p(t, x) = ∆p(t, x)m . (3) 0.6 0.8 1.0 1.2 1.4 3.75 3.70 3.65 3.60 3.55 3.50 3.45 (a) Error 3 2 1 0 1 2 3 4 3 2 1 0 1 2 3 4 analytic computed (b) Mapping comparison 2 1 0 1 2 0.0 0.2 t=0.000 2 1 0 1 2 0.0 0.2 t=0.200 2 0 2 0.0 0.2 t=0.400 2 0 2 0.0 0.2 t=0.600 2 0 2 0.0 0.2 t=0.800 2 0 2 0.0 0.2 t=1.000 (c) Density evolution Figure: Left: log-log plot of porous medium equation. The y-axis represents log10 error. x-axis represents log10 (N). Red lines represent results when only the weights terms are updated. Black lines represent results when both weights and bias are updated. Middle: mapping comparison between analytic solution (using an accurate numerical solver) and our computed solution f(θt , z). Right: density evolution of the porous medium equation. 37
  34. Keller-Segel equation 4 2 0 2 4 0.0 0.2 0.4

    t=0.00 4 2 0 2 4 0.0 0.2 0.4 t=0.06 5.0 2.5 0.0 2.5 5.0 0.0 0.2 0.4 t=0.12 5.0 2.5 0.0 2.5 5.0 0.00 0.25 0.50 t=0.18 5.0 2.5 0.0 2.5 5.0 0.0 0.5 t=0.24 5.0 2.5 0.0 2.5 5.0 0.0 0.5 1.0 t=0.30 (a) χ = 1.5 4 2 0 2 4 0.0 0.2 0.4 t=0.00 4 2 0 2 4 0.0 0.2 0.4 t=0.06 5.0 2.5 0.0 2.5 5.0 0.0 0.2 0.4 t=0.12 5.0 2.5 0.0 2.5 5.0 0.0 0.2 0.4 t=0.18 5.0 2.5 0.0 2.5 5.0 0.0 0.2 0.4 t=0.24 5.0 2.5 0.0 2.5 5.0 0.0 0.2 0.4 t=0.30 (b) χ = 0.5 Figure: Density evolution of Keller-Segel equation with different χ. Blue rectangles represent the histogram of 106 particles in 100 bins from t = 0 to t = 0.3. 38
  35. Discussions (i) Accuracy of neural networks in high dimensional (at

    least 2D) initial value PDE problems. (ii) Stiff analysis of matrix GW (θ) and gradient operator ∇θ F(θ) for neural network projected dynamics. (iii) Extend the current computation and analysis for general non-gradient Wasserstein-type dynamics, such as GENERICs in complex systems. (iv) Selections of neural network structures for accuracies and long-time behaviors of evolutionary PDEs. 39