Upgrade to Pro
— share decks privately, control downloads, hide ads and more …
Speaker Deck
Features
Speaker Deck
PRO
Sign in
Sign up for free
Search
Search
Distributed and Parallel Training for PyTorch
Search
tattaka
August 22, 2024
2
310
Distributed and Parallel Training for PyTorch
社内のAI技術共有会で使用した資料です。
PyTorchで使われている分散学習の仕組みについて紹介しました。
tattaka
August 22, 2024
Tweet
Share
More Decks by tattaka
See All by tattaka
論文紹介 DSRNet: Single Image Reflection Separation via Component Synergy (ICCV 2023)
tattaka
0
320
最近のVisual Odometry with Deep Learning
tattaka
1
1.8k
Fuzzy Metaballs: Approximate Differentiable Rendering with Algebraic Surfaces
tattaka
0
550
Featured
See All Featured
Building Your Own Lightsaber
phodgson
103
6.1k
[RailsConf 2023 Opening Keynote] The Magic of Rails
eileencodes
28
9.1k
A better future with KSS
kneath
238
17k
Large-scale JavaScript Application Architecture
addyosmani
510
110k
jQuery: Nuts, Bolts and Bling
dougneiner
61
7.5k
Scaling GitHub
holman
458
140k
個人開発の失敗を避けるイケてる考え方 / tips for indie hackers
panda_program
93
16k
Happy Clients
brianwarren
98
6.7k
The Language of Interfaces
destraynor
154
24k
A Tale of Four Properties
chriscoyier
156
23k
Bootstrapping a Software Product
garrettdimon
PRO
305
110k
Rails Girls Zürich Keynote
gr2m
94
13k
Transcript
AI 2024.8.22 @tattaka_sun GO株式会社 Distributed and Parallel Training for PyTorch
AI 2 ▪ 基盤モデルの流行などにより、大きなモデルを効率的に学 習を進める手法の需要が高まっている ▪ このスライドでわかるようになること ▪ PyTorchによる分散学習の基本的な仕組み ▪
どのような分散学習の手法があるか ▪ それぞれの分散学習手法の使い分け はじめに
AI 3 PyTorch Distributed Overview High level Low level Communication
backend Communication APIs (C10D) Sharding primitives Parallelism APIs Laucher Gloo Open MPI NCCL send recv broadcast all_reduce reduce all_gather gather scatter reduce_scatter all_to_all barrier DTensor DeviceMesh Data-Parallel Distributed Data-Parallel Fully Sharded Data-Parallel ZeRO series (DeepSpeed etc…) Tensor Parallel Pipeline Parallel torchrun (Elastic Launch) torch.distributed.launch
AI 4 PyTorch Distributed Overview High level Low level Communication
backend Communication APIs (C10D) Sharding primitives Parallelism APIs Laucher Gloo Open MPI NCCL send recv broadcast all_reduce reduce all_gather gather scatter reduce_scatter all_to_all barrier DTensor DeviceMesh Data-Parallel Distributed Data-Parallel Fully Sharded Data-Parallel ZeRO series (DeepSpeed etc…) Tensor Parallel Pipeline Parallel torchrun (Elastic Launch) torch.distributed.launch
AI ▪ それぞれのprocess間で 情報の通信を行う ▪ それぞれのprocessごとに 番号(rank)が振られ、 rank=0をmasterとして 扱う Distributed
Communications Backend Machine 2 Machine 1 5 PyTorchのDistributed Communicationの仕組み Process 1 (Rank 2) Process 2 (Rank 3) Process 1 (Rank 0) Process 2 (Rank 1)
AI 6 ▪ PyTorchでは以下の3つから分散通信に用いるbackendを 選択することができる ▪ Gloo ▪ CPU上での通信と、GPU上での一部の通信が実装されている ▪
NCCL ▪ GPU上での最適化された通信が実装されている ▪ GPUではGlooより高速 ▪ Open MPI ▪ ビルド済みパッケージに含まれないため、ソースからビルドする必要がある ▪ 上2つで十分なため、特別な理由がないかぎり使用されない 利用できるDistributed Communications Backend
AI 7 ▪ それぞれのbackendでできること が異なる ▪ 主要な操作については後述 利用できるDistributed Communications Backend
https://pytorch.org/docs/stable/distributed.html
AI ▪ torch.distributed.init_process_group を用いて初期化を行う ▪ 引数: ▪ rank: 現在のprocessのrank ▪
world_size: 全体のprocess数 ▪ backend: 分散通信にどのライブラリを使用するか defaultではgloo(cpu)とnccl(gpu)が併用される ▪ 他に環境変数で以下を設定する必要がある ▪ MASTER_PORT ▪ MASTER_ADDR ▪ (RANKとWORLD_SIZEも指定でき、その場合はinit_process_groupで指定する必要はない) 8 Distributed SettingのSet Up
AI ▪ 実装例 9 Distributed SettingのSet Up https://pytorch.org/docs/stable/distributed.html
AI 10 PyTorch Distributed Overview High level Low level Communication
backend Communication APIs (C10D) Sharding primitives Parallelism APIs Laucher Gloo Open MPI NCCL send recv broadcast all_reduce reduce all_gather gather scatter reduce_scatter all_to_all barrier DTensor DeviceMesh Data-Parallel Distributed Data-Parallel Fully Sharded Data-Parallel ZeRO series (DeepSpeed etc…) Tensor Parallel Pipeline Parallel torchrun (Elastic Launch) torch.distributed.launch
AI 11 ▪ Point-to-Point Communication • send (送信) • recv
(受信) Communication APIs (C10D) https://pytorch.org/tutorials/intermediate/dist_tuto.html
AI 12 ▪ Collective Communication ▪ 全てのrank間に対しての通信 Communication APIs (C10D)
https://pytorch.org/tutorials/intermediate/dist_tuto.html
AI 13 PyTorch Distributed Overview High level Low level Communication
backend Communication APIs (C10D) Sharding primitives Parallelism APIs Laucher Gloo Open MPI NCCL send recv broadcast all_reduce reduce all_gather gather scatter reduce_scatter all_to_all barrier DTensor DeviceMesh Data-Parallel Distributed Data-Parallel Fully Sharded Data-Parallel ZeRO series (DeepSpeed etc…) Tensor Parallel Pipeline Parallel torchrun (Elastic Launch) torch.distributed.launch
AI 14 ▪ Multi-Host, Multi-GPUを用いる場合、 設定が複雑になりがち🤯 DeviceMesh https://pytorch.org/tutorials/recipes/distributed_device_mesh.html Host 1
GPU 0 GPU 1 GPU 2 GPU 3 Host 2 GPU 0 GPU 1 GPU 2 GPU 3
AI 15 ▪ DeviceMeshを使って抽象的に2次元の ProcessGroup(processのsubset)を扱うことができる DeviceMesh https://pytorch.org/tutorials/recipes/distributed_device_mesh.html
AI 16 ▪ distributed tensor ▪ TensorやModuleを先述のDeviceMeshに基づいてprocessに 配置できる DTensor (Prototype
Release) https://github.com/pytorch/pytorch/tree/main/torch/distributed/_tensor
AI 17 PyTorch Distributed Overview High level Low level Communication
backend Communication APIs (C10D) Sharding primitives Parallelism APIs Laucher Gloo Open MPI NCCL send recv broadcast all_reduce reduce all_gather gather scatter reduce_scatter all_to_all barrier DTensor DeviceMesh Data-Parallel Distributed Data-Parallel Fully Sharded Data-Parallel ZeRO series (DeepSpeed etc…) Tensor Parallel Pipeline Parallel torchrun (Elastic Launch) torch.distributed.launch
AI 18 ▪ ここまで説明した低レイヤーな操作を抽象化して nn.Moduleを並列するAPI ▪ Data-Parallel (DP) ▪ Distributed
Data-Parallel (DDP) ▪ Tensor Parallel (TP) ▪ Pipeline Parallel (PP) ▪ Fully Sharded Data-Parallel (FSDP) ▪ ZeRO (DeepSpeedやFairScaleなどのサードパーティにて実装) Parallelism APIs
AI 19 ▪ それぞれのGPUにmodelをコピーし、Batchを分割して 学習、逆伝搬、更新後にパラメータをGPU間で同期する ▪ 単一のprocessがGPUを管理するので実装がシンプル ▪ オーバヘッドが大きいため現在は非推奨となっている Data-Parallel
(DP) Dataloader GPU:0 GPU:1 GPU:2 batch Model0 Model1 Model2 batch分割 Loss calc 勾配の計算 モデルパラメータの 更新後に分散 出力の集約 勾配の集約
AI 20 ▪ それぞれのGPU上のpipelineを別々のprocessが持つ ▪ DPと異なり、GPU間の通信は勾配の集約・分散のみ Distributed Data-Parallel (DDP) Dataloader
GPU:0 GPU:1 GPU:2 batch Model0 Model1 Model2 勾配の集約・分散 Loss calc Lossの計算 Dataloader batch Dataloader batch Loss calc Loss calc 勾配の計算 それぞれのGPUでモデルパラメータの更新
AI 21 ▪ 分散環境をsetup し、modelを DDP()でラップ ▪ checkpointの 保存・読み込み はprocess
1のみ 行うようにする Distributed Data-Parallel (DDP) https://pytorch.org/tutorials/intermediate/ddp_tutorial.html
AI 22 ▪ 個々のTensorをそれぞれのGPUに分割する ▪ Megatron-LMで提案された ▪ 行列積は行方向・列方向ともに分割でき、後で集約する Tensor Parallel
(TP) https://arxiv.org/abs/1909.08053
AI 23 ▪ modelをいくつかのユニットに分割し、 別々のprocessに配置する ▪ あるGPUで処理している間の他のGPUのidle期間を緩和す るために、batchをchunkして用いる process 1/GPU:0
Pipeline Parallel (PP) process 2/GPU:1 process 3/GPU:2 model module 1 module 2 module 3 module 4 module 5 module 6
AI 24 ▪ モデルパラメータや勾配をGPU間で分割して保持する ▪ forward・backward中、それぞれ決められたユニット内で 計算を行うためメモリを節約することができる Fully Sharded Data-Parallel
Training (FSDP) https://arxiv.org/abs/2304.11277
AI 25 ▪ モデルが1 GPUに載る ▪ DP (非推奨) ▪ DDP
▪ モデルが1 GPUに載らない ▪ 演算ごとに細かく分割したい ▪ TP ▪ モデルの段階ごとに細かく分割したい ▪ PP ▪ PyTorchに分割はお任せしたい ▪ FSDP ▪ size_based_auto_wrap_policy 使い分け
AI 26 PyTorch Distributed Overview High level Low level Communication
backend Communication APIs (C10D) Sharding primitives Parallelism APIs Laucher Gloo Open MPI NCCL send recv broadcast all_reduce reduce all_gather gather scatter reduce_scatter all_to_all barrier DTensor DeviceMesh Data-Parallel Distributed Data-Parallel Fully Sharded Data-Parallel ZeRO series (DeepSpeed etc…) Tensor Parallel Pipeline Parallel torchrun (Elastic Launch) torch.distributed.launch
AI 27 ▪ マルチプロセスでスクリプトを起動できる機能 ▪ 起動するとRANKなどの環境変数がセットされ、 スクリプトから参照できるようになる torchrun・torch.distributed.launch shell training
script https://pytorch.org/docs/stable/elastic/run.html
AI 28 ▪ Trainerにstrategy引数を指定するだけで、Single GPUの 学習コードに手を加えることなく実現できる • DDP以外のstrategyも指定できる(FSDP, DeepSpeedなど) PyTorch
Lightningを使う場合 https://lightning.ai/docs/pytorch/stable/accelerators/gpu_intermediate.html
AI 29 ▪ PyTorchにおける分散学習の概観を紹介した ▪ 1 GPUにモデルが載りきるならDDP、 分散させないといけないならFSDPを使えばOK ▪ 紹介しきれなかったDeepSpeedなどの
サードパーティライブラリについてはまた別の機会に まとめ
AI 30 ▪ TensorFlow ▪ https://www.tensorflow.org/guide/distributed_training?hl =ja ▪ PyTorchでいうところのDP・DDP・TPなどが実装されている ▪
Jax ▪ https://jax.readthedocs.io/en/latest/multi_process.html ▪ 低レベルのAPIが提供されており、適宜実装する必要がある? 余談:他のDLフレームワークでは