for example,
Blockquote loss = tf.reduce_mean(tf.square(heatmap_outs - gt_heatmap) * valid_mask)
If I want to calculate the loss function, in addition to y_pred and y_true, there is a valid_mask, and valid_mask is not a fixed parameter. Is there a way to achieve this by inheriting from tf.keras.losses.Loss?