w/ FP8 :B/F = 1.7 e-3 ※ LLMにおける計算のボトルネック = 行列積 LLMにおける利用事例・課題:学習 (計算高速化) 行列積 (MatMul, y=Wx) W: M x K x: K x N y: M x N M 8192 ~ 28672 K 8192 ~ 28672 N 4096 ※ B/Fの計算には入出力の総和を利用
w/ FP8 :B/F = 1.7 e-3 ※ LLMにおける計算のボトルネック = 行列積 LLMにおける利用事例・課題:学習 (計算高速化) 行列積 (MatMul, y=Wx) W: M x K x: K x N y: M x N TensorCoreの活用によりGPUの計算能力 (FLOP/s) をあげることが重要 M 8192 ~ 28672 K 8192 ~ 28672 N 4096 ※ B/Fの計算には入出力の総和を利用
M x K x: K x N y: M x N cast (BFloat16 → FP8) cast (BFloat16 → FP8) - per-tensor scaling - forwardではE4M3をbackwardではE5M2を使う [FP8 Formats for Deep Learning]
M x K x: K x N y: M x N cast (BFloat16 → FP8) cast (BFloat16 → FP8) - per-tensor scaling - forwardではE4M3をbackwardではE5M2を使う [FP8 Formats for Deep Learning] qx = Cast(x / s_prev, to=fp8) s_prev = max(abs(x)) n_amax_history = 1の時の疑似コード Delayed Scaling [Transformer Engine] scale factorとして過去の結果を使う ⇒ メモリアクセスを1回省略できる
M x K x: K x N y: M x N cast (BFloat16 → FP8) cast (BFloat16 → FP8) M=K=8192, N=4096の時の性能 (H100 SXM + Transformer Engine) 処理 実行時間 [us] 行列積 (FP8) 451 Cast (W) 67 Cast (x) 34 行列積 350 行列積 (BF16) 670
M x K x: K x N y: M x N cast (BFloat16 → FP8) cast (BFloat16 → FP8) M=K=8192, N=4096の時の性能 (H100 SXM + Transformer Engine) 処理 実行時間 [us] 行列積 (FP8) 451 Cast (W) 67 Cast (x) 34 行列積 350 行列積 (BF16) 670 BFloat16を使う処理の時間が無視できない