Train_on_batch and train_step used in custom training loop giving different results

I have a custom model class which has a training method implemented using train_step.

class MyFancyModel(tfk.models.Model):

	  ... 

    def train_model(self, data,  message=None):
        @tf.function()
        def train_step(model_and_inputs):
            model, data = model_and_inputs
            return model.train_step((data,))

        history = {}
        from tqdm import trange
        for i in trange(steps):
            _history = train_step((self, data))

            for k,v in _history.items():
                v = float(v)
                history[k].append(v)

        return history

Recently, I realized I should probably use self.train_on_batch rather than using my own local function definition based on train_step. So, I rewrote this method as

    def train_model(self, data,  message=None):
        history = {}
        from tqdm import trange
        for i in trange(steps):
            _history = train_on_batch(data, return_dict=True)

            for k,v in _history.items():
                v = float(v)
                history[k].append(v)

        return history

I thought that would be that, but I noticed that the new version’s output is slightly worse. I’m scratching my head trying to figure out what might be the salient difference between these two implementations. I’d appreciate any relevant insights into the keras.models.Model innards.

Cheers!

Hi @kilodalton,

Sorry for not getting back to you sooner.

As we know that train_on_batch is a built in method where the train_step is a custom looping where both operates on single batch of data. The key difference would be train_on_step is doing multiple gradient updates based on epoch, where train_step is updating single gradient at a time, this might affect the result of above.Kindly refer this github code of train_on_batch and train_step for understanding their functionalities.

Hope this helps.Thank You.