Slide 1

Slide 1 text

Hungry Hungry Hippos: Towards Language Modeling with State Space Models Daniel Y. Fu1, Tri Dao1, Khaled K. Saab1, Armin W. Thomas1, Atri Rudra2, Christopher Re1 (1Stanford University, 2University at Buffalo, SUNY) 和⽥唯我 / Yuiga Wada Daniel Y. Fu et al., “Hungry Hungry Hippos: Towards Language Modeling with State Space Models”, in ICLR(2023) ICLR23 notable top 25%

Slide 2

Slide 2 text

概要 2 ü 背景 • 状態空間モデル(state-space model; SSM)は様々なモダリティにおいて有⽤性が検証された • ⼀⽅で,未だ⾔語においてはTransformerよりも性能が低い • またSSMは線形にスケーリングするにも拘らず,ハードウェアの利⽤効率が悪いために Transformerよりも低速 ü 提案⼿法 • ⾼速かつ⾔語に強い新たなSSMとしてH3 (Hungry Hungry Hippos)を提案 ü 結果 • HybridモデルにおいてGPT-2およびGPT-Neoよりも低いperplexity • SuperGLUEにおいてもzero-shotで最良の結果

Slide 3

Slide 3 text

背景: SSMは低速 & ⾔語において性能が不⼗分 3 • 状態空間モデル(state-space model; SSM)は様々なモダリティにおいてSOTAに匹 敵する性能が検証されたが,未だ⾔語においてはTransformerよりも性能が低い • またSSMは線形にスケーリングするにも拘らず,ハードウェアの利⽤効率が悪い ためにTransformerよりも低速 ⾼速かつ⾔語に強い 新たなSSMが望まれる

Slide 4

Slide 4 text

前提: 状態空間モデルについて 4 • LSSL[Gu+, NeurIPS21]における定式化 • ⼊⼒ 𝑢 𝑡 ,状態 𝑥 𝑡 ,出⼒ 𝑦 𝑡 に対して以下のように定義 • GBTにより離散化 (GBT; generalized bilinear transform) 連続空間 離散空間 𝜶はハイパラ, 𝑨, 𝑩, 𝑪, 𝑫 は学習可能パラメタ (BPなどで学習)

Slide 5

Slide 5 text

先⾏研究: SSM + 機械学習の研究は盛んに⾏われている 5 • H3までの系譜は以下の図の通り • 今回はHiPPO→LSSL→S4→H3の流れを軽く紹介 • 特にS4は⾮常に重要な研究であり,後続の研究が盛んに⾏われている HiPPO [Gu+, NeurIPS20] LSSL [Gu+, NeurIPS21] S4 [Gu+, ICLR22] H3 [Fu+, ICLR23] S4D [Gu+, NeurIPS22] MEGA [Ma+, ICLR23] S5 [Smith+, ICLR23]

Slide 6

Slide 6 text

先⾏研究1: HiPPO[Gu+(Stanford Univ.), NeurIPS20(Spotlight)] 6 ü 背景: ⻑距離系列を扱うには系列の履歴を累積的に記憶する必要がある • HiPPO: 複数の直交多項式によって⼊⼒信号を近似する⼿法を提案 • RNNに直接組み込むことができる (LSSL, S4, H3に繋がる研究) • HiPPOを導⼊するだけで,pMNISTにおける精度が92%→98%に • pMNIST: MNISTをpermuteした系列データ. ⻑距離依存を記憶する必要がある

Slide 7

Slide 7 text

先⾏研究1: HiPPO[Gu+(Stanford Univ.), NeurIPS20(Spotlight)] ü 背景: ⻑距離系列を扱うには系列の履歴を累積的に記憶する必要がある • HiPPO: 複数の直交多項式によって⼊⼒信号を近似する⼿法を提案 • RNNに直接組み込むことができる (LSSL, S4, H3に繋がる研究) • HiPPOを導⼊するだけで,pMNISTにおける精度が92%→98%に • pMNIST: MNISTをpermuteした系列データ. ⻑距離依存を記憶する必要がある 7

Slide 8

Slide 8 text

先⾏研究1: HiPPO[Gu+(Stanford Univ.), NeurIPS20(Spotlight)] ü 背景: ⻑距離系列を扱うには系列の履歴を累積的に記憶する必要がある • HiPPO: 複数の直交多項式によって⼊⼒信号を近似する⼿法を提案 • RNNに直接組み込むことができる (LSSL, S4, H3に繋がる研究) • HiPPOを導⼊するだけで,pMNISTにおける精度が92%→98%に • pMNIST: MNISTをpermuteした系列データ. ⻑距離依存を記憶する必要がある 8

Slide 9

Slide 9 text

先⾏研究2: LSSL[Gu+(Stanford Univ.), NeurIPS21] 9 ü ⻑距離系列を扱う上で,状態空間モデル(SSM)にHiPPOを導⼊し,RNNのような recurrent と CNNのようなconvolution の両⽅で学習できる⼿法LSSLを提案 • RNNs: 👍系列データの学習 👎⻑距離系列→勾配爆発 • CNNs: 👍⾼速かつ並列可能 👎系列データの学習に向いていない • NDEs: 👍 連続かつ⻑距離依存を扱える 👎効率が悪い • これら三者の利点を統合する形のモデルを⽬指す • SSMの⾏列AをHiPPO⾏列にするだけで, sCIFARでTransformerを上回る • sMNIST, sCIFAR: MNIST, CIFARを1次元へと flatten化.画像の帰納バイアスが使えないため, ⾃⼒で系列の依存関係を理解する必要がある

Slide 10

Slide 10 text

先⾏研究2: LSSL[Gu+(Stanford Univ.), NeurIPS21] 10 ü ⻑距離系列を扱う上で,状態空間モデル(SSM)にHiPPOを導⼊し,RNNのような recurrent と CNNのようなconvolution の両⽅で学習できる⼿法LSSLを提案 出⼒ 𝑦! を ⼊⼒𝑢" を⽤いて 書き下す → 畳み込みで表現可能

Slide 11

Slide 11 text

先⾏研究2: LSSL[Gu+(Stanford Univ.), NeurIPS21] ü ⻑距離系列を扱う上で,状態空間モデル(SSM)にHiPPOを導⼊し,RNNのような recurrent と CNNのようなconvolution の両⽅で学習できる⼿法LSSLを提案 出⼒ 𝑦! を ⼊⼒𝑢" を⽤いて 書き下す → 畳み込みで表現可能 11 Backwards-Eulerで離散化したLSSLはRNN (gating mechanism)と同⼀視できる

Slide 12

Slide 12 text

先⾏研究3: S4[Gu+(Stanford Univ.), ICLR22] 12 ü 背景: LSSLは計算効率が悪い → 時間計算量 O 𝑁'𝐿 / 空間計算量 O 𝑁𝐿 • S4では 𝐴 をDPLR化(diagonal plus low rank)することで計算量を削減 • recurrent が O 𝑁 ,convolutionが ) O 𝑁 + 𝐿 で計算可能に (詳細は省略) • LSSLよりも30倍⾼速 / 400倍少ないメモリで計算可 • 初めてLong Range Arena[Tay+, ICLR20]におけるPath-Xを解くことのできたモデル

Slide 13

Slide 13 text

先⾏研究3: S4[Gu+(Stanford Univ.), ICLR22] 13 ü 背景: LSSLは計算効率が悪い → 時間計算量 O 𝑁'𝐿 / 空間計算量 O 𝑁𝐿 • S4では 𝐴 をDPLR化(diagonal plus low rank)することで計算量を削減 • recurrent が O 𝑁 ,convolutionが ) O 𝑁 + 𝐿 で計算可能に (詳細は省略) • LSSLよりも30倍⾼速 / 400倍少ないメモリで計算可 • 初めてLong Range Arena[Tay+, ICLR20]におけるPath-Xを解くことのできたモデル • Long Range Arena: ⻑距離依存を捉える必要のあるタスクを提供 • Path-X: ⼆点が点線で繋がっているかを画素系列から判断

Slide 14

Slide 14 text

⾔語処理におけるSSMの2つの問題点 14 ü SSMは①前⽅にあるトークンの記憶と②トークン間の⽐較が不得意 • 本研究では⼆つのタスクInduction HeadとAssociative Recallで検証 • Induction Head • 特殊なトークン” “ で囲まれた部分⽂字列の先頭の⽂字を出⼒させるタスク • 前⽅のトークンをどれほど記憶しているかを検証 • Associative Recall • key-valueでペアを成すアルファベットと数字の組に対して,与えられたkey に対応するvalueを出⼒させるタスク (e.g., “a 2 c 4 b 3 d 1”→⼊⼒’a’ 出⼒’2’) • トークン間の関係をどれほど記憶しているかを検証

Slide 15

Slide 15 text

⾔語処理におけるSSMの2つの問題点 15 • Induction Head • 前⽅のトークンをどれほど記憶しているかを検証 • Associative Recall • トークン間の関係をどれほど記憶しているかを検証 • Attentionはどちらのタスクでも100%であるのに対して,S4D[Gu+, NeurIPS22]や Gated State Spaces[Mehta+, ICLR23]の精度はかなり低い → SSMは①前⽅にあるトークンの記憶と②トークン間の⽐較が不得意

Slide 16

Slide 16 text

提案⼿法: H3 (Hungry Hungry Hippos) 16 ü SSMは①前⽅にあるトークンの記憶と②トークン間の⽐較が不得意 • ⼀⽅,Attentionは𝑄𝐾( によりトークン間の関係を記憶可能(②) & softmax 𝑄𝐾( 𝑉 によりトークン⾃体を直接記憶可能(①) ü 提案⼿法: H3 (Hungry Hungry Hippos) • この⼆つの難点を乗り越える新たなSSMとしてH3を提案 • 上記考察に基づき,Q, K, V によって Attention-Likeに設計 • またGPU上のFFT, iFFT等の⾼速化⼿法 としてFlashConvを提案 (詳細は省略) ⇒ ⾼速かつ⾔語に強いSSMを実現

Slide 17

Slide 17 text

提案⼿法: H3 (Hungry Hungry Hippos) 17 ü SSMは①前⽅にあるトークンの記憶と②トークン間の⽐較が不得意 ü H3 (Hungry Hungry Hippos) • Attention-Like: Efficient Transformerの⽅法論に則り𝐾(𝑉 を先に計算 ① shift演算によりトークンを記憶 • Shift演算: [a,b,c] → [0,a,b] • 直感的理解: 常に𝐴がshift演算として機能するとき,仮に 𝐵 = 𝑒# とすると,連鎖的に 𝑚 ステップ前までの 𝑢$ が 𝑥$ に格納される (𝑥$ = [𝑢$ , . . . , 𝑢$%&'#]) ② Attention-Likeに乗算することでトークン間の関係を記憶 (like 𝑄𝐾(𝑉) • 𝐾(𝑉部分はHiPPOによって初期化された対⾓⾏列によるSSMが通される • Shiftによって過去の情報を保持した⾏列と⼊⼒を乗算することで類似度が計算できる

Slide 18

Slide 18 text

提案⼿法: H3 (Hungry Hungry Hippos) ü SSMは①前⽅にあるトークンの記憶と②トークン間の⽐較が不得意 ü H3 (Hungry Hungry Hippos) • Attention-Like: Efficient Transformerの⽅法論に則り𝐾(𝑉 を先に計算 ① shift演算によりトークンを記憶 • Shift演算: [a,b,c] → [0,a,b] • 直感的理解: 常に𝐴がshift演算として機能するとき,仮に 𝐵 = 𝑒# とすると,連鎖的に 𝑚 ステップ前までの 𝑢$ が 𝑥$ に格納される (𝑥$ = [𝑢$ , . . . , 𝑢$%&'#]) ② Attention-Likeに乗算することでトークン間の関係を記憶 (like 𝑄𝐾(𝑉) • 𝐾(𝑉部分はHiPPOによって初期化された対⾓⾏列によるSSMが通される • Shiftによって過去の情報を保持した⾏列と⼊⼒を乗算することで類似度が計算できる ① 18

Slide 19

Slide 19 text

提案⼿法: H3 (Hungry Hungry Hippos) 19 ü SSMは①前⽅にあるトークンの記憶と②トークン間の⽐較が不得意 ü H3 (Hungry Hungry Hippos) • Attention-Like: Efficient Transformerの⽅法論に則り𝐾(𝑉 を先に計算 ① shift演算によりトークンを記憶 • Shift演算: [a,b,c] → [0,a,b] • 直感的理解: 常に𝐴がshift演算として機能するとき,仮に 𝐵 = 𝑒# とすると,連鎖的に 𝑚 ステップ前までの 𝑢$ が 𝑥$ に格納される (𝑥$ = [𝑢$ , . . . , 𝑢$%&'#]) ② Attention-Likeに乗算することでトークン間の関係を記憶 (like 𝑄𝐾(𝑉) • 𝐾(𝑉部分はHiPPOによって初期化された対⾓⾏列によるSSMが通される • Shiftによって過去の情報を保持した⾏列と⼊⼒を乗算することで類似度が計算できる ②

Slide 20

Slide 20 text

提案⼿法: H3 (Hungry Hungry Hippos) 20 • H3レイヤのアルゴリズムは以下の通り ⼊⼒ 𝑢 に対して, Q = 𝑢𝑊! , 𝐾 = 𝑢𝑊" , 𝑉 = 𝑢𝑊# を計算 𝑄 𝐾 𝑉

Slide 21

Slide 21 text

提案⼿法: H3 (Hungry Hungry Hippos) 21 • H3レイヤのアルゴリズムは以下の通り 𝑄 𝐾 𝑉 𝐾 を SSM$%&'( に通して 2 𝐾 を得る

Slide 22

Slide 22 text

提案⼿法: H3 (Hungry Hungry Hippos) 22 • H3レイヤのアルゴリズムは以下の通り 𝑄 𝐾 𝑉 𝑄, 𝐾, 𝑉をmulti-head化 すなわち,dim⽅向に分割

Slide 23

Slide 23 text

提案⼿法: H3 (Hungry Hungry Hippos) 23 • H3レイヤのアルゴリズムは以下の通り 𝑄 𝐾 𝑉 各headごとに 𝐾𝑉 ≔ SSM)&*+ 2 𝐾𝑉, を得る

Slide 24

Slide 24 text

提案⼿法: H3 (Hungry Hungry Hippos) 24 • H3レイヤのアルゴリズムは以下の通り 𝑄 𝐾 𝑉 𝑄- ∈ ℝ. 𝑖 = 1, … 𝑁} ごとに 𝑄- 𝐾𝑉 - を計算してconcat

Slide 25

Slide 25 text

提案⼿法: H3 (Hungry Hungry Hippos) 25 • H3は系列⻑ 𝑁 に対しておおよそ O 𝑁log 𝑁 で動作する • 時間計算量: O 𝑑'𝑁 + 𝑑𝑁log 𝑁 • 空間計算量: O 𝑑𝑁 → Attentionは 時間計算量 O 𝑑𝑁' / 空間計算量 O 𝑁' なのでAttentionよりも⾼速

Slide 26

Slide 26 text

定量的結果: GPT-{Neo, 2}よりも低いperplexity 26 • Hybridモデルで実験 (トリック→実はH3単体だと若⼲負けるので) • Hybrid: AttentionとH3レイヤを交互に配置したモデル • データセット: The Pile, OpenWebText, WikiText103 • GPT-Neo, GPT-2よりも 低いperplexity • どのモデルサイズでも 最良の結果

Slide 27

Slide 27 text

定量的結果: zero-shotでもGPT-{Neo, 2}およびOPTを上回る 27 • SuperGLUE (⾔語理解ベンチマーク)におけるzero-shot性能を検証 → Zero-shotにおいてもGPT-Neo, GPT-2, OPTを上回る結果

Slide 28

Slide 28 text

推論速度: Transformerよりも2.4倍⾼速に推論 28 • A100 (80GB)における1.3Bモデルの推論速度を計測 (バッチサイズ: 64) • FlashConvの速度への寄与を Long Range Arena [Tay+, ICLR20] にて検証 • H3: Transformerよりも 2.4倍⾼速に推論可 • S4 w/ FlashConvでは Transformerよりも5.8倍 ⾼速

Slide 29

Slide 29 text

PureなH3との⽐較: SuperGLUEにおいてOPT, GPT-Neoに及ばず 29 • PureなH3におけるzero-shot性能 • OPT, GPT-Neoにかなりの差で負けている • MultiRC, ReCoRDで⼤敗 • MultiRC: Multi-Sentence Reading Comprehension • ReCoRD: Reading Comprehension with Commonsense Reasoning Dataset

Slide 30

Slide 30 text

所感 30 • Strengths • ⾔語モデリングにおいて⾼速かつ強⼒なアーキテクチャを提⽰した点 • Few-shot, zero-shot, LRAでの実験や,GPT-{2,Neo}, OPT, Perceiver AR等との ⽐較など多様な実験が⾏われており,H3の有効性に関するスコープが⽰され ている点 • Weaknesses • PureなH3だとOPT, GPT-Neoに負けている • Attentionを混ぜてしまうならSSMの利点を⼗分に活かしきれていないのでは • Others • S5やMEGAなどもICLR23に採択されておりSSMの研究はかなり盛んな印象

Slide 31

Slide 31 text

まとめ 31 ü 背景 • 状態空間モデル(state-space model; SSM)は様々なモダリティにおいて有⽤性が検証された • ⼀⽅で,未だ⾔語においてはTransformerよりも性能が低い • またSSMは線形にスケーリングするにも拘らず,ハードウェアの利⽤効率が悪いために Transformerよりも低速 ü 提案⼿法 • ⾼速かつ⾔語に強い新たなSSMとしてH3 (Hungry Hungry Hippos)を提案 ü 結果 • HybridモデルにおいてGPT-2およびGPT-Neoよりも低いperplexity • SuperGLUEにおいてもzero-shotで最良の結果

Slide 32

Slide 32 text

参考⽂献 32 • [Gu+, NeurIPS20]: Albert Gu, Tri Dao, Stefano Ermon, Atri Rudra, and Christopher Re. “Hippo: Recurrent memory with optimal polynomial projections”. NeurIPS, 2020. • [Gu+, NeurIPS21]: Gu, Albert, Isys Johnson, Karan Goel, Khaled Saab, Tri Dao, Atri Rudra, and Christopher Ré. "Combining recurrent, convolutional, and continuous-time models with linear state space layers." NeurIPS, 2021. • [Gu+, ICLR22]: Albert Gu, Karan Goel, and Christopher Re. “Efficiently modeling long sequences with structured state spaces”. ICLR, 2022. • [Gu+, NeurIPS22]: Albert Gu, Ankit Gupta, Karan Goel, and Christopher Re. “On the parameterization and initialization of diagonal state space models”. NeurIPS, 2022. • [Ma+, ICLR23]: Xuezhe Ma, Chunting Zhou, Xiang Kong, Junxian He, Liangke Gui, Graham Neubig, Jonathan May, and Luke Zettlemoyer. "Mega: Moving Average Equipped Gated Attention." ICLR, 2023. • [Smith+, ICLR23]: Jimmy T.H. Smith, Andrew Warrington, Scott Linderman, “Simplified State Space Layers for Sequence Modeling” , ICLR, 2023. • [Mehta+, ICLR23]: Harsh Mehta, Ankit Gupta, Ashok Cutkosky, Behnam Neyshabur. “Long Range Language Modeling via Gated State Spaces”, ICLR, 2023 • [Tay+, ICLR20]: Yi Tay, Mostafa Dehghani, Samira Abnar, Yikang Shen, Dara Bahri, Philip Pham, Jinfeng Rao, Liu Yang, Sebastian Ruder, and Donald Metzler. “Long range arena: A benchmark for efficient transformers”, ICLR, 2020.

Slide 33

Slide 33 text

Appendix: 動作確認 33 • H3: (B, L, H) = (8, 6, 512) • データセット: The Pile (hacker_newsのみを使⽤) • 学習時間: 98時間, Epoch: 92 • RAM(GPU): 17.3GB, test/perplexity: 29.5 所感 • H3動かすのにCUDA関係でめちゃくちゃ苦労した… • 後続のHyenaになるとImageを扱えるので,普通に マルチモダリティ扱えそう • 学習時間が減ってる感じはあまりわからない • もう少し軽量のデータセットで試せば効果を 実感できそう

Slide 34

Slide 34 text

Appendix: 計算量の証明 34

Slide 35

Slide 35 text

Appendix: MultiRC, ReCoRDの具体例 (SuperGLUE) 35

Slide 36

Slide 36 text

Appendix: Few-shot / LRAの定量的結果 36

Slide 37

Slide 37 text

Appendix: FlashConvにおけるState-Passing Algorithm 37