Upgrade to Pro — share decks privately, control downloads, hide ads and more …

Transport Information Bregman divergencess

Wuchen Li
January 12, 2021
210

Transport Information Bregman divergencess

In this talk, we study Bregman divergences in probability density space embedded with the Wasserstein-2 metric. Several properties and dualities of transport Bregman divergences are provided. Concretely, we derive the transport Kullback-Leibler (KL) divergence by a Bregman divergence of negative Boltzmann-Shannon entropy in Wasserstein-2 space. We also derive analytical formulas of transport KL divergence for one-dimensional probability densities and Gaussian families.

Wuchen Li

January 12, 2021
Tweet

Transcript

  1. Learning Given a data measure ρdata (x) = 1 N

    N i=1 δXi (x) and a parameterized model ρ(x, θ). Learning problems often refer to min ρθ∈ρ(Θ) Dist(ρdata , ρθ ). Mathematics behind learning “Distance” between model and data in probability space, which allows efficient sampling approximations; Parameterizations: Full space; Neural networks generative models; Boltzmann machine; Gaussian; Gaussian mixture; finite volume/element, etc; Optimizations: Gradient descent; Primal dual algorithms; etc. In this talk, we focus on the construction of “distances”. 3
  2. Information What is “information”? Wiki: Information theory is the scientific

    study of the quantification, storage, and communication of information. The field is at the intersection of probability theory, statistics, computer science, statistical mechanics, information engineering, and electrical engineering. Applied Mathematics Entropy; Bregman divergences; Dualities; 5
  3. Bregman divergences Bregman divergences generalize Euclidean distances. Dψ (y x)

    = ψ(y) − ψ(x) − (∇ψ(x), y − x). Examples (i) Euclidean distance. ψ(z) = z2: Dψ (y x) = y2 − x2 − 2x(y − x) = (y − x)2. (ii) KL divergence. ψ(z) = z log z: Dψ (y x) = y log y x − (y − x). (iii) Itakura–Saito divergence. ψ(z) = − log z: Dψ (y x) = y x − log y x − 1. 6
  4. Bregman properties Nonnegativity: Dψ (y x) ≥ 0; Hessian metric:

    Consider a Taylor expansion as follows. Denote ∆x ∈ Rd, then Dψ (x + ∆x x) = 1 2 ∆xT∇2ψ(x)∆x + o( ∆x 2), where ∇2ψ is the Hessian operator of ψ w.r.t. the Euclidean metric; Asymmetry: In general, Dψ is not necessary symmetric w.r.t. x and y, i.e. Dψ (y x) = Dψ (x y); Duality: Denote the conjugate/dual function of ψ by ψ∗(x∗) = supx∈Ω (x, x∗) − ψ(x). Then Dψ∗ (x∗ y∗) = Dψ (y x). Here x∗ = ∇ψ(x) and y∗ = ∇ψ(y). 7
  5. KL divergence One of the most important Bregman divergence is

    KL divergence: DKL (p q) = Ω p(x) log p(x) q(x) dx. 8
  6. Why KL divergence? KL divergence = Bregman divergence+ Shannon entropy+L2

    space. DKL (p q) = Ω p(x) log p(x)dx− Ω p(x) log q(x)dx. Entropy Cross entropy KL has a lot of properties. Nonsymmetry: DKL (p q) = DKL (q p); Separable; Convexity in both variables p and q; Asymptotical behaviors: DKL (q + δq q) ≈ Ω (δq(x))2 q(x) dx, where 1 q is named the Fisher-Rao-information metric. 9
  7. Jenson–Shannon divergence KL divergence is a building block for other

    divergences. Its symmetrized version is named Jenson–Shannon divergence: DJS (p q) = 1 2 DKL (p r) + 1 2 DKL (q r), where r = p + q 2 is a geodesic midpoint (Barycenter) in L2 space. Because of its nice duality, its serves as an original objective function used in GANs, 10
  8. Generalized KL divergences Information geometry (Amari, Ay, Nilesen, et.al.) study

    generalizations of Bregman divergences while keeping their dualities. Using KL divergence and Fisher-Rao metric, various divergences can be constructed with nice duality properties. E.g. 11
  9. Optimal transport What is the optimal way to move or

    transport the mountain with shape X, density q(x) to another shape Y with density p(y)? I.e. DistT (p, q)2 = inf T Ω T(x) − x 2q(x)dx: T# q = p . The problem was first introduced by Monge in 1781 and relaxed by Kantorovich in 1940. It introduces a metric function on probability set, named optimal transport distance, Wasserstein metric or Earth Mover’s distance (Ambrosio, Gangbo, Villani, Otto, Figali, et.al.). 12
  10. Why optimal transport? Optimal transport provides a particular transport distance

    among histograms, which relies on the distance on sample spaces. E.g. Denote X0 ∼ p = δx0 , X1 ∼ q = δx1 . Compare DistT (p, q)2 = inf π∈Π(p,q) E(X0,X1)∼π X0 − X1 2 = x0 − x1 2. Vs DKL (p q) = Ω p(x) log p(x) q(x) dx = ∞. 13
  11. Optimal transport inference problems Nowadays, it has shown that optimal

    transport distances are useful in inference problems. Given a data distribution pdata and probability model pθ , consider min θ∈Θ DistT (pdata , pθ ). Benefits Hopf-Lax and Hamilton-Jacobi on a sample space (Small mac); Transport convexity. Drawback Additional minimization; Finite second moment of pdata and pθ ; 14
  12. Goals We plan to design Bregman divergences by using both

    transport distances and information entropies. Natural questions (i) What are Bregman divergences in Wasserstein space? (ii) What is the “KL divergence” in Wasserstein space? 15
  13. Transport Bregman divergence Definition (Transport Bregman divergence) Let F :

    P(Ω) → R be a smooth strictly displacement convex functional. Define DT,F : Ω × Ω → R by DT,F (p q) = F(p) − F(q) − Ω ∇x δ δq(x) F(q), T(x) − x q(x)dx, where T(x) is the optimal transport map function from q to p, such that T# q = p and T(x) = ∇x Φp (x). We call DT,F the transport Bregman divergence. 16
  14. Transport distance+ Bregman divergence Proposition Functional DT,F satisfies the following

    equality DT,F (p q) =F(p) − F(q) − 1 2 Ω gradT F(q)(x) · δ δq(x) DistT (p, q)2dx. 17
  15. Transport Bregman Properties (i) Non-negativity: Suppose F is displacement convex,

    then DT,F (p q) ≥ 0. Suppose F is strictly displacement convex, then DT,F (p q) = 0 if and only if DistT (p, q) = 0. (ii) Transport Hessian metric: Consider a Taylor expansion as follows. Denote σ = −∇ · (q∇Φ) ∈ Tq P(Ω) and ∈ R, then DT,F ((id + ∇Φ)# q q) = 2 2 HessT F(q)(σ, σ) + o( 2), where id: Ω → Ω is the identical map, id(x) = x, and HessT F(q) is the Hessian operator of functional F at q ∈ P(Ω) w.r.t. L2–Wasserstein metric. (iii) Asymmetry: In general, DT,F (p q) = DT,F (q p). Our transport duality relates to mean field game’s Wasserstein Hamilton-Jacobi equation (Big mac). 18
  16. Transport Bregman divergence of second moment If V(ρ) = Ω

    x 2p(x)dx, then DT,V (p q) = Ω T(x) 2 − x 2 − 2(T(x) − x, x) q(x)dx = Ω T(x) − x 2q(x)dx = Ω ∇x Φp (x) − ∇x Φq (x) 2q(x)dx = Ω Ω y − x 2π(x, y)dxdy =DistT (p, q)2. The transport Bregman divergence of second moment leads to the Wasserstein distances. 19
  17. Formulations: Linear energy Denote T = ∇Φp , ∇Φq =

    x. Consider a linear energy by V(p) = Ω V (x)p(x)dx, where the linear potential function V ∈ C∞(Ω) is strictly convex in Rd. Then DT,V (p q) = Ω DV (∇x Φp (x) ∇x Φq (x))q(x)dx, where DV : Ω × Ω → R is a Euclidean Bregman divergence of V defined by DV (z1 z2 ) = V (z1 ) − V (z2 ) − ∇V (z2 ) · (z1 − z2 ), for any z1 , z2 ∈ Ω. 20
  18. Formulations: Interaction energy Consider an interaction energy by W(p) =

    1 2 Ω Ω W(x, ˜ x)p(x)p(˜ x)dxd˜ x, where the interaction kernel potential function is W(x, ˜ x) = W(˜ x, x) ∈ C∞(Ω × Ω). Assume W(x, ˜ x) = ˜ W(x − ˜ x). Then DT,W (p q) = 1 2 Ω Ω D ˜ W ∇x Φp (x) − ∇˜ x Φp (˜ x) ∇x Φq (x) − ∇˜ x Φq (˜ x) q(x)q(˜ x)dxd˜ x, where D ˜ W : Ω × Ω → R is a Euclidean Bregman divergence of ˜ W defined by D ˜ W (z1 z2 ) = ˜ W(z1 )− ˜ W(z2 )−∇ ˜ W(z2 )·(z1 −z2 ), for any z1 , z2 ∈ Ω. 21
  19. Formulations: Negative entropy Consider a negative entropy by U(p) =

    Ω U(p(x))dx, where the entropy potential U : Ω → R is second differentiable and convex. Then DT,U (p q) = Ω Dˆ U ∇2 x Φp (x) ∇2 x Φq (x) q(x)dx, where Dˆ U : Rd×d × Rd×d → R is a matrix Bregman divergence function. Denote function ˆ U : R+ × Rd×d → R by ˆ U(q, A) = U( q det(A) ) det(A) q , where q ∈ R+ is the given reference density, and Dˆ U (A B) = ˆ U(q, A) − ˆ U(q, B) − tr ∇B ˆ U(q, B) · (A − B) , for any A, B ∈ Rd×d and ∇B is the Frechet derivative of a symmetric matrix B. 22
  20. Transport KL divergence Definition Define DTKL : P(Ω) × P(Ω)

    → R by DTKL (p q) = Ω ∆x Φp (x) − log det(∇2 x Φp (x)) − d q(x)dx, where ∇x Φp is the differemorphism map from q to p, such that (∇x Φp )# q = p. We call DTKL the transport KL divergence. 23
  21. Transport+Bregman+Entropy=TKL DTKL (p q) = Ω − log det(∇2 x

    Φp (x)) + ∆x Φp (x) − d q(x)dx = Ω p(x) log p(x)dx− Ω q(x) log q(x)dx + Ω ∆x Φp (x)q(x)dx − d =−H(p) + HT,q (p) =Entropy Transport cross entropy where we apply the fact that ∇x Φp# q = p, i.e. p(∇x Φp )det(∇2 x Φp ) = q(x). 24
  22. Why TKL? Theorem The transport KL divergence has the following

    properties. (i) Nonnegativity: For any p, q ∈ P(Ω), then DTKL (p q) ≥ 0. (ii) Separability: The transport KL divergence is additive for independent distributions. Suppose p(x, y) = p1 (x)p2 (y), q(x, y) = q1 (x)q2 (y). Then DTKL (p q) = DTKL (p1 q1 ) + DTKL (p2 q2 ). (iii) Transport Hessian information metric. (iv) Transport convexity. 25
  23. One Dimension: TKL vs KL divergence Transport KL divergence: DTKL

    (p q) := 1 0 ∇x F−1 p (x) ∇x F−1 q (x) − log ∇x F−1 p (x) ∇x F−1 q (x) − 1 dx. KL divergence: DKL (p q) = Ω ∇x Fp (x) log ∇x Fp (x) ∇x Fq (x) dx. Here Fp = x p(s)ds, Fq = x q(s)ds are cumulative distributions of probability densities p, q, respectively. 26
  24. Gaussian: TKL vs KL divergence Consider pX = N(0, ΣX

    ), pY = N(0, ΣY ). Transport KL divergence: DTKL (pX pY ) = 1 2 log det(ΣY ) det(ΣX ) + tr Σ 1 2 X Σ 1 2 X ΣY Σ 1 2 X − 1 2 Σ 1 2 X − d. KL divergence: DKL (pX pY ) = 1 2 log det(ΣY ) det(ΣX ) + 1 2 tr ΣX Σ−1 Y − d 2 . 27
  25. Transport Jensen–Shannon divergence Definition Define DTJS : P(Ω) × P(Ω)

    → R by DTJS (p q) = 1 2 DTKL (p r) + 1 2 DTKL (q r), where r ∈ P(Ω) is the geodesic midpoint (Barycenter) between p and q in L2–Wasserstein space, i.e. r = 1 2 ∇x Φp + ∇x Φq # q. 28
  26. One dimension: TJS vs JS divergence Transport Jenson-Shannon divergence: DTJS

    (p q) = − 1 2 1 0 log ∇x F−1 p (x) · ∇x F−1 q (x) 1 4 (∇x F−1 p (x) + ∇x F−1 q (x))2 dx. Jenson-Shannon divergence: DJS (p q) = 1 2 Ω ∇x Fp (x) log ∇x Fp (x) 1 2 ∇x Fp (x) + 1 2 ∇x Fq (x) dx + 1 2 Ω ∇x Fq (x) log ∇x Fq (x) 1 2 ∇x Fp (x) + 1 2 ∇x Fq (x) dx. 29
  27. Gaussian: TJS vs JS divergence Consider pX = N(0, ΣX

    ), pY = N(0, ΣY ). Transport Jenson-Shannon divergence: DTJS (pX pY ) = − 1 4 log det(ΣX )det(ΣY ) det(ΣZ )2 + 1 2 tr Σ 1 2 X Σ 1 2 X ΣZ Σ 1 2 X − 1 2 Σ 1 2 X + Σ 1 2 Y Σ 1 2 Y ΣZ Σ 1 2 Y − 1 2 Σ 1 2 Y − d. Jenson-Shannon divergence: No closed form solution. 30