Discrepancy code/schematic figure in Tensorflow Tutorial?

I am implementing a encoder-decoder transformer for language translation and using the TensorFlow tutorial as a guide: Modelo de transformador para compreensão da linguagem  |  Text  |  TensorFlow

I am confused at a specific point, notably where the MultiHead Class is created. There seems to be a discrepancy between the schematic figure: On the figure the q,k,v are first split and then run through the Dense layer. However, in the code itself, it seems they are first run through the dense layer, and only afterwards split up.

This is a significant difference, right?

Thanks :slight_smile:

1 Like

hey @markdaoust , can you help here?

Hi! Assuming you’re pointing to the following:

MHA

  • The corresponding MHA explanation:

    “Each multi-head attention block gets three inputs; Q (query), K (key), V (value). These are put through linear (Dense) layers and split up into multiple heads.”

  • Like this:

    class MultiHeadAttention(tf.keras.layers.Layer):
      def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
    ...
        self.wq = tf.keras.layers.Dense(d_model)
        self.wk = tf.keras.layers.Dense(d_model)
        self.wv = tf.keras.layers.Dense(d_model)
    ...
      def call(self, v, k, q, mask):
        batch_size = tf.shape(q)[0]
    
        q = self.wq(q)  # (batch_size, seq_len, d_model)
        k = self.wk(k)  # (batch_size, seq_len, d_model)
        v = self.wv(v)  # (batch_size, seq_len, d_model)
    
    ...
    
  • Like this:

    class MultiHeadAttention(tf.keras.layers.Layer):
    ...
    def call(self, v, k, q, mask):
    ...
      q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
      k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
      v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len_v, depth)
    

Is this correct?

The num_heads corresponds to h in the diagram, if I understand it correctly:

MHA

  def split_heads(self, x, batch_size):
    """Split the last dimension into (num_heads, depth).
    Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
    """
    x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
    return tf.transpose(x, perm=[0, 2, 1, 3])

So, maybe the diagram does show the linear/dense -> split order rather than split -> linear/dense. Does it make sense? Let’s also loop in @markdaoust

If you work through it you’ll see that the two are equivalent. Splitting the output of a Dense into num_heads chunks, or creating num_heads Dense layers is the same (except maybe for the initialization). Draw the weight matrices to see.

The code will be clearer when we can switch this to use https://www.tensorflow.org/api_docs/python/tf/keras/layers/experimental/EinsumDense (and maybe einops), but we try to avoid using experimental symbols in the tutorials.

2 Likes

Thanks for your response! You are indeed correct on restating my questions perfectly! However, I still do not see how the diagram shows the inputs going through a linear/Dense and only then being split. In that case, I would expect only three linear/Dense layers on the figure, not nine.

What am I missing here? :slight_smile:

1 Like

Alright, I will definitely try that. Thanks! :slight_smile:

1 Like

I would expect only three linear/Dense layers on the figure, not nine.

What am I missing here? :slight_smile:

Nothing. The code is correct, and simpler (+more efficient?) this way, with only one big linear layer. I didn’t have a good way to edit the figure. And I can’t switch it to EinsumDense yet.

If there’s something missing here it’sa sentence or two in the notebook explaining. Any chance you could send PR?

1 Like

Alright, cool. Thanks for helping. I’d love to help as well by sending a PR, but I’m not sure what you mean :smiley:

1 Like

:+1: I can see what you mean. Notice that the original diagram in the Attention Is All Your Need paper is the same :

image

This full Transformer architecture (in one place) from a popular blog post may be useful, in case you haven’t checked it out. Also, this: Reformer: The Efficient Transformer.