I have a CNN which I want to classify sleepstages with. I divided the data into epochs of 30 seconds (1.9mio in total) and added all epochs to a tensorflow record (because as far as i researched that should be a pretty fast way to retrieve the data). While training i retrieve the data into tf datasets as seen in the code below.
Now to my problem / question:
Currently one training epoch takes about two hours plus validation. I am not sure if thats a normal time for my architecture. I run the training on my gpu which is a rtx 3090 so it should be pretty good for that. So when taking a look at my code, is there anything obvious i could improve, so the performace gets better or anything to improve overall. I am really new at this topic, so please bear with me
My Code:
functions for parsing the record and for splitting the datasets:
def parse_tfrecord(example_proto):
feature_description = {
'ecg': tf.io.FixedLenFeature([7680], tf.float32),
'relative_time': tf.io.FixedLenFeature([1], tf.float32),
'relative_position': tf.io.FixedLenFeature([1], tf.float32),
'activity_normalized': tf.io.FixedLenFeature([1], tf.float32),
'stage': tf.io.FixedLenFeature([1], tf.int64)
}
example = tf.io.parse_single_example(example_proto, feature_description)
ecg = example['ecg']
rest = tf.concat([example['relative_time'], example['relative_position'], example['activity_normalized']], axis=0)
stage = example['stage']
return (ecg, rest), stage
def create_tvt_split(dataset, split):
train = int(split[0] * 10)
val = int(split[1] * 10)
test = int(split[2] * 10)
train_ds = dataset.window(train, train + val + test).flat_map(lambda *ds: ds[0] if len(ds) == 1 else tf.data.Dataset.zip(ds))
validation_ds = dataset.skip(train).window(val, train + val + test).flat_map(lambda *ds: ds[0] if len(ds) == 1 else tf.data.Dataset.zip(ds))
test_ds = dataset.skip(train + val).window(test, train + val + test).flat_map(lambda *ds: ds[0] if len(ds) == 1 else tf.data.Dataset.zip(ds))
return train_ds, validation_ds, test_ds
retrieving the data, splitting into train, test val & step calculation:
dataset = tf.data.TFRecordDataset(r"d:\sleep_data\tf_record\combined.tfrecord")
dataset = dataset.map(parse_tfrecord)
BATCH_SIZE = 400
SPLIT_RATIOS = [0.8, 0.1, 0.1]
train_ds, validation_ds, test_ds = create_tvt_split(dataset, SPLIT_RATIOS)
train_ds = train_ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE).repeat()
validation_ds = validation_ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE).repeat()
test_ds = test_ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE).repeat()
total_records = count_records_in_pkl_files(pkl_dir)
num_train = int(total_records * SPLIT_RATIOS[0])
num_val = int(total_records * SPLIT_RATIOS[1])
num_test = int(total_records * SPLIT_RATIOS[2])
train_steps_per_epoch = num_train // BATCH_SIZE
val_steps_per_epoch = num_val // BATCH_SIZE
test_steps_per_epoch = num_test // BATCH_SIZE
print("Train steps per epoch:", train_steps_per_epoch, 'train amount:', num_train)
print("Validation steps per epoch:", val_steps_per_epoch, 'val amount:', num_val)
print("Test steps per epoch:", test_steps_per_epoch, 'test amount:', num_test)
My Model and the training:
def get_cnn_model():
ecg_input = Input(shape=(7680, 1))
combined_input = Input(shape=(3,))
mask = Masking(mask_value=-2)(ecg_input)
# (7680, 1)
conv1 = Conv1D(32, kernel_size=5 , activation='relu')(mask)
# (7676, 32)
max1 = MaxPooling1D(2)(conv1)
# (3838, 32)
conv2 = Conv1D(64, kernel_size=5, activation='relu')(max1)
# (3834, 64)
max2 = MaxPooling1D(2)(conv2)
# (1917, 64)
conv3 = Conv1D(128, kernel_size=5, activation='relu')(max2)
# (956, 128)
max3 = MaxPooling1D(2)(conv3)
conv_flat = Flatten()(max3)
x = tf.keras.layers.concatenate([conv_flat, combined_input])
x = Dense(128, activation='relu')(x)
x = Dense(64, activation='relu')(x)
x = Dense(5, activation='softmax')(x)
output = x
model = Model(inputs=[ecg_input, combined_input], outputs=output)
model.summary()
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model = get_cnn_model()
early_stopping = EarlyStopping(monitor='val_loss', patience=1, restore_best_weights=True)
with tf.device('/device:GPU:0'):
history = model.fit(train_ds, epochs=10 , validation_data=validation_ds, steps_per_epoch = train_steps_per_epoch, validation_steps = val_steps_per_epoch, callbacks=[early_stopping])
I hope there is something to improve and to speed things up.