Upgrade to PRO for Only $50/Year—Limited-Time Offer! 🔥

[Journal club] Accelerating Toeplitz Neural Net...

[Journal club] Accelerating Toeplitz Neural Network with Constant-time Inference Complexity

More Decks by Semantic Machine Intelligence Lab., Keio Univ.

Other Decks in Technology

Transcript

  1. Accelerating Toeplitz Neural Network with Constant-time Inference Complexity Zhen Qin,

    Yiran Zhong (OpenNLPLab, Shanghai Artificial Intelligence Laborator) 慶應義塾⼤学 杉浦孔明研究室 M1 和⽥唯我 Zhen Qin et al., “Accelerating Toeplitz Neural Network with Constant-time Inference Complexity” in EMNLP (2023) EMNLP23
  2. 概要 2 ü 背景 • TNNは強⼒な⾔語モデリングを提供する⼀⽅,推論に O 𝑛𝑑 log 𝑛

    掛かる • SSMはTNNに性能が劣るものの,O 𝑑ℎ と⾼速に推論可 ü 提案⼿法: ETSC (Exact Toeplitz-to-SSM Conversion) • DFTを⽤いてTNNをSSMへと変換する⼿法ETSCを提案 • TNNをVandermonde linear systemとして直接SSMへと変換 ü 結果 • ETSCに基づくTNNは性能を保ちつつ⾼速な推論を実現
  3. 復習: Toeplitz⾏列によるモデリングは利点が多い 3 o Toeplitz⾏列 • 対⾓線に沿って値が⼀定である⾏列 • ⾏列積はFFTにより O

    𝑛log 𝑛 で計算可 • SSMや畳み込みはToeplitz⾏列で表現可 • これらの⼿法を⼀般化したモデリングが可 o Toeplitz⾏列によるモデリングは利点が多い • 𝑛 × 𝑛で 2𝑛 − 1のパラメタ → 低コスト • O 𝑛log 𝑛 の計算量で⾏列積が計算可能 → ⾼効率 → [Qin+, ICLR23]: Toeplitz⾏列に基づくTNNを提案
  4. 復習: Toeplitz Neural Network (TNN) 4 o Toeplitz Neural Network

    [Qin+, ICLR23] • Gated Toeplitz Units (GTU) → Channel-Mixing (+ Token-Mixing) • Toeplitz Neural Operator (TNO) → Token-Mixing • Relative Positonal Encoder (RPE) → Toeplitz⾏列演算の効率化 Topelitz⾏列により⾼速かつ 強⼒なアーキテクチャを提案
  5. 復習: State Space Model (SSM) 5 • LSSL [Gu+, NeurIPS21]

    における定式化 • ⼊⼒ 𝑢 𝑡 ,状態 𝑥 𝑡 ,出⼒ 𝑦 𝑡 に対して以下のように定義 • GBTにより離散化 (GBT; generalized bilinear transform) 連続空間 離散空間 𝜶はハイパラ, 𝑨, 𝑩, 𝑪, 𝑫 は学習可能パラメタ (BPなどで学習)
  6. 背景: TNNとSSMの統合が望まれる 6 o TNNおよびSSMの利点・⽋点 • TNNは強⼒な⾔語モデリングを提供する⼀⽅,推論に O 𝑛𝑑 log

    𝑛 掛かる • SSMはTNNに性能が劣るものの,O 𝑑ℎ と⾼速に推論可 → TNNをSSMの形式に変換できれば,両者の利点を享受できて便利 性能 速度 計算量 (推論時) Transformer [Vaswani+, NIPS17] ◎ ✗ O 𝑛'𝑑 + 𝑛𝑑' Transformer w/ KV-Cache [Pope+, 22] ◎ △ O 𝑛𝑑' Linear Attention […] △ ◎ O 𝑑ℎ State Space Model [Gu+, ICLR22] ◦ ◎ O 𝑑ℎ Toeplitz Neural Network [Qin+, ICLR23] ◦ △ O 𝑛𝑑 log 𝑛
  7. 問題設定: TNNをSSMの形式に変換 7 o SSMの⼀般的な定式化 • ⼊⼒ 𝑥! , 出⼒

    𝑦! に対して, ただし, → DSS [Gupta+, ICLR22] の定式化 o TNNの定式化 • ⼊⼒ 𝒙 , 出⼒ 𝒚 に対して, (注: 通常のSSMと異なり 𝑢! は⼊⼒でない) TNNをSSMの形式に変換したい
  8. 問題設定: TNNをSSMの形式に変換 8 • 問題設定: TNNをSSMに変換 • 既知の 𝑡 から未知の

    𝜆 , 𝑏 を推定したい • ただし 𝑐 は実質 𝑐𝑏 ← 𝑏 と書けるので,𝑐) = 1 とする • このとき,以下の条件を満たす (𝝀 , 𝒃) を求めれば良い SSM TNN
  9. 課題: 勾配ベースの⼿法は性能が不⼗分 9 o 既知の 𝑡 から(𝝀 , 𝒃)を求める⽅法は以下の⼆種に⼤別 1.

    勾配ベースの解法 • 以下の最適化問題を勾配ベースの⽅法で解くことで (𝝀 , 𝒃) を推定 L 上記最適化問題では (𝝀 , 𝒃) は収束困難 [Gu+, ICLR22] なので不⼗分 2. 閉形式解 • 次式をVandermonde linear systemとして直接計算 → 本研究では適宜仮定を置きながら,直接 (𝝀 , 𝒃) を求める
  10. 提案⼿法: ETSC (Exact Toeplitz-to-SSM Conversion) 10 o 提案⼿法: ETSC (Exact

    Toeplitz-to-SSM Conversion) • はじめに,次式から • Vandermode ⾏列 𝐕 を⽤いて下式を得る. 列⽅向に等⽐数列を成しているので注意
  11. 提案⼿法: ETSC (Exact Toeplitz-to-SSM Conversion) • 上記 Vandermonde linear system

    の解法は数値的に不安定 • 𝜆* がpairwise distinctならば逆⾏列が存在するため,解は求まる → 以下を 𝜆* とすることでDFT⾏列 𝑾𝒏 より 𝐕 を変形 11
  12. 提案⼿法: ETSC (Exact Toeplitz-to-SSM Conversion) 13 • ここで,右辺の(1,1)成分を⽐較すると • 本来

    𝑡 に関する制約は存在しないので, 「解が存在」⇔「式(1)を満たすかつ Vandermondeの逆⾏列が存在」 • 式(1)は 𝒕 = 𝑛 𝑾𝒏𝒃 において満たされ ないので,以下のように累積和 ̅ 𝑡 を項 に追加 (1)
  13. 提案⼿法: ETSC (Exact Toeplitz-to-SSM Conversion) 14 • 右辺の(1,1)成分を⽐較して, • 本来

    𝑡 に関する制約は存在しないので, 「解が存在」⇔「式(1)を満たすかつ Vandermondeの逆⾏列が存在」 • 式(1)は 𝒕 = 𝑛 𝑾𝒏𝒃 において満たされ ないので,以下のように累積和 ̅ 𝑡 を項 に追加 (1) 𝑾𝒏#𝟏 もユニタリ⾏列なので,上式より
  14. 実験設定: 多様な設定により提案⼿法の有効性を検証 19 o データセット • Wikitext-103 [Merity+, ICLR17] •

    Wikitext-book [Wettig+, EACL23] o ⽐較対象 (全て6層のTNN / A100 GPUで学習) • Origin • FFTに基づき 𝑻𝒊𝒙𝒊 を計算するオリジナルなTNN → 𝑂 𝑛𝑑 log 𝑛 • Cache • 右式を計算しつつ,KV-Cacheと同様に 𝑦6 をキャッシュ→ 𝑂 𝑛𝑑' • Gradient-based • 勾配ベースの⼿法によりSSMへ変換 • ETSC • SSMへと変換することで推論の計算量を削減 → 𝑂 ℎ𝑑
  15. 実験結果: Wikitext-103において従来のTNNよりも⾼速かつ効率的に推論可 20 o 推論速度: • Originに対して約100倍⾼速に動作 / Cacheと⽐較しても⾼速な推論を実現 o

    メモリコスト: • 系列⻑が増加するにつれ,Origin, Cacheのメモリ消費量は増⼤ • 提案⼿法に基づくTNNはOrigin等と⽐較してもconstantな空間計算量
  16. 所感 23 • Strengths • Vandermonde⾏列をDFT化し,TNNをSSMへと変換するアイデアの斬新さ • ⾼速かつ⾼効率な⼿法を提案し,適切な実験により有効性が確認されている点 • Weaknesses

    • 𝜆! を恣意的に決めてVandermondeをDFT⾏列にするなら,より詳細な実験ないしは 理論的な解析が欲しい • 特に𝜆! はSSM由来なので,SSMの⽂脈できちんと議論して欲しい • 全体的にheuristicなのが気になる • Σ 𝑡" = 0を満たすよう追加した累積和の項など,どう影響するかの解析が欲しい • Comments • ETSCはDSSベースだが,DSSよりもはるかに良い結果 → どこがクリティカルに効いているのか(empericalでも良いので)解析が欲しい
  17. まとめ 24 ü 背景 • TNNは強⼒な⾔語モデリングを提供する⼀⽅,推論に O 𝑛𝑑 log 𝑛

    掛かる • SSMはTNNに性能が劣るものの,O 𝑑ℎ と⾼速に推論可 ü 提案⼿法: ETSC (Exact Toeplitz-to-SSM Conversion) • DFTを⽤いてTNNをSSMへと変換する⼿法ETSCを提案 • TNNをVandermonde linear systemとして直接SSMへと変換 ü 結果 • ETSCに基づくTNNは性能を保ちつつ⾼速な推論を実現