Upgrade to Pro — share decks privately, control downloads, hide ads and more …

関数と型で理解する自動微分

Sponsored · Ship Features Fearlessly Turn features on and off without deploys. Used by thousands of Ruby developers.
Avatar for lotz lotz
November 09, 2019

 関数と型で理解する自動微分

Avatar for lotz

lotz

November 09, 2019
Tweet

More Decks by lotz

Other Decks in Programming

Transcript

  1. Ͱ΋ෳࡶͳඍ෼ܭࢉ͸ͨ͘͠ͳ͍ʂ f(x) = σ (W3 ϕ (W2 ϕ (W1 x

    + b1) + b2) + b3) U = (m1 + m2 )gl1 (1 − cos θ1 ) + m2 gl2 (1 − cos θ2 )
  2. > :m Numeric.AD Data.Number.Symbolic > diff (\x -> x^2 +

    sin x) 1 2.5403023058681398 > diff (\x -> x^2 + sin x) (var "x") x+x+cos x BE OVNCFST 3FGIUUQTUXJUUFSDPN(BCSJFM(TUBUVT
  3. diff ͷܕ diff :: Num a => (forall s. AD

    s (Forward a) -> AD s (Forward a)) -> (a -> a) • Forward ʁ • forall s. AD s (…)ʁ ඍ෼͍ͨؔ͠਺ ಋؔ਺
  4. ߹੒ؔ਺ͷඍ෼๏ {f(g(x))}′ = f′(g(x)) ⋅ g′(x) f(x) = exp (sin(x2))

    f′(x) = exp (sin(x2)) ⋅ cos(x2) ⋅ 2x f′(g(x)) g′(x)
  5. ೋॏ਺ data D a = D a a real, tangent

    :: D a -> a real (D a _) = a tangent (D _ b) = b a + bϵ, ϵ2 = 0 2ճֻ͚Δͱ0ʹͳΔಛघͳݩΛ࣋ͭ਺
  6. instance Num a => Num (D a) where D x

    x' + D y y' = D (x + y) (x' + y') D x x' * D y y' = D (x * y) (x' * y + x * y') negate (D x x') = D (negate x) (negate x') abs (D x x') = D (abs x) (x' * (signum x)) signum (D x x') = D (signum x) 0 fromInteger n = D (fromInteger n) 0 instance Fractional a => Fractional (D a) where recip (D x x') = D (recip x) (-1 * x' * (recip (x * x))) fromRational x = D (fromRational x) 0 ϥΠϓχοπϧʔϧ {f(x)g(x)}′ = f′(x)g(x) + f(x)g′(x)
  7. (a + bϵ)(c + dϵ) = ac + (ad +

    bc)ϵ + bdϵ2 = ac + (ad + bc)ϵ 1 a + bϵ = a − bϵ (a + bϵ)(a − bϵ) = a − bϵ a2 − b2ϵ2 = a − bϵ a2 = 1 a − b a2 ϵ
  8. instance Floating a => Floating (D a) where pi =

    D pi 0 exp (D x x') = D (exp x) (x' * exp x) log (D x x') = D (log x) (x' / x) sin (D x x') = D (sin x) (x' * cos x) cos (D x x') = D (cos x) (- x' * sin x) asin (D x x') = D (asin x) (x' / (sqrt(1 - x ** 2))) acos (D x x') = D (acos x) (- x' / (sqrt(1 - x ** 2))) atan (D x x') = D (atan x) (x' / (1 + x ** 2)) sinh (D x x') = D (sinh x) (x' * cosh x) cosh (D x x') = D (cosh x) (x' * sinh x) asinh (D x x') = D (asinh x) (x' / (sqrt(1 + x ** 2))) acosh (D x x') = D (acosh x) (x' / (sqrt(x ** 2 - 1))) atanh (D x x') = D (atanh x) (x' / (1 - x ** 2))
  9. lift :: Num a => a -> D a lift

    x = D x 0 infinitesimal :: Num a => D a infinitesimal = D 0 1 diffD :: Num a => (D a -> D a) -> a -> a diffD f x = tangent $ f (lift x + infinitesimal) diffͷ࣮૷
  10. > diffD (\x -> x^2 + sin x) 1 2.5403023058681398

    > diffD (\x -> x^2 + sin x) (var "x") x+x+cos x
  11. > diffD (\x -> x^2 + sin x) 1 =

    tangent $ (\x -> x^2 + sin x) (lift 1 + infinitesimal) = tangent $ (\x -> x^2 + sin x) (D 1 1) = tangent $ (D 1 1)^2 + sin (D 1 1) = tangent $ D 1 2 + D (sin 1) (cos 1) = tangent $ D (1 + sin 1) (2 + cos 1) = 2 + cos 1 2.5403023058681398
  12. diff ͷܕ diffD :: Num a => (D a ->

    D a) -> a -> a diff :: Num a => (forall s. AD s (Forward a) -> AD s (Forward a)) -> (a -> a) • ✅ Forward • forall s. AD s (…)ʁ
  13. diffDͷ໰୊ > diffD (\y -> diffD (\x -> (x +

    y)^3) 1) 1 error: • Occurs check: cannot construct the infinite type: a ~ D a d ( d(x + y)3 dx x=1 ) dy y=1
  14. diffDͷ໰୊ > diffD (\y -> lift $ diffD (\x ->

    (x + y)^3) 1) 1 0 > diffD (\y -> diffD (\x -> (x + lift y)^3) (lift 1)) 1 12 Ͳ͕ͬͪਖ਼ղʁ ؒҧͬͨํͰ΋ܕ͕߹ͬͯ͠·͏͜ͱ͕໰୊ ˢEYͱEZΛࠞಉ͍ͯ͠Δ
  15. newtype AD s a = AD {unAD :: a} instance

    Num a => Num (AD s a) where ... instance Fractional a => Fractional (AD s a) where ... instance Floating a => Floating (AD s a) where ... ༓ྶܕ > (1 :: AD Bool Int) + (2 :: AD Bool Int) AD {unAD = 3} > (1 :: AD Bool Int) + (2 :: AD Char Int) error: • Couldn't match type ‘Char’ with ‘Bool’
  16. liftAD :: Num a => a -> AD s (D

    a) liftAD = AD . lift diffAD :: Num a => (forall s. AD s (D a) -> AD s (D a)) -> (a -> a) diffAD f = diffD (unAD . f . AD) ଘࡏܕ • diffAD ͷୈҰҾ਺͸ଘࡏܕʹͳ͍ͬͯΔ • s ͕۩ମతʹͲΜͳܕ͔֎ଆ͔Β͸෼͔Βͳ͍
  17. > diffAD (\y -> liftAD $ diffAD (\x -> (x

    + y)^2) 1) 1 error: • Couldn't match type ‘s1’ with ’s’ > diffAD (\y -> diffAD (\x -> (x + liftAD y)^3) (liftAD 1)) 1 12 ˢ͜͜ͰΤϥʔʹͳΔ diffAD Λ࣮ߦͯ͠ΈΔ ؒҧͬͨํʹ͸ܕ͕͔ͭͳ͍ʂ AD s (D a) AD s1 (D a)
  18. diff ͷܕ diff :: Num a => (forall s. AD

    s (Forward a) -> AD s (Forward a)) -> (a -> a) • ✅ Forward • ✅ forall s. AD s (…)
  19. ϦόʔεϞʔυͷdiff diff :: Num a => (forall s. Reifies s

    Tape => Reverse s a -> Reverse s a) -> (a -> a) • Tape ?
  20. Wengert List (Tape) [ ("f", "exp", ["z2"]) , ("z2", "sin",

    ["z1"]) , ("z1", "square", ["x"]) ] f(x) = exp (sin(x2)) f′(x) = exp (sin(x2)) ⋅ cos(x2) ⋅ 2x
  21. unsafePerformIO binarily f di dj i b j c =

    Reverse (unsafePerformIO (modifyTape (Proxy :: Proxy s) (bin i j di dj))) $! f b c partials (Reverse k _) = map (sensitivities !) [0..vs] where Head n t = unsafePerformIO $ readIORef (getTape (reflect (Proxy :: Proxy s))) ... ૊ΈཱͯΔ࣌ʗ࣮ߦ࣌ʹ෭࡞༻͕ൃੜ͢Δ
  22. ͳΊΒ͔ͳϥϜμͱɺͦͷݍ • Conal Elliott, “The simple essence of automatic differentiation”,

    2018.
 ʢܧଓ౉͠ελΠϧΛ࢖͏͜ͱͰٯ޲͖ͷܭࢉΛ࣮ݱʣ • Fei Wang, et al; “Demystifying Differentiable Programming: Shift/Reset the Penultimate Backpropagator”, 2018. • Alois Brunel, et al; “Backpropagation in the Simply Typed Lambda- calculus with Linear Negation”, 2019.
 ʢઢܗܕΛ࢖͏͜ͱͰಋؔ਺͕৑௕ʹධՁ͞Εͳ͍͜ͱΛอূ͢Δʣ • Robin Cockett, et al; “Reverse derivative categories”, 2019.
  23. Q&A