Upgrade to Pro — share decks privately, control downloads, hide ads and more …

[Journal club] Hungry Hungry Hippos: Towards Language Modeling with State Space Models

[Journal club] Hungry Hungry Hippos: Towards Language Modeling with State Space Models

More Decks by Semantic Machine Intelligence Lab., Keio Univ.

Other Decks in Technology

Transcript

  1. Hungry Hungry Hippos: Towards Language Modeling with State Space Models

    Daniel Y. Fu1, Tri Dao1, Khaled K. Saab1, Armin W. Thomas1, Atri Rudra2, Christopher Re1 (1Stanford University, 2University at Buffalo, SUNY) 和⽥唯我 / Yuiga Wada Daniel Y. Fu et al., “Hungry Hungry Hippos: Towards Language Modeling with State Space Models”, in ICLR(2023) ICLR23 notable top 25%
  2. 概要 2 ü 背景 • 状態空間モデル(state-space model; SSM)は様々なモダリティにおいて有⽤性が検証された • ⼀⽅で,未だ⾔語においてはTransformerよりも性能が低い

    • またSSMは線形にスケーリングするにも拘らず,ハードウェアの利⽤効率が悪いために Transformerよりも低速 ü 提案⼿法 • ⾼速かつ⾔語に強い新たなSSMとしてH3 (Hungry Hungry Hippos)を提案 ü 結果 • HybridモデルにおいてGPT-2およびGPT-Neoよりも低いperplexity • SuperGLUEにおいてもzero-shotで最良の結果
  3. 背景: SSMは低速 & ⾔語において性能が不⼗分 3 • 状態空間モデル(state-space model; SSM)は様々なモダリティにおいてSOTAに匹 敵する性能が検証されたが,未だ⾔語においてはTransformerよりも性能が低い

    • またSSMは線形にスケーリングするにも拘らず,ハードウェアの利⽤効率が悪い ためにTransformerよりも低速 ⾼速かつ⾔語に強い 新たなSSMが望まれる
  4. 前提: 状態空間モデルについて 4 • LSSL[Gu+, NeurIPS21]における定式化 • ⼊⼒ 𝑢 𝑡

    ,状態 𝑥 𝑡 ,出⼒ 𝑦 𝑡 に対して以下のように定義 • GBTにより離散化 (GBT; generalized bilinear transform) 連続空間 離散空間 𝜶はハイパラ, 𝑨, 𝑩, 𝑪, 𝑫 は学習可能パラメタ (BPなどで学習)
  5. 先⾏研究: SSM + 機械学習の研究は盛んに⾏われている 5 • H3までの系譜は以下の図の通り • 今回はHiPPO→LSSL→S4→H3の流れを軽く紹介 •

    特にS4は⾮常に重要な研究であり,後続の研究が盛んに⾏われている HiPPO [Gu+, NeurIPS20] LSSL [Gu+, NeurIPS21] S4 [Gu+, ICLR22] H3 [Fu+, ICLR23] S4D [Gu+, NeurIPS22] MEGA [Ma+, ICLR23] S5 [Smith+, ICLR23]
  6. 先⾏研究1: HiPPO[Gu+(Stanford Univ.), NeurIPS20(Spotlight)] 6 ü 背景: ⻑距離系列を扱うには系列の履歴を累積的に記憶する必要がある • HiPPO:

    複数の直交多項式によって⼊⼒信号を近似する⼿法を提案 • RNNに直接組み込むことができる (LSSL, S4, H3に繋がる研究) • HiPPOを導⼊するだけで,pMNISTにおける精度が92%→98%に • pMNIST: MNISTをpermuteした系列データ. ⻑距離依存を記憶する必要がある
  7. 先⾏研究1: HiPPO[Gu+(Stanford Univ.), NeurIPS20(Spotlight)] ü 背景: ⻑距離系列を扱うには系列の履歴を累積的に記憶する必要がある • HiPPO: 複数の直交多項式によって⼊⼒信号を近似する⼿法を提案

    • RNNに直接組み込むことができる (LSSL, S4, H3に繋がる研究) • HiPPOを導⼊するだけで,pMNISTにおける精度が92%→98%に • pMNIST: MNISTをpermuteした系列データ. ⻑距離依存を記憶する必要がある 7
  8. 先⾏研究1: HiPPO[Gu+(Stanford Univ.), NeurIPS20(Spotlight)] ü 背景: ⻑距離系列を扱うには系列の履歴を累積的に記憶する必要がある • HiPPO: 複数の直交多項式によって⼊⼒信号を近似する⼿法を提案

    • RNNに直接組み込むことができる (LSSL, S4, H3に繋がる研究) • HiPPOを導⼊するだけで,pMNISTにおける精度が92%→98%に • pMNIST: MNISTをpermuteした系列データ. ⻑距離依存を記憶する必要がある 8
  9. 先⾏研究2: LSSL[Gu+(Stanford Univ.), NeurIPS21] 9 ü ⻑距離系列を扱う上で,状態空間モデル(SSM)にHiPPOを導⼊し,RNNのような recurrent と CNNのようなconvolution

    の両⽅で学習できる⼿法LSSLを提案 • RNNs: 👍系列データの学習 👎⻑距離系列→勾配爆発 • CNNs: 👍⾼速かつ並列可能 👎系列データの学習に向いていない • NDEs: 👍 連続かつ⻑距離依存を扱える 👎効率が悪い • これら三者の利点を統合する形のモデルを⽬指す • SSMの⾏列AをHiPPO⾏列にするだけで, sCIFARでTransformerを上回る • sMNIST, sCIFAR: MNIST, CIFARを1次元へと flatten化.画像の帰納バイアスが使えないため, ⾃⼒で系列の依存関係を理解する必要がある
  10. 先⾏研究2: LSSL[Gu+(Stanford Univ.), NeurIPS21] 10 ü ⻑距離系列を扱う上で,状態空間モデル(SSM)にHiPPOを導⼊し,RNNのような recurrent と CNNのようなconvolution

    の両⽅で学習できる⼿法LSSLを提案 出⼒ 𝑦! を ⼊⼒𝑢" を⽤いて 書き下す → 畳み込みで表現可能
  11. 先⾏研究2: LSSL[Gu+(Stanford Univ.), NeurIPS21] ü ⻑距離系列を扱う上で,状態空間モデル(SSM)にHiPPOを導⼊し,RNNのような recurrent と CNNのようなconvolution の両⽅で学習できる⼿法LSSLを提案

    出⼒ 𝑦! を ⼊⼒𝑢" を⽤いて 書き下す → 畳み込みで表現可能 11 Backwards-Eulerで離散化したLSSLはRNN (gating mechanism)と同⼀視できる
  12. 先⾏研究3: S4[Gu+(Stanford Univ.), ICLR22] 12 ü 背景: LSSLは計算効率が悪い → 時間計算量

    O 𝑁'𝐿 / 空間計算量 O 𝑁𝐿 • S4では 𝐴 をDPLR化(diagonal plus low rank)することで計算量を削減 • recurrent が O 𝑁 ,convolutionが ) O 𝑁 + 𝐿 で計算可能に (詳細は省略) • LSSLよりも30倍⾼速 / 400倍少ないメモリで計算可 • 初めてLong Range Arena[Tay+, ICLR20]におけるPath-Xを解くことのできたモデル
  13. 先⾏研究3: S4[Gu+(Stanford Univ.), ICLR22] 13 ü 背景: LSSLは計算効率が悪い → 時間計算量

    O 𝑁'𝐿 / 空間計算量 O 𝑁𝐿 • S4では 𝐴 をDPLR化(diagonal plus low rank)することで計算量を削減 • recurrent が O 𝑁 ,convolutionが ) O 𝑁 + 𝐿 で計算可能に (詳細は省略) • LSSLよりも30倍⾼速 / 400倍少ないメモリで計算可 • 初めてLong Range Arena[Tay+, ICLR20]におけるPath-Xを解くことのできたモデル • Long Range Arena: ⻑距離依存を捉える必要のあるタスクを提供 • Path-X: ⼆点が点線で繋がっているかを画素系列から判断
  14. ⾔語処理におけるSSMの2つの問題点 14 ü SSMは①前⽅にあるトークンの記憶と②トークン間の⽐較が不得意 • 本研究では⼆つのタスクInduction HeadとAssociative Recallで検証 • Induction

    Head • 特殊なトークン” “ で囲まれた部分⽂字列の先頭の⽂字を出⼒させるタスク • 前⽅のトークンをどれほど記憶しているかを検証 • Associative Recall • key-valueでペアを成すアルファベットと数字の組に対して,与えられたkey に対応するvalueを出⼒させるタスク (e.g., “a 2 c 4 b 3 d 1”→⼊⼒’a’ 出⼒’2’) • トークン間の関係をどれほど記憶しているかを検証
  15. ⾔語処理におけるSSMの2つの問題点 15 • Induction Head • 前⽅のトークンをどれほど記憶しているかを検証 • Associative Recall

    • トークン間の関係をどれほど記憶しているかを検証 • Attentionはどちらのタスクでも100%であるのに対して,S4D[Gu+, NeurIPS22]や Gated State Spaces[Mehta+, ICLR23]の精度はかなり低い → SSMは①前⽅にあるトークンの記憶と②トークン間の⽐較が不得意
  16. 提案⼿法: H3 (Hungry Hungry Hippos) 16 ü SSMは①前⽅にあるトークンの記憶と②トークン間の⽐較が不得意 • ⼀⽅,Attentionは𝑄𝐾(

    によりトークン間の関係を記憶可能(②) & softmax 𝑄𝐾( 𝑉 によりトークン⾃体を直接記憶可能(①) ü 提案⼿法: H3 (Hungry Hungry Hippos) • この⼆つの難点を乗り越える新たなSSMとしてH3を提案 • 上記考察に基づき,Q, K, V によって Attention-Likeに設計 • またGPU上のFFT, iFFT等の⾼速化⼿法 としてFlashConvを提案 (詳細は省略) ⇒ ⾼速かつ⾔語に強いSSMを実現
  17. 提案⼿法: H3 (Hungry Hungry Hippos) 17 ü SSMは①前⽅にあるトークンの記憶と②トークン間の⽐較が不得意 ü H3

    (Hungry Hungry Hippos) • Attention-Like: Efficient Transformerの⽅法論に則り𝐾(𝑉 を先に計算 ① shift演算によりトークンを記憶 • Shift演算: [a,b,c] → [0,a,b] • 直感的理解: 常に𝐴がshift演算として機能するとき,仮に 𝐵 = 𝑒# とすると,連鎖的に 𝑚 ステップ前までの 𝑢$ が 𝑥$ に格納される (𝑥$ = [𝑢$ , . . . , 𝑢$%&'#]) ② Attention-Likeに乗算することでトークン間の関係を記憶 (like 𝑄𝐾(𝑉) • 𝐾(𝑉部分はHiPPOによって初期化された対⾓⾏列によるSSMが通される • Shiftによって過去の情報を保持した⾏列と⼊⼒を乗算することで類似度が計算できる
  18. 提案⼿法: H3 (Hungry Hungry Hippos) ü SSMは①前⽅にあるトークンの記憶と②トークン間の⽐較が不得意 ü H3 (Hungry

    Hungry Hippos) • Attention-Like: Efficient Transformerの⽅法論に則り𝐾(𝑉 を先に計算 ① shift演算によりトークンを記憶 • Shift演算: [a,b,c] → [0,a,b] • 直感的理解: 常に𝐴がshift演算として機能するとき,仮に 𝐵 = 𝑒# とすると,連鎖的に 𝑚 ステップ前までの 𝑢$ が 𝑥$ に格納される (𝑥$ = [𝑢$ , . . . , 𝑢$%&'#]) ② Attention-Likeに乗算することでトークン間の関係を記憶 (like 𝑄𝐾(𝑉) • 𝐾(𝑉部分はHiPPOによって初期化された対⾓⾏列によるSSMが通される • Shiftによって過去の情報を保持した⾏列と⼊⼒を乗算することで類似度が計算できる ① 18
  19. 提案⼿法: H3 (Hungry Hungry Hippos) 19 ü SSMは①前⽅にあるトークンの記憶と②トークン間の⽐較が不得意 ü H3

    (Hungry Hungry Hippos) • Attention-Like: Efficient Transformerの⽅法論に則り𝐾(𝑉 を先に計算 ① shift演算によりトークンを記憶 • Shift演算: [a,b,c] → [0,a,b] • 直感的理解: 常に𝐴がshift演算として機能するとき,仮に 𝐵 = 𝑒# とすると,連鎖的に 𝑚 ステップ前までの 𝑢$ が 𝑥$ に格納される (𝑥$ = [𝑢$ , . . . , 𝑢$%&'#]) ② Attention-Likeに乗算することでトークン間の関係を記憶 (like 𝑄𝐾(𝑉) • 𝐾(𝑉部分はHiPPOによって初期化された対⾓⾏列によるSSMが通される • Shiftによって過去の情報を保持した⾏列と⼊⼒を乗算することで類似度が計算できる ②
  20. 提案⼿法: H3 (Hungry Hungry Hippos) 20 • H3レイヤのアルゴリズムは以下の通り ⼊⼒ 𝑢

    に対して, Q = 𝑢𝑊! , 𝐾 = 𝑢𝑊" , 𝑉 = 𝑢𝑊# を計算 𝑄 𝐾 𝑉
  21. 提案⼿法: H3 (Hungry Hungry Hippos) 22 • H3レイヤのアルゴリズムは以下の通り 𝑄 𝐾

    𝑉 𝑄, 𝐾, 𝑉をmulti-head化 すなわち,dim⽅向に分割
  22. 提案⼿法: H3 (Hungry Hungry Hippos) 24 • H3レイヤのアルゴリズムは以下の通り 𝑄 𝐾

    𝑉 𝑄- ∈ ℝ. 𝑖 = 1, … 𝑁} ごとに 𝑄- 𝐾𝑉 - を計算してconcat
  23. 提案⼿法: H3 (Hungry Hungry Hippos) 25 • H3は系列⻑ 𝑁 に対しておおよそ

    O 𝑁log 𝑁 で動作する • 時間計算量: O 𝑑'𝑁 + 𝑑𝑁log 𝑁 • 空間計算量: O 𝑑𝑁 → Attentionは 時間計算量 O 𝑑𝑁' / 空間計算量 O 𝑁' なのでAttentionよりも⾼速
  24. 定量的結果: GPT-{Neo, 2}よりも低いperplexity 26 • Hybridモデルで実験 (トリック→実はH3単体だと若⼲負けるので) • Hybrid: AttentionとH3レイヤを交互に配置したモデル

    • データセット: The Pile, OpenWebText, WikiText103 • GPT-Neo, GPT-2よりも 低いperplexity • どのモデルサイズでも 最良の結果
  25. 推論速度: Transformerよりも2.4倍⾼速に推論 28 • A100 (80GB)における1.3Bモデルの推論速度を計測 (バッチサイズ: 64) • FlashConvの速度への寄与を

    Long Range Arena [Tay+, ICLR20] にて検証 • H3: Transformerよりも 2.4倍⾼速に推論可 • S4 w/ FlashConvでは Transformerよりも5.8倍 ⾼速
  26. PureなH3との⽐較: SuperGLUEにおいてOPT, GPT-Neoに及ばず 29 • PureなH3におけるzero-shot性能 • OPT, GPT-Neoにかなりの差で負けている •

    MultiRC, ReCoRDで⼤敗 • MultiRC: Multi-Sentence Reading Comprehension • ReCoRD: Reading Comprehension with Commonsense Reasoning Dataset
  27. 所感 30 • Strengths • ⾔語モデリングにおいて⾼速かつ強⼒なアーキテクチャを提⽰した点 • Few-shot, zero-shot, LRAでの実験や,GPT-{2,Neo},

    OPT, Perceiver AR等との ⽐較など多様な実験が⾏われており,H3の有効性に関するスコープが⽰され ている点 • Weaknesses • PureなH3だとOPT, GPT-Neoに負けている • Attentionを混ぜてしまうならSSMの利点を⼗分に活かしきれていないのでは • Others • S5やMEGAなどもICLR23に採択されておりSSMの研究はかなり盛んな印象
  28. まとめ 31 ü 背景 • 状態空間モデル(state-space model; SSM)は様々なモダリティにおいて有⽤性が検証された • ⼀⽅で,未だ⾔語においてはTransformerよりも性能が低い

    • またSSMは線形にスケーリングするにも拘らず,ハードウェアの利⽤効率が悪いために Transformerよりも低速 ü 提案⼿法 • ⾼速かつ⾔語に強い新たなSSMとしてH3 (Hungry Hungry Hippos)を提案 ü 結果 • HybridモデルにおいてGPT-2およびGPT-Neoよりも低いperplexity • SuperGLUEにおいてもzero-shotで最良の結果
  29. 参考⽂献 32 • [Gu+, NeurIPS20]: Albert Gu, Tri Dao, Stefano

    Ermon, Atri Rudra, and Christopher Re. “Hippo: Recurrent memory with optimal polynomial projections”. NeurIPS, 2020. • [Gu+, NeurIPS21]: Gu, Albert, Isys Johnson, Karan Goel, Khaled Saab, Tri Dao, Atri Rudra, and Christopher Ré. "Combining recurrent, convolutional, and continuous-time models with linear state space layers." NeurIPS, 2021. • [Gu+, ICLR22]: Albert Gu, Karan Goel, and Christopher Re. “Efficiently modeling long sequences with structured state spaces”. ICLR, 2022. • [Gu+, NeurIPS22]: Albert Gu, Ankit Gupta, Karan Goel, and Christopher Re. “On the parameterization and initialization of diagonal state space models”. NeurIPS, 2022. • [Ma+, ICLR23]: Xuezhe Ma, Chunting Zhou, Xiang Kong, Junxian He, Liangke Gui, Graham Neubig, Jonathan May, and Luke Zettlemoyer. "Mega: Moving Average Equipped Gated Attention." ICLR, 2023. • [Smith+, ICLR23]: Jimmy T.H. Smith, Andrew Warrington, Scott Linderman, “Simplified State Space Layers for Sequence Modeling” , ICLR, 2023. • [Mehta+, ICLR23]: Harsh Mehta, Ankit Gupta, Ashok Cutkosky, Behnam Neyshabur. “Long Range Language Modeling via Gated State Spaces”, ICLR, 2023 • [Tay+, ICLR20]: Yi Tay, Mostafa Dehghani, Samira Abnar, Yikang Shen, Dara Bahri, Philip Pham, Jinfeng Rao, Liu Yang, Sebastian Ruder, and Donald Metzler. “Long range arena: A benchmark for efficient transformers”, ICLR, 2020.
  30. Appendix: 動作確認 33 • H3: (B, L, H) = (8,

    6, 512) • データセット: The Pile (hacker_newsのみを使⽤) • 学習時間: 98時間, Epoch: 92 • RAM(GPU): 17.3GB, test/perplexity: 29.5 所感 • H3動かすのにCUDA関係でめちゃくちゃ苦労した… • 後続のHyenaになるとImageを扱えるので,普通に マルチモダリティ扱えそう • 学習時間が減ってる感じはあまりわからない • もう少し軽量のデータセットで試せば効果を 実感できそう