TensorFlow Memory Growth Issue in Training Loop
Summary
I encountered a progressive memory growth issue when repeatedly creating, training, and deleting a tf.keras.Model inside a loop. Despite explicitly clearing the session, deleting the model, and forcing garbage collection, memory usage keeps increasing over time.
This behavior is consistent across:
- Operating Systems: Linux, Windows 11
- Python Versions: 3.11.15, 3.12.15
- TensorFlow Variants:
tensorflow,tensorflow-cpu
Minimal Reproducible Example
Code
import tensorflow as tf
import time
import psutil
import os
import gc
p = psutil.Process(os.getpid())
class MyModel(tf.keras.Model):
def __init__(self):
super().__init__()
self.dense1 = tf.keras.layers.Dense(100, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(100, activation=tf.nn.softmax)
self.dense3 = tf.keras.layers.Dense(100, activation=tf.nn.softmax)
self.dense4 = tf.keras.layers.Dense(100, activation=tf.nn.softmax)
def call(self, inputs):
x = self.dense1(inputs)
x = self.dense2(x)
x = self.dense3(x)
x = self.dense4(x)
return x
mem = []
for r in range(0, 200):
mem.append(round(p.memory_info().rss / 1024**2, 3))
model = MyModel()
ds = tf.data.Dataset.from_tensor_slices(
(tf.random.uniform((64 * 4, 1000)), tf.ones((64 * 4)))
)
model.compile(
optimizer='sgd',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
)
model.fit(ds.batch(64), verbose=0)
del model
tf.keras.backend.clear_session()
gc.collect()
time.sleep(3)
Observed Behavior
- Memory usage (
RSS) steadily increases with each loop iteration. - This occurs despite:
del modeltf.keras.backend.clear_session()gc.collect()- No persistent references to the model or dataset
The memory progression has already been plotted using the mem list, and clearly shows a linear or step-wise increase in memory consumption over time.
Expected Behavior
Memory should:
- Stabilize after a few iterations, or
- Be reclaimed after session clearing and garbage collection
Additional Notes
- I found this issue during hyperparameter optimization that includes training multiple models in the same session
- The dataset is recreated every loop but is small and should not cause accumulation.
- No custom training loop is used—only
model.fit. - The issue appears independent of:
- Hardware
- OS
- Python version
- TensorFlow CPU/GPU variant
Questions
- Is this expected behavior due to internal TensorFlow caching or graph tracing?
- Could this be related to:
tf.functionretracing?- Dataset pipeline caching?
- Backend allocator behavior?
Attempted Mitigations
- Forcing garbage collection (
gc.collect()) → no improvement - Clearing session → no improvement
- Deleting model → no improvement
Feedback
Any insights or suggestions would be greatly appreciated.
Happy to provide additional diagnostics if needed!