Upgrade to Pro — share decks privately, control downloads, hide ads and more …

Go言語での実装を通して学ぶLLMファインチューニングの仕組み / fukuokago22-l...

Go言語での実装を通して学ぶLLMファインチューニングの仕組み / fukuokago22-llm-peft

Avatar for monochromegane

monochromegane

August 25, 2025
Tweet

More Decks by monochromegane

Other Decks in Programming

Transcript

  1. ࡾ୐༔հ / Pepabo R&D Institute, GMO Pepabo, Inc. 2025.08.25 Fukuoka.go#22

    GoݴޠͰͷ࣮૷Λ௨ֶͯ͠Ϳ LLMϑΝΠϯνϡʔχϯάͷ࢓૊Έ
  2. • ҰํͰɺਪ࿦APIͱͯ͠ϞσϧαʔϏ ϯάʹཁٻ͞ΕΔGPUεϖοΫ͸ґવ ͱͯ͠ߴ͍ • → gpt-ossϞσϧͷαʔϏϯάʹ͓ ͚ΔϦΫΤετॲཧੑೳධՁ ― NVIDIA

    H100ɾA100ɾL4ͷൺֱ  6 LLMϞσϧαʔϏϯά Πϯελϯεछผ • a3-highgpu-1gʢ26 vCPU, 234 GB ϝϞϦ, 1 × NVIDIA H100 80GBʣ • a2-ultragpu-1gʢ12 vCPU, 170 GB ϝϞϦ, 1 × NVIDIA A100 80GBʣ • g2-standard-4ʢ4 vCPU, 16 GB ϝϞϦ, 1 × NVIDIA L4 24GBʣ https://rand.pepabo.com/article/2025/08/18/gpt-oss/
  3. • ޮ཰తͳϑΝΠϯνϡʔχϯάͷͨΊʹɺֶशର৅ͷύϥϝʔλ਺ΛݮΒ͢ • ͔Β ΁ʢx,y͸ύϥϝʔλ਺Ͱ͸ͳ͘ύϥϝʔλͷσʔλαΠζʣ • ॏΈ  Λ2ͭͷ௿ϥϯΫߦྻ 

    ͷੵʹΑΔߋ৽߲ΛՃ͑ͨܗࣜͱͯ͠ఆٛ •  4x x + 3y(x ≫ y) W B, A W′  ≜ W + sBA, B ∈ ℝm×r, A ∈ ℝr×n, W ∈ ℝm×n, s ∈ ℝ  9 ௿ϥϯΫదԠʢLow rank adaptation: LoRAʣ  W′   W  B  A +  BA = ( ) = m n m n m n m r r n 100,352 (m = 784,n = 128) 7,296 = 6,272 + 1,024 (m = 784,n = 128,r = 8)
  4. • ޮ཰తͳϑΝΠϯνϡʔχϯάͷͨΊʹɺݻఆͷύϥϝʔλͷαΠζΛݮΒ͢ •  ͔Β  ΁ • ݻఆύϥϝʔλ 

    ͸ྔࢠԽ  ͯ͠อଘ͠ɺֶश࣌ʹ͸ٯྔࢠԽͨ͠  Λ༻͍ɺLoRA ͷߋ৽߲ΛՃ͑Δ 4x x/q + 3y(x ≫ y) W Q(W) ˜ W = D(Q(W))  10 ྔࢠԽLoRA  W m n 100,352 × 4bytes (m = 784,n = 128,fp32)  Q(W) → 100,352 × 4bits + 784 × (4 + 1)bytes (m = 784,n = 128,int4) m n → ͜ͷ৔߹͸໿14%ʹ࡟ݮ TDBMFr = maxr − minr qmax − qmin , zr = round(− minr TDBMFr ), qr,c = clip(round( wr,c TDBMFr + zr)) ߦ͝ͱͷඇରশͳྔࢠԽͷྫɻ ֤஋Λ཭ࢄԽͨ͋͠Δ੔਺ͷ஋ʹׂΓ౰ͯɻ ʢٯྔࢠԽͷͨΊʹεέʔϧͱθϩϙΠϯτ͸อଘʣ ˎ QLoRA͸ޮ཰తͳྔࢠԽ΍༷ʑͳ޻෉Λࢪ͍ͯ͠Δ →
  5. • ޮ཰తͳϑΝΠϯνϡʔχϯάͷͨΊʹɺݻఆͷύϥϝʔλͷαΠζΛݮΒ͢ •  ͔Β  ΁ • ݻఆύϥϝʔλ 

    ͸ྔࢠԽ  ͯ͠อଘ͠ɺֶश࣌ʹ͸ٯྔࢠԽͨ͠  Λ༻͍ɺLoRA ͷߋ৽߲ΛՃ͑Δ 4x x/q + 3y(x ≫ y) W Q(W) ˜ W = D(Q(W))  11 ྔࢠԽLoRA  W m n 100,352 × 4bytes (m = 784,n = 128,fp32)  Q(W) → m n ٯྔࢠԽɻอଘͨ͠εέʔϧͱθϩϙΠϯτΛ༻͍ͯ ྔࢠԽஈ਺෼ͷුಈখ਺఺ͷ཭ࢄ஋ʹม׵ɻ ʢͪ͜Β͸ֶशର৅Ͱ͸ͳ͍ͷͰ཭ࢄతͰ΋Өڹ͸গͳ͍ɻ LoRAଆͷΞμϓλͰ཭ࢄԽͷޡࠩΛٵऩ͢Δʣ  ˜ W = D(Q(W)) m n → ˜ wr,c = (qr,c − zr ) TDBMFr → 100,352 × 4bytes (m = 784,n = 128,fp32)
  6. • ޮ཰తͳϑΝΠϯνϡʔχϯάͷͨΊʹɺݻఆͷύϥϝʔλͷαΠζΛݮΒ͢ •  ͔Β  ΁ • ݻఆύϥϝʔλ 

    ͸ྔࢠԽ  ͯ͠อଘ͠ɺֶश࣌ʹ͸ٯྔࢠԽͨ͠  Λ༻͍ɺLoRA ͷߋ৽߲ΛՃ͑Δ 4x x/q + 3y(x ≫ y) W Q(W) ˜ W = D(Q(W))  12 ྔࢠԽLoRA  W′   ˜ W = D(Q(W))  B  A +  BA = ( ) = m n m n m n m r r n 100,352 (m = 784,n = 128) 7,296 = 6,272 + 1,024 (m = 784,n = 128,r = 8)
  7. • χϡʔϥϧωοτϫʔΫͷਪ࿦΍ֶशͰ͸ɺϕΫτϧ΍ߦྻΛର৅ͱͨ͠ܭࢉ ͕ൃੜ͢ΔͷͰɺSIMDʹΑΔߴ଎Խ͕ظ଴Ͱ͖ΔߦྻܭࢉϥΠϒϥϦΛ࢖͏  15 GoݴޠͰͷ࣮૷ GoݴޠͷεϥΠεͰBA: 2.241916ms gonumͷmat.DenseͰBA: 578.209µs

    n=128, m=784, r=8ͷBAͷ࣮૷ͱ࣮ߦ࣌ؒͷൺֱ ࣮૷తʹ͸΄΅slice ͔Βͷ٧Ίସ͑ BA͸͚ͩ͜͜ ಺෦Ͱ͸SIMDͰෳ਺ ཁૉʹର͢ΔԋࢉΛҰ ׅͰ࣮ߦ n,m,rͷࡾॏϧʔϓ
  8. • ͱ͸͍͑ɺ਺ࣜʹΑΔఆ͕ٛઌɺ࣮૷͸ͦΕΛͦͷ··ॻ͚ͩ͘… • ॏΈ  ͷߋ৽ࣜ͸ҎԼ W(1)  19 ॏΈ

     ͷֶश W(1) W(1) ← W(1) − η ∂L ∂W(1) ∂L ∂W(1) = δ(1)x⊤, δ(1) = ((W(2))⊤δ(2)) ⊙ a′  (z(1)), δ(2) = ̂ y − y δ(2) δ(1) ޡࠩٯ఻೻๏ɻޡࠩͱਂ͍૚ͷॏΈͷؔ܎Λ௚઀ٻΊ Δͷ͸େม͕ͩɺޡࠩ৴߸  Λಋೖ͢Δ͜ͱͰҰͭઌͷ ૚ͷ৘ใΛ༻͍ͯߋ৽ྔΛٻΊΔ͜ͱ͕Ͱ͖Δ δ 0 1 783 ⋯ 0 1 127 ⋯ 0 9 ⋯ W(1) W(2)
  9.  20 ॏΈ  ͷྔࢠԽ W(1) • ֶशޙͷϞσϧʹରͯ͠ɺඞཁʹԠͯ͡ɺॏΈ  ΛྔࢠԽ

    • ࠓճͷ࣮૷ɾධՁͰ͸༧Ί  ͱٯྔࢠԽ·Ͱద༻ • ࠓճͷྔࢠԽͰ͸ɺ؆қతʹɺߦ͝ͱʹඇରশͳྔࢠԽΛద༻ • ຊདྷͷQLoRA͸ޮ཰తͳྔࢠԽ΍༷ʑͳ޻෉Λࢪ͍ͯ͠Δ W(1) ˜ W(1) = D(Q(W(1))) 0 1 783 ⋯ 0 1 127 ⋯ 0 9 ⋯ W(2) D(Q(W(1))) → ˜ W(1)
  10. • ֶशޙͷϞσϧͷॏΈ  ʢ΋͘͠͸ྔࢠԽద༻ॏΈ  ʣʹରͯ͠ɺ௿ϥϯ Ϋߦྻ  Λద༻ͯ͠ϑΝΠϯνϡʔχϯά •

    ॏΈ  ͷߋ৽ࣜ͸ҎԼʢͰ͍͍͸ͣʣ W(1) ˜ W(1) B, A B, A  21 LoRA: ௿ϥϯΫߦྻ  ͷֶश B, A B ← B − η ∂L ∂B ∂L ∂W′  (1) = δ(1)x⊤ δ(2) δ(1) 0 1 783 ⋯ 0 1 127 ⋯ 0 9 ⋯ W′  (1) = W(1) + sBA W(2) ∂L ∂B = s ⋅ ∂L ∂W(1)′  A⊤ A ← A − η ∂L ∂A ∂L ∂A = s ⋅ B⊤ ∂L ∂W(1)′  W′  (1) = ˜ W(1) + sBA  ྔࢠԽͷ৔߹ʢֶशํ๏͸มΘΒͣʣ ←
  11. ϕʔεϞσϧͷֶश݁Ռ 2025/08/25 15:52:12 base-train=54000, base-test=9000, ft-train=6000, ft-test=1000 2025/08/25 15:52:13 epoch

    1 | base-train loss 0.4883 acc 0.868 | base-test loss 0.2869 acc 0.919 2025/08/25 15:52:15 epoch 2 | base-train loss 0.2747 acc 0.922 | base-test loss 0.2310 acc 0.935 2025/08/25 15:52:16 epoch 3 | base-train loss 0.2240 acc 0.937 | base-test loss 0.1957 acc 0.942 2025/08/25 15:52:18 epoch 4 | base-train loss 0.1899 acc 0.946 | base-test loss 0.1724 acc 0.951 2025/08/25 15:52:20 epoch 5 | base-train loss 0.1654 acc 0.952 | base-test loss 0.1562 acc 0.958 2025/08/25 15:52:21 epoch 6 | base-train loss 0.1465 acc 0.959 | base-test loss 0.1406 acc 0.961 2025/08/25 15:52:23 epoch 7 | base-train loss 0.1313 acc 0.963 | base-test loss 0.1321 acc 0.964 2025/08/25 15:52:25 epoch 8 | base-train loss 0.1196 acc 0.967 | base-test loss 0.1223 acc 0.966 2025/08/25 15:52:26 epoch 9 | base-train loss 0.1095 acc 0.969 | base-test loss 0.1199 acc 0.966 2025/08/25 15:52:28 epoch 10 | base-train loss 0.1013 acc 0.972 | base-test loss 0.1107 acc 0.970 sample 0: pred=0 true=0 prob=1.000 sample 1: pred=9 true=9 prob=0.942 sample 2: pred=2 true=2 prob=0.997 sample 3: pred=2 true=2 prob=0.989 sample 4: pred=7 true=7 prob=1.000  24 ݩͷςετσʔλʹରͯ͠໿97%ਖ਼ղ ࣗ৴Λ࣋ͬͯճ౴Ͱ͖ͯͦ͏
  12. σʔλͷಛੑ͕มΘͬͨࡍͷ෼ྨੑೳͷมԽ 2025/08/25 15:52:28 [before FT] ft-test loss 1.1719 acc 0.681

     25 ໨Ӆ͠ͷςετσʔλʹରͯ͠໿68%ਖ਼ղ ݩͷςετσʔλʹର͢ΔҎԼͷ෼ྨੑೳ͔Β͸େ෯ʹ௿Լ base-test loss 0.1107 acc 0.970
  13. LoRAʹΑΔϑΝΠϯνϡʔχϯά 2025/08/25 15:52:28 [before FT] ft-test loss 1.1719 acc 0.681

    2025/08/25 15:52:28 [LoRA ep 1] ft-train loss 0.5404 acc 0.851 | ft-test loss 0.1591 acc 0.954 2025/08/25 15:52:28 [LoRA ep 2] ft-train loss 0.1744 acc 0.953 | ft-test loss 0.1452 acc 0.961 2025/08/25 15:52:28 [LoRA ep 3] ft-train loss 0.1640 acc 0.955 | ft-test loss 0.1474 acc 0.961 2025/08/25 15:52:29 [LoRA ep 4] ft-train loss 0.1588 acc 0.957 | ft-test loss 0.1441 acc 0.961 2025/08/25 15:52:29 [LoRA ep 5] ft-train loss 0.1550 acc 0.958 | ft-test loss 0.1449 acc 0.959 2025/08/25 15:52:29 [LoRA done] time=845.954709ms | ft-test loss 0.1449 acc 0.959  26 ໨Ӆ͠ͷςετσʔλʹରͯ͠໿96%ਖ਼ղ ݩͷ෼ྨੑೳఔ౓·Ͱਝ଎ʹճ෮ base-test loss 0.1107 acc 0.970
  14. ྔࢠԽʹΑΔ෼ྨੑೳͷมԽ 2025/08/25 15:52:29 == Quantization test: 4-bit == 2025/08/25 15:52:29

    [Q4] base-test loss 0.1111 acc 0.970 | ft-test (before FT) loss 1.2172 acc 0.667  27 ໨Ӆ͠ͷςετσʔλʹରͯ͠໿68%ਖ਼ղ ྔࢠԽͷ෼ྨੑೳͷྼԽ͸ܰඍ ft-test loss 1.1719 acc 0.681 ݩͷςετσʔλʹରͯ͠໿97%ਖ਼ղ ྔࢠԽͷ෼ྨੑೳͷྼԽ͸ܰඍ base-test loss 0.1107 acc 0.970
  15. ؆қQLoRAʹΑΔϑΝΠϯνϡʔχϯά 2025/08/25 15:52:29 == Quantization test: 4-bit == 2025/08/25 15:52:29

    [Q4] base-test loss 0.1111 acc 0.970 | ft-test (before FT) loss 1.2172 acc 0.667 2025/08/25 15:52:30 [Q4-LoRA ep 1] ft-train loss 0.5481 acc 0.848 | ft-test loss 0.1627 acc 0.955 2025/08/25 15:52:30 [Q4-LoRA ep 2] ft-train loss 0.1741 acc 0.954 | ft-test loss 0.1495 acc 0.960 2025/08/25 15:52:30 [Q4-LoRA ep 3] ft-train loss 0.1655 acc 0.956 | ft-test loss 0.1470 acc 0.960 2025/08/25 15:52:30 [Q4-LoRA ep 4] ft-train loss 0.1605 acc 0.957 | ft-test loss 0.1428 acc 0.963 2025/08/25 15:52:30 [Q4-LoRA ep 5] ft-train loss 0.1571 acc 0.958 | ft-test loss 0.1431 acc 0.961 2025/08/25 15:52:30 [Q4-LoRA done] time=911.142792ms | ft-test loss 0.1431 acc 0.961  28 ໨Ӆ͠ͷςετσʔλʹରͯ͠໿96%ਖ਼ղ ݩͷ෼ྨੑೳఔ౓·Ͱਝ଎ʹճ෮: base-test loss 0.1107 acc 0.970 LoRAͷ෼ྨਫ਼౓ͱ΋ಉ౳ఔ౓: ft-test loss 0.1449 acc 0.959
  16. • χϡʔϥϧωοτϫʔΫͷ࢓૊ΈΛֶͼ͍ͨ • εϚʔτχϡʔεגࣜձࣾɹཱੴ ݡޗ, ΍ֶ͘͞͠Ϳ σΟʔϓϥʔχϯά͕Θ͔Δ਺ֶͷ͖΄Μ ~ΞϠϊˍϛΦͱ ֶͿ σΟʔϓϥʔχϯάͷཧ࿦ͱ਺ֶɺ࣮૷~,

    ϚΠφϏग़൛, 2019೥07݄31೔. ISBNɿ978-4-8399-6837-3 • ϨϏϡʔࢀՃ͍ͯ͠·͢ • ޡࠩٯ఻೻๏ ~GMOϖύϘ৽ଔݚम2020 ػցֶशೖ໳ ิ଍ࢿྉ#06~, https://speakerdeck.com/ monochromegane/machine-learning-lecture-backpropagation • χϡʔϥϧωοτϫʔΫͷߴ଎Խʹֶ͍ͭͯͼ͍ͨ • ࠤ౻ཽഅ, ਂ૚χϡʔϥϧωοτϫʔΫͷߴ଎Խ, ٕज़ධ࿦ࣾ, 2024೥08݄26೔. ISBN: 978-4-297-14309-1  32 ࢀߟ