Hello, I recently try to classify an EEG signal using CNN on tensorflow, but somehow the training have very low accuracy and it doesn’t seem to get better even after a few epocs, before I use CPU to train the CNN after i switch to GPU the loss suddenly became nan.
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import pandas as pd
import os
import numpy as np
import matplotlib.pyplot as plt
import json
import random
Data Preparation
def get_dataset(directory):
feature = []
label = []
for foldername in os.listdir(directory):
folder = os.path.join(directory, foldername)
if os.path.isdir(folder):
files = os.listdir(folder)
for filename in files:
rel_path = os.path.join(directory, foldername, filename)
temp_label = filename.split('.')[0].split('_')[0]
if 'a' in temp_label:
label = np.append(label,'alcoholic')
else:
label = np.append(label,'control')
df_data = np.loadtxt(rel_path, delimiter=",")
# decomp = np.arange(0, 366)
# plt.plot(decomp, df_data)
# plt.xlabel('Dimension Number')
# plt.ylabel('Wavelet Bispectrum Energy')
# plt.show()
feature.append(df_data.T)
return np.array(feature), np.array(label)
def get_batch(path):
# loading extracted feature & label
x, y = get_dataset(path)
y = pd.DataFrame(y)
# Encode the labels
label_map = {"alcoholic": 1, "control": 0}
y[0] = y[0].map(label_map)
# y = keras.utils.to_categorical(y[0])
dataset = tf.data.Dataset.from_tensor_slices((x, y[0] ))
dataset = dataset.shuffle(len(y[0] )).batch(32)
return dataset
Model Definition
def create_model():
model = keras.models.Sequential()
model.add(layers.Input(shape=(366,)))
model.add(layers.Reshape((366, 1)))
model.add(layers.Conv1D(filters=16, kernel_size=4, activation="relu"))
model.add(layers.MaxPooling1D(pool_size=2))
model.add(layers.Conv1D(filters=8, kernel_size=2, activation="relu"))
model.add(layers.MaxPooling1D(pool_size=2))
model.add(layers.Flatten())
model.add(layers.Dense(512, activation="relu"))
model.add(layers.Dropout(0.5))
model.add(layers.Dense(256, activation="relu"))
model.add(layers.Dropout(0.5))
model.add(layers.Dense(1, activation="sigmoid"))
return model
Main program
train = get_batch('../Features/smni_cmi_train_feature_256')
train_size = int(len(list(train.as_numpy_iterator()))*0.8)
train_ds = train.take(train_size)
val_ds = train.skip(train_size)
model = create_model()
model.summary()
model.compile(loss='binary_crossentropy', optimizer= tf.keras.optimizers.Adam(0.2), metrics=['acc'])
# Train the model
history = model.fit(train_ds, epochs=100, validation_data=(val_ds))
Output
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
reshape_4 (Reshape) (None, 366, 1) 0
conv1d_8 (Conv1D) (None, 363, 16) 80
max_pooling1d_8 (MaxPooling (None, 181, 16) 0
1D)
conv1d_9 (Conv1D) (None, 180, 8) 264
max_pooling1d_9 (MaxPooling (None, 90, 8) 0
1D)
flatten_4 (Flatten) (None, 720) 0
dense_30 (Dense) (None, 512) 369152
dropout_20 (Dropout) (None, 512) 0
dense_31 (Dense) (None, 256) 131328
dropout_21 (Dropout) (None, 256) 0
dense_32 (Dense) (None, 1) 257
=================================================================
Total params: 501,081
Trainable params: 501,081
Non-trainable params: 0
_________________________________________________________________
Epoch 1/100
14/14 [==============================] - 1s 26ms/step - loss: nan - acc: 0.5000 - val_loss: nan - val_acc: 0.4344
Epoch 2/100
14/14 [==============================] - 0s 12ms/step - loss: nan - acc: 0.4643 - val_loss: nan - val_acc: 0.5000
Epoch 3/100
14/14 [==============================] - 0s 12ms/step - loss: nan - acc: 0.4821 - val_loss: nan - val_acc: 0.4754
Epoch 4/100
14/14 [==============================] - 0s 10ms/step - loss: nan - acc: 0.4866 - val_loss: nan - val_acc: 0.4918
Epoch 5/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4598 - val_loss: nan - val_acc: 0.4426
Epoch 6/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4955 - val_loss: nan - val_acc: 0.4836
Epoch 7/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4754 - val_loss: nan - val_acc: 0.4754
Epoch 8/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4643 - val_loss: nan - val_acc: 0.4836
Epoch 9/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4799 - val_loss: nan - val_acc: 0.4754
Epoch 10/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4978 - val_loss: nan - val_acc: 0.4590
Epoch 11/100
14/14 [==============================] - 0s 10ms/step - loss: nan - acc: 0.4688 - val_loss: nan - val_acc: 0.5410
Epoch 12/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4844 - val_loss: nan - val_acc: 0.4918
Epoch 13/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4665 - val_loss: nan - val_acc: 0.4672
Epoch 14/100
14/14 [==============================] - 0s 10ms/step - loss: nan - acc: 0.4621 - val_loss: nan - val_acc: 0.4508
Epoch 15/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4821 - val_loss: nan - val_acc: 0.4508
Epoch 16/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4710 - val_loss: nan - val_acc: 0.4508
Epoch 17/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4754 - val_loss: nan - val_acc: 0.4590
Epoch 18/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4844 - val_loss: nan - val_acc: 0.4426
Epoch 19/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4643 - val_loss: nan - val_acc: 0.4590
Epoch 20/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4665 - val_loss: nan - val_acc: 0.4754
Epoch 21/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4844 - val_loss: nan - val_acc: 0.4672
Epoch 22/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4688 - val_loss: nan - val_acc: 0.5328
Epoch 23/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4799 - val_loss: nan - val_acc: 0.5082
Epoch 24/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4732 - val_loss: nan - val_acc: 0.4262
Epoch 25/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4799 - val_loss: nan - val_acc: 0.4262
Epoch 26/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4688 - val_loss: nan - val_acc: 0.4672
Epoch 27/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4933 - val_loss: nan - val_acc: 0.4672
Epoch 28/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4643 - val_loss: nan - val_acc: 0.4426
Epoch 29/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4665 - val_loss: nan - val_acc: 0.4262
Epoch 30/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4754 - val_loss: nan - val_acc: 0.5082
Epoch 31/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4754 - val_loss: nan - val_acc: 0.5000
Epoch 32/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4487 - val_loss: nan - val_acc: 0.4836
Epoch 33/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4754 - val_loss: nan - val_acc: 0.5492
Epoch 34/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4665 - val_loss: nan - val_acc: 0.4672
Epoch 35/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4754 - val_loss: nan - val_acc: 0.4180
Epoch 36/100
14/14 [==============================] - 0s 11ms/step - loss: nan - acc: 0.4844 - val_loss: nan - val_acc: 0.4508
Epoch 37/100
14/14 [==============================] - 0s 10ms/step - loss: nan - acc: 0.4821 - val_loss: nan - val_acc: 0.4672
Epoch 38/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4911 - val_loss: nan - val_acc: 0.4262
Epoch 39/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4710 - val_loss: nan - val_acc: 0.5410
Epoch 40/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4732 - val_loss: nan - val_acc: 0.4918
Epoch 41/100
14/14 [==============================] - 0s 10ms/step - loss: nan - acc: 0.4732 - val_loss: nan - val_acc: 0.4590
Epoch 42/100
14/14 [==============================] - 0s 10ms/step - loss: nan - acc: 0.4777 - val_loss: nan - val_acc: 0.4508
Epoch 43/100
14/14 [==============================] - 0s 10ms/step - loss: nan - acc: 0.4754 - val_loss: nan - val_acc: 0.4836
Epoch 44/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4799 - val_loss: nan - val_acc: 0.4590
Epoch 45/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4710 - val_loss: nan - val_acc: 0.4672
Epoch 46/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4777 - val_loss: nan - val_acc: 0.5328
Epoch 47/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4911 - val_loss: nan - val_acc: 0.5410
Epoch 48/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4732 - val_loss: nan - val_acc: 0.4918
Epoch 49/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4777 - val_loss: nan - val_acc: 0.4754
Epoch 50/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4754 - val_loss: nan - val_acc: 0.4426
Epoch 51/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4799 - val_loss: nan - val_acc: 0.5328
Epoch 52/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4799 - val_loss: nan - val_acc: 0.4262
Epoch 53/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4799 - val_loss: nan - val_acc: 0.4918
Epoch 54/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4710 - val_loss: nan - val_acc: 0.4344
Epoch 55/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4799 - val_loss: nan - val_acc: 0.4344
Epoch 56/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4732 - val_loss: nan - val_acc: 0.4590
Epoch 57/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4621 - val_loss: nan - val_acc: 0.4180
Epoch 58/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4688 - val_loss: nan - val_acc: 0.4180
Epoch 59/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4665 - val_loss: nan - val_acc: 0.4918
Epoch 60/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4777 - val_loss: nan - val_acc: 0.4754
Epoch 61/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4866 - val_loss: nan - val_acc: 0.4180
Epoch 62/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4844 - val_loss: nan - val_acc: 0.4754
Epoch 63/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4754 - val_loss: nan - val_acc: 0.4344
Epoch 64/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4710 - val_loss: nan - val_acc: 0.4344
Epoch 65/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4688 - val_loss: nan - val_acc: 0.4754
Epoch 66/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4754 - val_loss: nan - val_acc: 0.4918
Epoch 67/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4821 - val_loss: nan - val_acc: 0.4590
Epoch 68/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4844 - val_loss: nan - val_acc: 0.4344
Epoch 69/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4777 - val_loss: nan - val_acc: 0.4590
Epoch 70/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4799 - val_loss: nan - val_acc: 0.5082
Epoch 71/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4732 - val_loss: nan - val_acc: 0.4836
Epoch 72/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4509 - val_loss: nan - val_acc: 0.4918
Epoch 73/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4777 - val_loss: nan - val_acc: 0.4426
Epoch 74/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4777 - val_loss: nan - val_acc: 0.4098
Epoch 75/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4420 - val_loss: nan - val_acc: 0.4508
Epoch 76/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4777 - val_loss: nan - val_acc: 0.4262
Epoch 77/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4777 - val_loss: nan - val_acc: 0.5000
Epoch 78/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4621 - val_loss: nan - val_acc: 0.4426
Epoch 79/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4844 - val_loss: nan - val_acc: 0.4180
Epoch 80/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4732 - val_loss: nan - val_acc: 0.4426
Epoch 81/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4911 - val_loss: nan - val_acc: 0.4508
Epoch 82/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4710 - val_loss: nan - val_acc: 0.4426
Epoch 83/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4487 - val_loss: nan - val_acc: 0.3852
Epoch 84/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4821 - val_loss: nan - val_acc: 0.5164
Epoch 85/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4955 - val_loss: nan - val_acc: 0.5082
Epoch 86/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4821 - val_loss: nan - val_acc: 0.4836
Epoch 87/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4688 - val_loss: nan - val_acc: 0.5492
Epoch 88/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4710 - val_loss: nan - val_acc: 0.4508
Epoch 89/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4866 - val_loss: nan - val_acc: 0.5000
Epoch 90/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4777 - val_loss: nan - val_acc: 0.4590
Epoch 91/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4799 - val_loss: nan - val_acc: 0.5000
Epoch 92/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4844 - val_loss: nan - val_acc: 0.4672
Epoch 93/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4688 - val_loss: nan - val_acc: 0.4508
Epoch 94/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4866 - val_loss: nan - val_acc: 0.5000
Epoch 95/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4665 - val_loss: nan - val_acc: 0.4590
Epoch 96/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4464 - val_loss: nan - val_acc: 0.4754
Epoch 97/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4710 - val_loss: nan - val_acc: 0.5656
Epoch 98/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4844 - val_loss: nan - val_acc: 0.4426
Epoch 99/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4643 - val_loss: nan - val_acc: 0.4754
Epoch 100/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4710 - val_loss: nan - val_acc: 0.4590