EEG classification using CNN have very low acc and nan loss

Hello, I recently try to classify an EEG signal using CNN on tensorflow, but somehow the training have very low accuracy and it doesn’t seem to get better even after a few epocs, before I use CPU to train the CNN after i switch to GPU the loss suddenly became nan.

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import pandas as pd
import os
import numpy as np
import matplotlib.pyplot as plt
import json
import random

Data Preparation

def get_dataset(directory):
    feature = []
    label = []
    for foldername in os.listdir(directory):
        folder = os.path.join(directory, foldername)
        if os.path.isdir(folder):
            files = os.listdir(folder)
            for filename in files:
                rel_path = os.path.join(directory, foldername, filename)
                temp_label = filename.split('.')[0].split('_')[0]
                if 'a' in temp_label:
                    label = np.append(label,'alcoholic')
                else:
                    label = np.append(label,'control')
                
                df_data = np.loadtxt(rel_path, delimiter=",")
                # decomp = np.arange(0, 366)
                # plt.plot(decomp, df_data)
                # plt.xlabel('Dimension Number')
                # plt.ylabel('Wavelet Bispectrum Energy')
                # plt.show()
                feature.append(df_data.T)
    return np.array(feature), np.array(label)
def get_batch(path):
    # loading extracted feature & label
    x, y = get_dataset(path)
    y = pd.DataFrame(y)
    
    # Encode the labels
    label_map = {"alcoholic": 1, "control": 0}
    y[0] = y[0].map(label_map)
    
    # y = keras.utils.to_categorical(y[0])
    dataset = tf.data.Dataset.from_tensor_slices((x, y[0] ))
    dataset = dataset.shuffle(len(y[0] )).batch(32)

    return dataset

Model Definition

def create_model():
    model = keras.models.Sequential()

    model.add(layers.Input(shape=(366,)))
    model.add(layers.Reshape((366, 1)))

    model.add(layers.Conv1D(filters=16, kernel_size=4, activation="relu"))
    model.add(layers.MaxPooling1D(pool_size=2))

    model.add(layers.Conv1D(filters=8, kernel_size=2, activation="relu"))
    model.add(layers.MaxPooling1D(pool_size=2))

    model.add(layers.Flatten())

    model.add(layers.Dense(512, activation="relu"))

    model.add(layers.Dropout(0.5))

    model.add(layers.Dense(256, activation="relu"))

    model.add(layers.Dropout(0.5))

    model.add(layers.Dense(1, activation="sigmoid"))

    return model

Main program

train = get_batch('../Features/smni_cmi_train_feature_256')

train_size = int(len(list(train.as_numpy_iterator()))*0.8)

train_ds = train.take(train_size)
val_ds = train.skip(train_size)

model = create_model()
model.summary()

model.compile(loss='binary_crossentropy', optimizer= tf.keras.optimizers.Adam(0.2), metrics=['acc'])

# Train the model
history = model.fit(train_ds, epochs=100, validation_data=(val_ds))

Output

_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 reshape_4 (Reshape)         (None, 366, 1)            0         
                                                                 
 conv1d_8 (Conv1D)           (None, 363, 16)           80        
                                                                 
 max_pooling1d_8 (MaxPooling  (None, 181, 16)          0         
 1D)                                                             
                                                                 
 conv1d_9 (Conv1D)           (None, 180, 8)            264       
                                                                 
 max_pooling1d_9 (MaxPooling  (None, 90, 8)            0         
 1D)                                                             
                                                                 
 flatten_4 (Flatten)         (None, 720)               0         
                                                                 
 dense_30 (Dense)            (None, 512)               369152    
                                                                 
 dropout_20 (Dropout)        (None, 512)               0         
                                                                 
 dense_31 (Dense)            (None, 256)               131328    
                                                                 
 dropout_21 (Dropout)        (None, 256)               0         
                                                                 
 dense_32 (Dense)            (None, 1)                 257       
                                                                 
=================================================================
Total params: 501,081
Trainable params: 501,081
Non-trainable params: 0
_________________________________________________________________
Epoch 1/100
14/14 [==============================] - 1s 26ms/step - loss: nan - acc: 0.5000 - val_loss: nan - val_acc: 0.4344
Epoch 2/100
14/14 [==============================] - 0s 12ms/step - loss: nan - acc: 0.4643 - val_loss: nan - val_acc: 0.5000
Epoch 3/100
14/14 [==============================] - 0s 12ms/step - loss: nan - acc: 0.4821 - val_loss: nan - val_acc: 0.4754
Epoch 4/100
14/14 [==============================] - 0s 10ms/step - loss: nan - acc: 0.4866 - val_loss: nan - val_acc: 0.4918
Epoch 5/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4598 - val_loss: nan - val_acc: 0.4426
Epoch 6/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4955 - val_loss: nan - val_acc: 0.4836
Epoch 7/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4754 - val_loss: nan - val_acc: 0.4754
Epoch 8/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4643 - val_loss: nan - val_acc: 0.4836
Epoch 9/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4799 - val_loss: nan - val_acc: 0.4754
Epoch 10/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4978 - val_loss: nan - val_acc: 0.4590
Epoch 11/100
14/14 [==============================] - 0s 10ms/step - loss: nan - acc: 0.4688 - val_loss: nan - val_acc: 0.5410
Epoch 12/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4844 - val_loss: nan - val_acc: 0.4918
Epoch 13/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4665 - val_loss: nan - val_acc: 0.4672
Epoch 14/100
14/14 [==============================] - 0s 10ms/step - loss: nan - acc: 0.4621 - val_loss: nan - val_acc: 0.4508
Epoch 15/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4821 - val_loss: nan - val_acc: 0.4508
Epoch 16/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4710 - val_loss: nan - val_acc: 0.4508
Epoch 17/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4754 - val_loss: nan - val_acc: 0.4590
Epoch 18/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4844 - val_loss: nan - val_acc: 0.4426
Epoch 19/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4643 - val_loss: nan - val_acc: 0.4590
Epoch 20/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4665 - val_loss: nan - val_acc: 0.4754
Epoch 21/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4844 - val_loss: nan - val_acc: 0.4672
Epoch 22/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4688 - val_loss: nan - val_acc: 0.5328
Epoch 23/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4799 - val_loss: nan - val_acc: 0.5082
Epoch 24/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4732 - val_loss: nan - val_acc: 0.4262
Epoch 25/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4799 - val_loss: nan - val_acc: 0.4262
Epoch 26/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4688 - val_loss: nan - val_acc: 0.4672
Epoch 27/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4933 - val_loss: nan - val_acc: 0.4672
Epoch 28/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4643 - val_loss: nan - val_acc: 0.4426
Epoch 29/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4665 - val_loss: nan - val_acc: 0.4262
Epoch 30/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4754 - val_loss: nan - val_acc: 0.5082
Epoch 31/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4754 - val_loss: nan - val_acc: 0.5000
Epoch 32/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4487 - val_loss: nan - val_acc: 0.4836
Epoch 33/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4754 - val_loss: nan - val_acc: 0.5492
Epoch 34/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4665 - val_loss: nan - val_acc: 0.4672
Epoch 35/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4754 - val_loss: nan - val_acc: 0.4180
Epoch 36/100
14/14 [==============================] - 0s 11ms/step - loss: nan - acc: 0.4844 - val_loss: nan - val_acc: 0.4508
Epoch 37/100
14/14 [==============================] - 0s 10ms/step - loss: nan - acc: 0.4821 - val_loss: nan - val_acc: 0.4672
Epoch 38/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4911 - val_loss: nan - val_acc: 0.4262
Epoch 39/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4710 - val_loss: nan - val_acc: 0.5410
Epoch 40/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4732 - val_loss: nan - val_acc: 0.4918
Epoch 41/100
14/14 [==============================] - 0s 10ms/step - loss: nan - acc: 0.4732 - val_loss: nan - val_acc: 0.4590
Epoch 42/100
14/14 [==============================] - 0s 10ms/step - loss: nan - acc: 0.4777 - val_loss: nan - val_acc: 0.4508
Epoch 43/100
14/14 [==============================] - 0s 10ms/step - loss: nan - acc: 0.4754 - val_loss: nan - val_acc: 0.4836
Epoch 44/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4799 - val_loss: nan - val_acc: 0.4590
Epoch 45/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4710 - val_loss: nan - val_acc: 0.4672
Epoch 46/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4777 - val_loss: nan - val_acc: 0.5328
Epoch 47/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4911 - val_loss: nan - val_acc: 0.5410
Epoch 48/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4732 - val_loss: nan - val_acc: 0.4918
Epoch 49/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4777 - val_loss: nan - val_acc: 0.4754
Epoch 50/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4754 - val_loss: nan - val_acc: 0.4426
Epoch 51/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4799 - val_loss: nan - val_acc: 0.5328
Epoch 52/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4799 - val_loss: nan - val_acc: 0.4262
Epoch 53/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4799 - val_loss: nan - val_acc: 0.4918
Epoch 54/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4710 - val_loss: nan - val_acc: 0.4344
Epoch 55/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4799 - val_loss: nan - val_acc: 0.4344
Epoch 56/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4732 - val_loss: nan - val_acc: 0.4590
Epoch 57/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4621 - val_loss: nan - val_acc: 0.4180
Epoch 58/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4688 - val_loss: nan - val_acc: 0.4180
Epoch 59/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4665 - val_loss: nan - val_acc: 0.4918
Epoch 60/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4777 - val_loss: nan - val_acc: 0.4754
Epoch 61/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4866 - val_loss: nan - val_acc: 0.4180
Epoch 62/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4844 - val_loss: nan - val_acc: 0.4754
Epoch 63/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4754 - val_loss: nan - val_acc: 0.4344
Epoch 64/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4710 - val_loss: nan - val_acc: 0.4344
Epoch 65/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4688 - val_loss: nan - val_acc: 0.4754
Epoch 66/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4754 - val_loss: nan - val_acc: 0.4918
Epoch 67/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4821 - val_loss: nan - val_acc: 0.4590
Epoch 68/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4844 - val_loss: nan - val_acc: 0.4344
Epoch 69/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4777 - val_loss: nan - val_acc: 0.4590
Epoch 70/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4799 - val_loss: nan - val_acc: 0.5082
Epoch 71/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4732 - val_loss: nan - val_acc: 0.4836
Epoch 72/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4509 - val_loss: nan - val_acc: 0.4918
Epoch 73/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4777 - val_loss: nan - val_acc: 0.4426
Epoch 74/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4777 - val_loss: nan - val_acc: 0.4098
Epoch 75/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4420 - val_loss: nan - val_acc: 0.4508
Epoch 76/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4777 - val_loss: nan - val_acc: 0.4262
Epoch 77/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4777 - val_loss: nan - val_acc: 0.5000
Epoch 78/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4621 - val_loss: nan - val_acc: 0.4426
Epoch 79/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4844 - val_loss: nan - val_acc: 0.4180
Epoch 80/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4732 - val_loss: nan - val_acc: 0.4426
Epoch 81/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4911 - val_loss: nan - val_acc: 0.4508
Epoch 82/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4710 - val_loss: nan - val_acc: 0.4426
Epoch 83/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4487 - val_loss: nan - val_acc: 0.3852
Epoch 84/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4821 - val_loss: nan - val_acc: 0.5164
Epoch 85/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4955 - val_loss: nan - val_acc: 0.5082
Epoch 86/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4821 - val_loss: nan - val_acc: 0.4836
Epoch 87/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4688 - val_loss: nan - val_acc: 0.5492
Epoch 88/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4710 - val_loss: nan - val_acc: 0.4508
Epoch 89/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4866 - val_loss: nan - val_acc: 0.5000
Epoch 90/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4777 - val_loss: nan - val_acc: 0.4590
Epoch 91/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4799 - val_loss: nan - val_acc: 0.5000
Epoch 92/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4844 - val_loss: nan - val_acc: 0.4672
Epoch 93/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4688 - val_loss: nan - val_acc: 0.4508
Epoch 94/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4866 - val_loss: nan - val_acc: 0.5000
Epoch 95/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4665 - val_loss: nan - val_acc: 0.4590
Epoch 96/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4464 - val_loss: nan - val_acc: 0.4754
Epoch 97/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4710 - val_loss: nan - val_acc: 0.5656
Epoch 98/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4844 - val_loss: nan - val_acc: 0.4426
Epoch 99/100
14/14 [==============================] - 0s 8ms/step - loss: nan - acc: 0.4643 - val_loss: nan - val_acc: 0.4754
Epoch 100/100
14/14 [==============================] - 0s 9ms/step - loss: nan - acc: 0.4710 - val_loss: nan - val_acc: 0.4590

@fulky,

Welcome to the Tensorflow Forum!

It seems like there might be a few issues with your code that could be causing the low accuracy and the nan loss. Here are a few things you can try to improve your model:

  1. Data normalization: Before training your model, you should normalize your data
  2. Changing the learning rate: The learning rate of your optimizer could be too high, which can cause the loss to diverge and become nan . Try decreasing the learning rate to see if it improves the performance of your model.
  3. Adding early stopping: You can use early stopping to stop the training of your model early if the validation loss stops improving.
  4. Increasing the batch size: You could try increasing the batch size of your training data to improve the accuracy of your model.

Thank you!

Thank you for the suggestion @chunduriv , turns out the problem lies in the dataset itself, It contains a lot of zero valued columns, after filtering out the dataset the model came out pretty good.