Slide 1

Slide 1 text

Wasserstein proximal learning

Slide 2

Slide 2 text

Hopf-Lax formula

Slide 3

Slide 3 text

min x f(x) u(1, y) = inf x f(x) + ky xk2 2 Proximal=Hopf-Lax @tu + 1 2 kruk2 = 0 u(0, x) = f(x) min ⇢ F(⇢) U(1, µ) = inf ⇢ F(⇢) + distW (⇢, µ)2 2 @t U + 1 2 Z (r U)2⇢(x)dx = 0 U(0, ⇢) = F(⇢)

Slide 4

Slide 4 text

U(t, µ) = inf ⇢2P(Td) F(⇢) + dW (⇢, µ)2 2t Hopf-Lax on density space E.g. (Burgers’) Hamilton-Jacobi on density space Characteristics on density space (Nash equilibrium in mean field games) @tµs + r · (µs r s) = 0 @t s + 1 2 (r s)2 = 0 Math Review: mean field games @ @t U(t, µ) + 1 2 Z Td (r µ U(t, µ))2µ(x)dx = 0, U(0, µ) = F(µ) 4

Slide 5

Slide 5 text

Example III: Generative Adversary Networks For each parameter ✓ 2 Rd and given neural network parameterized mapping function g✓ , consider ⇢✓ = g✓#p(z). 24

Slide 6

Slide 6 text

Wasserstein natural proximal The update scheme follows: ✓k+1 = arg min ✓2⇥ F(⇢✓) + 1 2h dW (✓, ✓k)2. where ✓ is the parameters of the generator, F(⇢✓) is the loss function, and dW is the Wasserstein metric. In practice, we approximate the Wasserstein metric to obtain the following update: ✓k+1 = arg min ✓2⇥ F(⇢✓) + 1 B B X i 1 2h kg✓(zi) g✓k (zi)k2, where g✓ is the generator, B is the batch size, and zi ⇠ p(z) are inputs to the generator. 25

Slide 7

Slide 7 text

Examples: Jensen–Shannon entropy Loss Figure: The Relaxed Wasserstein Proximal of GANs, on the CIFA10 (left), CelebA (right) datasets. 26

Slide 8

Slide 8 text

Examples: Wasserstein-1 Loss Figure: Wasserstein Proximal of Wasserstein-1 Loss function on the CIFA10 data set. 27

Slide 9

Slide 9 text

Example: Stepsize Figure: The Wasserstein proximal improves the training by providing a lower FID when the learning rate is high. The results are based on the CelebA dataset. 28