I am designing a generator inheriting from keras.utils.Sequence
, and I find it easier to implement it creating a list then popping samples from it. But for reasons I could not understand, if I try to train a model until the generator is exhausted (meaning, not setting steps_per_epoch
, or setting it to be equal to len(gen)
), I get an error - in some implementations it simply tells me the generator ran out of data, but it some it explicity warns that Python cannot pop from an empty list. My implementations makes sure that the list is exhausted by the end of the epoch, but it is replenished by the on_epoch_end
method. If I simply get samples from the generator in a for loop, I can check that all batches are produced as expected.
If, instead, I set steps_per_epoch
to be len(gen)-1
, than training happens - even though in some implementations I do get the error message. It seems like Keras is trying to sample for longer than needed.
What might be going wrong?
I know that, technically, I could avoid popping from the list, but for some data structures, I find it much easier to sample from.
Example code (running TF 2.8):
from tensorflow.keras.utils import Sequence
from tensorflow.keras.layers import Dense
from keras import Sequential
from random import shuffle
import numpy as np
class DataGenerator(Sequence):
def __init__(self, x, y, batch_size=10):
self.x = x
self.y = y
self.batch_size = batch_size
self.indices = list(range(len(self.x)))
shuffle(self.indices)
def __len__(self):
return len(self.x)//self.batch_size
def __getitem__(self, idx):
indices = []
for _ in range(self.batch_size):
index = self.indices.pop(0)
indices.append(index)
return self.x[indices], self.y[indices]
def on_epoch_end(self):
self.indices = list(range(len(self.x)))
shuffle(self.indices)
X = np.arange(100).reshape(100, 1)
Y = np.arange(100)
gen = DataGenerator(X, Y)
# This will show that the batches are produced as expected:
for i, (x, y) in enumerate(gen):
print(i, x.shape, y.shape)
gen = DataGenerator(X, Y)
model = Sequential([Dense(1)])
model.compile(loss='mean_squared_error')
# Throws error and training stops:
model.fit(gen, steps_per_epoch=len(gen), epochs=10)
# Throws error but training ends:
model.fit(gen, steps_per_epoch=len(gen)-1, epochs=10)