Upgrade to Pro — share decks privately, control downloads, hide ads and more …

Wasserstein proximal learning

7a507f364fce7547f94b9a5b4a072c87?s=47 Wuchen Li
October 21, 2019

Wasserstein proximal learning

In this talk, I will briefly review the calculus behind optimal transport and mean-field games. In particular, I present the Wasserstein proximal, a.k.a Jordan-Kindler-Otto scheme, with the Hopf-Lax formula in density space and Master equations (Big Mac) in mean-field games. We demonstrate the usefulness of Wasserstein proximal in learning tasks through both generalization and optimization.


Wuchen Li

October 21, 2019


  1. Wasserstein proximal learning

  2. Hopf-Lax formula

  3. min x f(x) u(1, y) = inf x f(x) +

    ky xk2 2 Proximal=Hopf-Lax @tu + 1 2 kruk2 = 0 u(0, x) = f(x) min ⇢ F(⇢) U(1, µ) = inf ⇢ F(⇢) + distW (⇢, µ)2 2 @t U + 1 2 Z (r U)2⇢(x)dx = 0 U(0, ⇢) = F(⇢)
  4. U(t, µ) = inf ⇢2P(Td) F(⇢) + dW (⇢, µ)2

    2t Hopf-Lax on density space E.g. (Burgers’) Hamilton-Jacobi on density space Characteristics on density space (Nash equilibrium in mean field games) @tµs + r · (µs r s) = 0 @t s + 1 2 (r s)2 = 0 Math Review: mean field games @ @t U(t, µ) + 1 2 Z Td (r µ U(t, µ))2µ(x)dx = 0, U(0, µ) = F(µ) 4
  5. Example III: Generative Adversary Networks For each parameter ✓ 2

    Rd and given neural network parameterized mapping function g✓ , consider ⇢✓ = g✓#p(z). 24
  6. Wasserstein natural proximal The update scheme follows: ✓k+1 = arg

    min ✓2⇥ F(⇢✓) + 1 2h dW (✓, ✓k)2. where ✓ is the parameters of the generator, F(⇢✓) is the loss function, and dW is the Wasserstein metric. In practice, we approximate the Wasserstein metric to obtain the following update: ✓k+1 = arg min ✓2⇥ F(⇢✓) + 1 B B X i 1 2h kg✓(zi) g✓k (zi)k2, where g✓ is the generator, B is the batch size, and zi ⇠ p(z) are inputs to the generator. 25
  7. Examples: Jensen–Shannon entropy Loss Figure: The Relaxed Wasserstein Proximal of

    GANs, on the CIFA10 (left), CelebA (right) datasets. 26
  8. Examples: Wasserstein-1 Loss Figure: Wasserstein Proximal of Wasserstein-1 Loss function

    on the CIFA10 data set. 27
  9. Example: Stepsize Figure: The Wasserstein proximal improves the training by

    providing a lower FID when the learning rate is high. The results are based on the CelebA dataset. 28