Upgrade to Pro — share decks privately, control downloads, hide ads and more …

Keras Core with JAX: Streamlining Deep Learning for Robust Acceleration

Jeongkyu Shin
September 09, 2023

Keras Core with JAX: Streamlining Deep Learning for Robust Acceleration

Have you heard of JAX? DeepMind's matrix manipulation library, JAX, harnesses the power of diverse devices, delivering robust deep learning performance. However, its relatively weak infrastructure has posed challenges for its use. In this presentation, we introduce JAX, and explore how, through Keras Core and Keras 3, we leverage the powerful performance of JAX on a range of accelerators. Join us as we delve into streamlining the utilization of JAX's performance, making deep learning more accessible and efficient for all.

-
This presentation was given as an invited talk for the Keras Community Day Kuala Lumpur 2023.

Jeongkyu Shin

September 09, 2023
Tweet

More Decks by Jeongkyu Shin

Other Decks in Technology

Transcript

  1. Keras Core with JAX: : Streamlining Deep Learning for Robust

    Acceleration Jeongkyu Shin ML GDE / Cloud Champion Innovator CEO / Lablup Inc. Sep. 8, 2023 Head image: https://www.hopefortheflowers.com
  2. • Lablup Inc. : Make AI Accessible Open source machine

    learning cluster platform: Backend.AI https://www.backend.ai • Google Developer Expert ML / DL Google Developer Expert Google Cloud Champion Innovator Google for Startup Accelerator Mentor • Open source Textcube Project Moderator 20th year! • Physics / Neuroscience Ph.D. in Statistical Physics Complex System / Computational Neuroscience Former Adj. Professor, Hanyang Univ. ERICA, Software Dept. Hi! 2
  3. • Lablup Inc. : Make AI Accessible Open source machine

    learning cluster platform: Backend.AI https://www.backend.ai • Google Developer Expert ML / DL GDE 2017 Google Cloud Champion Innovator 2021 Google for Startup Accelerator Mentor • Open source Textcube Project Moderator 20th year! • Physics / Neuroscience Ph.D. in Statistical Physics Complex System / Computational Neuroscience Former Adj. Professor, Hanyang Univ. ERICA, Software Dept. Hi! 3
  4. • You already heard about it from previous talks _

    So, let s see the backstories! • New Keras team repository _ Will be Keras 3.0 this fall • Multi backend support • Is it same as Keras? _ Well, you seem to have missed many parts in history… Keras Core 5
  5. • Long, long time ago, there was Keras High end

    deep learning library Python interface for neural network researches Multi backend support: without its own backend Supports TensorFlow, Theano, CNTK, MXNet… Looks familiar? Welcome to the oldies club! • And then Keras became the part of TensorFlow from 1.4 And finally it bound to TensorFlow only backend since Keras 2.4 2020. 6 TensorFlow has Keras implementation under tf.keras, and Keras became alias for tf.keras Is it over? But fchollet mentioned as a temporary option Keras: Back then... https://github.com/keras team/keras/releases/tag/2.4.0 6
  6. • Fill the bottle _ Keras repository became alive from

    2021 Copy paste based sync from tf.keras code since TensorFlow 2.5 And Keras became separated package from TensorFlow 2.6 2021. 7 _ What happened? Many ideas and stories… Who knows? • And after two years _ Keras is completely rewritten! _ Keras Core Keras: Recent changes 7
  7. 2015 2017 2019 2021 2023 Keras Keras released on top

    of Theano Keras Core Announced keras core repo update restarted TensorFlow support added tf.keras announced for TensorFlow 1.4 TensorFlow 2.6 Separates Keras Drops multi backend. Keras merged into TF Keras TensorFlow From Keras to Keras Core: long story short 8
  8. • Completely rewritten code _ Line of code: 135k to

    45k _ How?! Drop TensorFlow 1.X compatibility • Supporting backends 2023 _ TensorFlow _ JAX _ PyTorch _ NumPy*Not for training Keras Core: From ground up image: Modern Times (1936) 9
  9. • Switchable backend • Multi framework custom components • Universal

    training loop • Native models support • Future proof code So what will be better with Keras core? 11
  10. • As easy as before _ os.environ[“KERAS_BACKEND”] • Notes _

    Important import order torch should be imported before tensorflow Torch backend: • keras_core tensorflow _ No sparsity support yet Sparse types along frameworks are too broad now _ PyTorch compatibility Average pooling, integer dtypes Switchable backend import os os.environ["KERAS_BACKEND"] = "jax" import keras_core import jax import tensorflow as tf import os os.environ["KERAS_BACKEND"] = "torch" import keras_core import torch import tensorflow as tf 12
  11. • NumPy API _ Matmul, sum, stack, einsum… • Network

    specific functions _ Softmax, binary crossentrypy, conv… • Models / layers / losses / metrics / optimizers _ Works the same with any framework _ Even outside Keras workflow! Cross framework components with keras.ops 13
  12. • Use keras.ops to develop custom components • With backend

    specific custom components _ TensorFlow, PyTorch, JAX Multi framework custom components 14
  13. • Use keras.ops to develop custom components • With backend

    specific custom components _ TensorFlow, PyTorch, JAX Multi framework custom components 15
  14. • Train a Keras model with _ low level JAX

    training loop: optax optimizer, jax.grad, jax.jit, jax.pmap… _ Low level TensorFlow training loop: tf.GradientTape, tf.distribute _ Low level PyTorch training loop: torch.optim optimizer, torch.nn.parallel.DistributedDataParallel • Use a Keras layer or model as part of a torch.nn.Module _ PyTorch users can leverage Keras models like any other PyTorch module! Universal Training Loop: Seamless integration w/ backends https://www.youtube.com/watch?v 5fTPEoeFZk 16
  15. • Supports all current Keras applications _ KerasCV YOLOv8 …

    _ KerasNLP BERT, OPT … • Supports all Keras applications with all backends! Native model support: Keras with Pretrained models 17
  16. • JAX Deepmind, 2018 _ Designed for high performance machine

    learning research _ Provides composable transformations of Python and NumPy programs Automatic differentiation, vectorization vmap , parallelization pmap , etc. _ Autograd capability Can use grad function within JAX to automatically differentiate any function _ Others Random number generator, etc. • Why JAX? _ Scalable: CPU, GPU and TPUs with XLA compiler _ XLA? JAX: What is it? 19
  17. • XLA Accelerated Linear Algebra, Google, 2017 _ Fast matrix

    operations for CPU, Nvidia GPU and TPU _ Just In Time JIT Compilation, Ahead Of Time AOT Compilation • MLIR Multi Level Intermediate Representation, Google, 2019 _ https://mlir.llvm.org _ LLVM backed intermediate compilation layer for machine learning • OpenXLA Google, 2023 _ Open source version of XLA StableHLO StableHLO: Operation set for high level operations HLO in ML models _ Machine Learning compiler for ML accelerator hardwares JAX: What lies beneath? 20
  18. • JAX Low level API like PyTorch without torch.nn or

    TensorFlow without tf.keras • FLAX FLexible JAX Layer API from Google excluding DeepMind High level interface for defining and training neural networks based on JAX • Haiku Another layer API, from DeepMind Inspired by Sonnet for TensorFlow Not updated anymore: FLAX is recommended for JAX ecosystem • OPTAX Optimizers and loss function for JAX • ETC. JAX: Ecosystem 21
  19. • There is no free lunch _ High learning curve

    JAX does not provide high level neural network features Functional programming _ JIT nature Immutable objects lead to low debuggability like old TensorFlow 1.X _ Lack of documentation yet! • Does FLAX/Haiku solve the issue? _ Lack of third party projects / codes _ Independent project: one more deep learning library JAX: Trouble 22
  20. • To unleash the power of JAX, you d better

    care about distributed workload _ Data Sharding _ Device mesh _ ETC. • @jax.jit _ Decorator to use Just In Time compilation _ In fact, there is jit option in TensorFlow too JAX: Double Trouble 23
  21. • Simple image classification model _ Defines CNN Layer definition:

    Input to Output _ Compile model _ Train the model _ Evaluate the result model Example: Causal Keras Code import jax import jax.numpy as jnp import keras_core as keras # Keras multi-backend # Input definition # Layer structure definition # Model compilation with input / output # Train # Evaluate # (Optional) Model export # (Optional, usually as indep. Project) Inference 24
  22. • Simple image classification model _ Defines CNN Layer definition:

    Input to Output _ Compile model _ Train the model _ Evaluate the result model # Input definition conv2d_kwargs = { "kernel_size": (3, 3), "activation": "relu", "padding": "same", } inputs = keras_core.Input(shape=(32, 32, 3), name="input_layer") # Layer structure definition x = inputs for filters in [32, 64, 128]: x = keras_core.layers.Conv2D(filters=filters, **conv2d_kwargs)(x) x = keras_core.layers.BatchNormalization()(x) x = keras_core.layers.Conv2D(filters=filter, strides=(2, 2), **conv2d_kwargs)(x) x = keras_core.layers.BatchNormalization()(x) x = keras_core.layers.Dropout(0.25)(x) x = keras_core.layers.GlobalAveragePooling2D()(x) x = keras_core.layers.Dense(128, activation="relu")(x) x = keras_core.layers.Dropout(0.25)(x) outputs = keras_core.layers.Dense(10, activation="softmax", name="output_layer")(x) Example: Causal Keras Code 25
  23. • JAX _ Stateless: JAX provides tools, not as complete

    framework States are passed explicitly as a parameter • State Trainable parameters Non trainable parameters Optimizer state • Keras core JAX _ We need to consider the stateless manner of JAX Example: Custom model with JAX 26
  24. • Required methods for custom model _ compute_loss_and_updates _ train_step

    _ test_step • compute loss and updates _ Calculates loss update variables Custom model with JAX: Requirements class CustomModel(keras_core.Model): def compute_loss_and_updates( self, trainable_variables, non_trainable_variables, x, y, training=False, ): y_pred, non_trainable_variables = self.stateless_call( trainable_variables, non_trainable_variables, x, training=training, ) loss = self.compute_loss(x, y, y_pred) return loss, (y_pred, non_trainable_variables) 27
  25. • train_steps _ Computes gradients _ Updates trainable variables _

    Updates optimizer variables _ Updates metrics Custom model with JAX: Train steps def train_step(self, state, data): # Unpack the current state ( trainable_variables, non_trainable_variables, optimizer_variables, metrics_variables, ) = state grad_fn = jax.value_and_grad( self.compute_loss_and_updates, has_aux=True ) # Compute the gradients. (loss, (y_pred, non_trainable_variables)), grads = grad_fn(…) # Update trainable variables and optimizer variables. (…) = self.optimizer.stateless_apply(…) # Update metrics. state = (…) return logs, state 28
  26. • train_steps _ Computes gradients _ Updates trainable variables _

    Updates optimizer variables _ Updates metrics def train_step(self, state, data): # Unpack the current state ( trainable_variables, non_trainable_variables, optimizer_variables, metrics_variables, ) = state # Unpack the data x, y = data # Get the gradient function. grad_fn = jax.value_and_grad( self.compute_loss_and_updates, has_aux=True ) # Compute the gradients. (loss, (y_pred, non_trainable_variables)), grads = grad_fn( trainable_variables, non_trainable_variables, x, y, training=True, ) # Update trainable variables and optimizer variables. ( Custom model with JAX: Train steps 29
  27. • train_steps _ Computes gradients Get the grad function from

    JAX • The gradients are computed by calling grad_fn with the current trainable and non trainable variables and the data. • Also returns the loss and the updated non trainable variables. _ Updates trainable variables _ Updates optimizer variables _ Updates metrics Custom model with JAX: compute gradients def train_step(self, state, data): # Unpack the current state … # Unpack the data x, y = data # Get the gradient function. grad_fn = jax.value_and_grad( self.compute_loss_and_updates, has_aux=True ) # Compute the gradients. (loss, (y_pred, non_trainable_variables)), grads = grad_fn( trainable_variables, non_trainable_variables, x, y, training=True, ) … 30
  28. • train_steps _ Computes gradients _ Updates trainable variables optimizer

    variables • The stateless_apply method of the optimizer is used to update the trainable and optimizer variables using the computed gradients _ Updates metrics Custom model with JAX: update variables def train_step(self, state, data): # Unpack the current state … # Unpack the data … # Get the gradient function. … # Compute the gradients. … # Update trainable variables and optimizer variables. ( trainable_variables, optimizer_variables, ) = self.optimizer.stateless_apply( optimizer_variables, grads, trainable_variables ) # Update metrics. … 31
  29. • train_steps _ Computes gradients _ Updates trainable variables _

    Updates optimizer variables _ Updates metrics • Iterating over each metric in the model s metrics • Each metric s state variables are updated using its stateless_update_state method • Result is computed using its stateless_result method. Custom model with JAX: update metrics def train_step(self, state, data): # Update metrics. new_metrics_vars = [] for metric in self.metrics: this_metric_vars = metrics_variables[ len(new_metrics_vars) : len(new_metrics_vars) + len(metric.variables) ] if metric.name == "loss": this_metric_vars = metric.stateless_update_state( this_metric_vars, loss ) else: this_metric_vars = metric.stateless_update_state( this_metric_vars, y, y_pred ) logs = metric.stateless_result(this_metric_vars) new_metrics_vars += this_metric_vars # Return metric logs and updated state variables. state = ( trainable_variables, non_trainable_variables, optimizer_variables, new_metrics_vars, ) return logs, state 32
  30. • test_steps _ Computes predictions losses _ Updates metrics _

    Updates state variables _ Return metric logs and state variables Custom model with JAX: Test steps def test_step(self, state, data): # Unpack the data. x, y = data ( trainable_variables, non_trainable_variables, metrics_variables, ) = state # Compute predictions and loss. y_pred, non_trainable_variables = self.stateless_call( … training=False, ) loss = self.compute_loss(x, y, y_pred) # Update metrics. … new_metrics_vars += this_metric_vars # Return metric logs and updated state variables. state = (…) return logs, state 33
  31. • test_steps _ Computes predictions losses _ Updates metrics _

    Updates state variables _ Return metric logs and state variables def test_step(self, state, data): # Unpack the data. x, y = data ( trainable_variables, non_trainable_variables, metrics_variables, ) = state # Compute predictions and loss. y_pred, non_trainable_variables = self.stateless_call( trainable_variables, non_trainable_variables, x, training=False, ) loss = self.compute_loss(x, y, y_pred) # Update metrics. new_metrics_vars = [] for metric in self.metrics: this_metric_vars = metrics_variables[ len(new_metrics_vars) : len(new_metrics_vars) + len(metric.variables) ] if metric.name == "loss": Custom model with JAX: Test steps 34
  32. • test_steps _ Computes predictions losses • stateless_call method of

    the model is used to compute the predictions and update the non trainable variables • The loss is computed using the compute_loss method of the model _ Updates metrics _ Updates state variables _ Return metric logs and state variables Custom model with JAX: Compute predictions def test_step(self, state, data): # Unpack the data. … # Compute predictions and loss. y_pred, non_trainable_variables = self.stateless_call( trainable_variables, non_trainable_variables, x, training=False, ) loss = self.compute_loss(x, y, y_pred) … 35
  33. • test_steps _ Computes predictions losses _ Updates metrics _

    Updates state variables • The metrics are updated by iterating over each metric in the model s metrics. • Each metric s state variables are updated using its stateless_update_state method • Result is computed using its stateless_result method _ Return metric logs and state variables Custom model with JAX: Update metrics states def test_step(self, state, data): … # Update metrics. new_metrics_vars = [] for metric in self.metrics: this_metric_vars = metrics_variables[ len(new_metrics_vars) : len(new_metrics_vars) + len(metric.variables) ] if metric.name == "loss": this_metric_vars = metric.stateless_update_state( this_metric_vars, loss ) else: this_metric_vars = metric.stateless_update_state( this_metric_vars, y, y_pred ) logs = metric.stateless_result(this_metric_vars) new_metrics_vars += this_metric_vars # Return metric logs and updated state variables. state = ( trainable_variables, non_trainable_variables, new_metrics_vars,) return logs, state 36
  34. • Simple image classification model _ Defines CNN _ Compile

    model _ Train the model _ Evaluate the result model • This is so Keras _ Noting to explain but… _ One line is different! # Input definition conv2d_kwargs = { "kernel_size": (3, 3), "activation": "relu", "padding": "same", } inputs = keras_core.Input(shape=(32, 32, 3), name="input_layer") # Layer structure definition x = inputs for filters in [32, 64, 128]: x = keras_core.layers.Conv2D(filters=filters, **conv2d_kwargs)(x) x = keras_core.layers.BatchNormalization()(x) x = keras_core.layers.Conv2D(filters=filter, strides=(2, 2), **conv2d_kwargs)(x) x = keras_core.layers.BatchNormalization()(x) x = keras_core.layers.Dropout(0.25)(x) x = keras_core.layers.GlobalAveragePooling2D()(x) x = keras_core.layers.Dense(128, activation="relu")(x) x = keras_core.layers.Dropout(0.25)(x) outputs = keras_core.layers.Dense(10, activation="softmax", name="output_layer")(x) Custom model with JAX: Using custom model 37
  35. • JAX CustomModel like pure Keras _ Brings JAX performance

    to Keras _ As reusable form! • Causal pattern _ Write JAX based custom model code as separate files _ Import them when needed Custom model with JAX: Using custom model ... model = CustomModel(inputs, outputs, name="image_classification_model") ... 38
  36. • JAX DeviceMesh JAX Sharding _ from jax.experimental import mesh_utils

    _ from jax.sharding import Mesh _ from jax.sharding import NamedSharding Keras Core JAX: Distributed Training devices = mesh_utils.create_device_mesh((8,)) # data will be split along the batch axis # naming axes of the mesh data_mesh = Mesh(devices, axis_names=("batch",)) # naming axes of the sharded partition data_sharding = NamedSharding( data_mesh, P( "batch", ), ) # all variables will be replicated on all devices var_mesh = Mesh(devices, axis_names=("_")) # in NamedSharding, axes that are not mentioned are replicated var_replication = NamedSharding(var_mesh, P()) # Split the largest kernel 4-ways (and replicate 2-ways since we have 8 devices) large_kernel_mesh = Mesh( devices.reshape((-1, 4)), axis_names=(None, "out_chan") ) # naming axes of the mesh large_kernel_sharding = NamedSharding( large_kernel_mesh, P(None, None, None, "out_chan") ) # naming axes of the sharded partition 39
  37. • JAX gives _ CPU/GPU optimization _ TPU compaitilibity /

    optimization • Great example: WhisperJAX _ Whisper: OpenAI s speech / transcribe model _ Whisper JAX: Port version to JAX • Results _ Up to 2x speed ups _ Run on TPU: 5x 10x speed gain with JAX TPU JAX: Performance gain OpenAI Transformer s Whisper JAX Whisper JAX Framework PyTorch PyTorch JAX JAX Backend GPU GPU GPU TPU 1 min 13.8 4.54 1.72 0.45 10 min 108.3 20.2 9.38 2.01 1 hour 1001.0 126.1 75.3 13.8 https://github.com/sanchit gandhi/whisper jax 40
  38. • Win win _ : good demonstration multi backend Keras

    powering to other frameworks _ : now have a well documented robust framework! Keras Core JAX: combination 41
  39. • Performance _ Use JAX without hassle, with ease •

    Mixing ingredients _ Use Keras model like PyTorch module _ Model as stateless JAX function _ Ecosystem coverage: Keras models for both TensorFlow PyTorch • Cross data sourcing _ Optimized data wrapper / libraries with any backend _ E.g. NumPy, Pandas, tf.data, PyTorch DataLoader • Future proof _ Keras will remain even when others have vanished _ Theano, TensorFlow 1.X, CNTK… remember the past Keras Core: Bright side 43
  40. • May I just use TensorFlow / PyTorch? _ The

    biggest question now: why Keras? _ Before we realized it, the ecosystem has become regulated... • With JAX, you will be noticed soon that _ It is not easy anymore • Home of super spaghetti codes? _ Keras other backend ML frameworks _ Even Keras provides future proof grammar, others will not _ The pros of Keras can also be cons Keras: Dark side 44
  41. • A. R. Gosthipaty and R. Raha. What Is Keras

    Core? PyImageSearch, P. Chugh, S. Huot, K. Kidriavsteva, and A. Thanki, eds., 2023 https://pyimagesearch.com/2023/07/24/what is keras core/ • Introduction to Keras Core with Francois Chollet https://www.youtube.com/watch?v 5fTPEoeFZk • Writing a custom training loop in JAX, GitHub/keras/keras core https://github.com/keras team/keras core/blob/main/guides/writing a custom training loop in jax.py • JAX: Can It Beat PyTorch and TensorFlow? https://www.it jim.com/blog/jax can it beat pytorch and tensorflow/ • JAX Distributed demo, GitHub/keras/keras core https://github.com/keras team/keras core/blob/main/examples/demo jax distributed.py References 46