Slide 1

Slide 1 text

JAX: accelerating ML with composable transformations Matthew Johnson ([email protected]) on behalf of the JAX team

Slide 2

Slide 2 text

How might you implement a deep neural network from scratch in Python?

Slide 3

Slide 3 text

Motivating JAX import numpy as np def predict(params, inputs): for W, b in params: outputs = np.dot(inputs, W) + b inputs = np.tanh(outputs) return outputs

Slide 4

Slide 4 text

Motivating JAX import numpy as np def predict(params, inputs): for W, b in params: outputs = np.dot(inputs, W) + b inputs = np.tanh(outputs) return outputs def loss(params, batch): inputs, targets = batch preds = predict(params, inputs) return np.sum((preds - targets) ** 2)

Slide 5

Slide 5 text

Motivating JAX What’s missing? ● Accelerator hardware (GPU/TPU) ● Training via automatic differentiation ● Optimized compilation with fusion, memory layout, remat, … ● Vectorized batching of operations ● Parallelization over multiple accelerators import numpy as np def predict(params, inputs): for W, b in params: outputs = np.dot(inputs, W) + b inputs = np.tanh(outputs) return outputs def loss(params, batch): inputs, targets = batch preds = predict(params, inputs) return np.sum((preds - targets) ** 2)

Slide 6

Slide 6 text

import jax.numpy as jnp from jax import grad, jit, vmap def predict(params, inputs): for W, b in params: outputs = jnp.dot(inputs, W) + b inputs = jnp.tanh(outputs) return outputs def loss(params, batch): inputs, targets = batch preds = predict(params, inputs) return jnp.sum((preds - targets) ** 2) gradient_fun = jit(grad(loss)) perexample_grads = jit(vmap(grad(loss), in_axes=(None, 0))) Motivating JAX

Slide 7

Slide 7 text

Demo!

Slide 8

Slide 8 text

An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. Dosovitskiy, Beyer, Kolesnikov, Weissenborn, Zhai, Unterthiner, Dehghani, Minderer, Heigold, Gelly, Uszkoreit, Houlsby. arXiv 2021. Vision Transformer

Slide 9

Slide 9 text

Learned Initializations for Optimizing Coordinate-Based Neural Representations. Tancik, Mildenhall, et al. CVPR 2021. Mip-NeRF: A Multiscale Representation for Anti-Aliasing Neural Radiance Fields. Barron et al. arXiv 2021. Nerfies: Deformable Neural Radiance Fields. Park et al. arXiv 2020. NeRF

Slide 10

Slide 10 text

JAX, MD: a framework for differentiable physics. Schoenholz and Cubuk. NeurIPS 2020. Designing self-assembling kinetics with differentiable physics models. Goodrich, King, Schoenholz, Cubuk, and Brenner. PNAS 2021. JAX, MD for molecular dynamics and simulation

Slide 11

Slide 11 text

Ab-Initio Solution of the Many-Electron Schrödinger Equation with Deep Neural Networks. David Pfau,* James S. Spencer,* Alex G. de G. Matthews and W. M. C. Foulkes Physical Review Research 2(3), 033429, September 2020. FermiNet: Quantum Physics and Chemistry from First Principles. DeepMind Blog, 2020. DeepMind: FermiNet

Slide 12

Slide 12 text

Podracer architectures for scalable Reinforcement Learning. Hessel, Kroiss, et al. arXiv 2021. MuZero Sampled and MuZero Unplugged Hubert, Schrittwieser, et al. arXiv 2021. DeepMind: next generation of RL / MuZero work in JAX

Slide 13

Slide 13 text

AlphaFold: a solution to a 50-year-old grand challenge in biology. DeepMind Blog, 2020. (post refers to work done in TF) DeepMind: next generation of AlphaFold work in JAX

Slide 14

Slide 14 text

No content

Slide 15

Slide 15 text

No content

Slide 16

Slide 16 text

What about scale?

Slide 17

Slide 17 text

MLPerf Training v0.7 results (in seconds, lower is better) * Google, Research category † NVIDIA, Available On-Premise category. MLPerf v0.7 Training, closed division. Retrieved from www.mlperf.org 1 December 2020, entries 0.7-64, 0.7-65, 0.7-67, 0.7-30, 0.7-33, 0.7-37, 0.7-38. MLPerf name and logo are trademarks. See www.mlperf.org for more information. 50000x speedup over 5 years!

Slide 18

Slide 18 text

What are Cloud TPUs? = 4 TPU v3 chips (8 cores) attached to a CPU host + high-speed interconnects + compiler magic 🦄

Slide 19

Slide 19 text

What are Cloud TPU Pods? = 1,024 TPU v3 chips (2,048 cores) attached to many CPU hosts + high-speed interconnects + compiler magic 🦄

Slide 20

Slide 20 text

A Cloud TPU Pod Slice Host TPU Host Host TPU TPU TPU Host

Slide 21

Slide 21 text

A Cloud TPU Pod Slice running JAX Host TPU Host Host Cloud Storage (datasets, checkpoints, etc.) TPU TPU TPU Host Optional local disk ssh ssh ssh ssh

Slide 22

Slide 22 text

Demo! https://twitter.com/jekbradbury/status/1337528357517291520

Slide 23

Slide 23 text

JAX: accelerating ML with composable transformations Matthew Johnson ([email protected]) on behalf of the JAX team

Slide 24

Slide 24 text

:D :}