Keras 3 new features overview
- New distribution API
- Enable data and model parallel training across devices
- Initially via JAX, with TensorFlow and PyTorch support in later versions
- SparseTensor support with TensorFlow
- Train a model in one framework, reload in another framework
- Use the runtime most appropriate for your hardware / environment, without any code change
- Write framework-agnostic layers, models, losses, metrics, optimizers, and reuse them in native TF/JAX/PyTorch workflows
Keras 3.0 pre-release timeline
- How to install Keras 3?
- Now:
pip install keras-core
- In a few weeks:
pip install keras-nightly
- In November (tentative):
pip install keras
- Now:
- Contributing:
- The Keras 3.0 code is now in the keras-team/keras repo
tf.keras
compatibility
- TensorFlow 2.16 will use to Keras 3 by default.
- Upgrade from Keras 2 to Keras 3 for TensorFlow users
- No change needed in the user code
- If anything breaks, use the legacy
tf.keras
- How to install the keras 2 (the legacy tf.keras)?
- TensorFlow 2.15 and earlier:
pip install tensorflow
- From TensorFlow 2.16 release (Q1 2024):
pip uninstall keras
pip install tf-keras
- TensorFlow 2.15 and earlier:
Keras multi-backend distribute API
-
Lightweight data parallel / model parallel distribution API built on top of:
jax.sharding
- ready- PyTorch/XLA sharding - coming soon
- TensorFlow DTensor - coming soon
-
All the heavy lifting is already done in XLA and GSPMD !
-
Main class:
DataParallel
/ModelParallel
: distribution setting for model weights and input dataLayoutMap
: encapsulates sharding for your Keras model
-
Other API classes map directly onto backend primitives:
DeviceMesh
βjax.sharding.Mesh
TensorLayout
βjax.sharding.NamedSharding
-
Example:
-
Users can also train their Keras models using:
- A custom training loop
- Backend-specific distribution APIs directly
-
The roadmap
- Jax.sharding implementation - DONE
- Multi-worker/process training
- Model saving and checkpointing for distributed models
- Utility for distributing datasets (multi-backend equivalent of
tf.distribute.Strategy.distribute_datasets_from_function
) - Sharding constraints on intermediate values (equivalent of
jax.lax.with_sharding_constraint
) - MultiSlice capability (
jax.experimental.mesh_utils.create_hybrid_device_mesh
) - PyTorch/XLA sharding implementation
- TensorFlow DTensor implementation
KerasCV updates
- 0.7.0 release coming soon with some exciting new features
- Semantic segmentation support
- Full coverage of augmentation layers
- DeepLabV3Plus
- SegFormer
- Segment Anything Model (SAM)
- Guide to follow the 0.7 release
- Various small bug fixes
- Switched to programmatic API export
- Semantic segmentation support
KerasNLP updates
-
Upcoming features for LLM workflows.
- Preconfigured model parallelism for backbones.
- LoRA API for efficient fine tuning.
-
Letβs take a look.
- Note that these are still under development.
- APIs might change!
-
KerasNLP - Lora high-level
KerasNLP - Lora low-level
KerasNLP - Model parallelism
KerasTuner updates
- KerasTuner supports multi-backend Keras
- From KerasTuner v1.4
- Only exporting public APIs
- From KerasTuner v1.4
- All private APIs are hidden under keras_tuner.src
- Change
import keras_tuner.***
toimport keras_tuner.src.***
- Change