$30 off During Our Annual Pro Sale. View Details »

Automatic differentiation in Ruby

Tom Stuart
February 08, 2016

Automatic differentiation in Ruby

Finding the derivative of a mathematical function on a computer can be difficult, but there’s a clever trick that makes it easy: first write a program that computes the function, then execute it under a non-standard interpretation of its values and operations. In this talk I’ll show you how that works in Ruby.

Given at the London Ruby User Group (http://lrug.org/meetings/2016/february/). A video and expanded transcript are available at https://tomstu.art/automatic-differentiation-in-ruby, and the code is available at https://github.com/tomstuart/dual_number.

Tom Stuart

February 08, 2016
Tweet

More Decks by Tom Stuart

Other Decks in Programming

Transcript

  1. Automatic
    differentiation
    in Ruby
    http://codon.com/automatic-differentiation-in-ruby

    View Slide

  2. A clever trick that
    makes it easy to
    do differentiation
    on a computer

    View Slide

  3. What the blazes is
    “differentiation”?

    View Slide

  4. “function” =
    relationship
    between two
    quantities

    View Slide

  5. e.g. relationship
    between time and
    distance travelled

    View Slide

  6. cruising
    0 sec 1 sec 2 sec
    time
    distance

    View Slide

  7. pulling away
    0 sec 1 sec 2 sec
    time
    distance

    View Slide

  8. time distance(time)
    0 seconds 0 metres
    1 second 1 metre
    2 seconds 2 metres
    3 seconds 3 metres
    4 seconds 4 metres
    ⋮ ⋮
    distance(time) = time

    View Slide

  9. distance(time) = time
    ×
    time
    time distance(time)
    0 seconds 0 metres
    1 second 1 metre
    2 seconds 4 metres
    3 seconds 9 metres
    4 seconds 16 metres
    ⋮ ⋮

    View Slide

  10. “differentiation” =
    finding out how
    fast a function’s
    result is changing

    View Slide

  11. distance(time) = …
    speed(time) = ?
    differentiate

    View Slide

  12. time distance(time) speed(time)
    0 seconds 0 metres ?
    1 second 1 metre ?
    2 seconds 4 metres ?
    3 seconds 9 metres ?
    4 seconds 16 metres ?
    ⋮ ⋮ ⋮

    View Slide

  13. cruising
    0 sec 1 sec 2 sec
    1 m
    /s
    time
    distance

    View Slide

  14. fast
    cruising
    0 sec 1 sec 2 sec
    1 m
    /s
    2 m/s
    time
    distance

    View Slide

  15. slow
    cruising
    0 sec 1 sec 2 sec
    0.5 m/s
    1 m
    /s
    2 m/s
    time
    distance

    View Slide

  16. pulling away
    0 sec 1 sec 2 sec
    time
    distance

    View Slide

  17. pulling away
    0 sec 1 sec 2 sec
    time
    distance

    View Slide

  18. pulling away
    0 sec 1 sec 2 sec
    time
    distance

    View Slide

  19. distance(time) = …
    speed(time) = ?
    differentiate

    View Slide

  20. Symbolic
    differentiation

    View Slide

  21. distance(time) = time
    ×
    time
    speed(time) = ?

    View Slide

  22. View Slide

  23. View Slide

  24. View Slide

  25. View Slide

  26. distance(time) = time
    ×
    time
    speed(time) = ?

    View Slide

  27. distance(time) = time
    ²
    speed(time) = ?

    View Slide

  28. View Slide

  29. distance(time) = time
    ²
    speed(time) = 2
    ×
    time
    ¹
    (by the elementary power rule)

    View Slide

  30. time distance(time) speed(time)
    0 seconds 0 metres 0 m/s
    1 second 1 metre 2 m/s
    2 seconds 4 metres 4 m/s
    3 seconds 9 metres 6 m/s
    4 seconds 16 metres 8 m/s
    ⋮ ⋮ ⋮
    distance(time) = time
    ×
    time
    speed(time) = 2
    ×
    time

    View Slide

  31. Numerical
    differentiation

    View Slide

  32. distance(time) = time
    ×
    time

    View Slide

  33. def distance(time:)
    time * time
    end

    View Slide

  34. pulling away
    0 sec 1 sec 2 sec
    time
    distance

    View Slide

  35. pulling away
    0 sec 1 sec 2 sec
    time
    distance

    View Slide

  36. pulling away
    0 sec 1 sec 2 sec
    time
    distance

    View Slide

  37. pulling away
    0 sec 1 sec 2 sec
    time
    distance

    View Slide

  38. def speed(time:)
    time_elapsed = 0.01
    distance_before = distance(time: time)
    distance_after = distance(time: time + time_elapsed)
    distance_travelled = distance_after - distance_before
    distance_travelled / time_elapsed
    end

    View Slide

  39. >> speed(time: 0)
    => 0.01
    >> speed(time: 1)
    => 2.0100000000000007
    >> speed(time: 2)
    => 4.009999999999891
    >> speed(time: 3)
    => 6.009999999999849
    >> speed(time: 4)
    => 8.009999999999806

    View Slide

  40. time distance(time) speed(time)
    0 seconds 0 metres 0.01 m/s
    1 second 1 metre 2.01 m/s
    2 seconds 4 metres 4.01 m/s
    3 seconds 9 metres 6.01 m/s
    4 seconds 16 metres 8.01 m/s
    ⋮ ⋮ ⋮
    distance(time) = time
    ×
    time
    speed(time) = ?

    View Slide

  41. Automatic
    differentiation

    View Slide

  42. Big idea:
    calculate a function’s
    rate of change and its
    value all at once

    View Slide

  43. distance
    3 seconds 9 metres
    1 s/s 6 m/s
    distance(time) = time
    ×
    time

    View Slide

  44. i
    a + b
    COMPLEX NUMBERS

    View Slide

  45. DUAL NUMBERS
    ε
    a + b

    View Slide

  46. class DualNumber
    attr_accessor :real, :dual
    def initialize(real:, dual:)
    self.real = real
    self.dual = dual
    end
    def to_s
    [real, (dual < 0 ? '-' : '+'), dual.abs, 'ε'].join
    end
    def inspect
    "(#{to_s})"
    end
    end

    View Slide

  47. module Kernel
    def DualNumber(real, dual = 0)
    case real
    when DualNumber
    real
    else
    DualNumber.new(real: real, dual: dual)
    end
    end
    end

    View Slide

  48. def distance(time:)
    time
    end
    >> time_now = DualNumber(3, 1)
    => (3+1ε)
    >> distance_now = distance(time: time_now)
    => (3+1ε)
    >> distance_now.real
    => 3
    >> distance_now.dual
    => 1

    View Slide

  49. distance
    3 seconds 3 metres
    1 s/s 1 m/s
    distance(time) = time

    View Slide

  50. 3 metres
    1 m/s
    distance(time) = time
    ×
    time
    distance
    3 seconds 9
    1 s/s ?

    View Slide

  51. ( i × i = -1 )
    a + bi

    View Slide

  52. = a + c + bi + di
    = (a + c) + (b + d)i
    (a + bi) + (c + di) =

    View Slide

  53. (a + bi)
    ×
    (c + di) =
    = (a
    ×
    c) + (a
    ×
    di) + (bi ×
    c) + (bi ×
    di)
    = ac + (ad + bc)i + bdi ²
    = ac + (ad + bc)i + bd
    ×
    -1
    = ac + (ad + bc)i - bd
    = (ac - bd) + (ad + bc)i

    View Slide

  54. a + bε
    ( ε × ε = 0 )

    View Slide

  55. = a + c + bε + dε
    = (a + c) + (b + d)ε
    (a + bε) + (c + dε) =

    View Slide

  56. (a + bε)
    ×
    (c + dε) =
    = (a
    ×
    c) + (a
    ×
    dε) + (bε ×
    c) + (bε ×
    dε)
    = ac + (ad + bc)ε + bdε²
    = ac + (ad + bc)ε + bd
    ×
    0
    = ac + (ad + bc)ε

    View Slide

  57. class DualNumber
    def +(other)
    DualNumber.new \
    real: real + other.real,
    dual: dual + other.dual
    end
    def *(other)
    DualNumber.new \
    real: real * other.real,
    dual: real * other.dual + dual * other.real
    end
    end

    View Slide

  58. >> x = DualNumber(1, 2)
    => (1+2ε)
    >> y = DualNumber(3, 4)
    => (3+4ε)
    >> x + y
    => (4+6ε)
    >> x * y
    => (3+10ε)

    View Slide

  59. def distance(time:)
    time * time
    end
    >> time_now = DualNumber(3, 1)
    => (3+1ε)
    >> distance_now = distance(time: time_now)
    => (9+6ε)
    >> distance_now.real
    => 9
    >> distance_now.dual
    => 6

    View Slide

  60. distance
    3 seconds 9 metres
    1 s/s 6 m/s
    distance(time) = time
    ×
    time

    View Slide

  61. >> x = DualNumber(1, 2)
    => (1+2ε)
    >> (x + 3) * 4
    NoMethodError: undefined method `dual' for 3:Fixnum

    View Slide

  62. class DualNumber
    def +(other)
    other = DualNumber(other)
    DualNumber.new \
    real: real + other.real,
    dual: dual + other.dual
    end
    def *(other)
    other = DualNumber(other)
    DualNumber.new \
    real: real * other.real,
    dual: real * other.dual + dual * other.real
    end
    end

    View Slide

  63. >> x = DualNumber(1, 2)
    => (1+2ε)
    >> (x + 3) * 4
    => (16+8ε)

    View Slide

  64. >> x = DualNumber(1, 2)
    => (1+2ε)
    >> 3 + (4 * x)
    TypeError: DualNumber can't be coerced into Fixnum

    View Slide

  65. coerce(3)
    3
    !
    [!!!, !]
    +(!)
    !!!! !!!
    +(!)
    !!!!

    View Slide

  66. class DualNumber
    def coerce(other)
    [DualNumber(other), self]
    end
    end

    View Slide

  67. >> x = DualNumber(1, 2)
    => (1+2ε)
    >> 3 + (4 * x)
    => (7+8ε)

    View Slide

  68. = sin(a)ε
    = cos(a)ε
    = exp(a)ε
    = log(a)ε
    = sqrt(a)ε

    = sin(a)ε
    = cos(a)ε
    = exp(a)ε
    = log(a)ε
    = sqrt(a)ε
    + (b
    ×
    cos(a))ε
    (b
    ×
    sin(a))ε
    + (b
    ×
    exp(a))ε
    + (b ÷ a)ε
    + (b ÷ (2
    ×
    sqrt(a))ε
    sin(a + bε) =
    cos(a + bε) =
    exp(a + bε) =
    log(a + bε) =
    sqrt(a + bε) =
    - (b
    ×
    sin(a))ε
    (b
    ×
    cos(a))ε
    (b
    ×
    sin(a))ε
    (b
    ×
    exp(a))ε
    (b ÷ a)ε
    (b ÷ (2
    ×
    sqrt(a))ε

    View Slide

  69. Math.singleton_class.prepend Module.new {
    def sin(x)
    case x
    when DualNumber
    DualNumber.new \
    real: sin(x.real),
    dual: x.dual * cos(x.real)
    else
    super
    end
    end
    def cos(x)
    case x
    when DualNumber
    DualNumber.new \
    real: cos(x.real),
    dual: -x.dual * sin(x.real)
    else
    super
    end
    end
    }

    View Slide

  70. >> x = DualNumber(Math::PI / 3, 1)
    => (1.0471975511965976+1ε)
    >> Math.sin(x)
    => (0.8660254037844386+0.5000000000000001ε)
    >> Math.sin(x) + Math.cos(x)
    => (1.3660254037844388-0.3660254037844385ε)
    >> Math.sin(x) * Math.cos(x / 2)
    => (0.75+0.21650635094610984ε)

    View Slide

  71. def distance(time:)
    Math.sin(time)
    end

    View Slide

  72. View Slide

  73. def distance(time:)
    Math.sin(time) * 0.8 + Math.cos(time * 5) / 5
    end

    View Slide

  74. View Slide

  75. def distance(time:)
    time * Math.sin(time * time) + 1
    end

    View Slide

  76. View Slide

  77. tomstuart/dual_number
    gem install dual_number

    View Slide

  78. Thanks!
    http://codon.com/automatic-differentiation-in-ruby

    View Slide