Upgrade to Pro — share decks privately, control downloads, hide ads and more …

関数と型で理解する自動微分

lotz
November 09, 2019

 関数と型で理解する自動微分

lotz

November 09, 2019
Tweet

More Decks by lotz

Other Decks in Programming

Transcript

  1. Ͱ΋ෳࡶͳඍ෼ܭࢉ͸ͨ͘͠ͳ͍ʂ f(x) = σ (W3 ϕ (W2 ϕ (W1 x

    + b1) + b2) + b3) U = (m1 + m2 )gl1 (1 − cos θ1 ) + m2 gl2 (1 − cos θ2 )
  2. > :m Numeric.AD Data.Number.Symbolic > diff (\x -> x^2 +

    sin x) 1 2.5403023058681398 > diff (\x -> x^2 + sin x) (var "x") x+x+cos x BE OVNCFST 3FGIUUQTUXJUUFSDPN(BCSJFM(TUBUVT
  3. diff ͷܕ diff :: Num a => (forall s. AD

    s (Forward a) -> AD s (Forward a)) -> (a -> a) • Forward ʁ • forall s. AD s (…)ʁ ඍ෼͍ͨؔ͠਺ ಋؔ਺
  4. ߹੒ؔ਺ͷඍ෼๏ {f(g(x))}′ = f′(g(x)) ⋅ g′(x) f(x) = exp (sin(x2))

    f′(x) = exp (sin(x2)) ⋅ cos(x2) ⋅ 2x f′(g(x)) g′(x)
  5. ೋॏ਺ data D a = D a a real, tangent

    :: D a -> a real (D a _) = a tangent (D _ b) = b a + bϵ, ϵ2 = 0 2ճֻ͚Δͱ0ʹͳΔಛघͳݩΛ࣋ͭ਺
  6. instance Num a => Num (D a) where D x

    x' + D y y' = D (x + y) (x' + y') D x x' * D y y' = D (x * y) (x' * y + x * y') negate (D x x') = D (negate x) (negate x') abs (D x x') = D (abs x) (x' * (signum x)) signum (D x x') = D (signum x) 0 fromInteger n = D (fromInteger n) 0 instance Fractional a => Fractional (D a) where recip (D x x') = D (recip x) (-1 * x' * (recip (x * x))) fromRational x = D (fromRational x) 0 ϥΠϓχοπϧʔϧ {f(x)g(x)}′ = f′(x)g(x) + f(x)g′(x)
  7. (a + bϵ)(c + dϵ) = ac + (ad +

    bc)ϵ + bdϵ2 = ac + (ad + bc)ϵ 1 a + bϵ = a − bϵ (a + bϵ)(a − bϵ) = a − bϵ a2 − b2ϵ2 = a − bϵ a2 = 1 a − b a2 ϵ
  8. instance Floating a => Floating (D a) where pi =

    D pi 0 exp (D x x') = D (exp x) (x' * exp x) log (D x x') = D (log x) (x' / x) sin (D x x') = D (sin x) (x' * cos x) cos (D x x') = D (cos x) (- x' * sin x) asin (D x x') = D (asin x) (x' / (sqrt(1 - x ** 2))) acos (D x x') = D (acos x) (- x' / (sqrt(1 - x ** 2))) atan (D x x') = D (atan x) (x' / (1 + x ** 2)) sinh (D x x') = D (sinh x) (x' * cosh x) cosh (D x x') = D (cosh x) (x' * sinh x) asinh (D x x') = D (asinh x) (x' / (sqrt(1 + x ** 2))) acosh (D x x') = D (acosh x) (x' / (sqrt(x ** 2 - 1))) atanh (D x x') = D (atanh x) (x' / (1 - x ** 2))
  9. lift :: Num a => a -> D a lift

    x = D x 0 infinitesimal :: Num a => D a infinitesimal = D 0 1 diffD :: Num a => (D a -> D a) -> a -> a diffD f x = tangent $ f (lift x + infinitesimal) diffͷ࣮૷
  10. > diffD (\x -> x^2 + sin x) 1 2.5403023058681398

    > diffD (\x -> x^2 + sin x) (var "x") x+x+cos x
  11. > diffD (\x -> x^2 + sin x) 1 =

    tangent $ (\x -> x^2 + sin x) (lift 1 + infinitesimal) = tangent $ (\x -> x^2 + sin x) (D 1 1) = tangent $ (D 1 1)^2 + sin (D 1 1) = tangent $ D 1 2 + D (sin 1) (cos 1) = tangent $ D (1 + sin 1) (2 + cos 1) = 2 + cos 1 2.5403023058681398
  12. diff ͷܕ diffD :: Num a => (D a ->

    D a) -> a -> a diff :: Num a => (forall s. AD s (Forward a) -> AD s (Forward a)) -> (a -> a) • ✅ Forward • forall s. AD s (…)ʁ
  13. diffDͷ໰୊ > diffD (\y -> diffD (\x -> (x +

    y)^3) 1) 1 error: • Occurs check: cannot construct the infinite type: a ~ D a d ( d(x + y)3 dx x=1 ) dy y=1
  14. diffDͷ໰୊ > diffD (\y -> lift $ diffD (\x ->

    (x + y)^3) 1) 1 0 > diffD (\y -> diffD (\x -> (x + lift y)^3) (lift 1)) 1 12 Ͳ͕ͬͪਖ਼ղʁ ؒҧͬͨํͰ΋ܕ͕߹ͬͯ͠·͏͜ͱ͕໰୊ ˢEYͱEZΛࠞಉ͍ͯ͠Δ
  15. newtype AD s a = AD {unAD :: a} instance

    Num a => Num (AD s a) where ... instance Fractional a => Fractional (AD s a) where ... instance Floating a => Floating (AD s a) where ... ༓ྶܕ > (1 :: AD Bool Int) + (2 :: AD Bool Int) AD {unAD = 3} > (1 :: AD Bool Int) + (2 :: AD Char Int) error: • Couldn't match type ‘Char’ with ‘Bool’
  16. liftAD :: Num a => a -> AD s (D

    a) liftAD = AD . lift diffAD :: Num a => (forall s. AD s (D a) -> AD s (D a)) -> (a -> a) diffAD f = diffD (unAD . f . AD) ଘࡏܕ • diffAD ͷୈҰҾ਺͸ଘࡏܕʹͳ͍ͬͯΔ • s ͕۩ମతʹͲΜͳܕ͔֎ଆ͔Β͸෼͔Βͳ͍
  17. > diffAD (\y -> liftAD $ diffAD (\x -> (x

    + y)^2) 1) 1 error: • Couldn't match type ‘s1’ with ’s’ > diffAD (\y -> diffAD (\x -> (x + liftAD y)^3) (liftAD 1)) 1 12 ˢ͜͜ͰΤϥʔʹͳΔ diffAD Λ࣮ߦͯ͠ΈΔ ؒҧͬͨํʹ͸ܕ͕͔ͭͳ͍ʂ AD s (D a) AD s1 (D a)
  18. diff ͷܕ diff :: Num a => (forall s. AD

    s (Forward a) -> AD s (Forward a)) -> (a -> a) • ✅ Forward • ✅ forall s. AD s (…)
  19. ϦόʔεϞʔυͷdiff diff :: Num a => (forall s. Reifies s

    Tape => Reverse s a -> Reverse s a) -> (a -> a) • Tape ?
  20. Wengert List (Tape) [ ("f", "exp", ["z2"]) , ("z2", "sin",

    ["z1"]) , ("z1", "square", ["x"]) ] f(x) = exp (sin(x2)) f′(x) = exp (sin(x2)) ⋅ cos(x2) ⋅ 2x
  21. unsafePerformIO binarily f di dj i b j c =

    Reverse (unsafePerformIO (modifyTape (Proxy :: Proxy s) (bin i j di dj))) $! f b c partials (Reverse k _) = map (sensitivities !) [0..vs] where Head n t = unsafePerformIO $ readIORef (getTape (reflect (Proxy :: Proxy s))) ... ૊ΈཱͯΔ࣌ʗ࣮ߦ࣌ʹ෭࡞༻͕ൃੜ͢Δ
  22. ͳΊΒ͔ͳϥϜμͱɺͦͷݍ • Conal Elliott, “The simple essence of automatic differentiation”,

    2018.
 ʢܧଓ౉͠ελΠϧΛ࢖͏͜ͱͰٯ޲͖ͷܭࢉΛ࣮ݱʣ • Fei Wang, et al; “Demystifying Differentiable Programming: Shift/Reset the Penultimate Backpropagator”, 2018. • Alois Brunel, et al; “Backpropagation in the Simply Typed Lambda- calculus with Linear Negation”, 2019.
 ʢઢܗܕΛ࢖͏͜ͱͰಋؔ਺͕৑௕ʹධՁ͞Εͳ͍͜ͱΛอূ͢Δʣ • Robin Cockett, et al; “Reverse derivative categories”, 2019.
  23. Q&A