Slide 1

Slide 1 text

Keras Core with JAX: : Streamlining Deep Learning for Robust Acceleration Jeongkyu Shin ML GDE / Cloud Champion Innovator CEO / Lablup Inc. Sep. 8, 2023 Head image: https://www.hopefortheflowers.com

Slide 2

Slide 2 text

• Lablup Inc. : Make AI Accessible Open source machine learning cluster platform: Backend.AI https://www.backend.ai • Google Developer Expert ML / DL Google Developer Expert Google Cloud Champion Innovator Google for Startup Accelerator Mentor • Open source Textcube Project Moderator 20th year! • Physics / Neuroscience Ph.D. in Statistical Physics Complex System / Computational Neuroscience Former Adj. Professor, Hanyang Univ. ERICA, Software Dept. Hi! 2

Slide 3

Slide 3 text

• Lablup Inc. : Make AI Accessible Open source machine learning cluster platform: Backend.AI https://www.backend.ai • Google Developer Expert ML / DL GDE 2017 Google Cloud Champion Innovator 2021 Google for Startup Accelerator Mentor • Open source Textcube Project Moderator 20th year! • Physics / Neuroscience Ph.D. in Statistical Physics Complex System / Computational Neuroscience Former Adj. Professor, Hanyang Univ. ERICA, Software Dept. Hi! 3

Slide 4

Slide 4 text

Keras core: Multi backend Keras again!

Slide 5

Slide 5 text

• You already heard about it from previous talks _ So, let s see the backstories! • New Keras team repository _ Will be Keras 3.0 this fall • Multi backend support • Is it same as Keras? _ Well, you seem to have missed many parts in history… Keras Core 5

Slide 6

Slide 6 text

• Long, long time ago, there was Keras High end deep learning library Python interface for neural network researches Multi backend support: without its own backend Supports TensorFlow, Theano, CNTK, MXNet… Looks familiar? Welcome to the oldies club! • And then Keras became the part of TensorFlow from 1.4 And finally it bound to TensorFlow only backend since Keras 2.4 2020. 6 TensorFlow has Keras implementation under tf.keras, and Keras became alias for tf.keras Is it over? But fchollet mentioned as a temporary option Keras: Back then... https://github.com/keras team/keras/releases/tag/2.4.0 6

Slide 7

Slide 7 text

• Fill the bottle _ Keras repository became alive from 2021 Copy paste based sync from tf.keras code since TensorFlow 2.5 And Keras became separated package from TensorFlow 2.6 2021. 7 _ What happened? Many ideas and stories… Who knows? • And after two years _ Keras is completely rewritten! _ Keras Core Keras: Recent changes 7

Slide 8

Slide 8 text

2015 2017 2019 2021 2023 Keras Keras released on top of Theano Keras Core Announced keras core repo update restarted TensorFlow support added tf.keras announced for TensorFlow 1.4 TensorFlow 2.6 Separates Keras Drops multi backend. Keras merged into TF Keras TensorFlow From Keras to Keras Core: long story short 8

Slide 9

Slide 9 text

• Completely rewritten code _ Line of code: 135k to 45k _ How?! Drop TensorFlow 1.X compatibility • Supporting backends 2023 _ TensorFlow _ JAX _ PyTorch _ NumPy*Not for training Keras Core: From ground up image: Modern Times (1936) 9

Slide 10

Slide 10 text

Keras Core characteristics: Promising better future

Slide 11

Slide 11 text

• Switchable backend • Multi framework custom components • Universal training loop • Native models support • Future proof code So what will be better with Keras core? 11

Slide 12

Slide 12 text

• As easy as before _ os.environ[“KERAS_BACKEND”] • Notes _ Important import order torch should be imported before tensorflow Torch backend: • keras_core tensorflow _ No sparsity support yet Sparse types along frameworks are too broad now _ PyTorch compatibility Average pooling, integer dtypes Switchable backend import os os.environ["KERAS_BACKEND"] = "jax" import keras_core import jax import tensorflow as tf import os os.environ["KERAS_BACKEND"] = "torch" import keras_core import torch import tensorflow as tf 12

Slide 13

Slide 13 text

• NumPy API _ Matmul, sum, stack, einsum… • Network specific functions _ Softmax, binary crossentrypy, conv… • Models / layers / losses / metrics / optimizers _ Works the same with any framework _ Even outside Keras workflow! Cross framework components with keras.ops 13

Slide 14

Slide 14 text

• Use keras.ops to develop custom components • With backend specific custom components _ TensorFlow, PyTorch, JAX Multi framework custom components 14

Slide 15

Slide 15 text

• Use keras.ops to develop custom components • With backend specific custom components _ TensorFlow, PyTorch, JAX Multi framework custom components 15

Slide 16

Slide 16 text

• Train a Keras model with _ low level JAX training loop: optax optimizer, jax.grad, jax.jit, jax.pmap… _ Low level TensorFlow training loop: tf.GradientTape, tf.distribute _ Low level PyTorch training loop: torch.optim optimizer, torch.nn.parallel.DistributedDataParallel • Use a Keras layer or model as part of a torch.nn.Module _ PyTorch users can leverage Keras models like any other PyTorch module! Universal Training Loop: Seamless integration w/ backends https://www.youtube.com/watch?v 5fTPEoeFZk 16

Slide 17

Slide 17 text

• Supports all current Keras applications _ KerasCV YOLOv8 … _ KerasNLP BERT, OPT … • Supports all Keras applications with all backends! Native model support: Keras with Pretrained models 17

Slide 18

Slide 18 text

Keras Core with JAX Streamlining Deep Learning for Robust Acceleration

Slide 19

Slide 19 text

• JAX Deepmind, 2018 _ Designed for high performance machine learning research _ Provides composable transformations of Python and NumPy programs Automatic differentiation, vectorization vmap , parallelization pmap , etc. _ Autograd capability Can use grad function within JAX to automatically differentiate any function _ Others Random number generator, etc. • Why JAX? _ Scalable: CPU, GPU and TPUs with XLA compiler _ XLA? JAX: What is it? 19

Slide 20

Slide 20 text

• XLA Accelerated Linear Algebra, Google, 2017 _ Fast matrix operations for CPU, Nvidia GPU and TPU _ Just In Time JIT Compilation, Ahead Of Time AOT Compilation • MLIR Multi Level Intermediate Representation, Google, 2019 _ https://mlir.llvm.org _ LLVM backed intermediate compilation layer for machine learning • OpenXLA Google, 2023 _ Open source version of XLA StableHLO StableHLO: Operation set for high level operations HLO in ML models _ Machine Learning compiler for ML accelerator hardwares JAX: What lies beneath? 20

Slide 21

Slide 21 text

• JAX Low level API like PyTorch without torch.nn or TensorFlow without tf.keras • FLAX FLexible JAX Layer API from Google excluding DeepMind High level interface for defining and training neural networks based on JAX • Haiku Another layer API, from DeepMind Inspired by Sonnet for TensorFlow Not updated anymore: FLAX is recommended for JAX ecosystem • OPTAX Optimizers and loss function for JAX • ETC. JAX: Ecosystem 21

Slide 22

Slide 22 text

• There is no free lunch _ High learning curve JAX does not provide high level neural network features Functional programming _ JIT nature Immutable objects lead to low debuggability like old TensorFlow 1.X _ Lack of documentation yet! • Does FLAX/Haiku solve the issue? _ Lack of third party projects / codes _ Independent project: one more deep learning library JAX: Trouble 22

Slide 23

Slide 23 text

• To unleash the power of JAX, you d better care about distributed workload _ Data Sharding _ Device mesh _ ETC. • @jax.jit _ Decorator to use Just In Time compilation _ In fact, there is jit option in TensorFlow too JAX: Double Trouble 23

Slide 24

Slide 24 text

• Simple image classification model _ Defines CNN Layer definition: Input to Output _ Compile model _ Train the model _ Evaluate the result model Example: Causal Keras Code import jax import jax.numpy as jnp import keras_core as keras # Keras multi-backend # Input definition # Layer structure definition # Model compilation with input / output # Train # Evaluate # (Optional) Model export # (Optional, usually as indep. Project) Inference 24

Slide 25

Slide 25 text

• Simple image classification model _ Defines CNN Layer definition: Input to Output _ Compile model _ Train the model _ Evaluate the result model # Input definition conv2d_kwargs = { "kernel_size": (3, 3), "activation": "relu", "padding": "same", } inputs = keras_core.Input(shape=(32, 32, 3), name="input_layer") # Layer structure definition x = inputs for filters in [32, 64, 128]: x = keras_core.layers.Conv2D(filters=filters, **conv2d_kwargs)(x) x = keras_core.layers.BatchNormalization()(x) x = keras_core.layers.Conv2D(filters=filter, strides=(2, 2), **conv2d_kwargs)(x) x = keras_core.layers.BatchNormalization()(x) x = keras_core.layers.Dropout(0.25)(x) x = keras_core.layers.GlobalAveragePooling2D()(x) x = keras_core.layers.Dense(128, activation="relu")(x) x = keras_core.layers.Dropout(0.25)(x) outputs = keras_core.layers.Dense(10, activation="softmax", name="output_layer")(x) Example: Causal Keras Code 25

Slide 26

Slide 26 text

• JAX _ Stateless: JAX provides tools, not as complete framework States are passed explicitly as a parameter • State Trainable parameters Non trainable parameters Optimizer state • Keras core JAX _ We need to consider the stateless manner of JAX Example: Custom model with JAX 26

Slide 27

Slide 27 text

• Required methods for custom model _ compute_loss_and_updates _ train_step _ test_step • compute loss and updates _ Calculates loss update variables Custom model with JAX: Requirements class CustomModel(keras_core.Model): def compute_loss_and_updates( self, trainable_variables, non_trainable_variables, x, y, training=False, ): y_pred, non_trainable_variables = self.stateless_call( trainable_variables, non_trainable_variables, x, training=training, ) loss = self.compute_loss(x, y, y_pred) return loss, (y_pred, non_trainable_variables) 27

Slide 28

Slide 28 text

• train_steps _ Computes gradients _ Updates trainable variables _ Updates optimizer variables _ Updates metrics Custom model with JAX: Train steps def train_step(self, state, data): # Unpack the current state ( trainable_variables, non_trainable_variables, optimizer_variables, metrics_variables, ) = state grad_fn = jax.value_and_grad( self.compute_loss_and_updates, has_aux=True ) # Compute the gradients. (loss, (y_pred, non_trainable_variables)), grads = grad_fn(…) # Update trainable variables and optimizer variables. (…) = self.optimizer.stateless_apply(…) # Update metrics. state = (…) return logs, state 28

Slide 29

Slide 29 text

• train_steps _ Computes gradients _ Updates trainable variables _ Updates optimizer variables _ Updates metrics def train_step(self, state, data): # Unpack the current state ( trainable_variables, non_trainable_variables, optimizer_variables, metrics_variables, ) = state # Unpack the data x, y = data # Get the gradient function. grad_fn = jax.value_and_grad( self.compute_loss_and_updates, has_aux=True ) # Compute the gradients. (loss, (y_pred, non_trainable_variables)), grads = grad_fn( trainable_variables, non_trainable_variables, x, y, training=True, ) # Update trainable variables and optimizer variables. ( Custom model with JAX: Train steps 29

Slide 30

Slide 30 text

• train_steps _ Computes gradients Get the grad function from JAX • The gradients are computed by calling grad_fn with the current trainable and non trainable variables and the data. • Also returns the loss and the updated non trainable variables. _ Updates trainable variables _ Updates optimizer variables _ Updates metrics Custom model with JAX: compute gradients def train_step(self, state, data): # Unpack the current state … # Unpack the data x, y = data # Get the gradient function. grad_fn = jax.value_and_grad( self.compute_loss_and_updates, has_aux=True ) # Compute the gradients. (loss, (y_pred, non_trainable_variables)), grads = grad_fn( trainable_variables, non_trainable_variables, x, y, training=True, ) … 30

Slide 31

Slide 31 text

• train_steps _ Computes gradients _ Updates trainable variables optimizer variables • The stateless_apply method of the optimizer is used to update the trainable and optimizer variables using the computed gradients _ Updates metrics Custom model with JAX: update variables def train_step(self, state, data): # Unpack the current state … # Unpack the data … # Get the gradient function. … # Compute the gradients. … # Update trainable variables and optimizer variables. ( trainable_variables, optimizer_variables, ) = self.optimizer.stateless_apply( optimizer_variables, grads, trainable_variables ) # Update metrics. … 31

Slide 32

Slide 32 text

• train_steps _ Computes gradients _ Updates trainable variables _ Updates optimizer variables _ Updates metrics • Iterating over each metric in the model s metrics • Each metric s state variables are updated using its stateless_update_state method • Result is computed using its stateless_result method. Custom model with JAX: update metrics def train_step(self, state, data): # Update metrics. new_metrics_vars = [] for metric in self.metrics: this_metric_vars = metrics_variables[ len(new_metrics_vars) : len(new_metrics_vars) + len(metric.variables) ] if metric.name == "loss": this_metric_vars = metric.stateless_update_state( this_metric_vars, loss ) else: this_metric_vars = metric.stateless_update_state( this_metric_vars, y, y_pred ) logs = metric.stateless_result(this_metric_vars) new_metrics_vars += this_metric_vars # Return metric logs and updated state variables. state = ( trainable_variables, non_trainable_variables, optimizer_variables, new_metrics_vars, ) return logs, state 32

Slide 33

Slide 33 text

• test_steps _ Computes predictions losses _ Updates metrics _ Updates state variables _ Return metric logs and state variables Custom model with JAX: Test steps def test_step(self, state, data): # Unpack the data. x, y = data ( trainable_variables, non_trainable_variables, metrics_variables, ) = state # Compute predictions and loss. y_pred, non_trainable_variables = self.stateless_call( … training=False, ) loss = self.compute_loss(x, y, y_pred) # Update metrics. … new_metrics_vars += this_metric_vars # Return metric logs and updated state variables. state = (…) return logs, state 33

Slide 34

Slide 34 text

• test_steps _ Computes predictions losses _ Updates metrics _ Updates state variables _ Return metric logs and state variables def test_step(self, state, data): # Unpack the data. x, y = data ( trainable_variables, non_trainable_variables, metrics_variables, ) = state # Compute predictions and loss. y_pred, non_trainable_variables = self.stateless_call( trainable_variables, non_trainable_variables, x, training=False, ) loss = self.compute_loss(x, y, y_pred) # Update metrics. new_metrics_vars = [] for metric in self.metrics: this_metric_vars = metrics_variables[ len(new_metrics_vars) : len(new_metrics_vars) + len(metric.variables) ] if metric.name == "loss": Custom model with JAX: Test steps 34

Slide 35

Slide 35 text

• test_steps _ Computes predictions losses • stateless_call method of the model is used to compute the predictions and update the non trainable variables • The loss is computed using the compute_loss method of the model _ Updates metrics _ Updates state variables _ Return metric logs and state variables Custom model with JAX: Compute predictions def test_step(self, state, data): # Unpack the data. … # Compute predictions and loss. y_pred, non_trainable_variables = self.stateless_call( trainable_variables, non_trainable_variables, x, training=False, ) loss = self.compute_loss(x, y, y_pred) … 35

Slide 36

Slide 36 text

• test_steps _ Computes predictions losses _ Updates metrics _ Updates state variables • The metrics are updated by iterating over each metric in the model s metrics. • Each metric s state variables are updated using its stateless_update_state method • Result is computed using its stateless_result method _ Return metric logs and state variables Custom model with JAX: Update metrics states def test_step(self, state, data): … # Update metrics. new_metrics_vars = [] for metric in self.metrics: this_metric_vars = metrics_variables[ len(new_metrics_vars) : len(new_metrics_vars) + len(metric.variables) ] if metric.name == "loss": this_metric_vars = metric.stateless_update_state( this_metric_vars, loss ) else: this_metric_vars = metric.stateless_update_state( this_metric_vars, y, y_pred ) logs = metric.stateless_result(this_metric_vars) new_metrics_vars += this_metric_vars # Return metric logs and updated state variables. state = ( trainable_variables, non_trainable_variables, new_metrics_vars,) return logs, state 36

Slide 37

Slide 37 text

• Simple image classification model _ Defines CNN _ Compile model _ Train the model _ Evaluate the result model • This is so Keras _ Noting to explain but… _ One line is different! # Input definition conv2d_kwargs = { "kernel_size": (3, 3), "activation": "relu", "padding": "same", } inputs = keras_core.Input(shape=(32, 32, 3), name="input_layer") # Layer structure definition x = inputs for filters in [32, 64, 128]: x = keras_core.layers.Conv2D(filters=filters, **conv2d_kwargs)(x) x = keras_core.layers.BatchNormalization()(x) x = keras_core.layers.Conv2D(filters=filter, strides=(2, 2), **conv2d_kwargs)(x) x = keras_core.layers.BatchNormalization()(x) x = keras_core.layers.Dropout(0.25)(x) x = keras_core.layers.GlobalAveragePooling2D()(x) x = keras_core.layers.Dense(128, activation="relu")(x) x = keras_core.layers.Dropout(0.25)(x) outputs = keras_core.layers.Dense(10, activation="softmax", name="output_layer")(x) Custom model with JAX: Using custom model 37

Slide 38

Slide 38 text

• JAX CustomModel like pure Keras _ Brings JAX performance to Keras _ As reusable form! • Causal pattern _ Write JAX based custom model code as separate files _ Import them when needed Custom model with JAX: Using custom model ... model = CustomModel(inputs, outputs, name="image_classification_model") ... 38

Slide 39

Slide 39 text

• JAX DeviceMesh JAX Sharding _ from jax.experimental import mesh_utils _ from jax.sharding import Mesh _ from jax.sharding import NamedSharding Keras Core JAX: Distributed Training devices = mesh_utils.create_device_mesh((8,)) # data will be split along the batch axis # naming axes of the mesh data_mesh = Mesh(devices, axis_names=("batch",)) # naming axes of the sharded partition data_sharding = NamedSharding( data_mesh, P( "batch", ), ) # all variables will be replicated on all devices var_mesh = Mesh(devices, axis_names=("_")) # in NamedSharding, axes that are not mentioned are replicated var_replication = NamedSharding(var_mesh, P()) # Split the largest kernel 4-ways (and replicate 2-ways since we have 8 devices) large_kernel_mesh = Mesh( devices.reshape((-1, 4)), axis_names=(None, "out_chan") ) # naming axes of the mesh large_kernel_sharding = NamedSharding( large_kernel_mesh, P(None, None, None, "out_chan") ) # naming axes of the sharded partition 39

Slide 40

Slide 40 text

• JAX gives _ CPU/GPU optimization _ TPU compaitilibity / optimization • Great example: WhisperJAX _ Whisper: OpenAI s speech / transcribe model _ Whisper JAX: Port version to JAX • Results _ Up to 2x speed ups _ Run on TPU: 5x 10x speed gain with JAX TPU JAX: Performance gain OpenAI Transformer s Whisper JAX Whisper JAX Framework PyTorch PyTorch JAX JAX Backend GPU GPU GPU TPU 1 min 13.8 4.54 1.72 0.45 10 min 108.3 20.2 9.38 2.01 1 hour 1001.0 126.1 75.3 13.8 https://github.com/sanchit gandhi/whisper jax 40

Slide 41

Slide 41 text

• Win win _ : good demonstration multi backend Keras powering to other frameworks _ : now have a well documented robust framework! Keras Core JAX: combination 41

Slide 42

Slide 42 text

Keras core: consideration

Slide 43

Slide 43 text

• Performance _ Use JAX without hassle, with ease • Mixing ingredients _ Use Keras model like PyTorch module _ Model as stateless JAX function _ Ecosystem coverage: Keras models for both TensorFlow PyTorch • Cross data sourcing _ Optimized data wrapper / libraries with any backend _ E.g. NumPy, Pandas, tf.data, PyTorch DataLoader • Future proof _ Keras will remain even when others have vanished _ Theano, TensorFlow 1.X, CNTK… remember the past Keras Core: Bright side 43

Slide 44

Slide 44 text

• May I just use TensorFlow / PyTorch? _ The biggest question now: why Keras? _ Before we realized it, the ecosystem has become regulated... • With JAX, you will be noticed soon that _ It is not easy anymore • Home of super spaghetti codes? _ Keras other backend ML frameworks _ Even Keras provides future proof grammar, others will not _ The pros of Keras can also be cons Keras: Dark side 44

Slide 45

Slide 45 text

End! contact lablup.com https://www.facebook.com/lablupInc Lablup Inc. https://www.lablup.com Backend.AI https://www.backend.ai Backend.AI GitHub https://github.com/lablup/backend.ai Backend.AI Cloud https://cloud.backend.ai

Slide 46

Slide 46 text

• A. R. Gosthipaty and R. Raha. What Is Keras Core? PyImageSearch, P. Chugh, S. Huot, K. Kidriavsteva, and A. Thanki, eds., 2023 https://pyimagesearch.com/2023/07/24/what is keras core/ • Introduction to Keras Core with Francois Chollet https://www.youtube.com/watch?v 5fTPEoeFZk • Writing a custom training loop in JAX, GitHub/keras/keras core https://github.com/keras team/keras core/blob/main/guides/writing a custom training loop in jax.py • JAX: Can It Beat PyTorch and TensorFlow? https://www.it jim.com/blog/jax can it beat pytorch and tensorflow/ • JAX Distributed demo, GitHub/keras/keras core https://github.com/keras team/keras core/blob/main/examples/demo jax distributed.py References 46