tf.function(experimental_compile=True): Creating variables on a non-first call to a function decorated with tf.function

Hi, I am trying to compile a call function to speed up a recurrent network. However, despite my efforts tensorflow complains that variables are created after the first call. I fail to see where these variables are supposed to be created, so I hope someone here has an idea. The model tensorflow complains about is this one:

@tf.keras.saving.register_keras_serializable()
class FRAE(tf.keras.Model):
    def __init__(self,  output_dim, latent_dim, ht,  layer_config={"activations":[],"neurons":[]}, bypass=False,  vq_config=None, use_decoded=True, apply_mse_loss=False, mse_weight=1.0, window_size=5, **kwargs):
        super(FRAE, self).__init__(**kwargs)
        self.output_dim = output_dim
        self.latent_dim = latent_dim
        self.ht = ht
        self.layer_config = copy.deepcopy(layer_config)
        self.vq_config = vq_config
        self.encoder = FRAEEncoder(output_dim, latent_dim, layer_config, name=self.name+"_Encoder")
        self.decoder = FRAEDecoder(output_dim, layer_config, use_decoded=use_decoded,  name=self.name+"_Decoder")
        self.quantizer = VectorQuantizationFlattened(vq_config=vq_config, name=self.name+"_VQ")
        self.bypass = bypass
        self.use_decoded = use_decoded
        self.apply_mse_loss = apply_mse_loss
        self.mse_weight = mse_weight
        self.window_size = window_size
        self.timesteps = 7999
        # self.SetRandomStartIdx()


    def _get_input_signature(self):
        return [tf.TensorSpec(dtype=tf.float32, shape=[None, None, self.output_dim])]

    def SetBypass(self, bypass):
        self.bypass = bypass

    def SetupFRAELayers(self, latent_dim, layer_config):
        self.encoder.SetupLayers(self.output_dim, latent_dim, layer_config)
        self.decoder.SetupLayers(self.output_dim, layer_config)

        self.layer_config = copy.deepcopy(layer_config)
        self.latent_dim = latent_dim


    def CreateQuantizer(self, vq_config):
        self.vq_config = vq_config
        self.quantizer.CreateCodebook(vq_config)

    def IncreaseWindowSize(self, factor=2):
        self.window_size = min(self.window_size * factor, self.timesteps)
        # self.SetRandomStartIdx()
        tf.print(f"New window size: ", self.window_size)

    def build(self, input_shape):
        batch_size = input_shape[0] or 1
        dummy_input = tf.zeros([batch_size, 1, self.output_dim])
        dummy_state = tf.zeros([batch_size, 1, self.output_dim * self.ht])

        dec = self.encoder(dummy_input, dummy_state)
        dec_q = self.quantizer(dec)
        decoded, new_state = self.decoder(dec_q, dummy_state)

        super().build(input_shape)  # Wichtig für Keras

    # @tf.function(experimental_compile=True)
    def call(self, x, training=False):
        if self.bypass:
            return x
        else:
            if self.window_size < 7999:
                start_idx = tf.cond(tf.less(self.window_size, 7999),
                                    lambda: tf.random.uniform(shape=(), minval=0, maxval=7999 - self.window_size, dtype=tf.int32),
                                    lambda: tf.constant(0, dtype=tf.int32))
                state = tf.zeros(shape=(tf.shape(x)[0], 1, tf.shape(x)[2] * self.ht))

                timesteps = tf.shape(x)[1]
                selected_timesteps = tf.range(start_idx, start_idx + self.window_size)
                outputs = tf.TensorArray(tf.float32, timesteps)
                i = selected_timesteps[0]

                def cond(i, x, outputs, state, selected_timesteps):
                    return tf.less(i, selected_timesteps[-1])

                def body(i, x, outputs, state, selected_timesteps):
                    xin = tf.slice(x, [0, i, 0], [-1, 1, -1])
                    encoded = self.encoder(xin, state)
                    encoded_q = self.quantizer(encoded)
                    decoded, state = self.decoder(encoded_q, state)
                    outputs = outputs.write(i, decoded)
                    return tf.add(i,1), x, outputs, state, selected_timesteps

                i, x, outputs, state, selected_timesteps = tf.while_loop(cond, body, [i, x, outputs, state, selected_timesteps])
                # outputs: 7999 X bs x 1 x latent_dim -> 1 x bs x 7999 x latent_dim -> bs x 7999 x latent_dim
                partial_decoded = tf.squeeze(tf.transpose(outputs.stack(), [2, 1, 0, 3]))
                slice_A = tf.slice(x, [0, 0, 0], [-1, start_idx, -1])
                slice_B = tf.slice(x, [0, start_idx + self.window_size, 0],
                                   [-1, timesteps - (start_idx + self.window_size), -1])
                partial_decoded = tf.slice(partial_decoded, [0, start_idx, 0], [-1, self.window_size, -1])
                partial_x = tf.slice(x, [0, start_idx, 0], [-1, self.window_size, -1])

                decoded = tf.concat([slice_A, partial_decoded, slice_B], axis=1)
                if self.apply_mse_loss:
                    self.add_loss(self.mse_weight * tf.reduce_mean(tf.keras.losses.MSE(partial_decoded, partial_x)))

                self.add_metric(tf.reduce_mean(tf.keras.losses.MSE(partial_x, partial_decoded)), name='mse_' + self.name)
                # Else tf cannot deduce output shape due to random segments
                decoded.set_shape([None, None, x.shape[-1]])
                return decoded
            else:
                state = tf.zeros(shape=(tf.shape(x)[0], 1, tf.shape(x)[2] * self.ht))

                timesteps = tf.shape(x)[1]
                outputs = tf.TensorArray(tf.float32, timesteps)
                i = tf.constant(0)

                def cond(i, x, outputs, state, timesteps):
                    return tf.less(i, timesteps)

                def body(i, x, outputs, state, timesteps):
                    xin = tf.slice(x, [0, i, 0], [-1, 1, -1])
                    encoded = self.encoder(xin, state)
                    encoded_q = self.quantizer(encoded)
                    decoded, state = self.decoder(encoded_q, state)
                    outputs = outputs.write(i, decoded)
                    return tf.add(i, 1), x, outputs, state, timesteps

                i, x, outputs, state, selected_timesteps = tf.while_loop(cond, body,
                                                                         [i, x, outputs, state, timesteps])

                decoded = tf.squeeze(tf.transpose(outputs.stack(), [2, 1, 0, 3]))

                if self.apply_mse_loss:
                    self.add_loss(self.mse_weight * tf.reduce_mean(tf.keras.losses.MSE(decoded, x)))

                self.add_metric(tf.reduce_mean(tf.keras.losses.MSE(decoded, x)), name='mse_' + self.name)
                return decoded

FRAEEncoder, FRAEDecoder and the vector quantizer are defined as


@tf.function
def safe_norm(x, epsilon=1e-9, axis=None, keepdims=False):
    return tf.sqrt(tf.reduce_sum(x ** 2, axis=axis, keepdims=keepdims) + epsilon)

@tf.keras.saving.register_keras_serializable()
class VectorQuantizationFlattened(tf.keras.layers.Layer):
    def __init__(self, vq_config, **kwargs):
        super(VectorQuantizationFlattened, self).__init__(**kwargs)
        self.vq_config = vq_config
        self.num_embeddings = vq_config['num_embeddings']
        self.embedding_dim  = vq_config['embedding_dim']

        # self.CreateCodebook(vq_config)
        self.isRecording = False
        self.eps = 1e-12
        self.eps_noise = 1e-2
        self.codebook = self.add_weight(shape=(self.num_embeddings, self.embedding_dim),
                                        initializer=tf.keras.initializers.GlorotNormal(seed=None),
                                        trainable=True,
                                        name='embeddings')

        self.codebooks_used = self.add_weight(
            shape=(self.num_embeddings,),
            initializer=tf.keras.initializers.Zeros(),
            trainable=False,
            name="codebooks_used",
            dtype=tf.uint32
        )

        self.discarding_threshold = 0.25 / self.num_embeddings

    def CreateCodebook(self, vq_config):
        self.vq_config = vq_config
        self.num_embeddings = vq_config['num_embeddings']
        self.embedding_dim =  vq_config['embedding_dim']
        # self.build(None)
        if self.num_embeddings > 0:
            self.codebook = self.add_weight(shape=(self.num_embeddings, self.embedding_dim),
                                            initializer=tf.keras.initializers.GlorotNormal(seed=None),
                                            trainable=True,
                                            name='embeddings')
            # if weights is not None:
            #     self.codebook.assign(weights)

            # self.codebooks_used = tf.zeros_like(self.num_embeddings, dtype=tf.int32)
            self.codebooks_used = self.add_weight(
                shape=(self.num_embeddings,),
                initializer=tf.keras.initializers.Zeros(),
                trainable=False,
                name="codebooks_used",
                dtype=tf.uint64
            )

            self.discarding_threshold = 0.25/self.num_embeddings   # we init with appr. uniform distribution
        else:
            self.codebook = None
            self.codebooks_used = None
            self.discarding_threshold = None




    def call(self, inputs, training=None):

        if self.codebook is not None:
            input_shape = tf.shape(inputs)
            flattened = tf.reshape(inputs, [-1, self.embedding_dim])

            if training == False:
                # distances = tf.reduce_sum(tf.square(tf.expand_dims(inputs, axis=1) - self.codebook), axis=-1)
                distances = (
                        tf.reduce_sum(tf.square(flattened), axis=1, keepdims=True)
                        - 2 * tf.matmul(flattened, tf.transpose(self.codebook))
                        + tf.expand_dims(tf.reduce_sum(tf.square(self.codebook), axis=1), axis=0)
                )

                hard_assignments = tf.one_hot(tf.argmin(distances, axis=-1), self.num_embeddings)
                quantized_hard = tf.matmul(hard_assignments, self.codebook)
                quantized_hard = tf.reshape(quantized_hard, input_shape)

                return quantized_hard
                # quantized = quantized_hard

            else:
                input_shape = tf.shape(inputs)
                flattened = tf.reshape(inputs, [-1, self.embedding_dim])

                distances = (
                        tf.reduce_sum(tf.square(flattened), axis=1, keepdims=True)
                        - 2 * tf.matmul(flattened, tf.transpose(self.codebook))
                        + tf.expand_dims(tf.reduce_sum(tf.square(self.codebook), axis=1), axis=0)
                )


                min_indices = tf.argmin(distances, axis=1)  # Shape: (N,)


                hard_quantized_input = tf.gather(self.codebook, min_indices)  # Shape: (N, D)
                hard_quantized_input = tf.reshape(hard_quantized_input, input_shape)

                random_vector = tf.random.normal(tf.shape(inputs), mean=0.0, stddev=1.0)


                norm_quantization_residual = safe_norm(inputs - hard_quantized_input, axis=1, keepdims=True)  # (N, 1)
                norm_random_vector = safe_norm(random_vector, axis=1, keepdims=True)  # (N, 1)


                vq_error = tf.math.divide_no_nan(norm_quantization_residual, (norm_random_vector + self.eps)) * random_vector  # (N, D)
                quantized_input = inputs + vq_error


                encodings = tf.one_hot(min_indices, depth=self.num_embeddings, dtype=tf.float32)

                self.codebooks_used.assign_add(tf.cast(tf.reduce_sum(encodings, axis=0), tf.uint64))

                return quantized_input #, perplexity, self.codebooks_used


        else:
            return inputs

    def replace_unused_codebooks(self):
        usage_ratio = self.codebooks_used / tf.math.reduce_sum(self.codebooks_used)

        unused_indices = tf.reshape(tf.where(usage_ratio < self.discarding_threshold), [-1])
        used_indices = tf.reshape(tf.where(usage_ratio >= self.discarding_threshold), [-1])

        unused_count = tf.size(unused_indices)
        used_count = tf.size(used_indices)


        if used_count == 0:
            tf.print("####### used_indices equals zero / shuffling whole codebooks ######")
            noise = self.eps_noise * tf.random.normal(tf.shape(self.codebook), dtype=self.codebook.dtype)
            self.codebook.assign_add(noise)
        else:
            used_p = tf.gather(usage_ratio, used_indices)
            new_indices = np.random.choice(used_indices, unused_count, replace=True, p=used_p / tf.math.reduce_sum(used_p))

            noise = self.eps_noise * tf.random.normal([unused_count, self.embedding_dim], dtype=self.codebook.dtype)

            new_values = tf.gather(self.codebook, new_indices) + noise
            updated_codebooks = tf.tensor_scatter_nd_update(self.codebook,
                                                            tf.expand_dims(unused_indices, axis=1),
                                                            new_values)
            self.codebook.assign(updated_codebooks)


        self.eps_noise = max(self.eps_noise * 0.95, 1e-5)
        self.discarding_threshold = self.discarding_threshold * 0.95
        tf.print("************* Replaced", unused_count, "codebooks *************")
        # Setze die Nutzung zurück:
        self.codebooks_used.assign(tf.zeros_like(self.codebooks_used, dtype=self.codebooks_used.dtype))


    def get_config(self):
        config = super().get_config()
        config.update({
            'vq_config': self.vq_config
        })
        return config

    @classmethod
    def from_config(cls, config):

        return cls(**config)





@tf.keras.saving.register_keras_serializable()
class FRAEEncoder(tf.keras.layers.Layer):
    def __init__(self, input_shape, latent_dim,   layer_config,  **kwargs):
        super(FRAEEncoder, self).__init__(**kwargs)
        self.myinput_shape = input_shape
        self.latent_dim = latent_dim
        self.layer_config = copy.deepcopy(layer_config)
        self.SetupLayers(input_shape, latent_dim, layer_config)

    def SetupLayers(self, input_shape, latent_dim, layer_config):
        self.encoder = []
        activations  = layer_config["activations"]
        num_neuron   = layer_config["neurons"]
        l2reg = 1e-3
        if 'l2' in layer_config.keys():
            l2reg = layer_config['l2']

        for i, act in enumerate(activations):
            if i == 0:
                self.encoder.append(tf.keras.layers.Dense(num_neuron[i], activation=None, input_shape=(input_shape,), kernel_regularizer=regularizers.l2(l2reg), name='dense'))

                # self.encoder.append(tf.keras.layers.Activation(activation=act))
            else:
                self.encoder.append(tf.keras.layers.Dense(num_neuron[i], activation=None, kernel_regularizer=regularizers.l2(l2reg),  name='dense'))
                # self.encoder.append(tf.keras.layers.Activation(activation=act))

            self.encoder.append(tf.keras.layers.LayerNormalization(epsilon=1e-5))
            self.encoder.append(tf.keras.layers.Activation(activation=act, name='activation'))

        if len(activations) > 0:
            if "LatentAct" in layer_config.keys():
                self.encoder.append(tf.keras.layers.Dense(latent_dim, activation=None , kernel_regularizer=regularizers.l2(l2reg))) #layer_config["LatentAct"]
                self.encoder.append(tf.keras.layers.LayerNormalization(epsilon=1e-5))
                self.encoder.append(tf.keras.layers.Activation(activation=layer_config["LatentAct"], name='activation'))
            else:
                self.encoder.append(tf.keras.layers.Dense(latent_dim, activation=None, kernel_regularizer=regularizers.l2(l2reg)))
                self.encoder.append(tf.keras.layers.LayerNormalization(epsilon=1e-5))
                self.encoder.append(tf.keras.layers.Activation(activation="swish", name='activation'))


            #self.encoder.append(tf.keras.layers.LayerNormalization(epsilon=1e-5))            
        else:
            if "LatentAct" in layer_config.keys():
                self.encoder.append(tf.keras.layers.Dense(latent_dim, activation=None, input_shape=(input_shape,), kernel_regularizer=regularizers.l2(l2reg)))
                self.encoder.append(tf.keras.layers.LayerNormalization(epsilon=1e-7))
                self.encoder.append(tf.keras.layers.Activation(activation=layer_config["LatentAct"], name='activation'))

            else:
                self.encoder.append(tf.keras.layers.Dense(latent_dim, activation=None, input_shape=(input_shape,), kernel_regularizer=regularizers.l2(l2reg)))
                self.encoder.append(tf.keras.layers.LayerNormalization(epsilon=1e-7))
                self.encoder.append(tf.keras.layers.Activation(activation="swish", name='activation'))


            #self.encoder.append(tf.keras.layers.LayerNormalization(epsilon=1e-5))


           


    def call(self, inputs, state):
        # encoder
        encoded = tf.concat([inputs, state], axis=-1)

        for k, lrs in enumerate(self.encoder):
            encoded = lrs(encoded)
            if 'activation' in lrs.name:
                encoded = tf.concat([encoded, state], axis=-1)


        new_output = encoded
        return new_output

    def get_config(self):
        config = super().get_config()
        config.update({
            'myinput_shape': self.myinput_shape,
            'latent_dim': self.latent_dim,
            'layer_config': self.layer_config
        })
        return config

    @classmethod
    def from_config(cls, config):
        input_shape = config['myinput_shape']
        latent_dim = config['latent_dim']
        layer_config = config['layer_config']

        return cls(input_shape, latent_dim, layer_config)

@tf.keras.saving.register_keras_serializable()
class FRAEDecoder(tf.keras.layers.Layer):
    def __init__(self,   output_dim, layer_config, use_decoded = True, **kwargs):
        super(FRAEDecoder, self).__init__(**kwargs)
        self.output_dim = output_dim
        self.layer_config = copy.deepcopy(layer_config)
        self.use_decoded = use_decoded

        self.SetupLayers(output_dim, layer_config)

    def SetupLayers(self,  output_dim, layer_config):
        self.decoder = []
        activations  = layer_config["activations"]
        num_neuron   = layer_config["neurons"]

        l2reg = 1e-3
        if 'l2' in layer_config.keys():
            l2reg = layer_config['l2']

        for i, act in reversed(list(enumerate(activations))):
            self.decoder.append(tf.keras.layers.Dense(num_neuron[i],activation=None, kernel_regularizer=regularizers.l2(l2reg), name="dense") ) # kernel_initializer=tf.keras.initializers.Orthogonal()
            self.decoder.append(tf.keras.layers.LayerNormalization(epsilon=1e-7))
            self.decoder.append(tf.keras.layers.Activation(activation=act, name='activation'))

        self.decoder.append(tf.keras.layers.Dense(output_dim, kernel_regularizer=regularizers.l2(l2reg), activation='linear'))


    def call(self, encoded, state):
        y = tf.concat([encoded, state], axis=-1)
        for i, lrs in enumerate(self.decoder):
            y = lrs(y)
            if 'activation' in lrs.name:
               y = tf.concat([y, state], axis=-1)

        #update output and state
        new_output = y  # Result of
        new_state  = tf.concat([new_output, state[:, :, :-tf.shape(new_output)[-1]]], axis=-1)

        return new_output, new_state

    def get_config(self):
        config = super().get_config()
        config.update({
            'output_dim': self.output_dim,
            'layer_config': self.layer_config,
            'use_decoded': self.use_decoded
        })
        return config

    @classmethod
    def from_config(cls, config):
        output_dim = config['output_dim']
        layer_config = config['layer_config']
        if 'use_decoded' in config.keys():
            use_decoded = config['use_decoded']
        else:
            use_decoded = True

        return cls(output_dim, layer_config, use_decoded)

Due to the build() function all participating layers have their variables created then the entire model is built. Therefore I do not know why this error occurs:

        frae_decoded = self.frae(encoded) # we handle bypass inside the model

       ValueError: Exception encountered when calling layer 'frae' (type FRAE).
    
    Creating variables on a non-first call to a function decorated with tf.function.

Does anyone have an idea? Thank you :slight_smile: