Memory leak during training

I got the following model

class FRAE(tf.keras.Model):
    def __init__(self, latent_dim, shape, ht, n1, n2, n3, n4, bypass=False, trainable=True,**kwargs):
        super(FRAE, self).__init__(**kwargs)
        self.latent_dim = latent_dim
        self.shape = shape
        self.ht = ht
        self.buffer = tf.Variable(initial_value=tf.zeros(shape=(1,shape[0] * self.ht), dtype=tf.float32), trainable=False)
        # self.bufferlist =  [tf.zeros((1, shape[0]), dtype=tf.float32) for _ in range(ht)] #[tf.zeros(shape=(1,shape[0] ), dtype=tf.float32)] * ht
        self.bypass = bypass
        self.quantizer = None
        self.trainable = trainable
        
        self.l1 = tf.keras.layers.Dense(n1, activation='tanh', input_shape=shape)
        self.l2 = tf.keras.layers.Dense(n1, activation='tanh')
        self.ls = tf.keras.layers.Dense(latent_dim, activation='swish')

        self.l3 = tf.keras.layers.Dense(n3, activation='tanh')
        self.l4 = tf.keras.layers.Dense(n4, activation='tanh')
        self.l5 = tf.keras.layers.Dense(shape[-1], activation='linear')


    def get_config(self):
        config = super(FRAE,self).get_config().copy()
        config.update({'latent_dim':self.latent_dim, 'bypass':self.bypass, 'quantizer':self.quantizer, 
                       "encoder":self.encoder, "buffer":self.buffer,
                       'decoder':self.decoder,"ht":self.ht, "shape":self.shape, "name":self.name})        
        
        return config
          
    @tf.function(experimental_compile=True)
    def update_buffer(self, new_element):
        n = self.shape[0]
        # new_element_expanded = tf.expand_dims(new_element, axis=0)
        self.buffer.assign(tf.keras.backend.concatenate([new_element, self.buffer[:, :-n]], axis=1))

    @tf.function(experimental_compile=True)
    def resetBuffer(self):
        self.buffer[:,:].assign(tf.zeros(shape=(1,self.shape[0] * self.ht), dtype=tf.float32))

    def tensor_to_numpy(t):
        return tf.identity(t).numpy()

    @tf.function(experimental_compile=True)
    def call(self, x):        
        x = tf.squeeze(x,axis=0)
        decoded = tf.TensorArray(tf.float32, size=tf.shape(x)[0])
        for i in tf.range(tf.shape(x)[0]):
            xexpand = tf.expand_dims(x[i],axis=0)
            xin = tf.concat((xexpand, self.buffer), axis=1)
            encoded = self.ls(self.l2(self.l1(xin)))
            decin = tf.concat([encoded, self.buffer], axis=1)
            y = self.l5(self.l4(self.l3(decin)))
            decoded = decoded.write(i,y)
            i += 1
            # self.update_buffer(tf.squeeze(y))
            self.update_buffer(y)


        tmp = tf.transpose(decoded.stack(),[1,0,2])
        return tmp
    
   
    @tf.function(experimental_compile=True)
    def train_step(self, data):
    # Unpack the data. Its structure depends on your model and
    # on what you pass to `fit()`.
        x, y = data
    
        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)  # Forward pass
            # Compute the loss value
            # (the loss function is configured in `compile()`)
            loss = self.compute_loss(y=y, y_pred=y_pred)
    
        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        # Update metrics (includes the metric that tracks the loss)
        for metric in self.metrics:
            if metric.name == "loss":
                metric.update_state(loss)
            else:
                metric.update_state(y, y_pred)
        # Return a dict mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}

which runs finally sufficiently fast if I use CPU. Now, however, when I train the model on the CPU, RAM usage increases monotonically by about 100 MB every 3-4 seconds. I suspect the culprit is this line

    self.buffer.assign(tf.keras.backend.concatenate([new_element, self.buffer[:, :-n]], axis=1))

Do you have an idea how to fix it? I tried calling gc.collect() but to no avail.

Ok, I tested it, and this does NOT happen during inference. So it is related to the training and thus cannot really be due to the assign call I mentioned. Anyone got an idea? Must be the train_step perhaps.

Do you think Tensorboard profiler could help ? It includes tools to inspect what is going on in [memory] down to ops (Optimize TensorFlow performance using the Profiler  |  TensorFlow Core).

Pasting an excerpt of official Tensorflow documentation:

+++++
The pop-up window displays the following information:

  • timestamp(ms): The location of the selected event on the timeline.
  • event: The type of event (allocation or deallocation).
  • requested_size(GiBs): The amount of memory requested. This will be a negative number for deallocation events.
  • allocation_size(GiBs): The actual amount of memory allocated. This will be a negative number for deallocation events.
  • tf_op: The TensorFlow op that requests the allocation/deallocation.
  • step_id: The training step in which this event occurred.
  • region_type: The data entity type that this allocated memory is for. Possible values are temp for temporaries, output for activations and gradients, and persist/dynamic for weights and constants.
  • data_type: The tensor element type (e.g., uint8 for 8-bit unsigned integer).
  • tensor_shape: The shape of the tensor being allocated/deallocated.
  • memory_in_use(GiBs): The total memory that is in use at this point of time.

++++++