Hi, I am trying to compile a call function to speed up a recurrent network. However, despite my efforts tensorflow complains that variables are created after the first call. I fail to see where these variables are supposed to be created, so I hope someone here has an idea. The model tensorflow complains about is this one:
@tf.keras.saving.register_keras_serializable()
class FRAE(tf.keras.Model):
def __init__(self, output_dim, latent_dim, ht, layer_config={"activations":[],"neurons":[]}, bypass=False, vq_config=None, use_decoded=True, apply_mse_loss=False, mse_weight=1.0, window_size=5, **kwargs):
super(FRAE, self).__init__(**kwargs)
self.output_dim = output_dim
self.latent_dim = latent_dim
self.ht = ht
self.layer_config = copy.deepcopy(layer_config)
self.vq_config = vq_config
self.encoder = FRAEEncoder(output_dim, latent_dim, layer_config, name=self.name+"_Encoder")
self.decoder = FRAEDecoder(output_dim, layer_config, use_decoded=use_decoded, name=self.name+"_Decoder")
self.quantizer = VectorQuantizationFlattened(vq_config=vq_config, name=self.name+"_VQ")
self.bypass = bypass
self.use_decoded = use_decoded
self.apply_mse_loss = apply_mse_loss
self.mse_weight = mse_weight
self.window_size = window_size
self.timesteps = 7999
# self.SetRandomStartIdx()
def _get_input_signature(self):
return [tf.TensorSpec(dtype=tf.float32, shape=[None, None, self.output_dim])]
def SetBypass(self, bypass):
self.bypass = bypass
def SetupFRAELayers(self, latent_dim, layer_config):
self.encoder.SetupLayers(self.output_dim, latent_dim, layer_config)
self.decoder.SetupLayers(self.output_dim, layer_config)
self.layer_config = copy.deepcopy(layer_config)
self.latent_dim = latent_dim
def CreateQuantizer(self, vq_config):
self.vq_config = vq_config
self.quantizer.CreateCodebook(vq_config)
def IncreaseWindowSize(self, factor=2):
self.window_size = min(self.window_size * factor, self.timesteps)
# self.SetRandomStartIdx()
tf.print(f"New window size: ", self.window_size)
def build(self, input_shape):
batch_size = input_shape[0] or 1
dummy_input = tf.zeros([batch_size, 1, self.output_dim])
dummy_state = tf.zeros([batch_size, 1, self.output_dim * self.ht])
dec = self.encoder(dummy_input, dummy_state)
dec_q = self.quantizer(dec)
decoded, new_state = self.decoder(dec_q, dummy_state)
super().build(input_shape) # Wichtig für Keras
# @tf.function(experimental_compile=True)
def call(self, x, training=False):
if self.bypass:
return x
else:
if self.window_size < 7999:
start_idx = tf.cond(tf.less(self.window_size, 7999),
lambda: tf.random.uniform(shape=(), minval=0, maxval=7999 - self.window_size, dtype=tf.int32),
lambda: tf.constant(0, dtype=tf.int32))
state = tf.zeros(shape=(tf.shape(x)[0], 1, tf.shape(x)[2] * self.ht))
timesteps = tf.shape(x)[1]
selected_timesteps = tf.range(start_idx, start_idx + self.window_size)
outputs = tf.TensorArray(tf.float32, timesteps)
i = selected_timesteps[0]
def cond(i, x, outputs, state, selected_timesteps):
return tf.less(i, selected_timesteps[-1])
def body(i, x, outputs, state, selected_timesteps):
xin = tf.slice(x, [0, i, 0], [-1, 1, -1])
encoded = self.encoder(xin, state)
encoded_q = self.quantizer(encoded)
decoded, state = self.decoder(encoded_q, state)
outputs = outputs.write(i, decoded)
return tf.add(i,1), x, outputs, state, selected_timesteps
i, x, outputs, state, selected_timesteps = tf.while_loop(cond, body, [i, x, outputs, state, selected_timesteps])
# outputs: 7999 X bs x 1 x latent_dim -> 1 x bs x 7999 x latent_dim -> bs x 7999 x latent_dim
partial_decoded = tf.squeeze(tf.transpose(outputs.stack(), [2, 1, 0, 3]))
slice_A = tf.slice(x, [0, 0, 0], [-1, start_idx, -1])
slice_B = tf.slice(x, [0, start_idx + self.window_size, 0],
[-1, timesteps - (start_idx + self.window_size), -1])
partial_decoded = tf.slice(partial_decoded, [0, start_idx, 0], [-1, self.window_size, -1])
partial_x = tf.slice(x, [0, start_idx, 0], [-1, self.window_size, -1])
decoded = tf.concat([slice_A, partial_decoded, slice_B], axis=1)
if self.apply_mse_loss:
self.add_loss(self.mse_weight * tf.reduce_mean(tf.keras.losses.MSE(partial_decoded, partial_x)))
self.add_metric(tf.reduce_mean(tf.keras.losses.MSE(partial_x, partial_decoded)), name='mse_' + self.name)
# Else tf cannot deduce output shape due to random segments
decoded.set_shape([None, None, x.shape[-1]])
return decoded
else:
state = tf.zeros(shape=(tf.shape(x)[0], 1, tf.shape(x)[2] * self.ht))
timesteps = tf.shape(x)[1]
outputs = tf.TensorArray(tf.float32, timesteps)
i = tf.constant(0)
def cond(i, x, outputs, state, timesteps):
return tf.less(i, timesteps)
def body(i, x, outputs, state, timesteps):
xin = tf.slice(x, [0, i, 0], [-1, 1, -1])
encoded = self.encoder(xin, state)
encoded_q = self.quantizer(encoded)
decoded, state = self.decoder(encoded_q, state)
outputs = outputs.write(i, decoded)
return tf.add(i, 1), x, outputs, state, timesteps
i, x, outputs, state, selected_timesteps = tf.while_loop(cond, body,
[i, x, outputs, state, timesteps])
decoded = tf.squeeze(tf.transpose(outputs.stack(), [2, 1, 0, 3]))
if self.apply_mse_loss:
self.add_loss(self.mse_weight * tf.reduce_mean(tf.keras.losses.MSE(decoded, x)))
self.add_metric(tf.reduce_mean(tf.keras.losses.MSE(decoded, x)), name='mse_' + self.name)
return decoded
FRAEEncoder, FRAEDecoder and the vector quantizer are defined as
@tf.function
def safe_norm(x, epsilon=1e-9, axis=None, keepdims=False):
return tf.sqrt(tf.reduce_sum(x ** 2, axis=axis, keepdims=keepdims) + epsilon)
@tf.keras.saving.register_keras_serializable()
class VectorQuantizationFlattened(tf.keras.layers.Layer):
def __init__(self, vq_config, **kwargs):
super(VectorQuantizationFlattened, self).__init__(**kwargs)
self.vq_config = vq_config
self.num_embeddings = vq_config['num_embeddings']
self.embedding_dim = vq_config['embedding_dim']
# self.CreateCodebook(vq_config)
self.isRecording = False
self.eps = 1e-12
self.eps_noise = 1e-2
self.codebook = self.add_weight(shape=(self.num_embeddings, self.embedding_dim),
initializer=tf.keras.initializers.GlorotNormal(seed=None),
trainable=True,
name='embeddings')
self.codebooks_used = self.add_weight(
shape=(self.num_embeddings,),
initializer=tf.keras.initializers.Zeros(),
trainable=False,
name="codebooks_used",
dtype=tf.uint32
)
self.discarding_threshold = 0.25 / self.num_embeddings
def CreateCodebook(self, vq_config):
self.vq_config = vq_config
self.num_embeddings = vq_config['num_embeddings']
self.embedding_dim = vq_config['embedding_dim']
# self.build(None)
if self.num_embeddings > 0:
self.codebook = self.add_weight(shape=(self.num_embeddings, self.embedding_dim),
initializer=tf.keras.initializers.GlorotNormal(seed=None),
trainable=True,
name='embeddings')
# if weights is not None:
# self.codebook.assign(weights)
# self.codebooks_used = tf.zeros_like(self.num_embeddings, dtype=tf.int32)
self.codebooks_used = self.add_weight(
shape=(self.num_embeddings,),
initializer=tf.keras.initializers.Zeros(),
trainable=False,
name="codebooks_used",
dtype=tf.uint64
)
self.discarding_threshold = 0.25/self.num_embeddings # we init with appr. uniform distribution
else:
self.codebook = None
self.codebooks_used = None
self.discarding_threshold = None
def call(self, inputs, training=None):
if self.codebook is not None:
input_shape = tf.shape(inputs)
flattened = tf.reshape(inputs, [-1, self.embedding_dim])
if training == False:
# distances = tf.reduce_sum(tf.square(tf.expand_dims(inputs, axis=1) - self.codebook), axis=-1)
distances = (
tf.reduce_sum(tf.square(flattened), axis=1, keepdims=True)
- 2 * tf.matmul(flattened, tf.transpose(self.codebook))
+ tf.expand_dims(tf.reduce_sum(tf.square(self.codebook), axis=1), axis=0)
)
hard_assignments = tf.one_hot(tf.argmin(distances, axis=-1), self.num_embeddings)
quantized_hard = tf.matmul(hard_assignments, self.codebook)
quantized_hard = tf.reshape(quantized_hard, input_shape)
return quantized_hard
# quantized = quantized_hard
else:
input_shape = tf.shape(inputs)
flattened = tf.reshape(inputs, [-1, self.embedding_dim])
distances = (
tf.reduce_sum(tf.square(flattened), axis=1, keepdims=True)
- 2 * tf.matmul(flattened, tf.transpose(self.codebook))
+ tf.expand_dims(tf.reduce_sum(tf.square(self.codebook), axis=1), axis=0)
)
min_indices = tf.argmin(distances, axis=1) # Shape: (N,)
hard_quantized_input = tf.gather(self.codebook, min_indices) # Shape: (N, D)
hard_quantized_input = tf.reshape(hard_quantized_input, input_shape)
random_vector = tf.random.normal(tf.shape(inputs), mean=0.0, stddev=1.0)
norm_quantization_residual = safe_norm(inputs - hard_quantized_input, axis=1, keepdims=True) # (N, 1)
norm_random_vector = safe_norm(random_vector, axis=1, keepdims=True) # (N, 1)
vq_error = tf.math.divide_no_nan(norm_quantization_residual, (norm_random_vector + self.eps)) * random_vector # (N, D)
quantized_input = inputs + vq_error
encodings = tf.one_hot(min_indices, depth=self.num_embeddings, dtype=tf.float32)
self.codebooks_used.assign_add(tf.cast(tf.reduce_sum(encodings, axis=0), tf.uint64))
return quantized_input #, perplexity, self.codebooks_used
else:
return inputs
def replace_unused_codebooks(self):
usage_ratio = self.codebooks_used / tf.math.reduce_sum(self.codebooks_used)
unused_indices = tf.reshape(tf.where(usage_ratio < self.discarding_threshold), [-1])
used_indices = tf.reshape(tf.where(usage_ratio >= self.discarding_threshold), [-1])
unused_count = tf.size(unused_indices)
used_count = tf.size(used_indices)
if used_count == 0:
tf.print("####### used_indices equals zero / shuffling whole codebooks ######")
noise = self.eps_noise * tf.random.normal(tf.shape(self.codebook), dtype=self.codebook.dtype)
self.codebook.assign_add(noise)
else:
used_p = tf.gather(usage_ratio, used_indices)
new_indices = np.random.choice(used_indices, unused_count, replace=True, p=used_p / tf.math.reduce_sum(used_p))
noise = self.eps_noise * tf.random.normal([unused_count, self.embedding_dim], dtype=self.codebook.dtype)
new_values = tf.gather(self.codebook, new_indices) + noise
updated_codebooks = tf.tensor_scatter_nd_update(self.codebook,
tf.expand_dims(unused_indices, axis=1),
new_values)
self.codebook.assign(updated_codebooks)
self.eps_noise = max(self.eps_noise * 0.95, 1e-5)
self.discarding_threshold = self.discarding_threshold * 0.95
tf.print("************* Replaced", unused_count, "codebooks *************")
# Setze die Nutzung zurück:
self.codebooks_used.assign(tf.zeros_like(self.codebooks_used, dtype=self.codebooks_used.dtype))
def get_config(self):
config = super().get_config()
config.update({
'vq_config': self.vq_config
})
return config
@classmethod
def from_config(cls, config):
return cls(**config)
@tf.keras.saving.register_keras_serializable()
class FRAEEncoder(tf.keras.layers.Layer):
def __init__(self, input_shape, latent_dim, layer_config, **kwargs):
super(FRAEEncoder, self).__init__(**kwargs)
self.myinput_shape = input_shape
self.latent_dim = latent_dim
self.layer_config = copy.deepcopy(layer_config)
self.SetupLayers(input_shape, latent_dim, layer_config)
def SetupLayers(self, input_shape, latent_dim, layer_config):
self.encoder = []
activations = layer_config["activations"]
num_neuron = layer_config["neurons"]
l2reg = 1e-3
if 'l2' in layer_config.keys():
l2reg = layer_config['l2']
for i, act in enumerate(activations):
if i == 0:
self.encoder.append(tf.keras.layers.Dense(num_neuron[i], activation=None, input_shape=(input_shape,), kernel_regularizer=regularizers.l2(l2reg), name='dense'))
# self.encoder.append(tf.keras.layers.Activation(activation=act))
else:
self.encoder.append(tf.keras.layers.Dense(num_neuron[i], activation=None, kernel_regularizer=regularizers.l2(l2reg), name='dense'))
# self.encoder.append(tf.keras.layers.Activation(activation=act))
self.encoder.append(tf.keras.layers.LayerNormalization(epsilon=1e-5))
self.encoder.append(tf.keras.layers.Activation(activation=act, name='activation'))
if len(activations) > 0:
if "LatentAct" in layer_config.keys():
self.encoder.append(tf.keras.layers.Dense(latent_dim, activation=None , kernel_regularizer=regularizers.l2(l2reg))) #layer_config["LatentAct"]
self.encoder.append(tf.keras.layers.LayerNormalization(epsilon=1e-5))
self.encoder.append(tf.keras.layers.Activation(activation=layer_config["LatentAct"], name='activation'))
else:
self.encoder.append(tf.keras.layers.Dense(latent_dim, activation=None, kernel_regularizer=regularizers.l2(l2reg)))
self.encoder.append(tf.keras.layers.LayerNormalization(epsilon=1e-5))
self.encoder.append(tf.keras.layers.Activation(activation="swish", name='activation'))
#self.encoder.append(tf.keras.layers.LayerNormalization(epsilon=1e-5))
else:
if "LatentAct" in layer_config.keys():
self.encoder.append(tf.keras.layers.Dense(latent_dim, activation=None, input_shape=(input_shape,), kernel_regularizer=regularizers.l2(l2reg)))
self.encoder.append(tf.keras.layers.LayerNormalization(epsilon=1e-7))
self.encoder.append(tf.keras.layers.Activation(activation=layer_config["LatentAct"], name='activation'))
else:
self.encoder.append(tf.keras.layers.Dense(latent_dim, activation=None, input_shape=(input_shape,), kernel_regularizer=regularizers.l2(l2reg)))
self.encoder.append(tf.keras.layers.LayerNormalization(epsilon=1e-7))
self.encoder.append(tf.keras.layers.Activation(activation="swish", name='activation'))
#self.encoder.append(tf.keras.layers.LayerNormalization(epsilon=1e-5))
def call(self, inputs, state):
# encoder
encoded = tf.concat([inputs, state], axis=-1)
for k, lrs in enumerate(self.encoder):
encoded = lrs(encoded)
if 'activation' in lrs.name:
encoded = tf.concat([encoded, state], axis=-1)
new_output = encoded
return new_output
def get_config(self):
config = super().get_config()
config.update({
'myinput_shape': self.myinput_shape,
'latent_dim': self.latent_dim,
'layer_config': self.layer_config
})
return config
@classmethod
def from_config(cls, config):
input_shape = config['myinput_shape']
latent_dim = config['latent_dim']
layer_config = config['layer_config']
return cls(input_shape, latent_dim, layer_config)
@tf.keras.saving.register_keras_serializable()
class FRAEDecoder(tf.keras.layers.Layer):
def __init__(self, output_dim, layer_config, use_decoded = True, **kwargs):
super(FRAEDecoder, self).__init__(**kwargs)
self.output_dim = output_dim
self.layer_config = copy.deepcopy(layer_config)
self.use_decoded = use_decoded
self.SetupLayers(output_dim, layer_config)
def SetupLayers(self, output_dim, layer_config):
self.decoder = []
activations = layer_config["activations"]
num_neuron = layer_config["neurons"]
l2reg = 1e-3
if 'l2' in layer_config.keys():
l2reg = layer_config['l2']
for i, act in reversed(list(enumerate(activations))):
self.decoder.append(tf.keras.layers.Dense(num_neuron[i],activation=None, kernel_regularizer=regularizers.l2(l2reg), name="dense") ) # kernel_initializer=tf.keras.initializers.Orthogonal()
self.decoder.append(tf.keras.layers.LayerNormalization(epsilon=1e-7))
self.decoder.append(tf.keras.layers.Activation(activation=act, name='activation'))
self.decoder.append(tf.keras.layers.Dense(output_dim, kernel_regularizer=regularizers.l2(l2reg), activation='linear'))
def call(self, encoded, state):
y = tf.concat([encoded, state], axis=-1)
for i, lrs in enumerate(self.decoder):
y = lrs(y)
if 'activation' in lrs.name:
y = tf.concat([y, state], axis=-1)
#update output and state
new_output = y # Result of
new_state = tf.concat([new_output, state[:, :, :-tf.shape(new_output)[-1]]], axis=-1)
return new_output, new_state
def get_config(self):
config = super().get_config()
config.update({
'output_dim': self.output_dim,
'layer_config': self.layer_config,
'use_decoded': self.use_decoded
})
return config
@classmethod
def from_config(cls, config):
output_dim = config['output_dim']
layer_config = config['layer_config']
if 'use_decoded' in config.keys():
use_decoded = config['use_decoded']
else:
use_decoded = True
return cls(output_dim, layer_config, use_decoded)
Due to the build() function all participating layers have their variables created then the entire model is built. Therefore I do not know why this error occurs:
frae_decoded = self.frae(encoded) # we handle bypass inside the model ValueError: Exception encountered when calling layer 'frae' (type FRAE). Creating variables on a non-first call to a function decorated with tf.function.
Does anyone have an idea? Thank you