Upgrade to Pro — share decks privately, control downloads, hide ads and more …

第4回 関東Kaggler会 [Training LLMs with Limited VRAM]

Avatar for tascj tascj
August 23, 2025

第4回 関東Kaggler会 [Training LLMs with Limited VRAM]

Avatar for tascj

tascj

August 23, 2025
Tweet

Other Decks in Technology

Transcript

  1. Today’s topic Making LLM training accessible • What if you

    could: • Train 7B models on single RTX 4090 (24GB) • 14B on 48G GPUs • 32B on 80G/96G GPUs • With full quality • Full-parameter training, no compromise (LoRA, QLoRA, etc.) • Standard Adam • How? Two key techniques: 1. Using lower precision states (50% reduction) 2. E ffi cient CPU o ff l oading (75% reduction)
  2. Agenda 1. Understanding the Memory Challenge 2. Using Lower Precision

    Correctly 3. Efficient CPU Offloading 4. Practical Guidelines
  3. Training memory breakdown • Dynamic (Input-dependent) • Activations • Smaller

    batch + gradient accumulation • Activation checkpointing • Half-precision • CPU O ff l oading • Static (Input-independent) <- Our focus today • Parameters • Gradients • Adam states • Momentum • Variance
  4. The Static Memory Challenge • 16 bytes/pram • Parameters: 4

    bytes • Gradient: 4 bytes • Momentum: 4 bytes • Variance: 4 bytes • For a 7B model: . • Does not fi t in single A100 80G • amp.autocast(cache_enabled=False) does not help • FP32 (M, V) + FP16 (P, G) + FP32 (master P) -> 16 bytes/param • Our approach: 1. Using lower precision (BFloat16) correctly -> 8 bytes/param 2. O ffl oading states to CPU e ffi ciently -> 2 bytes/param 7 × 16 = 112GB
  5. Using lower precision correctly Case study • Kaggle: WSDM Cup

    - Multilingual Chatbot Arena • Qwen2.5-1.5B • Train for 1 epoch and track training loss • Try using BFloat16 to store states • Adam step uses fl oat32 internally • Lower bits are possible but we are focusing on BFloat16 today
  6. Using lower precision correctly Case study: WSDM Cup - Multilingual

    Chatbot Arena P G M V Convergence FP32 FP32 FP32 FP32
  7. Using lower precision correctly Case study: WSDM Cup - Multilingual

    Chatbot Arena P G M V Convergence FP32 FP32 FP32 FP32 FP32 FP32 BF16 FP32 ✅
  8. Using lower precision correctly Case study: WSDM Cup - Multilingual

    Chatbot Arena P G M V Convergence FP32 FP32 FP32 FP32 FP32 FP32 BF16 FP32 ✅ FP32 FP32 BF16 BF16 ✅
  9. Using lower precision correctly Case study: WSDM Cup - Multilingual

    Chatbot Arena P G M V Convergence FP32 FP32 FP32 FP32 FP32 FP32 BF16 FP32 ✅ FP32 FP32 BF16 BF16 ✅ FP32 BF16 BF16 BF16 ✅
  10. Using lower precision correctly Case study: WSDM Cup - Multilingual

    Chatbot Arena P G M V Convergence FP32 FP32 FP32 FP32 FP32 FP32 BF16 FP32 ✅ FP32 FP32 BF16 BF16 ✅ FP32 BF16 BF16 BF16 ✅ BF16 BF16 BF16 BF16 ❌
  11. Using lower precision correctly Case study: WSDM Cup - Multilingual

    Chatbot Arena P G M V Convergence FP32 FP32 FP32 FP32 FP32 FP32 BF16 FP32 ✅ FP32 FP32 BF16 BF16 ✅ FP32 BF16 BF16 BF16 ✅ BF16 BF16 BF16 BF16 ❌ BF16 FP32 FP32 FP32 ❌
  12. Using lower precision correctly Case study: WSDM Cup - Multilingual

    Chatbot Arena • Update steps: 1. 2. 3. Nearest-rounding of weights cancels small updates (arxiv) WFP32 t = to float(WBF16 t ) WFP32 t+1 = adam(WFP32 t , . . . ) WBF16 t+1 = to bfloat16(WFP32 t+1 )
  13. Using lower precision correctly Case study: WSDM Cup - Multilingual

    Chatbot Arena • Using FP32 P • forward&backward time P casting • PyTorch does not allow P&G in di ff erent dtypes by default • Inherit torch.Tensor (check TorchAO) • 10 bytes/param • 62.5% of baseline P G M V
  14. Using lower precision correctly Case study: WSDM Cup - Multilingual

    Chatbot Arena • Using an extra master P (the most common approach) • P for forward&backward • Master P for adam_step • Copy Master P to P after adam_step • 12 bytes/param • 75% of baseline P G M V Master P
  15. Using lower precision correctly Case study: WSDM Cup - Multilingual

    Chatbot Arena • Revisiting FP32 -> BF16 casting • Need 17 bits to save rounding error • Rounding error in 16 bits • Custom FP32 -> BF16 rounding rule • Sacri fi ce 1 bit to make it FP31 master • Check the code for details
  16. Using lower precision correctly Case study: WSDM Cup - Multilingual

    Chatbot Arena • Using an extra bu ff er to store rounding error (BF16 only) • P for forward&backward • Reconstruct before adam_step • Decompose after adam_step • 10 bytes/param • 62.5% of baseline • Kahan summation can do similar but not limited to BF16 P G M V Rounding Error
  17. Using lower precision correctly Case study: WSDM Cup - Multilingual

    Chatbot Arena • Stochastic rounding • Probabilistic rounding instead of deterministic rounding (round to nearest) • Preserves unbiased expectation during precision conversion • Used at the end of adam_step • 8 bytes/param • 50% of baseline P G M V
  18. Wrap-Up • By default, we need 16 bytes/param to store

    states • Optimizer implementation may fail when all states are in BF16 • PagedAdam(model.b fl oat16().parameters()) fails to converge • With BF16 and stochastic rounding, we can reduce states VRAM to 8 bytes/param • 56GB states when training a 7B Model • single A100/H100/RTX Pro 6000 Blackwell • Simply all BF16, no autocast, no grad scaler • Can we do better? For • 24G VRAM: RTX 3090/RTX 4090/RTX 5090 • 48G VRAM: RTX A6000/RTX 6000Ada
  19. Efficient CPU offloading Backward H2D D2H Backward H2D D2H Backward

    H2D D2H Backward H2D D2H Backward H2D D2H Backward H2D D2H
  20. Efficient CPU offloading • LLM parameters are mainly nn.Linear parameters

    • nn.Linear(bias=False) in training • Forward • output = input @ P.t() • Backward • grad_input = grad_output @ P • G = grad_output.t() @ input • Optimizer step • adam_step(P, G, M, V, lr=…) Stage States Involved Execution time Forward P, activation Input-dependent Backward P, G, activation Input-dependent Optimizer step P, G, M, V Input-independent Usually much faster than forward/backward
  21. Efficient CPU offloading Typical PyTorch training pipeline • Forward •

    Backward • Optimizer step Layer0 Layer1 Layer2
  22. Efficient CPU offloading Re-arranged training pipeline for o ff l

    oading • Forward • Backward & Optimizer step Layer0 Layer1 Layer2
  23. Efficient CPU offloading H2D -> update -> D2H -> release

    • Backward & Optimizer step • grad_input = grad_output @ P • G = grad_output.t() @ input • Copy M, V from CPU • adam_step(P, G, M, V, lr=…) • Copy updated M, V to CPU • Release G, M, V on GPU States on GPU reduced to model P + single parameter G, M, V Layer0 Layer1 Layer2
  24. Efficient CPU offloading Overlapping computation and communication • Backward •

    grad_input = grad_output @ P • G = grad_output.t() @ input • H2D • Copy M, V from CPU • Optimizer step • Accumulate grad_weight to G • adam_step(P, G, M, V, lr=…) • D2H • Copy updated M, V to CPU • Release G, M, V on GPU • Sequential • Overlapping computation and communication Default stream H2D stream D2H stream
  25. Efficient CPU offloading Overlapping computation and communication • nn.Linear(bias=False) •

    Per-token backward time • • Weight (2 bytes/param) transfer time • • Number of tokens for full overlap • • RTX4090 (165TFLOPS, PCIe 4.0 32GB/s) • Ideally >2600 tokens to overlap gradient accumulation 4 × Hin × Hout TFLOPS × 1012 Hin × Hout × 2 BandwidthGB/s × 109 Hin × Hout × 2 BandwidthGB/s × 109 ÷ 4 × Hin × Hout TFLOPS × 1012 = TFLOPS 2 × BandwidthGB/s × 1000 Default stream H2D stream D2H stream Default stream H2D stream D2H stream
  26. Efficient CPU offloading Case study • Kaggle: WSDM Cup -

    Multilingual Chatbot Arena • Qwen2.5 without lm_head • Train for 1 epoch • Activation checkpointing enabled • BF16 + stochastic rounding • ~9.5k tokens per forward • ~24.7k tokens per optimizer step
  27. Efficient CPU offloading Case study: WSDM Cup - Multilingual Chatbot

    Arena Model GPU O ff loading max_memory_a llocated Epoch time (minutes) O ff loading overhead Qwen2.5-1.5B RTX 4090 FALSE 14169M 89.3 Qwen2.5-1.5B RTX 4090 TRUE 5303M 96.3 ~8% Qwen2.5-7B RTX 4090 TRUE 18783M 382.6 Qwen2.5-7B A100 FALSE 59332M 255.8 Qwen2.5-7B A100 TRUE 18789M 271.8 ~11%
  28. Efficient CPU offloading Case study: WSDM Cup - Multilingual Chatbot

    Arena Model GPU O ff loading Average tokens per forward Epoch time (minutes) O ff loading overhead Qwen2.5-7B A100 FALSE 255.8 Qwen2.5-7B A100 TRUE 9.5k 271.8 ~11% Qwen2.5-7B A100 TRUE 6.4k 300.5 ~17% Qwen2.5-7B A100 TRUE 4.7k 334.2 ~31% • ~24.7k tokens per optimizer step
  29. Efficient CPU offloading Case study: WSDM Cup - Multilingual Chatbot

    Arena Forward Backward & GradAcc/OptStep Copy of token embedding • Qwen2.5-1.5B • RTX 4090 • ~9.5k tokens per forward • ~24.7k tokens per optimizer step
  30. Efficient CPU offloading Case study: WSDM Cup - Multilingual Chatbot

    Arena Backward & Gradient Accumulation. D2H: G. H2D: G • Qwen2.5-1.5B • RTX 4090 • ~9.5k tokens per forward • ~24.7k tokens per optimizer step Backward & Optimizer Step. D2H: M, V. H2D: G, M, V Backward & Gradient Reset. D2H: G. H2D: None
  31. Efficient CPU offloading Existing implementations • Deepspeed/TorchAO/Megatron o ff l

    oading using CPU adam_step • Pros • M, V kept on CPU, less H2D (P) and D2H (G) transfer • Cons • Fused CPU SIMD adam_step is fast, but still much slower than GPU adam_step • Fused CPU adam_step does not handle BF16 properly (FP32 master, stochastic rounding, etc.) • Some does not support gradient accumulation
  32. E ffi cient CPU o ff l oading https://github.com/tascj/o ffl

    oad_adam • Pure PyTorch&Triton • GPU adam_step implemented in triton • Fused • Internally uses FP32 computation • Handles BF16 states properly • FP32 master • Stochastic rounding • Computation & communication overlap • Extensible • QAT with TorchAO • Int8QuantizedTrainingLinearWeight supports FP32 -> Int8 with stochastic rounding (Train 14B models with 24G VRAM) • Int4/NF4 QAT using FP32 master mode also works fi ne (Train 32B models with 24?/32G VRAM)
  33. Wrap-Up • By o ff l oading G, M, V

    to CPU, we reduced states VRAM by 75% • By overlapping state transfer and computation, we can hide the overhead of o ffl oading • State VRAM requirement is now 2 bytes/param (BF16 P) + single parameter G, M, V. That is almost 12.5% of baseline.
  34. Summary • What we achieved: • VRAM reduction 16 ->

    2 bytes/param • 7B training on RTX 4090 • Acceptable performance overhead (<10% in WSDM Cup) • Key techniques • BFloat16 + stochastic rounding • CPU o ffl oading • GPU Adam step • Computation-communication overlap
  35. Practical guidelines Hardware • RTX 4090 • ~8% overhead when

    (~9k tokens per forward, ~25k tokens per optimizer step) • RTX 5090 versus RTX 4090 • 2x transfer (PCIe 5.0) • 1.3x compute Less tokens to overlap compute/transfer, more o ffl oading friendly than 4090 • GH200 • $1.49/hour on Lambda Cloud • 96G VRAM, 480GB of LPDDR5X (training 32B models) • NVLink-C2C, >330GB/s H2D, D2H transfer • Fast even without computation communication overlap • More than 3x compute compared to A100 Perfect fi t for training with o ff l oading. Training a 32B model on GH200 + o ff l oading is much faster than A100x4 + FSDP full shard
  36. Practical guidelines Hardware and model size • https://github.com/tascj/ o ffl

    oad_adam • Use lower precision states • BF16 + stochastic rounding • O ffl oading states to GPU • Computation and communication overlap Model size Host RAM (GB) VRAM (GB) Method 7B >42 24/48 BF16 states + O ffl oading 80/96 BF16 states 14B >84 24 Int8 QAT + BF16 states + O ffl oading 48/80/96 BF16 states + O ffl oading 32B >192 48 Int8 QAT + BF16 states + O ffl oading 80/96 BF16 states + O ffl oading