Hello,
I want to implement checkpointing of my model in the “SavedModel” format(the folder based one). Saving this model takes quite some time for my small lstm(around 10s). That’s why I wnated to clone the model and put it in a separate thread, to not hold up training for so long. When trying to clone the model the following error happens. Does anybody have any clue on why this happens?
My code to save the model in the callback looks like this:
def on_epoch_end(self, epoch: int, logs=None):
cloned_model = tf.keras.models.clone_model(self.model)
The error:
Traceback (most recent call last):
File "./main.py", line 35, in <module>
main()
File "./main.py", line 28, in main
compile_and_fit(model, train_gen, val_gen, timestamp, config.patience)
File "./training/compile_and_fit.py", line 28, in compile_and_fit
history = model.fit(train_dfs, epochs=MAX_EPOCHS,
File "./venv/lib/python3.9/site-packages/keras/engine/training.py", line 1230, in fit
callbacks.on_epoch_end(epoch, epoch_logs)
File "./venv/lib/python3.9/site-packages/keras/callbacks.py", line 413, in on_epoch_end
callback.on_epoch_end(epoch, logs)
File "./checkpointing/MetricsCallback.py", line 48, in on_epoch_end
cloned_model = tf.keras.models.clone_model(self.model)
File "./venv/lib/python3.9/site-packages/keras/models.py", line 448, in clone_model
return _clone_sequential_model(
File "./venv/lib/python3.9/site-packages/keras/models.py", line 332, in _clone_sequential_model
cloned_model = Sequential(layers=layers, name=model.name)
File "./venv/lib/python3.9/site-packages/tensorflow/python/training/tracking/base.py", line 530, in _method_wrapper
result = method(self, *args, **kwargs)
File "./venv/lib/python3.9/site-packages/keras/engine/sequential.py", line 134, in __init__
self.add(layer)
File "./venv/lib/python3.9/site-packages/tensorflow/python/training/tracking/base.py", line 530, in _method_wrapper
result = method(self, *args, **kwargs)
File "./venv/lib/python3.9/site-packages/keras/engine/sequential.py", line 217, in add
output_tensor = layer(self.outputs[0])
File "./venv/lib/python3.9/site-packages/keras/layers/recurrent.py", line 659, in __call__
return super(RNN, self).__call__(inputs, **kwargs)
File "./venv/lib/python3.9/site-packages/keras/engine/base_layer.py", line 976, in __call__
return self._functional_construction_call(inputs, args, kwargs,
File "./venv/lib/python3.9/site-packages/keras/engine/base_layer.py", line 1114, in _functional_construction_call
outputs = self._keras_tensor_symbolic_call(
File "./venv/lib/python3.9/site-packages/keras/engine/base_layer.py", line 848, in _keras_tensor_symbolic_call
return self._infer_output_signature(inputs, args, kwargs, input_masks)
File "./venv/lib/python3.9/site-packages/keras/engine/base_layer.py", line 886, in _infer_output_signature
self._maybe_build(inputs)
File "./venv/lib/python3.9/site-packages/keras/engine/base_layer.py", line 2659, in _maybe_build
self.build(input_shapes) # pylint:disable=not-callable
File "./venv/lib/python3.9/site-packages/keras/layers/recurrent.py", line 577, in build
self.cell.build(step_input_shape)
File "./venv/lib/python3.9/site-packages/keras/utils/tf_utils.py", line 259, in wrapper
output_shape = fn(instance, input_shape)
File "./venv/lib/python3.9/site-packages/keras/layers/recurrent.py", line 2354, in build
self.kernel = self.add_weight(
File "./venv/lib/python3.9/site-packages/keras/engine/base_layer.py", line 647, in add_weight
variable = self._add_variable_with_custom_getter(
File "./venv/lib/python3.9/site-packages/tensorflow/python/training/tracking/base.py", line 813, in _add_variable_with_custom_getter
new_variable = getter(
File "./venv/lib/python3.9/site-packages/keras/engine/base_layer_utils.py", line 117, in make_variable
return tf.compat.v1.Variable(
File "./venv/lib/python3.9/site-packages/tensorflow/python/ops/variables.py", line 266, in __call__
return cls._variable_v1_call(*args, **kwargs)
File "./venv/lib/python3.9/site-packages/tensorflow/python/ops/variables.py", line 212, in _variable_v1_call
return previous_getter(
File "./venv/lib/python3.9/site-packages/tensorflow/python/ops/variables.py", line 67, in getter
return captured_getter(captured_previous, **kwargs)
File "./venv/lib/python3.9/site-packages/tensorflow/python/distribute/distribute_lib.py", line 3547, in creator
return next_creator(**kwargs)
File "./venv/lib/python3.9/site-packages/tensorflow/python/ops/variables.py", line 205, in <lambda>
previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)
File "./venv/lib/python3.9/site-packages/tensorflow/python/ops/variable_scope.py", line 2612, in default_variable_creator
return resource_variable_ops.ResourceVariable(
File "./venv/lib/python3.9/site-packages/tensorflow/python/ops/variables.py", line 270, in __call__
return super(VariableMetaclass, cls).__call__(*args, **kwargs)
File "./venv/lib/python3.9/site-packages/tensorflow/python/ops/resource_variable_ops.py", line 1602, in __init__
self._init_from_args(
File "./venv/lib/python3.9/site-packages/tensorflow/python/ops/resource_variable_ops.py", line 1740, in _init_from_args
initial_value = initial_value()
File "./venv/lib/python3.9/site-packages/keras/initializers/initializers_v2.py", line 499, in __call__
fan_in, fan_out = _compute_fans(shape)
File "./venv/lib/python3.9/site-packages/keras/initializers/initializers_v2.py", line 1009, in _compute_fans
return int(fan_in), int(fan_out)
TypeError: int() argument must be a string, a bytes-like object or a number, not 'NoneType'