Unknown/reduced dataset length after resampling

Hi!

I’m working with an imbalanced dataset, and am attempting to resample using tf.data.Dataset.sample_from_datasets, but I’m running into a strange issue. My initial dataframe contains ~319000 samples, and I’m attempting to sample with 50% weight for the two classes in the dataset. My minoirty class makes up ~9% of the data. I was initially getting TypeError: dataset length is unknown, but after asserting cardinality to my expected value (54748), my resampled dataset appears to only contain 42 samples (split into train, val, and test). Below is my code and outputs. Any insights as to why my resampling is so small would be much appreciated.

Thanks in advance!

def df_to_dataset(df, batch_size=32):
    pos_df = df[df['target'] == 1]
    neg_df = df[df['target'] == 0]
    pos_labels = pos_df.pop('target')
    pos_features = pos_df
    neg_labels = neg_df.pop('target')
    neg_features = neg_df
    print('before resampling: ')
    print(len(pos_features))
    print(len(neg_features)) 
    pos_ds = tf.data.Dataset.from_tensor_slices((dict(pos_features), pos_labels))
    neg_ds = tf.data.Dataset.from_tensor_slices((dict(neg_features), neg_labels))
    # resample
    resampled_ds = tf.data.Dataset.sample_from_datasets([pos_ds, neg_ds], weights=[0.5, 0.5])
    resampled_ds = resampled_ds.apply(tf.data.experimental.assert_cardinality(54748))
    resampled_ds = resampled_ds.shuffle(buffer_size=len(df))
    resampled_ds = resampled_ds.batch(batch_size).prefetch(2)
    return resampled_ds

train, val, test = np.split(dataframe.sample(frac=1), [int(0.8*len(dataframe)), int(0.9*len(dataframe))])
train_ds = df_to_dataset(df=train, batch_size=4096)
val_ds = df_to_dataset(df=val, batch_size=4096)
test_ds = df_to_dataset(df=test, batch_size=4096)
size = (len(train_ds)+len(val_ds)+len(test_ds))
print('dataset sizes:')
print(len(train_ds))
print(len(val_ds))
print(len(test_ds))

output:

before resampling: 
21913
233923
before resampling: 
2680
29299
before resampling: 
2780
29200
dataset sizes:
14
14
14

Hi @Rawan_Mahdi,

Sorry for the delay in response.
To resolve this issue, I suggest to use take(54748) to ensure the dataset has exactly the desired number of samples after resampling instead of assert_cardinality(54748) and reduce the shuffle buffer size ( to 10000 instead of len(df)) can improve memory efficiency and used to prevent issues with large datasets. If the resampled dataset is small, try using a smaller batch size, like 128, to avoid overly large batches that may not fit well with the data.

Hope this helps.Thank You.