I’m trying to convert a pretty large JAX model to TFLite (> 10GB), so I could quantize it, and reduce the size. However, I’ve been running into a lot of issues, and was wondering if someone here could help.
I tried two approaches:
Using ‘from_concrete_functions.’ Unfortunately, this gave me an error due to the 2GB limit, but I read somewhere that this problem is now fixed with with ‘Proto Splitter.’ Is it possible to use that here? If so, how would I go about doing that?
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_quantized_model = converter.convert()
with open(‘./quantized.tflite’, ‘wb’) as f:
f.write(tflite_quantized_model)
When I couldn’t get the above to work, I decided to try the ‘experimental_from_jax’ function. This is a deprecated API, but it was available, so I wanted to give it a try. When I tried this approach, I received a segfault. Maybe it’s linked to the same 2GB limit issue.
TBH, I did not try anytime below 2 methods but could you please try from your end and let me know if it is successfully converted or not.
Try
from_concrete_functions approach with the Proto Splitter feature.By setting experimental_lower_to_saved_model=True, the converter will split the frozen TensorFlow graph into multiple files, allowing it to handle models larger than 2GB.
Another approach you can try to first convert your JAX model to a TensorFlow SavedModel format (use TensorFlow version 2.15.0) using tf.saved_model.save, and then use tf.lite.TFLiteConverter.from_saved_model to convert it to TFLite.
File “/usr/local/lib/python3.10/dist-packages/tensorflow/python/framework/tensor_util.py”, line 589, in make_tensor_proto
raise ValueError(
ValueError: Cannot create a tensor proto whose content is larger than 2GB.
It errored out while trying to execute the following line:
It seems you’re running into issues while trying to convert a large JAX model to TFLite due to the 2GB limit and perhaps some deprecated functionality. Here are a few suggestions and approaches based on the available information:
Using Proto Splitter
The 2GB limit has been a common issue; however, TensorFlow’s Proto Splitter might help split large model files into smaller chunks. Unfortunately, there’s no direct mention of JAX models in the text, but here is an example of how you might use it in TensorFlow:
import tensorflow as tf
from tensorflow.experimental import proto
# Assuming you have a tf.saved_model directory `saved_model_dir`
saved_model_dir = 'path_to_saved_model_dir'
# Load the SavedModel
loaded_model = tf.saved_model.load(saved_model_dir)
# Convert it to TFLite using the Proto Splitter
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.experimental_new_converter = True
tflite_model = converter.convert()
# Save the TFLite model
with open('model.tflite', 'wb') as f:
f.write(tflite_model)
For JAX models, you might need to first convert your model to TensorFlow SavedModel format and then proceed with the conversion. The key is to use the new converter which handles the larger models more efficiently.
Converting from JAX to TFLite
Your current approach with JAX might need some fine-tuning. Consider breaking down the model into smaller parts or exporting sub-modules which are separately convertible. Here is how to convert parts of a model:
Segmenting Model Conversion: Split the model into manageable parts if possible.
Using the JAX to TensorFlow Conversion:
from jax.experimental import jax2tf
import tensorflow as tf
def predict_fn(batch, vision_masks):
params = restored_params
batch_tokens = batch
batch_vision_masks = vision_masks
vision_logits = model.apply(params, batch_tokens, batch_vision_masks, deterministic=True).logits
return vision_logits
# Define JAX function
def predict_function(*args):
return predict_fn(*args)
# Convert JAX function to TensorFlow function
tf_predict = tf.function(
jax2tf.convert(predict_function, enable_xla=False),
input_signature=[
tf.TensorSpec(shape=[1, 256], dtype=tf.int32, name='batch'),
tf.TensorSpec(shape=[1, 256], dtype=tf.bool, name='vision_masks')
],
autograph=False
)
# Save the TensorFlow model
concrete_function = tf_predict.get_concrete_function()
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_function])
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS,
tf.lite.OpsSet.SELECT_TF_OPS
]
converter.optimizations = [tf.lite.Optimize.DEFAULT]
# Convert the model
tflite_model = converter.convert()
with open('float_model.tflite', 'wb') as f:
f.write(tflite_model)
# Quantize the model
tflite_quantized_model = converter.convert()
with open('quantized_model.tflite', 'wb') as f:
f.write(tflite_quantized_model)
Deprecated APIs
Using deprecated APIs like experimental_from_jax is not recommended. Instead, it’s better to follow the latest supported mechanisms which ensure better support and fewer errors.
Check TensorFlow Compatibility
Ensure you are using a compatible version of TensorFlow that supports all necessary operations for conversion and handles large models effectively. TensorFlow 2.x is recommended.
Further Debugging
The error might be due to other issues such as insufficient memory or compatibility issues with TensorFlow and JAX. Update to the latest stable versions of both libraries.
Additional Resources
For further information on conversion and troubleshooting issues, refer to the following guides:
Got it, thanks. Sounds like there isn’t a clean way to solve this problem at the moment. Unfortunately, I don’t know enough about the internals of the model to split it up into smaller parts, but I’ll see what I can learn. Appreciate the help.