Slide 1

Slide 1 text

νϡʔτϦΞϧɿ.BNCB 7JTJPO.BNCB 7JN ౻٢߂࿱ Ԭຊ௚थʢத෦େֶɾ.13(ʣ ࣲాܓٱʢגࣜձࣾσϯιʔʣ IUUQNQSHKQ

Slide 2

Slide 2 text

w 5SBOTGPSNFSɿେن໛ݴޠϞσϧʢ--.ʣͷج൫Ϟσϧ ஫ҙػߏͷܭࢉྔ͕ೖྗαΠζʹରͯ͠ೋ࣍తʹ૿Ճˠ௕͍γʔέϯεͷॲཧʹ͓͍ͯܭࢉίετ͕ߴ͍ w ঢ়ଶۭؒϞσϧʢ44.Tʣ ݹయతͳঢ়ଶۭؒϞσϧʹண૝Λಘͨʮߏ଄Խঢ়ଶۭؒγʔέϯεϞσϧʯ͕஫໨ w .BNCB 5SBOTGPSNFSͱಉ౳ͷϞσϦϯάೳྗΛ࣋ͪͳ͕Βɺγʔέϯε௕ʹରͯ͠ઢܗతͳεέʔϥϏϦςΟΛ࣮ݱ ෆཁͳ৘ใΛഉআ͠ඞཁͳσʔλΛอ࣋͢Δબ୒ϝΧχζϜͱɺϋʔυ΢ΣΞʹ࠷దԽ͞ΕͨΞϧΰϦζϜΛ ಋೖ͢Δ͜ͱͰܭࢉޮ཰Λେ෯ʹ޲্ w .BNCBͷԠ༻ ίϯϐϡʔλϏδϣϯɺࣗવݴޠॲཧɺϔϧεέΞͳͲ༷ʑͳ෼໺Ͱ׆ൃͳݚڀ 7JNͱ͍͏Ϟσϧ͕%FJ5ΑΓ΋ഒߴ଎Ͱߴղ૾౓ը૾ͷಛ௃நग़Λߦ͍ɺ(16ϝϞϦΛˋઅ໿ .BNCBͷഎܠ

Slide 3

Slide 3 text

w 5SBOTGPSNFS<7BTXBOJ /FVS*14> ςΩετೖྗΛτʔΫϯͰߏ੒͞Εͨγʔέϯεͱͯ͠ଊ͑Δ ͲͷΑ͏ͳೖྗΛड͚औͬͯ΋ɺγʔέϯε಺ͷ೚ҙͷલͷτʔΫϯΛࢀরͯ͠දݱՄೳ w ܽ఺ ࣍τʔΫϯΛੜ੒͢Δʹ͸γʔέϯεશମͷΞςϯγϣϯΛ࠶ܭࢉ͢Δඞཁ͋Γ ௕͞ ͷγʔέϯεʹରͯ͠τʔΫϯΛੜ੒ˠ ͷܭࢉ͕ඞཁ L L2 5SBOTGPSNFS 4FMGBUUFOUJPO 4FMGBUUFOUJPO 4FMGBUUFOUJPO 4FMGBUUFOUJPO 4FMGBUUFOUJPO ӳࠃ ͷ ट౎ ͸ ΠΪϦε ௕͞L ௕͞ º࣍ݩ਺ L D

Slide 4

Slide 4 text

w &MNBO/FUXPSL<&MNBO $4> ࣌ࠁ ͷೖྗͱ ͷӅΕঢ়ଶΛड͚औΓɺ࣍ͷӅΕঢ়ଶΛੜ੒ͯ͠ग़ྗΛ༧ଌ લͷӅΕঢ়ଶͱݱࡏͷೖྗͷΈΛߟྀ͢Ε͹Α͍ͨΊɺ5SBOTGPSNFSͷΑ͏ʹશͯͷҎલͷӅΕঢ়ଶ Λ࠶ܭࢉ͢Δඞཁ͕ͳ͍ ˠγʔέϯεͷ௕͞ʹରͯ͠ઢܗʹεέʔϧ͢ΔͨΊߴ଎ͳਪ࿦Մೳ t t − 1 ϦΧϨϯτχϡʔϥϧωοτϫʔΫ 3// ˠલͷঢ়ଶͷΈΛߟྀ͢ΔͨΊɺ࣌ؒͱͱ΋ʹ৘ใΛ๨ΕΔ܏޲ ӳࠃ P(y1 |ӳࠃ) ͷ ट౎ P(y2 |ӳࠃͷ) ͸ P(y3 |ӳࠃͷट౎) P(y4 |ӳࠃͷट౎͸) ΠΪϦε P(y5 |ӳࠃͷट౎͸ΠΪϦε) 3// ҎલͷӅΕঢ়ଶͷू໿ ͷͨΊѹॖ͞Ε͍ͯΔ ࣍ݩ਺D

Slide 5

Slide 5 text

w ࣌ؒʹ൐͏γεςϜͷಈతͳڍಈΛදݱ͢ΔͨΊͷ਺ֶతϑϨʔϜϫʔΫ<,BMNBO > ݱࡏͷ࣌ࠁ ʹ͓͚Δೖྗ ͱग़ྗ ͷؔ܎Λঢ়ଶ Λհͯ͠ϞσϧԽ ੍ޚ޻ֶ෼໺Ͱݹ͔͘Βར༻ Ұൠతʹɺଟ͘ͷ44.TͰ͸ग़ྗํఔࣜͷୈ߲Λলུʢ ʣˠਂ૚ֶशϞσϧʹ͓͚ΔεΩοϓ઀ଓͱͯ͠ղऍ t x(t) ∈ ℝ y(t) ∈ ℝ h(t) ∈ ℝN Dx(t) = 0 ঢ়ଶۭؒϞσϧɿ4UBUF4QBDF.PEFMT 44.T 0VUQVU TFRVFODF *OQVU TFRVFODF 4UBUF4QBDF.PEFMT 44.T y(t) x(t) h′  (t) = Ah(t) + Bx(t) y(t) = Ch(t) + Dx(t) ঢ়ଶભҠߦྻ ɿঢ়ଶͷ࣌ؒมԽ A ∈ ℝN×N ೖྗߦྻ ɿೖྗ͕ঢ়ଶมԽʹ༩͑ΔӨڹΛ੍ޚ B ∈ ℝN×1 ग़ྗߦྻ ɿݱࡏͷঢ়ଶʹج͍ͮͯੜ੒ C ∈ ℝ1×N ೖྗ͕ग़ྗʹ௚઀༩͑ΔӨڹΛܾఆ͢Δ܎਺D ∈ ℝ ঢ়ଶํఔࣜɿ ग़ྗํఔࣜɿ ঢ়ଶ ͷ࣌ؒඍ෼ h(t) A h B C x(t) y(t) D

Slide 6

Slide 6 text

w ܥྻσʔλʢFH ୯ޠྻʣΛѻ͏ʹ͸࿈ଓදݱΛ཭ࢄԽ͢Δඞཁ͋Γ ࿈ଓ࣌ؒΛ౳͍͠ੵ෼ྖҬΛ࣋ͭ ݸͷ཭ࢄతͳ۠ؒʹ෼ׂʢ;FSP0SEFS)PME;0)ʣ ؔ਺஋͕۠ؒ ͷؒͰҰఆͰ͋ΔͱԾఆ ;0)ʹΑΔ཭ࢄԽޙͷ44.Tɿ ཭ࢄԽ͞Εͨ44.T͸࠶ؼతͳදݱͱͳΓɺ3//ʹྨࣅͨ͠ߏ଄Λ࣋ͭ ˠશͯͷೖྗʹରͯ͠஫ҙػߏΛܭࢉ͢Δ5SBOTGPSNFSϕʔεͷϞσϧΑΓߴޮ཰ͳਪ࿦͕Մೳ K Δ = [tk−1 , tk ] 44.Tͷ཭ࢄԽ hk = ¯ Ahk−1 + ¯ Bxk yk = ¯ Chk ¯ A = exp(ΔA) ¯ B = (ΔA)−1(exp(ΔA) − I) ⋅ ΔB xt xt−1 xt+1 ht ht−1 ht+1 yt yt−1 yt+1 ¯ A ¯ A ¯ A ¯ A ¯ B ¯ B ¯ B C C C ʢ ཭ࢄతͳ࣌ؒεςοϓʣ k ঢ়ଶํఔࣜɿ ग़ྗํఔࣜɿ

Slide 7

Slide 7 text

w ཭ࢄ44.T͸ઢܗγεςϜͰ͋Γ݁߹ੑ࣭Λ࣋ͭͨΊɺ৞ΈࠐΈܭࢉͱγʔϜϨεʹ౷߹ ৞ΈࠐΈΧʔωϧ ͱ͢Δͱˠ࠶ؼతͳܭࢉΛ࣍ݩ৞ΈࠐΈͱͯ͠ܭࢉՄೳ ৞ΈࠐΈܭࢉ͸ֶशϓϩηεͷฒྻܭࢉ (16 ͕Մೳʢඇઢܗ׆ੑԽؔ਺Λར༻͢Δ3//Ͱ͸࣮ݱͰ͖ͳ͍ʣ ೖྗ ͕ ࣍ݩͷ৔߹ɺ44.ͷܭࢉ͸֤࣍ݩ͝ͱʹݸผʹߦΘΕɺ ࣍ݩͷग़ྗ Λੜ੒ ೖྗߦྻ ग़ྗߦྻ ίϚϯυߦྻ ͱͳΓɺঢ়ଶભҠߦྻ͸มߋ͞Εͣ ¯ K = ( ¯ C ¯ B, ¯ C ¯ A ¯ B, …, ¯ C ¯ Ak ¯ B, …) x(k) D D y(t) B ∈ ℝN×D C ∈ ℝD×N D ∈ ℝD×D A ∈ ℝN×N 44.Tͷ৞ΈࠐΈܭࢉͱͷ౷߹ ೖྗγʔέϯεɿx = [x0 , x1 , …] ग़ྗγʔέϯεɿy = [y0 , y1 , …] ∈ ℝL y2 = ¯ Ch2 = ¯ C ¯ A ¯ A ¯ Bx0 + ¯ C ¯ A ¯ Bx1 + ¯ C ¯ Bx2 y0 = ¯ Ch0 = ¯ C ¯ Bx0 y1 = ¯ Ch1 = ¯ C ¯ A ¯ Bx0 + ¯ C ¯ Bx1 yk = ¯ Chk = ¯ C ¯ Ak ¯ Bx0 + ¯ C ¯ Ak−1 ¯ Bx1 + ¯ C ¯ Ak−2 ¯ Bx2 + … + ¯ C ¯ Bxk … (h−1 = 0) hk = ¯ Ahk−1 + ¯ Bxk yk = ¯ Chk y = x * ¯ K

Slide 8

Slide 8 text

࣍ݩ৞ΈࠐΈʹΑΔ44.Tͷܭࢉॲཧ ࣍ݩ৞ΈࠐΈͷॏΈύϥϝʔλ C ¯ A0 ¯ B C ¯ A1 ¯ B C ¯ A2 ¯ B x0 x1 x2 ʜ ʜ ʜ ೖྗγʔέϯεɿx = [x0 , x1 , …] ग़ྗγʔέϯεɿy = [y0 , y1 , …] ∈ ℝL y2 = ¯ Ch2 = ¯ C ¯ A ¯ A ¯ Bx0 + ¯ C ¯ A ¯ Bx1 + ¯ C ¯ Bx2 y0 = ¯ Ch0 = ¯ C ¯ Bx0 y1 = ¯ Ch1 = ¯ C ¯ A ¯ Bx0 + ¯ C ¯ Bx1 yk = ¯ Chk = ¯ C ¯ Ak ¯ Bx0 + ¯ C ¯ Ak−1 ¯ Bx1 + ¯ C ¯ Ak−2 ¯ Bx2 + … + ¯ C ¯ Bxk … (h−1 = 0) hk = ¯ Ahk−1 + ¯ Bxk yk = ¯ Chk y = x * ¯ K 0 0 y0

Slide 9

Slide 9 text

࣍ݩ৞ΈࠐΈʹΑΔ44.Tͷܭࢉॲཧ 0 x1 x2 ʜ ʜ ೖྗγʔέϯεɿx = [x0 , x1 , …] ग़ྗγʔέϯεɿy = [y0 , y1 , …] ∈ ℝL y2 = ¯ Ch2 = ¯ C ¯ A ¯ A ¯ Bx0 + ¯ C ¯ A ¯ Bx1 + ¯ C ¯ Bx2 y0 = ¯ Ch0 = ¯ C ¯ Bx0 y1 = ¯ Ch1 = ¯ C ¯ A ¯ Bx0 + ¯ C ¯ Bx1 yk = ¯ Chk = ¯ C ¯ Ak ¯ Bx0 + ¯ C ¯ Ak−1 ¯ Bx1 + ¯ C ¯ Ak−2 ¯ Bx2 + … + ¯ C ¯ Bxk … (h−1 = 0) hk = ¯ Ahk−1 + ¯ Bxk yk = ¯ Chk y = x * ¯ K ࣍ݩ৞ΈࠐΈͷॏΈύϥϝʔλ C ¯ A0 ¯ B C ¯ A1 ¯ B C ¯ A2 ¯ B ʜ y1 0 x0

Slide 10

Slide 10 text

࣍ݩ৞ΈࠐΈʹΑΔ44.Tͷܭࢉॲཧ 0 0 x2 ʜ ʜ ೖྗγʔέϯεɿx = [x0 , x1 , …] ग़ྗγʔέϯεɿy = [y0 , y1 , …] ∈ ℝL y2 = ¯ Ch2 = ¯ C ¯ A ¯ A ¯ Bx0 + ¯ C ¯ A ¯ Bx1 + ¯ C ¯ Bx2 y0 = ¯ Ch0 = ¯ C ¯ Bx0 y1 = ¯ Ch1 = ¯ C ¯ A ¯ Bx0 + ¯ C ¯ Bx1 yk = ¯ Chk = ¯ C ¯ Ak ¯ Bx0 + ¯ C ¯ Ak−1 ¯ Bx1 + ¯ C ¯ Ak−2 ¯ Bx2 + … + ¯ C ¯ Bxk … (h−1 = 0) hk = ¯ Ahk−1 + ¯ Bxk yk = ¯ Chk y = x * ¯ K ࣍ݩ৞ΈࠐΈͷॏΈύϥϝʔλ C ¯ A0 ¯ B C ¯ A1 ¯ B C ¯ A2 ¯ B ʜ y2 x0 x1

Slide 11

Slide 11 text

w ௕͍ܥྻσʔλΛޮ཰తʹϞσϦϯά͢ΔͨΊʹઃܭ͞Εͨߏ଄Խঢ়ଶۭؒϞσϧʢ4ʣ ߦྻ Λ)J110ʹΑΔϝϞϦॳظԽ ˠաڈͷೖྗΛྑ͘هԱͰ͖Δ)JHIPSEFS1PMZOPNJBM1SPKFDUJPO0QFSBUPSʢ)J110-FHTʣΛར༻ ޮ཰తͳܭࢉख๏ ˠ࣍ݩ৞ΈࠐΈͷΧʔωϧ Λߴ଎ʹܭࢉ͢ΔͨΊʹɺप೾਺ྖҬͰͷߴ଎ͳ৞ΈࠐΈܭࢉ ¯ A ¯ K 4USVDUVSFE4QBDF4UBUF.PEFMT 4 <(V *$-3> yk = ¯ Chk hk = ¯ Ahk−1 + ¯ Bxk N = 4 ¯ A = −1.00 0 0 0 −1.73 −3.00 0 0 −2.24 −3.87 −5.00 0 −2.65 −4.58 −5.92 −7 An,k = − (2n + 1), JG n = k (2n + 1)(2k + 1), JG n > k 0, JG n < k ॳظ஋ Algorithm 1 S4 Convolution Kernel (Sketch) Input: S4 parameters ⇤, P , Q, B, C 2 N and step size Output: SSM convolution kernel K = KL (A, B, C) for A = ⇤ P Q⇤ (equation (5)) 1: e C ⇣ I AL ⌘⇤ C . Truncate SSM generating function (SSMGF) to length L 2:  k00 (!) k01 (!) k10 (!) k11 (!) h e C Q i⇤ ⇣ 2 1 ! 1+! ⇤ ⌘ 1 [B P ] . Black-box Cauchy kernel 3: ˆ K(!) 2 1+! ⇥ k00 (!) k01 (!)(1 + k11 (!)) 1k10 (!) ⇤ . Woodbury Identity 4: ˆ K = { ˆ K(!) : ! = exp(2⇡i k L )} . Evaluate SSMGF at all roots of unity ! 2 ⌦L 5: K iFFT( ˆ K) . Inverse Fourier Transform

Slide 12

Slide 12 text

w ཭ࢄ44.T͸ઢܗಛੑʹΑΓɺ࠶ؼతܭࢉͱ࣍ݩ৞ΈࠐΈܭࢉͷ྆ํΛॊೈʹαϙʔτ 44.T 3// 5SBOTGPSNFSͷؔ܎ A Survey of Mamba 7 Linearity Multi-Head Attention Add & Norm Feed Forward Add & Norm Linear Softmax Masked Multi-Head Attention Add & Norm Multi-Head Attention Add & Norm Feed Forward Add & Norm Input Embedding Input Embedding Inputs Outputs Output Probabilities Positional Encoding Positional Encoding (b) Transformer (Attention Mechanism) Unfold (a) Recurrent Neural Network Unfold (c) State Space Model (Time-Invariance) Unfold Non-Linearity Recurrent form Convolutional form ... ... ... ... ... = = = Parallel Computation ... ... tanh ... ... ... Softmax Multi-head Attention Scores Softmax Softmax Scaling Scaling Scaling ... ... Parallel Computation Fig. 2. An illustration of representative model architectures, namely Recurrent Neural Network (RNN), Transformer, and State Space Model (SSM). (a) RNNs function within a nonlinear recurrent framework, facilitating rapid outputs during auto-regressive inference. (b) Transformers execute matrix multiplications concurrently across numerous query-key pairs, facilitating parallel training. (c) SSMs exhibit versatility by accommodating both recurrent and convolutional computations due to their linear nature. This fusion harnesses <2V BS9JW`>ΑΓҾ༻ 44.Tɿ 👍ઢܗಛੑʹΑΓɺ࠶ؼతܭࢉͱ৞ΈࠐΈܭࢉͷ྆ํʹରԠͰ͖Δ൚༻ੑˠ࠶ؼతͳਪ࿦ͱฒྻֶशΛՄೳ 👎ैདྷͷ44.͸࣌ؒෆมˠ ͕Ϟσϧͷೖྗ ʹؔ࿈͠ͳ͍ˠίϯςΩετΛߟྀͨ͠ϞσϦϯά͕ྼΓɺಛఆͷλεΫͰੑೳ͕௿Լ A, B, C, Δ x

Slide 13

Slide 13 text

44.T 3// 5SBOTGPSNFSͷؔ܎ <2V BS9JW`>ΑΓҾ༻ that the most conventional SSMs are time-invariant, meaning that their A, B, C, and are unrelated to the model input G. This would limit context-aware modeling, which leads to inferior performance of SSMs in certain tasks such as selective copying [55]. Table 1. Pros and cons of three primary architectures-RNNs, Transformers, and SSMs-in auto-regressive sequential modeling tasks. Comparison RNNs Transformers SSMs Training Speed Slow (Recurrent) Fast (Parallel) Fast (Convolutional) Inference Speed Fast (Recurrent) Slow (Quadratic-Time) Fast (Recurrent) Complexity $(!⇡2) $(!2⇡) $(!⇡2) Modeling Capabilities (Hidden State) (Attention) (Time-Invariance) Manuscript submitted to ACM x2 x1 x3 y2 y1 y3 "UUFOUJPO աڈͷೖྗશͯͱͷܭࢉ x2 x1 x3 h2 h1 h3 y2 y1 y3 ¯ A ¯ A ¯ B ¯ B ¯ B C C C x0 h0 y0 ¯ A ¯ B C 44.T ݱࡏͷೖྗͱҰͭલͷঢ়ଶ͔Βܭࢉ ࣗݾճؼλεΫʹ͓͚Δ֤Ϟσϧͷಛ௃ x0 y0

Slide 14

Slide 14 text

w ैདྷͷঢ়ଶۭؒϞσϧʢ44.ʣͷ໰୊఺ ςΩετ΍৘ใີ౓ͷߴ͍σʔλͷϞσϦϯάʹ͓͍ͯޮՌ͕ݶఆత w .BNCB ߏ଄Խঢ়ଶۭؒϞσϧ 4 Λج൫ͱͨͭ͠ͷख๏Λಋೖ )JHIPSEFS1PMZOPNJBM1SPKFDUJPO0QFSBUPSʢ)J110ʣʹΑΔϝϞϦॳظԽ ˠҰ؏ͨ͠ӅΕঢ়ଶߦྻΛߏங͠ɺ௕ظతͳهԱΛଅਐ બ୒ϝΧχζϜ ˠίϯςΩετʹԠͨ͡දݱΛ֫ಘ ϋʔυ΢ΣΞʹ࠷దԽ͞ΕͨܭࢉΞϧΰϦζϜ ˠϋʔυ΢ΣΞ࠷దԽ͞Εͨܭࢉख๏ʢฒྻ࿈૝εΩϟϯͱϝϞϦ࠶ܭࢉʣʹΑΓֶशޮ཰Λ޲্ .BNCB<(VBOE%BP BS9JW>

Slide 15

Slide 15 text

w /-1෼໺ʹ͓͚Δঢ়ଶۭؒϞσϧʢ44.Tʣʹجͮ͘৽͍͠ϞσϧΞʔΩςΫνϟ w ಛ௃ બ୒ϝΧχζϜʹΑΓɺϞσϧ͕ೖྗʹԠͯ͡ύϥϝʔλΛಈతʹௐ੔ 4FMFDUJWF44. ϋʔυ΢ΣΞʹదͨ͠ΞϧΰϦζϜʹΑΓɺޮ཰తͳֶशͱਪ࿦Λ࣮ݱ γϯϓϧͳωοτϫʔΫΞʔΩςΫνϟ .BNCB#MPDL .BNCB<(VBOE%BP BS9JW> H3 Gated MLP Mamba Linear projection Sequence transformation Nonlinearity (activation or multiplication) X X X ! X Conv SSM X ! ! Conv SSM ⨂ Project Discretize !! ℎ!"# ℎ! "! # $! %! Selection Mechanism GPU SRAM GPU HBM ∆! Selective State Space Model with Hardware-aware State Expansion 4FMFDUJWF44.T .BNCB#MPDL ϋʔυ΢ΣΞʹదͨ͠ΞϧΰϦζϜ

Slide 16

Slide 16 text

w 44.Tʹબ୒ϝΧχζϜΛಋೖ ೖྗ Λ༻͍ͨઢܗ૚ͷग़ྗΛ44.ͷύϥϝʔλ ʹ࢖༻ xt Bt Ct Δt .BNCBɿ4FMFDUJWF44.T 4 ˠೖྗ ʹԠͯ͡อ࣋͞ΕΔ৘ใΛ੍ޚ͢Δ xt Project Discretize !! ℎ!"# ℎ! "! # $! %! Selection Mechanism GPU SRAM GPU HBM ∆! Selective State Space Model with Hardware-aware State Expansion ࣍ݩ D જࡏঢ়ଶ(N = 4) ࣍ݩ D (B ∈ ℝN×D) (C ∈ ℝD×N) (A ∈ ℝN×N) 1 − gt gt ೖྗ ग़ྗ જࡏঢ়ଶ ht = (1 − gt )ht−1 + gt xt ͱԾఆͨ͠৔߹: N = 1, A = − 1, B = 1, gt = σ(Linear(xt ))

Slide 17

Slide 17 text

w બ୒ϝΧχζϜͷ࢓૊Έ ೖྗʹԠͯͭ͡લͷঢ়ଶͱೖྗͷͲͪΒΛॏࢹ͢Δ͔Λબ୒ .BNCBɿ4FMFDUJWF44.T 4 ht = Aht−1 + Bxt ht = (1 − gt )ht−1 + gt xt gt = σ(Linear(xt )) • Recurrent Memory Transformer (Bulatov, Kuratov, and Burtsev 2023), a lightweight wrapper around a Transformer backbone. It showed ability to generalize up to 1M sequences but only on synthetic memorization tasks; their main result is similar to our Induction Heads extrapolation experiment (Table 2). • LongNet (Ding et al. 2023), which claimed to scale to 1B length but only evaluated on length < 100 for actual tasks. • Hyena and HyenaDNA (Nguyen, Poli, et al. 2023; Poli et al. 2023), which claimed to leverage up to 1M context. How- ever, their experiments trained on proportionally more data at longer contexts, making it hard to conclude if quality improvements at 1M context are due to context length or due to more data and computation. • Sparse Transformer (Child et al. 2019) showed a proof-of-concept of using a strided sparse attention Transformer to model audio waveforms of length 220 = 1048576, although did not discuss performance tradeos when controlling for computation and model size. In contrast, we believe this work presents one of the rst approaches to meaningfully demonstrate increasing performance with longer context. C Mechanics of Selective SSMs Proof of Theorem 1. Consider a selective SSM (Algorithm 2) with # = 1, G = 1, H = 1,B = Linear(G),g = soplus. The corresponding continuous-time SSM (1) is ⌘(C) = ⌘(C) + G(C) which is also called a leaky integrator. The discretization step size is C = g (Parameter + B (GC )) = soplus(Parameter + Linear(GC )) = soplus(Linear(GC )) where we observe that the parameter can be viewed as a learnable bias and folded into the linear projection. Now applying the zero-order hold (ZOH) discretization formulas: GC = exp( G) = 1 1 + exp(Linear(GC )) = f( Linear(GC )) = 1 f(Linear(GC )) HC = ( G) 1(exp( G) O) · H = (exp( G) O) = 1 G = f(Linear(GC )). 27 Thus the nal discrete recurrence (2a) is 6C = f(Linear(GC )) ⌘C = (1 6C )⌘C 1 + 6CGC as desired. ⇤ D Hardware-aware Algorithm For Selective SSMs ࣜมܗ ͷ৔߹ɿ gt → 0 ht = ht−1 ೖྗΛແࢹ l͋ʙz z͑ʙzͱ͍ͬͨؒ౤ࢺΛޮ཰తʹഉআ ͷ৔߹ɿ gt → 1 ht = xt ঢ়ଶΛϦηοτ lͱ͜ΖͰz౳ͷ࿩ͷల։ʹ߹ΘͤͯϦηοτ͢Δ ͜ͱͰɺίϯςΩετΛਂ͘ཧղ

Slide 18

Slide 18 text

w ฒྻ࿈૝εΩϟϯ εΩϟϯΞϧΰϦζϜΛಋೖͯ͠Ұ෦ͷܭࢉΛฒྻԽ ˠฒྻԽʹΑΓܭࢉ࣌ؒΛ࡟ݮ w Χʔωϧ༥߹ αΠζͷେ͖͍ Λ)#.͔Β43".ʹϩʔυ͢Δ୅ΘΓʹ Λϩʔυ Λ43".Ͱܭࢉ͠ɺͦͷ··࠷ऴग़ྗ ͷܭࢉʹར༻ ˠϝϞϦ*0ͷྔ͕ݮΔ͜ͱͰܭࢉ࣌ؒΛ࡟ݮ w ϝϞϦͷ࠶ܭࢉ 'PSXBSE࣌ ɿޯ഑ܭࢉʹඞཁ͕ͩ࠶ܭࢉ͕ߴ଎ͳதؒग़ྗΛ)#.ʹอଘ͠ͳ͍ #BDLXBSE࣌ ɿதؒग़ྗΛ43".Ͱ࠶ܭࢉ͠ޯ഑ܭࢉʹར༻ʢதؒग़ྗͷϩʔυ࣌ؒ࠶ܭࢉͷ࣌ؒʣ ˠதؒग़ྗΛอଘ͠ͳ͍͜ͱͰϝϞϦ࢖༻ྔΛ࡟ݮɺϩʔυΑΓߴ଎ͳ࠶ܭࢉʹΑΓܭࢉ࣌ؒΛ࡟ݮ ¯ A, ¯ B Δ, A, B, C ¯ A, ¯ B y .BNCBɿϋʔυ΢ΣΞʹదͨ͠ΞϧΰϦζϜ Project Discretize !! ℎ!"# ℎ! "! # $! %! Selection Mechanism GPU SRAM GPU HBM ∆! Selective State Space Model with Hardware-aware State Expansion Project Discretize !! ℎ!"# ℎ! "! $! %! Selection Mechanism GPU SRAM GPU HBM ∆! ϝϞϦྔ͕গͳ͍͕ߴ଎ ϝϞϦྔ͕ଟ͍͕௿଎

Slide 19

Slide 19 text

w 44.ΞʔΩςΫνϟͷجૅͰ͋Δ)ͱ(BUFE.-1ͷ૊Έ߹Θͤ .BNCBɿϒϩοΫ )ˠ.BNCBɿϝΠϯϒϥϯνͷ࠷ॳͷ৐ࢉήʔτΛ׆ੑԽؔ਺ʹஔ͖׵͑ (BUFE.-1ˠ.BNCBɿ44.ͷ௥Ճͱ4J-64XJTI׆ੑԽؔ਺Λ࢖༻ H3 Gated MLP Mamba Linear projection Sequence transformation Nonlinearity (activation or multiplication) X X X ! X Conv SSM X ! ! Conv SSM ⨂ Figure 3: (Architecture.) Our simplied block design combines the H3 block, which is the basis of most SSM architectures, with the ubiquitous MLP block of modern neural networks. Instead of interleaving these two blocks, we simply repeat the Mamba block homogenously. Compared to the H3 block, Mamba replaces the rst multiplicative gate with an activation function. Compared to the MLP block, Mamba adds an SSM to the main branch. For f we use the SiLU / Swish activation (Hendrycks and Gimpel 2016; Ramachandran, Zoph, and Quoc V Le 2017). -JOFS1SPKFDUJPO -JOFS1SPKFDUJPO -JOFS1SPKFDUJPO -JOFS1SPKFDUJPO -JOFS1SPKFDUJPO -JOFS1SPKFDUJPO -JOFS1SPKFDUJPO -JOFS1SPKFDUJPO -JOFS1SPKFDUJPO -JOFS1SPKFDUJPO ήʔτػߏ ɿ4J-6 4JHNPJE-JOFBS6OJU Λ࢖༻ σ() ೖྗ৘ใͷऔࣺબ୒ ΤϯδχΞϦϯάख๏Ͱߴ଎Խ

Slide 20

Slide 20 text

w ૚਺ͱτʔΫϯͷग़ྗ࣍ݩ਺ʹԠ༷ͯ͡ʑͳύϥϝʔλ਺ͷ.BNCBϞσϧΛ༻ҙ .BNCBɿϞσϧαΠζ Ϟσϧ໊ ύϥϝʔλ਺ ૚਺ τʔΫϯͷ࣍ݩ਺ .BNCB. . .BNCB. . .BNCB. . .BNCB# # .BNCB# #

Slide 21

Slide 21 text

.BNCBͷޮՌ ίϐʔλεΫ λεΫɿ4FMFDUJWF$PQZJOH هԱೳྗΛςετ͢ΔͨΊʹઃܭ͞Βͨ߹੒λεΫ ΞϧϑΝϕοτɿ\B C D ʜ [ ۭന τϦΨʔ^ ೖྗγʔέϯεɿ ظ଴͞ΕΔग़ྗɿ<ۭന ۭന ۭന ۭന ۭന ۭന ۭന I F M M P> M A. L A. S4 No gate S4 18.3 - No gate S6 97.0 H3 H3 S4 57.0 Hyena H3 Hyena 30.1 - H3 S6 99.7 - Mamba S4 56.4 - Mamba Hyena 28.4 Mamba Mamba S6 99.8 Table 1: (Selective Copying.) Accuracy for combinations of architectures and inner sequence layers. Table 2: (Induction Heads.) Models are trained on sequence length 28 = 256, and tested on increasing sequence lengths of 26 = 64 up to 220 = 1048576. Full numbers in Table 11. ˠ.BNCB 4FMFDUJWF44. ʹΑΓಈతͳਪ࿦ が Մೳʹ *ODPOUFYUMFBSOJOHͷೳྗΛධՁ͢ΔͨΊͷ߹੒λεΫ ೖྗγʔέϯεɿ"#"#" ύλʔϯͷൃݟɿγʔέϯε͸l"#z͕܁Γฦ͞Ε͍ͯΔ Ϟσϧͷظ଴͞ΕΔग़ྗɿ࣍ʹདྷΔͷ͸l#z ϞσϧͷλεΫɿ͜ͷύλʔϯΛೝࣝ͠ɺl#zΛग़ྗ M A. L A. S4 No gate S4 18.3 - No gate S6 97.0 H3 H3 S4 57.0 Hyena H3 Hyena 30.1 - H3 S6 99.7 - Mamba S4 56.4 - Mamba Hyena 28.4 Mamba Mamba S6 99.8 Table 1: (Selective Copying.) Accuracy for combinations of architectures and inner sequence layers. Table 2: (Induction Heads.) Models are trained on sequence length 28 = 256, and tested on increasing sequence lengths of 26 = 64 up to 220 = 1048576. Full numbers in Table 11.

Slide 22

Slide 22 text

.BNCBͷޮՌ ˠͷϞ デ ϧαΠ ズで 5SBOTGPSNFSΑΓߴ͍ੑೳ ༷ʑͳԼྲྀλεΫͷθϩγϣοτධՁ against the most well-known open source models at these sizes, most importantly Pythia (Biderman et al. 2023) and RWKV (B. Peng et al. 2023) which were trained with the same tokenizer, dataset, and training length (300B tokens) as our models. (Note that Mamba and Pythia are trained with context length 2048, while RWKV was trained with context length 1024.) Table 3: (Zero-shot Evaluations.) Best results for each size in bold. We compare against open source LMs with various tokenizers, trained for up to 300B tokens. Pile refers to the validation split, comparing only against models trained on the same dataset and tokenizer (GPT-NeoX-20B). For each model size, Mamba is best-in-class on every single evaluation result, and generally matches baselines at twice the model size. M T. P LAMBADA LAMBADA HS PIQA AE AC WG A # # " " " " " " " Hybrid H3-130M GPT2 — 89.48 25.77 31.7 64.2 44.4 24.2 50.6 40.1 Pythia-160M NeoX 29.64 38.10 33.0 30.2 61.4 43.2 24.1 51.9 40.6 Mamba-130M NeoX 10.56 16.07 44.3 35.3 64.5 48.0 24.3 51.9 44.7 Hybrid H3-360M GPT2 — 12.58 48.0 41.5 68.1 51.4 24.7 54.1 48.0 Pythia-410M NeoX 9.95 10.84 51.4 40.6 66.9 52.1 24.6 53.8 48.2 Mamba-370M NeoX 8.28 8.14 55.6 46.5 69.5 55.1 28.0 55.3 50.0 Pythia-1B NeoX 7.82 7.92 56.1 47.2 70.7 57.0 27.1 53.5 51.9 Mamba-790M NeoX 7.33 6.02 62.7 55.1 72.1 61.2 29.5 56.1 57.1 GPT-Neo 1.3B GPT2 — 7.50 57.2 48.9 71.1 56.2 25.9 54.9 52.4 Hybrid H3-1.3B GPT2 — 11.25 49.6 52.6 71.3 59.2 28.1 56.9 53.0 OPT-1.3B OPT — 6.64 58.0 53.7 72.4 56.7 29.6 59.5 55.0 Pythia-1.4B NeoX 7.51 6.08 61.7 52.1 71.0 60.5 28.5 57.2 55.2 RWKV-1.5B NeoX 7.70 7.04 56.4 52.5 72.4 60.5 29.4 54.6 54.3 Mamba-1.4B NeoX 6.80 5.04 64.9 59.1 74.2 65.5 32.8 61.5 59.7 GPT-Neo 2.7B GPT2 — 5.63 62.2 55.8 72.1 61.1 30.2 57.6 56.5 Hybrid H3-2.7B GPT2 — 7.92 55.7 59.7 73.3 65.6 32.3 61.4 58.0 OPT-2.7B OPT — 5.12 63.6 60.6 74.8 60.8 31.3 61.0 58.7 Pythia-2.8B NeoX 6.73 5.04 64.7 59.3 74.0 64.1 32.9 59.7 59.1 RWKV-3B NeoX 7.00 5.24 63.9 59.6 73.7 67.8 33.1 59.6 59.6 Mamba-2.8B NeoX 6.22 4.23 69.2 66.1 75.2 69.7 36.3 63.5 63.3 GPT-J-6B GPT2 – 4.10 68.3 66.3 75.4 67.0 36.6 64.1 63.0 OPT-6.7B OPT – 4.25 67.7 67.2 76.3 65.6 34.9 65.5 62.9 Pythia-6.9B NeoX 6.51 4.45 67.1 64.0 75.2 67.3 35.5 61.3 61.7 RWKV-7.4B NeoX 6.31 4.38 67.2 65.5 76.1 67.8 37.5 61.0 62.5 4.3 DNA Modeling Motivated by the success of large language models, there has been recent exploration into using the foundation model

Slide 23

Slide 23 text

X B C X B C Parallel Mamba Block Linear projection Sequence transformation Nonlinearity (activation, normalization, multiplication) X ! ! Conv SSM X ! Conv SSM A A N Y Y Sequential Mamba Block ! .BNCB .BNCB -JOFBS -JOFBS -JOFBS -JOFBS -JOFBS -JOFBS -JOFBS w .BNCB#MPDLͷػೳΛ͞Βʹ֦ு 44.ͷܭࢉΛߦྻੵΞϧΰϦζϜͱͯ͠࠶ఆٛ͢Δ͜ͱͰ(16্Ͱͷܭࢉ଎౓Λվળ 5SBOTGPSNFSͷ஫ҙػߏͱ4FMFDUJWF44.Λ4USVDUVSFEߦྻͰ౷Ұతʹදݱ w ஫ҙػߏͱ4FMFDUJWF44.͕਺ֶతʹ౳Ձˠ5SBOTGPSNFSͷςΫχοΫΛ44.ʹಋೖՄೳ .BNCB<%BPBOE(V *$.-> ୯ҰͷઢܗࣹӨͰ Λฒྻʹࢉग़ A, X, B, C Λࢉग़͢ΔઢܗࣹӨͷޙʹ Λࢉग़͢ΔઢܗࣹӨΛద༻ X A, B, C /PSN'PSNFSΞʔΩςΫνϟ ʹج͍ͮͯਖ਼نԽΛಋೖ X B C X B C Linear projection Sequence transformation Nonlinearity (activation, normalization, multiplication) X ! ! Conv SSM X ! Conv SSM A A N Y Y !

Slide 24

Slide 24 text

.BNCBͷޮՌ .VMUJRVFSZBTTPDJBUJWFSFDBMMλεΫ ίϯςΩετ಺Ͱ৘ใΛݕࡧ͢ΔೳྗΛςετ͢ΔλεΫ .BNCB͕ࠔ೉ͳλεΫΛ.BNCB͸ղ͘͜ͱ͕Մೳ ܭࢉ଎౓ͷධՁ 'MBTI"UUFOUJPOɺ$POWPMVUJPOɺ ҟͳΔΞϧΰϦζϜʹΑΔ44.ͷܭࢉ଎౓Λൺֱ .BNCB͸.BNCBͱൺ΂ͯʙഒߴ଎ ೖྗɿ"#$'&ˠ" $ ' & # λεΫɿ֤ΫΤϦʹରԠ͢Δ৘ใΛग़ྗ ظ଴͞ΕΔग़ྗɿ \ \ ΩʔͱόϦϡʔ ΫΤϦ

Slide 25

Slide 25 text

.BNCBͷޮՌ ༷ʑͳԼྲྀλεΫͷθϩγϣοτධՁ ˠ.BNCB͸.BNCBͱಉ౳ͷੑೳ • ARC-challenge (Clark et al., 2018) • ARC-easy: an easy subset of ARC-challenge • WinoGrande (Sakaguchi et al., 2021) • OpenBookQA (Mihaylov et al., 2018) Table 3: (Zero-shot Evaluations.) Best results for each size in bold, second best unlined. We compare against open source LMs with various tokenizers, trained for up to 300B tokens. Pile refers to the validation split, comparing only against models trained on the same dataset and tokenizer (GPT-NeoX-20B). For each model size, Mamba-2 outperforms Mamba, and generally matches Pythia at twice the model size. MODEL TOKEN. PILE LAMBADA LAMBADA HELLASWAG PIQA ARC-E ARC-C WINOGRANDE OPENBOOKQA AVERAGE PPL → PPL → ACC ↑ ACC ↑ ACC ↑ ACC ↑ ACC ↑ ACC ↑ ACC ↑ ACC ↑ Hybrid H3-130M GPT2 — 89.48 25.8 31.7 64.2 44.4 24.2 50.6 27.0 38.2 Pythia-160M NeoX 29.64 38.10 33.0 30.2 61.4 43.2 24.1 51.9 29.2 39.0 Mamba-130M NeoX 10.56 16.07 44.3 35.2 64.5 48.0 24.2 51.9 28.8 42.4 Mamba-2-130M NeoX 10.48 16.86 43.9 35.3 64.9 47.4 24.2 52.1 30.6 42.6 Hybrid H3-360M GPT2 — 12.58 48.0 41.5 68.1 51.4 24.7 54.1 31.6 45.6 Pythia-410M NeoX 9.95 10.84 51.4 40.6 66.9 52.1 24.6 53.8 30.0 45.6 Mamba-370M NeoX 8.28 8.14 55.6 46.5 69.5 55.1 28.0 55.3 30.8 48.7 Mamba-2-370M NeoX 8.21 8.02 55.8 46.9 70.5 54.9 26.9 55.7 32.4 49.0 Pythia-1B NeoX 7.82 7.92 56.1 47.2 70.7 57.0 27.1 53.5 31.4 49.0 Mamba-790M NeoX 7.33 6.02 62.7 55.1 72.1 61.2 29.5 56.1 34.2 53.0 Mamba-2-780M NeoX 7.26 5.86 61.7 54.9 72.0 61.0 28.5 60.2 36.2 53.5 GPT-Neo 1.3B GPT2 — 7.50 57.2 48.9 71.1 56.2 25.9 54.9 33.6 49.7 Hybrid H3-1.3B GPT2 — 11.25 49.6 52.6 71.3 59.2 28.1 56.9 34.4 50.3 OPT-1.3B OPT — 6.64 58.0 53.7 72.4 56.7 29.6 59.5 33.2 51.9 Pythia-1.4B NeoX 7.51 6.08 61.7 52.1 71.0 60.5 28.5 57.2 30.8 51.7 RWKV4-1.5B NeoX 7.70 7.04 56.4 52.5 72.4 60.5 29.4 54.6 34.0 51.4 Mamba-1.4B NeoX 6.80 5.04 65.0 59.1 74.2 65.5 32.8 61.5 36.4 56.4 Mamba-2-1.3B NeoX 6.66 5.02 65.7 59.9 73.2 64.3 33.3 60.9 37.8 56.4 GPT-Neo 2.7B GPT2 — 5.63 62.2 55.8 72.1 61.1 30.2 57.6 33.2 53.2 Hybrid H3-2.7B GPT2 — 7.92 55.7 59.7 73.3 65.6 32.3 61.4 33.6 54.5 OPT-2.7B OPT — 5.12 63.6 60.6 74.8 60.8 31.3 61.0 35.2 55.3 Pythia-2.8B NeoX 6.73 5.04 64.7 59.3 74.0 64.1 32.9 59.7 35.2 55.7 RWKV4-3B NeoX 7.00 5.24 63.9 59.6 73.7 67.8 33.1 59.6 37.0 56.4 Mamba-2.8B NeoX 6.22 4.23 69.2 66.1 75.2 69.7 36.3 63.5 39.6 59.9 Mamba-2-2.7B NeoX 6.09 4.10 69.7 66.6 76.4 69.6 36.4 64.0 38.8 60.2 GPT-J-6B GPT2 – 4.10 68.3 66.3 75.4 67.0 36.6 64.1 38.2 59.4 OPT-6.7B OPT – 4.25 67.7 67.2 76.3 65.6 34.9 65.5 37.4 59.2 Pythia-6.9B NeoX 6.51 4.45 67.1 64.0 75.2 67.3 35.5 61.3 38.0 58.3 RWKV4-7.4B NeoX 6.31 4.38 67.2 65.5 76.1 67.8 37.5 61.0 40.2 59.3

Slide 26

Slide 26 text

1BQFSMJTUIUUQTQBQFSESPQCPYDPNEPD.BNCBQBQFSMJTU$650EY6/I-'WTBH#PJd/"2T)JR.RZKSMGRP";HK.W .BNCBͷ$7෼໺΁ͷల։ʢ೥݄̔ʣ ೥݄ ݄ ೥݄ ݄ ݄ ݄ ݄ .BNCB <(VBOE%BP BS9JW`> .BNCB <(VBOE%BP *$.-`> ݄ /-1෼໺ .BNCB/% <-J &$$7`> 7JTJPO38,7 <%VBO BS9JW`> -PDBM.BNCB <)VBOH &$$78`> & ffi DJFOU7.BNCB <1FJ BS9JW`> 1MBJO.BNCB <:BOH #.7$`> .VMUJ4DBMF7.BNCB <4IJ /FVS*14`> .BNCB3 <8BOH BS9JW`> %FNZTUJGZ.BNCB <)BO /FVS*14`> W)FBU <8BOH BS9JW`> 7JN' <;IBOH BS9JW`> 7JTJPO.BNCB <;IV *$.-`> 7.BNCB <-JV BS9JW`> .BNCB7JTJPO <)BUBNJ[BEFIBOE,BVU[ BS9JW`> (SPVQ.BNCB <4IBLFS BS9JW`> $7෼໺΁ͷల։ 6OJ fi FE*NQMJDJU"UUFOUJPO'PSNVMBUJPO <;JNFSNBO BS9JW`> .BNCB-31 <+BGBSJ /FVS*14`> ࢹ֮తઆ໌ "VUPSFHSFTTJWF1SFUSBJOJOH <3FO BS9JW`> ࣗݾڭࢣ͋Γֶश .BNCB:0-0 <8BOH BS9JW`> 'VTJPO.BNCB <%POH BS9JW`> ෺ମମݕग़ˍηάϝϯςʔγϣϯ 3F.BNCFS <:BOH &$$7`> 4JBNFTF.BNCB <8BO BS9JW`> .BNCBPS38,7 <:VBO BS9JW`> ϚϧνϞʔμϧ 7-.BNCB <2JBP BS9JW`> $-*1.BNCB <)VBOH BS9JW`> (SPPU7- <9JBP /FVS*14`> 4).BNCB <:BOH BS9JW`> .-.BNCB <)VBOHBOE)V BS9JW`> #&7.BNCB <-JV 5FYI3YJW`> 0&#FW4FH <4VO BS9JW`> '&3:0-0.BNCB <.B BS9JW`> 40"3 <7FSNB BS9JW`> ηάϝϯςʔγϣϯ ෺ମݕग़ #&71FSDFQUJPO

Slide 27

Slide 27 text

w .BNCBΛίϯϐϡʔλϏδϣϯλεΫʹద༻ͨ͠Ϟσϧ σʔλґଘܕͷάϩʔόϧͳࢹ֮తίϯςΩετͷϞσϦϯά ࢹ֮తཧղͷͨΊͷҐஔ৘ใͷ૊ΈࠐΈ 1PTJUJPOFNCFEEJOHT w ஫ҙػߏΛ࢖༻ͤͣʹ7JTJPO5SBOTGPSNFSʢ7J5ʣͱಉ౳ͷೳྗ %FJ5<5PVWSPO *$.-`>ΑΓ΋ഒߴ଎Ͱɺ(16ϝϞϦΛઅ໿ 7JTJPO.BNCB 7JN <;IV *$.-> Embedded Patches Norm ! Forward Conv1d Backward Conv1d Forward SSM Backward SSM Vision Mamba Encoder Input Image Vision Mamba Encoder Flatten & Linear Projection Projection Layer Patch Tokens Position Embed. Class Token 0 1 * Vision Mamba (Vim) Activation # MLP & Prediction 0 1 2 3 4 5 * 6 7 8 9 Embedded Patches Norm ! Forward Conv1d Backward Conv1d Forward SSM Backward SSM L× Vision Mamba Encoder Input Image Vision Mamba Encoder Flatten & Linear Projection Projection Layer Patch Tokens Position Embed. Class Token 0 1 * Vision Mamba (Vim) Activation # MLP & Prediction 0 1 2 3 4 5 * 6 7 8 9 -JOFBS -JOFBS -JOFBS 7JN#MPDL

Slide 28

Slide 28 text

w ҎԼͷύʔπͰߏ੒ -JOFBS1SPKFDUJPO 7JTJPO.BNCB&ODPEFS 7JN#MPDLɿ'PSXBSE$POWE 'PSXBSE44. #BDLXBSE$POWE #BDLXBSE44. .-11SFEJDUJPOIFBE 7JNɿΞʔΩςΫνϟ Embedded Patches Norm ! Forward Conv1d Backward Conv1d Forward SSM Backward SSM Vision Mamba Encoder Input Image Vision Mamba Encoder Flatten & Linear Projection Projection Layer Patch Tokens Position Embed. Class Token 0 1 * Vision Mamba (Vim) Activation # MLP & Prediction 0 1 2 3 4 5 * 6 7 8 9 Embedded Patches Norm ! Forward Conv1d Backward Conv1d Forward SSM Backward SSM L× Vision Mamba Encoder Input Image Vision Mamba Encoder Flatten & Linear Projection Projection Layer Patch Tokens Position Embed. Class Token 0 1 * Vision Mamba (Vim) Activation # MLP & Prediction 0 1 2 3 4 5 * 6 7 8 9 -JOFBS -JOFBS -JOFBS 7JN#MPDL

Slide 29

Slide 29 text

w 7JNͷਪ࿦ϓϩηε ը૾ ΛύονԽͯ͠ฏୱԽ ʢ ɿը૾αΠζ ɿνϟϯωϧ਺ ɿύοναΠζ ɿύον਺ʣ t ∈ ℝH×W×C xp ∈ ℝJ×(P2⋅C) (H, W) C P J 7JNɿΞʔΩςΫνϟ Embedded Patches Norm ! Forward Conv1d Backward Conv1d Forward SSM Backward SSM Vision Mamba Encoder Input Image Vision Mamba Encoder Flatten & Linear Projection Projection Layer Patch Tokens Position Embed. Class Token 0 1 * Vision Mamba (Vim) Activation # MLP & Prediction 0 1 2 3 4 5 * 6 7 8 9 Embedded Patches Norm ! Forward Conv1d Backward Conv1d Forward SSM Backward SSM L× Vision Mamba Encoder Input Image Vision Mamba Encoder Flatten & Linear Projection Projection Layer Patch Tokens Position Embed. Class Token 0 1 * Vision Mamba (Vim) Activation # MLP & Prediction 0 1 2 3 4 5 * 6 7 8 9 -JOFBS -JOFBS -JOFBS 7JN#MPDL

Slide 30

Slide 30 text

w 7JNͷਪ࿦ϓϩηε ը૾ ΛύονԽͯ͠ฏୱԽ ʢ ɿը૾αΠζ ɿνϟϯωϧ਺ ɿύοναΠζ ɿύον਺ʣ Λ ࣍ݩͷϕΫτϧʹࣹӨͯ͠ҐஔϕΫτϧ Λ෇༩ t ∈ ℝH×W×C xp ∈ ℝJ×(P2⋅C) (H, W) C P J xp D Epos ∈ ℝ(J+1)×D 7JNɿΞʔΩςΫνϟ Embedded Patches Norm ! Forward Conv1d Backward Conv1d Forward SSM Backward SSM Vision Mamba Encoder Input Image Vision Mamba Encoder Flatten & Linear Projection Projection Layer Patch Tokens Position Embed. Class Token 0 1 * Vision Mamba (Vim) Activation # MLP & Prediction 0 1 2 3 4 5 * 6 7 8 9 T0 = [tcls ; t1 p W, t2 p W; ⋅ ⋅ ⋅ , tJ p W] + Epos ΫϥετʔΫϯ ֶशՄೳͳࣹӨߦྻ T0 Embedded Patches Norm ! Forward Conv1d Backward Conv1d Forward SSM Backward SSM L× Vision Mamba Encoder Input Image Vision Mamba Encoder Flatten & Linear Projection Projection Layer Patch Tokens Position Embed. Class Token 0 1 * Vision Mamba (Vim) Activation # MLP & Prediction 0 1 2 3 4 5 * 6 7 8 9 -JOFBS -JOFBS -JOFBS 7JN#MPDL

Slide 31

Slide 31 text

w 7JNͷਪ࿦ϓϩηε ը૾ ΛύονԽͯ͠ฏୱԽ ʢ ɿը૾αΠζ ɿνϟϯωϧ਺ ɿύοναΠζ ɿύον਺ʣ Λ ࣍ݩͷϕΫτϧʹࣹӨͯ͠ҐஔϕΫτϧ Λ෇༩ τʔΫϯྻ Λ7JNFODPEFSͷ ൪໨ͷϨΠϠʔʹೖྗͯ͠ग़ྗτʔΫϯྻ Λܭࢉ t ∈ ℝH×W×C xp ∈ ℝJ×(P2⋅C) (H, W) C P J xp D Epos ∈ ℝ(J+1)×D Tl−1 l Tl 7JNɿΞʔΩςΫνϟ Embedded Patches Norm ! Forward Conv1d Backward Conv1d Forward SSM Backward SSM Vision Mamba Encoder Input Image Vision Mamba Encoder Flatten & Linear Projection Projection Layer Patch Tokens Position Embed. Class Token 0 1 * Vision Mamba (Vim) Activation # MLP & Prediction 0 1 2 3 4 5 * 6 7 8 9 Tl = Vim(Tl−1 ) + Tl−1 Embedded Patches Norm ! Forward Conv1d Backward Conv1d Forward SSM Backward SSM L× Vision Mamba Encoder Input Image Vision Mamba Encoder Flatten & Linear Projection Projection Layer Patch Tokens Position Embed. Class Token 0 1 * Vision Mamba (Vim) Activation # MLP & Prediction 0 1 2 3 4 5 * 6 7 8 9 -JOFBS -JOFBS -JOFBS 7JN#MPDL

Slide 32

Slide 32 text

w 7JNͷਪ࿦ϓϩηε ը૾ ΛύονԽͯ͠ฏୱԽ ʢ ɿը૾αΠζ ɿνϟϯωϧ਺ ɿύοναΠζ ɿύον਺ʣ Λ ࣍ݩͷϕΫτϧʹࣹӨͯ͠ҐஔϕΫτϧ Λ෇༩ τʔΫϯྻ Λ7JNFODPEFSͷ ൪໨ͷϨΠϠʔʹೖྗͯ͠ग़ྗτʔΫϯྻ Λܭࢉ ΫϥετʔΫϯΛਖ਼نԽ ɺ.-1ϔουʹೖྗͯ͠༧ଌΛग़ྗ t ∈ ℝH×W×C xp ∈ ℝJ×(P2⋅C) (H, W) C P J xp D Epos ∈ ℝ(J+1)×D Tl−1 l Tl f = Norm(T0 L ) ̂ p = MLP(f) 7JNɿΞʔΩςΫνϟ Embedded Patches Norm ! Forward Conv1d Backward Conv1d Forward SSM Backward SSM Vision Mamba Encoder Input Image Vision Mamba Encoder Flatten & Linear Projection Projection Layer Patch Tokens Position Embed. Class Token 0 1 * Vision Mamba (Vim) Activation # MLP & Prediction 0 1 2 3 4 5 * 6 7 8 9 Embedded Patches Norm ! Forward Conv1d Backward Conv1d Forward SSM Backward SSM L× Vision Mamba Encoder Input Image Vision Mamba Encoder Flatten & Linear Projection Projection Layer Patch Tokens Position Embed. Class Token 0 1 * Vision Mamba (Vim) Activation # MLP & Prediction 0 1 2 3 4 5 * 6 7 8 9 -JOFBS -JOFBS -JOFBS 7JN#MPDL

Slide 33

Slide 33 text

w ϏδϣϯλεΫͷͨΊʹ૒ํ޲ͷγʔέϯεϞσϦϯάΛಋೖ τʔΫϯྻΛલํ޲ʢϑΥϫʔυʣͱޙํ޲ʢόοΫϫʔυʣͷ྆ํͰॲཧ ը૾಺ͷ֤ύονؒͷ૒ํ޲తͳґଘؔ܎ΛޮՌతʹଊ͑Δ 7JNɿϒϩοΫ ॱํ޲ Embedded Patches Norm ! Forward Conv1d Backward Conv1d Forward SSM Backward SSM L× Vision Mamba Encoder Encoder Projection ba (Vim) Activation # iction 6 7 8 9 -JOFBS -JOFBS -JOFBS 7JN#MPDL ٯํ޲ τʔΫϯྻ τʔΫϯྻ ˠը૾෼ྨ ෺ମݕग़ ηάϝϯςʔγϣϯͳͲɺ͞·͟·ͳϏδϣϯλεΫʹॊೈʹదԠՄೳ

Slide 34

Slide 34 text

w .-1ϔουʹೖྗ͞ΕΔτʔΫϯΛ࣮ݧతʹઃܭ .FBOQPPM ɿ7JNϒϩοΫ͔Βͷग़ྗʹରͯ͠ฏۉϓʔϦϯά .BYQPPM ɿ.-1ϔου͔Βͷ֤τʔΫϯͷग़ྗʹରͯ͠.BYϓʔϦϯά )FBEDMBTTUPLFO ɿΫϥετʔΫϯΛτʔΫϯͷઌ಄ʹ࿈݁ %PVCMFDMBTTUPLFO ɿΫϥετʔΫϯΛτʔΫϯͷઌ಄ͱ຤ඌʹ࿈݁ .JEEMFDMBTTUPLFO ɿΫϥετʔΫϯΛτʔΫϯͷதԝʹ௥Ճ 7JNɿը૾෼ྨλεΫͷͨΊͷग़ྗઃܭ Classification strategy ImageNet top-1 acc. Mean pool 73.9 Max pool 73.4 Head class token 75.2 Double class token 74.3 Middle class token 76.1 Table 5. Ablation study on the classification design. The default setting for Vim is marked in blue . modeling power as Transfor putation complexity. Benefi designs of Mamba, the infe age of Vim are significantly cessing high-resolution imag dard computer vision benchm ing power and high efficiency great potential to be the next In future works, Vim with ing with position embedding tasks such as mask image ˠ.JEEMFDMBTTUPLFO͕࠷ߴਫ਼౓Λୡ੒

Slide 35

Slide 35 text

w ࣄલֶशޙʹ௕͍γʔέϯεͷઃఆͰ7JNΛϑΝΠϯνϡʔχϯά 7JNͷޮ཰తͳ௕͍γʔέϯεϞσϦϯάೳྗΛ࠷େݶʹ׆༻ ࣄલֶश࣌ͷύονநग़ετϥΠυΛมߋ ࣄલֶशͱಉ͡σʔληοτΛ࢖༻ 7JNɿ-POH4FRVFODF'JOFUVOJOH 1SFUSBJOJOH -POH4FRVFODF'JOFUVOJOH σʔληοτ *NBHF/FU, *NBHF/FU, 0CKFFDUJWF ڭࢣ͋Γֶश ڭࢣ͋Γֶश ೖྗαΠζ º º ύοναΠζ º º ετϥΠυ

Slide 36

Slide 36 text

7JNɿ-POH4FRVFODF'JOFUVOJOH ɾɾɾ ɾɾɾ ɾɾɾ 1SFUSBJOJOH -POH4FRVFODF'JOFUVOJOH ύον਺ɿ ετϥΠυɿ ʢύονը૾ؒʹΦʔόϥοϓແ͠ʣ ύον਺ɿ ετϥΠυɿ ʢύονը૾ؒʹΦʔόϥοϓ͋Γʣ

Slide 37

Slide 37 text

7JNͷޮՌ B ɿࣄલֶशλεΫͱϑΝΠϯνϡʔχϯάλεΫͷ྆ํͰ%FJ5ΑΓ޲্ C ɿߴղ૾౓ը૾ͷॲཧʹ͓͍ͯ%FJ5ΑΓ΋ܭࢉޮ཰ͱϝϞϦޮ཰͕ߴ͍ Lianghui Zhu1⇤, Bencheng Liao1⇤, Qian Zhang2, Xinlong Wang3, Wenyu Liu1, Xinggang Wang1 1 Huazhong University of Science and Technology 2 Horizon Robotics 3 Beijing Academy of Artificial Intelligence Code & Models: hustvl/Vim 42 43 44 45 46 Detection mAP (%) 36 37 38 39 40 Ins. Seg. mAP (%) 71 73 75 77 Classification Top-1 Acc. (%) 38 39 40 41 Sem. Seg. mIoU (%) (a) Accuracy Comparison 1 1.4 1.8 2.2 2.6 512 640 738 1024 1248 FPS w/ log scale Resolution DeiT-Ti Vim-Ti 2.54 2.25 2.05 1.57 1.26 2.29 2.07 1.91 1.71 (b) Speed Comparison 0 20 40 60 80 512 640 738 1024 1248 GPU Memory (GB) Resolution DeiT-Ti Vim-Ti 4.56 4.22 12.48 8.13 11.14 8.09 5.03 40.09 OOM (c) GPU Memory Comparison 3.32 DeiT-Ti Vim-Ti Faster Smaller 2.8× faster -86.8% memory Figure 1. Performance and efficiency comparisons between DeiT [59] and our Vim model. For the accuracy comparison, we first pretrain DeiT and Vim on IN1K classification dataset [9], then we finetune the generic backbones on different downstream dense prediction tasks, i.e., semantic segmentation, object detection, instance segmentation. Results show that the proposed Vim outperforms DeiT on both pretraining and finetuning tasks. Vim is also more computation and memory efficient than DeiT in dealing with high-resolution images. For example, Vim is 2.8⇥ faster than DeiT and saves 86.8% GPU memory when performing batch inference to extract features on images with a resolution of 1248⇥1248, i.e., 6084 tokens per image. Abstract tion & memory efficiency. For example, Vim is 2.8⇥ faster 9417v2 [cs.CV] 10 Feb 2024

Slide 38

Slide 38 text

7JNͷಛ௃දݱ ܗঢ়ʹ஫໨ ςΫενϟʹ஫໨ ܗঢ়ͷΧςΰϦ $//˔͸ςΫενϟɺ5SBOTGPSNFS˛͸ܗঢ়ʹண໨ͯ͠ਪ࿦ .BNCB ♦︎ ͸$//ͱ5SBOTGPSNFSͷதؒతͳ܏޲

Slide 39

Slide 39 text

w .BNCBCMPDLΛ࠶ઃܭͯ͠ϋΠϒϦουΞʔΩςΫνϟΛఏҊ εςʔδ͔ΒͳΔߏ੒ w 4UBHF ɿ$POWPMVUJPOCMPDL w 4UBHF ɿ)ZCSJE.BNCB5SBOTGPSNFSCMPDL .BNCB7JTJPO<)BUBNJ[BEFIBOE,BVU[ BS9JW> Stem Conv Block Downsample MambaVision Mixer MLP Downsample Stage 1 Conv Block Stage 2 Self-Attention MLP Stage 3 Downsample MambaVision Mixer MLP Self-Attention MLP Stage 4 2D Avg Pool Linear )ZCSJE.BNCB5SBOTGPSNFS#MPDL

Slide 40

Slide 40 text

w ը૾ύονؒͷॱংؔ܎Λߏங͠ɺॱংؔ܎ʹج͍ͮͯ͜ΕΒͷύονΛ.BNCBʹೖྗ ը૾ύον͸7JTJPO5SBOTGPSNFSͱಉ༷ʹը૾Λύονʹ෼ղ w ̎ͭͷҟͳΔ؍఺͔Β༷ʑͳํ๏͕ఏҊ γʔέϯεͷલޙؔ܎ ϞσϧΞʔΩςΫνϟ .BNCBͷ$7λεΫ΁ͷల։ 7JTVBM.BNCB"4VSWFZBOE/FX0VUMPPLT<9V BS9JW> Raster Bidirectional (BD) Horizontal (H) Raster Vim Scan Raster Bidirectional (BD) Atrous Sampling Horizontal (H) Vertical (V) Raster EVSS Scan Bidirectional (BD) Horizontal (H) Vertical (V) Raster VSS Scan Bidirectional (BD) Atrous Sampling Horizontal (H) Vertical (V) Raster EVSS Scan Bidirectional (BD) Horizontal (H) Vertical (V) Zigzag Local Mamba Scan Bidirectional (BD) Local Sampling Horizontal (H) Vertical (V) Raster Scan Plain Mamba

Slide 41

Slide 41 text

w .BNCB ෆཁͳ৘ใΛഉআ͠ඞཁͳσʔλΛอ࣋͢Δબ୒ϝΧχζϜ ϋʔυ΢ΣΞʹ࠷దԽ͞ΕͨΞϧΰϦζϜΛಋೖ͢Δ͜ͱͰ ܭࢉޮ཰Λେ෯ʹ޲্ w 7JTJPO.BNCB .BNCBΛ$7λεΫʹల։ ૒ํ޲ͷγʔέϯεϞσϦϯάΛಋೖ ·ͱΊɿ.BNCB 7JTJPO.BNCB 7JN 2 Fig. 1 The statistics of Mamba-based papers released to date on vision tasks, spanning di↵erent modalities including Image, Video, Point Cloud, and Multi-Modal of image patches, have demonstrated remarkable mod- eling capabilities across various visual tasks (Liu et al, 2021). Self-attention enables ViTs to capture long-range dependencies within images, providing a significant ad- vantage over traditional CNNs that rely on local recep- tive fields. This capability allows ViTs to exhibit robust and can concept for sequ models CNNs. for proc can be their ab parallel been wi putatio state re the adv these li matrice transfor corpora et al, 20 et al, 2 ever, SS context the e c 2017). I pose to nism int propaga scan pa e cient .BNCBͷޮՌΛൃشͰ͖Δϋʔυ΢ΣΞͷબผ͸ॏཁ ˠ$7λεΫ༻ͷ.BNCBؔ࿈ͷ࿦จ͕૿Ճத ஫ҙ఺