Keras Inference vs Training Results

Hello, this is my first time here, so please correct me if I’m in the wrong forum.

I’m using a custom layer that scales the result of a Dense layer between 0 and 1:

class BoundingBoxNumber(tf.keras.layers.Layer):

	def __init__(self, self_input_shape):

		self.input_shape_custom = self_input_shape

		self.input_shape_accu = 1

		for item in self_input_shape:
			self.input_shape_accu *= item

		self.internal_dense_layer = tf.keras.layers.Dense(1, activation = 'tanh')

	def call(self, inputs):

		inputs = tf.keras.layers.Flatten()(inputs)

		inputs.set_shape((1, self.input_shape_accu))

		output = self.internal_dense_layer(inputs)


		output = tf.divide(output, 2)

		output = tf.math.add(output, 0.5)


Then I’m using a custom training loop:

for epoch in range(50):

	print("Epoch:", epoch, "of 50")

	average_loss = 0

	for iter, item in enumerate(image):
		num_bounding_boxes = tf.shape(bb[iter])[0]

		float_target = tf.cast(1/num_bounding_boxes, tf.float32)

		with tf.GradientTape() as tape:

			logits = dense_bb_num_layer(item)
			loss = NumLoss(logits, float_target)

			print("Logits:", logits, "Target:", float_target, "Loss:", loss)

		average_loss += loss

		call_gradients = tape.gradient(loss, dense_bb_num_layer.trainable_weights)

		call_optimizer.apply_gradients(zip(call_gradients, dense_bb_num_layer.trainable_weights))

	average_loss /= len(image)


and custom loss function:

def NumLoss(logits, expected):

	return((logits - expected)**2)

When running the layer in the custom training loop, the result is always exactly 0 or 1, whereas when it is outside (just calling BoundingBoxNumber()(input), I get a random float, as expected). My training data are all floats between 0 and 1, none of them are exactly zero or one.

Hi @Kevin_T

Welcome to the TensorFlow Forum!

Please let us know if this issue still persists. If so, Could you please share the standalone code along with the dataset type and shape you are using for model training to better understand the issue? Thank you.