Upgrade to Pro — share decks privately, control downloads, hide ads and more …

Chainer の Trainer 解説と NStepLSTM について

Chainer の Trainer 解説と NStepLSTM について

レトリバセミナー 2017/03/15
Movie: https://www.youtube.com/watch?v=ok_bvPKAEaM

Kei Shiratsuchi

March 17, 2017
Tweet

More Decks by Kei Shiratsuchi

Other Decks in Technology

Transcript

  1. $IBJOFS ͷ 5SBJOFSղઆͱ
    /4UFQ-45. ʹ͍ͭͯ
    株式会社レトリバ
    © 2017 Retrieva, Inc.

    View Slide

  2. ࣗݾ঺հ
    • ⽩⼟慧(シラツチ ケイ)
    • 株式会社レトリバ
    • 2016年4⽉⼊社
    • Ruby on Rails / JavaScript
    • フロントエンド側の⼈間
    • ⼤学時代は複雑ネットワーク科学の研究
    • Chainer ⼊⾨中
    © 2017 Retrieva, Inc. 2

    View Slide

  3. ΞδΣϯμ
    • 第1部 Chainer における Trainer の解説
    • 第2部 NStepLSTM との格闘
    © 2017 Retrieva, Inc. 3

    View Slide

  4. Ξϯέʔτ
    • Chainer を使っている⽅
    • Chainer の Trainer を使っている⽅
    • LSTM を使っている⽅
    • NStepLSTM を使っている⽅
    • NStepLSTM と Trainer を使っている⽅
    © 2017 Retrieva, Inc. 4

    View Slide

  5. ୈ̍෦ $IBJOFS ʹ͓͚Δ 5SBJOFS
    © 2017 Retrieva, Inc. 5

    View Slide

  6. $IBJOFS ʹ͓͚Δ 5SBJOFS
    • Chainer 1.11.0 から導⼊された学習フレームワーク
    • batchの取り出し、forward/backward が抽象化されている
    • 進捗表⽰、モデルのスナップショットなど
    • Trainer 後から⼊⾨した⼈(私も)は、MNIST のサンプルが
    Trainerで抽象化されていて、何が起きているのかわからない
    • 以前から Chainer を使っている⼈は、Trainer なしで動かして
    いることが多い
    © 2017 Retrieva, Inc. 6

    View Slide

  7. 5SBJOFSશମਤ [email protected]

    © 2017 Retrieva, Inc. 7
    Trainer
    Updater (StandardUpdater)
    Iterator
    Optimizer
    Classifier
    model (MLP)
    train dataset
    Evaluator
    Iterator
    test dataset
    Extensions
    • dump_graph, snapshot
    • LogReport, PrintReport
    • ProgressBar
    • converter
    • loss_func
    • device
    • converter
    • device

    View Slide

  8. 5SBJOFS
    • Trainer フレームワークの⼤元
    • 渡された Updater、(必要があれば) Evaluator を実⾏する
    • グラフのダンプ、スナップショット、レポーティング、進捗表
    ⽰などを、Extension として実⾏できる
    © 2017 Retrieva, Inc. 8

    View Slide

  9. 5SBJOFS
    • 指定した epoch 数になるまで、Updater の update() を呼ぶ
    © 2017 Retrieva, Inc. 9
    # examples/mnist/train_mnist.py
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)
    # chainer/training/trainer.py
    class Trainer(object):
    def run(self):
    update = self.updater.update
    # main training loop
    try:
    while not stop_trigger(self):
    self.observation = {}
    with reporter.scope(self.observation):
    update()

    View Slide

  10. 6QEBUFS
    © 2017 Retrieva, Inc. 10
    Trainer
    Updater (StandardUpdater)
    Iterator
    Optimizer
    Classifier
    model (MLP)
    train dataset
    Evaluator
    Iterator
    test dataset
    Extensions
    • dump_graph, snapshot
    • LogReport, PrintReport
    • ProgressBar
    • converter
    • loss_func
    • device
    • converter
    • device

    View Slide

  11. 6QEBUFS
    • ⼊⼒を逐次実⾏する
    • ⼊⼒の Iterator と、Optimizer を持つ
    • Iterator から⼀つずつデータを読み込み、変換し、Optimizer に
    かける
    © 2017 Retrieva, Inc. 11

    View Slide

  12. 6QEBUFS
    © 2017 Retrieva, Inc. 12
    # examples/mnist/train_mnist.py
    updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu)
    # chainer/training/updater.py
    class StandardUpdater(Updater):
    def update_core(self):
    batch = self._iterators['main'].next()
    in_arrays = self.converter(batch, self.device)
    optimizer = self._optimizers['main']
    loss_func = self.loss_func or optimizer.target
    if isinstance(in_arrays, tuple):
    in_vars = tuple(variable.Variable(x) for x in in_arrays)
    optimizer.update(loss_func, *in_vars)

    View Slide

  13. 6QEBUFS
    © 2017 Retrieva, Inc. 13
    # examples/mnist/train_mnist.py
    updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu)
    # chainer/training/updater.py
    class StandardUpdater(Updater):
    def update_core(self):
    batch = self._iterators['main'].next()
    in_arrays = self.converter(batch, self.device)
    optimizer = self._optimizers['main']
    loss_func = self.loss_func or optimizer.target
    if isinstance(in_arrays, tuple):
    in_vars = tuple(variable.Variable(x) for x in in_arrays)
    optimizer.update(loss_func, *in_vars)
    Iterator から
    ⼀つ呼び出す

    View Slide

  14. 6QEBUFS
    © 2017 Retrieva, Inc. 14
    # examples/mnist/train_mnist.py
    updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu)
    # chainer/training/updater.py
    class StandardUpdater(Updater):
    def update_core(self):
    batch = self._iterators['main'].next()
    in_arrays = self.converter(batch, self.device)
    optimizer = self._optimizers['main']
    loss_func = self.loss_func or optimizer.target
    if isinstance(in_arrays, tuple):
    in_vars = tuple(variable.Variable(x) for x in in_arrays)
    optimizer.update(loss_func, *in_vars)
    Iterator から
    ⼀つ呼び出す
    Converter にかける
    (変換し、to_gpu する)

    View Slide

  15. 6QEBUFS
    © 2017 Retrieva, Inc. 15
    # examples/mnist/train_mnist.py
    updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu)
    # chainer/training/updater.py
    class StandardUpdater(Updater):
    def update_core(self):
    batch = self._iterators['main'].next()
    in_arrays = self.converter(batch, self.device)
    optimizer = self._optimizers['main']
    loss_func = self.loss_func or optimizer.target
    if isinstance(in_arrays, tuple):
    in_vars = tuple(variable.Variable(x) for x in in_arrays)
    optimizer.update(loss_func, *in_vars)
    Iterator から
    ⼀つ呼び出す
    Converter にかける
    (変換し、to_gpu する)
    Optimizer の
    update を呼ぶ

    View Slide

  16. *UFSBUPS
    © 2017 Retrieva, Inc. 16
    Trainer
    Updater (StandardUpdater)
    Iterator
    Optimizer
    Classifier
    model (MLP)
    train dataset
    Evaluator
    Iterator
    test dataset
    Extensions
    • dump_graph, snapshot
    • LogReport, PrintReport
    • ProgressBar
    • converter
    • loss_func
    • device
    • converter
    • device

    View Slide

  17. *UFSBUPS
    © 2017 Retrieva, Inc. 17
    # chainer/iterators/serial_iterator.py
    class SerialIterator(iterator.Iterator):
    def __next__(self):
    ...
    return batch
    @property
    def epoch_detail(self):
    return self.epoch + self.current_position / len(self.dataset)
    # examples/mnist/train_mnist.py
    train, test = chainer.datasets.get_mnist()
    train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
    • Iterator として、batch を返す
    • 回している回数の管理をする

    View Slide

  18. 0QUJNJ[FS
    © 2017 Retrieva, Inc. 18
    Trainer
    Updater (StandardUpdater)
    Iterator
    Optimizer
    Classifier
    model (MLP)
    train dataset
    Evaluator
    Iterator
    test dataset
    Extensions
    • dump_graph, snapshot
    • LogReport, PrintReport
    • ProgressBar
    • converter
    • loss_func
    • device
    • converter
    • device

    View Slide

  19. 0QUJNJ[FS
    • ⼊⼒データを model に forward し、返り値の loss を
    backward する
    • 最適化アルゴリズムごとに実装がある
    • SGD, MomentumSGD, Adam, …
    • Optimizer で抽象化されている
    © 2017 Retrieva, Inc. 19

    View Slide

  20. 0QUJNJ[FS
    © 2017 Retrieva, Inc. 20
    # chainer/training/updater.py
    loss_func = self.loss_func or optimizer.target
    optimizer.update(loss_func, *in_vars)
    # examples/mnist/train_mnist.py
    optimizer = chainer.optimizers.Adam()
    optimizer.setup(model)
    # chainer/optimizer.py
    class GradientMethod(Optimizer):
    def update(self, lossfun=None, *args, **kwds):
    if lossfun is not None:
    use_cleargrads = getattr(self, '_use_cleargrads', False)
    loss = lossfun(*args, **kwds)
    if use_cleargrads:
    self.target.cleargrads()
    else:
    self.target.zerograds()
    loss.backward()

    View Slide

  21. 0QUJNJ[FS
    © 2017 Retrieva, Inc. 21
    # chainer/training/updater.py
    loss_func = self.loss_func or optimizer.target
    optimizer.update(loss_func, *in_vars)
    # examples/mnist/train_mnist.py
    optimizer = chainer.optimizers.Adam()
    optimizer.setup(model)
    # chainer/optimizer.py
    class GradientMethod(Optimizer):
    def update(self, lossfun=None, *args, **kwds):
    if lossfun is not None:
    use_cleargrads = getattr(self, '_use_cleargrads', False)
    loss = lossfun(*args, **kwds)
    if use_cleargrads:
    self.target.cleargrads()
    else:
    self.target.zerograds()
    loss.backward()
    target は
    渡された model
    (Classifier)

    View Slide

  22. 0QUJNJ[FS
    © 2017 Retrieva, Inc. 22
    # chainer/training/updater.py
    loss_func = self.loss_func or optimizer.target
    optimizer.update(loss_func, *in_vars)
    # examples/mnist/train_mnist.py
    optimizer = chainer.optimizers.Adam()
    optimizer.setup(model)
    # chainer/optimizer.py
    class GradientMethod(Optimizer):
    def update(self, lossfun=None, *args, **kwds):
    if lossfun is not None:
    use_cleargrads = getattr(self, '_use_cleargrads', False)
    loss = lossfun(*args, **kwds)
    if use_cleargrads:
    self.target.cleargrads()
    else:
    self.target.zerograds()
    loss.backward()
    target は
    渡された model
    (Classifier)
    model に forward

    View Slide

  23. 0QUJNJ[FS
    © 2017 Retrieva, Inc. 23
    # chainer/training/updater.py
    loss_func = self.loss_func or optimizer.target
    optimizer.update(loss_func, *in_vars)
    # examples/mnist/train_mnist.py
    optimizer = chainer.optimizers.Adam()
    optimizer.setup(model)
    # chainer/optimizer.py
    class GradientMethod(Optimizer):
    def update(self, lossfun=None, *args, **kwds):
    if lossfun is not None:
    use_cleargrads = getattr(self, '_use_cleargrads', False)
    loss = lossfun(*args, **kwds)
    if use_cleargrads:
    self.target.cleargrads()
    else:
    self.target.zerograds()
    loss.backward()
    target は
    渡された model
    (Classifier)
    model に forward
    backward を実⾏

    View Slide

  24. $MBTTJGJFS
    © 2017 Retrieva, Inc. 24
    Trainer
    Updater (StandardUpdater)
    Iterator
    Optimizer
    Classifier
    model (MLP)
    train dataset
    Evaluator
    Iterator
    test dataset
    Extensions
    • dump_graph, snapshot
    • LogReport, PrintReport
    • ProgressBar
    • converter
    • loss_func
    • device
    • converter
    • device

    View Slide

  25. $MBTTJGJFS
    • 教師あり学習⽤の model のラッパー
    • ⼊⼒と正解データから、loss と accuracy を計算する
    © 2017 Retrieva, Inc. 25

    View Slide

  26. $MBTTJGJFS
    © 2017 Retrieva, Inc. 26
    # examples/mnist/train_mnist.py
    model = L.Classifier(MLP(args.unit, 10))
    # chainer/links/model/classifier.py
    class Classifier(link.Chain):
    def __init__(self, predictor,
    lossfun=softmax_cross_entropy.softmax_cross_entropy,
    accfun=accuracy.accuracy):
    def __call__(self, *args):
    self.y = self.predictor(*x)
    self.loss = self.lossfun(self.y, t)
    reporter.report({'loss': self.loss}, self)
    if self.compute_accuracy:
    self.accuracy = self.accfun(self.y, t)
    reporter.report({'accuracy': self.accuracy}, self)
    return self.loss

    View Slide

  27. $MBTTJGJFS
    © 2017 Retrieva, Inc. 27
    # examples/mnist/train_mnist.py
    model = L.Classifier(MLP(args.unit, 10))
    # chainer/links/model/classifier.py
    class Classifier(link.Chain):
    def __init__(self, predictor,
    lossfun=softmax_cross_entropy.softmax_cross_entropy,
    accfun=accuracy.accuracy):
    def __call__(self, *args):
    self.y = self.predictor(*x)
    self.loss = self.lossfun(self.y, t)
    reporter.report({'loss': self.loss}, self)
    if self.compute_accuracy:
    self.accuracy = self.accfun(self.y, t)
    reporter.report({'accuracy': self.accuracy}, self)
    return self.loss
    損失関数を指定する

    View Slide

  28. $MBTTJGJFS
    © 2017 Retrieva, Inc. 28
    # examples/mnist/train_mnist.py
    model = L.Classifier(MLP(args.unit, 10))
    # chainer/links/model/classifier.py
    class Classifier(link.Chain):
    def __init__(self, predictor,
    lossfun=softmax_cross_entropy.softmax_cross_entropy,
    accfun=accuracy.accuracy):
    def __call__(self, *args):
    self.y = self.predictor(*x)
    self.loss = self.lossfun(self.y, t)
    reporter.report({'loss': self.loss}, self)
    if self.compute_accuracy:
    self.accuracy = self.accfun(self.y, t)
    reporter.report({'accuracy': self.accuracy}, self)
    return self.loss
    損失関数を指定する
    model に forward

    View Slide

  29. $MBTTJGJFS
    © 2017 Retrieva, Inc. 29
    # examples/mnist/train_mnist.py
    model = L.Classifier(MLP(args.unit, 10))
    # chainer/links/model/classifier.py
    class Classifier(link.Chain):
    def __init__(self, predictor,
    lossfun=softmax_cross_entropy.softmax_cross_entropy,
    accfun=accuracy.accuracy):
    def __call__(self, *args):
    self.y = self.predictor(*x)
    self.loss = self.lossfun(self.y, t)
    reporter.report({'loss': self.loss}, self)
    if self.compute_accuracy:
    self.accuracy = self.accfun(self.y, t)
    reporter.report({'accuracy': self.accuracy}, self)
    return self.loss
    損失関数を指定する
    model に forward
    loss を算出
    accuracy を算出

    View Slide

  30. $MBTTJGJFS
    © 2017 Retrieva, Inc. 30
    # examples/mnist/train_mnist.py
    model = L.Classifier(MLP(args.unit, 10))
    # chainer/links/model/classifier.py
    class Classifier(link.Chain):
    def __init__(self, predictor,
    lossfun=softmax_cross_entropy.softmax_cross_entropy,
    accfun=accuracy.accuracy):
    def __call__(self, *args):
    self.y = self.predictor(*x)
    self.loss = self.lossfun(self.y, t)
    reporter.report({'loss': self.loss}, self)
    if self.compute_accuracy:
    self.accuracy = self.accfun(self.y, t)
    reporter.report({'accuracy': self.accuracy}, self)
    return self.loss
    損失関数を指定する
    model に forward
    loss を算出
    accuracy を算出
    loss を返す

    View Slide

  31. &WBMVBUPS
    © 2017 Retrieva, Inc. 31
    Trainer
    Updater (StandardUpdater)
    Iterator
    Optimizer
    Classifier
    model (MLP)
    train dataset
    Evaluator
    Iterator
    test dataset
    Extensions
    • dump_graph, snapshot
    • LogReport, PrintReport
    • ProgressBar
    • converter
    • loss_func
    • device
    • converter
    • device

    View Slide

  32. &WBMVBUPS
    • テストデータに対して、loss, accuracy などを計算し、検証す

    • epoch ごとに、現在まで学習された model に対して検証する
    • ⼤まかには、Updater と対応している
    © 2017 Retrieva, Inc. 32

    View Slide

  33. &WBMVBUPS
    © 2017 Retrieva, Inc. 33
    # examples/mnist/train_mnist.py
    test_iter = chainer.iterators.SerialIterator(test, args.batchsize,
    repeat=False, shuffle=False)
    trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu))
    # chainer/training/extensions/evaluator.py
    class Evaluator(extension.Extension):
    def evaluate(self):
    iterator = self._iterators['main']
    target = self._targets['main']
    eval_func = self.eval_func or target
    it = copy.copy(iterator)
    for batch in it:
    observation = {}
    with reporter_module.report_scope(observation):
    in_arrays = self.converter(batch, self.device)
    if isinstance(in_arrays, tuple):
    in_vars = tuple(variable.Variable(x, volatile='on')
    for x in in_arrays)
    eval_func(*in_vars)

    View Slide

  34. &WBMVBUPS
    © 2017 Retrieva, Inc. 34
    # examples/mnist/train_mnist.py
    test_iter = chainer.iterators.SerialIterator(test, args.batchsize,
    repeat=False, shuffle=False)
    trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu))
    # chainer/training/extensions/evaluator.py
    class Evaluator(extension.Extension):
    def evaluate(self):
    iterator = self._iterators['main']
    target = self._targets['main']
    eval_func = self.eval_func or target
    it = copy.copy(iterator)
    for batch in it:
    observation = {}
    with reporter_module.report_scope(observation):
    in_arrays = self.converter(batch, self.device)
    if isinstance(in_arrays, tuple):
    in_vars = tuple(variable.Variable(x, volatile='on')
    for x in in_arrays)
    eval_func(*in_vars)
    Iterator から
    全部呼び出す

    View Slide

  35. &WBMVBUPS
    © 2017 Retrieva, Inc. 35
    # examples/mnist/train_mnist.py
    test_iter = chainer.iterators.SerialIterator(test, args.batchsize,
    repeat=False, shuffle=False)
    trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu))
    # chainer/training/extensions/evaluator.py
    class Evaluator(extension.Extension):
    def evaluate(self):
    iterator = self._iterators['main']
    target = self._targets['main']
    eval_func = self.eval_func or target
    it = copy.copy(iterator)
    for batch in it:
    observation = {}
    with reporter_module.report_scope(observation):
    in_arrays = self.converter(batch, self.device)
    if isinstance(in_arrays, tuple):
    in_vars = tuple(variable.Variable(x, volatile='on')
    for x in in_arrays)
    eval_func(*in_vars)
    Iterator から
    全部呼び出す
    Converter にかける
    (変換し、to_gpu する)

    View Slide

  36. &WBMVBUPS
    © 2017 Retrieva, Inc. 36
    # examples/mnist/train_mnist.py
    test_iter = chainer.iterators.SerialIterator(test, args.batchsize,
    repeat=False, shuffle=False)
    trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu))
    # chainer/training/extensions/evaluator.py
    class Evaluator(extension.Extension):
    def evaluate(self):
    iterator = self._iterators['main']
    target = self._targets['main']
    eval_func = self.eval_func or target
    it = copy.copy(iterator)
    for batch in it:
    observation = {}
    with reporter_module.report_scope(observation):
    in_arrays = self.converter(batch, self.device)
    if isinstance(in_arrays, tuple):
    in_vars = tuple(variable.Variable(x, volatile='on')
    for x in in_arrays)
    eval_func(*in_vars)
    Iterator から
    全部呼び出す
    Converter にかける
    (変換し、to_gpu する)
    model に forward

    View Slide

  37. આ໌͍ͯ͠ͳ͍͜ͱ
    • Reporter 周り
    • Evaluator で、eval_func しているが戻り値を使っていない理由
    • (ざっくり⾔うと)Classifier 内で、reporter に loss, accuracy を登
    録している
    • Extension 周り
    © 2017 Retrieva, Inc. 37

    View Slide

  38. ୈ̎෦ /4USFQ-45. ͱͷ֨ಆ
    © 2017 Retrieva, Inc. 38

    View Slide

  39. /4UFQ-45. ͱ͸
    • RNN のための、Chainer 1.16.0 で導⼊された Link
    • cuDNN の恩恵を受けて、⾼速に動く
    • 既存の LSTM と使い⽅が違う
    • 既存の LSTM のサンプルは examples/ptb/train_ptb.py
    © 2017 Retrieva, Inc. 39

    View Slide

  40. 3//
    • Recurrent Neural Network
    • 並び⽅に意味のある、「系列データ」を扱う場合に⽤いられる
    • 応⽤例:⽂章の推定、⾳声認識、変動する数値の推定
    • 例:⽂章が途中まで与えられた時、次の単語を予測する問題
    © 2017 Retrieva, Inc. 40
    私 は ⽩い ⽝ が ?
    x1 x2 x3 x4 x5
    y1 y2 y3 y4 y5
    x1〜x5を⼊⼒データと
    して、y5を推定する。

    View Slide

  41. -45.ͱ /4UFQ-45.
    データ1 1 2
    データ1ラベル A B
    データ2 1 2 3
    データ2ラベル A B C
    © 2017 Retrieva, Inc. 41
    • LSTM(逐次渡す)
    • x1: Variable[1, 1]
    • t1: Variable[B, B]
    • x2: Variable[2, 2]
    • t2: Variable[0, C]
    • NStepLSTM(⼀度に渡す)
    • xs: [Variable[1,2], Variable[1,2,3]]
    • ts: [Variable[A,B], Variable[A,B,C]]

    View Slide

  42. -45.ͱ /4UFQ-45.
    データ1 1 2
    データ1ラベル A B
    データ2 1 2 3
    データ2ラベル A B C
    © 2017 Retrieva, Inc. 42
    • LSTM(逐次渡す)
    • x1: Variable[1, 1]
    • t1: Variable[B, B]
    • x2: Variable[2, 2]
    • t2: Variable[0, C]
    • NStepLSTM(⼀度に渡す)
    • xs: [Variable[1,2], Variable[1,2,3]]
    • ts: [Variable[A,B], Variable[A,B,C]]
    ⻑さが合っていない時、
    0 などで埋める必要がある
    ⻑さを合わせなくて良い
    Variable の list を渡す

    View Slide

  43. /4UFQ-45. αϯϓϧ
    © 2017 Retrieva, Inc. 43
    class RNNNStepLSTM(chainer.Chain):
    def __init__(self, n_layer, n_units, train=True):
    super(RNNNStepLSTM, self).__init__(
    l1 = L.NStepLSTM(n_layer, n_units, n_units, 0.5, True),
    )
    self.n_layer = n_layer
    self.n_units = n_units
    def __call__(self, xs):
    xp = self.xp
    hx = chainer.Variable(xp.zeros(
    (self.n_layer, len(xs), self.n_units), dtype=xp.float32))
    cx = chainer.Variable(xp.zeros(
    (self.n_layer, len(xs), self.n_units), dtype=xp.float32))
    hy, cy, ys = self.l1(hx, cx, xs, train=self.train)

    View Slide

  44. /4UFQ-45. αϯϓϧ
    © 2017 Retrieva, Inc. 44
    class RNNNStepLSTM(chainer.Chain):
    def __init__(self, n_layer, n_units, train=True):
    super(RNNNStepLSTM, self).__init__(
    l1 = L.NStepLSTM(n_layer, n_units, n_units, 0.5, True),
    )
    self.n_layer = n_layer
    self.n_units = n_units
    def __call__(self, xs):
    xp = self.xp
    hx = chainer.Variable(xp.zeros(
    (self.n_layer, len(xs), self.n_units), dtype=xp.float32))
    cx = chainer.Variable(xp.zeros(
    (self.n_layer, len(xs), self.n_units), dtype=xp.float32))
    hy, cy, ys = self.l1(hx, cx, xs, train=self.train)
    レイヤー数、ユニット数、
    Dropout を指定する
    パラメータの
    初期状態を作成し、
    渡す
    Variable のリストを
    ⼊⼒する
    出⼒も Variable の
    リスト

    View Slide

  45. /4UFQ-45. ͱɺඪ४తͳ 5SBJOFSͷᴥᴪ
    • 標準的な Trainer の構成では、model には Variable を渡す
    • NStepLSTM では、「Variable のリスト」を渡さなければいけない
    © 2017 Retrieva, Inc. 45
    # chainer/training/updater.py
    class StandardUpdater(Updater):
    def update_core(self):
    batch = self._iterators['main'].next()
    in_arrays = self.converter(batch, self.device)
    ...
    if isinstance(in_arrays, tuple):
    in_vars = tuple(variable.Variable(x) for x in in_arrays)
    optimizer.update(loss_func, *in_vars)
    Variableを作成し、
    そのまま渡している

    View Slide

  46. /4UFQ-45. ͱɺඪ४తͳ 5SBJOFSͷᴥᴪ
    • loss の計算を、既存のメソッドに任せられない
    © 2017 Retrieva, Inc. 46
    # chainer/links/model/classifier.py
    class Classifier(link.Chain):
    def __init__(self, predictor,
    lossfun=softmax_cross_entropy.softmax_cross_entropy,
    accfun=accuracy.accuracy):
    def __call__(self, *args):
    self.y = self.predictor(*x)
    self.loss = self.lossfun(self.y, t)
    reporter.report({'loss': self.loss}, self)
    if self.compute_accuracy:
    self.accuracy = self.accfun(self.y, t)
    reporter.report({'accuracy': self.accuracy}, self)
    return self.loss
    NStepLSTM の出⼒だと、
    y は Variable のリスト
    accuracy も同様

    View Slide

  47. [email protected] Λ /4UFQ-45. Ͱ
    • https://github.com/kei-s/chainer-ptb-
    nsteplstm/blob/master/train_ptb_nstep.py
    • できるだけ構成を同じにして(Trainer の上で)、
    train_ptb.py を NStepLSTM を使って実装してみる
    • Disclaimer
    • testモード(100件)では完⾛したけど、全データは⾛らせていない
    • 無駄なコードはありそう…
    © 2017 Retrieva, Inc. 47

    View Slide

  48. [email protected] Λ /4UFQ-45. Ͱ
    • モデル
    • EmbedID にかけるため、⼀度 concat し、split_axis でまた分ける
    • 系列のそれぞれの要素を Linear にかける
    • Iterator
    • bprop_len を Iterator に渡し、バッチで系列化したものを返す
    • Converter
    • NStepLSTM を使った seq2seq https://github.com/pfnet/chainer/pull/2070 を参照
    • Updater
    • 系列をそのままモデルに渡すように変更
    • Evaluator
    • 系列をそのままモデルに渡すように変更
    • Lossfun
    • softmax_cross_entropy をそれぞれの系列に対してかけ、⾜し合わせる
    © 2017 Retrieva, Inc. 48

    View Slide

  49. [email protected] Λ /4UFQ-45. Ͱ
    • ⾼速化したか?
    • Estimated time では 6時間 → 3時間
    • とはいえ、実際に⾛らせると 3時間以上かかりそう
    • Estimated Time に Evaluator 部分が考慮されてなさそう
    • ⾼速化のために必要なこと
    • Chainer もしくは Numpy の世界で処理を終わらせること
    • Python の世界でループを回すとかなり遅くなる
    • Lossfun でループを回しているが、concat して渡してもよさそう(互換性が微
    妙な予感)
    © 2017 Retrieva, Inc. 49

    View Slide

  50. © 2017 Retrieva, Inc.

    View Slide