Fit customization

Hello,
I tried to customize fit method by creating a new model for PINNs simulation. I have created a train_step method decorated with @tf.function.

I wrote a method inside this class to look at the performance of my train_step method like the following code

    def fit_custom(
        self,
        inputs,
        outputs,
        epochs
    ):
        for epoch in range(epochs):
            metrs = self.train_step([inputs, outputs])

and the performance are as expected. But if I use the same model using the fit method, the performance are really bad. So, I imagine that the optimized version made using @tf.function is not used or maybe some other computations reduce the performance.

Could you explain what happens ?

Thanks

The performance difference occurs because TensorFlow’s built-in fit method includes additional overhead for metrics tracking, callbacks, and validation that your custom fit_custom doesn’t have. Here’s how to optimize it:

@tf.function
def train_step(self, data):
    inputs, outputs = data
    with tf.GradientTape() as tape:
        predictions = self(inputs, training=True)
        loss = self.loss_fn(outputs, predictions)
    gradients = tape.gradient(loss, self.trainable_variables)
    self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
    return loss

For production use, consider adding essential metrics while keeping the optimization by using TensorFlow’s @tf.function compilation. This gives you both speed and monitoring capabilities.

Thank you for your reply.
After a bit more digging, I realized that it was the iterator TFEpochIterator that was breaking all the performance. I made my own fit method by copying almost all the existing fit method, removing the iterator and making my own splitting based on the batch_size. And I’m getting the performance back.

I don’t know why using the iterator TFEpochIterator breaks performance.

Your finding about TFEpochIterator’s impact on performance is valuable. Creating a custom fit method that handles batching directly is often more efficient for specialized models like PINNs. The performance difference likely comes from TFEpochIterator’s additional overhead in data pipeline management, which isn’t always necessary for simpler data structures.

Here’s a quick example of efficient direct batching:

def custom_fit(self, inputs, outputs, batch_size):
    n_samples = len(inputs)
    indices = tf.range(n_samples)
    batched_indices = tf.data.Dataset.from_tensor_slices(indices).batch(batch_size)
    for batch_idx in batched_indices:
        self.train_step([inputs[batch_idx], outputs[batch_idx]])

Your optimization shows how sometimes simpler approaches can outperform more general-purpose solutions!

Thanks for the tip.

Is there an easy way to add the shuffle argument ?

1 Like
def custom_fit(self, inputs, outputs, batch_size, shuffle=True):
    n_samples = len(inputs)
    indices = tf.range(n_samples)
    if shuffle:
        indices = tf.random.shuffle(indices)
    batched_indices = tf.data.Dataset.from_tensor_slices(indices).batch(batch_size)
    for batch_idx in batched_indices:
        self.train_step([inputs[batch_idx], outputs[batch_idx]])

This gives you the flexibility to toggle shuffling with a simple boolean parameter while maintaining the performance benefits of your custom implementation!

Could this also help you with your fit customization question?
https://discuss.ai.google.dev/t/support-runtimeerror-merge-call-called-while-defining-a-new-graph-or-a-tf-function/24334?u=krows

Thanks for the shuffle.

The link you provide doesn’t work for me.