As part of my QARAC NLP project, I’m trying to implement Hyena models in Keras. Hyena models are intended for use with variable-length sequences, so I need to train with ragged tensors.
My implementation looks like this
import keras
import keras_nlp
import tensorflow
def convolve(x,y):
xT = tensorflow.transpose(x,[0,2,1])
yT = tensorflow.transpose(y,[0,2,1])
z = tensorflow.signal.irfft(tensorflow.signal.rfft(xT)*tensorflow.signal.rfft(yT))
return tensorflow.transpose(z,[0,2,1])
class HyenaLayer(keras.layers.Layer):
"""Keras implementation of Hyena layer. Unlike in the original paper,
this can be used as an encoder layer by setting the optional parameter
`causal` to `False`"""
def __init__(self,stages=3,causal=True):
stages : int, optional
Number of stages of convolution and Hadamard multiplication. The default is 3.
causal : bool, optional
Set to False for an encoder layer. The default is True.
self.stages = stages
self.causal = causal
self.data_projection = None
self.filters = None
self.positional_encoding = keras_nlp.layers.SinePositionEncoding()
def build(self,input_shape):
width = input_shape[-1]
self.data_projection = self.add_weight(shape=(width,width,self.stages+1),
self.filters = self.add_weight(shape=(width,width,self.stages),
def call(self,X,training=None):
x = tensorflow.tensordot(X,self.data_projection,axes=1)
f = tensorflow.tensordot(self.positional_encoding(X),self.filters,axes=1)
if self.causal:
concat = keras.layers.Concatenate()
x = concat(x,tensorflow.zeros_like(x))
f = concat(f,tensorflow.zeros_like(f))
y = x[:,:,:,0]
for i in range(self.stages):
y = convolve(y,f[:,:,:,i])*x[:,:,:,i+1]
if self.causal:
for (i,n) in enumerate(X.row_lengths()):
y[i] = y[i,:n]
return y
This compiles correctly, but when I try to fit, I get an error message
TypeError: Exception encountered when calling layer 'hyena_layer' (type HyenaLayer).
in user code:
File "/home/peter/QARAC/qarac/models/layers/", line 58, in call *
x = tensorflow.tensordot(X,self.data_projection,axes=1)
TypeError: Failed to convert elements of tf.RaggedTensor(values=Tensor("sequential/embedding/embedding_lookup_ragged/embedding_lookup/Identity:0", shape=(None, 768), dtype=float32), row_splits=Tensor("RaggedFromVariant/RaggedTensorFromVariant:0", shape=(None,), dtype=int64)) to Tensor. Consider casting elements to a supported type. See for supported TF dtypes.
Can anyone suggest a remedy for this?