Slide 1

Slide 1 text

Wasserstein Information Matrix Wuchen Li University of South Carolina/UCLA This is based on a joint work with Jiaxi Zhao.

Slide 2

Slide 2 text

Statistical distances 2

Slide 3

Slide 3 text

Information matrix Information matrix (a.k.a Fisher information matrix, Fisher–Rao metric) plays important roles in estimation, information science, statistics and machine learning: I Machine learning: Natural gradient (Amari); ADAM (Kingma 2014); Stochastic relaxation (Malago, Pistone) and many more in book Information geometry (Ay, Nielsen, et.al.). I Statistics: Likelihood principle; Cramer-Rao bound; Sampling complexity; etc. 3

Slide 4

Slide 4 text

Statistical Learning Given a data measure pdata(x) = 1 N PN i=1 Xi (x) and a parameterized model p(x; ✓). Machine learning problems often refer to min p✓ 2p✓ D(pdata , p✓). Here D is the related statistical distance function. One typical choice of D is the Kullback–Leibler divergence (relative entropy) D(pdata , p✓) = Z ⌦ pdata(x) log pdata(x) p(x; ✓) dx. 4

Slide 5

Slide 5 text

Fisher information matrix The Fisher information matrix satisfies GF (✓)ij =E X⇠p✓ h @ @✓i log p(X; ✓) @ @✓j log p(X; ✓) i , where @ @✓k log p(X; ✓), k = i, j, is named the score function. Applications of Fisher information matrix I Estimation and E ciency: Cramer-Rao bound, Hessian curvature of KL loss function; I Pre-conditioners for KL divergence related learning problems. I Natural gradient: Parameterization invariant optimization; 5

Slide 6

Slide 6 text

Optimal transport In recent years, optimal transport (a.k.a Earth mover’s distance, Monge-Kantorovich problem, Wasserstein metric) has deeply connected with statistics and machine learning: I Theory (Brenier, Gangbo, Mccan, Villani, Figalli et.al.); Gradient flows (Otto, Villani, Carrilo, Jordan, Kinderlehrer et.al.); I Image retrieval (Rubner et.al. 2000); I Computational optimal transport (Preye, Cuturi, Soloman, Carrilo, Benamou, Osher, Li, et.al.) I Machine learning: Wasserstein Training of Boltzmann Machines (Cuturi et.al. 2015); Learning from Wasserstein Loss (Frogner et.al. 2015); Wasserstein GAN (Bottou et.al. 2017); Deep learning (Gu, Yau et.al.). I Bayesian Sampling by Wasserstein dynamics (Bernton, Heng, Doucet, Jacob, Liu, Amir, Mehta, Liu et.al., Ma et.al., Li, Wang) 6

Slide 7

Slide 7 text

Wasserstein Loss function Given a data distribution pdata(x) = 1 N PN i=1 Xi (x) and a probability model p✓ . Consider min ✓2⇥ W(pdata , p✓). This is a double minimization problem, i.e. W(pdata , p✓) = min ⇡2⇧(pdata,p✓) E (X0,X1)⇠⇡ c(X0 , X1). Many applications, such as Wasserstein GAN, Wasserstein Loss, are built on the above formulation. 7

Slide 8

Slide 8 text

Goals Main Question: Can we study the optimal transport induced information matrices, and understand their properties on statistics and machine learning problems? Related studies I Wasserstein covariance (Petersen, Muller); I Wasserstein minimal distance estimator (Bernton, Jacob, Gerber, Robert, Blanchet); I Statistical inference for generative models with maximum mean discrepancy (Briol, Barp, Duncan, Girolami); I Joint study between information geometry and optimal transport: (Amari, Karakida, Oizumi, Takatsu, Malago, Piston, Wong, Yang, Modin, Chen, Tryphon, Sanctis); I Wasserstein natural gradient (Li, Montufar, Chen, Lin, Abel, Gretton, et.al.); I Wasserstein statistics of local scale family (Amari, Li, Zhao). 8

Slide 9

Slide 9 text

Problem formulation I Mapping formulation: Monge problem (1781): Monge-Amp´ ere equation ; I Statical formulation: Kantorovich problem (1940): Linear programming ; I Dynamical formulation: Density optimal control (Nelson, Carlen, La↵erty, Otto, Villani, et.al.). In this talk, we will apply density optimal control into learning problems. 9

Slide 10

Slide 10 text

Density manifold Optimal transport has an optimal control reformulation, known as the Benamou-Breiner formula: inf pt Z 1 0 gW (@t pt , @t pt)dt = Z 1 0 Z ⌦ (r t , r t)pt dxdt, under the dynamical constraint, i.e. continuity equation: @t pt = r · (pt r t), p0 = p0, p1 = p1. Here, (P(⌦), gW ) forms an infinite-dimensional Riemannian manifold1. 1John D. La↵erty, The density manifold and configuration space quantization, 1988. 10

Slide 11

Slide 11 text

Transport information statistics 11

Slide 12

Slide 12 text

Information matrix 12

Slide 13

Slide 13 text

Statistical information matrix Definition (Statistical Information Matrix) Consider the density manifold (P(X), g) with a metric tensor g, and a smoothly parametrized statistical model p✓ with parameter ✓ 2 ⇥ ⇢ Rd. Then the pull-back G of g onto the parameter space ⇥ is given by G(✓) = D r✓ p✓ , g(p✓)r✓ p✓ E . Denote G(✓) = (G(✓)ij)1i,jd , then G(✓)ij = Z X @ @✓i p(x; ✓) ⇣ g(p✓) @ @✓j p ⌘ (x; ✓)dx. Here we name g the statistical metric, and call G the statistical information matrix. 13

Slide 14

Slide 14 text

Statistical information matrix Definition (Score function) Denote i : X ⇥ ⇥ ! R, i = 1, ..., n satisfying i(x; ✓) =  g(p) ✓ @ @✓i p(x; ✓) ◆ . They are the score functions associated with the statistical information matrix G and are equivalent classes in C(X)/R. The representatives in the equivalent classes are determined by the following normalization condition: E x⇠p✓ h i(x; ✓) i = 0, i = 1, ..., n. Then the statistical information matrix satisfies G(✓)ij = Z X i(x; ✓) ⇣ g(p✓) 1 j ⌘ (x; ✓)dx. 14

Slide 15

Slide 15 text

Examples: Fisher information matrix Consider gF (p) 1 = p: GF (✓)ij = E X⇠p✓ h i(X; ✓) j(X; ✓) i , where k(X; ✓) = 1 p(X; ✓) @ @✓k p(X; ✓), k = i, j. Notice the fact 1 p(X; ✓) @ @✓k p(X; ✓) = @ @✓k log p(X; ✓k), then GF (✓)ij = E X⇠p✓ h @ @✓i log p(X; ✓) @ @✓j log p(X; ✓) i . In literature, GF (✓) is known as the Fisher information matrix and log p(X; ✓) is named (Fisher) score function. 15

Slide 16

Slide 16 text

Examples: Wasserstein information matrix Consider gW (p) 1 = r · (pr): GW (✓)ij = E X⇠p✓ h⇣ rX W i (X; ✓), rX W j (X; ✓) ⌘i . where rx · (p(x; ✓)rx W k (x; ✓)) = @ @✓k p(x; ✓), k = i, j. Here we call GW (✓) the Wasserstein information matrix (WIM) and name W the Wasserstein score function. 16

Slide 17

Slide 17 text

Distance and information matrix Specifically, given a smooth family of probability densities p(x; ✓) and a given perturbation ✓ 2 T✓⇥, consider the following Taylor expansions in term of ✓: KL(p✓ kp(✓ + ✓)) = 1 2 ✓T GF (✓) ✓ + o(( ✓)2), and W2(p(✓ + ✓), p✓)2 = ✓T GW (✓) ✓ + o(( ✓)2). 17

Slide 18

Slide 18 text

Poisson equation The Wasserstein score functions W i (x; ✓) satisfy the following Poisson equations rx log p(x; ✓) · rx W i (x; ✓) + x W i (x; ✓) = @ @✓i log p(x; ✓). 18

Slide 19

Slide 19 text

Separability If p(x; ✓) is an independence model, i.e. p(X; ✓) = ⇧n k=1 pk(xk; ✓), x = (x1 , · · · , xn) 2 Rn. Then there exists a set of one dimensional functions W,k : Xk ⇥ ⇥ ! R, such that W (x; ✓) = n X k=1 W,k(xk; ✓). In addition, the Wasserstein information matrix is separable: GW (✓) = n X k=1 Gk W (✓), where Gk W (✓) = E xk ⇠pk h⇣ rxk W,k(xk; ✓), rxk W (xk; ✓) ⌘i . 19

Slide 20

Slide 20 text

One dimensional sample space If X ⇢ R1, the Wasserstein score functions satisfy W i (x; ✓) = Z x 1 p(z; ✓) @ @✓i F(z; ✓)dz, where F(x; ✓) = R x p(y; ✓)dy is the cumulative distribution function. And the Wasserstein information matrix2 satisfies GW (✓)ij = E X⇠p✓ h @ @✓i F(X; ✓) @ @✓j F(X; ✓) p(X; ✓)2 i . 2Chen, Li, Wasserstein natural gradient in continuous sample space, 2018. 20

Slide 21

Slide 21 text

Remark The Wasserstein score function is the average of the cumulative Fisher score function, and the Wasserstein information matrix is the covariance of the density average of the cumulative Fisher score function. 21

Slide 22

Slide 22 text

Analytic examples: Location-scale family I Gaussian family: p(x; µ, ) = 1 p 2⇡ e 1 2 2 (x µ)2 , GW (µ, ) = ✓ 1 0 0 1 ◆ . I Laplacian family: p(x; m, ) = 2 e |x m|, GW (µ, ) = ✓ 1 1 2 . 1 2 2 4 ◆ . 22

Slide 23

Slide 23 text

Analytic examples: Mixed family Consider the mixed family: p(x; ✓) = N X i=1 ✓i pi(x), n X i=1 ✓i = 1, ✓i 0. The WIM satisfies GW (✓)ij = E x⇠p✓ h(Fi+1(x) Fi(x))(Fj+1(x) Fj(x)) p(x; ✓)2 i , where Fi(x) is the cumulative distribution function of the density function pi . 23

Slide 24

Slide 24 text

Generative Adversary Networks Consider a class of invertible push-forward maps {f✓ }✓2⇥ indexed by parameter ✓ 2 ⇥ ⇢ Rm f✓ : Rd ! Rd. where the push-forward distribution is defined as Z A p0 dx = Z f 1 ✓ (A) f✓⇤ p0 dx, 24

Slide 25

Slide 25 text

Neural Wasserstein information matrix Denote a family of parametric distributions3 4 P⇥ = p✓ = f✓# p0 | ✓ 2 ⇥ . In this case, the WIM forms GW (✓)ij = Z X r i(f✓(x)) · r j(f✓(x)) p(x) dx, where r · (p✓ r k(x)) = r · (p✓ @✓k f✓(f 1 ✓ (x))). 3Lin, Li, Osher, Montufar, Wasserstein proximal of GANs, 2018. 4Liu, Li, Zha, Zhou, Neural Fokker-Planck equations, 2020. 25

Slide 26

Slide 26 text

Analytic examples: WIM in Generative models In continuous 1-d generative family: p(·, ✓) = f✓⇤ p0 (·) , p0 is a given distribution, then @ @x i(x; ✓) = @ @✓i f(z; ✓), with x = f(z; ✓), and GW (✓)ij = Z R1 @ @✓i f(z; ✓) · @ @✓j f(z; ✓)p0(z)dz. 26

Slide 27

Slide 27 text

Analytic examples Consider generative models with ReLU families: f✓ (z) = (z ✓) = ( 0, z  ✓, z ✓, z > ✓. GW (✓) = F0(✓), F0 cumulative distribution function of p0 . Figure: This figure plots two example of the push-forward family with ✓1 = 3, ✓2 = 5. 27

Slide 28

Slide 28 text

Statistical Information Matrix Probability family Wasserstein information matrix Fisher information matrix Uniform: p(x;a,b) = 1 b a 1(a,b) (x) GW (a,b) = 1 3 ✓ 1 1 2 1 2 1 ◆ GF (a,b) not well-defined Gaussian: p(x;µ, ) = e 1 2 2 (x µ)2 p 2⇡ GW (µ, ) = ✓ 1 0 0 1 ◆ GF (µ, ) = ✓ 1 2 0 0 2 2 ◆ Exponential: p(x;m, ) = e (x m) GW (m, ) = ✓ 1 0 0 2 4 ◆ GF (m, ) not well-defined Laplacian: p(x;m, ) = 2 e |x m| GW (m, ) = ✓ 1 1 2 1 2 2 4 ◆ GF (m, ) = ✓ 2 0 0 1 2 ◆ Location-scale: p(x;m, ) = 1 p( x p ) GW ( ,m) = E ,mx2 2mE ,mx+m2 2 0 0 1 ! GF ( ,m) = 0 @ 1 2 ⇣ 1 + R R ⇣ (x m)2p02 2p + (x m)p0 ⌘ dx ⌘ R R (x m)p02 3p dx R R (x m)p02 3p dx 1 2 R R p02 p dx 1 A Independent: p(x,y;✓) = p(x;✓)p(y;✓) GW (x,y;✓) = G1 W (x;✓) + G2 W (y;✓) GF (x,y;✓) = G1 F (x;✓) + G2 F (y;✓) ReLU push-forward: p(x;✓) = f✓⇤ p(x), f✓ ✓-parameterized ReLUs.. GW (✓) = F (✓), F cdf of p(x) GF (✓) not well-defined Table: In this table, we present Wasserstein and Fisher information matrices for various probability families. 28

Slide 29

Slide 29 text

Application of WIM Recently, we apply the WIM in learning optimizations and computational fluid dynamics: I Wasserstein natural gradient; I Machine learning methods for Wasserstein Hamiltonian flows, Mean field games etc. 29 Computation of OT by variational Neural ODEs Lars, et.al. 2020. Computation of Neural Fokker-Planck equations Liu, et.al. 2020.

Slide 30

Slide 30 text

Wasserstein statistics Today, we present the statistical theory of WIM: I Estimation: Wasserstein–Cramer-Rao bound; I Information inequalities: Ricci curvature in parametric statistics; I E ciency: Wasserstein–Online e ciency; Here we develop a Wasserstein statistics following the classical (Fisher) statistical approach. 30

Slide 31

Slide 31 text

Wasserstein covariance Definition Given a statistical model ⇥, denote the Wasserstein covariance as follows: CovW ✓ [T1 , T2] = E p✓ h (rx T1(x), rx T2(x))T i , where T1 , T2 are random variables as functions of x and the expectation is taken w.r.t. x ⇠ p✓ . Denote the Wasserstein variance: (VarW ✓ [T])ij =CovW ✓ [T, T] =E p✓ h (rx Ti(x), rx Tj(x)) i . 31

Slide 32

Slide 32 text

Wasserstein-Cramer-Rao bound Theorem Given any set of statistics T = (T1 , ..., Tn) : X ! Rn, where n is the number of the statistics, define two matrices CovW ✓ [T(x)], r✓ E p✓ [T(x)]T as below: CovW ✓ [T(x)]ij = CovW ✓ [Ti , Tj], r✓ E p✓ [T(x)]T ij = @ @✓j E p✓ [Ti(x)], then CovW ✓ [T(x)] ⌫ r✓ E p✓ [T(x)]GW (✓) 1r✓ E p✓ [T(x)]. 32

Slide 33

Slide 33 text

Cramer-Rao bound: Fisher vs Wasserstein I Gaussian: GW (µ, ) = ✓ 1 0 0 1 ◆ , GF (µ, ) = ✓ 1 2 0 0 2 2 ◆ . I Laplacian: GW (m, ) = ✓ 1 1 2 1 2 2 4 ◆ , GF not well-defined. I Comparison: GW is well-defined for a wide range of families. I Tighter bound on the variance of an estimator. 33

Slide 34

Slide 34 text

Wasserstein natural gradient Given a loss function F : P(⌦) ! R and probability model p(·; ✓), the associated gradient flow on a Riemannian manifold is defined by d✓ dt = rg F(p(·; ✓)). Here rg is the Riemannian gradient operator satisfying g✓(rg F(p(·; ✓)), ⇠) = r✓ F(p(·; ✓)) · ⇠ for any tangent vector ⇠ 2 T✓⇥, where r✓ represents the Euclidean gradient. 34

Slide 35

Slide 35 text

Wasserstein natural gradient The gradient flow of loss function F(p(·; ✓)) in (⇥, GW (✓)) satisfies d✓ dt = GW (✓) 1r✓ F(p(·; ✓)). If p(x; ✓) = p(x), then we recover the Wasserstein gradient flow in full probability space: @t p = r · (pr p F(p)). 35

Slide 36

Slide 36 text

Information functional inequalities Comparison between Fisher and Wasserstein information matrices relates to well-known information functional inequalities (Lott–Sturm–Villani). Here we study them in parameter statistics. I Dissipation of entropy along gradient flow d dt H(p|p⇤) = Z X rx log p(x) p⇤(x) 2 p⇤(x)dx = I(p|p⇤) d dt e H(p✓ |p✓⇤ ) = r✓ e HT G 1 W r✓ e H = e I(p✓ |p✓⇤ ) I Log-Sobolev inequality (LSI) H(p|p⇤) < 1 2↵ I(p|p⇤), p 2 P(X) e H(p✓ |p✓⇤ ) < 1 2↵ e I(p✓ |p✓⇤ ), ✓ 2 ⇥ 36

Slide 37

Slide 37 text

Ricci curvature in parametric statistics Theorem (RIW-condition5 6) The information matrix criterion for LSI can be written as: GF (✓) + r2 ✓ p✓ log p✓ p✓⇤ W r✓ e H(p✓ |p✓⇤ ) 2↵GW (✓), where W is the Christo↵el symbol in Wasserstein statistical model ⇥, while for PI can be written as: GF (✓) + r2 ✓ p✓ log p✓ p✓⇤ 2↵GW (✓). 5Li, Transport information geometry I, 2018. 6Li, Montufar, Ricci curvature for parameter statistics, 2018. 37

Slide 38

Slide 38 text

List of functional inequalities in family Family Fisher information functional Log-Sobolev inequality(LSI(↵)) Gaussian e I(pµ, |p⇤) = (µ µ⇤)2 4 4 ⇤ + ✓ 1 + 2 ⇤ ◆2 (µ µ⇤)2 4 ⇤ + ✓ 1 + 2 ⇤ ◆2 2↵ ✓ log +log ⇤ 1 2 + 2 +(µ µ⇤)2 2 2 ⇤ ! Laplacian e I(pm, |p⇤) = 2 ⇤ ⇣ 1 e |m m⇤| ⌘2 + ( |m m⇤ |+1) ⇤ e |m m⇤| 2 2 2 ⇤ ⇣ 1 e |m m⇤| ⌘2 + ( |m m⇤ |+1) ⇤ e |m m⇤| 2 2 2↵( 1+log log ⇤ + ⇤ |m m⇤ |+ ⇤ e |m m⇤| ◆ Table: In this table, we continue the list to include the Fisher information functional, Log-Sobolev inequality for various probability families. 38

Slide 39

Slide 39 text

List of functional inequalities in families Family RIW condition for LSI(↵) RIW condition for PI(↵) Gaussian 1 2 ⇤ 0 0 1 2 ⇤ + 1 2 ! ⌫ 2↵ ✓ 1 0 0 1 ◆ 1 2 ⇤ 0 0 2 2 ⇤ ! ⌫ 2↵ ✓ 1 0 0 1 ◆ Laplacian ⇤ e |m m⇤| 0 0 1 2 + ⇤e |m m⇤| 2(m⇤ m)2 3 ! ⌫ 2↵ ✓ 1 0 0 2 4 ◆ ✓ 2 ⇤ 0 0 1 2 ⇤ ◆ ⌫ 2↵ ✓ 1 0 0 2 4 ⇤ ◆ Table: In this table, we present the RIW condition for LSI and PI in various probability families. 39

Slide 40

Slide 40 text

Online natural gradient algorithm We sample from the unknown distribution once in each step, and use a sample xt to generate an estimator ✓t+1 ✓t+1 = ✓t 1 t rW ✓ l(xt; ✓t), where l is the loss function. To analyze the convergence of this algorithm, we define the Wasserstein covariance matrix Vt to be Vt = E p✓⇤ h rx(✓t ✓⇤) · rx(✓t ✓⇤)T i , where ✓⇤ is the optimal value of learning optimization. 40

Slide 41

Slide 41 text

Wasserstein Natural gradient e ciency Definition The Wasserstein natural gradient is asymptotic e cient if Vt = 1 t G 1 W (✓⇤) + O( 1 t2 ). 41

Slide 42

Slide 42 text

Wasserstein online e ciency Corollary (Wasserstein Natural Gradient E ciency) For the dynamics ✓t+1 = ✓t 1 t G 1 W (✓t) W (xt; ✓t), the Wasserstein covariance updates according to Vt+1 =Vt + 1 t2 G 1 W (✓⇤) 2 t Vt + o ✓ 1 t2 ◆ + o( Vt t ). Then, the online Wasserstein natural gradient algorithm is Wasserstein e cient, that is: Vt = 1 t G 1 W (✓⇤) + O ✓ 1 t2 ◆ . 42

Slide 43

Slide 43 text

Poincare online e ciency Corollary For the dynamics ✓t+1 = ✓t 1 t rW ✓ l(xt; ✓t), where l(xt; ✓t) = log p (xt; ✓t) is the log-likelihood function. The Wasserstein covariance updates according to Vt+1 = Vt + 1 t2 G 1 W (✓⇤)E p✓⇤ ⇥ rx (r✓ l(xt; ✓⇤)) · rx r✓ l(xt; ✓⇤)T ⇤ G 1 W (✓⇤) 2 t Vt GF (✓⇤)G 1 W (✓⇤) + O ✓ 1 t3 ◆ + o ✓ Vt t ◆ . 43

Slide 44

Slide 44 text

Poincare online e ciency continued Corollary Suppose ↵ = sup{a 2 R|GF ⌫ aGW }. Then Vt = 8 > < > : O t 2↵ , 2↵  1, 1 t 2GF G 1 W I 1 G 1 W (✓⇤)IG 1 W (✓⇤) + O ✓ 1 t2 ◆ , 2↵ > 1, where I = E xt ⇠p✓⇤ ⇥ rx (r✓ log p(xt; ✓⇤)) · rx r✓ log p(xt; ✓⇤)T ⇤ . 44

Slide 45

Slide 45 text

Future works I Study sampling complexity by WIM; I Analyze Wasserstein estimation by WIM; I Approximate WIM for scientific computing. 45