Upgrade to PRO for Only $50/Year—Limited-Time Offer! 🔥

大規模言語モデル (LLM)における低精度数値表現

大規模言語モデル (LLM)における低精度数値表現

2024年5月8日のHPC研究会で使用したスライドです。
https://www.ipsj.or.jp/kenkyukai/event/hpc194.html

Preferred Networks

May 13, 2024
Tweet

More Decks by Preferred Networks

Other Decks in Technology

Transcript

  1. 2 • 三上 裕明 (みかみ ひろあき) • (主な) 業務: DNN学習等の高速化/分散処理

    ◦ LLMの事前学習高速化 ▪ PLaMo-13Bを公開しました ◦ MN-Core 向けコンパイラの開発 ◦ その他GPUを用いた分散処理の最適化 • 経歴 ◦ 株式会社 Preferred Elements (2023/11 〜) ◦ 株式会社 Preferred Networks (2019/9 〜) ◦ ソニー株式会社 (2017/4 〜 2019/8) ◦ 東京大学大学院 (2015/4 〜 2017/3) 自己紹介
  2. 3 • Preferred Networks (PFN), Preferred Elements (PFE) について •

    LLMと低精度数値表現 • 低精度数値表現の手法 ◦ フォーマット ◦ cast (量子化) 手法 • LLMにおける利用事例・課題 ◦ 学習における事例 ◦ 推論における事例 • まとめ 目次
  3. 5 Preferred Networks(PFN)会社概要 PFNは深層学習などのソフトウェア技術と、それを支える計算インフラなどのハードウェア技術を 融合し、様々な産業領域で最先端技術の実用化・事業化に取り組んでいます。 事業化領域 研究領域 計算インフラ
 機械学習・深層学習 


    シミュレーション
 
 
 画像認識
 
 
 
 自然言語処理
 
 
 
 
 ロボティクス
 
 
 
 最適化
 
 
 製造業
 交通
 システム
 エンタメ
 その他 プラント
 最適化
 材料探索
 創薬・
 ヘルスケア
 ロボット
 
 データ
 生成補完
 
 
 異常検知
 
 

  4. 6 • LLMを含む基盤モデルの研究開発を行うPFNの子会社 • NEDOの採択 (GENIAC, Generative AI Accelerator Challenge)

    を受け マルチモーダル基盤モデルの開発中 Preferred Elements (PFE) 会社概要
  5. 8 Deep Neural Networkにおける低精度数値表現 16bit浮動小数点 一部の最適化したケースで利用 • ResNet-50 • BERT

    • Instant NGP • … 8bit整数 推論高速化で主に利用 [Practical Quantization in PyTorch] [Introduction to Quantization on PyTorch] [NVIDIA TensorRT] 低精度数値表現とは? 32bit未満のフォーマット
  6. 9 Deep Neural Networkにおける低精度数値表現 16bit浮動小数点 一部の最適化したケースで利用 • ResNet-50 • BERT

    • Instant NGP • … ※ 統計はみつかりませんでしたが、PFN社内の状況やQuantizationについてのpytorchのドキュメントをもとにしています 8bit整数 推論高速化で主に利用 [Practical Quantization in PyTorch] [Introduction to Quantization on PyTorch] [NVIDIA TensorRT] LLM以外のDNNではいまだに32bit浮動小数点が主流 ※ 低精度数値表現は特殊な最適化の一つ 低精度数値表現とは? 16bit以下のフォーマット
  7. 10 LLM における状況の変化 BERT (large) [Accelerated Large Batch Optimization of

    BERT Pretraining in 54 minutes] パラメータ数 334M (1.3 GB w/ FP32) 学習時間 (16bit) 1時間 (V100 1536台)
  8. 11 LLM における状況の変化 BERT (large) [Accelerated Large Batch Optimization of

    BERT Pretraining in 54 minutes] パラメータ数 334M (1.3 GB w/ FP32) 学習時間 (16bit) 1時間 (V100 1536台) FP32でも動作する 低精度表現は必須ではない
  9. 12 LLM における状況の変化 BERT (large) LLaMA-65B [Accelerated Large Batch Optimization

    of BERT Pretraining in 54 minutes] パラメータ数 65 B (260GB w/ FP32) 学習時間 (16bit) 3週間 (A100 2048台) [LLaMA: Open and Efficient Foundation Language Models] パラメータ数 334M (1.3 GB w/ FP32) 学習時間 (16bit) 1時間 (V100 1536台) FP32でも動作する 低精度表現は必須ではない
  10. 13 LLM における状況の変化 BERT (large) LLaMA-65B [Accelerated Large Batch Optimization

    of BERT Pretraining in 54 minutes] パラメータ数 65 B (260GB w/ FP32) 学習時間 (16bit) 3週間 (A100 2048台) [LLaMA: Open and Efficient Foundation Language Models] パラメータ数 334M (1.3 GB w/ FP32) 学習時間 (16bit) 1時間 (V100 1536台) FP32でも動作する 低精度表現は必須ではない FP32では現実的な条件で動かない 低精度化は前提となることが多い
  11. 15 低精度数値表現例: 数値フォーマット 浮動小数点 整数 その他 16 bit 1 bit

    2 bit 8 bit BFloat16 FP8 (E4M3 / E5M2) FP6 [FP6-LLM] int8 int4 int2 [HQQ] 3値 [BitNet] NormalFloat4 [QLoRA] Dynamic Tree Quantization [8-bit optimizer] 4 bit FP4
  12. 16 低精度数値表現例: 数値フォーマット 浮動小数点 整数 その他 16 bit 1 bit

    2 bit 8 bit BFloat16 FP8 (E4M3 / E5M2) FP6 [FP6-LLM] int8 int4 int2 [HQQ] 3値 [BitNet] NormalFloat4 [QLoRA] Dynamic Tree Quantization [8-bit optimizer] 4 bit FP4 GPUでの計算高速化に 利用できる
  13. 17 低精度数値表現例: 数値フォーマット 浮動小数点 整数 その他 16 bit 1 bit

    2 bit 8 bit BFloat16 FP8 (E4M3 / E5M2) FP6 [FP6-LLM] int8 int4 int2 [HQQ] 3値 [BitNet] NormalFloat4 [QLoRA] Dynamic Tree Quantization [8-bit optimizer] 4 bit FP4 GPUでの計算高速化に 利用できる 4bit未満の精度では 整数型が一般的
  14. 18 低精度数値表現例: 数値フォーマット 浮動小数点 整数 その他 16 bit 1 bit

    2 bit 8 bit BFloat16 FP8 (E4M3 / E5M2) FP6 [FP6-LLM] int8 int4 int2 [HQQ] 3値 [BitNet] NormalFloat4 [QLoRA] Dynamic Tree Quantization [8-bit optimizer] 4 bit FP4 GPUでの計算高速化に 利用できる 4bit未満の精度では 整数型が一般的 速度が重要でない用途 に向く
  15. 19 低精度数値表現例: 数値フォーマット 浮動小数点 整数 その他 16 bit 1 bit

    2 bit 8 bit BFloat16 FP8 (E4M3 / E5M2) FP6 [FP6-LLM] int8 int4 int2 [HQQ] 3値 [BitNet] NormalFloat4 [QLoRA] Dynamic Tree Quantization [8-bit optimizer] 4 bit FP4
  16. 21 低精度数値表現例: cast手法 (量子化) 高精度 (BFloat16 or FP32) 低精度表現 (~8bit)

    [Mixed Precision Training] - オーバーフロー / アンダーフロー [Mixed Precision Training] - その他の数値誤差
  17. 22 低精度数値表現例: cast手法 (量子化) スケーリング s = max(abs(x)) qx =

    Quantize(x / s) # 値を[-1:1]の範囲にしてから量子化 # ⇒ オーバーフロー/アンダーフローを防ぐ -100 1 100 1000 BFloat16 int8 -100 1 100 -24 -12 0 12 127 scale = 1000 / 127 w/o scaling w/ scaling
  18. 23 低精度数値表現例: cast手法 (量子化) block-wise (fine-grained) 量子化 -1 1 1

    10 per-tensor (coarse-grained) 量子化 0 0 0 1 3値量子化 scale=10 0 0 0 10 復元 -1 1 1 10 block-wise (fine-grained) 量子化 -1 1 0 1 3値量子化, block-size=2 scale=1, 10 -1 -1 0 10 復元
  19. 24 低精度数値表現例: cast手法 (量子化) block-wise (fine-grained) 量子化 -1 1 1

    10 per-tensor (coarse-grained) 量子化 0 0 0 1 3値量子化 scale=10 0 0 0 10 復元 -1 1 1 10 block-wise (fine-grained) 量子化 -1 1 0 1 3値量子化, block-size=2 scale=1, 10 -1 -1 0 10 復元 Tensorごとに単一のscaleを用いる
  20. 25 低精度数値表現例: cast手法 (量子化) block-wise (fine-grained) 量子化 -1 1 1

    10 per-tensor (coarse-grained) 量子化 0 0 0 1 3値量子化 scale=10 0 0 0 10 復元 -1 1 1 10 block-wise (fine-grained) 量子化 -1 1 0 1 3値量子化, block-size=2 scale=1, 10 -1 -1 0 10 復元 一定の要素数ごとにscaleを用意する ⇒ 量子化誤差が小さくなる 4bit以下への量子化で特に重要 [ZeROQuant(4 + 2)] Tensorごとに単一のscaleを用いる
  21. 27 LLMにおける利用事例・課題:学習 Icon pack by Icons8 - https://icons8.com optimizer state

    計算 (fwd + bwd) 作業領域 optimizer state 計算 (fwd + bwd) 作業領域 通信
  22. 28 LLMにおける利用事例・課題:学習 Icon pack by Icons8 - https://icons8.com optimizer state

    計算 (fwd + bwd) 作業領域 optimizer state 計算 (fwd + bwd) 作業領域 通信 計算の高速化 - BF16 TensorCoreの利用 - FP8 TensorCoreの利用
  23. 29 LLMにおける利用事例・課題:学習 Icon pack by Icons8 - https://icons8.com optimizer state

    計算 (fwd + bwd) 作業領域 optimizer state 計算 (fwd + bwd) 作業領域 通信 通信の高速化 - collectiveの量子化 - モデル並列による削減 メモリ消費の削減 - 値の量子化 - モデル並列による削減 計算の高速化 - BF16 TensorCoreの利用 - FP8 TensorCoreの利用
  24. 30 要素数 (Llama2 70Bの場合の一例) M=K=8192の時 :B/F = 5.0 e-4 H100

    w/ FP8 :B/F = 1.7 e-3 ※ LLMにおける計算のボトルネック = 行列積 LLMにおける利用事例・課題:学習 (計算高速化) 行列積 (MatMul, y=Wx) W: M x K x: K x N y: M x N M 8192 ~ 28672 K 8192 ~ 28672 N 4096 ※ B/Fの計算には入出力の総和を利用
  25. 31 要素数 (Llama2 70Bの場合の一例) M=K=8192の時 :B/F = 5.0 e-4 H100

    w/ FP8 :B/F = 1.7 e-3 ※ LLMにおける計算のボトルネック = 行列積 LLMにおける利用事例・課題:学習 (計算高速化) 行列積 (MatMul, y=Wx) W: M x K x: K x N y: M x N TensorCoreの活用によりGPUの計算能力 (FLOP/s) をあげることが重要 M 8192 ~ 28672 K 8192 ~ 28672 N 4096 ※ B/Fの計算には入出力の総和を利用
  26. 33 LLMにおける計算のボトルネック = 行列積 LLMにおける利用事例・課題:学習 (計算高速化) 行列積 (MatMul, y=Wx) W:

    M x K x: K x N y: M x N cast (BFloat16 → FP8) cast (BFloat16 → FP8) - per-tensor scaling - forwardではE4M3をbackwardではE5M2を使う [FP8 Formats for Deep Learning]
  27. 34 LLMにおける計算のボトルネック = 行列積 LLMにおける利用事例・課題:学習 (計算高速化) 行列積 (MatMul, y=Wx) W:

    M x K x: K x N y: M x N cast (BFloat16 → FP8) cast (BFloat16 → FP8) - per-tensor scaling - forwardではE4M3をbackwardではE5M2を使う [FP8 Formats for Deep Learning] qx = Cast(x / s_prev, to=fp8) s_prev = max(abs(x)) n_amax_history = 1の時の疑似コード Delayed Scaling [Transformer Engine] scale factorとして過去の結果を使う ⇒ メモリアクセスを1回省略できる
  28. 35 LLMにおける計算のボトルネック = 行列積 LLMにおける利用事例・課題:学習 (計算高速化) 行列積 (MatMul, y=Wx) W:

    M x K x: K x N y: M x N cast (BFloat16 → FP8) cast (BFloat16 → FP8) M=K=8192, N=4096の時の性能 (H100 SXM + Transformer Engine) 処理 実行時間 [us] 行列積 (FP8) 451 Cast (W) 67 Cast (x) 34 行列積 350 行列積 (BF16) 670
  29. 36 LLMにおける計算のボトルネック = 行列積 LLMにおける利用事例・課題:学習 (計算高速化) 行列積 (MatMul, y=Wx) W:

    M x K x: K x N y: M x N cast (BFloat16 → FP8) cast (BFloat16 → FP8) M=K=8192, N=4096の時の性能 (H100 SXM + Transformer Engine) 処理 実行時間 [us] 行列積 (FP8) 451 Cast (W) 67 Cast (x) 34 行列積 350 行列積 (BF16) 670 BFloat16を使う処理の時間が無視できない
  30. 37 LLMにおける利用事例・課題:推論 Icon pack by Icons8 - https://icons8.com パラメータ KV

    キャッシュ KV キャッシュ KV キャッシュ 生成token 生成token 生成token
  31. 38 LLMにおける利用事例・課題:推論 Icon pack by Icons8 - https://icons8.com パラメータ KV

    キャッシュ KV キャッシュ KV キャッシュ 生成token 生成token 生成token 計算の高速化・効率化 - パラメータの量子化 - モデルサイズの削減
  32. 39 LLMにおける利用事例・課題:推論 Icon pack by Icons8 - https://icons8.com パラメータ KV

    キャッシュ KV キャッシュ KV キャッシュ 生成token 生成token 生成token 省メモリ化 - パラメータの量子化 - KVキャッシュの量子化 - モデル並列による推論 計算の高速化・効率化 - パラメータの量子化 - モデルサイズの削減
  33. 41 Mixed Precision Decomposition [LLM.int8()] LLMにおける利用事例・課題:推論 (省メモリ化) 学習済LLMでは一部のchannelだけ scaleが大きくなる [KVQuant]

    [AWQ] 外れ値のchannelだけは16bitのまま 保持することで性能を維持する どうやって外れ値をみつけるか? - 適当な閾値を決める [LLM.int8()] - 小さいデータセットを流して計算途中の値を使う [AWQ] (data dependent quantization)
  34. 42 パラメータ LLMにおける利用事例・課題:推論 (省メモリ化) KV キャッシュ 処理 scaling format mixed-prec.

    decomposition bitsandbytes block-wise NormalFloat4 no LLM.int8() block-wise int8 yes AWQ block-wise int4 yes HQQ block-wise int1/int2/int4 no 処理 scaling format mixed-prec. decomposition FlexGen block-wise int4 no KVQuant block-wise int3 yes
  35. 44 ロスレス符号化 [QMoE] LLMにおける利用事例・課題:推論 (省メモリ化) 高精度 (BFloat16) 低精度表現 (3値 or

    2bit) 量子化 量子化は多くの場合性能と消費メモリのトレードオフ で優れている [Pruning vs Quantization]
  36. 45 ロスレス符号化 [QMoE] LLMにおける利用事例・課題:推論 (省メモリ化) 高精度 (BFloat16) 低精度表現 (3値 or

    2bit) 量子化 量子化は多くの場合性能と消費メモリのトレードオフ で優れている [Pruning vs Quantization] 符号化 可逆圧縮 - 3値からのさらなる量子化は限界がある - アンダーフローにより0の割合は増えていく 可逆圧縮によりメモリ消費の削減を狙う (0.80 ~ 0.93 bit/要素にできる [QMoE])
  37. 46 ロスレス符号化 [QMoE] LLMにおける利用事例・課題:推論 (省メモリ化) モデル Bfloat16 HQQ (int4) +zstd

    HQQ (3値) +zstd 3B 5.6GiB 1.8GiB (31%) 1.5 GiB (27%) 1.1GiB (19%) 0.64GiB (11%) 10B (MoE) 21 GiB 6.7 GiB (31%) 5.4 GiB (25%) 4.0GiB (19%) 2.2GiB (10%) ※ bit幅以外の設定はhqqのデフォルトを使用
  38. 47 ロスレス符号化 [QMoE] LLMにおける利用事例・課題:推論 (省メモリ化) モデル Bfloat16 HQQ (int4) +zstd

    HQQ (3値) +zstd 3B 5.6GiB 1.8GiB (31%) 1.5 GiB (27%) 1.1GiB (19%) 0.64GiB (11%) 10B (MoE) 21 GiB 6.7 GiB (31%) 5.4 GiB (25%) 4.0GiB (19%) 2.2GiB (10%) scale factorのサイズは無視できない割合をしめる ※ bit幅以外の設定はhqqのデフォルトを使用
  39. 48 ロスレス符号化 [QMoE] LLMにおける利用事例・課題:推論 (省メモリ化) モデル Bfloat16 HQQ (int4) +zstd

    HQQ (3値) +zstd 3B 5.6GiB 1.8GiB (31%) 1.5 GiB (27%) 1.1GiB (19%) 0.64GiB (11%) 10B (MoE) 21 GiB 6.7 GiB (31%) 5.4 GiB (25%) 4.0GiB (19%) 2.2GiB (10%) ※ bit幅以外の設定はhqqのデフォルトを使用
  40. 49 ロスレス符号化 [QMoE] LLMにおける利用事例・課題:推論 (省メモリ化) モデル Bfloat16 HQQ (int4) +zstd

    HQQ (3値) +zstd 3B 5.6GiB 1.8GiB (31%) 1.5 GiB (27%) 1.1GiB (19%) 0.64GiB (11%) 10B (MoE) 21 GiB 6.7 GiB (31%) 5.4 GiB (25%) 4.0GiB (19%) 2.2GiB (10%) - 汎用的な圧縮アルゴリズムは効果が薄い - DNNモデルの種類の影響はみえない ※ bit幅以外の設定はhqqのデフォルトを使用
  41. 51 • LLMでは様々な低精度数値表現が提案・利用されている ◦ 学習: 8bitでの計算が主流となりつつある ◦ 推論: 1~2bit表現が実用化されつつある •

    成熟した技術ではなく、多数の課題が残っている ◦ 学習: 行列積以外の処理をボトルネックにしない方法 ◦ 推論: ▪ 1要素あたり1bit以下での保存 ▪ 低精度化したあとのLLMの精度評価 まとめ