Upgrade to Pro — share decks privately, control downloads, hide ads and more …

Automatic differentiation in Ruby

Sponsored · Your Podcast. Everywhere. Effortlessly. Share. Educate. Inspire. Entertain. You do you. We'll handle the rest.
Avatar for Tom Stuart Tom Stuart
February 08, 2016

Automatic differentiation in Ruby

Finding the derivative of a mathematical function on a computer can be difficult, but there’s a clever trick that makes it easy: first write a program that computes the function, then execute it under a non-standard interpretation of its values and operations. In this talk I’ll show you how that works in Ruby.

Given at the London Ruby User Group (http://lrug.org/meetings/2016/february/). A video and expanded transcript are available at https://tomstu.art/automatic-differentiation-in-ruby, and the code is available at https://github.com/tomstuart/dual_number.

Avatar for Tom Stuart

Tom Stuart

February 08, 2016
Tweet

More Decks by Tom Stuart

Other Decks in Programming

Transcript

  1. time distance(time) 0 seconds 0 metres 1 second 1 metre

    2 seconds 2 metres 3 seconds 3 metres 4 seconds 4 metres ⋮ ⋮ distance(time) = time
  2. distance(time) = time × time time distance(time) 0 seconds 0

    metres 1 second 1 metre 2 seconds 4 metres 3 seconds 9 metres 4 seconds 16 metres ⋮ ⋮
  3. time distance(time) speed(time) 0 seconds 0 metres ? 1 second

    1 metre ? 2 seconds 4 metres ? 3 seconds 9 metres ? 4 seconds 16 metres ? ⋮ ⋮ ⋮
  4. fast cruising 0 sec 1 sec 2 sec 1 m

    /s 2 m/s time distance
  5. slow cruising 0 sec 1 sec 2 sec 0.5 m/s

    1 m /s 2 m/s time distance
  6. time distance(time) speed(time) 0 seconds 0 metres 0 m/s 1

    second 1 metre 2 m/s 2 seconds 4 metres 4 m/s 3 seconds 9 metres 6 m/s 4 seconds 16 metres 8 m/s ⋮ ⋮ ⋮ distance(time) = time × time speed(time) = 2 × time
  7. def speed(time:) time_elapsed = 0.01 distance_before = distance(time: time) distance_after

    = distance(time: time + time_elapsed) distance_travelled = distance_after - distance_before distance_travelled / time_elapsed end
  8. >> speed(time: 0) => 0.01 >> speed(time: 1) => 2.0100000000000007

    >> speed(time: 2) => 4.009999999999891 >> speed(time: 3) => 6.009999999999849 >> speed(time: 4) => 8.009999999999806
  9. time distance(time) speed(time) 0 seconds 0 metres 0.01 m/s 1

    second 1 metre 2.01 m/s 2 seconds 4 metres 4.01 m/s 3 seconds 9 metres 6.01 m/s 4 seconds 16 metres 8.01 m/s ⋮ ⋮ ⋮ distance(time) = time × time speed(time) = ?
  10. class DualNumber attr_accessor :real, :dual def initialize(real:, dual:) self.real =

    real self.dual = dual end def to_s [real, (dual < 0 ? '-' : '+'), dual.abs, 'ε'].join end def inspect "(#{to_s})" end end
  11. module Kernel def DualNumber(real, dual = 0) case real when

    DualNumber real else DualNumber.new(real: real, dual: dual) end end end
  12. def distance(time:) time end >> time_now = DualNumber(3, 1) =>

    (3+1ε) >> distance_now = distance(time: time_now) => (3+1ε) >> distance_now.real => 3 >> distance_now.dual => 1
  13. = a + c + bi + di = (a

    + c) + (b + d)i (a + bi) + (c + di) =
  14. (a + bi) × (c + di) = = (a

    × c) + (a × di) + (bi × c) + (bi × di) = ac + (ad + bc)i + bdi ² = ac + (ad + bc)i + bd × -1 = ac + (ad + bc)i - bd = (ac - bd) + (ad + bc)i
  15. = a + c + bε + dε = (a

    + c) + (b + d)ε (a + bε) + (c + dε) =
  16. (a + bε) × (c + dε) = = (a

    × c) + (a × dε) + (bε × c) + (bε × dε) = ac + (ad + bc)ε + bdε² = ac + (ad + bc)ε + bd × 0 = ac + (ad + bc)ε
  17. class DualNumber def +(other) DualNumber.new \ real: real + other.real,

    dual: dual + other.dual end def *(other) DualNumber.new \ real: real * other.real, dual: real * other.dual + dual * other.real end end
  18. >> x = DualNumber(1, 2) => (1+2ε) >> y =

    DualNumber(3, 4) => (3+4ε) >> x + y => (4+6ε) >> x * y => (3+10ε)
  19. def distance(time:) time * time end >> time_now = DualNumber(3,

    1) => (3+1ε) >> distance_now = distance(time: time_now) => (9+6ε) >> distance_now.real => 9 >> distance_now.dual => 6
  20. >> x = DualNumber(1, 2) => (1+2ε) >> (x +

    3) * 4 NoMethodError: undefined method `dual' for 3:Fixnum
  21. class DualNumber def +(other) other = DualNumber(other) DualNumber.new \ real:

    real + other.real, dual: dual + other.dual end def *(other) other = DualNumber(other) DualNumber.new \ real: real * other.real, dual: real * other.dual + dual * other.real end end
  22. >> x = DualNumber(1, 2) => (1+2ε) >> 3 +

    (4 * x) TypeError: DualNumber can't be coerced into Fixnum
  23. = sin(a)ε = cos(a)ε = exp(a)ε = log(a)ε = sqrt(a)ε

    ⋮ = sin(a)ε = cos(a)ε = exp(a)ε = log(a)ε = sqrt(a)ε + (b × cos(a))ε (b × sin(a))ε + (b × exp(a))ε + (b ÷ a)ε + (b ÷ (2 × sqrt(a))ε sin(a + bε) = cos(a + bε) = exp(a + bε) = log(a + bε) = sqrt(a + bε) = - (b × sin(a))ε (b × cos(a))ε (b × sin(a))ε (b × exp(a))ε (b ÷ a)ε (b ÷ (2 × sqrt(a))ε
  24. Math.singleton_class.prepend Module.new { def sin(x) case x when DualNumber DualNumber.new

    \ real: sin(x.real), dual: x.dual * cos(x.real) else super end end def cos(x) case x when DualNumber DualNumber.new \ real: cos(x.real), dual: -x.dual * sin(x.real) else super end end }
  25. >> x = DualNumber(Math::PI / 3, 1) => (1.0471975511965976+1ε) >>

    Math.sin(x) => (0.8660254037844386+0.5000000000000001ε) >> Math.sin(x) + Math.cos(x) => (1.3660254037844388-0.3660254037844385ε) >> Math.sin(x) * Math.cos(x / 2) => (0.75+0.21650635094610984ε)