Upgrade to Pro — share decks privately, control downloads, hide ads and more …

Automatic differentiation in Ruby

Tom Stuart
February 08, 2016

Automatic differentiation in Ruby

Finding the derivative of a mathematical function on a computer can be difficult, but there’s a clever trick that makes it easy: first write a program that computes the function, then execute it under a non-standard interpretation of its values and operations. In this talk I’ll show you how that works in Ruby.

Given at the London Ruby User Group (http://lrug.org/meetings/2016/february/). A video and expanded transcript are available at https://tomstu.art/automatic-differentiation-in-ruby, and the code is available at https://github.com/tomstuart/dual_number.

Tom Stuart

February 08, 2016
Tweet

More Decks by Tom Stuart

Other Decks in Programming

Transcript

  1. time distance(time) 0 seconds 0 metres 1 second 1 metre

    2 seconds 2 metres 3 seconds 3 metres 4 seconds 4 metres ⋮ ⋮ distance(time) = time
  2. distance(time) = time × time time distance(time) 0 seconds 0

    metres 1 second 1 metre 2 seconds 4 metres 3 seconds 9 metres 4 seconds 16 metres ⋮ ⋮
  3. time distance(time) speed(time) 0 seconds 0 metres ? 1 second

    1 metre ? 2 seconds 4 metres ? 3 seconds 9 metres ? 4 seconds 16 metres ? ⋮ ⋮ ⋮
  4. fast cruising 0 sec 1 sec 2 sec 1 m

    /s 2 m/s time distance
  5. slow cruising 0 sec 1 sec 2 sec 0.5 m/s

    1 m /s 2 m/s time distance
  6. time distance(time) speed(time) 0 seconds 0 metres 0 m/s 1

    second 1 metre 2 m/s 2 seconds 4 metres 4 m/s 3 seconds 9 metres 6 m/s 4 seconds 16 metres 8 m/s ⋮ ⋮ ⋮ distance(time) = time × time speed(time) = 2 × time
  7. def speed(time:) time_elapsed = 0.01 distance_before = distance(time: time) distance_after

    = distance(time: time + time_elapsed) distance_travelled = distance_after - distance_before distance_travelled / time_elapsed end
  8. >> speed(time: 0) => 0.01 >> speed(time: 1) => 2.0100000000000007

    >> speed(time: 2) => 4.009999999999891 >> speed(time: 3) => 6.009999999999849 >> speed(time: 4) => 8.009999999999806
  9. time distance(time) speed(time) 0 seconds 0 metres 0.01 m/s 1

    second 1 metre 2.01 m/s 2 seconds 4 metres 4.01 m/s 3 seconds 9 metres 6.01 m/s 4 seconds 16 metres 8.01 m/s ⋮ ⋮ ⋮ distance(time) = time × time speed(time) = ?
  10. class DualNumber attr_accessor :real, :dual def initialize(real:, dual:) self.real =

    real self.dual = dual end def to_s [real, (dual < 0 ? '-' : '+'), dual.abs, 'ε'].join end def inspect "(#{to_s})" end end
  11. module Kernel def DualNumber(real, dual = 0) case real when

    DualNumber real else DualNumber.new(real: real, dual: dual) end end end
  12. def distance(time:) time end >> time_now = DualNumber(3, 1) =>

    (3+1ε) >> distance_now = distance(time: time_now) => (3+1ε) >> distance_now.real => 3 >> distance_now.dual => 1
  13. = a + c + bi + di = (a

    + c) + (b + d)i (a + bi) + (c + di) =
  14. (a + bi) × (c + di) = = (a

    × c) + (a × di) + (bi × c) + (bi × di) = ac + (ad + bc)i + bdi ² = ac + (ad + bc)i + bd × -1 = ac + (ad + bc)i - bd = (ac - bd) + (ad + bc)i
  15. = a + c + bε + dε = (a

    + c) + (b + d)ε (a + bε) + (c + dε) =
  16. (a + bε) × (c + dε) = = (a

    × c) + (a × dε) + (bε × c) + (bε × dε) = ac + (ad + bc)ε + bdε² = ac + (ad + bc)ε + bd × 0 = ac + (ad + bc)ε
  17. class DualNumber def +(other) DualNumber.new \ real: real + other.real,

    dual: dual + other.dual end def *(other) DualNumber.new \ real: real * other.real, dual: real * other.dual + dual * other.real end end
  18. >> x = DualNumber(1, 2) => (1+2ε) >> y =

    DualNumber(3, 4) => (3+4ε) >> x + y => (4+6ε) >> x * y => (3+10ε)
  19. def distance(time:) time * time end >> time_now = DualNumber(3,

    1) => (3+1ε) >> distance_now = distance(time: time_now) => (9+6ε) >> distance_now.real => 9 >> distance_now.dual => 6
  20. >> x = DualNumber(1, 2) => (1+2ε) >> (x +

    3) * 4 NoMethodError: undefined method `dual' for 3:Fixnum
  21. class DualNumber def +(other) other = DualNumber(other) DualNumber.new \ real:

    real + other.real, dual: dual + other.dual end def *(other) other = DualNumber(other) DualNumber.new \ real: real * other.real, dual: real * other.dual + dual * other.real end end
  22. >> x = DualNumber(1, 2) => (1+2ε) >> 3 +

    (4 * x) TypeError: DualNumber can't be coerced into Fixnum
  23. = sin(a)ε = cos(a)ε = exp(a)ε = log(a)ε = sqrt(a)ε

    ⋮ = sin(a)ε = cos(a)ε = exp(a)ε = log(a)ε = sqrt(a)ε + (b × cos(a))ε (b × sin(a))ε + (b × exp(a))ε + (b ÷ a)ε + (b ÷ (2 × sqrt(a))ε sin(a + bε) = cos(a + bε) = exp(a + bε) = log(a + bε) = sqrt(a + bε) = - (b × sin(a))ε (b × cos(a))ε (b × sin(a))ε (b × exp(a))ε (b ÷ a)ε (b ÷ (2 × sqrt(a))ε
  24. Math.singleton_class.prepend Module.new { def sin(x) case x when DualNumber DualNumber.new

    \ real: sin(x.real), dual: x.dual * cos(x.real) else super end end def cos(x) case x when DualNumber DualNumber.new \ real: cos(x.real), dual: -x.dual * sin(x.real) else super end end }
  25. >> x = DualNumber(Math::PI / 3, 1) => (1.0471975511965976+1ε) >>

    Math.sin(x) => (0.8660254037844386+0.5000000000000001ε) >> Math.sin(x) + Math.cos(x) => (1.3660254037844388-0.3660254037844385ε) >> Math.sin(x) * Math.cos(x / 2) => (0.75+0.21650635094610984ε)