Upgrade to Pro — share decks privately, control downloads, hide ads and more …

論文紹介:Proximity Variational Inference (近接性変分推論)

論文紹介:Proximity Variational Inference (近接性変分推論)

Takahiro Kawashima

May 19, 2022
Tweet

More Decks by Takahiro Kawashima

Other Decks in Research

Transcript

  1. 論文概要 論文情報 • Title: Proximity Variational Inference • Author: Jaan

    Altosaar, Rajesh Ranganath, David Blei • Published in: AISTATS 2018 TL;DR: 変分推論に距離構造(近接性:proximity)を入れて一般化し, 学習性能を向上 2
  2. Factor Models 例として,次のような factor model を考える: 𝑧𝑖𝑘 ∼ Bernoulli(𝜋) 𝑥𝑖

    |{𝑧𝑖𝑘 }𝑘 ∼ N (∑ 𝑘 𝑧𝑖𝑘 𝜇𝑘 , 1) 𝑖 はサンプル,𝑘 は “因子” のインデックス 𝑧𝑖𝑘 ∈ {0, 1}:𝑖 番目のサンプルが 𝑘 番目の因子に依存するか否か 潜在変数 𝑧𝑖𝑘 に変分事後分布 𝑞(𝑧𝑖𝑘 ; 𝜆𝑖𝑘 ) = Bernoulli(𝜆𝑖𝑘 ) を与え, 𝜇𝑘 は(たぶん経験ベイズ的に)ELBO の勾配をとって学習 3
  3. Factor Models の学習 Factor Models 𝑧𝑖𝑘 ∼ Bernoulli(𝜋), 𝑥𝑖 |{𝑧𝑖𝑘

    }𝑘 ∼ N (∑ 𝑘 𝑧𝑖𝑘 𝜇𝑘 , 1) 変分 EM の方法により,変分事後分布 𝑞(𝑧𝑖𝑘 ; 𝜆𝑖𝑘 ) = Bernoulli(𝜆𝑖𝑘 ) のパラメータは 𝜆𝑖𝑘 ∝ exp {− 1 2𝜎2 𝔼\𝑧𝑖𝑘 [( 𝑥𝑖 − ∑ 𝑗 𝑧𝑖𝑗 𝜇𝑗 )2]} , と学習でき,ELBOL の 𝜇𝑘 による勾配は 𝜕L 𝜕𝜇𝑘 = − 1 𝜎2 ∑ 𝑖 𝜆𝑖𝑘 (−𝑥𝑖 + 𝜇𝑘 + ∑ 𝑗≠𝑘 𝜆𝑖𝑗 𝜇𝑗 ) ととれる. 4
  4. Factor Models の変分推論の難しさ Factor Models の学習則 𝜆𝑖𝑘 ∝ exp{−(2𝜎2)−1𝔼\𝑧𝑖𝑘 [(

    𝑥𝑖 − ∑ 𝑗 𝑧𝑖𝑗 𝜇𝑗 )2]}, (1) 𝜕L/𝜕𝜇𝑘 = −𝜎−2 ∑ 𝑖 𝜆𝑖𝑘 (−𝑥𝑖 + 𝜇𝑘 + ∑ 𝑗≠𝑘 𝜆𝑖𝑗 𝜇𝑗 ) (2) • {𝜇𝑘 } の初期値がデータと大きく乖離している場合,(1) 内の 二乗距離が大きくなる  𝑞(𝑧𝑖𝑘 = 1|𝜆𝑖𝑘 ) が非常に小さくなる • {𝜆𝑖𝑘 } が小さいと (2) の勾配が小さくなる  学習速度の低下 5
  5. 変分推論再考 一般に観測 𝒙 と潜在変数 𝒛 に対し 𝑝(𝒙, 𝒛) なるモデルを考え, 変分事後分布

    𝑞(𝒛; 𝝀) を学習する.𝝀 の勾配上昇法では 𝝀𝒕+1 = 𝝀𝑡 + 𝜌∇L(𝝀𝑡 ) と更新される.これは次の目的関数 𝑈 の最大化ともみなせる: 𝑈(𝝀𝑡+1 ) = L(𝝀𝑡 ) + ∇L(𝝀𝑡 )⊤(𝝀𝑡+1 − 𝝀𝑡 ) − 1 2𝜌 ‖𝝀𝑡+1 − 𝝀𝑡 ‖2 さらに max 𝑈(𝝀𝑡+1 ) の双対問題として max 𝝀𝑡+1 L(𝝀𝑡 ) + ∇L(𝝀𝑡 )⊤(𝝀𝑡+1 − 𝝀𝑡 ) s.t. ‖𝝀𝑡+1 − 𝝀𝑡 ‖2 ≤ 𝐶 が導かれる1. 1双対問題の目的関数は 𝝀𝑡 のまわりでの ELBO の一次近似になっている. 7
  6. Proximity Variational Inference へ 変分パラメータ最適化の双対問題 max 𝝀𝑡+1 L(𝝀𝑡 ) +

    ∇L(𝝀𝑡 )⊤(𝝀𝑡+1 − 𝝀𝑡 ) s.t. ‖𝝀𝑡+1 − 𝝀𝑡 ‖2 ≤ 𝐶 提案手法 Proximity Variational Inference (PVI) のアイデア: ∠ 上の問題に proximity constraints(近接性制約)を加える 1. proximity statistic𝑓(⋅) でパラメータを適当な空間へ飛ばす 2. 飛ばした先の空間で距離 𝑑(𝑓(𝝀𝑡 ), 𝑓(𝝀𝑡+1 )) を測る 3. 𝑑(𝑓(𝝀𝑡 ), 𝑓(𝝀𝑡+1 )) に関する制約を付け足す 8
  7. Proximity Variational Inference 提案手法 PVI の目的関数を次で定義: PVI の目的関数 𝑈(𝝀𝑡+1 )

    = L(𝝀𝑡 ) + ∇L(𝝀𝑡 )⊤(𝝀𝑡+1 − 𝝀𝑡 ) − 1 2𝜌 ‖𝝀𝑡+1 − 𝝀𝑡 ‖2 − 𝑘 ⋅ 𝑑(𝑓( ̃ 𝝀𝑡 ), 𝑓(𝝀𝑡+1 )) ここで ̃ 𝝀𝑡 は 𝝀1 , … , 𝝀𝑡 から適当に構成する量. ̃ 𝝀𝑡 の例:指数移動平均 ̃ 𝝀𝑡 = 𝛼 ̃ 𝝀𝑡−1 + (1 − 𝛼)𝝀𝑡 . ∠ 各更新において misstep への頑健性の向上が見込める 9
  8. PVI のアルゴリズム Algorithm 1: Proximity Variational Inference Input : 𝝀0

    , 𝑓(⋅), 𝑑(⋅, ⋅) Output: 𝝀 while L not converged do 𝝀𝑡+1 ← 𝝀𝑡 + Noise while 𝑈 not converged do Update 𝝀𝑡+1 ← 𝝀𝑡+1 + 𝜌∇𝑈(𝝀𝑡+1 ) end 𝝀𝑡 ← 𝝀𝑡+1 end return 𝝀 ELBO L(𝝀) の線形近似を inner loop で最適化するイメージ? 10
  9. proximity functions proximity function 𝑓(⋅) の具体例 • エントロピー:𝑓(𝝀) = ℍ[𝑞(𝒛;

    𝝀)] ∠ 変分推論 (reverse KL) の“zero-forcing” 性の改善(後述) • KL ダイバージェンス ∠ L(𝝀) = 𝔼𝑞 [log 𝑝(𝒙|𝒛)] − KL(𝑞𝝀 ‖𝑝) において,柔軟なモデルで は KL の項がすぐに poor optima に引っかかりやすい. 尤度項に対する 𝐾𝐿 項の強さを調整し,この問題を改善 • 平均・分散 • 直交性:𝑓(𝑊) = 𝑊𝑊⊤ • VAE などの NN 向け.重みに直交性を課すと学習がしやすく なるらしい.初期値を 𝑊𝑊⊤ = 𝐼 とする. 11
  10. zero-forcing 性と entropy proximity functions 変分推論では次の reverse KL をパラメータ 𝝀

    について最小化: KL(𝑞‖𝑝) = ∫ 𝑞(𝒛; 𝝀) log 𝑞(𝒛; 𝝀) 𝑝(𝒛|𝒙) 𝑑𝒛 𝑞(𝒛; 𝝀) が 𝒛 = 𝟎 周辺に “縮退” していると KL ≈ 0 に ∠  poor local minima にトラップ(zero-forcing 性) proximity function 𝑓(⋅) に 𝑞(𝒛; 𝝀) のエントロピーを設定し, エントロピーの大きい初期値を与える ∠  𝑞(𝒛; 𝝀) の縮退を防げる 12
  11. 計算量の改善 ナイーブな PVI のアルゴリズムは二重のループでつらい ∠ PVI の制約を一次近似して解析的に解く (Fast PVI) 𝑑(𝑓(

    ̃ 𝝀𝑡 ), 𝑓(𝝀𝑡+1 )) ≈ 𝑑(𝑓( ̃ 𝝀𝑡 ), 𝑓(𝝀𝑡 )) + ∇𝑑(𝑓( ̃ 𝝀𝑡 ), 𝑓(𝝀𝑡 ))∇𝑓(𝝀𝑡 )⊤(𝝀𝑡+1 − 𝝀𝑡 ) の右辺を左辺の代わりに PVI の目的関数 𝑈(𝝀𝑡+1 に組み込むと, 𝝀𝑡+1 = 𝝀𝑡 + 𝜌(∇L(𝝀𝑡 ) − 𝑘∇𝑑(𝑓( ̃ 𝝀𝑡 ), 𝑓(𝝀𝑡 ))∇𝑓(𝝀𝑡 )) なる解析的な更新則が得られる ∠  一重のループで済む 13
  12. Fast PVI のアルゴリズム Algorithm 2: Fast Proximity Variational Inference Input

    : 𝝀0 , 𝑓(⋅), 𝑑(⋅, ⋅) Output: 𝝀 while Lproximity not converged do 𝝀𝑡+1 = 𝝀𝑡 + 𝜌(∇L(𝝀𝑡 ) − 𝑘∇𝑑(𝑓( ̃ 𝝀𝑡 ), 𝑓(𝝀𝑡 ))∇𝑓(𝝀𝑡 )) ̃ 𝝀𝑡+1 = 𝛼 ̃ 𝝀𝑡 + (1 − 𝛼)𝝀𝑡+1 𝑡 ← 𝑡 + 1 end return 𝝀 Lproximity は Fast PVI の大域的な目的関数2で, Lproximity (𝝀𝑡+1 ) = 𝔼𝑞 [log 𝑝(𝒙, 𝒛)] − 𝔼𝑞 [log 𝑞(𝒛; 𝝀𝑡+1 )] − 𝑘 ⋅ 𝑑(𝑓( ̃ 𝝀𝑡 ), 𝑓(𝝀𝑡+1 )). 2𝑑(⋅, ⋅) ≥ 0 より,通常の ELBO L を modify した Lproximity も依然 エビデンスの lower bound. 14
  13. ハイパラ設定 • proximity statistic 𝑓(⋅): 各実験にあわせていろいろ変える • 距離関数 𝑑(⋅, ⋅):

    inverse Huber distance 𝑑(𝑥, 𝑦) = { |𝑥 − 𝑦| (|𝑥 − 𝑦| < 1), 0.5(𝑥 − 𝑦)2 + 0.5 (otherwise). • ̃ 𝝀𝑡 : 指数移動平均 with 𝛼 = 0.9999 • proximity condition の係数 𝑘: ELBO の初期値の絶対値 • Adam で最適化 15
  14. 実験 1: MNIST with Sigmoid Belief Nets Sigmoid Belief Nets

    を MNIST から学習し,validation set から ELBO/対数周辺尤度を算出3 1 層 200 ユニットの SBN 3 層 ×200 ユニットの SBN ∠  PVI with entropy constraint が最良 3たぶん classification はやってない. 16
  15. 実験 2: Binary MNIST with VAE 100 ユニットの潜在変数,2 層 ×200

    ユニットの隠れ層,ReLU からなる VAE で binary MNIST を学習.𝑘 は複数候補から選択. VAE による結果 ∠  PVI with orthogonal constraint がベター 上記結果の考察として,ReLU の overpruning 問題の改善を指摘 ∠  “dead ReLU” が隠れ層/出力層でそれぞれ 1.6%/3.2% 削減 17
  16. 実験 3: テキスト生成モデル 単語の出現回数を Poisson 潜在変数でモデリング: 𝒛 ∼ Poisson(𝝀), 𝒙|𝒛

    ∼ Poisson(𝑔(𝑊)⊤𝒛). ここで 𝑊 ∈ ℝ𝐾×𝑉 はモデルパラメータ,𝑔(𝑊) は element-wise softplus. V はボキャブラリ数. 100 次元の潜在変数を設定し,Science 誌のコーパスから学習 ∠ #train/#test: 138K/1K documents, #terms: 5.9K 評価時,𝑊 は固定のまま test set の各 document 内 10% の単語 で 𝝀 を学習し,残りの 90% から perplexity を算出 18
  17. 実験 3: テキスト生成モデル 学習結果. PVI がベター. PVI で学習された 100 次元潜在変数のうち

    3 つについて,関連の強い 単語上位 10 個を表示.学習結果の妥当性がわかる. 19
  18. むすび まとめ: • 変分推論を proximity constraints によって拡張し,学習性能 を向上 • entropy/KL/mean-variance/orthogonal

    の 4 種の proximity statisics の提案 • proximity constraints の一次近似による効率的な学習 • 実験的に ELBO/周辺尤度の改善を確認 所感: 従来の変分推論の性能を劇的に改善! とまではいかなそう. 変分推論に対して拡張可能性の高そうな perspective を与えた ことが本質的っぽい. 20