Positional Encoding Speedup

I’m working on a Transformer based model and I followed the great example of the positional encoding from:

Since the original implementation relies heavily on Numpy, I’ve created a pure TF variation that runs 100 times faster.

Hopes it would help others too:

import tensorflow as tf

def get_angles(pos, i, d_model):
return pos * (1 / tf.math.pow(10000.0, (2 * (i // 2)) / d_model))

@tf.function
def positional_encoding(pos, d_model):
angle_rads = get_angles(tf.range(pos, dtype=tf.float32)[:, tf.newaxis],
tf.range(d_model, dtype=tf.float32)[tf.newaxis, :], d_model)
return tf.reshape(tf.concat([tf.expand_dims(tf.sin(angle_rads[:, ::2]), axis=-1),
tf.expand_dims(tf.cos(angle_rads[:, 1::2]), axis=-1)], axis=-1), [1, pos, -1])

n, d = 2048, 512
timeit pos_encoding = positional_encoding(n, d)
137 µs ± 2.2 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Sorry, I omitted the from link path (text/tutorials/transformer)

And the original timing of the tutorial implementation is:

timeit pos_encoding = positional_encoding(n, d)
17.8 ms ± 582 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)