Slow fft using TF compared to Numpy and Jax

Is it possible to integrate PocketFFT when using a CPU into Tensorflow, which functions like tf.signal.stft and tf.signal.inverse_stft can leverage?

Currently, Tensorflow’s FFT uses EigenFFT, which is almost 3x slower than Numpy and Jax, which use PocketFFT.

There are heaps more details here: tf.signal CPU FFT implementation is slower than NumPy, PyTorch, etc. · Issue #6541 · tensorflow/tensorflow · GitHub

I’m sure many projects would benefit from this investment considering much of what’s done for speech and music these days use STFT data.

print(“seconds (lower is better):”)
print(f"Tensorflow {tf.version}“, timeit.timeit(‘X = tf.signal.rfft(x)’, setup=‘import tensorflow as tf; x = tf.random.normal([50000, 512])’, number=10))
print(f"Tensorflow {tf.version}, double precision”, timeit.timeit(‘X = tf.cast(tf.signal.rfft(tf.cast(x, tf.float64)), tf.complex64)’, setup=‘import tensorflow as tf; x = tf.random.normal([50000, 512])’, number=10))
print("Numpy: ", timeit.timeit(‘X = numpy.fft.rfft(x)’, setup=‘import numpy.fft; import tensorflow as tf; x = tf.random.normal([50000, 512])’, number=10))
print("Jax: ", timeit.timeit(‘jnp.fft.rfft(x).block_until_ready()’, setup=‘import jax.numpy as jnp; import tensorflow as tf; x = tf.random.normal([50000, 512]).numpy()’, number=10))

seconds (lower is better):
Tensorflow 2.9.1 5.495112890999991
Tensorflow 2.9.1, double precision 7.629201937000033
Numpy: 2.1803204349999987
Jax: 1.4081462569999985

TL;DR: it’s complicated. For individual matrix operations on CPU, JAX is often slower than NumPy, but JIT-compiled sequences of operations in JAX are often faster than NumPy, and once you move to GPU/TPU, JAX will generally be much faster than NumPy.

@Ramyar_Jahani I noticed an almost 3x speed improvement using both jax and numpy on CPU, compared to Tensorflow.
I’m able to wrap numpy’s fft in a tf.py_function, and use jax’s fft using jax2tf.convert.

The problem is that although you could use these functions in eager mode and get reasonable speed improvements in Python on CPU, there’s no way to serialize them into a protobuf for production :frowning: