I have a custom layer, GlobalAttentionPoolingHead
, which is meant to reduce the outputs from a transformer base model to a single vector. The theory can be seen at QARAC: Models andCorpora.
import keras
import tensorflow
class GlobalAttentionPoolingHead(keras.layers.Layer):
def __init__(self):
"""
Creates the layer
Returns
-------
None.
"""
super(GlobalAttentionPoolingHead,self).__init__()
self.global_projection = None
self.local_projection = None
def build(self,input_shape):
"""
Initialises layer weights
Parameters
----------
input_shape : tuple
Shape of the input layer
Returns
-------
None.
"""
width = input_shape[-1]
self.global_projection = self.add_weight('global projection',shape=(width,width))
self.local_projection = self.add_weight('local projection',shape=(width,width))
self.built=True
def call(self,X,training=None):
"""
Parameters
----------
X : tensorflow.Tensor
Base model vectors to apply pooling to.
training : bool, optional
Not used. The default is None.
Returns
-------
tensorflow.Tensor
The pooled value.
"""
gp = tensorflow.linalg.l2_normalize(tensorflow.tensordot(tensorflow.reduce_sum(X,
axis=1),
self.global_projection,
axes=1),
axis=1)
lp = tensorflow.linalg.l2_normalize(tensorflow.tensordot(X,
self.local_projection,
axes=1),
axis=2)
attention = tensorflow.tensordot(lp,gp,axes=1)
return tensorflow.reduce_sum(attention *X,
axis=1)
I expect the input shape to be (batch_size,samples,width), and I’m exoecting batch size to be 32 and width to be 768.
When I try to fit this model, I’m getting
File "/home/peter/QARAC/qarac/models/layers/GlobalAttentionPoolingHead.py", line 75, in call
attention = tensorflow.tensordot(lp,gp,axes=1)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Exception encountered when calling layer 'global_attention_pooling_head_1' (type GlobalAttentionPoolingHead).
{{function_node __wrapped__MatMul_device_/job:localhost/replica:0/task:0/device:CPU:0}} Matrix size-incompatible: In[0]: [42,768], In[1]: [3,768] [Op:MatMul] name:
Call arguments received by layer 'global_attention_pooling_head_1' (type GlobalAttentionPoolingHead):
• X=tf.Tensor(shape=(3, 14, 768), dtype=float32)
• training=False
What is happening here and how can I fix it?