lotz
November 09, 2019
2.6k

関数と型で理解する自動微分

lotz

November 09, 2019

Transcript

1. ؔ਺ͱܕͰཧղ͢Δࣗಈඍ෼ 2019/11/9 lotz  at Haskell Day 2019

• GitHub: @lotz84

5. ಋؔ਺͕ཉ͍࣌͋͠Γ·͢ΑͶʁ • ϙςϯγϟϧؔ਺͕༩͑ΒΕͨͷͰ  ܥͷ࣌ؒൃలΛௐ΂͍ͨ • ֶशϞσϧΛ࡞ͬͨͷͰ  ޯ഑๏Ͱύϥϝʔλਪఆ͍ͨ͠ • ͱʹ͔ؔ͘਺͕༩͑ΒΕͨͷͰඍ෼͠ͳ͍ͱ ؾ͕͢·ͳ͍ʂʂ

෺ཧγϛϡϨʔγϣϯ ػցֶश
6. Ͱ΋ෳࡶͳඍ෼ܭࢉ͸ͨ͘͠ͳ͍ʂ f(x) = σ (W3 ϕ (W2 ϕ (W1 x

+ b1) + b2) + b3) U = (m1 + m2 )gl1 (1 − cos θ1 ) + m2 gl2 (1 − cos θ2 )

8. ࣗಈඍ෼ ϓϩάϥϜͰఆٛ͞Εͨؔ਺ͷಋؔ਺Λ  ϓϩάϥϜͷؔ਺ͱͯ͠ಋग़͢Δख๏ \x -> x^2 + sin x \x

-> 2 * x + cos x diff f(x) = x2 + sin x f′(x) = 2x + cos x d dx
9. > :m Numeric.AD Data.Number.Symbolic > diff (\x -> x^2 +

sin x) 1 2.5403023058681398 > diff (\x -> x^2 + sin x) (var "x") x+x+cos x BE OVNCFST 3FGIUUQTUXJUUFSDPN(BCSJFM(TUBUVT
10. diff ͷܕ diff :: Num a => (forall s. AD

s (Forward a) -> AD s (Forward a)) -> (a -> a) • Forward ʁ • forall s. AD s (…)ʁ ඍ෼͍ͨؔ͠਺ ಋؔ਺

12. ߹੒ؔ਺ͷඍ෼๏ {f(g(x))}′ = f′(g(x)) ⋅ g′(x) f(x) = exp (sin(x2))

f′(x) = exp (sin(x2)) ⋅ cos(x2) ⋅ 2x f′(g(x)) g′(x)
13. Ϟʔυͱܭࢉྔ • ଟ஋ଟม਺ؔ਺ʹͳΔͱܭࢉޮ཰ʹ͕ࠩग़Δ • ͷ࣌͸ϑΥϫʔυϞʔυͷޮ཰͕ྑ͍ • ͷ࣌͸ϦόʔεϞʔυͷޮ཰͕ྑ͍ • ಛʹػցֶशʹ͓͍ͯ͸ Ͱ͋Δ͜ͱ͕ଟ͍

n < m n > m n > m f :: ℝn → ℝm

15. ೋॏ਺ data D a = D a a real, tangent

:: D a -> a real (D a _) = a tangent (D _ b) = b a + bϵ, ϵ2 = 0 2ճֻ͚Δͱ0ʹͳΔಛघͳݩΛ࣋ͭ਺
16. instance Num a => Num (D a) where D x

x' + D y y' = D (x + y) (x' + y') D x x' * D y y' = D (x * y) (x' * y + x * y') negate (D x x') = D (negate x) (negate x') abs (D x x') = D (abs x) (x' * (signum x)) signum (D x x') = D (signum x) 0 fromInteger n = D (fromInteger n) 0 instance Fractional a => Fractional (D a) where recip (D x x') = D (recip x) (-1 * x' * (recip (x * x))) fromRational x = D (fromRational x) 0 ϥΠϓχοπϧʔϧ {f(x)g(x)}′ = f′(x)g(x) + f(x)g′(x)
17. (a + bϵ)(c + dϵ) = ac + (ad +

bc)ϵ + bdϵ2 = ac + (ad + bc)ϵ 1 a + bϵ = a − bϵ (a + bϵ)(a − bϵ) = a − bϵ a2 − b2ϵ2 = a − bϵ a2 = 1 a − b a2 ϵ
18. instance Floating a => Floating (D a) where pi =

D pi 0 exp (D x x') = D (exp x) (x' * exp x) log (D x x') = D (log x) (x' / x) sin (D x x') = D (sin x) (x' * cos x) cos (D x x') = D (cos x) (- x' * sin x) asin (D x x') = D (asin x) (x' / (sqrt(1 - x ** 2))) acos (D x x') = D (acos x) (- x' / (sqrt(1 - x ** 2))) atan (D x x') = D (atan x) (x' / (1 + x ** 2)) sinh (D x x') = D (sinh x) (x' * cosh x) cosh (D x x') = D (cosh x) (x' * sinh x) asinh (D x x') = D (asinh x) (x' / (sqrt(1 + x ** 2))) acosh (D x x') = D (acosh x) (x' / (sqrt(x ** 2 - 1))) atanh (D x x') = D (atanh x) (x' / (1 - x ** 2))
19. lift :: Num a => a -> D a lift

x = D x 0 infinitesimal :: Num a => D a infinitesimal = D 0 1 diffD :: Num a => (D a -> D a) -> a -> a diffD f x = tangent \$ f (lift x + infinitesimal) diffͷ࣮૷
20. > diffD (\x -> x^2 + sin x) 1 2.5403023058681398

> diffD (\x -> x^2 + sin x) (var "x") x+x+cos x
21. > diffD (\x -> x^2 + sin x) 1 =

tangent \$ (\x -> x^2 + sin x) (lift 1 + infinitesimal) = tangent \$ (\x -> x^2 + sin x) (D 1 1) = tangent \$ (D 1 1)^2 + sin (D 1 1) = tangent \$ D 1 2 + D (sin 1) (cos 1) = tangent \$ D (1 + sin 1) (2 + cos 1) = 2 + cos 1 2.5403023058681398
22. diff ͷܕ diffD :: Num a => (D a ->

D a) -> a -> a diff :: Num a => (forall s. AD s (Forward a) -> AD s (Forward a)) -> (a -> a) • ✅ Forward • forall s. AD s (…)ʁ
23. diffDͷ໰୊ > diffD (\y -> diffD (\x -> (x +

y)^3) 1) 1 error: • Occurs check: cannot construct the infinite type: a ~ D a d ( d(x + y)3 dx x=1 ) dy y=1
24. diffDͷ໰୊ > diffD (\y -> lift \$ diffD (\x ->

(x + y)^3) 1) 1 0 > diffD (\y -> diffD (\x -> (x + lift y)^3) (lift 1)) 1 12 Ͳ͕ͬͪਖ਼ղʁ ؒҧͬͨํͰ΋ܕ͕߹ͬͯ͠·͏͜ͱ͕໰୊ ˢEYͱEZΛࠞಉ͍ͯ͠Δ

Num a => Num (AD s a) where ... instance Fractional a => Fractional (AD s a) where ... instance Floating a => Floating (AD s a) where ... ༓ྶܕ > (1 :: AD Bool Int) + (2 :: AD Bool Int) AD {unAD = 3} > (1 :: AD Bool Int) + (2 :: AD Char Int) error: • Couldn't match type ‘Char’ with ‘Bool’

+ y)^2) 1) 1 error: • Couldn't match type ‘s1’ with ’s’ > diffAD (\y -> diffAD (\x -> (x + liftAD y)^3) (liftAD 1)) 1 12 ˢ͜͜ͰΤϥʔʹͳΔ diffAD Λ࣮ߦͯ͠ΈΔ ؒҧͬͨํʹ͸ܕ͕͔ͭͳ͍ʂ AD s (D a) AD s1 (D a)
28. diff ͷܕ diff :: Num a => (forall s. AD

s (Forward a) -> AD s (Forward a)) -> (a -> a) • ✅ Forward • ✅ forall s. AD s (…)

30. ϦόʔεϞʔυͷdiff diff :: Num a => (forall s. Reifies s

Tape => Reverse s a -> Reverse s a) -> (a -> a) • Tape ?
31. Wengert List (Tape) [ ("f", "exp", ["z2"]) , ("z2", "sin",

["z1"]) , ("z1", "square", ["x"]) ] f(x) = exp (sin(x2)) f′(x) = exp (sin(x2)) ⋅ cos(x2) ⋅ 2x
32. unsafePerformIO binarily f di dj i b j c =

Reverse (unsafePerformIO (modifyTape (Proxy :: Proxy s) (bin i j di dj))) \$! f b c partials (Reverse k _) = map (sensitivities !) [0..vs] where Head n t = unsafePerformIO \$ readIORef (getTape (reflect (Proxy :: Proxy s))) ... ૊ΈཱͯΔ࣌ʗ࣮ߦ࣌ʹ෭࡞༻͕ൃੜ͢Δ
33. ͳΊΒ͔ͳϥϜμͱɺͦͷݍ • Conal Elliott, “The simple essence of automatic differentiation”,

2018.  ʢܧଓ౉͠ελΠϧΛ࢖͏͜ͱͰٯ޲͖ͷܭࢉΛ࣮ݱʣ • Fei Wang, et al; “Demystifying Differentiable Programming: Shift/Reset the Penultimate Backpropagator”, 2018. • Alois Brunel, et al; “Backpropagation in the Simply Typed Lambda- calculus with Linear Negation”, 2019.  ʢઢܗܕΛ࢖͏͜ͱͰಋؔ਺͕৑௕ʹධՁ͞Εͳ͍͜ͱΛอূ͢Δʣ • Robin Cockett, et al; “Reverse derivative categories”, 2019.