Slide 1

Slide 1 text

Keras 3 A multi-framework API for deep learning Wesley Kambale ML Engineer kambale.dev

Slide 2

Slide 2 text

Multi-backend Keras 01 Unified Distribution API 02 Applied AI with Keras CV Applied AI with Keras NLP 03 Contents

Slide 3

Slide 3 text

Multi-backend Keras 0 1 2 1 3

Slide 4

Slide 4 text

What is Keras? A high-level deep learning library used for building neural networks An open source and industry standard ● Large community presence of around 2.5 million developers ● 61% adoption rate among ML developers and Data Scientists in 2022, according to Kaggle [source] ● Used by startups and large companies alike, such as NeVlix, Uber, Yelp, Instacafi, etc. For the past 4 years, Keras has exclusively supported the TensorFlow backend From Keras 3.0 onwards, Keras will be multi-backend again!

Slide 5

Slide 5 text

Multi-backend Keras is back Full rewrite of Keras ● Now only 45k loc instead of 135k Support for TensorFlow, JAX, PyTorch, NumPy backends ● NumPy backend is inference-only Drop-in replacement for tf.keras when using TensorFlow backend ● Minimal changes needed XLA compilation by default in TF and JAX (soon PyTorch?)

Slide 6

Slide 6 text

Configuring your backend Command Line Colab

Slide 7

Slide 7 text

No content

Slide 8

Slide 8 text

Develop cross-framework components with keras.ops Includes the NumPy API – same functions, same arguments ● ops.matmul, ops.sum, ops.stack, ops.einsum, etc. IPlus neural network-specific functions absent from NumPy ● ops.softmax, oops.binary_crossentropy, ops.conv, etc. Models / layers / losses / metrics / optimizers written with Keras APIs work the same with any framework ● They can even be used outside of Keras workflows! ops.matmul tf.matmul jax.numpy.matmul_ torch.matmul np.matmul

Slide 9

Slide 9 text

Develop custom components that work with any framework using keras.ops (which includes the NumPY API) …

Slide 10

Slide 10 text

… or use your framework of choice for backend-specific components

Slide 11

Slide 11 text

Seamless integration with backend-native workflows Write a low-level JAX training loop to train a Keras model ● e.g. optax optimizer, jax.grad, jax.jit, ax.pmap... Write a low-level TensorFlow training loop to train a Keras model ● e.g. tf.GradientTape & tf.distribute. Write a low-level PyTorch training loop to train a Keras model ● e.g. torch.optim optimizer, torch loss functions Use a Keras layer or model as part of a torch.nn.Module ● PyTorch users can start leveraging Keras models whether or not they use Keras APIs! You can treat a Keras model just like any other PyTorch Module Etc.

Slide 12

Slide 12 text

Customizing model.fit(): PyTorch, TensorFlow

Slide 13

Slide 13 text

Writing a custom training loop for a Keras model

Slide 14

Slide 14 text

Beginner Intermediate Advanced Expert Model Building Sequential Functional custom layers Subclassed Model Model Training model.fit() Callbacks custom train_step custom training loop Progressive disclosure of complexity Start simple, then gradually gain arbitrary flexibility by "opening up the box"

Slide 15

Slide 15 text

You can choose how to input data! Keras 3.0 has full integration with useful existing data utilities across backends! tf.data.Dataset objects PyTorch DataLoader objects Numpy arrays Pandas dataframes

Slide 16

Slide 16 text

Why Keras? Maximize performance ● Pick the backend that's the fastest for your particular model Maximize available ecosystem surface ● Expofi your model to TF SavedModel (TFLite, TF.js, TF Serving, TF-MOT, etc.) ● Instantiate your model as a PyTorch Module and use it with the PyTorch ecosystem ● Call your model as a stateless JAX function and use it with JAX transforms ● Keras models are usable by anyone with no framework lock-in Maximize data source availability ● Use V.data, PyTorch DataLoader, NumPy, Pandas, etc. – with any backend

Slide 17

Slide 17 text

Keras = future-proof stability If you were a Theano user in 2016, you had to migrate to TF 1… … but if you were a Keras user on top of Theano, you got TF 1 nearly for free If you were a TF 1 user in 2019, you had to migrate to TF 2… … but if you were a Keras user on top of TF 1, you got TF 2 nearly for free If you are using Keras on top of TF 2 in 2023… … you get JAX and PyTorch support nearly for free And so on going forward Frameworks are transient, Keras is your rock.

Slide 18

Slide 18 text

1 02 Distribution API 3 2

Slide 19

Slide 19 text

Distribution API Machine 4 Machine 2 Machine 3 Machine 1 Machine 1 Machine 2 Machine 3 Machine 4 ● Data Parallel ● Tensor Parallel ● Model Parallel ● Pipeline Parallel ● More and more Data Parallelism Model Parallelism Express computation without changing the math part of the model

Slide 20

Slide 20 text

Distribution API 01 Data Parallel Model weights are replicated across all devices in the DeviceMesh, and each device processes a portion of the input data. 02 Model Parallel and Layout Map Spit model weights or activation tensors across all the devices on the DeviceMesh, and enable the horizontal scaling for the large models.

Slide 21

Slide 21 text

Distribution API - Data Parallel

Slide 22

Slide 22 text

Distribution API - Model Parallel

Slide 23

Slide 23 text

Distribution API - Model Parallel Specify sharding / replication for your Keras model

Slide 24

Slide 24 text

Distribution API - Data and Model Parallel

Slide 25

Slide 25 text

Distribution API Lightweight data parallel / model parallel distribution API built on top of: ● jax.sharding ● PyTorch/XLA sharding ● TensorFlow DTensor READY COMING SOON COMING SOON All the heavy lifting is already done in XLA GSPMD!

Slide 26

Slide 26 text

You can choose how to distribute! Keras 3.0 has complete optionality for distribution across backends ● A custom training loop ● Backend-specific distribution APIs directly ● Keras multi-backend distribution API

Slide 27

Slide 27 text

1 03 Applied AI - Keras CV & Keras NLP 2 3

Slide 28

Slide 28 text

Pre-trained models Keras 3.0 includes all Keras Applications (popular image classifiers) KerasCV and KerasNLP work out of the box with Keras 3 across all backends as of the latest releases ● YOLOv8 ● Whisper ● BERT ● OPT ● Stable Diffusion ● etc "a photograph of an astronaut riding a horse"

Slide 29

Slide 29 text

Hello World with KerasNLP

Slide 30

Slide 30 text

Hello World with KerasCV

Slide 31

Slide 31 text

Robert John GDE ML and Google Cloud Machine learning is the future

Slide 32

Slide 32 text

Keras 3 examples: https://keras.io/examples/

Slide 33

Slide 33 text

Thank you! Questions? Wesley Kambale ML Engineer kambale.dev