Automatic differentiation in Ruby

Cd9b247e4507fed75312e9a42070125d?s=47 Tom Stuart
February 08, 2016

Automatic differentiation in Ruby

Finding the derivative of a mathematical function on a computer can be difficult, but there’s a clever trick that makes it easy: first write a program that computes the function, then execute it under a non-standard interpretation of its values and operations. In this talk I’ll show you how that works in Ruby.

Given at the London Ruby User Group (http://lrug.org/meetings/2016/february/). A video and expanded transcript are available at http://codon.com/automatic-differentiation-in-ruby, and the code is available at https://github.com/tomstuart/dual_number.

Cd9b247e4507fed75312e9a42070125d?s=128

Tom Stuart

February 08, 2016
Tweet

Transcript

  1. Automatic differentiation in Ruby http://codon.com/automatic-differentiation-in-ruby

  2. A clever trick that makes it easy to do differentiation

    on a computer
  3. What the blazes is “differentiation”?

  4. “function” = relationship between two quantities

  5. e.g. relationship between time and distance travelled

  6. cruising 0 sec 1 sec 2 sec time distance

  7. pulling away 0 sec 1 sec 2 sec time distance

  8. time distance(time) 0 seconds 0 metres 1 second 1 metre

    2 seconds 2 metres 3 seconds 3 metres 4 seconds 4 metres ⋮ ⋮ distance(time) = time
  9. distance(time) = time × time time distance(time) 0 seconds 0

    metres 1 second 1 metre 2 seconds 4 metres 3 seconds 9 metres 4 seconds 16 metres ⋮ ⋮
  10. “differentiation” = finding out how fast a function’s result is

    changing
  11. distance(time) = … speed(time) = ? differentiate

  12. time distance(time) speed(time) 0 seconds 0 metres ? 1 second

    1 metre ? 2 seconds 4 metres ? 3 seconds 9 metres ? 4 seconds 16 metres ? ⋮ ⋮ ⋮
  13. cruising 0 sec 1 sec 2 sec 1 m /s

    time distance
  14. fast cruising 0 sec 1 sec 2 sec 1 m

    /s 2 m/s time distance
  15. slow cruising 0 sec 1 sec 2 sec 0.5 m/s

    1 m /s 2 m/s time distance
  16. pulling away 0 sec 1 sec 2 sec time distance

  17. pulling away 0 sec 1 sec 2 sec time distance

  18. pulling away 0 sec 1 sec 2 sec time distance

  19. distance(time) = … speed(time) = ? differentiate

  20. Symbolic differentiation

  21. distance(time) = time × time speed(time) = ?

  22. None
  23. None
  24. None
  25. None
  26. distance(time) = time × time speed(time) = ?

  27. distance(time) = time ² speed(time) = ?

  28. None
  29. distance(time) = time ² speed(time) = 2 × time ¹

    (by the elementary power rule)
  30. time distance(time) speed(time) 0 seconds 0 metres 0 m/s 1

    second 1 metre 2 m/s 2 seconds 4 metres 4 m/s 3 seconds 9 metres 6 m/s 4 seconds 16 metres 8 m/s ⋮ ⋮ ⋮ distance(time) = time × time speed(time) = 2 × time
  31. Numerical differentiation

  32. distance(time) = time × time

  33. def distance(time:) time * time end

  34. pulling away 0 sec 1 sec 2 sec time distance

  35. pulling away 0 sec 1 sec 2 sec time distance

  36. pulling away 0 sec 1 sec 2 sec time distance

  37. pulling away 0 sec 1 sec 2 sec time distance

  38. def speed(time:) time_elapsed = 0.01 distance_before = distance(time: time) distance_after

    = distance(time: time + time_elapsed) distance_travelled = distance_after - distance_before distance_travelled / time_elapsed end
  39. >> speed(time: 0) => 0.01 >> speed(time: 1) => 2.0100000000000007

    >> speed(time: 2) => 4.009999999999891 >> speed(time: 3) => 6.009999999999849 >> speed(time: 4) => 8.009999999999806
  40. time distance(time) speed(time) 0 seconds 0 metres 0.01 m/s 1

    second 1 metre 2.01 m/s 2 seconds 4 metres 4.01 m/s 3 seconds 9 metres 6.01 m/s 4 seconds 16 metres 8.01 m/s ⋮ ⋮ ⋮ distance(time) = time × time speed(time) = ?
  41. Automatic differentiation

  42. Big idea: calculate a function’s rate of change and its

    value all at once
  43. distance 3 seconds 9 metres 1 s/s 6 m/s distance(time)

    = time × time
  44. i a + b COMPLEX NUMBERS

  45. DUAL NUMBERS ε a + b

  46. class DualNumber attr_accessor :real, :dual def initialize(real:, dual:) self.real =

    real self.dual = dual end def to_s [real, (dual < 0 ? '-' : '+'), dual.abs, 'ε'].join end def inspect "(#{to_s})" end end
  47. module Kernel def DualNumber(real, dual = 0) case real when

    DualNumber real else DualNumber.new(real: real, dual: dual) end end end
  48. def distance(time:) time end >> time_now = DualNumber(3, 1) =>

    (3+1ε) >> distance_now = distance(time: time_now) => (3+1ε) >> distance_now.real => 3 >> distance_now.dual => 1
  49. distance 3 seconds 3 metres 1 s/s 1 m/s distance(time)

    = time
  50. 3 metres 1 m/s distance(time) = time × time distance

    3 seconds 9 1 s/s ?
  51. ( i × i = -1 ) a + bi

  52. = a + c + bi + di = (a

    + c) + (b + d)i (a + bi) + (c + di) =
  53. (a + bi) × (c + di) = = (a

    × c) + (a × di) + (bi × c) + (bi × di) = ac + (ad + bc)i + bdi ² = ac + (ad + bc)i + bd × -1 = ac + (ad + bc)i - bd = (ac - bd) + (ad + bc)i
  54. a + bε ( ε × ε = 0 )

  55. = a + c + bε + dε = (a

    + c) + (b + d)ε (a + bε) + (c + dε) =
  56. (a + bε) × (c + dε) = = (a

    × c) + (a × dε) + (bε × c) + (bε × dε) = ac + (ad + bc)ε + bdε² = ac + (ad + bc)ε + bd × 0 = ac + (ad + bc)ε
  57. class DualNumber def +(other) DualNumber.new \ real: real + other.real,

    dual: dual + other.dual end def *(other) DualNumber.new \ real: real * other.real, dual: real * other.dual + dual * other.real end end
  58. >> x = DualNumber(1, 2) => (1+2ε) >> y =

    DualNumber(3, 4) => (3+4ε) >> x + y => (4+6ε) >> x * y => (3+10ε)
  59. def distance(time:) time * time end >> time_now = DualNumber(3,

    1) => (3+1ε) >> distance_now = distance(time: time_now) => (9+6ε) >> distance_now.real => 9 >> distance_now.dual => 6
  60. distance 3 seconds 9 metres 1 s/s 6 m/s distance(time)

    = time × time
  61. >> x = DualNumber(1, 2) => (1+2ε) >> (x +

    3) * 4 NoMethodError: undefined method `dual' for 3:Fixnum
  62. class DualNumber def +(other) other = DualNumber(other) DualNumber.new \ real:

    real + other.real, dual: dual + other.dual end def *(other) other = DualNumber(other) DualNumber.new \ real: real * other.real, dual: real * other.dual + dual * other.real end end
  63. >> x = DualNumber(1, 2) => (1+2ε) >> (x +

    3) * 4 => (16+8ε)
  64. >> x = DualNumber(1, 2) => (1+2ε) >> 3 +

    (4 * x) TypeError: DualNumber can't be coerced into Fixnum
  65. coerce(3) 3 ! [!!!, !] +(!) !!!! !!! +(!) !!!!

  66. class DualNumber def coerce(other) [DualNumber(other), self] end end

  67. >> x = DualNumber(1, 2) => (1+2ε) >> 3 +

    (4 * x) => (7+8ε)
  68. = sin(a)ε = cos(a)ε = exp(a)ε = log(a)ε = sqrt(a)ε

    ⋮ = sin(a)ε = cos(a)ε = exp(a)ε = log(a)ε = sqrt(a)ε + (b × cos(a))ε (b × sin(a))ε + (b × exp(a))ε + (b ÷ a)ε + (b ÷ (2 × sqrt(a))ε sin(a + bε) = cos(a + bε) = exp(a + bε) = log(a + bε) = sqrt(a + bε) = - (b × sin(a))ε (b × cos(a))ε (b × sin(a))ε (b × exp(a))ε (b ÷ a)ε (b ÷ (2 × sqrt(a))ε
  69. Math.singleton_class.prepend Module.new { def sin(x) case x when DualNumber DualNumber.new

    \ real: sin(x.real), dual: x.dual * cos(x.real) else super end end def cos(x) case x when DualNumber DualNumber.new \ real: cos(x.real), dual: -x.dual * sin(x.real) else super end end }
  70. >> x = DualNumber(Math::PI / 3, 1) => (1.0471975511965976+1ε) >>

    Math.sin(x) => (0.8660254037844386+0.5000000000000001ε) >> Math.sin(x) + Math.cos(x) => (1.3660254037844388-0.3660254037844385ε) >> Math.sin(x) * Math.cos(x / 2) => (0.75+0.21650635094610984ε)
  71. def distance(time:) Math.sin(time) end

  72. None
  73. def distance(time:) Math.sin(time) * 0.8 + Math.cos(time * 5) /

    5 end
  74. None
  75. def distance(time:) time * Math.sin(time * time) + 1 end

  76. None
  77. tomstuart/dual_number gem install dual_number

  78. Thanks! http://codon.com/automatic-differentiation-in-ruby