MultiHeadAttention output shape & use_causal_mask

A couple of silly questions:

Let’s say I am using MultiHeadAttention in encoder-decoder architecture. In this case, the decoder part of it generates a single output token at a time. If this is correct, then why does the documentation say this:

attention_output The result of the computation, of shape (B, T, E), where T is for target sequence shapes and E is the query input last dimension

Is T == 1?

Also, how does use_causal_mask work? During inference, in order for TF to know what to mask, it has to know the current time-step (how many output tokens has already been generated) to mask the rest, right? Where does it get the information?

Hi @dimanne,

As far as I know, the decoder generates one output token at a time but the target sequence shape T value can be more than 1 since the decoder must generate a sequence of tokens.

Causal self-attention ensures that the output for each sequence element only depends on the previous sequence elements not the future ones. Kindly refer this documentation for more information. During inference, the decoder state is being used to check the current step where causal masking is done from there.

Hope this clarifies. Thank You.