Our team has implemented a tf.keras.Model
sub-classed model overriding the train_step()
method. We also implemented the call()
method of the underlying model.
Now, during evaluation (i.e., running inference on what’s passed to validation_data
), we observe that a spurious dimension gets added to intermediate values. Like the expected dimension is, say (128, 768)
(where 128 denotes the batch size) but we get (128, 128, 768)
. But when a separate test_step()
is implemented this behavior goes away.
Has anyone faced something similar?
Cc: @anon1529149