How to create tf.data.Dataset from a set of elements inside tf.data.Dataset

I want to apply the same effect as (.choose_from_datasets) method
but to the tf.string elements inside my list_ds that have 30 tf.string items. How can I do this?


list_ds = tf.data.Dataset.list_files(str(PATH/'Pasta_test/*'), shuffle=False)
print(list_ds)
for i in list_ds.take(2):  
    print(i) 
    
print('--------------------------------------------------------------------------------------------------------------------')
#-------(.choose_from_datasets) method -----  https://www.tensorflow.org/api_docs/python/tf/data/Dataset#choose_from_datasets

datasets = [tf.data.Dataset.from_tensors("Tokyo"),
            tf.data.Dataset.from_tensors("Berlin"),
            tf.data.Dataset.from_tensors("Spiderman"),
            tf.data.Dataset.from_tensors("Batman"),]
            
# Define a dataset containing `[0, 1, 2, 0, 1, 2, 0, 1, 2]`.
choice_dataset = tf.data.Dataset.range(2,4)

result = tf.data.Dataset.choose_from_datasets(datasets, choice_dataset)
#-----------------------------------------------------------------------------------------------------------------------------
print('%s \n' %datasets)
for i in result:
    print(i)

Apply the .window method for the tf.data.Dataset object, more info can be found here: tf.data.Dataset  |  TensorFlow v2.16.1
.window splits the dataset into a finite dataset of a specified size given bysize

window( size, shift=None, stride=1, drop_remainder=False, name=None)

list_ds = tf.data.Dataset.list_files(str(PATH/'Test/*'), shuffle=False)
print(list_ds)
for i in list_ds:  
    print(i) 
    
set_size = 0.2
ds_length = list_ds.cardinality().numpy()
ds_set_size = int(ds_length*set_size)
print(ds_set_size)
print(" ")

training_ds = {}
i = 0
key_index = -1
dataset_list = list_ds.window(ds_set_size)
for dataset in dataset_list:
    key_index +=1
    for item in dataset:
        val_ds = dataset
        i +=1
        reset = i%ds_set_size
        if(reset == 0 ):
            print(key_index)
            training_ds['im_ds%s' %key_index] = dataset
print('\n')

print_it(training_ds.get("im_ds0"))
print_it(training_ds.get("im_ds1"))
print_it(training_ds.get("im_ds2"))
print_it(training_ds.get("im_ds3"))
print_it(training_ds.get("im_ds4"))