関数と型で理解する自動微分

315791eefd7b239aaae562274dfc0e75?s=47 lotz
November 09, 2019

 関数と型で理解する自動微分

315791eefd7b239aaae562274dfc0e75?s=128

lotz

November 09, 2019
Tweet

Transcript

  1. ؔ਺ͱܕͰཧղ͢Δࣗಈඍ෼ 2019/11/9 lotz
 at Haskell Day 2019

  2. ࣗݾ঺հ • HaskellͱػցֶशΛ
 झຯͰ΍͍ͬͯΔऀͰ͢ • ීஈ͸൒ଂ໳Ͱಇ͘ΤϯδχΞ • Twitter: @lotz84_
 ࣭໰͕͋Ε͹ͪ͜Β·Ͱ

    • GitHub: @lotz84
  3. ࠓ೔ͷΰʔϧ • ࣗಈඍ෼Λମײͯ͠஌ͬͯ΋Β͏ • HaskellͷͲΜͳٕज़͕׆༻͞Ε͍ͯΔͷ͔
 ཧղͯ͠΋Β͏ • ࣗಈඍ෼͸ָ͍͠ͱײͯ͡΋Β͏ʂ

  4. ඍ෼Մೳͳؔ਺ϓϩάϥϛϯά Beyond Functional Programming

  5. ಋؔ਺͕ཉ͍࣌͋͠Γ·͢ΑͶʁ • ϙςϯγϟϧؔ਺͕༩͑ΒΕͨͷͰ
 ܥͷ࣌ؒൃలΛௐ΂͍ͨ • ֶशϞσϧΛ࡞ͬͨͷͰ
 ޯ഑๏Ͱύϥϝʔλਪఆ͍ͨ͠ • ͱʹ͔ؔ͘਺͕༩͑ΒΕͨͷͰඍ෼͠ͳ͍ͱ ؾ͕͢·ͳ͍ʂʂ

    ෺ཧγϛϡϨʔγϣϯ ػցֶश
  6. Ͱ΋ෳࡶͳඍ෼ܭࢉ͸ͨ͘͠ͳ͍ʂ f(x) = σ (W3 ϕ (W2 ϕ (W1 x

    + b1) + b2) + b3) U = (m1 + m2 )gl1 (1 − cos θ1 ) + m2 gl2 (1 − cos θ2 )
  7. ୀ۶ͳ͜ͱ͸
 Haskellʹ΍ΒͤΑ͏

  8. ࣗಈඍ෼ ϓϩάϥϜͰఆٛ͞Εͨؔ਺ͷಋؔ਺Λ
 ϓϩάϥϜͷؔ਺ͱͯ͠ಋग़͢Δख๏ \x -> x^2 + sin x \x

    -> 2 * x + cos x diff f(x) = x2 + sin x f′(x) = 2x + cos x d dx
  9. > :m Numeric.AD Data.Number.Symbolic > diff (\x -> x^2 +

    sin x) 1 2.5403023058681398 > diff (\x -> x^2 + sin x) (var "x") x+x+cos x BE OVNCFST 3FGIUUQTUXJUUFSDPN(BCSJFM(TUBUVT
  10. diff ͷܕ diff :: Num a => (forall s. AD

    s (Forward a) -> AD s (Forward a)) -> (a -> a) • Forward ʁ • forall s. AD s (…)ʁ ඍ෼͍ͨؔ͠਺ ಋؔ਺
  11. ࣮૷ͯ͠ཧղ͢Δ

  12. ߹੒ؔ਺ͷඍ෼๏ {f(g(x))}′ = f′(g(x)) ⋅ g′(x) f(x) = exp (sin(x2))

    f′(x) = exp (sin(x2)) ⋅ cos(x2) ⋅ 2x f′(g(x)) g′(x)
  13. Ϟʔυͱܭࢉྔ • ଟ஋ଟม਺ؔ਺ʹͳΔͱܭࢉޮ཰ʹ͕ࠩग़Δ • ͷ࣌͸ϑΥϫʔυϞʔυͷޮ཰͕ྑ͍ • ͷ࣌͸ϦόʔεϞʔυͷޮ཰͕ྑ͍ • ಛʹػցֶशʹ͓͍ͯ͸ Ͱ͋Δ͜ͱ͕ଟ͍

    n < m n > m n > m f :: ℝn → ℝm
  14. ࣮૷ํ਑ • ԋࢉࢠͷΦʔόʔϩʔυ • ιʔείʔυม׵ • ϓϩάϥϜΛ௚઀ղੳͯ͠
 ಋؔ਺ͷ࣮૷ίʔυΛੜ੒͢Δ

  15. ೋॏ਺ data D a = D a a real, tangent

    :: D a -> a real (D a _) = a tangent (D _ b) = b a + bϵ, ϵ2 = 0 2ճֻ͚Δͱ0ʹͳΔಛघͳݩΛ࣋ͭ਺
  16. instance Num a => Num (D a) where D x

    x' + D y y' = D (x + y) (x' + y') D x x' * D y y' = D (x * y) (x' * y + x * y') negate (D x x') = D (negate x) (negate x') abs (D x x') = D (abs x) (x' * (signum x)) signum (D x x') = D (signum x) 0 fromInteger n = D (fromInteger n) 0 instance Fractional a => Fractional (D a) where recip (D x x') = D (recip x) (-1 * x' * (recip (x * x))) fromRational x = D (fromRational x) 0 ϥΠϓχοπϧʔϧ {f(x)g(x)}′ = f′(x)g(x) + f(x)g′(x)
  17. (a + bϵ)(c + dϵ) = ac + (ad +

    bc)ϵ + bdϵ2 = ac + (ad + bc)ϵ 1 a + bϵ = a − bϵ (a + bϵ)(a − bϵ) = a − bϵ a2 − b2ϵ2 = a − bϵ a2 = 1 a − b a2 ϵ
  18. instance Floating a => Floating (D a) where pi =

    D pi 0 exp (D x x') = D (exp x) (x' * exp x) log (D x x') = D (log x) (x' / x) sin (D x x') = D (sin x) (x' * cos x) cos (D x x') = D (cos x) (- x' * sin x) asin (D x x') = D (asin x) (x' / (sqrt(1 - x ** 2))) acos (D x x') = D (acos x) (- x' / (sqrt(1 - x ** 2))) atan (D x x') = D (atan x) (x' / (1 + x ** 2)) sinh (D x x') = D (sinh x) (x' * cosh x) cosh (D x x') = D (cosh x) (x' * sinh x) asinh (D x x') = D (asinh x) (x' / (sqrt(1 + x ** 2))) acosh (D x x') = D (acosh x) (x' / (sqrt(x ** 2 - 1))) atanh (D x x') = D (atanh x) (x' / (1 - x ** 2))
  19. lift :: Num a => a -> D a lift

    x = D x 0 infinitesimal :: Num a => D a infinitesimal = D 0 1 diffD :: Num a => (D a -> D a) -> a -> a diffD f x = tangent $ f (lift x + infinitesimal) diffͷ࣮૷
  20. > diffD (\x -> x^2 + sin x) 1 2.5403023058681398

    > diffD (\x -> x^2 + sin x) (var "x") x+x+cos x
  21. > diffD (\x -> x^2 + sin x) 1 =

    tangent $ (\x -> x^2 + sin x) (lift 1 + infinitesimal) = tangent $ (\x -> x^2 + sin x) (D 1 1) = tangent $ (D 1 1)^2 + sin (D 1 1) = tangent $ D 1 2 + D (sin 1) (cos 1) = tangent $ D (1 + sin 1) (2 + cos 1) = 2 + cos 1 2.5403023058681398
  22. diff ͷܕ diffD :: Num a => (D a ->

    D a) -> a -> a diff :: Num a => (forall s. AD s (Forward a) -> AD s (Forward a)) -> (a -> a) • ✅ Forward • forall s. AD s (…)ʁ
  23. diffDͷ໰୊ > diffD (\y -> diffD (\x -> (x +

    y)^3) 1) 1 error: • Occurs check: cannot construct the infinite type: a ~ D a d ( d(x + y)3 dx x=1 ) dy y=1
  24. diffDͷ໰୊ > diffD (\y -> lift $ diffD (\x ->

    (x + y)^3) 1) 1 0 > diffD (\y -> diffD (\x -> (x + lift y)^3) (lift 1)) 1 12 Ͳ͕ͬͪਖ਼ղʁ ؒҧͬͨํͰ΋ܕ͕߹ͬͯ͠·͏͜ͱ͕໰୊ ˢEYͱEZΛࠞಉ͍ͯ͠Δ
  25. newtype AD s a = AD {unAD :: a} instance

    Num a => Num (AD s a) where ... instance Fractional a => Fractional (AD s a) where ... instance Floating a => Floating (AD s a) where ... ༓ྶܕ > (1 :: AD Bool Int) + (2 :: AD Bool Int) AD {unAD = 3} > (1 :: AD Bool Int) + (2 :: AD Char Int) error: • Couldn't match type ‘Char’ with ‘Bool’
  26. liftAD :: Num a => a -> AD s (D

    a) liftAD = AD . lift diffAD :: Num a => (forall s. AD s (D a) -> AD s (D a)) -> (a -> a) diffAD f = diffD (unAD . f . AD) ଘࡏܕ • diffAD ͷୈҰҾ਺͸ଘࡏܕʹͳ͍ͬͯΔ • s ͕۩ମతʹͲΜͳܕ͔֎ଆ͔Β͸෼͔Βͳ͍
  27. > diffAD (\y -> liftAD $ diffAD (\x -> (x

    + y)^2) 1) 1 error: • Couldn't match type ‘s1’ with ’s’ > diffAD (\y -> diffAD (\x -> (x + liftAD y)^3) (liftAD 1)) 1 12 ˢ͜͜ͰΤϥʔʹͳΔ diffAD Λ࣮ߦͯ͠ΈΔ ؒҧͬͨํʹ͸ܕ͕͔ͭͳ͍ʂ AD s (D a) AD s1 (D a)
  28. diff ͷܕ diff :: Num a => (forall s. AD

    s (Forward a) -> AD s (Forward a)) -> (a -> a) • ✅ Forward • ✅ forall s. AD s (…)
  29. ϦόʔεϞʔυͱͦͷઌ

  30. ϦόʔεϞʔυͷdiff diff :: Num a => (forall s. Reifies s

    Tape => Reverse s a -> Reverse s a) -> (a -> a) • Tape ?
  31. Wengert List (Tape) [ ("f", "exp", ["z2"]) , ("z2", "sin",

    ["z1"]) , ("z1", "square", ["x"]) ] f(x) = exp (sin(x2)) f′(x) = exp (sin(x2)) ⋅ cos(x2) ⋅ 2x
  32. unsafePerformIO binarily f di dj i b j c =

    Reverse (unsafePerformIO (modifyTape (Proxy :: Proxy s) (bin i j di dj))) $! f b c partials (Reverse k _) = map (sensitivities !) [0..vs] where Head n t = unsafePerformIO $ readIORef (getTape (reflect (Proxy :: Proxy s))) ... ૊ΈཱͯΔ࣌ʗ࣮ߦ࣌ʹ෭࡞༻͕ൃੜ͢Δ
  33. ͳΊΒ͔ͳϥϜμͱɺͦͷݍ • Conal Elliott, “The simple essence of automatic differentiation”,

    2018.
 ʢܧଓ౉͠ελΠϧΛ࢖͏͜ͱͰٯ޲͖ͷܭࢉΛ࣮ݱʣ • Fei Wang, et al; “Demystifying Differentiable Programming: Shift/Reset the Penultimate Backpropagator”, 2018. • Alois Brunel, et al; “Backpropagation in the Simply Typed Lambda- calculus with Linear Negation”, 2019.
 ʢઢܗܕΛ࢖͏͜ͱͰಋؔ਺͕৑௕ʹධՁ͞Εͳ͍͜ͱΛอূ͢Δʣ • Robin Cockett, et al; “Reverse derivative categories”, 2019.
  34. ·ͱΊ • ϓϩάϥϜͷؔ਺Λඍ෼Ͱ͖Δࣗಈඍ෼͸͍͢͝ • Ԡ༻ͷ໘͔Β΋جૅతͳ໘͔Β΋
 ·ͩ·ͩݱࡏਐߦͰൃల͍ͯ͠Δ • Έͳ͞Μ΋ࣗಈඍ෼Λ࢖ͬͯ༡ΜͰΈͯԼ͍͞ʂ

  35. ͝ਗ਼ௌ͋Γ͕ͱ͏͍͟͝·ͨ͠

  36. Q&A