Hello everybody,
I am trying to read data from a tensorflow created using the .repeat() command. I want to read all values, but the repeat() operation keeps the execution and never stops. It makes sense since the system repeats the slices, but I want to take all values before the first repeating operation. For example:
val_data = [[1, 2] [1, 3] [1, 2], [1, 3], [1, 2], [1, 3], …, [1, 3]]. I would like read just [1, 2], [1,3]. Is there any way to do it without removing .repeat() operation? Or I must to del val_data, create again without repeat() command and take(-1)?
val_data = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_data = val_data.batch(batch_size).repeat()
Hi @Cavour_Martinelli, You can get the data using
x = [[1, 2], [1, 3], [1, 2], [1, 3], [1, 2], [1, 3]]
val_data = tf.data.Dataset.from_tensor_slices(x)
val_data = val_data.batch(2).repeat()
val_data = val_data.take(1)
you can iterate through the val to get [1,2],[1,3]
for i in val_data:
print(i)
#output:tf.Tensor(
[[1 2]
[1 3]], shape=(2, 2), dtype=int32)
Thank You.
Hello @Kiran_Sai_Ramineni , I think there was a misunderstanding from my side. I don’t know the number inside take( ) because the val_data repeats the data indefinitely. The scenario above is just an example. Supose you split your dataset into training and validation (85% and 15%) but you don’t know what is the last batch. How to get the 15% of data before repeating? That’s my question.
Hi @Cavour_Martinelli, Let’s say we have 1000 elements
data = np.arange(1000)
And divided those to training and validation of (85% and 15%) respectively
total_samples = len(data)
train_size = int(total_samples * 0.85)
valid_size = int(total_samples * 0.15)
train_set = data[:train_size]
valid_set = data[train_size:train_size + valid_size]
And created a dataset using
data= tf.data.Dataset.from_tensor_slices(train_set)
And make the dataset in batch dataset using
batch_data = data.batch(5)
To get the 15% data from the batch dataset.
num_elements = int(0.15 * len(batch_data))
for i in batch_data.take(num_elements):
print(i)
If you want to get 15% data from dataset
num_elements = int(0.15 * len(data))
for i in data.take(num_elements):
print(i)
Thank You.