Hi,
I’m trying to implement Sobolev training [1] in TF2 and having some problems. As far as I understand Sobolev training adds an additional term to the loss which is the difference between the derivative of the target data and the derivative of the prediction.
In my case I am just doing some simple 1D regression of some time-series data and I have precomputed the derivative of the target function.
I’m using the notes from Personnalisez ce qui se passe dans Model.fit | TensorFlow Core to try an implement this in a custom call to model.fit(), see below for the current version I have that is not working
below data
, which is typically a 2-tuple where the first element are the inputs and the second element are the targets. My first complication is I’m not sure how to pass additional information to model.fit.
The way I thought this would work is if I pass a [N, 2] array where the first column is my target data and the second column is the derivate of my target data.
Then below I retrieve these using
y = ys[...,0]
dy = ys[...,1]
But I’m not sure if that is correct.
Sorry for the badly explained post! What I’m really after is help with how to compute the derivative of the model’s prediction with respect to the input values so that I can add a term to the loss which is the difference between this and the known derivative of the target function.
Thanks in advance!
loss_tracker = tf.keras.metrics.Mean(name="loss")
class CustomModel(tf.keras.Model):
def train_step(self, data):
x, ys = data
y = ys[...,0]
dy = ys[...,1]
with tf.GradientTape() as tape:
tape.watch(x)
y_pred = self(x, training=True)
loss_y = tf.keras.losses.mean_squared_error(y, y_pred)
#derivative of prediction
d_y_pred = tape.gradient(y_pred, x)
loss_dy = tf.keras.losses.mean_squared_error(y, d_y_pred)
loss = loss_y + loss_dy
# Compute gradients
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Update weights
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# Compute our own metrics
loss_tracker.update_state(loss)
return {"loss": loss_tracker.result()}