I have made two instances of the same custom model in Tensorflow 2.9 (i.e., model = Model()
and ema_model = Model()
). During the training of model
in a custom loop, I want to calculate its EMA and update the ema_model
with these variables.
Having checked this solution and also using ema_model.set_weights(model.get_weights())
, my attempts were not successful. To be specific, I used them right after the optimization in the train_step function.
In other words, I want the parameters of the model
get updated in the training loop, while the parameters of the ema_model
are updated as the decayed version of the model
in each epoch.
Any hits/solution to this problem?