Upgrade to Pro — share decks privately, control downloads, hide ads and more …

JAX: Accelerated Machine Learning Research via ...

JAX: Accelerated Machine Learning Research via Composable Function Transformations in Python (Matt Johnson, Google Brain)

This talk is about JAX, a system for high-performance machine learning research and numerical computing. JAX offers the familiarity of Python+NumPy together with hardware acceleration. JAX combines these features with user-wielded function transformations, including automatic differentiation, automatic vectorized batching, end-to-end compilation (via XLA), parallelizing over multiple accelerators, and more. Composing these transformations is the key to JAX's power and simplicity. It’s used by researchers for a wide range of advanced applications, from large-scale neural net training, to probabilistic programming, to scientific applications in physics and biology.

Anyscale

July 20, 2021
Tweet

More Decks by Anyscale

Other Decks in Technology

Transcript

  1. Motivating JAX import numpy as np def predict(params, inputs): for

    W, b in params: outputs = np.dot(inputs, W) + b inputs = np.tanh(outputs) return outputs
  2. Motivating JAX import numpy as np def predict(params, inputs): for

    W, b in params: outputs = np.dot(inputs, W) + b inputs = np.tanh(outputs) return outputs def loss(params, batch): inputs, targets = batch preds = predict(params, inputs) return np.sum((preds - targets) ** 2)
  3. Motivating JAX What’s missing? • Accelerator hardware (GPU/TPU) • Training

    via automatic differentiation • Optimized compilation with fusion, memory layout, remat, … • Vectorized batching of operations • Parallelization over multiple accelerators import numpy as np def predict(params, inputs): for W, b in params: outputs = np.dot(inputs, W) + b inputs = np.tanh(outputs) return outputs def loss(params, batch): inputs, targets = batch preds = predict(params, inputs) return np.sum((preds - targets) ** 2)
  4. import jax.numpy as jnp from jax import grad, jit, vmap

    def predict(params, inputs): for W, b in params: outputs = jnp.dot(inputs, W) + b inputs = jnp.tanh(outputs) return outputs def loss(params, batch): inputs, targets = batch preds = predict(params, inputs) return jnp.sum((preds - targets) ** 2) gradient_fun = jit(grad(loss)) perexample_grads = jit(vmap(grad(loss), in_axes=(None, 0))) Motivating JAX
  5. An Image is Worth 16x16 Words: Transformers for Image Recognition

    at Scale. Dosovitskiy, Beyer, Kolesnikov, Weissenborn, Zhai, Unterthiner, Dehghani, Minderer, Heigold, Gelly, Uszkoreit, Houlsby. arXiv 2021. Vision Transformer
  6. Learned Initializations for Optimizing Coordinate-Based Neural Representations. Tancik, Mildenhall, et

    al. CVPR 2021. Mip-NeRF: A Multiscale Representation for Anti-Aliasing Neural Radiance Fields. Barron et al. arXiv 2021. Nerfies: Deformable Neural Radiance Fields. Park et al. arXiv 2020. NeRF
  7. JAX, MD: a framework for differentiable physics. Schoenholz and Cubuk.

    NeurIPS 2020. Designing self-assembling kinetics with differentiable physics models. Goodrich, King, Schoenholz, Cubuk, and Brenner. PNAS 2021. JAX, MD for molecular dynamics and simulation
  8. Ab-Initio Solution of the Many-Electron Schrödinger Equation with Deep Neural

    Networks. David Pfau,* James S. Spencer,* Alex G. de G. Matthews and W. M. C. Foulkes Physical Review Research 2(3), 033429, September 2020. FermiNet: Quantum Physics and Chemistry from First Principles. DeepMind Blog, 2020. DeepMind: FermiNet
  9. Podracer architectures for scalable Reinforcement Learning. Hessel, Kroiss, et al.

    arXiv 2021. MuZero Sampled and MuZero Unplugged Hubert, Schrittwieser, et al. arXiv 2021. DeepMind: next generation of RL / MuZero work in JAX
  10. AlphaFold: a solution to a 50-year-old grand challenge in biology.

    DeepMind Blog, 2020. (post refers to work done in TF) DeepMind: next generation of AlphaFold work in JAX
  11. MLPerf Training v0.7 results (in seconds, lower is better) *

    Google, Research category † NVIDIA, Available On-Premise category. MLPerf v0.7 Training, closed division. Retrieved from www.mlperf.org 1 December 2020, entries 0.7-64, 0.7-65, 0.7-67, 0.7-30, 0.7-33, 0.7-37, 0.7-38. MLPerf name and logo are trademarks. See www.mlperf.org for more information. 50000x speedup over 5 years!
  12. What are Cloud TPUs? = 4 TPU v3 chips (8

    cores) attached to a CPU host + high-speed interconnects + compiler magic 🦄
  13. What are Cloud TPU Pods? = 1,024 TPU v3 chips

    (2,048 cores) attached to many CPU hosts + high-speed interconnects + compiler magic 🦄
  14. A Cloud TPU Pod Slice running JAX Host TPU Host

    Host Cloud Storage (datasets, checkpoints, etc.) TPU TPU TPU Host Optional local disk ssh ssh ssh ssh