I am trying to save a model built by the functional API with a loss function that has tf.where, but it cannot be loaded again. Here is my code:
#!/usr/bin/env python3
import tensorflow as tf
def loss(output_0, output_1, weights, label_0, label_1):
q_0 = tf.where(condition=(label_0 == 0), x=tf.fill(tf.shape(output_0), 0.0), y=output_0)
q_0 = tf.where(condition=(label_0 == 1), x=tf.fill(tf.shape(output_0), 1.0), y=q_0)
q_1 = tf.where(condition=(label_1 == 0), x=tf.fill(tf.shape(output_1), 0.0), y=output_1)
q_1 = tf.where(condition=(label_1 == 1), x=tf.fill(tf.shape(output_1), 1.0), y=q_1)
L = weights * tf.square(q_0 - q_1)
return tf.reduce_mean(L)
if __name__ == '__main__':
x_0 = tf.keras.layers.Input(shape=(2,))
x_1 = tf.keras.layers.Input(shape=(2,))
w = tf.keras.layers.Input(shape=(1,))
l_0 = tf.keras.layers.Input(shape=(1,))
l_1 = tf.keras.layers.Input(shape=(1,))
layer = tf.keras.layers.Dense(name='dense1', units=10)
y_0 = layer(x_0)
y_1 = layer(x_1)
model = tf.keras.Model(
inputs=[x_0, x_1, w, l_0, l_1],
outputs=[y_0, y_1, w, l_0, l_1])
model.add_loss(
loss(output_0=y_0, output_1=y_1, weights=w,
label_0=l_0, label_1=l_1))
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5))
model.save('test_save')
test_reload = tf.keras.models.load_model('test_save')
If I run the code above, then I will get the following error messages:
Traceback (most recent call last):
File "/home/hanatok/HDD/Documents/playground/python_tf_bug/./test.py", line 31, in <module>
test_reload = tf.keras.models.load_model('test_save')
File "/home/hanatok/mambaforge/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/home/hanatok/mambaforge/lib/python3.10/site-packages/tensorflow/python/util/dispatch.py", line 1076, in op_dispatch_handler
result = api_dispatcher.Dispatch(args, kwargs)
TypeError: Missing required positional argument
Any ideas?
Update:
I solved the problem myself by replacing
label_0 == 0
with
tf.equal(label_0, 0)