Upgrade to Pro — share decks privately, control downloads, hide ads and more …

[Journal club] Hyena Hierarchy: Towards Larger Convolutional Language Models

[Journal club] Hyena Hierarchy: Towards Larger Convolutional Language Models

More Decks by Semantic Machine Intelligence Lab., Keio Univ.

Other Decks in Technology

Transcript

  1. Hyena Hierarchy: Towards Larger Convolutional Language Models Michael Poli1, Stefano

    Massaroli2, Eric Nguyen1, Daniel Y. Fu1, Tri Dao1, Stephen Baccus1, Yoshua Bengio2, Stefano Ermon1, Christopher Ré1 (1Stanford University 2Mila and Université de Montréal) 慶應義塾⼤学 杉浦孔明研究室 M1 和⽥唯我 Michael Poli et al., “Hyena Hierarchy: Towards Larger Convolutional Language Models”, arXiv preprint arXiv:2302.10866.
  2. 概要 2 ü 背景 • Transformerは強⼒だが系列⻑に対してquadraticな計算量が掛かる • Efficient Transformerは計算量を改善するものの性能が不⼗分 ü

    提案⼿法 • Subquadraticな計算量でTransformerに匹敵するモデルHyenaを提案 • Attention-freeであり,SSMベースのH3やGSSを⼀般化したモデル ü 結果 • WikiText103やThe Pileにおいて80%の学習コストでTransformerに匹敵する結果 • ImageNetにおいてもVision Transformerと同程度のaccuracy
  3. 関連研究: H3 (Hungry Hungry Hippos) 3 ü SSMは①前⽅にあるトークンの記憶と②トークン間の⽐較が不得意 ü H3

    (Hungry Hungry Hippos) [Fu+, ICLR23] • ⾔語モデリングにおいて強⼒かつ⾼速なSSMレイヤ ① shift演算によりトークンを記憶 ② Attention-Likeに乗算することでトークン間の関係を記憶 (like 𝑄𝐾!𝑉) → AttentionとH3を交互に配置したHybrid-H3はTransformerと同程度のperplexity 課題: H3単体だと未だ性能に改善の余地あり / Hybrid-H3はAttention-Freeでない
  4. 背景: Attentionの性能に匹敵するsubquadraticな演算が望まれている 4 • Transformerは強⼒だが O 𝑑𝐿" の計算量が掛かる • 系列⻑に対してquadraticな計算量

    → ⼊⼒が⻑くなると計算コスト⼤ • Efficient Transformer: 計算量を改善するが,Transformerとの性能のギャップ • 例: Routing Transformer[Roy, TACL21], Reformer[Kitaev+, ICLR20], Linformer[Wang+, 20] Attentionの性能に匹敵する subquadraticな演算が望まれる Reformer[Kitaev+, ICLR20]
  5. 背景: Attentionの3つの特性 5 ü Data control • ⼊⼒に応じて異なる線形関数(重み)を表現できる → Attention

    Mapは⼊⼒により動的に変化 ü Sublinear parameter scaling • パラメタ数が系列⻑に対しsublinear(2次以下)でスケーリング → パラメタをFFN等の他のブロックに割り振ることが可能 ü Unrestricted context • コンテキストに対する制約が存在しない (e.g., 局所性) → ⼊⼒における任意の2点間の依存関係を近似することが可能 本論⽂の⽬的: これら3つの特性を満たす subquadratic な演算の設計
  6. 畳み込みはフィルタ ℎ の最適化⽅法によって2つに⼤別可能 6 • Explicit Convolutions • フィルタ ℎ

    を直接最適化する⽅法 (e.g., CNN) • CNNは信号処理において有限インパルス応答 (FIR) といえる • FIRのパラメタ数はカーネルサイズ 𝑀 (= step size)に関して線形にスケーリング • Implicit Convolutions • フィルタ ℎ を間接的に最適化する⽅法 (e.g., SSM) • パラメタ数がstep sizeに依存しないようフィルタ ℎ をstep 𝑡 で記述 (ℎ# ≔ 𝛾$ (𝑡)) SSM (reccurent) SSM (convolution)
  7. 畳み込みはフィルタ ℎ の最適化⽅法によって2つに⼤別可能 7 • Explicit Convolutions • フィルタ ℎ

    を直接最適化する⽅法 (e.g., CNN) • CNNは信号処理において有限インパルス応答 (FIR) といえる • FIRのパラメタ数はカーネルサイズ 𝑀 (= step size)に関して線形にスケーリング • Implicit Convolutions • フィルタ ℎ を間接的に最適化する⽅法 (e.g., SSM) • パラメタ数がstep sizeに依存しないようフィルタ ℎ を時間 𝑡 で記述 (ℎ# ≔ 𝛾$ (𝑡)) SSM (reccurent) SSM (convolution) SSMのフィルタはstep sizeである 𝑀に依存せず,学習可能パラメタ 𝐴, 𝐵, 𝐶, 𝐷 で記述可
  8. 畳み込み / H3はToeplitz⾏列の⾏列積で記述できる 8 • ⼀般的な畳み込みにおける 𝑡 成分 • ℎ!"#

    が繰り返し使⽤されるので,対⾓線に沿って 値が⼀定である Toeplitz⾏列 S$ により記述可能 • H3 [Fu+, ICLR23] 𝜓, 𝜙 : SSMでparametrizeされたフィルタ 𝐷! , 𝐷" : 𝑞, 𝑘 の要素で構成された対⾓⾏列 → SSMはImplicit Convolutionsの類なので, H3はToeplitz⾏列と対⾓⾏列の積で記述可能
  9. 提案⼿法: Hyena 9 ü 提案⼿法: Hyena • H3を⼀般化 + ⻑い畳み込みフィルタによってAttentionを代替

    • Hyena MatricesとHyena Filtersによって構成 • Hyena Matrices • Toeplitz⾏列と対⾓⾏列の積を𝑁 回に⼀般化 (H3では2回)
  10. 提案⼿法: Hyena 10 ü 提案⼿法: Hyena • H3を⼀般化 + ⻑い畳み込みフィルタによってAttentionを代替

    • Hyena MatricesとHyena Filtersによって構成 • Hyena Matrices • Toeplitz⾏列と対⾓⾏列の積を𝑁 回に⼀般化 (H3では2回) ⼊⼒に応じて重みが動的に変化 → Data controlled matrix
  11. 提案⼿法: Hyena 11 • Hyena Filters • ⼊⼒と同じ⻑さの⻑いフィルタによる畳み込み (FFN部分のみ学習可能) •

    Window: 指数関数的に減衰する関数は⾼周波フィルタと相性が良い • Positional Encoding: ⾼周波数成分の学習が安定化 [Basri+, PMLR20] • FFN: S4[Gu+, ICLR22]のフィルタを近似的に学習することが可能
  12. 提案⼿法: Hyena 12 • Hyena Filters • ⼊⼒と同じ⻑さの⻑いフィルタによる畳み込み (FFN部分のみ学習可能) •

    Window: 指数関数的に減衰する関数は⾼周波フィルタと相性が良い • Positional Encoding: ⾼周波数成分の学習が安定化 [Basri+, PMLR20] • FFN: S4[Gu+, ICLR22]のフィルタを近似的に学習することが可能 • パラメタ数が系列⻑に対して強依存していない → Sublinear parameter scaling • ⼊⼒と同系列⻑のフィルタを扱うため,任意の⼆点間の依存関係を近似可能 → Unrestricted context H3ではSSMによってフィルタが規定されたが,HyenaではHyena Filterを使⽤
  13. Hyena Algorithm 13 ⼊⼒ 𝑢 に対して,Attentionにおける Q = 𝑢𝑊! ,

    𝐾 = 𝑢𝑊" , 𝑉 = 𝑢𝑊# と同じ要領でProjection (Linear + DwConv)
  14. 定量的結果: Attention-freeであるにも拘らず,Transformerと同程度のPerplexity 16 ⽐較モデル • Hyena-3: 𝑁 = 3 としたHyena

    (12レイヤ) • Hyena-3-slim: depth⼤・width⼩ (18レイヤ) • Efficient Transformer • Performer[Choromanski+, ICLR21] • Reformer[Kitaev+, ICLR20] • Linear Attention[Katharopoulos+, ICML20] WikiText103 におけるperplexity 各モデルのパラメタについて HyenaはTransformerおよび Hybrid H3と同程度のperplexity
  15. 定量的結果: ImageNetやCIFAR10においてもViTと同程度のaccuracy 17 • データセット: ImageNet / CIFAR10 • Hyena-ViT:

    S4ND[Nguyen+, NeurIPS22]に則り設計 • AttentionをHyenaに代替 / CLSトークンとPEを削除 • 8台のA100で学習 → ViT[Dosovitskiy+, ICLR21]と⽐較して同程度のaccuracy "-ISO” = isotropic
  16. まとめ 20 ü 背景 • Transformerは強⼒だが系列⻑に対してquadraticな計算量が掛かる • Efficient Transformerは計算量を改善するものの性能が不⼗分 ü

    提案⼿法 • Subquadraticな計算量でTransformerに匹敵するモデルHyenaを提案 • Attention-freeであり,SSMベースのH3やGSSを⼀般化したモデル ü 結果 • WikiText103やThe Pileにおいて80%の学習コストでTransformerに匹敵する結果 • ImageNetにおいてもVision Transformerと同程度のaccuracy
  17. Appendix: 状態空間モデルについて 26 • LSSL[Gu+, NeurIPS21]における定式化 • ⼊⼒ 𝑢 𝑡

    ,状態 𝑥 𝑡 ,出⼒ 𝑦 𝑡 に対して以下のように定義 • GBTにより離散化 (GBT; generalized bilinear transform) 連続空間 離散空間 𝜶はハイパラ, 𝑨, 𝑩, 𝑪, 𝑫 は学習可能パラメタ (BPなどで学習)
  18. Appendix: SSM + 機械学習の研究は盛んに⾏われている 27 • H3までの系譜は以下の図の通り • HiPPO→LSSL→S4→H3→Hyena •

    特にS4は⾮常に重要な研究であり,後続の研究が盛んに⾏われている HiPPO [Gu+, NeurIPS20] LSSL [Gu+, NeurIPS21] S4 [Gu+, ICLR22] H3 [Fu+, ICLR23] S4D [Gu+, NeurIPS22] MEGA [Ma+, ICLR23] S5 [Smith+, ICLR23]
  19. Appendix: H3 (Hungry Hungry Hippos) 28 ü SSMは①前⽅にあるトークンの記憶と②トークン間の⽐較が不得意 • ⼀⽅,Attentionは𝑄𝐾!

    によりトークン間の関係を記憶可能(②) & softmax 𝑄𝐾! 𝑉 によりトークン⾃体を直接記憶可能(①) ü 提案⼿法: H3 (Hungry Hungry Hippos) • この⼆つの難点を乗り越える新たなSSMとしてH3を提案 • 上記考察に基づき,Q, K, V によって Attention-Likeに設計 • またGPU上のFFT, iFFT等の⾼速化⼿法 としてFlashConvを提案 (詳細は省略) ⇒ ⾼速かつ⾔語に強いSSMを実現