Here is my code for reading in .npy files and their labels using generators.
def data_generator(file_paths):
np.random.shuffle(file_paths)
classes = tf.constant(["class 1","class 2","class 3"])
for fn in file_paths:
f = np.reshape(np.load(fn),(150,150,1))
for j in range(len(classes)):
if classes[j].numpy() in fn:
yield f,np.array(j)
and this is how I created my dataset
data_generator,args = (train_files),
output_signature=(
tf.TensorSpec(shape=(150,150,1),dtype=tf.float32),
tf.TensorSpec(shape=(),dtype=tf.int64)
)
)
However anytime I try to plot or view the generated dataset using next(iter(train_ds.take(5)))
I get the following error
InvalidArgumentError: {{function_node __wrapped__IteratorGetNext_output_types_2_device_/job:localhost/replica:0/task:0/device:CPU:0}} TypeError: data_generator() takes 1 positional argument but 30000 were given
Traceback (most recent call last):
File "/usr/local/lib/python3.9/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 855, in get_iterator
return self._iterators[iterator_id]
KeyError: 1
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/usr/local/lib/python3.9/dist-packages/tensorflow/python/ops/script_ops.py", line 271, in __call__
ret = func(*args)
File "/usr/local/lib/python3.9/dist-packages/tensorflow/python/autograph/impl/api.py", line 642, in wrapper
return func(*args, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 1039, in generator_py_func
values = next(generator_state.get_iterator(iterator_id))
File "/usr/local/lib/python3.9/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 857, in get_iterator
iterator = iter(self._generator(*self._args.pop(iterator_id)))
TypeError: data_generator() takes 1 positional argument but 30000 were given
[[{{node PyFunc}}]] [Op:IteratorGetNext]
train_files
is a list of paths. Any help will be greatly appreciated.