In documentation, keras.model.fit() runs in graph mode by default, even if eager mode is by default in TF2.x. So I expect that training a simple keras model (13 parameters) should be fast. But it is very slow on my computer (~30s). However, it will be 10 times faster (~3s) if I add this line in the code: tf.compat.v1.disable_eager_execution()
It seems that keras fit function do not run in graph mode as expected (?). Adding “run_eagerly=False” option while compiling the model do not solve the problem. Is this a bug ?
Run the code below to see the difference:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras import Input
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.optimizers import Adam
from datetime import datetime
# Do not use GPU
tf.config.set_visible_devices([], "GPU")
# tf.compat.v1.disable_eager_execution() # adding this line makes keras training much faster
# Generate data
def f(x):
return (x + 1) * np.sin(5 * x)
x_plot = np.arange(-1, 1 + 0.001, 0.001)
y_plot = f(x_plot)
x_train = np.arange(-1 + 0.05, 1, 0.2)
y_train = f(x_train)
x_val = np.arange(-1 + 0.15, 1, 0.2)
y_val = f(x_val)
# Plot the problem
plt.figure()
plt.plot(x_plot, y_plot, "-", label="Orgininal function")
plt.plot(x_train, y_train, "o", label="Training points")
plt.plot(x_val, y_val, "s", label="Validation points")
plt.xlim(-1, 1)
plt.ylim(-2, 2)
plt.xlabel("x")
plt.ylabel("f")
plt.grid()
plt.legend()
plt.show(block=False)
# Reshape
X_train = x_train.reshape(x_train.shape[0], 1)
Y_train = y_train.reshape(x_train.shape[0], 1)
X_val = x_val.reshape(x_val.shape[0], 1)
Y_val = y_val.reshape(x_val.shape[0], 1)
# Simple model
tf.keras.utils.set_random_seed(1)
model = Sequential()
model.add(Input(shape=(1,))) # Input layer
model.add(
Dense(
4,
activation="sigmoid",
kernel_initializer=RandomNormal(mean=0.0, stddev=1.0, seed=1),
)
)
model.add(Dense(1))
model.compile(loss="mean_squared_error", optimizer=Adam(learning_rate=3e-1))
start_time = datetime.now()
history = model.fit(
X_train,
Y_train,
validation_split=0.0,
validation_data=(X_val, Y_val),
validation_freq=1,
batch_size=X_train.shape[0],
epochs=2000,
verbose=0,
)
run_time = datetime.now() - start_time
print("Training time : {:.4f} s".format(run_time.total_seconds()))
Of course, I can run with the eager mode disabled. But a problem arises. When disabling the eager mode , I cannot use predict function of my trained model which was cached in the Streamlit framework (using @st.cache(allow_output_mutation=True) or @st.experimental_singleton). The following error arises:
InvalidArgumentError: Tensor input_1:0, specified in either feed_devices or fetch_devices was not found in the Graph
Traceback:
File "/home/vinh/miniconda3/envs/data/lib/python3.8/site-packages/streamlit/scriptrunner/script_runner.py", line 557, in _run_script
exec(code, module.__dict__)
File "dashboard.py", line 227, in <module>
Y_train_pred = model.predict(X_train)
File "/home/vinh/miniconda3/envs/data/lib/python3.8/site-packages/keras/engine/training_v1.py", line 969, in predict
return func.predict(
File "/home/vinh/miniconda3/envs/data/lib/python3.8/site-packages/keras/engine/training_arrays_v1.py", line 700, in predict
return predict_loop(
File "/home/vinh/miniconda3/envs/data/lib/python3.8/site-packages/keras/engine/training_arrays_v1.py", line 377, in model_iteration
batch_outs = f(ins_batch)
File "/home/vinh/miniconda3/envs/data/lib/python3.8/site-packages/keras/backend.py", line 4282, in __call__
self._make_callable(feed_arrays, feed_symbols, symbol_vals, session)
File "/home/vinh/miniconda3/envs/data/lib/python3.8/site-packages/keras/backend.py", line 4218, in _make_callable
callable_fn = session._make_callable_from_options(callable_opts)
File "/home/vinh/miniconda3/envs/data/lib/python3.8/site-packages/tensorflow/python/client/session.py", line 1513, in _make_callable_from_options
return BaseSession._Callable(self, callable_options)
File "/home/vinh/miniconda3/envs/data/lib/python3.8/site-packages/tensorflow/python/client/session.py", line 1471, in __init__
self._handle = tf_session.TF_SessionMakeCallable(
It seems like the trained model is somehow corrupted. This does not happen when the eager mode is activated. Where this difference comes from and how can I fix it?
Run the following code to reproduce the error in streamlit:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras import Input
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.optimizers import Adam
from datetime import datetime
import streamlit as st
# Do not use GPU
tf.config.set_visible_devices([], "GPU")
tf.compat.v1.disable_eager_execution() # adding this line make keras training much slower
# Generate data
def f(x):
return (x + 1) * np.sin(5 * x)
x_plot = np.arange(-1, 1 + 0.001, 0.001)
y_plot = f(x_plot)
x_train = np.arange(-1 + 0.05, 1, 0.2)
y_train = f(x_train)
x_val = np.arange(-1 + 0.15, 1, 0.2)
y_val = f(x_val)
# Plot the problem
plt.figure()
plt.plot(x_plot, y_plot, "-", label="Orgininal function")
plt.plot(x_train, y_train, "o", label="Training points")
plt.plot(x_val, y_val, "s", label="Validation points")
plt.xlim(-1, 1)
plt.ylim(-2, 2)
plt.xlabel("x")
plt.ylabel("f")
plt.grid()
plt.legend()
plt.show(block=False)
# Reshape
X_train = x_train.reshape(x_train.shape[0], 1)
Y_train = y_train.reshape(x_train.shape[0], 1)
X_val = x_val.reshape(x_val.shape[0], 1)
Y_val = y_val.reshape(x_val.shape[0], 1)
# Simple model
start_time = datetime.now()
@st.experimental_singleton
def train():
tf.keras.utils.set_random_seed(1)
model = Sequential()
model.add(Input(shape=(1,))) # Input layer
model.add(
Dense(
4,
activation="sigmoid",
kernel_initializer=RandomNormal(mean=0.0, stddev=1.0, seed=1),
)
)
model.add(Dense(1))
model.compile(
loss="mean_squared_error", optimizer=Adam(learning_rate=3e-1), run_eagerly=False
)
history = model.fit(
X_train,
Y_train,
validation_split=0.0,
validation_data=(X_val, Y_val),
validation_freq=1,
batch_size=X_train.shape[0],
epochs=2000,
verbose=0,
)
return model
run_time = datetime.now() - start_time
model2 = train()
print("Training time : {:.4f} s".format(run_time.total_seconds()))
st.write("Please rerun the app to see the error")
model2.predict(X_val)
Thank you for any idea!
I also created an issue in TF repo.
Btw, I am using TF 2.9 and python 3.8.13