Recently, when we tried to use MultiWorkerMirroredStrategy with Keras, we found:
- When Keras wrap our passed in dataset with experimental distributed dataset, we found we cannot scale over x nodes because it needs us pass in a global batch size and global batch size needs to take number of workers into consideration (global batch size = batch size * num of workers * num of replica). Therefore, when we have a lot of workers, compared with Mirrored strategy, we start seeing job failure due to OOM
- We try to get around this issue by passing in distribute_datasets_from_function that we can have full control over per replica batch and sharding logic (and get around OOM issue). Then our job failed at:
tensorflow/tensorflow/python/keras/engine/data_adapter.py
Line 733 in 1923123
if _is_distributed_dataset(self._dataset):
When we passed in normal dataset, it has UNKNOWN cardinality and leverage
tensorflow/tensorflow/python/keras/engine/data_adapter.py
Line 710 in 1923123
def should_recreate_iterator(self):
to recreate iterator for every epoch. Our use case is to have validation step to exhaust our dataset instead of hard coding steps. I wonder if we can relax check in L733 altogether with change to L714. Then we can support no steps input from users? If you agree, I can submit the PR to make the change.
Please let me know if any downside of doing so.
Thanks