tf newbie here, please be kind. I am using tf and sonnet, trying to create a conditional gan. My training function looks like this:
def train(dataset, generator_obj, discriminator_obj):
num_training_steps = 500000
gen_optimizer = snt.optimizers.Adam(...)
disc_optimizer = snt.optimizers.Adam(...)
with tf.Session() as sess:
for epoch in range(num_training_steps):
print("EPOCH {}".format(epoch))
for i in range(...):
print("BATCH {}".format(i))
batch=tf.stack(...)
batch_inputs = ...
batch_targets = ...
concat_inputs =gen_step(batch_inputs,batch_targets,generator_obj,gen_optimizer)
...
and gen_step looks like this:
@tf.function
def gen_step(batch_inputs, batch_targets, generator_obj,gen_optimizer):
with tf.GradientTape() as tape:
batch_predictions = generator_obj(batch_inputs)
gen_sequence = tf.concat([batch_inputs, tf.cast(batch_predictions,tf.float64)], axis=1)
real_sequence = tf.concat([batch_inputs, batch_targets], axis=1)
concat_inputs = tf.concat([real_sequence, gen_sequence], axis=0)
num_samples_per_input = 6
gen_samples = [generator_obj(batch_inputs) for _ in range(num_samples_per_input)]
grid_cell_reg = grid_cell_regularizer(tf.stack(gen_samples, axis=0), batch_targets)
gen_sequences = [tf.concat(...) for x in gen_samples]
gen_disc_loss = loss_hinge_gen(...)
gen_loss = ...
gen_optimizer.apply(tape.gradient(gen_loss,generator_obj.trainable_variables), generator_obj.trainable_variables)
return concat_inputs
When I run the network it executes gen_step and when it reaches the bottom, instead of returning, it goes back to the top of the function and begins to re-run it. I don’t understand why that happens and it causes the error “ValueError: tf.function-decorated function tried to create variables on non-first call.” in the ._initialize part of the sonnet convolution.
I am a little confused about how the autographing works in general, it all feels very different to pytorch, which I’m used to. Any help/advice would be appreciated.
Thank you!