How To Add New Keras Model During Training With TF.function

Hi TF Community,

I’m a maintainer of a reinforcement learning framework called RLlib. I’ve been writing a new distributed training stack for this library using tensorflow 2.11.

One requirement that I have when training is sometimes when doing some types of Reinforcement learning training I need to add variables, or a whole new keras model, to my existing container of models that is being trained.

I’m using tf.function to speed up the training loop of this new distributed training stack.
The way that I have been doing this is not by decorating my training loop function with @tf.function but rather holding a reference to the tf traced function by doing something like self.update_function = tf.function(self._my_update_function) and then calling update_function.

The reason that I do this is because when I add or remove keras models from my container of models, I then re-assign my update_function in order to force retracing with the line self.update_function = tf.function(self._my_update_function).

This seems to have worked, but I’m wondering if any TF Devs or experts can chime in and let me know if there is a better way to force the retracing of my function. I see warnings that my function has been retraced many times, after I do this retracing operation, which logically makes sense to me since I am doing a retrace operation, but then also doesn’t make sense because from a syntax standpoint it looks like I’m creating a brand new traced function object every time I do a reassignment operation and call self.update_function = tf.function(self._my_update_function).

Thanks!

Hi @Avnish_Narayan,

Sorry for the delay in response.
As far as I’m aware your approach in creating new traced function and setting this reassignment operation self.update_function = tf.function(self._my_update_function) is good for the RL but those warnings are due to the frequent retracing. So I recommend to reduce the model changes to minimum and reassign only when necessary and use input_signature to check expected input shapes, which helps in reducing unnecessary retracing.

Eg:

tf.function(
    input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])

Kindly review this documentation regarding input_signature. In addition, I recommend checking this resource for more effective use of tf.function.

Thank You.