Upgrade to Pro — share decks privately, control downloads, hide ads and more …

論文紹介: Bi-Directional Block Self-Attention for Fast and Memory-Efficient Sequence Modeling

論文紹介: Bi-Directional Block Self-Attention for Fast and Memory-Efficient Sequence Modeling

首都大 小町研 春の最先端論文紹介 2018

Satoru Katsumata

December 10, 2023
Tweet

More Decks by Satoru Katsumata

Other Decks in Research

Transcript

  1. Bi-Directional Block Self-Attention for Fast and Memory-Efficient Sequence Modeling Tao

    Shen, Tianyi Zhou, Guodong Long, Jing Jiang, Chengqi Zhang (ICLR 2018) 論文紹介: 小町研 M1 勝又智 (論文は arXiv v1 のものを紹介)
  2. 概要 • sequence encoding model の話 • ネットワークとして RNN、CNN、SAN (self-attention

    networks) がある RNN: 広範囲の関係性が取れる。が、並列処理ができず、時間がかかる。 CNN: 並列処理可能。が、広範囲の関係性をみるのは大変。 SAN: 並列処理可能で広範囲の関係性も取れる。が、メモリ的に重い。 → それほどメモリを使わずにSANを設計した話 • sequence を block に分割することで 1度の self-attention を小規模に行うモデル • 評価 9つのNLPベンチマークで提案手法が同じくらいの精度を、 メモリ効率をあげながら達成した。 2
  3. Network 比較 • RNN: 多くの NLP タスクで使われる。LSTM などで勾配問題にも対応。 → time

    cost が未だしんどい。(特に長文) • CNN: 並列処理できる点からいくつかの NLP タスクに用いられていたりする。 → long-range dependency を捕まえ辛い。(階層的に積むから) • SAN: 最近流行り。sequence内の各 token について、他のtokenとのattentionを 用いてcontext-aware な表現を作る。 long-range、local dependency どちらもパッと取れる。 attentionを求めていくところはパッとできる。(並列処理) → 全ての token pairs のattention score を保持するのでめっちゃメモリ使う。 (長文になる程辛い)(DiSANの方、Multi-headはそんなに) → RNNぐらいのメモリ消費でSANのいいところを残したモデルを作りたい。 3
  4. 提案手法の気持ち 1. sequence を長さが等しいブロックに分割する。 2. それぞれのブロック内で self-attention を求める。(local dependency) 3.

    すべてのブロックの出力に対して self-attention を求める。(global dependency) ↑ self-attention の処理は小さいsequenceに対してのみ行われる。 他にやったこととして(新規性はない) • Bi-directional なモデルにした • attentionの計算をfeature-levelに行なった 4 (DiSAN: Directional Self-Attention Network for RNN/CNN-Free Language Understanding. Shen et al., arXiv 2017) 次のスライドからsequence → seq
  5. attentionについて(Vanilla Attention) seq x とquery token があって、vannila attention は query

    に対する seq 内の各トークンの alignment score を計算する。 各トークンに対する score に対してsoftmax かけて、query に対する seq 内の各トークンの 重要度合いを求める。 (f(xi, q)はxiとqのalignment scoreを求めるやつ) 5 ↑ Multiplicative attention ↑ Additive attention n: seq length
  6. attentionについて(Multi-dimentional Attention) 何をやるか: Alignment score を embedding の各次元に対して求める。 (文脈によって意味が異なる単語とかによく効きそう) 6

    de: embed size query に対する token i の feature k の重要度 Piは additive attention で計算できる。 (additive の w^T を W にすれば score が vector になるから)
  7. self-attention • token2token (Hu et al., 2017; Vaswani et al.,

    2017; Shen et al., 2017) あるseq内のトークンxjがseq内のどれに対応しているかみる話 → xjをquery tokenとして考える • source2token (Lin et al., 2017; Shen et al., 2017; Liu et al., 2016) seq全体に対する各トークンの重要性をみる話 → query を削除する 7
  8. 順序情報を捉えるために(for token2token) f(xi, xj) で求めた score をマスク処理することで token2token attention に順序情報を

    考慮させる。 8 ↑ query token xj より前にあるのか後ろにあるのかで forward、backwardを区別 ← Shen et al., 2017 より引用 Wではない; c=5でやってる
  9. 提案手法: ① Intra-block self-attention 入力 seq からr 個ずつトークンを持ってきて、合計 m このブロックを作成する。

    それぞれのブロック内で Masked self-attention をかけて長さ r の表現を作成する。 → block内の local dependency を捉えたい 11 ・必要に応じてpadding ・ブロックの長さ r はハイパーパラメータ → メモリ効率が最も良くなる値に設定 (詳細は次スライド)
  10. ブロック分割の話 使用するデータセットに対して、メモリ消費が最小化するように r を設定する。 1. データセットの文長が固定長 n であるとする。 ①メモリ消費の主な部分は masked

    self-attention ②masked self-attention のメモリ消費は文長の2乗 ③mBloSAは長さ r のブロック m 個に masked self-attention して、長さ m の系列に対して 1 回 masked self-attention す る → この時のメモリ消費量が最小になる r は r = 2. データセットの文長を考える。(固定長 n ではないはずなので) 具体的には、mini-batch (B個の文がある) における最大文長の期待値を求める。 (正確には上限を抑える) そんなわけでブロックの長さはデータセットに対して決められる。 12 ξ: メモリ使用量 Xi: mini-batch内の i 文目の文長 詳細な式変形は論文内の Appendix 参照
  11. 提案手法: ② Inter-block self-attention ・各ブロックからの長さ r の出力に対して source2token self-attention を通して、

    ブロックごとのベクトル表現(local dependency 情報が溜まってる)を獲得する。 ・ブロック間の関係 (global dependency) を獲得するために Masked self-attention をか ける。 ・local、global dependency を組み合わせたいので、Masked self-attention への 入出力を gate でくっつける。 13
  12. 提案手法: ③ context fusion ①、②で word embeddings の seq x

    から local context を表す h や、それを元にglobal context を表す e を作成した。 → gate に入れてくっつける 14
  13. 評価 • 9つの NLPベンチマークで評価した。 Natural language inference、Reading comprehension、Semantic relatedness、 Sentence

    classfication (CR、MPQA、SUBJ、TREC、SST-1、SST-2) → NLI と Reading comprehension を紹介する。 • 比較したモデル(明記されない場合は600D) ◦ Bi-LSTM: Bi-directional LSTM ◦ Bi-GRU: Bi-directional GRU ◦ Bi-SRU: Bi-directional SRU ◦ Multi-CNN: CNN sentence embedding model (3, 4, 5-gram) ◦ Hrchy-CNN: 3-layer 300D CNN (kernel: 5) ← convlolutional encoder とかのやつ ◦ Multi-head: Multi-head attention (head: 8, hidden: 75 → total 600D) ◦ DiSAN: Directional self-attention network (Shen et al., 2017) 15
  14. NLI experiments: setting Natural language inference: 文対 (premise, hypthesis sent)

    の関係性を当てる。 具体的には3つの関係性: entailment, neutral, contradiction ↑ 詳しくは https://nlp.stanford.edu/projects/snli/ dataset: Stanford Natural Language Inference (SNLI) training: 549,367 samples; dev: 9,842 samples; test: 9,824 samples 作成した文ベクトルから関係性を当てるモデル (Bowman et al., 2016): premise sent と hypothesis sent それぞれの 文ベクトルを作ってそれらを組み合わせて 関係性を当てるモデル。 16 ハイパーパラメータとかは論文参照 Conneau et al., 2017 から引用→
  15. NLI experiments: result (メモリ、時間) 18 → メモリがRNNベースと同じくらい、 training time はCNN

    ベースや Multi-head よりいくらか遅い くらい、精度は最も良い
  16. Reading comprehension experiments: setting なんか passage とそれに対応した質問があって、そのpassageの中から質問に対する 解答を探す話。 dataset: Stanford

    Question Answering Dataset (SQuAD) passage: Wikipediaの記事; 解答: 記事内のあるspan; question: 記事に基づいて人手で作成 ↑ 詳しくは https://rajpurkar.github.io/SQuAD-explorer/ → 今回 sequence encoding の比較をしたいので、answer span を当てるのではなく、 解答が含まれる文を見つけるタスクとして扱う。 というわけで今回のモデルは次のスライド 20 ハイパーパラメータとかは論文参照
  17. Analysis: 文長に対する時間とかメモリ消費とか batch size と feature num を固定して、文長を 16 刻みで変えたものを処理した際の推論に

    必要な時間とメモリ消費を調べた。 ([batch_size, seq length, feature num] の tensor を seq length について 16 to 384 で変えて入力した(batch: 64; feature: 300)) 24
  18. まとめ • Bi-directional block self-attention network を提案した。 • sequence をブロックに分けて

    self-attention をかけることでメモリ消費を抑えた。 • local and long-range context dependency を捉えるために intra-block と inter-block self-attention を用いた。 • temporal order information を考慮するために Mask を利用して directional にし た。 • 9つの NLP ベンチマークで評価した。 → 精度を保ったまま、RNNぐらいのメモリ使用で、mulit-head ぐらいの速度の性能 を示した。 25
  19. 参考にしたもの • 実装 ◦ Tensorflow (author): https://github.com/taoshen58/BiBloSA ◦ PyTorch: https://github.com/galsang/BiBloSA-pytorch

    • Natural language inference by tree-based convolution and heuristic matching. Mou et al., ACL 2016 • A Fast Unified Model for Parsing and Sentence Understanding. Bowman et al., ACL 2016 • Supervised Learning of Universal Sentence Representations from Natural Language Inference Data. Conneau et al., EMNLP 2017 • DiSAN: Directional Self-Attention Network for RNN/CNN-Free Language Understanding. Shen et al., arXiv 2017 • Reinforced mnemonic reader for machine comprehension. Hu et al., arXiv 2017 • Attention is all you need. Vaswani et al., NIPS 2017 • A structured self-attentive sentence embedding. Lin et al., ICLR 2017 • Learning natural language inference using bidirectional lstm model and inner-attention. Liu et al., arXiv 2016 26