$30 off During Our Annual Pro Sale. View Details »

Distributed and Parallel Training for PyTorch

tattaka
August 22, 2024
340

Distributed and Parallel Training for PyTorch

社内のAI技術共有会で使用した資料です。
PyTorchで使われている分散学習の仕組みについて紹介しました。

tattaka

August 22, 2024
Tweet

Transcript

  1. AI 3 PyTorch Distributed Overview High level Low level Communication

    backend Communication APIs (C10D) Sharding primitives Parallelism APIs Laucher Gloo Open MPI NCCL send recv broadcast all_reduce reduce all_gather gather scatter reduce_scatter all_to_all barrier DTensor DeviceMesh Data-Parallel Distributed Data-Parallel Fully Sharded Data-Parallel ZeRO series (DeepSpeed etc…) Tensor Parallel Pipeline Parallel torchrun (Elastic Launch) torch.distributed.launch
  2. AI 4 PyTorch Distributed Overview High level Low level Communication

    backend Communication APIs (C10D) Sharding primitives Parallelism APIs Laucher Gloo Open MPI NCCL send recv broadcast all_reduce reduce all_gather gather scatter reduce_scatter all_to_all barrier DTensor DeviceMesh Data-Parallel Distributed Data-Parallel Fully Sharded Data-Parallel ZeRO series (DeepSpeed etc…) Tensor Parallel Pipeline Parallel torchrun (Elastic Launch) torch.distributed.launch
  3. AI ▪ それぞれのprocess間で 情報の通信を行う ▪ それぞれのprocessごとに 番号(rank)が振られ、 rank=0をmasterとして 扱う Distributed

    Communications Backend Machine 2 Machine 1 5 PyTorchのDistributed Communicationの仕組み Process 1 (Rank 2) Process 2 (Rank 3) Process 1 (Rank 0) Process 2 (Rank 1)
  4. AI 6 ▪ PyTorchでは以下の3つから分散通信に用いるbackendを 選択することができる ▪ Gloo ▪ CPU上での通信と、GPU上での一部の通信が実装されている ▪

    NCCL ▪ GPU上での最適化された通信が実装されている ▪ GPUではGlooより高速 ▪ Open MPI ▪ ビルド済みパッケージに含まれないため、ソースからビルドする必要がある ▪ 上2つで十分なため、特別な理由がないかぎり使用されない 利用できるDistributed Communications Backend
  5. AI ▪ torch.distributed.init_process_group を用いて初期化を行う ▪ 引数: ▪ rank: 現在のprocessのrank ▪

    world_size: 全体のprocess数 ▪ backend: 分散通信にどのライブラリを使用するか defaultではgloo(cpu)とnccl(gpu)が併用される ▪ 他に環境変数で以下を設定する必要がある ▪ MASTER_PORT ▪ MASTER_ADDR ▪ (RANKとWORLD_SIZEも指定でき、その場合はinit_process_groupで指定する必要はない) 8 Distributed SettingのSet Up
  6. AI 10 PyTorch Distributed Overview High level Low level Communication

    backend Communication APIs (C10D) Sharding primitives Parallelism APIs Laucher Gloo Open MPI NCCL send recv broadcast all_reduce reduce all_gather gather scatter reduce_scatter all_to_all barrier DTensor DeviceMesh Data-Parallel Distributed Data-Parallel Fully Sharded Data-Parallel ZeRO series (DeepSpeed etc…) Tensor Parallel Pipeline Parallel torchrun (Elastic Launch) torch.distributed.launch
  7. AI 11 ▪ Point-to-Point Communication • send (送信) • recv

    (受信) Communication APIs (C10D) https://pytorch.org/tutorials/intermediate/dist_tuto.html
  8. AI 13 PyTorch Distributed Overview High level Low level Communication

    backend Communication APIs (C10D) Sharding primitives Parallelism APIs Laucher Gloo Open MPI NCCL send recv broadcast all_reduce reduce all_gather gather scatter reduce_scatter all_to_all barrier DTensor DeviceMesh Data-Parallel Distributed Data-Parallel Fully Sharded Data-Parallel ZeRO series (DeepSpeed etc…) Tensor Parallel Pipeline Parallel torchrun (Elastic Launch) torch.distributed.launch
  9. AI 17 PyTorch Distributed Overview High level Low level Communication

    backend Communication APIs (C10D) Sharding primitives Parallelism APIs Laucher Gloo Open MPI NCCL send recv broadcast all_reduce reduce all_gather gather scatter reduce_scatter all_to_all barrier DTensor DeviceMesh Data-Parallel Distributed Data-Parallel Fully Sharded Data-Parallel ZeRO series (DeepSpeed etc…) Tensor Parallel Pipeline Parallel torchrun (Elastic Launch) torch.distributed.launch
  10. AI 18 ▪ ここまで説明した低レイヤーな操作を抽象化して nn.Moduleを並列するAPI ▪ Data-Parallel (DP) ▪ Distributed

    Data-Parallel (DDP) ▪ Tensor Parallel (TP) ▪ Pipeline Parallel (PP) ▪ Fully Sharded Data-Parallel (FSDP) ▪ ZeRO (DeepSpeedやFairScaleなどのサードパーティにて実装) Parallelism APIs
  11. AI 20 ▪ それぞれのGPU上のpipelineを別々のprocessが持つ ▪ DPと異なり、GPU間の通信は勾配の集約・分散のみ Distributed Data-Parallel (DDP) Dataloader

    GPU:0 GPU:1 GPU:2 batch Model0 Model1 Model2 勾配の集約・分散 Loss calc Lossの計算 Dataloader batch Dataloader batch Loss calc Loss calc 勾配の計算 それぞれのGPUでモデルパラメータの更新
  12. AI 21 ▪ 分散環境をsetup し、modelを DDP()でラップ ▪ checkpointの 保存・読み込み はprocess

    1のみ 行うようにする Distributed Data-Parallel (DDP) https://pytorch.org/tutorials/intermediate/ddp_tutorial.html
  13. AI 25 ▪ モデルが1 GPUに載る ▪ DP (非推奨) ▪ DDP

    ▪ モデルが1 GPUに載らない ▪ 演算ごとに細かく分割したい ▪ TP ▪ モデルの段階ごとに細かく分割したい ▪ PP ▪ PyTorchに分割はお任せしたい ▪ FSDP ▪ size_based_auto_wrap_policy 使い分け
  14. AI 26 PyTorch Distributed Overview High level Low level Communication

    backend Communication APIs (C10D) Sharding primitives Parallelism APIs Laucher Gloo Open MPI NCCL send recv broadcast all_reduce reduce all_gather gather scatter reduce_scatter all_to_all barrier DTensor DeviceMesh Data-Parallel Distributed Data-Parallel Fully Sharded Data-Parallel ZeRO series (DeepSpeed etc…) Tensor Parallel Pipeline Parallel torchrun (Elastic Launch) torch.distributed.launch
  15. AI 30 ▪ TensorFlow ▪ https://www.tensorflow.org/guide/distributed_training?hl =ja ▪ PyTorchでいうところのDP・DDP・TPなどが実装されている ▪

    Jax ▪ https://jax.readthedocs.io/en/latest/multi_process.html ▪ 低レベルのAPIが提供されており、適宜実装する必要がある? 余談:他のDLフレームワークでは