Slide 1

Slide 1 text

Wasserstein Gradient Flows of Moreau Envelopes of f-Divergences in Reproducing Kernel Hilbert Spaces joint work with Sebastian Neumayer, TU Chemnitz Gabriele Steidl, TU Berlin Nicolaj Rux, TU Berlin UCLA Level set seminar (Stan Osher) 19.08.2024

Slide 2

Slide 2 text

Goal. Recover ν ∈ P(Rd) from samples by minimizing f-divergence Df,ν to ν, e.g. KL(· | ν). Problem. Only samples ⇝ empirical measures, but µ ̸≪ ν =⇒ Df,ν (µ) = ∞. weak convergence Our Solution. Regularize Df,ν : M(Rd) → [0, ∞]. pointwise convergence “Df,ν ◦ m−1” = Gf,ν : HK → [0, ∞] λGf,ν m(µ) = min σ∈M+(Rd) Df,ν (σ) + 1 2λ ∥m(σ) − m(µ)∥2 HK , λ > 0. 1. “Kernel trick” m: M(Rd) → HK , µ → Rd K(x, ·) dµ(x) 2. Moreau envelope regularization We prove existence & uniqueness of W2 gradient flows of (λGf,ν ) ◦ m. Simulate particle flows = W2 gradient flows starting at empirical measure

Slide 3

Slide 3 text

Literature review of prior work • KALE functional = MMD-regularized KL divergence [Glaser, Arbel, Gretton. NeurIPS’21] No Moreau envelope interpretation. • Kernel methods of moments = f-divergence-regularized MMD [Kremer, Nemmour, Schölkopf, Zhu. ICML’23] Doesn’t cover all f-divergences. • (f, Γ)-divergence = Pasch-Hausdorff envelope of f-divergences. [Birrell, Dupuis, Katsoulakis, Pantazis, Rey-Bellet, JMLR’23] Yields only Lipschitz, not differentiable functional. • W1 -Moreau envelope of f-divergences [Terjék. ICML’21] No RKHS, which makes optimization finite-dimensional, hence tractable. Viktor Stein W2 Gradient Flows of MMD-Moreau Envelopes of f-Divergences in RKHSs August 19th, 2024 3 / 23

Slide 4

Slide 4 text

1. RKHS & MMD 2. Moreau envelopes 3. f-divergences 4. MMD-Moreau envelopes of f-divergences 5. Wasserstein gradient flow 6. WGF of MMD-Moreau en- velopes of f-divergences

Slide 5

Slide 5 text

Reproducing Kernel Hilbert Spaces “Kernel trick”: embed data into high-dimensional Hilbert space. K : Rd × Rd → R symmetric, positive definite. We consider radial kernels K(x, y) = ϕ(∥x − y∥2 2 ) with ϕ ∈ C∞((0, ∞)) ∩ C2([0, ∞)), (−1)kϕ(k)(r) ≥ 0, ∀k ∈ N, r > 0. ⇝ reproducing kernel Hilbert space (RKHS) HK := span({K(x, ·) : x ∈ Rd}). Key property: h → h(x) cts. Fig. 1: “Kernel trick”. Source: songcy.net/posts/story-of-basis-and-kernel-part-2/ 0.5 1 1.5 2 2.5 3 0.2 0.4 0.6 0.8 1 1.2 (1 − √ x)3 + (s + x)− 1 2 √ s exp − 1 2s x Examples (with parameter s > 0). • Gaussian ϕ(r) = exp − 1 2s r • inverse multiquadric ϕ(r) := (s + r)− 1 2 • spline ϕ(r) = max(0, (1 − √ r)s+2). Nonexamples. • Laplace ϕ(r) = exp(− 1 2s √ r) (not smooth enough) • K(x, y) = ∥x∥+∥y∥−∥x−y∥ (not radial) Viktor Stein W2 Gradient Flows of MMD-Moreau Envelopes of f-Divergences in RKHSs August 19th, 2024 5 / 23

Slide 6

Slide 6 text

Kernel mean embedding and Maximum Mean Discrepancy “Kernel trick for signed measures” µ ∈ M(Rd) (in- stead of points): kernel mean embedding (KME) m: M(Rd) → HK , µ → Rd K(x, ·) dµ(x). HK Rd M(Rd) ⟲ x → K(x, ·) x → δx m We require m to be injective (HK “characteristic”) ⇐⇒ HK ⊂ C0 (Rd) dense. ⇝ Instead of measures, compare their embeddings in HK : maximum mean discrepancy (MMD) dK : M(Rd) × M(Rd) → [0, ∞), (µ, ν) → ∥m(µ − ν)∥HK . m injective =⇒ dK is a metric, but (M(Rd), dK ) is not complete. Easy to evaluate, e.g. for discrete measures since dK (µ, ν)2 = Rd × Rd K(x, y) d(µ − ν)(x) d(µ − ν)(y) ∀µ, ν ∈ M(Rd). Viktor Stein W2 Gradient Flows of MMD-Moreau Envelopes of f-Divergences in RKHSs August 19th, 2024 6 / 23

Slide 7

Slide 7 text

1. RKHS & MMD 2. Moreau envelopes 3. f-divergences 4. MMD-Moreau envelopes of f-divergences 5. Wasserstein gradient flow 6. WGF of MMD-Moreau en- velopes of f-divergences

Slide 8

Slide 8 text

Regularization in Convex Analysis - Moreau envelopes Let (H, ⟨·, ·⟩, ∥ · ∥) Hilbert space, f ∈ Γ0 (H), i.e. f : H → (−∞, ∞] convex lower semicontinuous, dom(f) := {x ∈ H : f(x) < ∞} ̸= ∅. For ε > 0, the ε-Moreau envelope of f, εf : H → R, x → min f(x′) + 1 2ε ∥x − x′∥2 : x′ ∈ H is convex, differentiable regularization of f preserving its min- imizers. Asymptotics: εf(x) ↗ f(x) for ε ↘ 0 and εf(x) ↘ inf(f) for ε → ∞. (ε, x) → εf(x) is viscosity solution of Hamilton-Jacobi equation:      ∂ε (εf)(x) + 1 2 ∥∇(εf)(x)∥2 2 = 0, 0f(x) → f(x). [Osher, Heaton, Fung, PNAS 120, 14, 2023]. Moreau envelope of an extended-real-valued non-differentiable function (top) and of | · | for different ε (bottom). ©Trygve U. Helgaker, Pontus Giselsson Viktor Stein W2 Gradient Flows of MMD-Moreau Envelopes of f-Divergences in RKHSs August 19th, 2024 8 / 23

Slide 9

Slide 9 text

1. RKHS & MMD 2. Moreau envelopes 3. f-divergences 4. MMD-Moreau envelopes of f-divergences 5. Wasserstein gradient flow 6. WGF of MMD-Moreau en- velopes of f-divergences

Slide 10

Slide 10 text

Entropy functions We consider f ∈ Γ0 (R) with f|(−∞,0) ≡ ∞ and with unique minimizer at 1: f(1) = 0 and positive recession constant f′ ∞ := limt→∞ 1 t f(t) > 0. Examples. fKL (x) := x ln(x) − x + 1 for x ≥ 0 yields the Kullback-Leibler divergence and fα (x) := 1 α−1 (xα − αx + α − 1) the Tsallis-α divergence Tα for α > 0. In the limit: T1 = KL. −0.5 0.5 1 1.5 2 2.5 3 0.5 1 1.5 2 2.5 x ln(x) − x + 1 |x − 1| (x − 1) ln(x) x ln(x) − (x + 1) ln x+1 2 max(0, 1 − x)2 Left: Examples of entropy functions, except the red. Right: The functions fα for α ∈ [0.1, 2.5]. Viktor Stein W2 Gradient Flows of MMD-Moreau Envelopes of f-Divergences in RKHSs August 19th, 2024 10 / 23

Slide 11

Slide 11 text

f-divergences f-divergence of µ = ρν + µs ∈ M+ (Rd) (unique Lebesgue decomposition) to ν ∈ M+ (Rd) Df,ν (ρν + µs ) := Rd f ◦ ρ dν + f′ ∞ · µs (Rd) (∞ · 0 := 0) = sup h∈Cb(Rd;dom(f∗)) E µ [h] − E ν [f∗ ◦ h], E σ [h] := Rd h(x) dσ(x) The convex conjugate of f is f∗ : R → (−∞, ∞], s → sup {st − f(t) : t ≥ 0} . Theorem (Properties of Df,ν ) Df,ν : M+ (Rd) → [0, ∞] is convex, weak* lower semicontinuous. We have: Df,ν (µ) = 0 ⇐⇒ µ = ν. Viktor Stein W2 Gradient Flows of MMD-Moreau Envelopes of f-Divergences in RKHSs August 19th, 2024 11 / 23

Slide 12

Slide 12 text

MMD-Regularized f-divergence - Moreau envelope interpretation We define the MMD-regularized f-divergence functional Dλ f,ν (µ) := min Df,ν (σ) + 1 2λ dK (µ, σ)2 : σ ∈ M(Rd) , λ > 0, µ ∈ M(Rd). (1) Theorem (Moreau envelope interpretation of Dλ f,ν [NSSR24]) The HK -extension of Df,ν , Gf,ν : HK → [0, ∞], h →      Df,ν (µ), if ∃µ ∈ M+ (Rd) s.t. h = m(µ), ∞, else. is convex, lower semicontinuous and its Moreau envelope concatenated with m is the MMD-regularized f-divergence: λGf,ν ◦ m = Dλ f,ν [0, ∞) HK M(Rd) [0, ∞] Gf,ν λGf,ν m Df,ν Dλ f,ν Viktor Stein W2 Gradient Flows of MMD-Moreau Envelopes of f-Divergences in RKHSs August 19th, 2024 12 / 23

Slide 13

Slide 13 text

Properties of Dλ f,ν (Properties of Dλ f,ν) [NSSR24] • Dual formulation Dλ f,ν (µ) = max E µ [p] − E ν [f∗ ◦ p] − λ 2 ∥p∥2 HK : p ∈ HK , p ≤ f′ ∞ . (2) ˆ p ∈ HK maximizes (2) ⇐⇒ ˆ g = m(µ) − λˆ p is primal solution. λ 2 ∥ˆ p∥2 HK ≤ Dλ f,ν (µ) ≤ ∥ˆ p∥HK (∥mµ ∥HK + ∥mν ∥HK ) and ∥ˆ p∥HK ≤ 2 λ dK (µ, ν). • Dλ f,ν is Fréchet differentiable on M(Rd) and its gradient is λ-Lipschitz with respect to dK : ∇Dλ f,ν (µ) = argmax (2). Viktor Stein W2 Gradient Flows of MMD-Moreau Envelopes of f-Divergences in RKHSs August 19th, 2024 13 / 23

Slide 14

Slide 14 text

Theorem. (Properties of Dλ f,ν) [NSSR24] • Asymptotic regimes: Mosco resp. pointwise convergence (if 0 ∈ int(dom(f∗)) resp. f∗ differentiable in 0) Dλ f,ν → Df,ν λ ↘ 0 and (1 + λ)Dλ f,ν → 1 2 dK (·, ν)2 λ → ∞ • Divergence property: Dλ f,ν (µ) = 0 ⇐⇒ µ = ν. • If f∗ is differentiable in 0, then (µ, ν) → Dλ f,ν (µ) metrizes weak convergence on M+ (Rd)-balls. Viktor Stein W2 Gradient Flows of MMD-Moreau Envelopes of f-Divergences in RKHSs August 19th, 2024 14 / 23

Slide 15

Slide 15 text

1. RKHS & MMD 2. Moreau envelopes 3. f-divergences 4. MMD-Moreau envelopes of f-divergences 5. Wasserstein gradient flow 6. WGF of MMD-Moreau en- velopes of f-divergences

Slide 16

Slide 16 text

Wasserstein space and generalized geodesics P2 (Rd) := {µ ∈ P(Rd) : Rd ∥x∥2 2 < ∞}, ∥ · ∥2 Eucl. norm. W2 (µ, ν)2 = min π∈Γ(µ,ν) Rd × Rd ∥x − y∥2 2 dπ(x, y), µ, ν ∈ P2 (Rd). Fig. 2: Vertical (L2 ) vs. horizontal (W2 ) mass displacement. ©A. Korba Fig. 3: Generalized geodesic from µ2 to µ3 with base µ1 [AGS08]. Definition (Generalized geodesic convexity) A function F : P2 (Rd) → (−∞, ∞] is M-convex along generalized geodesics if, for every σ, µ, ν ∈ dom(F), there exists a α ∈ P2 (R3d)with (P1,2 )# α ∈ Γopt(σ, µ) and (P1,3 )# α ∈ Γopt(σ, ν) such that F (1−t)P2 +tP3 # α ≤ (1−t) F(µ)+t F(ν)− M 2 t(1−t) Rd × Rd × Rd ∥y−z∥2 2 dα(x, y, z), ∀t ∈ [0, 1]. Viktor Stein W2 Gradient Flows of MMD-Moreau Envelopes of f-Divergences in RKHSs August 19th, 2024 16 / 23

Slide 17

Slide 17 text

Wasserstein gradient flows Definition (Fréchet subdifferential in Wasserstein space) The (reduced) Fréchet subdifferential of F : P2 (Rd) → (−∞, ∞] at µ ∈ dom(F) is ∂ F(µ) := ξ ∈ L2(Rd; µ) : F(ν) − F(µ) ≥ inf π∈Γopt(µ,ν) Rd × Rd ⟨ξ(x1 ), x2 − x1 ⟩ dπ(x, y) + o(W2 (µ, ν)) A curve γ : (0, ∞) → P2 (Rd) is absolutely continuous if ∃ L2-Borel velocity field v : Rd ×(0, ∞) → Rd s.t. ∂t γt + ∇ · (vt γt ) = 0, (t, x) ∈ (0, ∞) × Rd, weakly. (Continuity Eq.) Definition (Wasserstein gradient flow) A locally absolutely continuous curve γ : (0, ∞) → P2 (Rd) with velocity field vt ∈ Tγt P2 (Rd) is a Wasserstein gradient flow with respect to F : P2 (Rd) → (−∞, ∞] if vt ∈ −∂ F(γt ), for a.e. t > 0. ©Petr Mokrov Viktor Stein W2 Gradient Flows of MMD-Moreau Envelopes of f-Divergences in RKHSs August 19th, 2024 17 / 23

Slide 18

Slide 18 text

Wasserstein Gradient Flow with respect to Dλ f,ν Theorem (Convexity and gradient of Dλ f,ν [NSSR24]) Since K being radial and smooth, Dλ f,ν is M-convex along generalized geodesics with M := −8λ−1 (d + 2)ϕ′′(0)ϕ(0) and its (reduced) Fréchet subdifferential is ∂Dλ f,ν (µ) = {∇ argmax (2)}. Remark. M seems non-optimal, since for λ → 0, Dλ f,ν → Df,ν and Df,ν is 0-convex, but M → −∞. Corollary There exists a unique Wasserstein gradient flow (γt )t>0 of Dλ f,ν starting at µ0 ∈ P2 (Rd), fulfilling the continuity equation ∂t γt = ∇ · γt ∂Dλ f,ν (γt ) , γ0 = µ0 . Lemma (Particle flows are W2 gradient flows) If µ0 is empirical, then so is γt for all t > 0. Viktor Stein W2 Gradient Flows of MMD-Moreau Envelopes of f-Divergences in RKHSs August 19th, 2024 18 / 23

Slide 19

Slide 19 text

Numerical Experiments - Particle Descent Algorithm Take i.i.d. samples (x(0) j )N j=1 ∼ µ0 and (yj )M j=1 ∼ ν. Forward Euler discretization in time with step size τ > 0 yields γn+1 := (id −τ∇ˆ pn )# γn , ˆ pn = argmax in Dλ f,ν (γn ) so (γn )n∈N = 1 N N j=1 δ x(n) j with gradient step x(n+1) j = x(n) j − τ∇ˆ pn x(n) j , j ∈ {1, . . . , N}, n ∈ N . Theorem (Representer-type theorem [NSSR24]) If f′ ∞ = ∞ or if λ > 2dK (γn , ν) ϕ(0) 1 f′ ∞ , then finding ˆ pn is a finite-dimensional strongly convex problem. To find ˆ pn , we use L-BFGS-B, a quasi-Newton method. We use annealing strategy for λ if f′ ∞ < ∞. Viktor Stein W2 Gradient Flows of MMD-Moreau Envelopes of f-Divergences in RKHSs August 19th, 2024 19 / 23

Slide 20

Slide 20 text

Numerical experiments Fig. 4: IMQ kernel, λ = 1 100 τ = 1 1000 , Top: Tsallis-3 divergence, Bottom: Tsallis- 1 2 divergence, with annealing. Fig. 5: Number of starting particles N, less than number of samples of target, M ⇝ quantization Viktor Stein W2 Gradient Flows of MMD-Moreau Envelopes of f-Divergences in RKHSs August 19th, 2024 20 / 23

Slide 21

Slide 21 text

Further work • Non-differentiable (e.g. Laplace = 1 2 -Matérn) and unbounded (e.g. Riesz, Coulomb) kernels. • Convergence rates in suitable metric. • Prove consistency bounds [Leclerc, Mérigot, Santambrogio, Stra. 2020] and better M-convexity estimates. • Convergence for annealing strategy? • Different domains, e.g. compact subsets of Rd (manifolds like sphere, torus), groups, infinite-dimensional spaces. • Regularize other divergences, e.g. Rényi divergences, Bregman divergences. • Gradient flow of Dλ f,ν with respect to other metrics, like Kantorovich-Hellinger (related to unbalanced OT), MMD, Fisher-Rao or Wasserstein-p for p ∈ [1, ∞]. • More elaborate time discretizations, variable step sizes. Viktor Stein W2 Gradient Flows of MMD-Moreau Envelopes of f-Divergences in RKHSs August 19th, 2024 21 / 23

Slide 22

Slide 22 text

Conclusion • We created novel objective. Minimizing it allows sampling from a target measure of which only samples are known. • Clear, rigorous interpretation using Convex Analysis and RKHS. • Theory covers (almost) all f-divergences. • Best of both worlds: Dλ f,ν interpolates between Df,ν and dK (·, ν)2. • Effective algorithms due to (modified) representer theorem & GPU / PyTorch. Viktor Stein W2 Gradient Flows of MMD-Moreau Envelopes of f-Divergences in RKHSs August 19th, 2024 22 / 23

Slide 23

Slide 23 text

Thank you for your attention! I am happy to take any questions. Paper link: arxiv.org/abs/2402.04613 My website: viktorajstein.github.io [AGS08, BDK+22, GAG21, HWAH24, KYSZ23, LMSS20, LMS17, Ter21] Viktor Stein W2 Gradient Flows of MMD-Moreau Envelopes of f-Divergences in RKHSs August 19th, 2024 23 / 23

Slide 24

Slide 24 text

References I [AGS08] Luigi Ambrosio, Nicola Gigli, and Giuseppe Savaré, Gradient flows: in metric spaces and in the space of probability measures, 2 ed., Springer Science & Business Media, 2008. [BDK+22] Jeremiah Birrell, Paul Dupuis, Markos A. Katsoulakis, Yannis Pantazis, and Luc Rey-Bellet, (f, Γ)-divergences: Interpolating between f-divergences and integral probability metrics, J. Mach. Learn. Res. 23 (2022), no. 39, 1–70. [GAG21] Pierre Glaser, Michael Arbel, and Arthur Gretton, KALE flow: A relaxed KL gradient flow for probabilities with disjoint support, Advances in Neural Information Processing Systems (Virtual event), vol. 34, 6–14 Dec 2021, pp. 8018–8031. [HWAH24] J. Hertrich, C. Wald, F. Altekrüger, and P. Hagemann, Generative sliced MMD flows with Riesz kernels, International Conference on Learning Representations (ICLR) (Vienna, Austria), 7 – 11 May 2024. Viktor Stein W2 Gradient Flows of MMD-Moreau Envelopes of f-Divergences in RKHSs August 19th, 2024 1 / 3

Slide 25

Slide 25 text

References II [KYSZ23] H. Kremer, Nemmour Y., B. Schölkopf, and J.-J. Zhu, Estimation beyond data reweighting: kernel methods of moments, ICML’23: Proceedings of the 40th International Conference on Machine Learning (Honolulu, Hawaii, USA), vol. 202, July 23 - 29 2023, p. 17745–17783. [LMS17] Matthias Liero, Alexander Mielke, and Giuseppe Savaré, Optimal entropy-transport problems and a new Hellinger–Kantorovich distance between positive measures, Invent. Math. 211 (2017), no. 3, 969–1117. [LMSS20] Hugo Leclerc, Quentin Mérigot, Filippo Santambrogio, and Federico Stra, Lagrangian discretization of crowd motion and linear diffusion, SIAM J. Numer. Anal. 58 (2020), no. 4, 2093–2118. MR 4123686 [Ter21] Dávid Terjék, Moreau-Yosida f-divergences, International Conference on Machine Learning (ICML) (Virtual event), PMLR, Jul 18–24 2021, pp. 10214–10224. Viktor Stein W2 Gradient Flows of MMD-Moreau Envelopes of f-Divergences in RKHSs August 19th, 2024 2 / 3

Slide 26

Slide 26 text

Shameless plug: other works Interpolating between OT and KL regularized OT using Rényi Divergences Rényi divergence ̸∈ {f-div., Bregman div.}, α ∈ (0, 1) Rα (µ | ν) := 1 α − 1 ln X dµ dτ α dν dτ 1−α dτ , OTε,α (µ, ν) := min π∈Π(µ,ν) ⟨c, π⟩ + εRα (π | µ ⊗ ν) is a metric, where ε > 0, µ, ν ∈ P(X), X compact. OT(µ, ν) α↘0 ← − − − − or ε→0 OTε,α (µ, ν) α↗1 − − − → OTKL ε (µ, ν). In the works: debiased Rényi-Sinkhorn divergence OTε,α (µ, ν) − 1 2 OTε,α (µ, µ) − 1 2 OTε,α (ν, ν). W2 gradient flows of dK (·, ν)2 with K(x, y) := −|x − y| in 1D. Reformulation as maximal monotone inclu- sion Cauchy problem in L2 (0, 1) via quantile functions. Comprehensive description of solutions’ behav- ior, instantaneous measure-to-L∞ regular- ization, implicit Euler is simple. Viktor Stein W2 Gradient Flows of MMD-Moreau Envelopes of f-Divergences in RKHSs August 19th, 2024 3 / 3 −1 −0.5 0.5 1 1.5 2 1 2 3 µ0 8 6 4 2 0 2 4 6 8 0.00 0.05 0.10 0.15 0.20 0.25 0.30 0.35 0.40 Iteration 0 initial target explicit implicit