Upgrade to PRO for Only $50/Year—Limited-Time Offer! 🔥

[Journal club] Simplified State Space Layers fo...

[Journal club] Simplified State Space Layers for Sequence Modeling

More Decks by Semantic Machine Intelligence Lab., Keio Univ.

Other Decks in Technology

Transcript

  1. Simplified State Space Layers for Sequence Modeling Jimmy T.H. Smith,

    Andrew Warrington, Scott W. Linderman (Stanford University) 慶應義塾大学 杉浦孔明研究室 D1 和田唯我 Jimmy T.H. Smith et al., “Simplified State Space Layers for Sequence Modeling” in ICLR (2023) ICLR23
  2. 概要 2 ✓ 背景 • SSMは長系列のモデリングに有望 • S4はSISO (single-input, single-output)でありrecurrent-modeでの学習は非効率

    ✓ 提案手法 • 状態サイズを小さくすることで S4 をMIMO (multi-input, multi-output) へ • parallel scanによってrecurrent-modeでも効率的に計算 ✓ 結果 • Long Range Arenaにおいて良好な結果(16kの系列長を持つPath-XにてSOTA)
  3. 前提: 状態空間モデルについて 3 • LSSL[Gu+, NeurIPS21]における定式化 • 入力 𝑢 𝑡

    ,状態 𝑥 𝑡 ,出力 𝑦 𝑡 に対して以下のように定義 • GBTにより離散化 (GBT; generalized bilinear transform) \begin{align*} x(t + \Delta t) = (I - \alpha \Delta t \cdot A)^{-1}(I + (1 - \alpha) \Delta t \cdot A) x(t) + \Delta t (I - \alpha \Delta t \cdot A)^{-1} B \cdot u(t) \end{align*} \begin{align*} \dot{x}(t) = Ax(t) + Bu(t), \\ y(t) = Cx(t)+Du(t) \end{align*} 連続空間 離散空間 𝜶はハイパラ, 𝑨, 𝑩, 𝑪, 𝑫 は学習可能パラメタ (BPなどで学習)
  4. 先行研究: SSM + 機械学習の研究は盛んに行われている 4 HiPPO [Gu+, NeurIPS20] LSSL [Gu+,

    NeurIPS21] S4 [Gu+, ICLR22] Mamba [Gu+, 24] S4D [Gu+, NeurIPS22] H3 [Fu+, ICLR23] S5 [Smith+, ICLR23] ← 詳しくはwadaの解説動画へ (SONY nnabla channel) https://youtu.be/G4pBhdv3RWk?si=4gINyCcW8CYx9Z_X
  5. S4はrecurrent-mode / convolution-modeを持つ 5 o LSSL [Gu+, NeurIPS21] 〜 S4

    [Gu+, ICLR22] • SSMは畳み込み形式で記述可能 → covolution-mode・recurrent-mode } \overline{A} \right)^k \overline{B} u_0 + C \left( \overline{A} \right)^{k-1} \overline{B} u_1 + \dots + C \overline{A} \overline{B} u_{k-1} + \overline{B} u_k + D u_k \\ K}_L(\overline{A}, \overline{B}, C) \ast u + D u \\ L(A, B, C) = \left(C A^i B\right)_{i \in \lbrack L\rbrack} \in \mathbb{R}^L = (CB, CAB, \dots, CA^{L-1}B) 出力 𝑦𝑘 を 入力𝑢0 を用いて 書き下す → 畳み込みで表現可能
  6. S4はSISOでありrecurrent-modeでの学習は非効率 6 o S4: LSSLは計算効率が悪い → 時間計算量 O 𝑁2𝐿 /

    空間計算量 O 𝑁𝐿 • S4では 𝐴 をDPLR化(diagonal plus low rank)することで計算量を削減 o S4はSISO (single-{input, output}) • 入力 𝑢 ∈ ℝ𝐿×𝐻 を処理するため, 𝐻個の独立したSSM (w/ 𝑁次元の状態)を使用 \begin{align*} A = \Lambda - PQ^* \end{align*} \begin{align*} (\Lambda - PQ^*, B,C) \end{align*} 状態 𝑥 ∈ ℝ𝐻×𝑁 入力 𝑢 ∈ ℝ𝐿×𝐻 チャネル方向へ独立に処理したため ここでチャネル方向にmixing 状態を𝐻 個に分割 → xk ∶ 𝑘 ∈ 0, 𝐻 番目の状態
  7. S4はSISOでありrecurrent-modeでの学習は非効率 7 o S4: LSSLは計算効率が悪い → 時間計算量 O 𝑁2𝐿 /

    空間計算量 O 𝑁𝐿 • S4では 𝐴 をDPLR化(diagonal plus low rank)することで計算量を削減 o S4はSISO (single-{input, output}) • 入力 𝑢 ∈ ℝ𝐿×𝐻 を処理するため, 𝐻個の独立したSSM (w/ 𝑁次元の状態)を使用 \begin{align*} A = \Lambda - PQ^* \end{align*} \begin{align*} (\Lambda - PQ^*, B,C) \end{align*} 状態 𝑥 ∈ ℝ𝐻×𝑁 入力 𝑢 ∈ ℝ𝐿×𝐻 チャネル方向へ独立に処理したため ここでチャネル方向にmixing 状態を𝐻 個に分割 → xk ∶ 𝑘 ∈ 0, 𝐻 番目の状態 通常のS4はSISOなので,recurrentでは効率的に学習できず → S4をMIMO化できれば,parrallel scanで高速化可能
  8. 提案手法 S5 はparallel scanにより効率的に学習可能 8 o 提案手法: S5 • S4では

    𝐻 × 𝑁 の状態を保持 → 低次元 𝑃 ≪ 𝐻 × 𝑁 の状態によりMIMO化 • MIMO化によって … • MIMOなのでS4で使われるChannel mixingは不要 • parallel scan(後述)により効率的な学習が可能に
  9. 提案手法 S5 はparallel scanにより効率的に学習可能 9 o 提案手法: S5 • S4では

    𝐻 × 𝑁 の状態を保持 → 低次元 𝑃 ≪ 𝐻 × 𝑁 の状態によりMIMO化 • MIMO化によって … • MIMOなのでS4で使われるChannel mixingは不要 • parallel scan(後述)により効率的な学習が可能に 状態 𝑥 ∈ ℝ𝑃 入力 𝑢 ∈ ℝ𝐿×𝐻 チャネル方向にすでにmixされている → Channel mixingは不要 ↑ S4では 𝑥 ∈ ℝ𝐻×𝑁 だった 状態を軽量化することで,単一SSM でモデリング可能(MIMO化)
  10. 提案手法 S5 はparallel scanにより効率的に学習可能 10 o Parallel scan algorithm •

    ある二項演算子・を持つモノイドに対して,以下の累積和を並列に計算 https://x.com/scott_linderman/status/1587142541949878273 • モノイド: 逆元を保証しない群と考えて 良い (cf. segment-tree) 1. 演算が集合内で閉じていて, 2. 単位元が存在して, 3. 結合律が成り立つ • 以下のモノイドから,SSMにparallel scanが適用可能なことが証明できる
  11. o Parallel scan algorithm • ある二項演算子・を持つモノイドに対して,累積和を並列に計算 • 行列 𝐴 ∈

    ℝ𝑃×𝑃における行列積の計算量を𝑇⊙ とすると O(𝑇⊙ log 𝐿) で計算可 • 通常の行列であれば𝑇⊙ = O (𝑃3)だが,対角行列であれば𝑇⊙ = O (𝑃 log 𝐿) • 空間計算量は O 𝑃𝐿 であり,かなり高速 → SSMでは対角化されているので, O 𝑃𝐿 の計算量で累積和を計算 提案手法 S5 はparallel scanにより効率的に学習可能 11 recurrent \begin{align*} y_k = C \left( \overline{A} \right)^k \overline{B} u_0 + C \left( \overline{A} \right)^{k-1} \overline{B} u_1 + \dots + C \overline{A} \overline{B} u_{k-1} + \overline{B} u_k + D u_k \\ \Rightarrow y = \mathcal{K}_L(\overline{A}, \overline{B}, C) \ast u + D u \\ \mathcal{K}_L(A, B, C) = \left(C A^i B\right)_{i \in \lbrack L\rbrack} \in \mathbb{R}^L = (CB, CAB, \dots, CA^{L-1}B) \end{align*} convolution cumulative sum S4 / LSSL 提案手法 S5
  12. o Parallel scan algorithm • ある二項演算子・を持つモノイドに対して,累積和を並列に計算 • 行列 𝐴 ∈

    ℝ𝑃×𝑃における行列積の計算量を𝑇⊙ とすると O(𝑇⊙ log 𝐿) で計算可 • 通常の行列であれば𝑇⊙ = O (𝑃3)だが,対角行列であれば𝑇⊙ = O (𝑃 log 𝐿) • 空間計算量は O 𝑃𝐿 であり,かなり高速 → SSMでは対角化されているので, O 𝑃𝐿 の計算量で累積和を計算 提案手法 S5 はparallel scanにより効率的に学習可能 12 recurrent cumulative sum 提案手法 S5 ❑ S5は以下の計算量で高速に学習可 O 𝑃𝐻𝐿 + 𝐻𝐿 = O 𝐻2𝐿 + 𝐻𝐿 𝑃 = 𝑂 𝐻
  13. 定量的結果: LRAにおいて良好な結果 / 特にPath-XでSOTA 13 o Long Range Arena [Tay+,

    ICLR21]: 長距離依存を扱う様々なタスクで構成されたベンチマーク • ListOps: 逆ポーランド記法により,入れ子になった数式を解く (len = 2,048) • Text: byte-levelのテキスト分類 (len = 4,096) • Retrieval: byte-levelの文書分類 (len = 4,000) • Image: pixel-levelの画像分類 (len = 1,024) • Pathfinder, Path-X: 白点同士が繋がっているかをpixel-levelで分類 (len = 1,024 / 16,384) • 既存手法とcomparableな結果 • 特にPath-XではSOTA
  14. 所感 15 • Strengths • 適切な理論に裏付けられ良好な結果を得ている点 • MIMO化しparralel scan を使うというアイデア

    • Weaknesses • 状態サイズが小さいので,S4に劣るタスクもあり • Comments • 結局Mambaでは,SISOのまま高速にparallel scanを行っていて,Tri Daoって偉大だ なという気持ち.アルゴリズムは全てを解決する! > Our proposed S6 shares the scan, but differs by (i) keeping the SISO dimensions, which provides a larger effective recurrent state, (ii) using a hardware-aware algorithm to overcome the computation issue, (iii) adding the selection mechanism. 引用: Mamba [Gu+, 24]
  15. まとめ 16 ✓ 背景 • SSMは長系列のモデリングに有望 • S4はSISO (single-input, single-output)でありrecurrent-modeでの学習は非効率

    ✓ 提案手法 • 状態サイズを小さくすることで S4 をMIMO (multi-input, multi-output) へ • parallel scanによってrecurrent-modeでも効率的に計算 ✓ 結果 • Long Range Arenaにおいて良好な結果(16kの系列長を持つPath-XにてSOTA)