Hi, i would like for an advice regarding a problem that i am having with a mlp with DenseFlipout layers.
I have been trying to train the neural network for a regression task and i am using denseFlipout layers with mostly the default settings
def create_flipout_bnn_model(train_size):
def normal_sp(params):
return tfd.Normal(loc=params[:,0:1], scale=1e-3 + tf.math.softplus(0.05 * params[:,1:2]))
kernel_divergence_fn=lambda q, p, _: tfp.distributions.kl_divergence(q, p) / (train_size)
bias_divergence_fn=lambda q, p, _: tfp.distributions.kl_divergence(q, p) / (train_size)
inputs = Input(shape=(1,),name="input layer")
hidden = tfp.layers.DenseFlipout(50,
kernel_divergence_fn=kernel_divergence_fn,
activation="relu",name="DenseFlipout_layer_1")(inputs)
hidden = tfp.layers.DenseFlipout(50,
kernel_divergence_fn=kernel_divergence_fn,
activation="relu",name="DenseFlipout_layer_2")(hidden)
hidden = tfp.layers.DenseFlipout(50,
kernel_divergence_fn=kernel_divergence_fn,
activation="relu",name="DenseFlipout_layer_3")(hidden)
params = tfp.layers.DenseFlipout(2,
kernel_divergence_fn=kernel_divergence_fn,
name="DenseFlipout_layer_5")(hidden)
dist = tfp.layers.DistributionLambda(normal_sp,name = 'normal_sp')(params)
model = Model(inputs=inputs, outputs=dist)
return model
flipout_BNN = create_flipout_bnn_model(train_size=train_size)
flipout_BNN.compile(optimizer=Adam(learning_rate=0.002 ),
loss=NLL,metrics= [tf.keras.metrics.RootMeanSquaredError()]
)
flipout_BNN.summary()
history_flipout_BNN = flipout_BNN.fit(X_train, y_train, epochs=50000, verbose=0, batch_size=batch_size,validation_data=(X_val,y_val) )
but the plot of the loss function keep showing spikes no matter the number of epochs. what can i do to avoid this issue? i think it’s related to the fact that the weights are sampled from a distribution but still…shouldnt those spikes disappear?