Upgrade to Pro — share decks privately, control downloads, hide ads and more …

Poster: Distributed Training with JAX and Kubeflow

Poster: Distributed Training with JAX and Kubeflow

Avatar for Mahdi Khashan

Mahdi Khashan

September 02, 2025
Tweet

More Decks by Mahdi Khashan

Other Decks in Technology

Transcript

  1. Distributed Training with and Kubeflow Motivation This project implements key

    components of KEP-2170, creating TrainingRuntime and Cluster-TrainingRuntime CRDs for JAX. Built on the Kubernetes JobSet API, these reusable blueprints simplify LLM and model training within cloud-native ML pipelines, letting AI practitioners submit jobs via SDK or YAML without managing low-level Kubernetes orchestration Takeaways • Simplified UX: scientists use high-level Python SDK references instead of complex YAML. • Reusability: blueprints can be curated by admins and consistently shared. • Framework-Agnostic: same API covers JAX, PyTorch, LLMs, and more. • Cloud-Native Scalability: leverages JobSet and Kubernetes for distributed execution. Python SDK Example from kubeflow.trainer import TrainerClient from kubeflow.trainer import CustomTrainer def jax_train_mnist(args): pass client = TrainerClient() jax_runtime = next( r for r in client.list_runtimes() if r.name == "jax-distributed" ) job_id = client.train( trainer=CustomTrainer( func=jax_train_mnist, func_args={"epoch": "10"}, num_nodes=4), runtime=jax_runtime, ) YAML using a Runtime Blueprint apiVersion: trainer.kubeflow.org/v1alpha1 kind: ClusterTrainingRuntime metadata: name: jax-distributed spec: mlPolicy: numNodes: 4 jax: backend: nccl template: spec: replicatedJobs: - name: process template: spec: template: spec: containers: - name: node image: kubeflow/jax-runtime Key Innovations from Trainer V2 • Unified CRDs: TrainJob, TrainingRuntime, and ClusterTrainingRuntime – replace framework-specific controller (e.g., JAXJob, PyTorchJob) with a single, flexible interface. • Reusable runtime blueprints let admins standardize compute environments, while practitioners simply reference them in TrainJobs. • Built on Kubernetes JobSet API, enabling scalable, multi-pod distributed training across TPU / GPU / CPU. • Full SDK support for programmatic job submission and management. Core Components • Kubernetes: container orchestration at scale. • Kubeflow Trainer V2: unified API with reusable runtime abstractions. • TrainingRuntime / ClusterTrainingRuntime: define environment and resources for training. • JAX: primary frameworks supported Mahdi Khashan Master’s of Artificial Intelligence at JKU Scalable ML on Kubernetes