Slide 1

Slide 1 text

TextPruner による 大規模言語モデルの軽量化 ELYZA 三澤遼 NLP Hacks 2022/5/13

Slide 2

Slide 2 text

2 ● 三澤遼 ● Twitter: twitter.com/misawann06 ● 東京大学物理工学科3年 ● ELYZA: 2020/10 ~ ● 研究 / サービス開発チームに所属 ● 大規模言語モデルの開発・応用 ● ELYZA DIGEST 自己紹介

Slide 3

Slide 3 text

3 目次 ● TextPrunerについて ● 枝刈りとは ● 実装されているアルゴリズム・使い方 ● 既存ライブラリとの比較 ● 実験設定・結果 ● 複数設定で BERT と JaQuAD を題材に Transformer 層の枝刈り ● 定量・定性評価 ● 枝刈りされる箇所の分析 ● まとめ ● 実験できなかったことについて ● 全体のまとめ

Slide 4

Slide 4 text

4 目次 ● TextPrunerについて ● 枝刈りとは ● 実装されているアルゴリズム・使い方 ● 既存ライブラリとの比較 ● 実験設定・結果 ● 複数設定で BERT と JaQuAD を題材に Transformer 層の枝刈り ● 定量・定性評価 ● 枝刈りされる箇所の分析 ● まとめ ● 実験できなかったことについて ● 全体のまとめ

Slide 5

Slide 5 text

5 枝刈りについて ● 不要なニューロンを捨てることでパラメータ削減 ● structured pruning ● 重み行列の列または行を除去 ● 今回扱う Transformer Pruning はこれに該当 ● unstructured pruning ● 列や行単位ではなく,各パラメータが枝刈り対象 ● 訓練必要 ● e.g. Movement Pruning (Sanh et al., 2020) ● 訓練不要 ● e.g. Michel et al. (TextPruner で実装されているアルゴリズム) ※ 以前ブログで公開した大規模言語モデル軽量化に関するサーベイ https://elyza-inc.hatenablog.com/entry/2021/05/21/163349 枝刈りはモデル軽量化の1手法 TextPruner から引用

Slide 6

Slide 6 text

6 ● GitHub: https://github.com/airaria/TextPruner ○ 半年前に公開されたライブラリ ○ スターは150件程度 ● preprint: https://arxiv.org/abs/2203.15996 ● 主著の Ziqing Yang さんは蒸留ライブラリ TextBrewer も開発 ○ GitHub: https://github.com/airaria/TextBrewer ○ 2年前に公開されスターは 1100件程度 ● TextPruner はその姉妹ライブラリ ACL2022 demo で採択された枝刈りライブラリ TextPruner について

Slide 7

Slide 7 text

7 Vocabulary Pruning, Transformer Pruning, Pipeline Pruning が実装済 TextPruner について

Slide 8

Slide 8 text

8 現在9種類のモデルがサポートされている ● encoder モデルは両方実装済み ● encoder-decoder モデルは Vocabulary Pruning のみ ● Transformer Pruning は decoder にも適用できるはず。 実装されていないだけ? TextPruner について

Slide 9

Slide 9 text

9 Vocabulary Pruning のユースケース 多言語モデルの単言語化に有効 ● 多言語モデルのほとんどのパラメータは embedding が占める ● pruning することで性能劣化を抑えたまま軽量化できる XNLI に対して Vocabulary Pruning を適用した結果 TextPruner の README から引用

Slide 10

Slide 10 text

10 Vocabulary Pruning の実行方法 指定した vocabulary のみを残す from textpruner import VocabularyPruner from transformers import AutoModelForQuestionAnswering, AutoTokenizer model = AutoModelForQuestionAnswering.from_pretrained( "SkelterLabsInc/bert-base-japanese-jaquad" ) tokenizer = AutoTokenizer.from_pretrained( "SkelterLabsInc/bert-base-japanese-jaquad” ) pruner = VocabularyPruner(model, tokenizer) pruner.prune(dataiter=texts) # texts: a list of strings to leave 以下を準備 Ø transformers 形式のモデル Ø transformers 形式の トークナイザ prune の引数について Ø dataiter 残したい文字列のリスト Ø additional_tokens トークンのリスト Ø additional_token_ids トークンの ID のリスト

Slide 11

Slide 11 text

11 Michel et al. (2019) の手法を実装 ● IS (=Importance Score) が低い FFN / attention head を削る ● IS Θ = 𝔼!~# $ℒ(!) $( Θ で重要度を測る ● 𝚯 はニューロンまたはその出力 ● FFN・attention head に対して実装されている l 𝚯 の重要度 = 𝓛(𝒙)の 𝚯 に対する sensitivity l 𝚯 の変化に対して損失関数の変化が大きいなら 𝚯 が重要と考える 教師あり Transformer Pruning について Michel, P., Levy, O., & Neubig, G. (2019). Are sixteen heads really better than one?. Advances in neural information processing systems, 32.

Slide 12

Slide 12 text

12 loss ではなく,logits を使う手法を提案 教師なし Transformer Pruning について ● IS の loss として以下の KL-divergence を使用 ● ℒ!" (𝑥) = KL(stopgrad(𝑞(𝑥)||𝑝 𝑥 )) ● 𝒒(𝒙): フルモデルの予測確率分布 ● 𝒑(𝒙): 枝刈りしたモデルの予測確率分布 ● 𝚯 が重要 = 𝚯 の変化に対して 𝒒 𝒙 , 𝒑 𝒙 の差異が大きい

Slide 13

Slide 13 text

13 Transformer Pruning について 教師あり・教師なしの手法を実装 from textpruner import TransformerPruner, TransformerPruningConfig transformer_pruning_config = TransformerPruningConfig( target_ffn_size=1536, target_num_of_heads=6, pruning_method="iterative", # or "masks" n_iters=2, use_logits=False, # True for unsupervised mode ) pruner = TransformerPruner( model, transformer_pruning_config=transformer_pruning_config, ) # torch_ds = hf_ds.set_format(type='torch') # dataloader = torch.utils.data.DataLoader(torch_ds, batch_size=BATCH_SIZE) pruner.prune(dataloader=dataloader, save_model=True) # dataloader: PyTorch dataloader ● attention head の数,FFN のサイズの枝刈りが可能 以下を準備 Ø transformers 形式の モデル Ø Pytorch の dataloader 引数について Ø pruning_method iterative: IS に基づく 枝刈りを反復 masks: 指定した箇所 を枝刈りする Ø use_logits True: 教師なし枝刈り False: 教師あり枝刈り

Slide 14

Slide 14 text

14 Pipeline Pruning について Vocabulary Pruning と Transformer Pruning の 組み合わせ from textpruner import PipelinePruner, TransformerPruningConfig transformer_pruning_config = TransformerPruningConfig( target_ffn_size=1536, target_num_of_heads=6, pruning_method="iterative", n_iters=2, ) pruner = PipelinePruner( model, tokenizer, transformer_pruning_config=transformer_pruning_config, ) pruner.prune(dataloader=dataloader, dataiter=texts, save_model=True)

Slide 15

Slide 15 text

15 Pytorch では重みのノルムによる枝刈りが実装 既存のライブラリとの⽐較 torch.nn.utils.prune.LnStructured torch.nn.utils.prune.LnUnstructured Structured l モジュール,パラメータ, 枝刈りする割合,次元, ノルムを指定 l 指定した次元に沿って ノルムを計算し 値が小さいものを除去する UnStructured l 次元は指定しない torch.nn.utils.prune.L1Structured torch.nn.utils.prune.L1Unstructured l LnStructured/Unstructured で n = 1 の場合 l つまり,L1ノルムを使用 torch.nn.utils.prune.RandomStructured torch.nn.utils.prune.RandomUnStructured l ランダムに枝刈り

Slide 16

Slide 16 text

16 目次 ● TextPrunerについて ● 枝刈りとは ● 実装されているアルゴリズム・既存ライブラリとの比較 ● 使い方 ● 実験設定・結果 ● 複数設定で BERT と JaQuAD を題材に Transformer 層の枝刈り ● 定量・定性評価 ● 枝刈りされる箇所の分析 ● まとめ ● 実験できなかったことについて ● 全体のまとめ

Slide 17

Slide 17 text

17 モデルは BERT, タスクには JaQuAD を使用 ● JaQuAD l Japanese SQuAD l NLP Hacks でも紹介した QA データセット l 参考: NLP News #2 l preprint: https://arxiv.org/abs/2202.01764 ● ベースラインの BERT と実験コードを公開している l HF: hf.co/SkelterLabsInc/bert-base-japanese-jaquad l GitHub: https://github.com/SkelterLabsInc/JaQuAD l ベースラインモデルに対して枝刈りを実施 l 実験に使用した Notebook 実験設定

Slide 18

Slide 18 text

18 教師あり Transformer Pruning を試し定量・定性評価 ● 実験内容 1. attention head の数を固定し,FFN のサイズを20%ずつ落とす 2. FFN のサイズを固定し,attention head の数を20%ずつ落とす 3. attention head の数・FFNのサイズの適当な複数ペアで実験 4. 枝刈りのイテレーションを変える ● 1~3は n_iter = 2 で固定 ● 1~4で使用するデータは効率化のため訓練データのうち1024件のみ ● 評価方法 l 定量評価 l EM (Exact Match) 及び F1 で評価 l 定性評価 ● パラメータ削減による予測結果変化の特徴分析 ● ランダムに抽出した50件のデータに対する予測結果に目を通す 実験設定

Slide 19

Slide 19 text

19 実験コードを修正する必要がある ● そのまま実行すると EM が30程度になる ● 以下のように修正 ● その他軽微な修正で報告されている スコアから10ポイント程度改善する とのこと ● 参考 PR 実験設定 ※ 蛇足 JaQuAD.ipynb # 修正前 for i in range(0, input_len - max_seq_len + stride, stride): # 修正後 for i in range(0, max(input_len - max_seq_len + stride, stride), stride):

Slide 20

Slide 20 text

20 20% 程度の FFN 削減なら性能劣化を抑えられる ● 40% 削減すると3~4ポイント程度減少 ● それ以降は10ポイント以上劣化 FFN の割合 F1 EM 100% 77.9 61.4 80% 77.2 61.0 60% 74.0 58.1 40% 62.7 46.7 20% 26.3 8.63 実験結果 1

Slide 21

Slide 21 text

21 attention head も20%程度の削減なら微劣化 実験結果 2 attention head の割合 F1 EM 100% 77.9 61.4 80% 76.7 60.3 60% 71.6 54.5 40% 49.3 32.2 20% 17.4 5.9 ● FFN のサイズ削減よりも性能劣化しやすい

Slide 22

Slide 22 text

22 FFN を小さくする方が効率的にパラメータ削減できる 実験結果 3 FFN のサイズ,attention head の数 FFN のサイズ,attention head の数

Slide 23

Slide 23 text

23 枝刈りのイテレーションを増やすと 1~2ポイント程度改善する 実験結果 4 イテレーション数 F1 EM 10 75.6 58.7 6 75.3 58.5 4 75.3 58.6 2 75.1 58.6 1 74.7 58.0

Slide 24

Slide 24 text

24 余分な箇所を選択することがある 定性評価 context: ところが、マクドナルド労働党政権は、1924年2月にソ連と外交関係を樹立し、4月14日から対ソ一般条約締結を目的と した交渉をロンドンで開始した。...(省略) question: 1924年4月14日からソ連との条約締結を進めたイギリスの政権は、誰が指導者であったか? answer: マクドナルド full(ffn3072,head12): マクドナルド ffn2150,head10: マクドナルド労働党政権は、1924年2月にソ連と外交関係を樹立し、4月14日から対ソ一般条約締結を目的とし た交渉をロンドンで開始した。8月5日まで続いたこの交渉自体はイギリス人財産賠償問題を巡って決裂したのだが、その直後に労働 党左派議員 context: ドクウツボ(毒鱓)Gymnothoraxjavanicus(Bleeker,1859)は体長3メートルの記録がある大型種で、鰓孔が黒いこ とで近縁種と区別できる。 インド洋と太平洋の熱帯域に広く分布し、日本では琉球列島で見られる。...(省略) question: ドクウツボはインド洋とどの海域の熱帯域に分布しますか? answer: 太平洋 full(ffn3072,head12): 太平洋 ffn1843,head7: 太平洋の熱帯域

Slide 25

Slide 25 text

25 その他の性能劣化事例と改善事例 定性評価 context: ドイツ側の損害についてシュトロープの報告書は「16名が殺害され、85名が負傷した」としているが、この数は日々の損 害報告とまったく一致していない。...(省略) question: シュトロープの報告書によると、殺害と負傷のうち、どちらに該当するドイツ軍の数がより多かったの? answer: 負傷 full(ffn3072,head12): 負傷 ffn2150,head10: 85名 context: ...(省略)量産初期のステンレス鋼は充分な精錬ができなかったため材質のよいものではなかったが、1940年代に酸 素脱炭法、1960年代にAOD法・VOD法が実用化され、品質が飛躍的に向上した。...(省略) question: 酸素脱炭法とAOD法のうち、より早く実用化されたのは、どれ? answer: 酸素脱炭法 full(ffn3072,head12): AOD法・VOD法 ffn1843,head7: 酸素脱炭法 ※ その他の例は Notebook の出力をご覧ください context: 3か所とも天竜川水系和知野川の発電所であり、上流側から和合発電所・豊発電所・和知野発電所の順に建設された。 ...(省略) question: 和知野発電所、和合発電所、豊発電所の中で最も下流側に位置しているのはどれ? answer: 和知野発電所 full(ffn3072,head12): 和合発電所 ffn:1843,head7: 豊発電所 性能劣化事例 性能改善事例

Slide 26

Slide 26 text

26 重みが小さい箇所ほど枝刈りされやすいわけではない = 重みの大きさを重要度と見做す手法とは異なる結果 枝刈りされる head の分析 層,head 別の Frobenius ノルム ※ query を使用 各層における枝刈りされたhead の分布 枝刈りされた head : 1, されなかった箇所: 0 ※ FFN, attention head 30 % 削減したモデルでの結果

Slide 27

Slide 27 text

27 目次 ● TextPrunerについて ● 枝刈りとは ● 実装されているアルゴリズム・使い方 ● 既存ライブラリとの比較 ● 実験設定・結果 ● 複数設定で BERT と JaQuAD を題材に Transformer 層の枝刈り ● 定量・定性評価 ● 枝刈りされる箇所の分析 ● まとめ ● 実験できなかったことについて ● 全体のまとめ

Slide 28

Slide 28 text

28 ● 枝刈り後の追加学習 ● 蒸留の文脈では,パラメータ移植後の生徒モデルが追加学習されたりする ● e.g. Shleifer & Rush, 2020, Pre-trained Summarization Distillation ● 定性評価の結果,致命的な破綻はなかったため,少量の追加学習で 性能回復する見込みがある ● 枝刈りしたモデルと同サイズのまま fine-tuning したモデルの比較 ● Transformer Pruning の教師あり / 教師なしの比較 ● 教師あり Pruning に使用するラベル付きデータ量の性能依存性 ● 今回は約3万件の訓練データのうち1024件のみを使った ● 他の枝刈り手法との比較 ● e.g. Pytorch のような重みの大きさによる枝刈り手法 今回実験できなかったこと

Slide 29

Slide 29 text

29 ● ACL 2022 で採択された枝刈りライブラリを紹介 ● 教師あり/教師なし Transformer Pruning, Vocabulary Pruning が実装済 ● Vocabulary Pruning は多言語モデルの単言語化などで有用 ● Transformer Pruning では FFN のサイズ, attention head の数を削減できる ● JaQuAD で学習した BERT に対して Transformer Pruning を適用 ● FFN を削減する方がパラメータ数を多く削りつつも性能劣化しにくい ● 枝刈りのイテレーションを増やすと,1~2ポイント性能改善する ● 定性評価によると,致命的な性能劣化をしているわけではないので 追加学習により性能回復しそう まとめ 作成した1モデルを以下で公開しました misawann/bert-base-jaquad-ffn2150-head-10