Is it possible to integrate PocketFFT when using a CPU into Tensorflow, which functions like tf.signal.stft and tf.signal.inverse_stft can leverage?
Currently, Tensorflow’s FFT uses EigenFFT, which is almost 3x slower than Numpy and Jax, which use PocketFFT.
There are heaps more details here: tf.signal CPU FFT implementation is slower than NumPy, PyTorch, etc. · Issue #6541 · tensorflow/tensorflow · GitHub
I’m sure many projects would benefit from this investment considering much of what’s done for speech and music these days use STFT data.
print(“seconds (lower is better):”)
print(f"Tensorflow {tf.version}“, timeit.timeit(‘X = tf.signal.rfft(x)’, setup=‘import tensorflow as tf; x = tf.random.normal([50000, 512])’, number=10))
print(f"Tensorflow {tf.version}, double precision”, timeit.timeit(‘X = tf.cast(tf.signal.rfft(tf.cast(x, tf.float64)), tf.complex64)’, setup=‘import tensorflow as tf; x = tf.random.normal([50000, 512])’, number=10))
print("Numpy: ", timeit.timeit(‘X = numpy.fft.rfft(x)’, setup=‘import numpy.fft; import tensorflow as tf; x = tf.random.normal([50000, 512])’, number=10))
print("Jax: ", timeit.timeit(‘jnp.fft.rfft(x).block_until_ready()’, setup=‘import jax.numpy as jnp; import tensorflow as tf; x = tf.random.normal([50000, 512]).numpy()’, number=10))
seconds (lower is better):
Tensorflow 2.9.1 5.495112890999991
Tensorflow 2.9.1, double precision 7.629201937000033
Numpy: 2.1803204349999987
Jax: 1.4081462569999985