I have to use tf.function(jit_compile=true), but it seems contradicts with tf.py_function, sending out error message No registered ‘EagerPyFunc’ OpKernel for XLA_GPU_JIT devices compatible with node {{node EagerPyFunc}}){{node EagerPyFunc}}
import tensorflow as tf
import numpy as np
def my_numpy_func(x): # This function must be numpy function, because it involves lots of scipy operations
# tf.function
return np.sinh(x)
# @tf.function(input_signature=[tf.TensorSpec(None, tf.float32)]) # ok, but I have to use jit_compile parameter
@tf.function(input_signature=[tf.TensorSpec(None, tf.float32)],jit_compile=True) # error because of jit_compile parameter
def tf_function(input):
y = tf.py_function(my_numpy_func, [input], tf.float32)
return y * y
a = tf_function(tf.constant(1.))
print(a)
error message
Traceback (most recent call last):
File "test10.py", line 15, in <module>
a = tf_function(tf.constant(1.))
File "/home/sjtusmartboy/.local/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 885, in __call__
result = self._call(*args, **kwds)
File "/home/sjtusmartboy/.local/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 957, in _call
filtered_flat_args, self._concrete_stateful_fn.captured_inputs) # pylint: disable=protected-access
File "/home/sjtusmartboy/.local/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 1964, in _call_flat
ctx, args, cancellation_manager=cancellation_manager))
File "/home/sjtusmartboy/.local/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 596, in call
ctx=ctx)
File "/home/sjtusmartboy/.local/lib/python3.6/site-packages/tensorflow/python/eager/execute.py", line 60, in quick_execute
inputs, attrs, num_outputs)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Detected unsupported operations when trying to compile graph __inference_tf_function_8[_XlaMustCompile=true,config_proto="\n\007\n\003CPU\020\001\n\007\n\003GPU\020\0012\005*\0010J\0008\001\202\001\000",executor_type=""] on XLA_GPU_JIT: EagerPyFunc (No registered 'EagerPyFunc' OpKernel for XLA_GPU_JIT devices compatible with node {{node EagerPyFunc}}){{node EagerPyFunc}}
The op is created at:
File "test10.py", line 15, in <module>
a = tf_function(tf.constant(1.))
File "test10.py", line 12, in tf_function
y = tf.py_function(my_numpy_func, [input], tf.float32) [Op:__inference_tf_function_8]