Hey! I have a tensorflow.python.data.ops.dataset_ops.MapDataset object which I’m loading from TensorflowRecords.
It has three “columns”: encodings, label, batch_id.
I need to batch this dataset on the batch_id column.
For example, if I have 50 unique batch_ids , after the batch operation, the dataset will contain 50 batches, and each batch will have the data with the corresponding batch_id.
I’m trying to use the group_by_window function:
key_func=lambda elem: tf.cast(elem[‘batch_id’], tf.float32)
reduce_func=lambda key, window: window.batch(100000)
ds = ds.group_by_window(key_func = key_func, reduce_func = reduce_func, window_size = 10000)
but this throws an error: TypeError: () takes 1 positional argument but 3 were given
How to achieve this? Thanks!