Upgrade to Pro — share decks privately, control downloads, hide ads and more …

論文解説 LoRA : Low Rank Adaptation of Large Language Models

koharite
June 22, 2023

論文解説 LoRA : Low Rank Adaptation of Large Language Models

Presentation for explaining the paper "LoRA "presented at ICLR2022.
LoRA (Low Rank Adaptation) is a useful method for tuning foundation (very large) models such as GPT and Stable Diffusion according to the user purpose.

koharite

June 22, 2023
Tweet

More Decks by koharite

Other Decks in Research

Transcript

  1. 2 論⽂情報 タイトル:LoRA: Low-Rank Adaptation of Large Language Models •

    論⽂: https://arxiv.org/abs/2106.09685 • コード: https://github.com/microsoft/LoRA • 投稿学会: ICLR 2022 • 著者: Edward Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen • 所属:Microsoft Corporation 選んだ理由: • 基盤モデルのような⼤規模モデルを少ない学習コストで所望のタスクに合わせた チューニングを⾏う⽅法として有効そうである。
  2. 4 Natural Language benchmark GLUE(General Language Understanding Evaluation) NLPモデルの性能向上も著しく、より難しいSuperGLUEも使われている。 英語の⾃然⾔語処理の標準ベンチマークとして扱われている。

    CoLA: ⽂が英語⽂法として正しいか SST-2: 映画レビューの感情(Positive, Negative)判定 MRPC: オンラインニュースからの2つの⽂が同じ意味か STS-B: ニュースの⾒出しの類似度を5段階評価 QQP: 2つの質問が同じ意味か MNLI-m/MNLI-m: 2つの⽂の含意、⽭盾、中⽴の判定 SquAD: コンテキストから質問の回答を抽出 QNLI: 質問と⽂のペアについて、⽂が正しい回答を含んでいるか RTE: 2つの⽂が含意かそうでないか WNLI: 代名詞が置換された⽂が元の⽂に含まれているか
  3. 5 Background (BERT fine-tuning) BERT BERT E [CLS] E 1

    E [SEP] ... E N E 1 ’ ... E M ’ C T 1 T [SEP] ... T N T 1 ’ ... T M ’ [CLS] Tok 1 [SEP] ... Tok N Tok 1 ... TokM Question Paragraph Start/End Span BERT E [CLS] E 1 E [SEP] ... E N E 1 ’ ... E M ’ C T 1 T [SEP] ... T N T 1 ’ ... T M ’ [CLS] Tok 1 [SEP] ... Tok N Tok 1 ... TokM Masked Sentence A Masked Sentence B Pre-training Fine-Tuning NSP Mask LM Mask LM Unlabeled Sentence A and B Pair SQuAD Question Answer Pair NER MNLI Figure 1: Overall pre-training and fine-tuning procedures for BERT. Apart from output layers, the same architec- tures are used in both pre-training and fine-tuning. The same pre-trained model parameters are used to initialize models for different down-stream tasks. During fine-tuning, all parameters are fine-tuned. [CLS] is a special symbol added in front of every input example, and [SEP] is a special separator token (e.g. separating ques- tions/answers). ⾔語としての⾔葉の連なりを学習 タスクごとにふさわしい回答をできるように チューニング BERTの頃 https://aclanthology.org/N19-1423/
  4. 6 Background (GPT-3 few-shot learning) Figure 2.1: Zero-shot, one-shot and

    few-shot, contrasted with traditional https://arxiv.org/abs/2005.14165 モデルを学習し直さなくても、回答⽅法の例⽰などで、出⼒精度が向上する。
  5. 7 Background (Prompt engineering) (c) Zero-shot Q: A juggler can

    juggle 16 balls. Half of the balls are golf balls, and half of the golf balls are blue. How many blue golf balls are there? A: The answer (arabic numerals) is (Output) 8 X (d) Zero-shot-CoT (Ours) Q: A juggler can juggle 16 balls. Half of the balls are golf balls, and half of the golf balls are blue. How many blue golf balls are there? A: Let’s think step by step. (Output) There are 16 balls in total. Half of the balls are golf balls. That means that there are 8 golf balls. Half of the golf balls are blue. That means that there are 4 blue golf balls. ✓ Q: Roger has 5 tennis balls. He buys 2 more cans of tennis balls. Each can has 3 tennis balls. How many tennis balls does he have now? A: Roger started with 5 balls. 2 cans of 3 tennis balls each is 6 tennis balls. 5 + 6 = 11. The answer is 11. Q: A juggler can juggle 16 balls. Half of the balls are golf balls, and half of the golf balls are blue. How many blue golf balls are there? A: (Output) The juggler can juggle 16 balls. Half of the balls are golf balls. So there are 16 / 2 = 8 golf balls. Half of the golf balls are blue. So there are 8 / 2 = 4 blue golf balls. The answer is 4. ✓ (b) Few-shot-CoT (a) Few-shot Q: Roger has 5 tennis balls. He buys 2 more cans of tennis balls. Each can has 3 tennis balls. How many tennis balls does he have now? A: The answer is 11. Q: A juggler can juggle 16 balls. Half of the balls are golf balls, and half of the golf balls are blue. How many blue golf balls are there? A: (Output) The answer is 8. X Figure 1: Example inputs and outputs of GPT-3 with (a) standard Few-shot ([Brown et al., 2020]), (b) Few-shot-CoT ([Wei et al., 2022]), (c) standard Zero-shot, and (d) ours (Zero-shot-CoT). Similar to Few-shot-CoT, Zero-shot-CoT facilitates multi-step reasoning (blue text) and reach correct answer where standard prompting fails. Unlike Few-shot-CoT using step-by-step reasoning examples per 質問や回答⽅法の指⽰Promptを適切にすると、望ましい回答が得られる可能性が上がる。
  6. 9 Related work (Adapter) Parameter-Efficient Transfer Learning for NLP Multi-headed

    attention Layer Norm + Adapter 2x Feed-forward layer Layer Norm + Adapter Feed-forward layer Transformer Layer Nonlinearity Feedforward up-project Feedforward down-project Adapter Layer + Figure 2. Architecture of the adapter module and its integration with the Transformer. Left: We add the adapter module twice to each Transformer layer: after the projection following multi- headed attention and after the two feed-forward layers. Right: The adapter consists of a bottleneck which contains few parameters rel- ative to the attention and feedforward layers in the original model. The adapter also contains a skip-connection. During adapter tun- ing, the green layers are trained on the downstream data, this includes the adapter, the layer normalization parameters, and the final classification layer (not shown in the figure). classification problems (Vaswani et al., 2017; Radford et al., 2018; Devlin et al., 2018). We consider the standard Trans- nique, similar to conditional batch normalization (De Vries et al., 2017), FiLM (Perez et al., 2018), and self- Transformerの中に学習可能なMulti Layer Perception層を⼊れる。 層を直列に追加するため、GPUの並列処理性能をうまく活⽤できず、Latencyを増加させ る。 Batch Size 32 16 1 Sequence Length 512 256 128 |⇥| 0.5M 11M 11M Fine-Tune/LoRA 1449.4±0.8 338.0±0.6 19.8±2.7 AdapterL 1482.0±1.0 (+2.2%) 354.8±0.5 (+5.0%) 23.9±2.1 (+20.7%) AdapterH 1492.2±1.0 (+3.0%) 366.3±0.5 (+8.4%) 25.8±2.2 (+30.3%) Table 1: Infernece latency of a single forward pass in GPT-2 medium measured in milliseconds, av eraged over 100 trials. We use an NVIDIA Quadro RTX8000. “|⇥|” denotes the number of trainabl parameters in adapter layers. AdapterL and AdapterH are two variants of adapter tuning, which w describe in Section 5.1. The inference latency introduced by adapter layers can be significant in a online, short-sequence-length scenario. See the full study in Appendix B. 4 OUR METHOD We describe the simple design of LoRA and its practical benefits. The principles outlined here appl to any dense layers in deep learning models, though we only focus on certain weights in Transforme language models in our experiments as the motivating use case. GPT-2 midiumについて、推論のLatencyを計測(100回の平均) NVIIDA Quadra RTX8000にて
  7. 10 LoRA: Low-Rank Adaptation 本論⽂で提案している⽅法 ベースモデルの線形層に差分⾏列 Δ𝑊 = 𝐵𝐴を追加 2つの低ランク⾏列A,

    B A: d x r サイズ B: r x d サイズ rは2~64など⼩さな数 ベースモデルは 固定 Transformerの Attentionの𝑊 !, 𝑊", 𝑊 # などの線形層に対して 加えられる we hypothesize the updates to the weights also have a low intrinsic rank during adaptation.
  8. 11 Trick of LoRA スタート Aは平均0の正規分布、Bはゼロ⾏列で初期化 Δ𝑊 = 𝐵𝐴は最初は何もないのと同じで、 LoRA層が元のモデルに影響しない

    学習可能パラメータ数を増やした場合 Adapter: MLP層を追加したモデルをFine-tuningしている LoRA: rを⼤きくして元の⾏列と同じであれば、元のモデル とほぼ同じ
  9. 12 Benefit of LoRA u容量, 性能が⼤きいベースモデルは⼀つで、タスクごとに⼩さな LoRAを作れば良い。 u推論Latencyが増えない。LoRA層は線形変換をしているのみで、 ベースとマージできる。 ℎ

    = 𝑊&𝑥 + Δ𝑊𝑥 = 𝑊&𝑥 + 𝐵𝐴𝑥 = 𝑊& + 𝐵𝐴 𝑥 = 𝑊'()𝑥 GPT-3 175Bの場合 (r=4) 学習パラメータ数は10,000分の1(checkpoint size 350GB→35MB) GPUメモリは1/3(1.2TB→350GB) 学習はfull fine-tuningに⽐べて25%⾼速化 u少ない計算リソースで調整できる。
  10. 13 Evaluation through various models • RoBERTa base(125M)/ large(355M) •

    DeBERTa XXL(1.5B) • GPT-2 medium(355M) / large(774M) • GPT-3(175B) • GLUE • E2E NLG Challenge • WikiSQL モデル ベンチマーク
  11. 14 Evaluation RoBETa & DeBERTa Model & Method # Trainable

    Parameters MNLI SST-2 MRPC CoLA QNLI QQP RTE STS-B Avg. RoBbase (FT)* 125.0M 87.6 94.8 90.2 63.6 92.8 91.9 78.7 91.2 86.4 RoBbase (BitFit)* 0.1M 84.7 93.7 92.7 62.0 91.8 84.0 81.5 90.8 85.2 RoBbase (AdptD)* 0.3M 87.1±.0 94.2±.1 88.5±1.1 60.8±.4 93.1±.1 90.2±.0 71.5±2.7 89.7±.3 84.4 RoBbase (AdptD)* 0.9M 87.3±.1 94.7±.3 88.4±.1 62.6±.9 93.0±.2 90.6±.0 75.9±2.2 90.3±.1 85.4 RoBbase (LoRA) 0.3M 87.5±.3 95.1±.2 89.7±.7 63.4±1.2 93.3±.3 90.8±.1 86.6±.7 91.5±.2 87.2 RoBlarge (FT)* 355.0M 90.2 96.4 90.9 68.0 94.7 92.2 86.6 92.4 88.9 RoBlarge (LoRA) 0.8M 90.6±.2 96.2±.5 90.9±1.2 68.2±1.9 94.9±.3 91.6±.1 87.4±2.5 92.6±.2 89.0 RoBlarge (AdptP)† 3.0M 90.2±.3 96.1±.3 90.2±.7 68.3±1.0 94.8±.2 91.9±.1 83.8±2.9 92.1±.7 88.4 RoBlarge (AdptP)† 0.8M 90.5±.3 96.6±.2 89.7±1.2 67.8±2.5 94.8±.3 91.7±.2 80.1±2.9 91.9±.4 87.9 RoBlarge (AdptH)† 6.0M 89.9±.5 96.2±.3 88.7±2.9 66.5±4.4 94.7±.2 92.1±.1 83.4±1.1 91.0±1.7 87.8 RoBlarge (AdptH)† 0.8M 90.3±.3 96.3±.5 87.7±1.7 66.3±2.0 94.7±.2 91.5±.1 72.9±2.9 91.5±.5 86.4 RoBlarge (LoRA)† 0.8M 90.6±.2 96.2±.5 90.2±1.0 68.2±1.9 94.8±.3 91.6±.2 85.2±1.1 92.3±.5 88.6 DeBXXL (FT)* 1500.0M 91.8 97.2 92.0 72.0 96.0 92.7 93.9 92.9 91.1 DeBXXL (LoRA) 4.7M 91.9±.2 96.9±.2 92.6±.6 72.4±1.1 96.0±.1 92.9±.1 94.9±.4 93.0±.2 91.3 Table 2: RoBERTabase , RoBERTalarge , and DeBERTaXXL with different adaptation methods on the GLUEの⾔語理解において、少ない学習パラメータでFine-tuningと同等か、 それ以上のスコアを得ている
  12. 15 Evaluation GPT-2 Model & Method # Trainable E2E NLG

    Challenge Parameters BLEU NIST MET ROUGE-L CIDEr GPT-2 M (FT)* 354.92M 68.2 8.62 46.2 71.0 2.47 GPT-2 M (AdapterL)* 0.37M 66.3 8.41 45.0 69.8 2.40 GPT-2 M (AdapterL)* 11.09M 68.9 8.71 46.1 71.3 2.47 GPT-2 M (AdapterH) 11.09M 67.3±.6 8.50±.07 46.0±.2 70.7±.2 2.44±.01 GPT-2 M (FTTop2)* 25.19M 68.1 8.59 46.0 70.8 2.41 GPT-2 M (PreLayer)* 0.35M 69.7 8.81 46.1 71.4 2.49 GPT-2 M (LoRA) 0.35M 70.4±.1 8.85±.02 46.8±.2 71.8±.1 2.53±.02 GPT-2 L (FT)* 774.03M 68.5 8.78 46.0 69.9 2.45 GPT-2 L (AdapterL) 0.88M 69.1±.1 8.68±.03 46.3±.0 71.4±.2 2.49±.0 GPT-2 L (AdapterL) 23.00M 68.9±.3 8.70±.04 46.1±.1 71.3±.2 2.45±.02 GPT-2 L (PreLayer)* 0.77M 70.3 8.85 46.2 71.7 2.47 GPT-2 L (LoRA) 0.77M 70.4±.1 8.89±.02 46.8±.2 72.0±.2 2.47±.02 Table 3: GPT-2 medium (M) and large (L) with different adaptation methods on the E2E NLG Challenge. For all metrics, higher is better. LoRA outperforms several baselines with comparable or fewer trainable parameters. Confidence intervals are shown for experiments we ran. * indicates E2E NLG Challengeの⾔語⽣成において、少ない学習パラメータでFine-tuning と同等か、それ以上のスコアを得ている
  13. 16 Evaluation GPT-3 Model&Method # Trainable WikiSQL MNLI-m SAMSum Parameters

    Acc. (%) Acc. (%) R1/R2/RL GPT-3 (FT) 175,255.8M 73.8 89.5 52.0/28.0/44.5 GPT-3 (BitFit) 14.2M 71.3 91.0 51.3/27.4/43.5 GPT-3 (PreEmbed) 3.2M 63.1 88.6 48.3/24.2/40.5 GPT-3 (PreLayer) 20.2M 70.1 89.5 50.8/27.3/43.5 GPT-3 (AdapterH) 7.1M 71.9 89.8 53.0/28.9/44.8 GPT-3 (AdapterH) 40.1M 73.2 91.5 53.2/29.0/45.1 GPT-3 (LoRA) 4.7M 73.4 91.7 53.8/29.8/45.9 GPT-3 (LoRA) 37.7M 74.0 91.6 53.4/29.2/45.1 Table 4: Performance of different adaptation methods on GPT-3 175B. We report the logical form validation accuracy on WikiSQL, validation accuracy on MultiNLI-matched, and Rouge-1/2/L on ⾔語理解・⾔語⽣成において、少ない学習パラメータでFine-tuningと同等か、 それ以上のスコアを得ている
  14. 18 Choice of rank and apply to where We believe

    that our answers to question (2) and (3) shed light on the fundamental principles of using pre-trained language models for downstream tasks, which is a critical topic in NLP. 7.1 WHICH WEIGHT MATRICES IN TRANSFORMER SHOULD WE APPLY LORA TO? Given a limited parameter budget, which types of weights should we adapt with LoRA to obtain the best performance on downstream tasks? As mentioned in Section 4.2, we only consider weight matrices in the self-attention module. We set a parameter budget of 18M (roughly 35MB if stored in FP16) on GPT-3 175B, which corresponds to r = 8 if we adapt one type of attention weights or r = 4 if we adapt two types, for all 96 layers. The result is presented in Table 5. # of Trainable Parameters = 18M Weight Type Wq Wk Wv Wo Wq, Wk Wq, Wv Wq, Wk, Wv, Wo Rank r 8 8 8 8 4 4 2 WikiSQL (±0.5%) 70.4 70.0 73.0 73.2 71.4 73.7 73.7 MultiNLI (±0.1%) 91.0 90.8 91.0 91.3 91.3 91.3 91.7 Table 5: Validation accuracy on WikiSQL and MultiNLI after applying LoRA to different types of attention weights in GPT-3, given the same number of trainable parameters. Adapting both Wq and Wv gives the best performance overall. We find the standard deviation across random seeds to be consistent for a given dataset, which we report in the first column. Note that putting all the parameters in Wq or Wk results in significantly lower performance, while adapting both Wq and Wv yields the best result. This suggests that even a rank of four captures enough information in W such that it is preferable to adapt more weight matrices than adapting a single type of weights with a larger rank. 7.2 WHAT IS THE OPTIMAL RANK r FOR LORA? WikiSQL (±0.5%) 70.4 70.0 73.0 73.2 71.4 73.7 73.7 MultiNLI (±0.1%) 91.0 90.8 91.0 91.3 91.3 91.3 91.7 Table 5: Validation accuracy on WikiSQL and MultiNLI after applying LoRA to different types of attention weights in GPT-3, given the same number of trainable parameters. Adapting both Wq and Wv gives the best performance overall. We find the standard deviation across random seeds to be consistent for a given dataset, which we report in the first column. Note that putting all the parameters in Wq or Wk results in significantly lower performance, while adapting both Wq and Wv yields the best result. This suggests that even a rank of four captures enough information in W such that it is preferable to adapt more weight matrices than adapting a single type of weights with a larger rank. 7.2 WHAT IS THE OPTIMAL RANK r FOR LORA? We turn our attention to the effect of rank r on model performance. We adapt {Wq, Wv }, {Wq, Wk, Wv, Wc }, and just Wq for a comparison. Weight Type r = 1 r = 2 r = 4 r = 8 r = 64 WikiSQL(±0.5%) Wq 68.8 69.6 70.5 70.4 70.0 Wq, Wv 73.4 73.3 73.7 73.8 73.5 Wq, Wk, Wv, Wo 74.1 73.7 74.0 74.0 73.9 MultiNLI (±0.1%) Wq 90.7 90.9 91.1 90.7 90.7 Wq, Wv 91.3 91.4 91.3 91.6 91.4 Wq, Wk, Wv, Wo 91.2 91.7 91.7 91.5 91.4 Table 6: Validation accuracy on WikiSQL and MultiNLI with different rank r. To our surprise, a GPT−3において、どの層(𝑊 !, 𝑊", 𝑊 #, 𝑊 $ )に適応するか、Rank(学習パラメータ数)の違いに よる性能の違いを⽐較する。 r=2などのかなり⼩さなRankでも性能が出ている。 単⼀の層に適応するより、Rankが⼩さくても複数の層に適応するのが有効。
  15. 19 Subspace similarity r=8とr=64について、特異値分解して、Grassmann distanceでsubspaceの類似度を測ってみる。 (Ar=8, Ar=64, i, j) =

    min(i, j) 2 [0, 1] (4) where Ui Ar=8 represents the columns of UAr=8 corresponding to the top-i singular vectors. (·) has a range of [0, 1], where 1 represents a complete overlap of subspaces and 0 a complete separation. See Figure 3 for how changes as we vary i and j. We only look at the 48th layer (out of 96) due to space constraint, but the conclusion holds for other layers as well, as shown in Section H.1. Figure 3: Subspace similarity between column vectors of Ar=8 and Ar=64 for both Wq and Wv . The third and the fourth figures zoom in on the lower-left triangle in the first two figures. The top directions in r = 8 are included in r = 64, and vice versa. 上位の特異値ベクトルには共通性がみられる。 低ランクでタスク固有の空間を学習で得られていると⾒える。 48番⽬のLayerについて
  16. 20 Compare parameter space of base and LoRA how “large”

    is W comparing to its corresponding directions in W? This can shed light on the underlying mechanism for adapting pre-trained language models. To answer these questions, we project W onto the r-dimensional subspace of W by comput- ing U>WV >, with U/V being the left/right singular-vector matrix of W. Then, we com- pare the Frobenius norm between kU>WV >kF and kWkF . As a comparison, we also compute kU>WV >kF by replacing U, V with the top r singular vectors of W or a random matrix. r = 4 r = 64 Wq Wq Random Wq Wq Random ||U>WqV >||F = 0.32 21.67 0.02 1.90 37.71 0.33 ||Wq ||F = 61.95 || Wq ||F = 6.91 || Wq ||F = 3.57 Table 7: The Frobenius norm of U>WqV > where U and V are the left/right top r singular vector directions of either (1) Wq , (2) Wq , or (3) a random matrix. The weight matrices are taken from the 48th layer of GPT-3. We draw several conclusions from Table 7. First, W has a stronger correlation with W compared to a random matrix, indicating that W amplifies some features that are already in W. Second, Δ𝑊のsubspaceにWを投影して、関連性を測る。 Randomと⽐較して、ΔWは強い相関を持つ ➔ Wが持っているなんらかの特徴を強調する働きがある。 Wそのもののと⽐較して、ΔWはかなり⼩さな相関を持つ ➔ Wの⽀配的な⽅向を強調するわけではない。 6.91 0.32 ≈ 21.6 Feature amplification factor 3.57 1.90 ≈ 1.88 >> r=4がr=64より効率的にタスクに対して適合するように働きかけている。 ベースモデルには必要な知識が存在しており、LoRAはそれをタスクに応じて強調しているのではないか。