I’m getting a memory leak and I believe it to be linked to the following warning:
WARNING:tensorflow:6 out of the last 6 calls to <function _BaseOptimizer._update_step_xla at 0x7fa9f8074c20> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to Better performance with tf.function | TensorFlow Core and tf.function | TensorFlow v2.16.1 for more details.
The warning occurs in the following function:
def learn(self):
for _ in range(self.n_epochs):
state_arr, additional_info, action_arr, old_prob_arr, values, reward_arr, _, trades_complete, env_states, batches = self.memory.generate_batches() # generate batches
reward_diff = reward_arr[:-1] + values[1:] * (1 - tf.cast(trades_complete[:-1], dtype=tf.float32)) - values[:-1]
advantage = tf.concat([tf.cumsum(reward_diff, reverse=True), self.zero_tensor], axis=0)
with tf.GradientTape(persistent=True) as tape:
new_probs, new_val = self.cnn_actor_critic([state_arr, additional_info])
masked_new_probs = ENVIRONMENT.mass_apply_mask(new_probs.numpy(), env_states)
rows = tf.range(tf.shape(masked_new_probs)[0])
index_arr = tf.add(tf.cast(action_arr, dtype=tf.int32), self.one_val)
gather_indices = tf.stack([rows, index_arr], axis=1)
chosen_probs = tf.gather_nd(masked_new_probs, gather_indices)
new_log_probs_of_old_actions = tf.negative(tf.math.log(chosen_probs))
critic_value = tf.squeeze(new_val, 1) # removes dimensions of size 1 from the tensor
values_list = tf.convert_to_tensor(values)
returns = tf.add(advantage, values_list)
critic_loss = tf.keras.losses.MSE(critic_value, returns)
prob_ratio = tf.math.exp(tf.add(new_log_probs_of_old_actions,tf.cast(tf.negative(old_prob_arr), dtype=tf.float32)))
weighted_probs = tf.multiply(advantage, prob_ratio)
clipped_probs = tf.clip_by_value(prob_ratio, 0.8, 1.2)
weighted_clipped_probs = tf.multiply(clipped_probs, advantage)
l_clip = tf.math.minimum(weighted_probs, weighted_clipped_probs) # prviously actor_loss
entropy_term = tf.negative(tf.reduce_sum(tf.multiply(new_log_probs_of_old_actions, tf.math.log(new_log_probs_of_old_actions))))
l_q_extension = tf.add(tf.multiply(self.c1,critic_loss), tf.negative(tf.multiply(self.c2,entropy_term)))
l_q = tf.negative(tf.add(l_clip,l_q_extension))
actor_critic_cnn_loss = tf.math.reduce_mean(l_q)
cnn_actor_critic_params = self.cnn_actor_critic.trainable_variables
actor_critic_grads = tape.gradient([actor_critic_cnn_loss, critic_loss], cnn_actor_critic_params, unconnected_gradients=tf.UnconnectedGradients.ZERO)
self.cnn_actor_critic.optimizer.apply_gradients(zip(actor_critic_grads, cnn_actor_critic_params))
self.memory.clear_memory()
More specifically, the second last line: self.cnn_actor_critic.optimizer.apply_gradients(zip(actor_critic_grads, cnn_actor_critic_params))
When this line is commented, I no longer get the warning.
How do I alter the code to avoid the warning and the associated memory leak?