I was trying a neural network on the cifar-10 dataset but the training outputted nan
for the cost
of the model on the first epoch when I was using my GTX 1650 for laptop.
I tried to normalize the data with tf.keras.layers.Normalization
with mean 0 and standard deviation 1. I also tried tf.keras.layers.Rescaling(1./255)
to get values between 0 and 1. I also added the LossScaleOptimizer
to prevent underflow. I also used clipnorm = 1
in the optimizer to prevent overflow. But none of the above helped with the issue.
However, I copied the code to colab and used a gpu runtime and the training is actually successful without any nan for the cost
.
Code
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import tensorflow as tf
import seaborn as sns
sns.set(style="dark")
from tensorflow.keras.datasets.cifar10 import load_data
from tensorflow.keras import Model, Input, Sequential
from tensorflow.keras.layers import Add, Rescaling, Dense, Conv2D, GlobalAveragePooling2D, MaxPool2D, Dropout, BatchNormalization, ReLU, Layer, Reshape, Flatten, Activation, Normalization, Multiply, AveragePooling2D
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoard, TerminateOnNaN, CSVLogger
from tensorflow.keras.optimizers import SGD, Adam
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.regularizers import l2
# from tensorflow.keras.applications import efficientnet_v2
from functools import partial
from tensorflow.image import random_flip_left_right, random_crop, resize_with_crop_or_pad
from tensorflow.keras.models import load_model
from tensorflow.keras import mixed_precision
import tensorflow as tf
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)
mixed_precision.set_global_policy('mixed_float16')# if this line is deleted, there will be a out of memory error.
(x_train, y_train),(x_test, y_test) = load_data()
from tensorflow.keras.regularizers import l2
regparam = 0.0005
class WideResNet(Model):
def __init__(self , activation , numfilters , identity = False ):
super().__init__()
self.add = Add()
self.activation = Activation(activation)
self.batchnorm = BatchNormalization()
self.batchnorm2 = BatchNormalization()
self.mainconv1 = Conv2D(numfilters , (3,3), padding = 'same' , kernel_regularizer= l2(regparam) )
self.mainconv2 = Conv2D(numfilters , (3,3), padding = 'same' , kernel_regularizer= l2(regparam))
self.sideconv = Conv2D(numfilters , (1,1), padding = 'same' , kernel_regularizer= l2(regparam)) if not identity else None
def call(self , X):
mainbranch = self.batchnorm(X)
mainbranch = self.activation(mainbranch)
mainbranch = self.mainconv1(mainbranch)
mainbranch = self.batchnorm2(mainbranch)
mainbranch = self.activation(mainbranch)
mainbranch = self.mainconv2(mainbranch)
sidebranch = self.sideconv(X) if self.sideconv is not None else X
return self.add([mainbranch , sidebranch ])
def buildwrnmodel(k, shapes = [16,32 , 64], n_inner_layers =4 ,n_classes =10 , imageshape = (32,32,3) ,loss = 'sparse_categorical_crossentropy',
activation = 'relu', optimizer = 'adam' , metric = 'accuracy'):
inputs = Input(imageshape)
x = tf.keras.layers.Rescaling(1./255)(inputs)
x = Conv2D(16 , (3,3))(x)
x = BatchNormalization( )(x)
x = Activation(activation)(x)
for i, length in enumerate(shapes):
for j in range(n_inner_layers):
x = WideResNet(activation , length* k , identity = (j != 0) )(x)
x = BatchNormalization( )(x)
x = Activation('tanh')(x)
x = GlobalAveragePooling2D()(x)
x = Flatten()(x)
outputs = Dense(n_classes, activation="softmax", dtype='float32')(x)
model = Model(inputs=inputs, outputs=outputs)
model.compile( loss = loss, optimizer = optimizer, metrics = [metric])
print(model.summary())
return model
from tensorflow.keras.mixed_precision import LossScaleOptimizer
from keras.layers import LeakyReLU
wrn = buildwrnmodel( 7 ,optimizer = LossScaleOptimizer(tf.keras.optimizers.Adam(clipnorm = 1) , initial_scale = 2**30 ) , activation = LeakyReLU(0.1))
history = wrn.fit(x = x_train , y = y_train , epochs=50 , batch_size = 16 , validation_split = 0.2 , verbose = 1 )
Output for Local / GTX 1650:
Epoch 1/50
2500/2500 [==============================] - ETA: 0s - loss: nan - accuracy: 0.0997 # <---- the nan cost here
Output for Colab:
2500/2500 [==============================] - 350s 131ms/step - loss: 2.0917 - accuracy: 0.3268 - val_loss: 2.2110 - val_accuracy: 0.2778
Epoch 2/50
2500/2500 [==============================] - 326s 131ms/step - loss: 1.6841 - accuracy: 0.4144 - val_loss: 2.0834 - val_accuracy: 0.2945
Epoch 3/50
2500/2500 [==============================] - 326s 130ms/step - loss: 1.5518 - accuracy: 0.4737 - val_loss: 1.7759 - val_accuracy: 0.4252
Epoch 4/50
2500/2500 [==============================] - 325s 130ms/step - loss: 1.4494 - accuracy: 0.5221 - val_loss: 1.7152 - val_accuracy: 0.4548
Epoch 5/50
2500/2500 [==============================] - 325s 130ms/step - loss: 1.3813 - accuracy: 0.5464 - val_loss: 1.9141 - val_accuracy: 0.3800
Epoch 6/50
2500/2500 [==============================] - 325s 130ms/step - loss: 1.3320 - accuracy: 0.5684 - val_loss: 1.5846 - val_accuracy: 0.4920
Epoch 7/50
2500/2500 [==============================] - 324s 129ms/step - loss: 1.2882 - accuracy: 0.5847 - val_loss: 1.7444 - val_accuracy: 0.4798
Epoch 8/50
2500/2500 [==============================] - 324s 130ms/step - loss: 1.2460 - accuracy: 0.6057 - val_loss: 1.2865 - val_accuracy: 0.5981
Epoch 9/50
2500/2500 [==============================] - 324s 130ms/step - loss: 1.2215 - accuracy: 0.6112 - val_loss: 1.5941 - val_accuracy: 0.4577
Epoch 10/50
2500/2500 [==============================] - 323s 129ms/step - loss: 1.1926 - accuracy: 0.6244 - val_loss: 1.5356 - val_accuracy: 0.5154
Epoch 11/50
2500/2500 [==============================] - 323s 129ms/step - loss: 1.1703 - accuracy: 0.6353 - val_loss: 1.6718 - val_accuracy: 0.4706
Epoch 12/50
2500/2500 [==============================] - 323s 129ms/step - loss: 1.1521 - accuracy: 0.6450 - val_loss: 1.4850 - val_accuracy: 0.5209
Epoch 13/50
2500/2500 [==============================] - 324s 130ms/step - loss: 1.1241 - accuracy: 0.6562 - val_loss: 1.7300 - val_accuracy: 0.4685
Epoch 14/50
2500/2500 [==============================] - 324s 130ms/step - loss: 1.1133 - accuracy: 0.6625 - val_loss: 2.5892 - val_accuracy: 0.3180
Epoch 15/50
2500/2500 [==============================] - 324s 130ms/step - loss: 1.0970 - accuracy: 0.6719 - val_loss: 1.2511 - val_accuracy: 0.6163
Epoch 16/50
2500/2500 [==============================] - 324s 129ms/step - loss: 1.0848 - accuracy: 0.6785 - val_loss: 1.6947 - val_accuracy: 0.5217
Epoch 17/50
2500/2500 [==============================] - 323s 129ms/step - loss: 1.0742 - accuracy: 0.6823 - val_loss: 2.1976 - val_accuracy: 0.4288
Epoch 18/50
2500/2500 [==============================] - 324s 130ms/step - loss: 1.0602 - accuracy: 0.6896 - val_loss: 1.5810 - val_accuracy: 0.5695
Epoch 19/50
2500/2500 [==============================] - 323s 129ms/step - loss: 1.0435 - accuracy: 0.6962 - val_loss: 1.4429 - val_accuracy: 0.5653
Epoch 20/50
2500/2500 [==============================] - 323s 129ms/step - loss: 1.0381 - accuracy: 0.6973 - val_loss: 1.5911 - val_accuracy: 0.5423
Epoch 21/50
2500/2500 [==============================] - 323s 129ms/step - loss: 1.0231 - accuracy: 0.7044 - val_loss: 1.4593 - val_accuracy: 0.5889
Epoch 22/50
2500/2500 [==============================] - 322s 129ms/step - loss: 1.0160 - accuracy: 0.7096 - val_loss: 1.4631 - val_accuracy: 0.5841
Epoch 23/50
2500/2500 [==============================] - 322s 129ms/step - loss: 1.0110 - accuracy: 0.7095 - val_loss: 1.8995 - val_accuracy: 0.5124
Epoch 24/50
2500/2500 [==============================] - 322s 129ms/step - loss: 0.9988 - accuracy: 0.7141 - val_loss: 1.1256 - val_accuracy: 0.6848
Epoch 25/50
2500/2500 [==============================] - 322s 129ms/step - loss: 0.9927 - accuracy: 0.7188 - val_loss: 1.9539 - val_accuracy: 0.4719
Epoch 26/50
2500/2500 [==============================] - 322s 129ms/step - loss: 0.9923 - accuracy: 0.7165 - val_loss: 1.4381 - val_accuracy: 0.6026
Epoch 27/50
2500/2500 [==============================] - 323s 129ms/step - loss: 0.9826 - accuracy: 0.7223 - val_loss: 2.3859 - val_accuracy: 0.4096
Epoch 28/50
2500/2500 [==============================] - 322s 129ms/step - loss: 0.9820 - accuracy: 0.7217 - val_loss: 1.7952 - val_accuracy: 0.5303
Epoch 29/50
2500/2500 [==============================] - 321s 129ms/step - loss: 0.9767 - accuracy: 0.7260 - val_loss: 1.5632 - val_accuracy: 0.5590
Epoch 30/50
2500/2500 [==============================] - 321s 129ms/step - loss: 0.9643 - accuracy: 0.7307 - val_loss: 2.1064 - val_accuracy: 0.4547
Epoch 31/50
2500/2500 [==============================] - 322s 129ms/step - loss: 0.9550 - accuracy: 0.7322 - val_loss: 4.3578 - val_accuracy: 0.3707
Epoch 32/50
2500/2500 [==============================] - 321s 128ms/step - loss: 0.9596 - accuracy: 0.7326 - val_loss: 2.3511 - val_accuracy: 0.4620
Epoch 33/50
2500/2500 [==============================] - 321s 128ms/step - loss: 0.9552 - accuracy: 0.7345 - val_loss: 3.2045 - val_accuracy: 0.3117
Epoch 34/50
2500/2500 [==============================] - 322s 129ms/step - loss: 0.9454 - accuracy: 0.7367 - val_loss: 2.4369 - val_accuracy: 0.4574
Epoch 35/50
1970/2500 [======================>.......] - ETA: 1:04 - loss: 0.9404 - accuracy: 0.7360
I thought the mixed_precision.set_global_policy('mixed_float16')
caused this problem in the local environment / laptop as float16
has a smaller range than float32
. However, mixed floats (with the same limitations as mixed floats on local) do not seem to cause a problem in the colab.