Slide 1

Slide 1 text

大規模言語モデル (LLM) における低精度数値表現 株式会社 Preferred Networks リサーチャー 三上 裕明

Slide 2

Slide 2 text

2 ● 三上 裕明 (みかみ ひろあき) ● (主な) 業務: DNN学習等の高速化/分散処理 ○ LLMの事前学習高速化 ■ PLaMo-13Bを公開しました ○ MN-Core 向けコンパイラの開発 ○ その他GPUを用いた分散処理の最適化 ● 経歴 ○ 株式会社 Preferred Elements (2023/11 〜) ○ 株式会社 Preferred Networks (2019/9 〜) ○ ソニー株式会社 (2017/4 〜 2019/8) ○ 東京大学大学院 (2015/4 〜 2017/3) 自己紹介

Slide 3

Slide 3 text

3 ● Preferred Networks (PFN), Preferred Elements (PFE) について ● LLMと低精度数値表現 ● 低精度数値表現の手法 ○ フォーマット ○ cast (量子化) 手法 ● LLMにおける利用事例・課題 ○ 学習における事例 ○ 推論における事例 ● まとめ 目次

Slide 4

Slide 4 text

4 Preferred Networks, Preferred Elementsについて

Slide 5

Slide 5 text

5 Preferred Networks(PFN)会社概要 PFNは深層学習などのソフトウェア技術と、それを支える計算インフラなどのハードウェア技術を 融合し、様々な産業領域で最先端技術の実用化・事業化に取り組んでいます。 事業化領域 研究領域 計算インフラ
 機械学習・深層学習 
 シミュレーション
 
 
 画像認識
 
 
 
 自然言語処理
 
 
 
 
 ロボティクス
 
 
 
 最適化
 
 
 製造業
 交通
 システム
 エンタメ
 その他 プラント
 最適化
 材料探索
 創薬・
 ヘルスケア
 ロボット
 
 データ
 生成補完
 
 
 異常検知
 
 


Slide 6

Slide 6 text

6 ● LLMを含む基盤モデルの研究開発を行うPFNの子会社 ● NEDOの採択 (GENIAC, Generative AI Accelerator Challenge) を受け マルチモーダル基盤モデルの開発中 Preferred Elements (PFE) 会社概要

Slide 7

Slide 7 text

7 LLMと低精度数値表現

Slide 8

Slide 8 text

8 Deep Neural Networkにおける低精度数値表現 16bit浮動小数点 一部の最適化したケースで利用 ● ResNet-50 ● BERT ● Instant NGP ● … 8bit整数 推論高速化で主に利用 [Practical Quantization in PyTorch] [Introduction to Quantization on PyTorch] [NVIDIA TensorRT] 低精度数値表現とは? 32bit未満のフォーマット

Slide 9

Slide 9 text

9 Deep Neural Networkにおける低精度数値表現 16bit浮動小数点 一部の最適化したケースで利用 ● ResNet-50 ● BERT ● Instant NGP ● … ※ 統計はみつかりませんでしたが、PFN社内の状況やQuantizationについてのpytorchのドキュメントをもとにしています 8bit整数 推論高速化で主に利用 [Practical Quantization in PyTorch] [Introduction to Quantization on PyTorch] [NVIDIA TensorRT] LLM以外のDNNではいまだに32bit浮動小数点が主流 ※ 低精度数値表現は特殊な最適化の一つ 低精度数値表現とは? 16bit以下のフォーマット

Slide 10

Slide 10 text

10 LLM における状況の変化 BERT (large) [Accelerated Large Batch Optimization of BERT Pretraining in 54 minutes] パラメータ数 334M (1.3 GB w/ FP32) 学習時間 (16bit) 1時間 (V100 1536台)

Slide 11

Slide 11 text

11 LLM における状況の変化 BERT (large) [Accelerated Large Batch Optimization of BERT Pretraining in 54 minutes] パラメータ数 334M (1.3 GB w/ FP32) 学習時間 (16bit) 1時間 (V100 1536台) FP32でも動作する 低精度表現は必須ではない

Slide 12

Slide 12 text

12 LLM における状況の変化 BERT (large) LLaMA-65B [Accelerated Large Batch Optimization of BERT Pretraining in 54 minutes] パラメータ数 65 B (260GB w/ FP32) 学習時間 (16bit) 3週間 (A100 2048台) [LLaMA: Open and Efficient Foundation Language Models] パラメータ数 334M (1.3 GB w/ FP32) 学習時間 (16bit) 1時間 (V100 1536台) FP32でも動作する 低精度表現は必須ではない

Slide 13

Slide 13 text

13 LLM における状況の変化 BERT (large) LLaMA-65B [Accelerated Large Batch Optimization of BERT Pretraining in 54 minutes] パラメータ数 65 B (260GB w/ FP32) 学習時間 (16bit) 3週間 (A100 2048台) [LLaMA: Open and Efficient Foundation Language Models] パラメータ数 334M (1.3 GB w/ FP32) 学習時間 (16bit) 1時間 (V100 1536台) FP32でも動作する 低精度表現は必須ではない FP32では現実的な条件で動かない 低精度化は前提となることが多い

Slide 14

Slide 14 text

14 低精度数値表現の手法

Slide 15

Slide 15 text

15 低精度数値表現例: 数値フォーマット 浮動小数点 整数 その他 16 bit 1 bit 2 bit 8 bit BFloat16 FP8 (E4M3 / E5M2) FP6 [FP6-LLM] int8 int4 int2 [HQQ] 3値 [BitNet] NormalFloat4 [QLoRA] Dynamic Tree Quantization [8-bit optimizer] 4 bit FP4

Slide 16

Slide 16 text

16 低精度数値表現例: 数値フォーマット 浮動小数点 整数 その他 16 bit 1 bit 2 bit 8 bit BFloat16 FP8 (E4M3 / E5M2) FP6 [FP6-LLM] int8 int4 int2 [HQQ] 3値 [BitNet] NormalFloat4 [QLoRA] Dynamic Tree Quantization [8-bit optimizer] 4 bit FP4 GPUでの計算高速化に 利用できる

Slide 17

Slide 17 text

17 低精度数値表現例: 数値フォーマット 浮動小数点 整数 その他 16 bit 1 bit 2 bit 8 bit BFloat16 FP8 (E4M3 / E5M2) FP6 [FP6-LLM] int8 int4 int2 [HQQ] 3値 [BitNet] NormalFloat4 [QLoRA] Dynamic Tree Quantization [8-bit optimizer] 4 bit FP4 GPUでの計算高速化に 利用できる 4bit未満の精度では 整数型が一般的

Slide 18

Slide 18 text

18 低精度数値表現例: 数値フォーマット 浮動小数点 整数 その他 16 bit 1 bit 2 bit 8 bit BFloat16 FP8 (E4M3 / E5M2) FP6 [FP6-LLM] int8 int4 int2 [HQQ] 3値 [BitNet] NormalFloat4 [QLoRA] Dynamic Tree Quantization [8-bit optimizer] 4 bit FP4 GPUでの計算高速化に 利用できる 4bit未満の精度では 整数型が一般的 速度が重要でない用途 に向く

Slide 19

Slide 19 text

19 低精度数値表現例: 数値フォーマット 浮動小数点 整数 その他 16 bit 1 bit 2 bit 8 bit BFloat16 FP8 (E4M3 / E5M2) FP6 [FP6-LLM] int8 int4 int2 [HQQ] 3値 [BitNet] NormalFloat4 [QLoRA] Dynamic Tree Quantization [8-bit optimizer] 4 bit FP4

Slide 20

Slide 20 text

20 低精度数値表現例: cast手法 (量子化) 高精度 (BFloat16 or FP32) 低精度表現 (~8bit) [Mixed Precision Training]

Slide 21

Slide 21 text

21 低精度数値表現例: cast手法 (量子化) 高精度 (BFloat16 or FP32) 低精度表現 (~8bit) [Mixed Precision Training] - オーバーフロー / アンダーフロー [Mixed Precision Training] - その他の数値誤差

Slide 22

Slide 22 text

22 低精度数値表現例: cast手法 (量子化) スケーリング s = max(abs(x)) qx = Quantize(x / s) # 値を[-1:1]の範囲にしてから量子化 # ⇒ オーバーフロー/アンダーフローを防ぐ -100 1 100 1000 BFloat16 int8 -100 1 100 -24 -12 0 12 127 scale = 1000 / 127 w/o scaling w/ scaling

Slide 23

Slide 23 text

23 低精度数値表現例: cast手法 (量子化) block-wise (fine-grained) 量子化 -1 1 1 10 per-tensor (coarse-grained) 量子化 0 0 0 1 3値量子化 scale=10 0 0 0 10 復元 -1 1 1 10 block-wise (fine-grained) 量子化 -1 1 0 1 3値量子化, block-size=2 scale=1, 10 -1 -1 0 10 復元

Slide 24

Slide 24 text

24 低精度数値表現例: cast手法 (量子化) block-wise (fine-grained) 量子化 -1 1 1 10 per-tensor (coarse-grained) 量子化 0 0 0 1 3値量子化 scale=10 0 0 0 10 復元 -1 1 1 10 block-wise (fine-grained) 量子化 -1 1 0 1 3値量子化, block-size=2 scale=1, 10 -1 -1 0 10 復元 Tensorごとに単一のscaleを用いる

Slide 25

Slide 25 text

25 低精度数値表現例: cast手法 (量子化) block-wise (fine-grained) 量子化 -1 1 1 10 per-tensor (coarse-grained) 量子化 0 0 0 1 3値量子化 scale=10 0 0 0 10 復元 -1 1 1 10 block-wise (fine-grained) 量子化 -1 1 0 1 3値量子化, block-size=2 scale=1, 10 -1 -1 0 10 復元 一定の要素数ごとにscaleを用意する ⇒ 量子化誤差が小さくなる 4bit以下への量子化で特に重要 [ZeROQuant(4 + 2)] Tensorごとに単一のscaleを用いる

Slide 26

Slide 26 text

26 LLMにおける利用事例・課題

Slide 27

Slide 27 text

27 LLMにおける利用事例・課題:学習 Icon pack by Icons8 - https://icons8.com optimizer state 計算 (fwd + bwd) 作業領域 optimizer state 計算 (fwd + bwd) 作業領域 通信

Slide 28

Slide 28 text

28 LLMにおける利用事例・課題:学習 Icon pack by Icons8 - https://icons8.com optimizer state 計算 (fwd + bwd) 作業領域 optimizer state 計算 (fwd + bwd) 作業領域 通信 計算の高速化 - BF16 TensorCoreの利用 - FP8 TensorCoreの利用

Slide 29

Slide 29 text

29 LLMにおける利用事例・課題:学習 Icon pack by Icons8 - https://icons8.com optimizer state 計算 (fwd + bwd) 作業領域 optimizer state 計算 (fwd + bwd) 作業領域 通信 通信の高速化 - collectiveの量子化 - モデル並列による削減 メモリ消費の削減 - 値の量子化 - モデル並列による削減 計算の高速化 - BF16 TensorCoreの利用 - FP8 TensorCoreの利用

Slide 30

Slide 30 text

30 要素数 (Llama2 70Bの場合の一例) M=K=8192の時 :B/F = 5.0 e-4 H100 w/ FP8 :B/F = 1.7 e-3 ※ LLMにおける計算のボトルネック = 行列積 LLMにおける利用事例・課題:学習 (計算高速化) 行列積 (MatMul, y=Wx) W: M x K x: K x N y: M x N M 8192 ~ 28672 K 8192 ~ 28672 N 4096 ※ B/Fの計算には入出力の総和を利用

Slide 31

Slide 31 text

31 要素数 (Llama2 70Bの場合の一例) M=K=8192の時 :B/F = 5.0 e-4 H100 w/ FP8 :B/F = 1.7 e-3 ※ LLMにおける計算のボトルネック = 行列積 LLMにおける利用事例・課題:学習 (計算高速化) 行列積 (MatMul, y=Wx) W: M x K x: K x N y: M x N TensorCoreの活用によりGPUの計算能力 (FLOP/s) をあげることが重要 M 8192 ~ 28672 K 8192 ~ 28672 N 4096 ※ B/Fの計算には入出力の総和を利用

Slide 32

Slide 32 text

32 LLMにおける計算のボトルネック = 行列積 LLMにおける利用事例・課題:学習 (計算高速化) 行列積 (MatMul, y=Wx) W: M x K x: K x N y: M x N cast (BFloat16 → FP8) cast (BFloat16 → FP8)

Slide 33

Slide 33 text

33 LLMにおける計算のボトルネック = 行列積 LLMにおける利用事例・課題:学習 (計算高速化) 行列積 (MatMul, y=Wx) W: M x K x: K x N y: M x N cast (BFloat16 → FP8) cast (BFloat16 → FP8) - per-tensor scaling - forwardではE4M3をbackwardではE5M2を使う [FP8 Formats for Deep Learning]

Slide 34

Slide 34 text

34 LLMにおける計算のボトルネック = 行列積 LLMにおける利用事例・課題:学習 (計算高速化) 行列積 (MatMul, y=Wx) W: M x K x: K x N y: M x N cast (BFloat16 → FP8) cast (BFloat16 → FP8) - per-tensor scaling - forwardではE4M3をbackwardではE5M2を使う [FP8 Formats for Deep Learning] qx = Cast(x / s_prev, to=fp8) s_prev = max(abs(x)) n_amax_history = 1の時の疑似コード Delayed Scaling [Transformer Engine] scale factorとして過去の結果を使う ⇒ メモリアクセスを1回省略できる

Slide 35

Slide 35 text

35 LLMにおける計算のボトルネック = 行列積 LLMにおける利用事例・課題:学習 (計算高速化) 行列積 (MatMul, y=Wx) W: M x K x: K x N y: M x N cast (BFloat16 → FP8) cast (BFloat16 → FP8) M=K=8192, N=4096の時の性能 (H100 SXM + Transformer Engine) 処理 実行時間 [us] 行列積 (FP8) 451 Cast (W) 67 Cast (x) 34 行列積 350 行列積 (BF16) 670

Slide 36

Slide 36 text

36 LLMにおける計算のボトルネック = 行列積 LLMにおける利用事例・課題:学習 (計算高速化) 行列積 (MatMul, y=Wx) W: M x K x: K x N y: M x N cast (BFloat16 → FP8) cast (BFloat16 → FP8) M=K=8192, N=4096の時の性能 (H100 SXM + Transformer Engine) 処理 実行時間 [us] 行列積 (FP8) 451 Cast (W) 67 Cast (x) 34 行列積 350 行列積 (BF16) 670 BFloat16を使う処理の時間が無視できない

Slide 37

Slide 37 text

37 LLMにおける利用事例・課題:推論 Icon pack by Icons8 - https://icons8.com パラメータ KV キャッシュ KV キャッシュ KV キャッシュ 生成token 生成token 生成token

Slide 38

Slide 38 text

38 LLMにおける利用事例・課題:推論 Icon pack by Icons8 - https://icons8.com パラメータ KV キャッシュ KV キャッシュ KV キャッシュ 生成token 生成token 生成token 計算の高速化・効率化 - パラメータの量子化 - モデルサイズの削減

Slide 39

Slide 39 text

39 LLMにおける利用事例・課題:推論 Icon pack by Icons8 - https://icons8.com パラメータ KV キャッシュ KV キャッシュ KV キャッシュ 生成token 生成token 生成token 省メモリ化 - パラメータの量子化 - KVキャッシュの量子化 - モデル並列による推論 計算の高速化・効率化 - パラメータの量子化 - モデルサイズの削減

Slide 40

Slide 40 text

40 Mixed Precision Decomposition [LLM.int8()] LLMにおける利用事例・課題:推論 (省メモリ化) 学習済LLMでは一部のchannelだけ scaleが大きくなる [KVQuant] [AWQ] 外れ値のchannelだけは16bitのまま 保持することで性能を維持する

Slide 41

Slide 41 text

41 Mixed Precision Decomposition [LLM.int8()] LLMにおける利用事例・課題:推論 (省メモリ化) 学習済LLMでは一部のchannelだけ scaleが大きくなる [KVQuant] [AWQ] 外れ値のchannelだけは16bitのまま 保持することで性能を維持する どうやって外れ値をみつけるか? - 適当な閾値を決める [LLM.int8()] - 小さいデータセットを流して計算途中の値を使う [AWQ] (data dependent quantization)

Slide 42

Slide 42 text

42 パラメータ LLMにおける利用事例・課題:推論 (省メモリ化) KV キャッシュ 処理 scaling format mixed-prec. decomposition bitsandbytes block-wise NormalFloat4 no LLM.int8() block-wise int8 yes AWQ block-wise int4 yes HQQ block-wise int1/int2/int4 no 処理 scaling format mixed-prec. decomposition FlexGen block-wise int4 no KVQuant block-wise int3 yes

Slide 43

Slide 43 text

43 ロスレス符号化 [QMoE] LLMにおける利用事例・課題:推論 (省メモリ化) 高精度 (BFloat16) 低精度表現 (3値 or 2bit) 量子化

Slide 44

Slide 44 text

44 ロスレス符号化 [QMoE] LLMにおける利用事例・課題:推論 (省メモリ化) 高精度 (BFloat16) 低精度表現 (3値 or 2bit) 量子化 量子化は多くの場合性能と消費メモリのトレードオフ で優れている [Pruning vs Quantization]

Slide 45

Slide 45 text

45 ロスレス符号化 [QMoE] LLMにおける利用事例・課題:推論 (省メモリ化) 高精度 (BFloat16) 低精度表現 (3値 or 2bit) 量子化 量子化は多くの場合性能と消費メモリのトレードオフ で優れている [Pruning vs Quantization] 符号化 可逆圧縮 - 3値からのさらなる量子化は限界がある - アンダーフローにより0の割合は増えていく 可逆圧縮によりメモリ消費の削減を狙う (0.80 ~ 0.93 bit/要素にできる [QMoE])

Slide 46

Slide 46 text

46 ロスレス符号化 [QMoE] LLMにおける利用事例・課題:推論 (省メモリ化) モデル Bfloat16 HQQ (int4) +zstd HQQ (3値) +zstd 3B 5.6GiB 1.8GiB (31%) 1.5 GiB (27%) 1.1GiB (19%) 0.64GiB (11%) 10B (MoE) 21 GiB 6.7 GiB (31%) 5.4 GiB (25%) 4.0GiB (19%) 2.2GiB (10%) ※ bit幅以外の設定はhqqのデフォルトを使用

Slide 47

Slide 47 text

47 ロスレス符号化 [QMoE] LLMにおける利用事例・課題:推論 (省メモリ化) モデル Bfloat16 HQQ (int4) +zstd HQQ (3値) +zstd 3B 5.6GiB 1.8GiB (31%) 1.5 GiB (27%) 1.1GiB (19%) 0.64GiB (11%) 10B (MoE) 21 GiB 6.7 GiB (31%) 5.4 GiB (25%) 4.0GiB (19%) 2.2GiB (10%) scale factorのサイズは無視できない割合をしめる ※ bit幅以外の設定はhqqのデフォルトを使用

Slide 48

Slide 48 text

48 ロスレス符号化 [QMoE] LLMにおける利用事例・課題:推論 (省メモリ化) モデル Bfloat16 HQQ (int4) +zstd HQQ (3値) +zstd 3B 5.6GiB 1.8GiB (31%) 1.5 GiB (27%) 1.1GiB (19%) 0.64GiB (11%) 10B (MoE) 21 GiB 6.7 GiB (31%) 5.4 GiB (25%) 4.0GiB (19%) 2.2GiB (10%) ※ bit幅以外の設定はhqqのデフォルトを使用

Slide 49

Slide 49 text

49 ロスレス符号化 [QMoE] LLMにおける利用事例・課題:推論 (省メモリ化) モデル Bfloat16 HQQ (int4) +zstd HQQ (3値) +zstd 3B 5.6GiB 1.8GiB (31%) 1.5 GiB (27%) 1.1GiB (19%) 0.64GiB (11%) 10B (MoE) 21 GiB 6.7 GiB (31%) 5.4 GiB (25%) 4.0GiB (19%) 2.2GiB (10%) - 汎用的な圧縮アルゴリズムは効果が薄い - DNNモデルの種類の影響はみえない ※ bit幅以外の設定はhqqのデフォルトを使用

Slide 50

Slide 50 text

50 まとめ

Slide 51

Slide 51 text

51 ● LLMでは様々な低精度数値表現が提案・利用されている ○ 学習: 8bitでの計算が主流となりつつある ○ 推論: 1~2bit表現が実用化されつつある ● 成熟した技術ではなく、多数の課題が残っている ○ 学習: 行列積以外の処理をボトルネックにしない方法 ○ 推論: ■ 1要素あたり1bit以下での保存 ■ 低精度化したあとのLLMの精度評価 まとめ

Slide 52

Slide 52 text

Making the real world computable