Upgrade to Pro — share decks privately, control downloads, hide ads and more …

Chainer の Trainer 解説と NStepLSTM について

Chainer の Trainer 解説と NStepLSTM について

レトリバセミナー 2017/03/15
Movie: https://www.youtube.com/watch?v=ok_bvPKAEaM

Kei Shiratsuchi

March 17, 2017
Tweet

More Decks by Kei Shiratsuchi

Other Decks in Technology

Transcript

  1. ࣗݾ঺հ • ⽩⼟慧(シラツチ ケイ) • 株式会社レトリバ • 2016年4⽉⼊社 • Ruby

    on Rails / JavaScript • フロントエンド側の⼈間 • ⼤学時代は複雑ネットワーク科学の研究 • Chainer ⼊⾨中 © 2017 Retrieva, Inc. 2
  2. Ξϯέʔτ • Chainer を使っている⽅ • Chainer の Trainer を使っている⽅ •

    LSTM を使っている⽅ • NStepLSTM を使っている⽅ • NStepLSTM と Trainer を使っている⽅ © 2017 Retrieva, Inc. 4
  3. $IBJOFS ʹ͓͚Δ 5SBJOFS • Chainer 1.11.0 から導⼊された学習フレームワーク • batchの取り出し、forward/backward が抽象化されている

    • 進捗表⽰、モデルのスナップショットなど • Trainer 後から⼊⾨した⼈(私も)は、MNIST のサンプルが Trainerで抽象化されていて、何が起きているのかわからない • 以前から Chainer を使っている⼈は、Trainer なしで動かして いることが多い © 2017 Retrieva, Inc. 6
  4. 5SBJOFSશମਤ USBJO@NOJTUQZ © 2017 Retrieva, Inc. 7 Trainer Updater (StandardUpdater)

    Iterator Optimizer Classifier model (MLP) train dataset Evaluator Iterator test dataset Extensions • dump_graph, snapshot • LogReport, PrintReport • ProgressBar • converter • loss_func • device • converter • device
  5. 5SBJOFS • Trainer フレームワークの⼤元 • 渡された Updater、(必要があれば) Evaluator を実⾏する •

    グラフのダンプ、スナップショット、レポーティング、進捗表 ⽰などを、Extension として実⾏できる © 2017 Retrieva, Inc. 8
  6. 5SBJOFS • 指定した epoch 数になるまで、Updater の update() を呼ぶ © 2017

    Retrieva, Inc. 9 # examples/mnist/train_mnist.py trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out) # chainer/training/trainer.py class Trainer(object): def run(self): update = self.updater.update # main training loop try: while not stop_trigger(self): self.observation = {} with reporter.scope(self.observation): update()
  7. 6QEBUFS © 2017 Retrieva, Inc. 10 Trainer Updater (StandardUpdater) Iterator

    Optimizer Classifier model (MLP) train dataset Evaluator Iterator test dataset Extensions • dump_graph, snapshot • LogReport, PrintReport • ProgressBar • converter • loss_func • device • converter • device
  8. 6QEBUFS • ⼊⼒を逐次実⾏する • ⼊⼒の Iterator と、Optimizer を持つ • Iterator

    から⼀つずつデータを読み込み、変換し、Optimizer に かける © 2017 Retrieva, Inc. 11
  9. 6QEBUFS © 2017 Retrieva, Inc. 12 # examples/mnist/train_mnist.py updater =

    training.StandardUpdater(train_iter, optimizer, device=args.gpu) # chainer/training/updater.py class StandardUpdater(Updater): def update_core(self): batch = self._iterators['main'].next() in_arrays = self.converter(batch, self.device) optimizer = self._optimizers['main'] loss_func = self.loss_func or optimizer.target if isinstance(in_arrays, tuple): in_vars = tuple(variable.Variable(x) for x in in_arrays) optimizer.update(loss_func, *in_vars)
  10. 6QEBUFS © 2017 Retrieva, Inc. 13 # examples/mnist/train_mnist.py updater =

    training.StandardUpdater(train_iter, optimizer, device=args.gpu) # chainer/training/updater.py class StandardUpdater(Updater): def update_core(self): batch = self._iterators['main'].next() in_arrays = self.converter(batch, self.device) optimizer = self._optimizers['main'] loss_func = self.loss_func or optimizer.target if isinstance(in_arrays, tuple): in_vars = tuple(variable.Variable(x) for x in in_arrays) optimizer.update(loss_func, *in_vars) Iterator から ⼀つ呼び出す
  11. 6QEBUFS © 2017 Retrieva, Inc. 14 # examples/mnist/train_mnist.py updater =

    training.StandardUpdater(train_iter, optimizer, device=args.gpu) # chainer/training/updater.py class StandardUpdater(Updater): def update_core(self): batch = self._iterators['main'].next() in_arrays = self.converter(batch, self.device) optimizer = self._optimizers['main'] loss_func = self.loss_func or optimizer.target if isinstance(in_arrays, tuple): in_vars = tuple(variable.Variable(x) for x in in_arrays) optimizer.update(loss_func, *in_vars) Iterator から ⼀つ呼び出す Converter にかける (変換し、to_gpu する)
  12. 6QEBUFS © 2017 Retrieva, Inc. 15 # examples/mnist/train_mnist.py updater =

    training.StandardUpdater(train_iter, optimizer, device=args.gpu) # chainer/training/updater.py class StandardUpdater(Updater): def update_core(self): batch = self._iterators['main'].next() in_arrays = self.converter(batch, self.device) optimizer = self._optimizers['main'] loss_func = self.loss_func or optimizer.target if isinstance(in_arrays, tuple): in_vars = tuple(variable.Variable(x) for x in in_arrays) optimizer.update(loss_func, *in_vars) Iterator から ⼀つ呼び出す Converter にかける (変換し、to_gpu する) Optimizer の update を呼ぶ
  13. *UFSBUPS © 2017 Retrieva, Inc. 16 Trainer Updater (StandardUpdater) Iterator

    Optimizer Classifier model (MLP) train dataset Evaluator Iterator test dataset Extensions • dump_graph, snapshot • LogReport, PrintReport • ProgressBar • converter • loss_func • device • converter • device
  14. *UFSBUPS © 2017 Retrieva, Inc. 17 # chainer/iterators/serial_iterator.py class SerialIterator(iterator.Iterator):

    def __next__(self): ... return batch @property def epoch_detail(self): return self.epoch + self.current_position / len(self.dataset) # examples/mnist/train_mnist.py train, test = chainer.datasets.get_mnist() train_iter = chainer.iterators.SerialIterator(train, args.batchsize) • Iterator として、batch を返す • 回している回数の管理をする
  15. 0QUJNJ[FS © 2017 Retrieva, Inc. 18 Trainer Updater (StandardUpdater) Iterator

    Optimizer Classifier model (MLP) train dataset Evaluator Iterator test dataset Extensions • dump_graph, snapshot • LogReport, PrintReport • ProgressBar • converter • loss_func • device • converter • device
  16. 0QUJNJ[FS • ⼊⼒データを model に forward し、返り値の loss を backward

    する • 最適化アルゴリズムごとに実装がある • SGD, MomentumSGD, Adam, … • Optimizer で抽象化されている © 2017 Retrieva, Inc. 19
  17. 0QUJNJ[FS © 2017 Retrieva, Inc. 20 # chainer/training/updater.py loss_func =

    self.loss_func or optimizer.target optimizer.update(loss_func, *in_vars) # examples/mnist/train_mnist.py optimizer = chainer.optimizers.Adam() optimizer.setup(model) # chainer/optimizer.py class GradientMethod(Optimizer): def update(self, lossfun=None, *args, **kwds): if lossfun is not None: use_cleargrads = getattr(self, '_use_cleargrads', False) loss = lossfun(*args, **kwds) if use_cleargrads: self.target.cleargrads() else: self.target.zerograds() loss.backward()
  18. 0QUJNJ[FS © 2017 Retrieva, Inc. 21 # chainer/training/updater.py loss_func =

    self.loss_func or optimizer.target optimizer.update(loss_func, *in_vars) # examples/mnist/train_mnist.py optimizer = chainer.optimizers.Adam() optimizer.setup(model) # chainer/optimizer.py class GradientMethod(Optimizer): def update(self, lossfun=None, *args, **kwds): if lossfun is not None: use_cleargrads = getattr(self, '_use_cleargrads', False) loss = lossfun(*args, **kwds) if use_cleargrads: self.target.cleargrads() else: self.target.zerograds() loss.backward() target は 渡された model (Classifier)
  19. 0QUJNJ[FS © 2017 Retrieva, Inc. 22 # chainer/training/updater.py loss_func =

    self.loss_func or optimizer.target optimizer.update(loss_func, *in_vars) # examples/mnist/train_mnist.py optimizer = chainer.optimizers.Adam() optimizer.setup(model) # chainer/optimizer.py class GradientMethod(Optimizer): def update(self, lossfun=None, *args, **kwds): if lossfun is not None: use_cleargrads = getattr(self, '_use_cleargrads', False) loss = lossfun(*args, **kwds) if use_cleargrads: self.target.cleargrads() else: self.target.zerograds() loss.backward() target は 渡された model (Classifier) model に forward
  20. 0QUJNJ[FS © 2017 Retrieva, Inc. 23 # chainer/training/updater.py loss_func =

    self.loss_func or optimizer.target optimizer.update(loss_func, *in_vars) # examples/mnist/train_mnist.py optimizer = chainer.optimizers.Adam() optimizer.setup(model) # chainer/optimizer.py class GradientMethod(Optimizer): def update(self, lossfun=None, *args, **kwds): if lossfun is not None: use_cleargrads = getattr(self, '_use_cleargrads', False) loss = lossfun(*args, **kwds) if use_cleargrads: self.target.cleargrads() else: self.target.zerograds() loss.backward() target は 渡された model (Classifier) model に forward backward を実⾏
  21. $MBTTJGJFS © 2017 Retrieva, Inc. 24 Trainer Updater (StandardUpdater) Iterator

    Optimizer Classifier model (MLP) train dataset Evaluator Iterator test dataset Extensions • dump_graph, snapshot • LogReport, PrintReport • ProgressBar • converter • loss_func • device • converter • device
  22. $MBTTJGJFS © 2017 Retrieva, Inc. 26 # examples/mnist/train_mnist.py model =

    L.Classifier(MLP(args.unit, 10)) # chainer/links/model/classifier.py class Classifier(link.Chain): def __init__(self, predictor, lossfun=softmax_cross_entropy.softmax_cross_entropy, accfun=accuracy.accuracy): def __call__(self, *args): self.y = self.predictor(*x) self.loss = self.lossfun(self.y, t) reporter.report({'loss': self.loss}, self) if self.compute_accuracy: self.accuracy = self.accfun(self.y, t) reporter.report({'accuracy': self.accuracy}, self) return self.loss
  23. $MBTTJGJFS © 2017 Retrieva, Inc. 27 # examples/mnist/train_mnist.py model =

    L.Classifier(MLP(args.unit, 10)) # chainer/links/model/classifier.py class Classifier(link.Chain): def __init__(self, predictor, lossfun=softmax_cross_entropy.softmax_cross_entropy, accfun=accuracy.accuracy): def __call__(self, *args): self.y = self.predictor(*x) self.loss = self.lossfun(self.y, t) reporter.report({'loss': self.loss}, self) if self.compute_accuracy: self.accuracy = self.accfun(self.y, t) reporter.report({'accuracy': self.accuracy}, self) return self.loss 損失関数を指定する
  24. $MBTTJGJFS © 2017 Retrieva, Inc. 28 # examples/mnist/train_mnist.py model =

    L.Classifier(MLP(args.unit, 10)) # chainer/links/model/classifier.py class Classifier(link.Chain): def __init__(self, predictor, lossfun=softmax_cross_entropy.softmax_cross_entropy, accfun=accuracy.accuracy): def __call__(self, *args): self.y = self.predictor(*x) self.loss = self.lossfun(self.y, t) reporter.report({'loss': self.loss}, self) if self.compute_accuracy: self.accuracy = self.accfun(self.y, t) reporter.report({'accuracy': self.accuracy}, self) return self.loss 損失関数を指定する model に forward
  25. $MBTTJGJFS © 2017 Retrieva, Inc. 29 # examples/mnist/train_mnist.py model =

    L.Classifier(MLP(args.unit, 10)) # chainer/links/model/classifier.py class Classifier(link.Chain): def __init__(self, predictor, lossfun=softmax_cross_entropy.softmax_cross_entropy, accfun=accuracy.accuracy): def __call__(self, *args): self.y = self.predictor(*x) self.loss = self.lossfun(self.y, t) reporter.report({'loss': self.loss}, self) if self.compute_accuracy: self.accuracy = self.accfun(self.y, t) reporter.report({'accuracy': self.accuracy}, self) return self.loss 損失関数を指定する model に forward loss を算出 accuracy を算出
  26. $MBTTJGJFS © 2017 Retrieva, Inc. 30 # examples/mnist/train_mnist.py model =

    L.Classifier(MLP(args.unit, 10)) # chainer/links/model/classifier.py class Classifier(link.Chain): def __init__(self, predictor, lossfun=softmax_cross_entropy.softmax_cross_entropy, accfun=accuracy.accuracy): def __call__(self, *args): self.y = self.predictor(*x) self.loss = self.lossfun(self.y, t) reporter.report({'loss': self.loss}, self) if self.compute_accuracy: self.accuracy = self.accfun(self.y, t) reporter.report({'accuracy': self.accuracy}, self) return self.loss 損失関数を指定する model に forward loss を算出 accuracy を算出 loss を返す
  27. &WBMVBUPS © 2017 Retrieva, Inc. 31 Trainer Updater (StandardUpdater) Iterator

    Optimizer Classifier model (MLP) train dataset Evaluator Iterator test dataset Extensions • dump_graph, snapshot • LogReport, PrintReport • ProgressBar • converter • loss_func • device • converter • device
  28. &WBMVBUPS • テストデータに対して、loss, accuracy などを計算し、検証す る • epoch ごとに、現在まで学習された model

    に対して検証する • ⼤まかには、Updater と対応している © 2017 Retrieva, Inc. 32
  29. &WBMVBUPS © 2017 Retrieva, Inc. 33 # examples/mnist/train_mnist.py test_iter =

    chainer.iterators.SerialIterator(test, args.batchsize, repeat=False, shuffle=False) trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu)) # chainer/training/extensions/evaluator.py class Evaluator(extension.Extension): def evaluate(self): iterator = self._iterators['main'] target = self._targets['main'] eval_func = self.eval_func or target it = copy.copy(iterator) for batch in it: observation = {} with reporter_module.report_scope(observation): in_arrays = self.converter(batch, self.device) if isinstance(in_arrays, tuple): in_vars = tuple(variable.Variable(x, volatile='on') for x in in_arrays) eval_func(*in_vars)
  30. &WBMVBUPS © 2017 Retrieva, Inc. 34 # examples/mnist/train_mnist.py test_iter =

    chainer.iterators.SerialIterator(test, args.batchsize, repeat=False, shuffle=False) trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu)) # chainer/training/extensions/evaluator.py class Evaluator(extension.Extension): def evaluate(self): iterator = self._iterators['main'] target = self._targets['main'] eval_func = self.eval_func or target it = copy.copy(iterator) for batch in it: observation = {} with reporter_module.report_scope(observation): in_arrays = self.converter(batch, self.device) if isinstance(in_arrays, tuple): in_vars = tuple(variable.Variable(x, volatile='on') for x in in_arrays) eval_func(*in_vars) Iterator から 全部呼び出す
  31. &WBMVBUPS © 2017 Retrieva, Inc. 35 # examples/mnist/train_mnist.py test_iter =

    chainer.iterators.SerialIterator(test, args.batchsize, repeat=False, shuffle=False) trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu)) # chainer/training/extensions/evaluator.py class Evaluator(extension.Extension): def evaluate(self): iterator = self._iterators['main'] target = self._targets['main'] eval_func = self.eval_func or target it = copy.copy(iterator) for batch in it: observation = {} with reporter_module.report_scope(observation): in_arrays = self.converter(batch, self.device) if isinstance(in_arrays, tuple): in_vars = tuple(variable.Variable(x, volatile='on') for x in in_arrays) eval_func(*in_vars) Iterator から 全部呼び出す Converter にかける (変換し、to_gpu する)
  32. &WBMVBUPS © 2017 Retrieva, Inc. 36 # examples/mnist/train_mnist.py test_iter =

    chainer.iterators.SerialIterator(test, args.batchsize, repeat=False, shuffle=False) trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu)) # chainer/training/extensions/evaluator.py class Evaluator(extension.Extension): def evaluate(self): iterator = self._iterators['main'] target = self._targets['main'] eval_func = self.eval_func or target it = copy.copy(iterator) for batch in it: observation = {} with reporter_module.report_scope(observation): in_arrays = self.converter(batch, self.device) if isinstance(in_arrays, tuple): in_vars = tuple(variable.Variable(x, volatile='on') for x in in_arrays) eval_func(*in_vars) Iterator から 全部呼び出す Converter にかける (変換し、to_gpu する) model に forward
  33. /4UFQ-45. ͱ͸ • RNN のための、Chainer 1.16.0 で導⼊された Link • cuDNN

    の恩恵を受けて、⾼速に動く • 既存の LSTM と使い⽅が違う • 既存の LSTM のサンプルは examples/ptb/train_ptb.py © 2017 Retrieva, Inc. 39
  34. 3// • Recurrent Neural Network • 並び⽅に意味のある、「系列データ」を扱う場合に⽤いられる • 応⽤例:⽂章の推定、⾳声認識、変動する数値の推定 •

    例:⽂章が途中まで与えられた時、次の単語を予測する問題 © 2017 Retrieva, Inc. 40 私 は ⽩い ⽝ が ? x1 x2 x3 x4 x5 y1 y2 y3 y4 y5 x1〜x5を⼊⼒データと して、y5を推定する。
  35. -45.ͱ /4UFQ-45. データ1 1 2 データ1ラベル A B データ2 1

    2 3 データ2ラベル A B C © 2017 Retrieva, Inc. 41 • LSTM(逐次渡す) • x1: Variable[1, 1] • t1: Variable[B, B] • x2: Variable[2, 2] • t2: Variable[0, C] • NStepLSTM(⼀度に渡す) • xs: [Variable[1,2], Variable[1,2,3]] • ts: [Variable[A,B], Variable[A,B,C]]
  36. -45.ͱ /4UFQ-45. データ1 1 2 データ1ラベル A B データ2 1

    2 3 データ2ラベル A B C © 2017 Retrieva, Inc. 42 • LSTM(逐次渡す) • x1: Variable[1, 1] • t1: Variable[B, B] • x2: Variable[2, 2] • t2: Variable[0, C] • NStepLSTM(⼀度に渡す) • xs: [Variable[1,2], Variable[1,2,3]] • ts: [Variable[A,B], Variable[A,B,C]] ⻑さが合っていない時、 0 などで埋める必要がある ⻑さを合わせなくて良い Variable の list を渡す
  37. /4UFQ-45. αϯϓϧ © 2017 Retrieva, Inc. 43 class RNNNStepLSTM(chainer.Chain): def

    __init__(self, n_layer, n_units, train=True): super(RNNNStepLSTM, self).__init__( l1 = L.NStepLSTM(n_layer, n_units, n_units, 0.5, True), ) self.n_layer = n_layer self.n_units = n_units def __call__(self, xs): xp = self.xp hx = chainer.Variable(xp.zeros( (self.n_layer, len(xs), self.n_units), dtype=xp.float32)) cx = chainer.Variable(xp.zeros( (self.n_layer, len(xs), self.n_units), dtype=xp.float32)) hy, cy, ys = self.l1(hx, cx, xs, train=self.train)
  38. /4UFQ-45. αϯϓϧ © 2017 Retrieva, Inc. 44 class RNNNStepLSTM(chainer.Chain): def

    __init__(self, n_layer, n_units, train=True): super(RNNNStepLSTM, self).__init__( l1 = L.NStepLSTM(n_layer, n_units, n_units, 0.5, True), ) self.n_layer = n_layer self.n_units = n_units def __call__(self, xs): xp = self.xp hx = chainer.Variable(xp.zeros( (self.n_layer, len(xs), self.n_units), dtype=xp.float32)) cx = chainer.Variable(xp.zeros( (self.n_layer, len(xs), self.n_units), dtype=xp.float32)) hy, cy, ys = self.l1(hx, cx, xs, train=self.train) レイヤー数、ユニット数、 Dropout を指定する パラメータの 初期状態を作成し、 渡す Variable のリストを ⼊⼒する 出⼒も Variable の リスト
  39. /4UFQ-45. ͱɺඪ४తͳ 5SBJOFSͷᴥᴪ • 標準的な Trainer の構成では、model には Variable を渡す

    • NStepLSTM では、「Variable のリスト」を渡さなければいけない © 2017 Retrieva, Inc. 45 # chainer/training/updater.py class StandardUpdater(Updater): def update_core(self): batch = self._iterators['main'].next() in_arrays = self.converter(batch, self.device) ... if isinstance(in_arrays, tuple): in_vars = tuple(variable.Variable(x) for x in in_arrays) optimizer.update(loss_func, *in_vars) Variableを作成し、 そのまま渡している
  40. /4UFQ-45. ͱɺඪ४తͳ 5SBJOFSͷᴥᴪ • loss の計算を、既存のメソッドに任せられない © 2017 Retrieva, Inc.

    46 # chainer/links/model/classifier.py class Classifier(link.Chain): def __init__(self, predictor, lossfun=softmax_cross_entropy.softmax_cross_entropy, accfun=accuracy.accuracy): def __call__(self, *args): self.y = self.predictor(*x) self.loss = self.lossfun(self.y, t) reporter.report({'loss': self.loss}, self) if self.compute_accuracy: self.accuracy = self.accfun(self.y, t) reporter.report({'accuracy': self.accuracy}, self) return self.loss NStepLSTM の出⼒だと、 y は Variable のリスト accuracy も同様
  41. QUCUSBJO@QUCQZ Λ /4UFQ-45. Ͱ • https://github.com/kei-s/chainer-ptb- nsteplstm/blob/master/train_ptb_nstep.py • できるだけ構成を同じにして(Trainer の上で)、

    train_ptb.py を NStepLSTM を使って実装してみる • Disclaimer • testモード(100件)では完⾛したけど、全データは⾛らせていない • 無駄なコードはありそう… © 2017 Retrieva, Inc. 47
  42. QUCUSBJO@QUCQZ Λ /4UFQ-45. Ͱ • モデル • EmbedID にかけるため、⼀度 concat

    し、split_axis でまた分ける • 系列のそれぞれの要素を Linear にかける • Iterator • bprop_len を Iterator に渡し、バッチで系列化したものを返す • Converter • NStepLSTM を使った seq2seq https://github.com/pfnet/chainer/pull/2070 を参照 • Updater • 系列をそのままモデルに渡すように変更 • Evaluator • 系列をそのままモデルに渡すように変更 • Lossfun • softmax_cross_entropy をそれぞれの系列に対してかけ、⾜し合わせる © 2017 Retrieva, Inc. 48
  43. QUCUSBJO@QUCQZ Λ /4UFQ-45. Ͱ • ⾼速化したか? • Estimated time では

    6時間 → 3時間 • とはいえ、実際に⾛らせると 3時間以上かかりそう • Estimated Time に Evaluator 部分が考慮されてなさそう • ⾼速化のために必要なこと • Chainer もしくは Numpy の世界で処理を終わらせること • Python の世界でループを回すとかなり遅くなる • Lossfun でループを回しているが、concat して渡してもよさそう(互換性が微 妙な予感) © 2017 Retrieva, Inc. 49