Slide 1

Slide 1 text

Copyright © RevComm Inc. LoRAによるメモリ使用量削減の検証 Akihiro Katsuta Works Applications Co.,Ltd Masaki Ono RevComm Inc.

Slide 2

Slide 2 text

Copyright © RevComm Inc. LoRA: Low-Rank Adaptation of Large Language Models Edward Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang ,Weizhu Chen https://arxiv.org/abs/2106.09685 LoRAとは ● 再学習するベースモデルのパラメータを固定し、fine-tuning用に学習可能な階数分解行列(rank decomposition matrices)をモデルに挿入 ● モデルのパラメータを更新する従来のfine-tuningと比較して ○ 更新するパラメータ数を削減し、GPU要件を下げられる ○ 差分のみ保存すればよく、ストレージ要件も下げられる ○ かつ、同等の精度を出せる 2

Slide 3

Slide 3 text

Copyright © RevComm Inc. その他のメモリ削減 ● Mixed Precision (FP16) ○ 学習は32ビット浮動小数点演算(FP32)を使っているが、代替可能な部分を16ビットの低精度 dtypeに変更することでメモリや計算時間を削減 ○ しかし、若干の性能低下がネック ● Model Quantization (INT8) ○ ディープラーニングの最適化手法の一つで浮動小数点を8ビットの整数に変換 ○ 性能を維持しつつメモリの削減が可能 peftを使った導入もさほど難しくない 3

Slide 4

Slide 4 text

Copyright © RevComm Inc. 計算資源の選定基準 今回の実験で使う環境 ● AWS G5.xlarge ○ RAM: 16GB ○ GPU: NVIDIA A10G(24GB) 流行りのChatGPTと比べてG5.xlargeで実験を完結できている分にはコスパが良さそう ● ChatGPT: 520K [token/$(USD)]、 ● G5.xlarge 日本語T5 model (1GB): 5000K [token/$(USD)] 4

Slide 5

Slide 5 text

Copyright © RevComm Inc. ● 学習・評価データセット:SAMSum Corpus ○ 対話の抽象型要約データセット ● 学習モデル:flan-t5 ○ base (約 1GB, 250M param) ○ large (約 3 GB, 780M param) ○ xl (約 11.5 GB, 3B param) ● 学習時間やメモリ使用量を測るため、ハイパラは以下で固定 ○ epoch: 3 ○ train batch: 8 ○ = (train step: 5526) 実験設定 5

Slide 6

Slide 6 text

Copyright © RevComm Inc. ● flan-t5-baseのFine-TuningとLoRAを比較して、精度はほとんど変わらずGPUメモリ が約2割まで減らせている ● LoRAとfp16やint8を組み合わせるとよりメモリ効率が上がる ● LoRAを使うことで24GBのGPUでもxlサイズのモデルまで学習できる メモリ使用量削減の検証 言語モデル ROUGE-1(↑) ROUGE-2(↑) 学習時間 GPU Memory CPU Memory flan-t5-base-FT 51.3911 26.7651 40.27 min 14830 MB 4873 MB flan-t5-base-LoRA 51.6928 27.0978 49.11 min 3298 MB 4781 MB flan-t5-large-LoRA 53.8242 28.7078 139.36 min 6175 MB 3036 MB flan-t5-xl-LoRA 54.3119 30.1148 350.24 min 15491 MB 5943 MB flan-t5-xl-LoRA-fp16 53.8726 29.7212 224.69 min 10861 MB 8142 MB flan-t5-xl-LoRA-int8 54.6110 30.3880 331.53 min 9251 MB 8296 MB 6

Slide 7

Slide 7 text

Copyright © RevComm Inc. LoRAのr(次元)を4,8,16でそれぞれ15epoch程学習をさせ、その傾向を比較 ● rが小さい方がこのタスクでは過学習が抑えられるためか良さそう 付録: LoRAのハイパラ r=4 r=16 r=8 r=8 r=4 r=16 r=4 r=8 r=16 7

Slide 8

Slide 8 text

Copyright © RevComm Inc. ● GPUメモリ使用量の削減としてLoRAなどの検証を行った ● LoRA + int8などを組み合わせることでよりGPUメモリ要件を下げられる ● 所感として導入も楽でこれだけ削減できるのでかなり使い勝手は良いように思う まとめ 8

Slide 9

Slide 9 text

Copyright © RevComm Inc. Thank you!