Hello everyone,
I am having troubles debugging my code. I have recently switched from Pytorch to Tensorflow and I am trying to reproduce the model from the paper Attention Is All You need from scratch in tf.Keras. I am not using the predefined tf.keras.layers.MultiHeadAttention
. Instead I have created my own version which is modelled after the graph shown in the paper. The code is quite lengthy but I will show it later first I want to specify the problem. I am getting the following:
WARNING:tensorflow:Gradients do not exist for variables
['transformer_1/encoder_1/positional_embedding_layer_3/embedding_3/embeddings:0','transformer_1/encoder_1/encoder_layer_4/sub_module_20/multi_headed_attention_12/attention_12/dense_65/kernel:0',
'transformer_1/encoder_1/encoder_layer_4/sub_module_20/multi_headed_attention_12/attention_12/dense_65/bias:0',
.......
When looking online I found 2 possible reasons why that might happen:
- I am transforming some variables to numpy arrays and thus I drop its gradient.
- I am not properly connecting the graph.
I looked carefully and I am sure that I am not converting anything from the graph to numpy so the problem must be in 2 (or something else that I dont know yet).
I will show you the code and provide a brief explanation for some choices (Quite big but bear with me).
- In the paper they say:
We employ a residual connection [10] around each of the two sub-layers, followed by layer normalization [1]. That is, the output of each sub-layer is LayerNorm(x + Sublayer(x)), where Sublayer(x)
. So I modelled that as Layer.
class SubModule(tf.keras.layers.Layer):
def __init__(self,Sublayer):
"""_summary_
Args:
Sublayer (keras.Layer): A submodule in the encoder or decoder
"""
super(SubModule, self).__init__()
self.Sublayer = Sublayer
self.LayerNorm = tf.keras.layers.LayerNormalization()
self.Add = tf.keras.layers.Add()
def call(self,x,k=None,v=None):
if k is not None and v is not None:
return self.LayerNorm(self.Add([x,self.Sublayer(x,k,v)]))
else:
return self.LayerNorm(self.Add([x,self.Sublayer(x)]))
The reason I have an if statement there is because when I implement the cross attention I need to get inputs from the encoder as well.
- My implementation of the MultiHeadAttention: It follows Figure 2 in the paper.
class Attention(tf.keras.layers.Layer):
def __init__(self,dims,use_mask=False):
"""_summary_
Args:
dims (int): The number of dimensions to attend to
mask (tensor, optional): Mask parts of the input
"""
super(Attention, self).__init__()
self.d_k = tf.keras.layers.Dense(dims)
self.d_q = tf.keras.layers.Dense(dims)
self.d_v = tf.keras.layers.Dense(dims)
self.dims = dims
self.use_mask = use_mask
self.last_attention_score = None
def ScaledDotProductAttention(self,Q,V,K):
score = tf.matmul(Q,K,transpose_b=True)/tf.math.sqrt(self.dims/1)
if self.use_mask:
mask = self.compute_causal_mask(score.shape[-1])
score = score + mask
score = keras.activations.softmax(score)
return tf.matmul(score,V)
def compute_causal_mask(self,score_dims):
mask = np.triu(np.ones((score_dims,score_dims))*-np.inf,1)
return tf.cast(mask,dtype=tf.float32)
def call(self,q,k,v):
Q,K,V = self.d_q(q),self.d_k(q),self.d_v(q)
return self.ScaledDotProductAttention(Q,K,V)
class MultiHeadedAttention(tf.keras.layers.Layer):
def __init__(self,d_model, h, use_mask=False):
"""_summary_
Args:
d_model (int): length of the embedding
h (int): number of heads
mask (tensor, optional): Mask parts of the input
"""
super(MultiHeadedAttention,self).__init__()
self.h = h
self.heads = Attention(d_model/h, use_mask)
self.reshape = tf.keras.layers.Reshape(target_shape=(-1,h,d_model//h))
self.reverse = tf.keras.layers.Reshape(target_shape=(-1,d_model))
self.WO = tf.keras.layers.Dense(d_model)
def split(self,x):
"""_summary_
Args:
x (tensor): a tensor of shape (batch_size,seq_length,d_model)
Returns:
(tensor): a tensor of shape (batch_size,h,seq_length,d_model//h)
"""
return tf.transpose(self.reshape(x),perm=[0,2,1,3])
def concat(self,x):
"""_summary_
Args:
x (tensor): a tensor of shape (batch_size,h,seq_length,d_model//h)
Returns:
(tensor): a tensor of shape (batch_size,seq_length,d_model)
"""
return self.reverse(tf.transpose(x,perm=[0,2,1,3]))
def call(self,q,k,v):
q = self.split(q)
k = self.split(k)
v = self.split(v)
x = self.heads(q,k,v) # tensor of shape=(batch_size,seq_len,d_model//h)
x = self.concat(x)
return self.WO(x)
- The position-wise feed-forward network is simple and looks as follows:
class PositionWiseFeedForward(tf.keras.layers.Layer):
def __init__(self,d_model,d_ff):
super(PositionWiseFeedForward,self).__init__()
self.l1 = tf.keras.layers.Dense(d_ff,activation="relu")
self.l2 = tf.keras.layers.Dense(d_model)
def call(self,x):
return self.l2(self.l1(x))
- The Encoder is a stack of N=6 encoderLayers
class EncoderLayer(tf.keras.layers.Layer):
def __init__(self,d_model,h,dff):
super(EncoderLayer,self).__init__()
self.fsm = SubModule(MultiHeadedAttention(d_model,h))
self.ssm = SubModule(PositionWiseFeedForward(d_model,dff))
def call(self,x):
x = self.fsm(x,x,x)
x = self.ssm(x)
return x
class Encoder(tf.keras.layers.Layer):
def __init__(self,N,d_model,max_seq_len,D,h,dff):
super(Encoder,self).__init__()
self.embedding = PositionalEmbeddingLayer(
d_model = d_model,
max_seq_len=max_seq_len,
D=D)
self.encoderStack = [
EncoderLayer(
d_model=d_model,
h=h,
dff=dff)
for _ in range(N)]
def call(self,x):
posX = self.embedding(x)
for i in range(len(self.encoderStack)):
posX = self.encoderStack[i](posX)
return posX
- To save space assume that the decoder is created using those blocks as well. The final transformer looks as follows.
class Transformer(keras.Model):
def __init__(self,N, d_model, h, dff, max_seq_len, in_D, out_D):
"""_summary_
Args:
N (int): Number of layers in the encoder and decoder
d_model (int): Dimensions of the embedding
h (int): Number of heads
dff (int): Dimensions of the positionwise feedforward network
max_seq_len (int): Length of the sequence
in_D (int): Size of the input vocabulary
out_D (int): Size of the output vocabulary
"""
super().__init__()
self.encoder = Encoder(
N=N,
d_model=d_model,
max_seq_len=max_seq_len,
D=in_D,
h=h,
dff=dff)
self.decoder = Decoder(
N=N,
d_model=d_model,
max_seq_len=max_seq_len,
D=out_D,
h=h,
dff=dff
)
def call(self, inputs):
context,x = inputs
z = self.encoder(context)
y = self.decoder(x,z)
# Something I saw on the internet to save memory
try:
del y._keras_mask
except AttributeError:
pass
return y
The positional embedding transforms something of shape (batch,len) to (batch,len,d_model) It is also a keras layer.
Can anyone help me to pinpoint the mistake?
Thank you in advance.