Upgrade to Pro — share decks privately, control downloads, hide ads and more …

[Journal club] Transformers Learn In-Context by Gradient Descent

[Journal club] Transformers Learn In-Context by Gradient Descent

More Decks by Semantic Machine Intelligence Lab., Keio Univ.

Other Decks in Technology

Transcript

  1. Johannes von Oswald1,2 Eyvind Niklasson2 Ettore Randazzo2 Joao Sacramento1 Alexander

    Mordvintsev2 Andrey Zhmoginov2 Max Vladymyrov2 1. ETH Zurich, 2. Google Research 杉浦孔明研究室 小槻誠太郎 ICML23 OralPoster J. von Oswald, E. Niklasson, E. Randazzo, J. Sacramento, A. Mordvintsev, A. Zhmoginov, and M. Vladymyrov, “Transformers Learn In-Context by Gradient Descent,” in ICML, 2023, pp. 35151–35174.
  2. Induction heads [Olsson+, 22]: Input: [A], [B], …, [A] →

    Output: [B] のようにtoken列を完成させる単純なアルゴリズムを実現するattention head - In-context learningの仕組みの大半はinduction headによって構成されるという仮説, - その間接的な証拠を提示 5 関連研究 – In-context learningの説明を試みた研究
  3. Meta Learning: “学習の方法” を学習 目標: ◦たくさんのタスクで学習し, 新しいタスクが与えられたときに当該タスクに急速に適応 ×タスク特化の性能を向上させる e.g., MAML

    [Finn+, ICML17] Fast Weights [Schmidhuber+, Neural Computation92] ↓ Linear self-attention [Schlag+, ICML21] 6 関連研究 – 学習の方法を獲得する “Meta Learning” Softmaxを取り除いたself-attentionである Linear self-attentionがFast weight controllerと等価であることを示した
  4. Recap: Self-attention ↓ Linear self-attention [Schlag+, ICML21] 8 準備1/2 –

    Linear self-attention (LSA) の導入 Simplify (Softmax → 恒等写像)
  5. Gradient-induced dynamics: ある𝜃GD が存在して, 任意の 𝑒𝑗 , ( 𝑗 ∈

    1, … , 𝑁 ) に対して以下が成立し, さらに e𝑁+1 , 即ち 𝑒test についても上式が成立する. E.g., 13 命題1 – 勾配降下を表現できるLSAのパラメータ設定𝜃GD が存在 [warn]: 𝑉, 𝐾の計算には𝑒𝑁+1 が含まれていない. Wを0で(実用的には十分小さく)初期 化すれば𝑒𝑁+1 を含んでもok.
  6. 線形教師モデルの入出力をもとにデータ生成 ← ( ”LSA == 線形回帰のGD” を検証 ) 入力: In-context

    data, と, Test point, 出力: 損失関数: → を獲得 理想: が と一致, 𝑁個の例示 をもとに の予測に成功 14 実験設定 – 線形教師モデルをin-context Learningで獲得 In-context learning
  7. 線形教師モデルの入出力をもとにデータ生成 ← ( ”LSA == 線形回帰のGD” を検証 ) 入力: In-context

    data, と, Test point, 出力: 損失関数: → を獲得 理想: が と一致, 𝑁個の例示 をもとに の予測に成功 15 実験設定 – 線形教師モデルをin-context Learningで獲得 In-context learning
  8. を決める際の学習ステップ数を増やすと Lossが の場合と一致, 一致度を測る指標も良好 一致度を測る指標 Preds diff : Model cos:

    sensitivity同士のcosine類似度 Model diff: sensitivity同士のL2 norm Sensitivity: (in-context learningで得たモデルの特性を反映) 16 結果 – 命題1の主張の裏付け: 学習済みLSAがGDと等価
  9. LSAを複数層用意した場合の挙動 MLPを追加した場合の挙動 → Appendix Softmaxを追加した場合 (LSA → SA) の挙動 →

    Appendix LayerNormを追加した場合の挙動 → Appendix 非線形回帰への対応 特殊なtoken構成を導入する妥当性 18 追加の議論, 実験 – ここまでは様々なものを無視している
  10. MLPを通した入力を1層のLSAに通す →1stepのGDの性能と一致 GD init, TF init: MLPを通した出力 GD step1: GD

    initに対してGDを適用 TF step1: TF initに対してLSAを適用 21 実験 – 正弦関数の回帰問題において, 1stepのGDに一致