I got the following model, an autoencoder with feedback to the input layers of both, encoder and decoder, which is very slow at the moment due to the for loop. But it seems TOO slow, even for that case. Is it possible to speed up the inference/training?
The model is:
class FRAE(tf.keras.Model):
def __init__(self, latent_dim, shape, ht, n1, n2, n3, n4, bypass=False, trainable=True,**kwargs):
super(FRAE, self).__init__(**kwargs)
self.latent_dim = latent_dim
self.shape = shape
self.ht = ht
self.buffer = tf.Variable(initial_value=tf.zeros(shape=(1,shape[0] * self.ht), dtype=tf.float32))
self.bypass = bypass
self.quantizer = None
self.trainable = trainable
self.l1 = tf.keras.layers.Dense(n1, activation='swish', input_shape=shape)
self.l2 = tf.keras.layers.Dense(n1, activation='swish')
self.ls = tf.keras.layers.Dense(latent_dim, activation='swish')
self.l3 = tf.keras.layers.Dense(n3, activation='swish')
self.l4 = tf.keras.layers.Dense(n4, activation='swish')
self.l5 = tf.keras.layers.Dense(shape[-1], activation='linear')
def get_config(self):
config = super(FRAE,self).get_config().copy()
config.update({'latent_dim':self.latent_dim, 'bypass':self.bypass, 'quantizer':self.quantizer,
"encoder":self.encoder, "buffer":self.buffer,
'decoder':self.decoder,"ht":self.ht, "shape":self.shape, "name":self.name})
return config
def update_buffer(self, new_element):
n = self.shape[0]
new_element_expanded = tf.expand_dims(new_element, axis=0)
self.buffer.assign(tf.keras.backend.concatenate([new_element_expanded, self.buffer[:, :-n]], axis=1))
def resetBuffer(self):
self.buffer[:,:].assign(tf.zeros(shape=(1,self.shape[0] * self.ht), dtype=tf.float32))
@tf.function
def call(self, x):
if self.bypass is True:
print("Bypassing FRAE", flush=True)
return x
else:
x = tf.squeeze(x,axis=0)
decoded = tf.TensorArray(tf.float32, size=tf.shape(x)[0])
for i in tf.range(tf.shape(x)[0]):
xexpand = tf.expand_dims(x[i],axis=0)
xin = tf.concat((xexpand, self.buffer), axis=1)
encoded = self.ls(self.l2(self.l1(xin)))
decin = tf.concat([encoded, self.buffer], axis=1)
y = self.l5(self.l4(self.l3(decin)))
decoded = decoded.write(i,y)
i += 1
self.update_buffer(tf.squeeze(y))
tmp = tf.transpose(decoded.stack(),[1,0,2])
return tmp