I’m trying to write a custom loss function where there is an extra potential penalty when the true label is 17.0. From what I understand, the code below doesn’t work because of the “if” statement. I think I need to replace it with a tf.where() statement, but everything I’ve tried creates an error. Any suggestions?
from tensorflow.keras import backend as bk
def custom_error(y_true, y_pred):
total_error = 0.0
# do mse calc
error = y_true - y_pred
sqr_error = bk.square(error)
mse = bk.mean(sqr_error)
# This is an extra penalty I was trying to code
# if y_true==17.0 and bk.round(y_pred)<17.0:
# total_error += 25.0
# add mse to to total error as well
total_error += mse
mse2 = tf.math.add(mse,25.0)
total_error = tf.where(tf.equal(y_true,17.0) and tf.less(bk.round(y_pred),17.0), mse, mse2)
return total_error
model = tf.keras.Sequential([
tf.keras.layers.Dense(1024, activation='relu', input_shape=(40,)),
tf.keras.layers.Dropout(0.1),
tf.keras.layers.Dense(512, activation='relu'),
tf.keras.layers.Dropout(0.1),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.1),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dropout(0.1),
tf.keras.layers.Dense(1)
])
model.compile(optimizer='adam',
loss=custom_error
)
model.summary()
model.fit(xtrain, ytrain, epochs=10
)
Here’s the error message I’m getting:
Epoch 1/10
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
<ipython-input-29-ec85575477da> in <module>()
44 model.summary()
45
---> 46 model.fit(xtrain, ytrain, epochs=10
47 )
48
1 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
53 ctx.ensure_initialized()
54 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 55 inputs, attrs, num_outputs)
56 except core._NotOkStatusException as e:
57 if name is not None:
InvalidArgumentError: Graph execution error:
The second input must be a scalar, but it has shape [32]
[[{{node custom_error/cond/custom_error/Equal/_6}}]] [Op:__inference_train_function_433584]