Error converting JAX model to TFLite

Hi everyone,

I’m trying to convert a pretty large JAX model to TFLite (> 10GB), so I could quantize it, and reduce the size. However, I’ve been running into a lot of issues, and was wondering if someone here could help.

I tried two approaches:

  1. Using ‘from_concrete_functions.’ Unfortunately, this gave me an error due to the 2GB limit, but I read somewhere that this problem is now fixed with with ‘Proto Splitter.’ Is it possible to use that here? If so, how would I go about doing that?

from jax.experimental import jax2tf
tf_predict = tf.function(
jax2tf.convert(predict_fn, enable_xla=False),
input_signature=[
#tf.TensorSpec(shape=, dtype=tf.float32, name=“params”),
tf.TensorSpec(shape=[1, 256], dtype=tf.int32, name=‘batch’),
tf.TensorSpec(shape=[1, 256], dtype=tf.bool, name=‘vision_masks’)
],
autograph=False)
converter = tf.lite.TFLiteConverter.from_concrete_functions(
[tf_predict.get_concrete_function()], tf_predict)
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
]
tflite_float_model = converter.convert()
with open(‘float_model.tflite’, “wb”) as f:
f.write(tflite_float_model)

converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_quantized_model = converter.convert()
with open(‘./quantized.tflite’, ‘wb’) as f:
f.write(tflite_quantized_model)

  1. When I couldn’t get the above to work, I decided to try the ‘experimental_from_jax’ function. This is a deprecated API, but it was available, so I wanted to give it a try. When I tried this approach, I received a segfault. Maybe it’s linked to the same 2GB limit issue.

from pdb import set_trace; set_trace()
converter = tf.lite.TFLiteConverter.experimental_from_jax([predict_fn], [[
(‘input1’, jnp.zeros((1,256), dtype=jnp.int32)), (‘input2’, jnp.zeros((1,256), dtype=jnp.int32))]])
tflite_model = converter.convert()

Is it currently possible to solve my problem or am I doing something wrong?

Here’s the ‘predict_fn’ function that is getting called by the above two approaches:

def predict_fn(batch, vision_masks):
params = restored_params # this is the frozendict params file
batch_tokens = batch
batch_vision_masks = vision_masks
vision_logits = model.apply(params, batch_tokens, batch_vision_masks, deterministic=True).logits
return vision_logits

Let me know what you guys think, thanks!

Can anyone help? Just trying to feel out if I should explore a different route. Thanks.

Hi @ups1 ,

TBH, I did not try anytime below 2 methods but could you please try from your end and let me know if it is successfully converted or not.

Try

  1. from_concrete_functions approach with the Proto Splitter feature.By setting experimental_lower_to_saved_model=True, the converter will split the frozen TensorFlow graph into multiple files, allowing it to handle models larger than 2GB.

  2. Another approach you can try to first convert your JAX model to a TensorFlow SavedModel format (use TensorFlow version 2.15.0) using tf.saved_model.save, and then use tf.lite.TFLiteConverter.from_saved_model to convert it to TFLite.

Thanks.

Hello, thanks for the response.

For (1), I got the following error again:

File “/usr/local/lib/python3.10/dist-packages/tensorflow/python/framework/tensor_util.py”, line 589, in make_tensor_proto
raise ValueError(
ValueError: Cannot create a tensor proto whose content is larger than 2GB.

It errored out while trying to execute the following line:

converter = tf.lite.TFLiteConverter.from_concrete_functions(
    [tf_predict.get_concrete_function()], tf_predict, experimental_lower_to_saved_model=True)

For (2), can you help me better understand how I would go about running that test? Currently, I’m loading my params using the following:

train_state, restored_params = checkpointer.load_trainstate_checkpoint(
    FLAGS.load_checkpoint, None, None, max_buffer_size=32 * 2 ** 30
)

And for prediction (as seen in the predict_fn above), I’m using the following:

vision_logits = model.apply(params, batch_tokens, batch_vision_masks, deterministic=True).logits

So, in this case, what exactly is the JAX model I should pass into tf.saved_model.save()?

Thanks.

TensorFlow 2.16 + Python 3.12 – JARaaS Hybrid RAG - 6/16/2024

Hello,

It seems you’re running into issues while trying to convert a large JAX model to TFLite due to the 2GB limit and perhaps some deprecated functionality. Here are a few suggestions and approaches based on the available information:

Using Proto Splitter

The 2GB limit has been a common issue; however, TensorFlow’s Proto Splitter might help split large model files into smaller chunks. Unfortunately, there’s no direct mention of JAX models in the text, but here is an example of how you might use it in TensorFlow:

import tensorflow as tf
from tensorflow.experimental import proto

# Assuming you have a tf.saved_model directory `saved_model_dir`
saved_model_dir = 'path_to_saved_model_dir'

# Load the SavedModel
loaded_model = tf.saved_model.load(saved_model_dir)

# Convert it to TFLite using the Proto Splitter
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.experimental_new_converter = True
tflite_model = converter.convert()

# Save the TFLite model
with open('model.tflite', 'wb') as f:
    f.write(tflite_model)

For JAX models, you might need to first convert your model to TensorFlow SavedModel format and then proceed with the conversion. The key is to use the new converter which handles the larger models more efficiently.

Converting from JAX to TFLite

Your current approach with JAX might need some fine-tuning. Consider breaking down the model into smaller parts or exporting sub-modules which are separately convertible. Here is how to convert parts of a model:

  1. Segmenting Model Conversion: Split the model into manageable parts if possible.

  2. Using the JAX to TensorFlow Conversion:

    from jax.experimental import jax2tf
    import tensorflow as tf
    
    def predict_fn(batch, vision_masks):
        params = restored_params
        batch_tokens = batch
        batch_vision_masks = vision_masks
        vision_logits = model.apply(params, batch_tokens, batch_vision_masks, deterministic=True).logits
        return vision_logits
    
    # Define JAX function
    def predict_function(*args):
        return predict_fn(*args)
    
    # Convert JAX function to TensorFlow function
    tf_predict = tf.function(
        jax2tf.convert(predict_function, enable_xla=False),
        input_signature=[
            tf.TensorSpec(shape=[1, 256], dtype=tf.int32, name='batch'),
            tf.TensorSpec(shape=[1, 256], dtype=tf.bool, name='vision_masks')
        ],
        autograph=False
    )
    
    # Save the TensorFlow model
    concrete_function = tf_predict.get_concrete_function()
    converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_function])
    converter.target_spec.supported_ops = [
        tf.lite.OpsSet.TFLITE_BUILTINS,
        tf.lite.OpsSet.SELECT_TF_OPS
    ]
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    
    # Convert the model
    tflite_model = converter.convert()
    with open('float_model.tflite', 'wb') as f:
        f.write(tflite_model)
    
    # Quantize the model
    tflite_quantized_model = converter.convert()
    with open('quantized_model.tflite', 'wb') as f:
        f.write(tflite_quantized_model)
    

Deprecated APIs

Using deprecated APIs like experimental_from_jax is not recommended. Instead, it’s better to follow the latest supported mechanisms which ensure better support and fewer errors.

Check TensorFlow Compatibility

Ensure you are using a compatible version of TensorFlow that supports all necessary operations for conversion and handles large models effectively. TensorFlow 2.x is recommended.

Further Debugging

The error might be due to other issues such as insufficient memory or compatibility issues with TensorFlow and JAX. Update to the latest stable versions of both libraries.

Additional Resources

For further information on conversion and troubleshooting issues, refer to the following guides:

Example Files

  • Common Issues with TensorFlow: common_issues.md (internal document)
  • TPU Configuration: using_tpu.md (internal document)
  • TFLite Conversion Guide: tflite.ipynb (internal document)

Let me know if you need any further assistance!

Sources:

  • TFLite Conversion Guide: tflite.ipynb (internal document)
  • Common Issues: common_issues.md (internal document)
  • TPU Configuration: using_tpu.md (internal document)

Got it, thanks. Sounds like there isn’t a clean way to solve this problem at the moment. Unfortunately, I don’t know enough about the internals of the model to split it up into smaller parts, but I’ll see what I can learn. Appreciate the help.