Upgrade to Pro — share decks privately, control downloads, hide ads and more …

Keras 3: A multi-framework API for Deep Learning

Keras 3: A multi-framework API for Deep Learning

Given at IWD Mbarara 2024, the session introduces Keras v3 to intermediate and advanced AI/ML engineers and enthusiasts.

Wesley Kambale

March 09, 2024
Tweet

More Decks by Wesley Kambale

Other Decks in Programming

Transcript

  1. Multi-backend Keras 01 Unified Distribution API 02 Applied AI with

    Keras CV Applied AI with Keras NLP 03 Contents
  2. What is Keras? A high-level deep learning library used for

    building neural networks An open source and industry standard • Large community presence of around 2.5 million developers • 61% adoption rate among ML developers and Data Scientists in 2022, according to Kaggle [source] • Used by startups and large companies alike, such as NeVlix, Uber, Yelp, Instacafi, etc. For the past 4 years, Keras has exclusively supported the TensorFlow backend From Keras 3.0 onwards, Keras will be multi-backend again!
  3. Multi-backend Keras is back Full rewrite of Keras • Now

    only 45k loc instead of 135k Support for TensorFlow, JAX, PyTorch, NumPy backends • NumPy backend is inference-only Drop-in replacement for tf.keras when using TensorFlow backend • Minimal changes needed XLA compilation by default in TF and JAX (soon PyTorch?)
  4. Develop cross-framework components with keras.ops Includes the NumPy API –

    same functions, same arguments • ops.matmul, ops.sum, ops.stack, ops.einsum, etc. IPlus neural network-specific functions absent from NumPy • ops.softmax, oops.binary_crossentropy, ops.conv, etc. Models / layers / losses / metrics / optimizers written with Keras APIs work the same with any framework • They can even be used outside of Keras workflows! ops.matmul tf.matmul jax.numpy.matmul_ torch.matmul np.matmul
  5. Seamless integration with backend-native workflows Write a low-level JAX training

    loop to train a Keras model • e.g. optax optimizer, jax.grad, jax.jit, ax.pmap... Write a low-level TensorFlow training loop to train a Keras model • e.g. tf.GradientTape & tf.distribute. Write a low-level PyTorch training loop to train a Keras model • e.g. torch.optim optimizer, torch loss functions Use a Keras layer or model as part of a torch.nn.Module • PyTorch users can start leveraging Keras models whether or not they use Keras APIs! You can treat a Keras model just like any other PyTorch Module Etc.
  6. Beginner Intermediate Advanced Expert Model Building Sequential Functional custom layers

    Subclassed Model Model Training model.fit() Callbacks custom train_step custom training loop Progressive disclosure of complexity Start simple, then gradually gain arbitrary flexibility by "opening up the box"
  7. You can choose how to input data! Keras 3.0 has

    full integration with useful existing data utilities across backends! tf.data.Dataset objects PyTorch DataLoader objects Numpy arrays Pandas dataframes
  8. Why Keras? Maximize performance • Pick the backend that's the

    fastest for your particular model Maximize available ecosystem surface • Expofi your model to TF SavedModel (TFLite, TF.js, TF Serving, TF-MOT, etc.) • Instantiate your model as a PyTorch Module and use it with the PyTorch ecosystem • Call your model as a stateless JAX function and use it with JAX transforms • Keras models are usable by anyone with no framework lock-in Maximize data source availability • Use V.data, PyTorch DataLoader, NumPy, Pandas, etc. – with any backend
  9. Keras = future-proof stability If you were a Theano user

    in 2016, you had to migrate to TF 1… … but if you were a Keras user on top of Theano, you got TF 1 nearly for free If you were a TF 1 user in 2019, you had to migrate to TF 2… … but if you were a Keras user on top of TF 1, you got TF 2 nearly for free If you are using Keras on top of TF 2 in 2023… … you get JAX and PyTorch support nearly for free And so on going forward Frameworks are transient, Keras is your rock.
  10. Distribution API Machine 4 Machine 2 Machine 3 Machine 1

    Machine 1 Machine 2 Machine 3 Machine 4 • Data Parallel • Tensor Parallel • Model Parallel • Pipeline Parallel • More and more Data Parallelism Model Parallelism Express computation without changing the math part of the model
  11. Distribution API 01 Data Parallel Model weights are replicated across

    all devices in the DeviceMesh, and each device processes a portion of the input data. 02 Model Parallel and Layout Map Spit model weights or activation tensors across all the devices on the DeviceMesh, and enable the horizontal scaling for the large models.
  12. Distribution API Lightweight data parallel / model parallel distribution API

    built on top of: • jax.sharding • PyTorch/XLA sharding • TensorFlow DTensor READY COMING SOON COMING SOON All the heavy lifting is already done in XLA GSPMD!
  13. You can choose how to distribute! Keras 3.0 has complete

    optionality for distribution across backends • A custom training loop • Backend-specific distribution APIs directly • Keras multi-backend distribution API
  14. Pre-trained models Keras 3.0 includes all Keras Applications (popular image

    classifiers) KerasCV and KerasNLP work out of the box with Keras 3 across all backends as of the latest releases • YOLOv8 • Whisper • BERT • OPT • Stable Diffusion • etc "a photograph of an astronaut riding a horse"