Tf function repeats instead of returning

tf newbie here, please be kind. I am using tf and sonnet, trying to create a conditional gan. My training function looks like this:

def train(dataset, generator_obj, discriminator_obj):
  num_training_steps = 500000
  gen_optimizer = snt.optimizers.Adam(...)
  disc_optimizer = snt.optimizers.Adam(...)
  with tf.Session() as sess:
    for epoch in range(num_training_steps):
      print("EPOCH {}".format(epoch))
      for i in range(...):
        print("BATCH {}".format(i))
        batch=tf.stack(...)
        batch_inputs = ...
        batch_targets = ...
        concat_inputs =gen_step(batch_inputs,batch_targets,generator_obj,gen_optimizer)
...

and gen_step looks like this:

@tf.function
def gen_step(batch_inputs, batch_targets, generator_obj,gen_optimizer):
  with tf.GradientTape() as tape:
    batch_predictions = generator_obj(batch_inputs)
    gen_sequence = tf.concat([batch_inputs, tf.cast(batch_predictions,tf.float64)], axis=1)
    real_sequence = tf.concat([batch_inputs, batch_targets], axis=1)
    concat_inputs = tf.concat([real_sequence, gen_sequence], axis=0)
    num_samples_per_input = 6
    gen_samples = [generator_obj(batch_inputs) for _ in range(num_samples_per_input)]
    grid_cell_reg = grid_cell_regularizer(tf.stack(gen_samples, axis=0), batch_targets)
    gen_sequences = [tf.concat(...) for x in gen_samples]
    gen_disc_loss = loss_hinge_gen(...)
    gen_loss = ...
  gen_optimizer.apply(tape.gradient(gen_loss,generator_obj.trainable_variables), generator_obj.trainable_variables)
  return concat_inputs

When I run the network it executes gen_step and when it reaches the bottom, instead of returning, it goes back to the top of the function and begins to re-run it. I don’t understand why that happens and it causes the error “ValueError: tf.function-decorated function tried to create variables on non-first call.” in the ._initialize part of the sonnet convolution.

I am a little confused about how the autographing works in general, it all feels very different to pytorch, which I’m used to. Any help/advice would be appreciated.

Thank you!

Hi @Theano_Xirouchaki, Before executing your tf.function decorator could you please try to enable run function eagerly by using tf.config.run_functions_eagerly(True) . Thank You.