Implement MultiHeadAttention() into an simple Model

Hello,

I am trying to implement the MultiHeadAttention layer in a small model.

I would like to realize the equivalent of a self attention
with this layer in a model similar to this one:

inp = tf.keras.layers.Input((10,64))
layer = tf.keras.layers.AdditiveAttention()([inp,inp])
model = tf.keras.Model(inputs=inp, outputs=layer)

but replacing the AdditiveAttention by MultiHeadAttention like this:

inp = tf.keras.layers.Input((10,64))
layer = tf.keras.layers.MultiHeadAttention(num_heads = 2, key_dim = 32)([inp,inp,inp])
model = tf.keras.Model(inputs=inp, outputs=layer)

However by doing this I get the error “call() missing 1 required positional argument: ‘value’”.

I have done some research on how the MultiHeadAttention layer works, however I am not sure what the key_dim and value_dim parameters are.
I would have thought that key_dims was used to change the output shape but when you change the value in the example found at “tf.keras.layers.MultiHeadAttention  |  TensorFlow v2.16.1”, this is not the case.

layer = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=32)
target = tf.keras.Input(shape=[8, 16])
source = tf.keras.Input(shape=[4, 16])
output_tensor, weights = layer(target, source,
return_attention_scores=True)
print(output_tensor.shape)
print(weights.shape)

Thanks in advance

1 Like

Hi @coco, Instead of passing arguments for the quary,key,value in a list pass them as individually
layer = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=32)(inp,inp,inp) which will not produce that error. please refer to this gist for working code example. Thank You.