Upgrade to Pro — share decks privately, control downloads, hide ads and more …

[論文紹介] Transformer-based World Models Are Happy With 100k Interactions

tt1717
January 31, 2024

[論文紹介] Transformer-based World Models Are Happy With 100k Interactions

PDFファイルをダウンロードすると,スライド内のリンクを見ることができます.

tt1717

January 31, 2024
Tweet

More Decks by tt1717

Other Decks in Research

Transcript

  1. どんなもの?
    先行研究と比べて何がすごい?
    技術の手法や肝は?
    どうやって有効だと検証した?
    ・Atari 100kベンチマークを使用して評価し,「中央値,四分位平
    均 (IQM),平均スコア」で高い性能を示した
    ・予測された報酬を世界モデルにフィードバックすることで,現在
    どれだけの報酬が出力されているかという情報を提供する
    ・Dreamerv2の損失関数を修正して,関係するエントロピー項とク
    ロスエントロピー項の相対的な重みを微調整した
    ・強化学習におけるサンプル効率の向上を目指し,Transformer-XL
    アーキテクチャを基にした新しい自己回帰型の世界モデル (TWM)を
    提案した
    ・提案されたTWMは,Atari 100kベンチマークで既存のモデルフ
    リー or モデルベースの強化学習アルゴリズムを上回る性能を示した
    Transformer-based World Models Are Happy With 100k Interactions
    (ICLR 2023) Jan Robine, Marc Höftmann, Tobias Uelwer, Stefan Harmeling
    https://arxiv.org/abs/2303.07109
    2024/01/31
    論文を表す画像
    被引用数:13
    1/9
    ・Transformer-XLアーキテクチャを活用することで長期依存関係を
    学習し,計算効率を保持している
    ・TWMは推論時にTransformerを必要としないため,計算コストを
    削減している

    View full-size slide

  2. ❖ 観測のエンコード:
    ➢ 観測otはCNNを使用して潜在状態ztに変換
    ❖ 潜在状態,行動,報酬の埋め込み:
    ➢ 生成された潜在状態zt,行動at,報酬rtはそれぞれ線形埋め込みを通して
    処理される
    ❖ Transformerの活用:
    ➢ 埋め込まれた潜在状態,行動,報酬はTransformerに入力され,各時間に
    おいて決定論的な隠れ状態htを計算する
    モデル
    2/9

    View full-size slide

  3. モデル
    3/9
    ❖ MLPを使用した予測
    ➢ Transformerによって計算された隠れ状態htを元に,MLPを使用して次の
    潜在状態zt+1^,報酬rt^,割引率γt^の予測を行う
    ❖ 時系列データの処理
    ➢ Transformerはht-Lからhtまでのシーケンスを処理することで過去のデー
    タに基づいて現在の隠れ状態htを更新する

    View full-size slide

  4. 損失関数の設計 (観測モデル)
    4/9
    ❖ decoder:観測デコーダ
    ➢ モデルがデータをどれだけうまく再構成できているかを測る項
    ❖ entropy regularizer:エントロピー正則化項
    ➢ 潜在状態の分布が一様になりすぎることを防ぐための項
    ❖ consistency:一貫性損失
    ➢ エンコーダとダイナミクスモデルが生成する潜在状態の分布の一貫性を測
    る項
    ❖ α1, α2:ハイパラ
    ➢ エントロピー正則化項と一貫性損失の重みを制御する

    View full-size slide

  5. ❖ latent state predictor:潜在状態予測器
    ➢ 次の時間における潜在状態 zt+1 の予測のクロスエントロピー
    ❖ reward predictor:報酬予測器
    ➢ モデルが予測する報酬 rt の負の対数尤度
    ❖ discount predictor:割引予測器
    ➢ 割引率 γt の予測の負の対数尤度,エピソード終了時 dt=1 のときγt=0で
    それ以外のときは,γt=γとなる
    ❖ β1, β2:ハイパラ
    ➢ 報酬予測器と割引予測器の重みを制御する
    損失関数の設計 (ダイナミクスモデル)
    5/9

    View full-size slide

  6. Atari 100kベンチマーク結果 (定量評価)
    6/9
    ❖ 100エピソードで訓練したモデ
    ルで5回評価したスコアから
    「中央値と平均値」を算出
    ❖ Normalized Mean
    ➢ 人間プレイヤーの平均スコア
    に対する各アルゴリズムのス
    コアの正規化平均
    ❖ Normalized Median
    ➢ 人間プレイヤーの平均スコア
    に対する各アルゴリズムのス
    コアの正規化中央値
    ❖ ほとんどのゲームで従来手法を
    上回る性能
    ❖ Normalized Meanのスコアが
    高いことから人間プレイヤーに
    匹敵する性能を示している

    View full-size slide

  7. ❖ Boxing
    ➢ プレイヤー (白) が攻撃 (赤フレーム)を行い,次のフレームで報酬を獲得
    している (緑フレーム)
    ❖ Freeway
    ➢ プレイヤーは上方向に移動するアクションを継続して選択している (赤い
    横枠)
    ❖ モデルは行動を取り,その結果として期待される報酬を計算し,ゲー
    ムの進行を「想像」することができている
    ゲームタスクの観測軌道 (定性評価)
    7/9

    View full-size slide

  8. まとめ
    8/9
    ❖ World model × Transformerによるモデルを提案した
    ❖ Dreamerv2の損失関数の設計を修正した
    ❖ 定量評価において,平均スコアは人間とほぼ同等性能
    ❖ 定性評価では,提案モデルが観測ot,行動at,報酬rtを予測しゲーム
    進行を再現できている

    View full-size slide

  9. 感想
    9/9
    ❖ 推論時にTransformerを使用しないことで,計算コスト削減しているの
    がIRISとの違い (だと思う)
    ❖ このモデルをオフラインデータで実験したらどのようになるのか気に
    なる
    ➢ githubを見た限りデータセットはないのでオンライン学習だと思う

    View full-size slide