$30 off During Our Annual Pro Sale. View Details »

[Journal club] Hungry Hungry Hippos: Towards Language Modeling with State Space Models

[Journal club] Hungry Hungry Hippos: Towards Language Modeling with State Space Models

More Decks by Semantic Machine Intelligence Lab., Keio Univ.

Other Decks in Technology

Transcript

  1. Hungry Hungry Hippos: Towards Language
    Modeling with State Space Models
    Daniel Y. Fu1, Tri Dao1, Khaled K. Saab1, Armin W. Thomas1, Atri Rudra2,
    Christopher Re1 (1Stanford University, 2University at Buffalo, SUNY)
    和⽥唯我 / Yuiga Wada
    Daniel Y. Fu et al., “Hungry Hungry Hippos: Towards Language Modeling with State Space Models”, in ICLR(2023)
    ICLR23
    notable top 25%

    View Slide

  2. 概要
    2
    ü 背景
    • 状態空間モデル(state-space model; SSM)は様々なモダリティにおいて有⽤性が検証された
    • ⼀⽅で,未だ⾔語においてはTransformerよりも性能が低い
    • またSSMは線形にスケーリングするにも拘らず,ハードウェアの利⽤効率が悪いために
    Transformerよりも低速
    ü 提案⼿法
    • ⾼速かつ⾔語に強い新たなSSMとしてH3 (Hungry Hungry Hippos)を提案
    ü 結果
    • HybridモデルにおいてGPT-2およびGPT-Neoよりも低いperplexity
    • SuperGLUEにおいてもzero-shotで最良の結果

    View Slide

  3. 背景: SSMは低速 & ⾔語において性能が不⼗分
    3
    • 状態空間モデル(state-space model; SSM)は様々なモダリティにおいてSOTAに匹
    敵する性能が検証されたが,未だ⾔語においてはTransformerよりも性能が低い
    • またSSMは線形にスケーリングするにも拘らず,ハードウェアの利⽤効率が悪い
    ためにTransformerよりも低速
    ⾼速かつ⾔語に強い
    新たなSSMが望まれる

    View Slide

  4. 前提: 状態空間モデルについて
    4
    • LSSL[Gu+, NeurIPS21]における定式化
    • ⼊⼒ 𝑢 𝑡 ,状態 𝑥 𝑡 ,出⼒ 𝑦 𝑡 に対して以下のように定義
    • GBTにより離散化 (GBT; generalized bilinear transform)
    連続空間
    離散空間
    𝜶はハイパラ, 𝑨, 𝑩, 𝑪, 𝑫 は学習可能パラメタ (BPなどで学習)

    View Slide

  5. 先⾏研究: SSM + 機械学習の研究は盛んに⾏われている
    5
    • H3までの系譜は以下の図の通り
    • 今回はHiPPO→LSSL→S4→H3の流れを軽く紹介
    • 特にS4は⾮常に重要な研究であり,後続の研究が盛んに⾏われている
    HiPPO
    [Gu+, NeurIPS20]
    LSSL
    [Gu+, NeurIPS21]
    S4
    [Gu+, ICLR22]
    H3
    [Fu+, ICLR23]
    S4D
    [Gu+, NeurIPS22]
    MEGA [Ma+, ICLR23]
    S5 [Smith+, ICLR23]

    View Slide

  6. 先⾏研究1: HiPPO[Gu+(Stanford Univ.), NeurIPS20(Spotlight)]
    6
    ü 背景: ⻑距離系列を扱うには系列の履歴を累積的に記憶する必要がある
    • HiPPO: 複数の直交多項式によって⼊⼒信号を近似する⼿法を提案
    • RNNに直接組み込むことができる (LSSL, S4, H3に繋がる研究)
    • HiPPOを導⼊するだけで,pMNISTにおける精度が92%→98%に
    • pMNIST: MNISTをpermuteした系列データ. ⻑距離依存を記憶する必要がある

    View Slide

  7. 先⾏研究1: HiPPO[Gu+(Stanford Univ.), NeurIPS20(Spotlight)]
    ü 背景: ⻑距離系列を扱うには系列の履歴を累積的に記憶する必要がある
    • HiPPO: 複数の直交多項式によって⼊⼒信号を近似する⼿法を提案
    • RNNに直接組み込むことができる (LSSL, S4, H3に繋がる研究)
    • HiPPOを導⼊するだけで,pMNISTにおける精度が92%→98%に
    • pMNIST: MNISTをpermuteした系列データ. ⻑距離依存を記憶する必要がある
    7

    View Slide

  8. 先⾏研究1: HiPPO[Gu+(Stanford Univ.), NeurIPS20(Spotlight)]
    ü 背景: ⻑距離系列を扱うには系列の履歴を累積的に記憶する必要がある
    • HiPPO: 複数の直交多項式によって⼊⼒信号を近似する⼿法を提案
    • RNNに直接組み込むことができる (LSSL, S4, H3に繋がる研究)
    • HiPPOを導⼊するだけで,pMNISTにおける精度が92%→98%に
    • pMNIST: MNISTをpermuteした系列データ. ⻑距離依存を記憶する必要がある
    8

    View Slide

  9. 先⾏研究2: LSSL[Gu+(Stanford Univ.), NeurIPS21]
    9
    ü ⻑距離系列を扱う上で,状態空間モデル(SSM)にHiPPOを導⼊し,RNNのような
    recurrent と CNNのようなconvolution の両⽅で学習できる⼿法LSSLを提案
    • RNNs: 👍系列データの学習 👎⻑距離系列→勾配爆発
    • CNNs: 👍⾼速かつ並列可能 👎系列データの学習に向いていない
    • NDEs: 👍 連続かつ⻑距離依存を扱える 👎効率が悪い
    • これら三者の利点を統合する形のモデルを⽬指す
    • SSMの⾏列AをHiPPO⾏列にするだけで,
    sCIFARでTransformerを上回る
    • sMNIST, sCIFAR: MNIST, CIFARを1次元へと
    flatten化.画像の帰納バイアスが使えないため,
    ⾃⼒で系列の依存関係を理解する必要がある

    View Slide

  10. 先⾏研究2: LSSL[Gu+(Stanford Univ.), NeurIPS21]
    10
    ü ⻑距離系列を扱う上で,状態空間モデル(SSM)にHiPPOを導⼊し,RNNのような
    recurrent と CNNのようなconvolution の両⽅で学習できる⼿法LSSLを提案
    出⼒ 𝑦!
    を ⼊⼒𝑢"
    を⽤いて
    書き下す
    → 畳み込みで表現可能

    View Slide

  11. 先⾏研究2: LSSL[Gu+(Stanford Univ.), NeurIPS21]
    ü ⻑距離系列を扱う上で,状態空間モデル(SSM)にHiPPOを導⼊し,RNNのような
    recurrent と CNNのようなconvolution の両⽅で学習できる⼿法LSSLを提案
    出⼒ 𝑦!
    を ⼊⼒𝑢"
    を⽤いて
    書き下す
    → 畳み込みで表現可能
    11
    Backwards-Eulerで離散化したLSSLはRNN (gating mechanism)と同⼀視できる

    View Slide

  12. 先⾏研究3: S4[Gu+(Stanford Univ.), ICLR22]
    12
    ü 背景: LSSLは計算効率が悪い → 時間計算量 O 𝑁'𝐿 / 空間計算量 O 𝑁𝐿
    • S4では 𝐴 をDPLR化(diagonal plus low rank)することで計算量を削減
    • recurrent が O 𝑁 ,convolutionが )
    O 𝑁 + 𝐿 で計算可能に (詳細は省略)
    • LSSLよりも30倍⾼速 / 400倍少ないメモリで計算可
    • 初めてLong Range Arena[Tay+, ICLR20]におけるPath-Xを解くことのできたモデル

    View Slide

  13. 先⾏研究3: S4[Gu+(Stanford Univ.), ICLR22]
    13
    ü 背景: LSSLは計算効率が悪い → 時間計算量 O 𝑁'𝐿 / 空間計算量 O 𝑁𝐿
    • S4では 𝐴 をDPLR化(diagonal plus low rank)することで計算量を削減
    • recurrent が O 𝑁 ,convolutionが )
    O 𝑁 + 𝐿 で計算可能に (詳細は省略)
    • LSSLよりも30倍⾼速 / 400倍少ないメモリで計算可
    • 初めてLong Range Arena[Tay+, ICLR20]におけるPath-Xを解くことのできたモデル
    • Long Range Arena: ⻑距離依存を捉える必要のあるタスクを提供
    • Path-X: ⼆点が点線で繋がっているかを画素系列から判断

    View Slide

  14. ⾔語処理におけるSSMの2つの問題点
    14
    ü SSMは①前⽅にあるトークンの記憶と②トークン間の⽐較が不得意
    • 本研究では⼆つのタスクInduction HeadとAssociative Recallで検証
    • Induction Head
    • 特殊なトークン” “ で囲まれた部分⽂字列の先頭の⽂字を出⼒させるタスク
    • 前⽅のトークンをどれほど記憶しているかを検証
    • Associative Recall
    • key-valueでペアを成すアルファベットと数字の組に対して,与えられたkey
    に対応するvalueを出⼒させるタスク (e.g., “a 2 c 4 b 3 d 1”→⼊⼒’a’ 出⼒’2’)
    • トークン間の関係をどれほど記憶しているかを検証

    View Slide

  15. ⾔語処理におけるSSMの2つの問題点
    15
    • Induction Head
    • 前⽅のトークンをどれほど記憶しているかを検証
    • Associative Recall
    • トークン間の関係をどれほど記憶しているかを検証
    • Attentionはどちらのタスクでも100%であるのに対して,S4D[Gu+, NeurIPS22]や
    Gated State Spaces[Mehta+, ICLR23]の精度はかなり低い
    → SSMは①前⽅にあるトークンの記憶と②トークン間の⽐較が不得意

    View Slide

  16. 提案⼿法: H3 (Hungry Hungry Hippos)
    16
    ü SSMは①前⽅にあるトークンの記憶と②トークン間の⽐較が不得意
    • ⼀⽅,Attentionは𝑄𝐾( によりトークン間の関係を記憶可能(②) &
    softmax 𝑄𝐾( 𝑉 によりトークン⾃体を直接記憶可能(①)
    ü 提案⼿法: H3 (Hungry Hungry Hippos)
    • この⼆つの難点を乗り越える新たなSSMとしてH3を提案
    • 上記考察に基づき,Q, K, V によって
    Attention-Likeに設計
    • またGPU上のFFT, iFFT等の⾼速化⼿法
    としてFlashConvを提案 (詳細は省略)
    ⇒ ⾼速かつ⾔語に強いSSMを実現

    View Slide

  17. 提案⼿法: H3 (Hungry Hungry Hippos)
    17
    ü SSMは①前⽅にあるトークンの記憶と②トークン間の⽐較が不得意
    ü H3 (Hungry Hungry Hippos)
    • Attention-Like: Efficient Transformerの⽅法論に則り𝐾(𝑉 を先に計算
    ① shift演算によりトークンを記憶
    • Shift演算: [a,b,c] → [0,a,b]
    • 直感的理解: 常に𝐴がshift演算として機能するとき,仮に 𝐵 = 𝑒#
    とすると,連鎖的に
    𝑚 ステップ前までの 𝑢$
    が 𝑥$
    に格納される (𝑥$ = [𝑢$ , . . . , 𝑢$%&'#])
    ② Attention-Likeに乗算することでトークン間の関係を記憶 (like 𝑄𝐾(𝑉)
    • 𝐾(𝑉部分はHiPPOによって初期化された対⾓⾏列によるSSMが通される
    • Shiftによって過去の情報を保持した⾏列と⼊⼒を乗算することで類似度が計算できる

    View Slide

  18. 提案⼿法: H3 (Hungry Hungry Hippos)
    ü SSMは①前⽅にあるトークンの記憶と②トークン間の⽐較が不得意
    ü H3 (Hungry Hungry Hippos)
    • Attention-Like: Efficient Transformerの⽅法論に則り𝐾(𝑉 を先に計算
    ① shift演算によりトークンを記憶
    • Shift演算: [a,b,c] → [0,a,b]
    • 直感的理解: 常に𝐴がshift演算として機能するとき,仮に 𝐵 = 𝑒#
    とすると,連鎖的に
    𝑚 ステップ前までの 𝑢$
    が 𝑥$
    に格納される (𝑥$ = [𝑢$ , . . . , 𝑢$%&'#])
    ② Attention-Likeに乗算することでトークン間の関係を記憶 (like 𝑄𝐾(𝑉)
    • 𝐾(𝑉部分はHiPPOによって初期化された対⾓⾏列によるSSMが通される
    • Shiftによって過去の情報を保持した⾏列と⼊⼒を乗算することで類似度が計算できる

    18

    View Slide

  19. 提案⼿法: H3 (Hungry Hungry Hippos)
    19
    ü SSMは①前⽅にあるトークンの記憶と②トークン間の⽐較が不得意
    ü H3 (Hungry Hungry Hippos)
    • Attention-Like: Efficient Transformerの⽅法論に則り𝐾(𝑉 を先に計算
    ① shift演算によりトークンを記憶
    • Shift演算: [a,b,c] → [0,a,b]
    • 直感的理解: 常に𝐴がshift演算として機能するとき,仮に 𝐵 = 𝑒#
    とすると,連鎖的に
    𝑚 ステップ前までの 𝑢$
    が 𝑥$
    に格納される (𝑥$ = [𝑢$ , . . . , 𝑢$%&'#])
    ② Attention-Likeに乗算することでトークン間の関係を記憶 (like 𝑄𝐾(𝑉)
    • 𝐾(𝑉部分はHiPPOによって初期化された対⾓⾏列によるSSMが通される
    • Shiftによって過去の情報を保持した⾏列と⼊⼒を乗算することで類似度が計算できる

    View Slide

  20. 提案⼿法: H3 (Hungry Hungry Hippos)
    20
    • H3レイヤのアルゴリズムは以下の通り
    ⼊⼒ 𝑢 に対して,
    Q = 𝑢𝑊!
    , 𝐾 = 𝑢𝑊"
    , 𝑉 = 𝑢𝑊#
    を計算
    𝑄
    𝐾
    𝑉

    View Slide

  21. 提案⼿法: H3 (Hungry Hungry Hippos)
    21
    • H3レイヤのアルゴリズムは以下の通り
    𝑄
    𝐾
    𝑉
    𝐾 を SSM$%&'(
    に通して 2
    𝐾 を得る

    View Slide

  22. 提案⼿法: H3 (Hungry Hungry Hippos)
    22
    • H3レイヤのアルゴリズムは以下の通り
    𝑄
    𝐾
    𝑉
    𝑄, 𝐾, 𝑉をmulti-head化
    すなわち,dim⽅向に分割

    View Slide

  23. 提案⼿法: H3 (Hungry Hungry Hippos)
    23
    • H3レイヤのアルゴリズムは以下の通り
    𝑄
    𝐾
    𝑉
    各headごとに
    𝐾𝑉 ≔ SSM)&*+
    2
    𝐾𝑉, を得る

    View Slide

  24. 提案⼿法: H3 (Hungry Hungry Hippos)
    24
    • H3レイヤのアルゴリズムは以下の通り
    𝑄
    𝐾
    𝑉
    𝑄-
    ∈ ℝ. 𝑖 = 1, … 𝑁} ごとに
    𝑄-
    𝐾𝑉 -
    を計算してconcat

    View Slide

  25. 提案⼿法: H3 (Hungry Hungry Hippos)
    25
    • H3は系列⻑ 𝑁 に対しておおよそ O 𝑁log 𝑁 で動作する
    • 時間計算量: O 𝑑'𝑁 + 𝑑𝑁log 𝑁
    • 空間計算量: O 𝑑𝑁
    → Attentionは 時間計算量 O 𝑑𝑁' / 空間計算量 O 𝑁' なのでAttentionよりも⾼速

    View Slide

  26. 定量的結果: GPT-{Neo, 2}よりも低いperplexity
    26
    • Hybridモデルで実験 (トリック→実はH3単体だと若⼲負けるので)
    • Hybrid: AttentionとH3レイヤを交互に配置したモデル
    • データセット: The Pile, OpenWebText, WikiText103
    • GPT-Neo, GPT-2よりも
    低いperplexity
    • どのモデルサイズでも
    最良の結果

    View Slide

  27. 定量的結果: zero-shotでもGPT-{Neo, 2}およびOPTを上回る
    27
    • SuperGLUE (⾔語理解ベンチマーク)におけるzero-shot性能を検証
    → Zero-shotにおいてもGPT-Neo, GPT-2, OPTを上回る結果

    View Slide

  28. 推論速度: Transformerよりも2.4倍⾼速に推論
    28
    • A100 (80GB)における1.3Bモデルの推論速度を計測 (バッチサイズ: 64)
    • FlashConvの速度への寄与を Long Range Arena [Tay+, ICLR20] にて検証
    • H3: Transformerよりも
    2.4倍⾼速に推論可
    • S4 w/ FlashConvでは
    Transformerよりも5.8倍
    ⾼速

    View Slide

  29. PureなH3との⽐較: SuperGLUEにおいてOPT, GPT-Neoに及ばず
    29
    • PureなH3におけるzero-shot性能
    • OPT, GPT-Neoにかなりの差で負けている
    • MultiRC, ReCoRDで⼤敗
    • MultiRC: Multi-Sentence Reading Comprehension
    • ReCoRD: Reading Comprehension with Commonsense Reasoning Dataset

    View Slide

  30. 所感
    30
    • Strengths
    • ⾔語モデリングにおいて⾼速かつ強⼒なアーキテクチャを提⽰した点
    • Few-shot, zero-shot, LRAでの実験や,GPT-{2,Neo}, OPT, Perceiver AR等との
    ⽐較など多様な実験が⾏われており,H3の有効性に関するスコープが⽰され
    ている点
    • Weaknesses
    • PureなH3だとOPT, GPT-Neoに負けている
    • Attentionを混ぜてしまうならSSMの利点を⼗分に活かしきれていないのでは
    • Others
    • S5やMEGAなどもICLR23に採択されておりSSMの研究はかなり盛んな印象

    View Slide

  31. まとめ
    31
    ü 背景
    • 状態空間モデル(state-space model; SSM)は様々なモダリティにおいて有⽤性が検証された
    • ⼀⽅で,未だ⾔語においてはTransformerよりも性能が低い
    • またSSMは線形にスケーリングするにも拘らず,ハードウェアの利⽤効率が悪いために
    Transformerよりも低速
    ü 提案⼿法
    • ⾼速かつ⾔語に強い新たなSSMとしてH3 (Hungry Hungry Hippos)を提案
    ü 結果
    • HybridモデルにおいてGPT-2およびGPT-Neoよりも低いperplexity
    • SuperGLUEにおいてもzero-shotで最良の結果

    View Slide

  32. 参考⽂献
    32
    • [Gu+, NeurIPS20]: Albert Gu, Tri Dao, Stefano Ermon, Atri Rudra, and Christopher Re. “Hippo: Recurrent
    memory with optimal polynomial projections”. NeurIPS, 2020.
    • [Gu+, NeurIPS21]: Gu, Albert, Isys Johnson, Karan Goel, Khaled Saab, Tri Dao, Atri Rudra, and
    Christopher Ré. "Combining recurrent, convolutional, and continuous-time models with linear state
    space layers." NeurIPS, 2021.
    • [Gu+, ICLR22]: Albert Gu, Karan Goel, and Christopher Re. “Efficiently modeling long sequences with
    structured state spaces”. ICLR, 2022.
    • [Gu+, NeurIPS22]: Albert Gu, Ankit Gupta, Karan Goel, and Christopher Re. “On the parameterization
    and initialization of diagonal state space models”. NeurIPS, 2022.
    • [Ma+, ICLR23]: Xuezhe Ma, Chunting Zhou, Xiang Kong, Junxian He, Liangke Gui, Graham Neubig,
    Jonathan May, and Luke Zettlemoyer. "Mega: Moving Average Equipped Gated Attention." ICLR, 2023.
    • [Smith+, ICLR23]: Jimmy T.H. Smith, Andrew Warrington, Scott Linderman, “Simplified State Space
    Layers for Sequence Modeling” , ICLR, 2023.
    • [Mehta+, ICLR23]: Harsh Mehta, Ankit Gupta, Ashok Cutkosky, Behnam Neyshabur. “Long Range
    Language Modeling via Gated State Spaces”, ICLR, 2023
    • [Tay+, ICLR20]: Yi Tay, Mostafa Dehghani, Samira Abnar, Yikang Shen, Dara Bahri, Philip Pham, Jinfeng
    Rao, Liu Yang, Sebastian Ruder, and Donald Metzler. “Long range arena: A benchmark for efficient
    transformers”, ICLR, 2020.

    View Slide

  33. Appendix: 動作確認
    33
    • H3: (B, L, H) = (8, 6, 512)
    • データセット: The Pile (hacker_newsのみを使⽤)
    • 学習時間: 98時間, Epoch: 92
    • RAM(GPU): 17.3GB, test/perplexity: 29.5
    所感
    • H3動かすのにCUDA関係でめちゃくちゃ苦労した…
    • 後続のHyenaになるとImageを扱えるので,普通に
    マルチモダリティ扱えそう
    • 学習時間が減ってる感じはあまりわからない
    • もう少し軽量のデータセットで試せば効果を
    実感できそう

    View Slide

  34. Appendix: 計算量の証明
    34

    View Slide

  35. Appendix: MultiRC, ReCoRDの具体例 (SuperGLUE)
    35

    View Slide

  36. Appendix: Few-shot / LRAの定量的結果
    36

    View Slide

  37. Appendix: FlashConvにおけるState-Passing Algorithm
    37

    View Slide