Hey all,
I am looking at the documentation of MultiHeadAttention
layer. I do not really understand the use of the key_dim
parameter.
In the doc it says:
key_dim
: Size of each attention head for query and key.
Thanks in advance
Hey all,
I am looking at the documentation of MultiHeadAttention
layer. I do not really understand the use of the key_dim
parameter.
In the doc it says:
key_dim
: Size of each attention head for query and key.
Thanks in advance
Hi @ariG23498, The key_dim is the dimension of key for each head, represents how much vector length does each head process. The key_dim should equal to embed_dim /head_num . So, if we want to have a head_num of 5, the key_dim has to be 2, if embedding_dim is 10. Thank You.