Deep Markov Model を数式で追う (+ Pyroでの追試)

Deep Markov Model を数式で追う (+ Pyroでの追試)

D360f2c7f29fa7e28e3e648031fbe2f3?s=128

Koga Kobayashi

August 08, 2020
Tweet

Transcript

  1. Deep Markov Model Λ਺ࣜͰ௥͏(+ PyroͰͷ௥ࢼ) খྛᕣՏ @kajyuuen

  2. ࣗݾ঺հ খྛ ᕣՏ (Koga Kobayashi) GitHub, Twitter ID: @kajyuuen ஜ೾େֶେֶӃ

    म࢜2೥ ීஈ͸ࣗવݴޠॲཧɺಛʹNERΛݚڀ͍ͯ͠·͢
  3. ͜ͷൃදʹ͍ͭͯ • PPLॳ৺ऀ͕PyroνϡʔτϦΞϧͷDMMͷ਺ࣜΛৄࡉʹ௥ͬͯΈΔ • ࠓճͷSLTͰ࿩͞ͳ͍(࿩ͤͳ͍)͜ͱ • ELBOͷಋग़ʹ͍ͭͯ • Ϟσϧɺۙࣅ෼෍ͦͷ΋ͷͷ֓೦ʹ͍ͭͯͷઆ໌ •

    ֬཰తม෼ਪ࿦(SVI; Stochastic Variational Inference)ʹ͍ͭͯ
  4. Hidden Markov Model p(X, Z, θ, π, A) = N

    ∏ n=1 {p(θ)p(π)p(A)p(x(n) 1 |z(n) 1 , θ)p(z(n) 1 |π) T ∏ i=2 p(xn |zn , θ)p(zn |zn−1 , A)} ؍ଌσʔλ ɺજࡏม਺ ʹ͍ͭͯߟ͑Δ X = {x(1), ⋯, x(N)} Z = {z(1), ⋯, z(N)} z(n) 1 z(n) 2 x(n) 2 xT ⋯ z(n) T x(n) T π A θ ⋯ ⋯ ⋯ Transition Emitter Init state z1 ∼ Cat(z1 |π) π ∼ Dir(π|α) zn ∼ K ∏ i=1 Cat(zn |A:,i )zn−1,i x(n) 1 A:,i ∼ Dir(A:,i , β:,i ) xi ∼ K ∏ k=1 Bern(xi |θk )zi,k θk ∼ Beta(θk |a, b)
  5. Deep Markov Model z(n) 0 z(n) 1 x(n) 1 xT

    ⋯ Gated Transition Emitter z(n) T x(n) T ҎԼɺ Λলུͯ͠ ͱॻ͘ z(n) z p(X, Z, ψ, ξ) = N ∏ n=1 {p(z(n) 0 ) T ∏ i=1 p(zi |MLPtrans (zi−1 ; ψ)) × p(xi |MLPemit (zi ; ξ))} ؍ଌσʔλ ɺજࡏม਺ ʹ͍ͭͯߟ͑Δ X = {x(1), ⋯, x(N)} Z = {z(1), ⋯, z(N)} : MLP DMMͰ͸৚݅෇͖֬཰ʹMLPΛಋೖ͢Δ͜ͱͰ HMMʹൺ΂ɺΑΓ๛͔ͳදݱ͕ՄೳʹͳΔ
  6. Emitter MLPemit (zi ) = μi h(1) i = ReLU(W(1)zi

    ) + b(1) h(2) i = ReLU(W(2)h(1) i ) + b(2) μi = Sigmoid(W(2)h(2) i ) + b(3) xi ∼ Bern(xi |μi ) zi xi Emitter ൪໨ͷજࡏม਺ ͔Β؍ଌσʔλ Λੜ੒͢Δ֬཰ i zi xi ͕ै͏෼෍͸ѻ͏σʔλʹΑͬͯҟͳΔ xi
  7. Gated Transition MLPσi (zi−1 ) = σi zi ∼ Normal(μi

    , σi ) MLPμi (zi−1 ) = μi h(1) i = ReLU(W(1) h zi−1 ) + b(1) h μi = (1 − g(2) i ) ⊙ (Wμ zi−1 + bμ ) + g(2) i ⊙ h(2) i g(1) i = ReLU(W(1) g zi−1 ) + b(1) g g(2) i = Sigmoid(W(2) g g(a) i ) + b(2) g h(2) i = W(2) h h(1) i + b(2) h σi = Softplus(Wσ ReLU(h(2) i ) + bσ ) zi−1 zi Gated Transition ൪໨ͷજࡏม਺ ͔Β ൪໨ͷજࡏม਺ Λੜ੒͢Δ֬཰ i − 1 zi−1 i zi
  8. ۙࣅ෼෍: Guide જࡏม਺ ͷۙࣅ෼෍ ʹΨ΢ε෼෍ΛԾఆ͢Δɻ Z = {z(1), ⋯, z(N)}

    q( ⋅ ) q(Z|X, ψ) = N ∏ n=1 Normal(z(n) |x(n) , ψ) ঈ٫ਪ࿦Λద༻͠ɺม෼ύϥϝʔλΛؔ਺ ʹΑͬͯճؼͯ͠ٻΊΔɻ f( ⋅ ) q(Z|X, ψ) = N ∏ n=1 Normal(z(n) |f(x(n) ; ψ)) ͜ͷ ΛCombinerͱݺͿɻ f( ⋅ )
  9. ۙࣅ෼෍: Guide zT xT (μT , σT ) hT zq

    0 z1 x1 (μ1 , σ1 ) h1 z2 x2 (μ2 , σ2 ) h2 ⋯ ⋯ ⋯ ⋯ Combiners RNN MLPσi (zi−1 ) = σi hcombined i = 1 2 {tanh(Wh zi−1 + b(1)) + hhmm i } μi = Wμ hcombined i + bμ zi ∼ Normal(μi , σi ) MLPμi (zi−1 ) = μi σi = Softplus(Wσ hcombined i + bσ )
  10. ར༻ͨ͠σʔλ • ෳ਺ͷԻූ͔Βߏ੒͞ΕΔָۂσʔλɻ • 4෼Իූ͝ͱʹ۠੾ͬͯɺͦ͜ʹؚ·ΕΔෳ਺ͷԻූ͔Βߏ੒͞ΕΔ ϕΫτϧ Λೖྗͱ͢Δɻ x ∈ ℝ88

    88伴൫ ྫ: x = {0,0,1,…,1,0}
  11. ࣮ݧ݁Ռ

  12. ײ૝ • જࡏม਺͕࿈ଓ஋ͩͬͨͨΊɺ HMMΑΓ͸ΧϧϚϯϑΟϧλʹ͍ۙΑ͏ͳҹ৅Λड͚ͨɻ • ࣮ݧࣗମ͸͏·͍ͬͨ͘ײ͕͡͠ͳ͍… ಛʹELBO͕΄΅ઢܗʹ্͕͍ͬͯΔͷʹ͸ͳʹ͔ݪҼ͕͋Γͦ͏ɻ • PPL͸ͱͯ΋໘ന͍͕ɺ΍͸Γਂ૚ֶशͱ͔ʹൺ΂Δͱ৘ใ͕গͳ͍ɻ (Ԡ༻ઌ΍࠷ઌ୺ͷϞσϧʹ͍ͭͯڭ͍͖͍͑ͯͨͩͨͰ͢)

  13. ࢀߟจݙ [1] ػցֶशελʔτΞοϓγϦʔζ ϕΠζਪ࿦ʹΑΔػցֶशೖ໳, ਢࢁ ರࢤ, ਿࢁ ক [2] ػցֶशϓϩϑΣογϣφϧγϦʔζ

    ϕΠζਂ૚ֶश, ਢࢁ ರࢤ [3] ࣗવݴޠॲཧͷͨΊͷਂ૚ֶश, Goldberg, Yoav, ଞ [4] Structured Inference Networks for Nonlinear State Space Models, Rahul G. Krishnan, Uri Shalit, David Sontag [5] Deep Markov Model(http://pyro.ai/examples/dmm.html)