How to solve tensorflow jit compile error

I have to use tf.function(jit_compile=true), but it seems contradicts with tf.py_function, sending out error message No registered ‘EagerPyFunc’ OpKernel for XLA_GPU_JIT devices compatible with node {{node EagerPyFunc}}){{node EagerPyFunc}}

import tensorflow as tf
import numpy as np

def my_numpy_func(x):  # This function must be numpy function, because it involves lots of scipy operations
  # tf.function
  return np.sinh(x)

# @tf.function(input_signature=[tf.TensorSpec(None, tf.float32)]) # ok, but I have to use jit_compile parameter
@tf.function(input_signature=[tf.TensorSpec(None, tf.float32)],jit_compile=True) # error because of jit_compile parameter
def tf_function(input):
  y = tf.py_function(my_numpy_func, [input], tf.float32)
  return y * y

a = tf_function(tf.constant(1.))
print(a)

error message

Traceback (most recent call last):
  File "test10.py", line 15, in <module>
    a = tf_function(tf.constant(1.))
  File "/home/sjtusmartboy/.local/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 885, in __call__
    result = self._call(*args, **kwds)
  File "/home/sjtusmartboy/.local/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 957, in _call
    filtered_flat_args, self._concrete_stateful_fn.captured_inputs)  # pylint: disable=protected-access
  File "/home/sjtusmartboy/.local/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 1964, in _call_flat
    ctx, args, cancellation_manager=cancellation_manager))
  File "/home/sjtusmartboy/.local/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 596, in call
    ctx=ctx)
  File "/home/sjtusmartboy/.local/lib/python3.6/site-packages/tensorflow/python/eager/execute.py", line 60, in quick_execute
    inputs, attrs, num_outputs)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Detected unsupported operations when trying to compile graph __inference_tf_function_8[_XlaMustCompile=true,config_proto="\n\007\n\003CPU\020\001\n\007\n\003GPU\020\0012\005*\0010J\0008\001\202\001\000",executor_type=""] on XLA_GPU_JIT: EagerPyFunc (No registered 'EagerPyFunc' OpKernel for XLA_GPU_JIT devices compatible with node {{node EagerPyFunc}}){{node EagerPyFunc}}
The op is created at: 
File "test10.py", line 15, in <module>
  a = tf_function(tf.constant(1.))
File "test10.py", line 12, in tf_function
  y = tf.py_function(my_numpy_func, [input], tf.float32) [Op:__inference_tf_function_8]

This is expected - jit_compile=True uses XLA. There is no supported XLA lowering for tf.py_function and so it raises: tf.function  |  TensorFlow v2.16.1

The error is telling you XLA_GPU_JIT has no supported kernel/lowering for py_function. Is there a reason you need to use jit_compile=True?

2 Likes

just to add to jit_compile=True, this video was published recently with a lot of info: