I have a huge dataset (1TB) with thousand of small hdf5 files, each consisting out of two 3D numpy arrays (only float64 numbers), which currently are fetched by a generator which is given to the tf.data.Dataset.from_generator function. Since I cant cache my data, the data fetching process is quite slow. Now I want to use all my CPUs and parallel fetch from my dataset. Here is my code:
def generator(files):
for file in files:
with h5py.File(file, 'r') as hf:
epsilon = hf['epsilon'][()]
field = hf['field'][()]
yield epsilon, field
dataset = tf.data.Dataset.from_generator(pygen.generator, args=[files],output_signature=(
tf.TensorSpec(shape=s[0], dtype=tf.float64), tf.TensorSpec(shape=s[1], dtype=tf.float64)))
Is there a best-practise/solution to this problem?
Hi @munsteraner, You can consider using tf.data.Dataset.prefetch this allows later elements to be prepared while the current element is being processed. Thank You.
Yeah, I am also using prefetch, but that doesn’t work either. I guess my disc reading speed is limiting the extraction. Also I read that the tf.py_function only uses one core and I don’t know how I do thread it. Here is my code (fetches data, but no speed up):
def load_hdf5_file(filename):
filename = tf.strings.reduce_join(filename, separator="").numpy().decode("utf-8")
with h5py.File(filename, 'r') as hf:
epsilon = hf['epsilon'][()]
field = hf['field'][()]
return epsilon, field
filenames = glob.glob(f'{args.p_data}/*.h5')
# Create a tf.data.Dataset using the filenames
dataset = tf.data.Dataset.from_tensor_slices(filenames)
# Map the load_hdf5_file function to load and extract arrays for each HDF5 file
num_parallel_calls = tf.data.AUTOTUNE
dataset = dataset.map(lambda x: tf.py_function(load_hdf5_file, [x], Tout=(tf.float32, tf.float32)), num_parallel_calls=num_parallel_calls)
spe = int(np.floor(len(filenames) / 32))
dataset = dataset.take(len(filenames)).batch(32).cache().repeat(3).prefetch(num_parallel_calls)
m = var_unet_3D_test.build_unet((120,100,50,1))
m.compile(run_eagerly=True)
m.fit(dataset,epochs=args.bs, steps_per_epoch = spe)