Upgrade to Pro — share decks privately, control downloads, hide ads and more …

Wasserstein proximal learning

Wuchen Li
October 21, 2019

Wasserstein proximal learning

In this talk, I will briefly review the calculus behind optimal transport and mean-field games. In particular, I present the Wasserstein proximal, a.k.a Jordan-Kindler-Otto scheme, with the Hopf-Lax formula in density space and Master equations (Big Mac) in mean-field games. We demonstrate the usefulness of Wasserstein proximal in learning tasks through both generalization and optimization.

Wuchen Li

October 21, 2019
Tweet

More Decks by Wuchen Li

Other Decks in Research

Transcript

  1. Wasserstein proximal learning

  2. Hopf-Lax formula

  3. min x f(x) u(1, y) = inf x f(x) +

    ky xk2 2 Proximal=Hopf-Lax @tu + 1 2 kruk2 = 0 u(0, x) = f(x) min ⇢ F(⇢) U(1, µ) = inf ⇢ F(⇢) + distW (⇢, µ)2 2 @t U + 1 2 Z (r U)2⇢(x)dx = 0 U(0, ⇢) = F(⇢)
  4. U(t, µ) = inf ⇢2P(Td) F(⇢) + dW (⇢, µ)2

    2t Hopf-Lax on density space E.g. (Burgers’) Hamilton-Jacobi on density space Characteristics on density space (Nash equilibrium in mean field games) @tµs + r · (µs r s) = 0 @t s + 1 2 (r s)2 = 0 Math Review: mean field games @ @t U(t, µ) + 1 2 Z Td (r µ U(t, µ))2µ(x)dx = 0, U(0, µ) = F(µ) 4
  5. Example III: Generative Adversary Networks For each parameter ✓ 2

    Rd and given neural network parameterized mapping function g✓ , consider ⇢✓ = g✓#p(z). 24
  6. Wasserstein natural proximal The update scheme follows: ✓k+1 = arg

    min ✓2⇥ F(⇢✓) + 1 2h dW (✓, ✓k)2. where ✓ is the parameters of the generator, F(⇢✓) is the loss function, and dW is the Wasserstein metric. In practice, we approximate the Wasserstein metric to obtain the following update: ✓k+1 = arg min ✓2⇥ F(⇢✓) + 1 B B X i 1 2h kg✓(zi) g✓k (zi)k2, where g✓ is the generator, B is the batch size, and zi ⇠ p(z) are inputs to the generator. 25
  7. Examples: Jensen–Shannon entropy Loss Figure: The Relaxed Wasserstein Proximal of

    GANs, on the CIFA10 (left), CelebA (right) datasets. 26
  8. Examples: Wasserstein-1 Loss Figure: Wasserstein Proximal of Wasserstein-1 Loss function

    on the CIFA10 data set. 27
  9. Example: Stepsize Figure: The Wasserstein proximal improves the training by

    providing a lower FID when the learning rate is high. The results are based on the CelebA dataset. 28