Hi, I am still working with a special recurrent model, which requires me (as far as I know; if you have another idea let me know) to store previous outputs of a batch as a buffer. In general, the size of the last batch of my dataset will have a smaller size than the nominal batch size. Due to this, if I do not take special precautions (e.g. dropping the last batch), I get a run time error. I am looking for a way to adapting the buffer size depending on the current batch size.
My code is as follows:
class FRAE(tf.keras.Model):
def __init__(self, latent_dim, shape, ht, n1, n2, n3, n4, batch_size=1, bypass=True, trainable=True,**kwargs):
super(FRAE, self).__init__(**kwargs)
self.latent_dim = latent_dim
self.shape = shape
self.ht = ht
self.batch_size = batch_size
self.SetupBuffer(batch_size, shape[0], ht)
self.bypass = bypass
self.trainable = trainable
self.l1 = tf.keras.layers.Dense(n1, activation='tanh', input_shape=shape)
self.l2 = tf.keras.layers.Dense(n2, activation='tanh')
self.ls = tf.keras.layers.Dense(latent_dim, activation='swish')
self.l3 = tf.keras.layers.Dense(n3, activation='tanh')
self.l4 = tf.keras.layers.Dense(n4, activation='tanh')
self.l5 = tf.keras.layers.Dense(shape[-1], activation='linear')
def SetupBuffer(self, batch_size, input_dim, ht):
self.buffer = tf.Variable(initial_value=tf.zeros(shape=(batch_size, input_dim * ht), dtype=tf.float32), trainable=False)
def call(self, x):
if self.bypass:
return x
decoded = tf.TensorArray(tf.float32, size=tf.shape(x)[1])
for i in tf.range(tf.shape(x)[1]):
xslice = x[:,i,:]
xin = tf.concat((xslice, self.buffer), axis=1)
encoded = self.ls(self.l2(self.l1(xin)))
decin = tf.concat([encoded, self.buffer], axis=1)
y = self.l5(self.l4(self.l3(decin)))
decoded = decoded.write(i,y)
i += 1
self.buffer.assign(tf.concat([y, self.buffer[:, :-self.shape[0]]], axis=1))
tmp = tf.transpose(decoded.stack(),[1,0,2])
return tmp
I would like to do something like
def call(self,x):
self.buffer = self.SetupBuffer(tf.shape(x)[0], self.shape[0], self.ht)
However, this does not run, because self.buffer cannot be set as is within the call method. All approaches I tried did not succeed. Is there some way I can dynamically adjust the shape/size of self.buffer to match x in dimension 0?
My only solution is to work with callbacks during training, which could call SetupBuffer right before a batch.
Any ideas are welcome.