My model seems to have obtained completely different results in train mode and eval mode due to BatchNormalization. What should I do

I am trying to convert the physnet model to onnx, but it seems that the bn layer of this model needs to work properly in training mode. In eval mode, I obtained completely different results from normal mode. What should I do?

Comparison of results

my code :

class PhysNet(keras.Model):

    def __init__(self, norm='batch'):
        self.norm = norm
        if norm == 'batch':
            norm = layers.BatchNormalization
        if norm == 'layer':
            norm = lambda :layers.LayerNormalization(axis=(1,))
        if norm == 'layer_frozen':
            norm = lambda :layers.LayerNormalization(axis=(1,), trainable=False)
        super().__init__()
        self.ConvBlock1 = keras.Sequential([
            layers.Conv3D(16, kernel_size=(1, 5, 5), strides=1, padding='same'),
            norm(),
            layers.Activation('relu')
        ])
        self.ConvBlock2 = keras.Sequential([
            layers.Conv3D(32, kernel_size=(3, 3, 3), strides=1, padding='same'),
            norm(),
            layers.Activation('relu')
        ])
        self.ConvBlock3 = keras.Sequential([
            layers.Conv3D(64, kernel_size=(3, 3, 3), strides=1, padding='same'),
            norm(),
            layers.Activation('relu')
        ])
        self.ConvBlock4 = keras.Sequential([
            layers.Conv3D(64, kernel_size=(3, 3, 3), strides=1, padding='same'),
            norm(),
            layers.Activation('relu')
        ])
        self.ConvBlock5 = keras.Sequential([
            layers.Conv3D(64, kernel_size=(3, 3, 3), strides=1, padding='same'),
            norm(),
            layers.Activation('relu')
        ])
        self.ConvBlock6 = keras.Sequential([
            layers.Conv3D(64, kernel_size=(3, 3, 3), strides=1, padding='same'),
            norm(),
            layers.Activation('relu')
        ])
        self.ConvBlock7 = keras.Sequential([
            layers.Conv3D(64, kernel_size=(3, 3, 3), strides=1, padding='same'),
            norm(),
            layers.Activation('relu')
        ])
        self.ConvBlock8 = keras.Sequential([
            layers.Conv3D(64, kernel_size=(3, 3, 3), strides=1, padding='same'),
            norm(),
            layers.Activation('relu')
        ])
        self.ConvBlock9 = keras.Sequential([
            layers.Conv3D(64, kernel_size=(3, 3, 3), strides=1, padding='same'),
            norm(),
            layers.Activation('relu')
        ])
        self.upsample = keras.Sequential([
            layers.Conv3DTranspose(64, kernel_size=(4, 1, 1), strides=(2, 1, 1), padding='same'),
            norm(),
            layers.Activation('elu')
        ])
        self.upsample2 = keras.Sequential([
            layers.Conv3DTranspose(64, kernel_size=(4, 1, 1), strides=(2, 1, 1), padding='same'),
            norm(),
            layers.Activation('elu')
        ])
        self.convBlock10 = layers.Conv3D(1, kernel_size=(1, 1, 1), strides=1)
        self.MaxpoolSpa = layers.MaxPool3D((1, 2, 2), strides=(1, 2, 2))
        self.MaxpoolSpaTem = layers.MaxPool3D((2, 2, 2), strides=2)
        self.poolspa = layers.AvgPool3D((1, 2, 2))
        self.flatten = layers.Reshape((-1,))

    def call(self, x):
        if self.norm == 'batch':
            training=True
        else:
            training=False
        x = self.ConvBlock1(x, training=training)
        x = self.MaxpoolSpa(x)
        x = self.ConvBlock2(x, training=training)
        x = self.ConvBlock3(x, training=training)
        x = self.MaxpoolSpaTem(x)
        x = self.ConvBlock4(x, training=training)
        x = self.ConvBlock5(x, training=training)
        x = self.MaxpoolSpaTem(x)
        x = self.ConvBlock6(x, training=training)
        x = self.ConvBlock7(x, training=training)
        x = self.MaxpoolSpa(x)
        x = self.ConvBlock8(x, training=training)
        x = self.ConvBlock9(x, training=training)
        x = self.upsample(x, training=training)
        x = self.upsample2(x, training=training)
        x = self.poolspa(x)
        x = self.convBlock10(x, training=training)
        x = self.flatten(x)
        x = x-tf.expand_dims(tf.reduce_mean(x, axis=-1), -1)
        return x