Slide 1

Slide 1 text

0 2025-03-21 第116回NearMe技術勉強会 Takuma KAKINOUE RustでDeepQNetworkを実装する (おまけ)tch-rsからcandleに移⾏した結果

Slide 2

Slide 2 text

1 はじめに ● 今回のスライドは、Zennにアップロードした記事(以下、url)のダイジェストになり ます ○ https://zenn.dev/kakky_hacker/articles/652bd7f9a1e6c1

Slide 3

Slide 3 text

2 API解説 ● DQNのインターフェース ○ actメソッド ■ 推論時に毎ステップ呼ぶ ○ act_and_trainメソッド ■ 訓練時に毎ステップ呼ぶ

Slide 4

Slide 4 text

3 actメソッドの実装 ● 推論時に呼ばれる想定なので、tch::no_gradで勾配計算を無効化 ● あとは、Qネットワークが出⼒したQ値の最⼤値のindexを返しているだけ

Slide 5

Slide 5 text

4 act_and_trainメソッドの実装 ● コードは⻑いので貼れないが、やっていることは以下。 ○ 観測した状態に対して、最⼤Q値の⾏動を算出する ○ 最⼤Q値の⾏動を選択するかランダム⾏動を選択するかを決める(ε-greedy法など) ○ リプレイバッファに状態‧⾏動‧報酬を記録する ○ update間隔に達していたら、 _updateメソッド(次スライドで解説)で重みを更新する ○ 選択した⾏動を返す

Slide 6

Slide 6 text

5 _updateメソッドの実装 ● ⼤まかな流れ ○ リプレイバッファから経験をサンプリング ○ Q値の更新式の各変数の値を求める ○ 損失を計算して、重みを更新する ● Q値の更新式は以下

Slide 7

Slide 7 text

6 _updateメソッドの実装 ● 損失の計算式 ○ 現状は平均⼆乗誤差を使っているが、Huber損失も実装予定 ■ Huber損失 ● https://ja.wikipedia.org/wiki/Huber%E6%90%8D%E5%A4%B1

Slide 8

Slide 8 text

7 tch-rsとcandle ● tch-rs(https://github.com/LaurentMazare/tch-rs) ○ メリット ■ コアの部分がPytorchなので実績と信頼性は⼗分 ○ デメリット ■ Pytorchの全機能をRustから呼べるわけではない ■ しかし、全機能を含んだコア部分をinstallするので重くなりがち ● candle(https://github.com/huggingface/candle) ○ メリット ■ Pure Rustなので、Rustから使う場合は型推論周りは良い ■ パッケージが軽い、WebAssenbly対応 ○ デメリット ■ Pytorchと⽐べてまだ実績が少ない

Slide 9

Slide 9 text

8 tch-rsからcandle移⾏した結果 ● packageの容量を1.7GB → 0.8GBに削減できた! ● しかし、訓練時のメモリ使⽤量や実⾏時間はtch-rsの⽅が若⼲性能が良かった ○ 結局、しばらくはtch-rsで開発を進めることに

Slide 10

Slide 10 text

9 今後の展望 ● [WIP] Proximal Policy Optimizationの実装 ○ PPO論⽂:https://arxiv.org/abs/1707.06347 ● Soft Actor Criticの実装 ○ SAC論⽂:https://arxiv.org/abs/1801.01290 ● リプレイバッファから経験をサンプリングするときに優先度を設ける ○ Prioritized Replay Buffer論⽂:https://arxiv.org/abs/1511.05952 ● 好奇⼼報酬による探索の効率化の導⼊ ○ RND論⽂:https://arxiv.org/abs/1810.12894 ○ SND論⽂:https://arxiv.org/abs/2302.11563

Slide 11

Slide 11 text

10 Thank you