Slide 1

Slide 1 text

Transport information Bregman divergences Wuchen Li University of South Carolina Optimal transport and Mean field game seminar 1

Slide 2

Slide 2 text

AI and Sampling 2

Slide 3

Slide 3 text

Learning Given a data measure ρdata (x) = 1 N N i=1 δXi (x) and a parameterized model ρ(x, θ). Learning problems often refer to min ρθ∈ρ(Θ) Dist(ρdata , ρθ ). Mathematics behind learning “Distance” between model and data in probability space, which allows efficient sampling approximations; Parameterizations: Full space; Neural networks generative models; Boltzmann machine; Gaussian; Gaussian mixture; finite volume/element, etc; Optimizations: Gradient descent; Primal dual algorithms; etc. In this talk, we focus on the construction of “distances”. 3

Slide 4

Slide 4 text

History 4

Slide 5

Slide 5 text

Information What is “information”? Wiki: Information theory is the scientific study of the quantification, storage, and communication of information. The field is at the intersection of probability theory, statistics, computer science, statistical mechanics, information engineering, and electrical engineering. Applied Mathematics Entropy; Bregman divergences; Dualities; 5

Slide 6

Slide 6 text

Bregman divergences Bregman divergences generalize Euclidean distances. Dψ (y x) = ψ(y) − ψ(x) − (∇ψ(x), y − x). Examples (i) Euclidean distance. ψ(z) = z2: Dψ (y x) = y2 − x2 − 2x(y − x) = (y − x)2. (ii) KL divergence. ψ(z) = z log z: Dψ (y x) = y log y x − (y − x). (iii) Itakura–Saito divergence. ψ(z) = − log z: Dψ (y x) = y x − log y x − 1. 6

Slide 7

Slide 7 text

Bregman properties Nonnegativity: Dψ (y x) ≥ 0; Hessian metric: Consider a Taylor expansion as follows. Denote ∆x ∈ Rd, then Dψ (x + ∆x x) = 1 2 ∆xT∇2ψ(x)∆x + o( ∆x 2), where ∇2ψ is the Hessian operator of ψ w.r.t. the Euclidean metric; Asymmetry: In general, Dψ is not necessary symmetric w.r.t. x and y, i.e. Dψ (y x) = Dψ (x y); Duality: Denote the conjugate/dual function of ψ by ψ∗(x∗) = supx∈Ω (x, x∗) − ψ(x). Then Dψ∗ (x∗ y∗) = Dψ (y x). Here x∗ = ∇ψ(x) and y∗ = ∇ψ(y). 7

Slide 8

Slide 8 text

KL divergence One of the most important Bregman divergence is KL divergence: DKL (p q) = Ω p(x) log p(x) q(x) dx. 8

Slide 9

Slide 9 text

Why KL divergence? KL divergence = Bregman divergence+ Shannon entropy+L2 space. DKL (p q) = Ω p(x) log p(x)dx− Ω p(x) log q(x)dx. Entropy Cross entropy KL has a lot of properties. Nonsymmetry: DKL (p q) = DKL (q p); Separable; Convexity in both variables p and q; Asymptotical behaviors: DKL (q + δq q) ≈ Ω (δq(x))2 q(x) dx, where 1 q is named the Fisher-Rao-information metric. 9

Slide 10

Slide 10 text

Jenson–Shannon divergence KL divergence is a building block for other divergences. Its symmetrized version is named Jenson–Shannon divergence: DJS (p q) = 1 2 DKL (p r) + 1 2 DKL (q r), where r = p + q 2 is a geodesic midpoint (Barycenter) in L2 space. Because of its nice duality, its serves as an original objective function used in GANs, 10

Slide 11

Slide 11 text

Generalized KL divergences Information geometry (Amari, Ay, Nilesen, et.al.) study generalizations of Bregman divergences while keeping their dualities. Using KL divergence and Fisher-Rao metric, various divergences can be constructed with nice duality properties. E.g. 11

Slide 12

Slide 12 text

Optimal transport What is the optimal way to move or transport the mountain with shape X, density q(x) to another shape Y with density p(y)? I.e. DistT (p, q)2 = inf T Ω T(x) − x 2q(x)dx: T# q = p . The problem was first introduced by Monge in 1781 and relaxed by Kantorovich in 1940. It introduces a metric function on probability set, named optimal transport distance, Wasserstein metric or Earth Mover’s distance (Ambrosio, Gangbo, Villani, Otto, Figali, et.al.). 12

Slide 13

Slide 13 text

Why optimal transport? Optimal transport provides a particular transport distance among histograms, which relies on the distance on sample spaces. E.g. Denote X0 ∼ p = δx0 , X1 ∼ q = δx1 . Compare DistT (p, q)2 = inf π∈Π(p,q) E(X0,X1)∼π X0 − X1 2 = x0 − x1 2. Vs DKL (p q) = Ω p(x) log p(x) q(x) dx = ∞. 13

Slide 14

Slide 14 text

Optimal transport inference problems Nowadays, it has shown that optimal transport distances are useful in inference problems. Given a data distribution pdata and probability model pθ , consider min θ∈Θ DistT (pdata , pθ ). Benefits Hopf-Lax and Hamilton-Jacobi on a sample space (Small mac); Transport convexity. Drawback Additional minimization; Finite second moment of pdata and pθ ; 14

Slide 15

Slide 15 text

Goals We plan to design Bregman divergences by using both transport distances and information entropies. Natural questions (i) What are Bregman divergences in Wasserstein space? (ii) What is the “KL divergence” in Wasserstein space? 15

Slide 16

Slide 16 text

Transport Bregman divergence Definition (Transport Bregman divergence) Let F : P(Ω) → R be a smooth strictly displacement convex functional. Define DT,F : Ω × Ω → R by DT,F (p q) = F(p) − F(q) − Ω ∇x δ δq(x) F(q), T(x) − x q(x)dx, where T(x) is the optimal transport map function from q to p, such that T# q = p and T(x) = ∇x Φp (x). We call DT,F the transport Bregman divergence. 16

Slide 17

Slide 17 text

Transport distance+ Bregman divergence Proposition Functional DT,F satisfies the following equality DT,F (p q) =F(p) − F(q) − 1 2 Ω gradT F(q)(x) · δ δq(x) DistT (p, q)2dx. 17

Slide 18

Slide 18 text

Transport Bregman Properties (i) Non-negativity: Suppose F is displacement convex, then DT,F (p q) ≥ 0. Suppose F is strictly displacement convex, then DT,F (p q) = 0 if and only if DistT (p, q) = 0. (ii) Transport Hessian metric: Consider a Taylor expansion as follows. Denote σ = −∇ · (q∇Φ) ∈ Tq P(Ω) and ∈ R, then DT,F ((id + ∇Φ)# q q) = 2 2 HessT F(q)(σ, σ) + o( 2), where id: Ω → Ω is the identical map, id(x) = x, and HessT F(q) is the Hessian operator of functional F at q ∈ P(Ω) w.r.t. L2–Wasserstein metric. (iii) Asymmetry: In general, DT,F (p q) = DT,F (q p). Our transport duality relates to mean field game’s Wasserstein Hamilton-Jacobi equation (Big mac). 18

Slide 19

Slide 19 text

Transport Bregman divergence of second moment If V(ρ) = Ω x 2p(x)dx, then DT,V (p q) = Ω T(x) 2 − x 2 − 2(T(x) − x, x) q(x)dx = Ω T(x) − x 2q(x)dx = Ω ∇x Φp (x) − ∇x Φq (x) 2q(x)dx = Ω Ω y − x 2π(x, y)dxdy =DistT (p, q)2. The transport Bregman divergence of second moment leads to the Wasserstein distances. 19

Slide 20

Slide 20 text

Formulations: Linear energy Denote T = ∇Φp , ∇Φq = x. Consider a linear energy by V(p) = Ω V (x)p(x)dx, where the linear potential function V ∈ C∞(Ω) is strictly convex in Rd. Then DT,V (p q) = Ω DV (∇x Φp (x) ∇x Φq (x))q(x)dx, where DV : Ω × Ω → R is a Euclidean Bregman divergence of V defined by DV (z1 z2 ) = V (z1 ) − V (z2 ) − ∇V (z2 ) · (z1 − z2 ), for any z1 , z2 ∈ Ω. 20

Slide 21

Slide 21 text

Formulations: Interaction energy Consider an interaction energy by W(p) = 1 2 Ω Ω W(x, ˜ x)p(x)p(˜ x)dxd˜ x, where the interaction kernel potential function is W(x, ˜ x) = W(˜ x, x) ∈ C∞(Ω × Ω). Assume W(x, ˜ x) = ˜ W(x − ˜ x). Then DT,W (p q) = 1 2 Ω Ω D ˜ W ∇x Φp (x) − ∇˜ x Φp (˜ x) ∇x Φq (x) − ∇˜ x Φq (˜ x) q(x)q(˜ x)dxd˜ x, where D ˜ W : Ω × Ω → R is a Euclidean Bregman divergence of ˜ W defined by D ˜ W (z1 z2 ) = ˜ W(z1 )− ˜ W(z2 )−∇ ˜ W(z2 )·(z1 −z2 ), for any z1 , z2 ∈ Ω. 21

Slide 22

Slide 22 text

Formulations: Negative entropy Consider a negative entropy by U(p) = Ω U(p(x))dx, where the entropy potential U : Ω → R is second differentiable and convex. Then DT,U (p q) = Ω Dˆ U ∇2 x Φp (x) ∇2 x Φq (x) q(x)dx, where Dˆ U : Rd×d × Rd×d → R is a matrix Bregman divergence function. Denote function ˆ U : R+ × Rd×d → R by ˆ U(q, A) = U( q det(A) ) det(A) q , where q ∈ R+ is the given reference density, and Dˆ U (A B) = ˆ U(q, A) − ˆ U(q, B) − tr ∇B ˆ U(q, B) · (A − B) , for any A, B ∈ Rd×d and ∇B is the Frechet derivative of a symmetric matrix B. 22

Slide 23

Slide 23 text

Transport KL divergence Definition Define DTKL : P(Ω) × P(Ω) → R by DTKL (p q) = Ω ∆x Φp (x) − log det(∇2 x Φp (x)) − d q(x)dx, where ∇x Φp is the differemorphism map from q to p, such that (∇x Φp )# q = p. We call DTKL the transport KL divergence. 23

Slide 24

Slide 24 text

Transport+Bregman+Entropy=TKL DTKL (p q) = Ω − log det(∇2 x Φp (x)) + ∆x Φp (x) − d q(x)dx = Ω p(x) log p(x)dx− Ω q(x) log q(x)dx + Ω ∆x Φp (x)q(x)dx − d =−H(p) + HT,q (p) =Entropy Transport cross entropy where we apply the fact that ∇x Φp# q = p, i.e. p(∇x Φp )det(∇2 x Φp ) = q(x). 24

Slide 25

Slide 25 text

Why TKL? Theorem The transport KL divergence has the following properties. (i) Nonnegativity: For any p, q ∈ P(Ω), then DTKL (p q) ≥ 0. (ii) Separability: The transport KL divergence is additive for independent distributions. Suppose p(x, y) = p1 (x)p2 (y), q(x, y) = q1 (x)q2 (y). Then DTKL (p q) = DTKL (p1 q1 ) + DTKL (p2 q2 ). (iii) Transport Hessian information metric. (iv) Transport convexity. 25

Slide 26

Slide 26 text

One Dimension: TKL vs KL divergence Transport KL divergence: DTKL (p q) := 1 0 ∇x F−1 p (x) ∇x F−1 q (x) − log ∇x F−1 p (x) ∇x F−1 q (x) − 1 dx. KL divergence: DKL (p q) = Ω ∇x Fp (x) log ∇x Fp (x) ∇x Fq (x) dx. Here Fp = x p(s)ds, Fq = x q(s)ds are cumulative distributions of probability densities p, q, respectively. 26

Slide 27

Slide 27 text

Gaussian: TKL vs KL divergence Consider pX = N(0, ΣX ), pY = N(0, ΣY ). Transport KL divergence: DTKL (pX pY ) = 1 2 log det(ΣY ) det(ΣX ) + tr Σ 1 2 X Σ 1 2 X ΣY Σ 1 2 X − 1 2 Σ 1 2 X − d. KL divergence: DKL (pX pY ) = 1 2 log det(ΣY ) det(ΣX ) + 1 2 tr ΣX Σ−1 Y − d 2 . 27

Slide 28

Slide 28 text

Transport Jensen–Shannon divergence Definition Define DTJS : P(Ω) × P(Ω) → R by DTJS (p q) = 1 2 DTKL (p r) + 1 2 DTKL (q r), where r ∈ P(Ω) is the geodesic midpoint (Barycenter) between p and q in L2–Wasserstein space, i.e. r = 1 2 ∇x Φp + ∇x Φq # q. 28

Slide 29

Slide 29 text

One dimension: TJS vs JS divergence Transport Jenson-Shannon divergence: DTJS (p q) = − 1 2 1 0 log ∇x F−1 p (x) · ∇x F−1 q (x) 1 4 (∇x F−1 p (x) + ∇x F−1 q (x))2 dx. Jenson-Shannon divergence: DJS (p q) = 1 2 Ω ∇x Fp (x) log ∇x Fp (x) 1 2 ∇x Fp (x) + 1 2 ∇x Fq (x) dx + 1 2 Ω ∇x Fq (x) log ∇x Fq (x) 1 2 ∇x Fp (x) + 1 2 ∇x Fq (x) dx. 29

Slide 30

Slide 30 text

Gaussian: TJS vs JS divergence Consider pX = N(0, ΣX ), pY = N(0, ΣY ). Transport Jenson-Shannon divergence: DTJS (pX pY ) = − 1 4 log det(ΣX )det(ΣY ) det(ΣZ )2 + 1 2 tr Σ 1 2 X Σ 1 2 X ΣZ Σ 1 2 X − 1 2 Σ 1 2 X + Σ 1 2 Y Σ 1 2 Y ΣZ Σ 1 2 Y − 1 2 Σ 1 2 Y − d. Jenson-Shannon divergence: No closed form solution. 30

Slide 31

Slide 31 text

Discussion Design transport Bregman divergences for learning objective problems; Design transport Bregman optimization algorithms. 31