tf.function(experimental_compile=True): Creating variables on a non-first call to a function decorated with tf.function

Hi, I am trying to compile a call function to speed up a recurrent network. However, despite my efforts tensorflow complains that variables are created after the first call. I fail to see where these variables are supposed to be created, so I hope someone here has an idea. The model tensorflow complains about is this one:

@tf.keras.saving.register_keras_serializable()
class FRAE(tf.keras.Model):
    def __init__(self,  output_dim, latent_dim, ht,  layer_config={"activations":[],"neurons":[]}, bypass=False,  vq_config=None, use_decoded=True, apply_mse_loss=False, mse_weight=1.0, window_size=5, **kwargs):
        super(FRAE, self).__init__(**kwargs)
        self.output_dim = output_dim
        self.latent_dim = latent_dim
        self.ht = ht
        self.layer_config = copy.deepcopy(layer_config)
        self.vq_config = vq_config
        self.encoder = FRAEEncoder(output_dim, latent_dim, layer_config, name=self.name+"_Encoder")
        self.decoder = FRAEDecoder(output_dim, layer_config, use_decoded=use_decoded,  name=self.name+"_Decoder")
        self.quantizer = VectorQuantizationFlattened(vq_config=vq_config, name=self.name+"_VQ")
        self.bypass = bypass
        self.use_decoded = use_decoded
        self.apply_mse_loss = apply_mse_loss
        self.mse_weight = mse_weight
        self.window_size = window_size
        self.timesteps = 7999
        # self.SetRandomStartIdx()


    def _get_input_signature(self):
        return [tf.TensorSpec(dtype=tf.float32, shape=[None, None, self.output_dim])]

    def SetBypass(self, bypass):
        self.bypass = bypass

    def SetupFRAELayers(self, latent_dim, layer_config):
        self.encoder.SetupLayers(self.output_dim, latent_dim, layer_config)
        self.decoder.SetupLayers(self.output_dim, layer_config)

        self.layer_config = copy.deepcopy(layer_config)
        self.latent_dim = latent_dim


    def CreateQuantizer(self, vq_config):
        self.vq_config = vq_config
        self.quantizer.CreateCodebook(vq_config)

    def IncreaseWindowSize(self, factor=2):
        self.window_size = min(self.window_size * factor, self.timesteps)
        # self.SetRandomStartIdx()
        tf.print(f"New window size: ", self.window_size)

    def build(self, input_shape):
        batch_size = input_shape[0] or 1
        dummy_input = tf.zeros([batch_size, 1, self.output_dim])
        dummy_state = tf.zeros([batch_size, 1, self.output_dim * self.ht])

        dec = self.encoder(dummy_input, dummy_state)
        dec_q = self.quantizer(dec)
        decoded, new_state = self.decoder(dec_q, dummy_state)

        super().build(input_shape)  # Wichtig für Keras

    # @tf.function(experimental_compile=True)
    def call(self, x, training=False):
        if self.bypass:
            return x
        else:
            if self.window_size < 7999:
                start_idx = tf.cond(tf.less(self.window_size, 7999),
                                    lambda: tf.random.uniform(shape=(), minval=0, maxval=7999 - self.window_size, dtype=tf.int32),
                                    lambda: tf.constant(0, dtype=tf.int32))
                state = tf.zeros(shape=(tf.shape(x)[0], 1, tf.shape(x)[2] * self.ht))

                timesteps = tf.shape(x)[1]
                selected_timesteps = tf.range(start_idx, start_idx + self.window_size)
                outputs = tf.TensorArray(tf.float32, timesteps)
                i = selected_timesteps[0]

                def cond(i, x, outputs, state, selected_timesteps):
                    return tf.less(i, selected_timesteps[-1])

                def body(i, x, outputs, state, selected_timesteps):
                    xin = tf.slice(x, [0, i, 0], [-1, 1, -1])
                    encoded = self.encoder(xin, state)
                    encoded_q = self.quantizer(encoded)
                    decoded, state = self.decoder(encoded_q, state)
                    outputs = outputs.write(i, decoded)
                    return tf.add(i,1), x, outputs, state, selected_timesteps

                i, x, outputs, state, selected_timesteps = tf.while_loop(cond, body, [i, x, outputs, state, selected_timesteps])
                # outputs: 7999 X bs x 1 x latent_dim -> 1 x bs x 7999 x latent_dim -> bs x 7999 x latent_dim
                partial_decoded = tf.squeeze(tf.transpose(outputs.stack(), [2, 1, 0, 3]))
                slice_A = tf.slice(x, [0, 0, 0], [-1, start_idx, -1])
                slice_B = tf.slice(x, [0, start_idx + self.window_size, 0],
                                   [-1, timesteps - (start_idx + self.window_size), -1])
                partial_decoded = tf.slice(partial_decoded, [0, start_idx, 0], [-1, self.window_size, -1])
                partial_x = tf.slice(x, [0, start_idx, 0], [-1, self.window_size, -1])

                decoded = tf.concat([slice_A, partial_decoded, slice_B], axis=1)
                if self.apply_mse_loss:
                    self.add_loss(self.mse_weight * tf.reduce_mean(tf.keras.losses.MSE(partial_decoded, partial_x)))

                self.add_metric(tf.reduce_mean(tf.keras.losses.MSE(partial_x, partial_decoded)), name='mse_' + self.name)
                # Else tf cannot deduce output shape due to random segments
                decoded.set_shape([None, None, x.shape[-1]])
                return decoded
            else:
                state = tf.zeros(shape=(tf.shape(x)[0], 1, tf.shape(x)[2] * self.ht))

                timesteps = tf.shape(x)[1]
                outputs = tf.TensorArray(tf.float32, timesteps)
                i = tf.constant(0)

                def cond(i, x, outputs, state, timesteps):
                    return tf.less(i, timesteps)

                def body(i, x, outputs, state, timesteps):
                    xin = tf.slice(x, [0, i, 0], [-1, 1, -1])
                    encoded = self.encoder(xin, state)
                    encoded_q = self.quantizer(encoded)
                    decoded, state = self.decoder(encoded_q, state)
                    outputs = outputs.write(i, decoded)
                    return tf.add(i, 1), x, outputs, state, timesteps

                i, x, outputs, state, selected_timesteps = tf.while_loop(cond, body,
                                                                         [i, x, outputs, state, timesteps])

                decoded = tf.squeeze(tf.transpose(outputs.stack(), [2, 1, 0, 3]))

                if self.apply_mse_loss:
                    self.add_loss(self.mse_weight * tf.reduce_mean(tf.keras.losses.MSE(decoded, x)))

                self.add_metric(tf.reduce_mean(tf.keras.losses.MSE(decoded, x)), name='mse_' + self.name)
                return decoded

FRAEEncoder, FRAEDecoder and the vector quantizer are defined as


@tf.function
def safe_norm(x, epsilon=1e-9, axis=None, keepdims=False):
    return tf.sqrt(tf.reduce_sum(x ** 2, axis=axis, keepdims=keepdims) + epsilon)

@tf.keras.saving.register_keras_serializable()
class VectorQuantizationFlattened(tf.keras.layers.Layer):
    def __init__(self, vq_config, **kwargs):
        super(VectorQuantizationFlattened, self).__init__(**kwargs)
        self.vq_config = vq_config
        self.num_embeddings = vq_config['num_embeddings']
        self.embedding_dim  = vq_config['embedding_dim']

        # self.CreateCodebook(vq_config)
        self.isRecording = False
        self.eps = 1e-12
        self.eps_noise = 1e-2
        self.codebook = self.add_weight(shape=(self.num_embeddings, self.embedding_dim),
                                        initializer=tf.keras.initializers.GlorotNormal(seed=None),
                                        trainable=True,
                                        name='embeddings')

        self.codebooks_used = self.add_weight(
            shape=(self.num_embeddings,),
            initializer=tf.keras.initializers.Zeros(),
            trainable=False,
            name="codebooks_used",
            dtype=tf.uint32
        )

        self.discarding_threshold = 0.25 / self.num_embeddings

    def CreateCodebook(self, vq_config):
        self.vq_config = vq_config
        self.num_embeddings = vq_config['num_embeddings']
        self.embedding_dim =  vq_config['embedding_dim']
        # self.build(None)
        if self.num_embeddings > 0:
            self.codebook = self.add_weight(shape=(self.num_embeddings, self.embedding_dim),
                                            initializer=tf.keras.initializers.GlorotNormal(seed=None),
                                            trainable=True,
                                            name='embeddings')
            # if weights is not None:
            #     self.codebook.assign(weights)

            # self.codebooks_used = tf.zeros_like(self.num_embeddings, dtype=tf.int32)
            self.codebooks_used = self.add_weight(
                shape=(self.num_embeddings,),
                initializer=tf.keras.initializers.Zeros(),
                trainable=False,
                name="codebooks_used",
                dtype=tf.uint64
            )

            self.discarding_threshold = 0.25/self.num_embeddings   # we init with appr. uniform distribution
        else:
            self.codebook = None
            self.codebooks_used = None
            self.discarding_threshold = None




    def call(self, inputs, training=None):

        if self.codebook is not None:
            input_shape = tf.shape(inputs)
            flattened = tf.reshape(inputs, [-1, self.embedding_dim])

            if training == False:
                # distances = tf.reduce_sum(tf.square(tf.expand_dims(inputs, axis=1) - self.codebook), axis=-1)
                distances = (
                        tf.reduce_sum(tf.square(flattened), axis=1, keepdims=True)
                        - 2 * tf.matmul(flattened, tf.transpose(self.codebook))
                        + tf.expand_dims(tf.reduce_sum(tf.square(self.codebook), axis=1), axis=0)
                )

                hard_assignments = tf.one_hot(tf.argmin(distances, axis=-1), self.num_embeddings)
                quantized_hard = tf.matmul(hard_assignments, self.codebook)
                quantized_hard = tf.reshape(quantized_hard, input_shape)

                return quantized_hard
                # quantized = quantized_hard

            else:
                input_shape = tf.shape(inputs)
                flattened = tf.reshape(inputs, [-1, self.embedding_dim])

                distances = (
                        tf.reduce_sum(tf.square(flattened), axis=1, keepdims=True)
                        - 2 * tf.matmul(flattened, tf.transpose(self.codebook))
                        + tf.expand_dims(tf.reduce_sum(tf.square(self.codebook), axis=1), axis=0)
                )


                min_indices = tf.argmin(distances, axis=1)  # Shape: (N,)


                hard_quantized_input = tf.gather(self.codebook, min_indices)  # Shape: (N, D)
                hard_quantized_input = tf.reshape(hard_quantized_input, input_shape)

                random_vector = tf.random.normal(tf.shape(inputs), mean=0.0, stddev=1.0)


                norm_quantization_residual = safe_norm(inputs - hard_quantized_input, axis=1, keepdims=True)  # (N, 1)
                norm_random_vector = safe_norm(random_vector, axis=1, keepdims=True)  # (N, 1)


                vq_error = tf.math.divide_no_nan(norm_quantization_residual, (norm_random_vector + self.eps)) * random_vector  # (N, D)
                quantized_input = inputs + vq_error


                encodings = tf.one_hot(min_indices, depth=self.num_embeddings, dtype=tf.float32)

                self.codebooks_used.assign_add(tf.cast(tf.reduce_sum(encodings, axis=0), tf.uint64))

                return quantized_input #, perplexity, self.codebooks_used


        else:
            return inputs

    def replace_unused_codebooks(self):
        usage_ratio = self.codebooks_used / tf.math.reduce_sum(self.codebooks_used)

        unused_indices = tf.reshape(tf.where(usage_ratio < self.discarding_threshold), [-1])
        used_indices = tf.reshape(tf.where(usage_ratio >= self.discarding_threshold), [-1])

        unused_count = tf.size(unused_indices)
        used_count = tf.size(used_indices)


        if used_count == 0:
            tf.print("####### used_indices equals zero / shuffling whole codebooks ######")
            noise = self.eps_noise * tf.random.normal(tf.shape(self.codebook), dtype=self.codebook.dtype)
            self.codebook.assign_add(noise)
        else:
            used_p = tf.gather(usage_ratio, used_indices)
            new_indices = np.random.choice(used_indices, unused_count, replace=True, p=used_p / tf.math.reduce_sum(used_p))

            noise = self.eps_noise * tf.random.normal([unused_count, self.embedding_dim], dtype=self.codebook.dtype)

            new_values = tf.gather(self.codebook, new_indices) + noise
            updated_codebooks = tf.tensor_scatter_nd_update(self.codebook,
                                                            tf.expand_dims(unused_indices, axis=1),
                                                            new_values)
            self.codebook.assign(updated_codebooks)


        self.eps_noise = max(self.eps_noise * 0.95, 1e-5)
        self.discarding_threshold = self.discarding_threshold * 0.95
        tf.print("************* Replaced", unused_count, "codebooks *************")
        # Setze die Nutzung zurück:
        self.codebooks_used.assign(tf.zeros_like(self.codebooks_used, dtype=self.codebooks_used.dtype))


    def get_config(self):
        config = super().get_config()
        config.update({
            'vq_config': self.vq_config
        })
        return config

    @classmethod
    def from_config(cls, config):

        return cls(**config)





@tf.keras.saving.register_keras_serializable()
class FRAEEncoder(tf.keras.layers.Layer):
    def __init__(self, input_shape, latent_dim,   layer_config,  **kwargs):
        super(FRAEEncoder, self).__init__(**kwargs)
        self.myinput_shape = input_shape
        self.latent_dim = latent_dim
        self.layer_config = copy.deepcopy(layer_config)
        self.SetupLayers(input_shape, latent_dim, layer_config)

    def SetupLayers(self, input_shape, latent_dim, layer_config):
        self.encoder = []
        activations  = layer_config["activations"]
        num_neuron   = layer_config["neurons"]
        l2reg = 1e-3
        if 'l2' in layer_config.keys():
            l2reg = layer_config['l2']

        for i, act in enumerate(activations):
            if i == 0:
                self.encoder.append(tf.keras.layers.Dense(num_neuron[i], activation=None, input_shape=(input_shape,), kernel_regularizer=regularizers.l2(l2reg), name='dense'))

                # self.encoder.append(tf.keras.layers.Activation(activation=act))
            else:
                self.encoder.append(tf.keras.layers.Dense(num_neuron[i], activation=None, kernel_regularizer=regularizers.l2(l2reg),  name='dense'))
                # self.encoder.append(tf.keras.layers.Activation(activation=act))

            self.encoder.append(tf.keras.layers.LayerNormalization(epsilon=1e-5))
            self.encoder.append(tf.keras.layers.Activation(activation=act, name='activation'))

        if len(activations) > 0:
            if "LatentAct" in layer_config.keys():
                self.encoder.append(tf.keras.layers.Dense(latent_dim, activation=None , kernel_regularizer=regularizers.l2(l2reg))) #layer_config["LatentAct"]
                self.encoder.append(tf.keras.layers.LayerNormalization(epsilon=1e-5))
                self.encoder.append(tf.keras.layers.Activation(activation=layer_config["LatentAct"], name='activation'))
            else:
                self.encoder.append(tf.keras.layers.Dense(latent_dim, activation=None, kernel_regularizer=regularizers.l2(l2reg)))
                self.encoder.append(tf.keras.layers.LayerNormalization(epsilon=1e-5))
                self.encoder.append(tf.keras.layers.Activation(activation="swish", name='activation'))


            #self.encoder.append(tf.keras.layers.LayerNormalization(epsilon=1e-5))            
        else:
            if "LatentAct" in layer_config.keys():
                self.encoder.append(tf.keras.layers.Dense(latent_dim, activation=None, input_shape=(input_shape,), kernel_regularizer=regularizers.l2(l2reg)))
                self.encoder.append(tf.keras.layers.LayerNormalization(epsilon=1e-7))
                self.encoder.append(tf.keras.layers.Activation(activation=layer_config["LatentAct"], name='activation'))

            else:
                self.encoder.append(tf.keras.layers.Dense(latent_dim, activation=None, input_shape=(input_shape,), kernel_regularizer=regularizers.l2(l2reg)))
                self.encoder.append(tf.keras.layers.LayerNormalization(epsilon=1e-7))
                self.encoder.append(tf.keras.layers.Activation(activation="swish", name='activation'))


            #self.encoder.append(tf.keras.layers.LayerNormalization(epsilon=1e-5))


           


    def call(self, inputs, state):
        # encoder
        encoded = tf.concat([inputs, state], axis=-1)

        for k, lrs in enumerate(self.encoder):
            encoded = lrs(encoded)
            if 'activation' in lrs.name:
                encoded = tf.concat([encoded, state], axis=-1)


        new_output = encoded
        return new_output

    def get_config(self):
        config = super().get_config()
        config.update({
            'myinput_shape': self.myinput_shape,
            'latent_dim': self.latent_dim,
            'layer_config': self.layer_config
        })
        return config

    @classmethod
    def from_config(cls, config):
        input_shape = config['myinput_shape']
        latent_dim = config['latent_dim']
        layer_config = config['layer_config']

        return cls(input_shape, latent_dim, layer_config)

@tf.keras.saving.register_keras_serializable()
class FRAEDecoder(tf.keras.layers.Layer):
    def __init__(self,   output_dim, layer_config, use_decoded = True, **kwargs):
        super(FRAEDecoder, self).__init__(**kwargs)
        self.output_dim = output_dim
        self.layer_config = copy.deepcopy(layer_config)
        self.use_decoded = use_decoded

        self.SetupLayers(output_dim, layer_config)

    def SetupLayers(self,  output_dim, layer_config):
        self.decoder = []
        activations  = layer_config["activations"]
        num_neuron   = layer_config["neurons"]

        l2reg = 1e-3
        if 'l2' in layer_config.keys():
            l2reg = layer_config['l2']

        for i, act in reversed(list(enumerate(activations))):
            self.decoder.append(tf.keras.layers.Dense(num_neuron[i],activation=None, kernel_regularizer=regularizers.l2(l2reg), name="dense") ) # kernel_initializer=tf.keras.initializers.Orthogonal()
            self.decoder.append(tf.keras.layers.LayerNormalization(epsilon=1e-7))
            self.decoder.append(tf.keras.layers.Activation(activation=act, name='activation'))

        self.decoder.append(tf.keras.layers.Dense(output_dim, kernel_regularizer=regularizers.l2(l2reg), activation='linear'))


    def call(self, encoded, state):
        y = tf.concat([encoded, state], axis=-1)
        for i, lrs in enumerate(self.decoder):
            y = lrs(y)
            if 'activation' in lrs.name:
               y = tf.concat([y, state], axis=-1)

        #update output and state
        new_output = y  # Result of
        new_state  = tf.concat([new_output, state[:, :, :-tf.shape(new_output)[-1]]], axis=-1)

        return new_output, new_state

    def get_config(self):
        config = super().get_config()
        config.update({
            'output_dim': self.output_dim,
            'layer_config': self.layer_config,
            'use_decoded': self.use_decoded
        })
        return config

    @classmethod
    def from_config(cls, config):
        output_dim = config['output_dim']
        layer_config = config['layer_config']
        if 'use_decoded' in config.keys():
            use_decoded = config['use_decoded']
        else:
            use_decoded = True

        return cls(output_dim, layer_config, use_decoded)

Due to the build() function all participating layers have their variables created then the entire model is built. Therefore I do not know why this error occurs:

        frae_decoded = self.frae(encoded) # we handle bypass inside the model

       ValueError: Exception encountered when calling layer 'frae' (type FRAE).
    
    Creating variables on a non-first call to a function decorated with tf.function.

Does anyone have an idea? Thank you :slight_smile:

Hi @Cola_Lightyear, This error occurs because tf.function builds the computation graph on the first call, which includes variable creation. tf.function only allows creating new tf.Variable objects when it is called for the first time. When the function runs again,It is trying to create new variables, which is not allowed. So, Ensure to create tf.Variable s outside of tf.function. Thanks!