`key_dim` in multihead attention layer

Hey all,

I am looking at the documentation of MultiHeadAttention layer. I do not really understand the use of the key_dim parameter.

In the doc it says:

key_dim: Size of each attention head for query and key.

Thanks in advance :slight_smile:

Hi @ariG23498, The key_dim is the dimension of key for each head, represents how much vector length does each head process. The key_dim should equal to embed_dim /head_num . So, if we want to have a head_num of 5, the key_dim has to be 2, if embedding_dim is 10. Thank You.