Hi,
I’m using tensorflow==2.18.0 alongside tf_keras==2.18.0. I want to train a model on a huge dataset and I have a A100 80GB. When I use a batch size of 4096 I can train my model using only 20% of my GPUDRAM but when I double my batch size, I have tensors bigger than Int32 elements and I have illegal memory access error because some variable for example work_element_count are encoded as int32. Has someone ever faced this issue and have a fix please ? Do I need to rebuild tf from source with some compiler spec ?