Keras Newsletter (August 4, 2023)

Keras Core beta release

  • Full rewrite of Keras
    • Now only 45k loc instead of 135k
  • Support for TensorFlow, JAX, PyTorch, NumPy backends
    • NumPy backend is inference-only
  • Drop-in replacement for tf.keras when using TensorFlow backend
    • Just change your imports!
  • Will become Keras 3 in the Fall

Learn more about KerasCore

The same Keras code runs with different frameworks

Use NumPy APIs to create models

Use native framework APIs with your model

  • Example: Writing custom training loop for Keras Core model with native PyTorch or TensorFlow APIs

Why Keras Core?

  • Maximize performance
    • Pick the backend that’s the fastest for your particular model
      *Typically, PyTorch < TensorFlow < JAX (by 10-20% jumps between frameworks)
  • Maximize available ecosystem surface
    • Export your model to TF SavedModel (TFLite, TF.js, TF Serving, TF-MOT, etc.)
    • Instantiate your model as a PyTorch Module and use it with the PyTorch ecosystem
    • Call your model as a stateless JAX function and use it with JAX transforms
  • Maximize addressable market for your OSS model releases
    • PyTorch, TF have only 40-60% of the market each
    • Keras models are usable by anyone with no framework lock-in
  • Maximize data source availability
    • Use tf.data, PyTorch DataLoader, NumPy, Pandas, etc. – with any backend

Keras Core benchmarks on BERT *

Training Inferencing
Keras Core (JAX) 229 ms/step 70 ms/step
Keras Core (TensorFlow) 227 ms/step 69 ms/step
Keras Core (PyTorch) 301 ms/step 88 ms/step
HuggingFace + PyTorch 261 ms/step 75 ms/step
tf.keras 364 ms/step 112 ms/step

* Tested on V100 GPU in Google Colab. These are temporary results subject to further optimization.

KerasCV & KerasNLP now support Keras Core

  • KerasCV and KerasNLP 0.6 Releases
    • Support for tf.keras will continue until Keras Core becomes Keras 3.0
  • All KerasCV components support all backends with Keras Core
    • Except:
      • StableDiffusion (coming in the next release)
      • CenterPillar
  • All KerasNLP components support all backends with Keras Core

Using Keras Core with Keras NLP

  • Switch back to tf.keras by un-setting KERAS_BACKEND
  • Persist configuration by editing ~/.keras/keras_nlp.json

Using Keras Core with KerasCV

A contradiction seems to exist between

and

The benchmarks show Keras Core with a TensorFlow backend is fastest, contrary to the claim that the JAX backend is fastest.

Sorry fort the confusion. The result here is only a specific case for BERT. The “JAX is faster” conclusion is based on the more comprehensive experiments we have done.

Will have more results released soon.

1 Like

*Typically, PyTorch < TensorFlow < JAX (by 10-20% jumps between frameworks)

Does this mean PyTorch is faster or JAX is faster?