Implementing Hyena Models in Keras

As part of my QARAC NLP project, I’m trying to implement Hyena models in Keras. Hyena models are intended for use with variable-length sequences, so I need to train with ragged tensors.

My implementation looks like this

import keras
import keras_nlp
import tensorflow

def convolve(x,y):
    xT = tensorflow.transpose(x,[0,2,1])
    yT = tensorflow.transpose(y,[0,2,1])
    z = tensorflow.signal.irfft(tensorflow.signal.rfft(xT)*tensorflow.signal.rfft(yT))
    return tensorflow.transpose(z,[0,2,1])

    
    

class HyenaLayer(keras.layers.Layer):
    """Keras implementation of Hyena layer. Unlike in the original paper,
       this can be used as an encoder layer by setting the optional parameter
       `causal` to `False`"""
       
    def __init__(self,stages=3,causal=True):
        """
        

        Parameters
        ----------
        stages : int, optional
            Number of stages of convolution and Hadamard multiplication. The default is 3.
        causal : bool, optional
            Set to False for an encoder layer. The default is True.

        Returns
        -------
        None.

        """
        super(HyenaLayer,self).__init__()
        self.stages = stages
        self.causal = causal
        self.data_projection = None
        self.filters = None
        self.positional_encoding = keras_nlp.layers.SinePositionEncoding()
        
    def build(self,input_shape):
        width = input_shape[-1]
        self.data_projection = self.add_weight(shape=(width,width,self.stages+1),
                                               trainable=True)
        self.filters = self.add_weight(shape=(width,width,self.stages),
                                       trainable=True)
        
    def call(self,X,training=None):
        x = tensorflow.tensordot(X,self.data_projection,axes=1)
        f = tensorflow.tensordot(self.positional_encoding(X),self.filters,axes=1)
        if self.causal:
            concat = keras.layers.Concatenate()
            x = concat(x,tensorflow.zeros_like(x))
            f = concat(f,tensorflow.zeros_like(f))
        y = x[:,:,:,0]
        for i in range(self.stages):
            y = convolve(y,f[:,:,:,i])*x[:,:,:,i+1]
        if self.causal:
            for (i,n) in enumerate(X.row_lengths()):
                y[i] = y[i,:n]
        return y

This compiles correctly, but when I try to fit, I get an error message

    TypeError: Exception encountered when calling layer 'hyena_layer' (type HyenaLayer).
    
    in user code:
    
        File "/home/peter/QARAC/qarac/models/layers/HyenaLayer.py", line 58, in call  *
            x = tensorflow.tensordot(X,self.data_projection,axes=1)
    
        TypeError: Failed to convert elements of tf.RaggedTensor(values=Tensor("sequential/embedding/embedding_lookup_ragged/embedding_lookup/Identity:0", shape=(None, 768), dtype=float32), row_splits=Tensor("RaggedFromVariant/RaggedTensorFromVariant:0", shape=(None,), dtype=int64)) to Tensor. Consider casting elements to a supported type. See https://www.tensorflow.org/api_docs/python/tf/dtypes for supported TF dtypes.

Can anyone suggest a remedy for this?

I’ve managed to solve that by applying the tensordot operation to the ragged tensor’s flat_values attribute. However, I’ve got another problem that’s a bit trickier to solve. The input to the layer is a RaggedTensor of shape (None, None, 768). I need to perform a Fourier transform along the ragged dimension. So far my code looks like

fftspec = tensorflow.RaggedTensorSpec(shape=(x.shape[-1],None))
fx = tensorflow.map_fn(fft,x,
                           fn_output_signature=fftspec)

def fft(x):
    y= tensorflow.signal.rfft(tensorflow.transpose(x))
    print(y.shape)
    return y

However, the shape of y is (None,None), where I’m expecting (768,None), so the returned tensor doesn’t appear to match the RaggedTensorSpec supplied. Does anyone have any ideas what I could do about this?

After a bit of experimentation, I’ve found that the problem can be solved by using vectorize_map instead of map_fn.