All,
I am encountering the following error in the BulkInferrer
component of a TFX pipeline that trains a RandomForestModel
on both dense and sparse features:
/home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/tfx_bsl/beam/run_inference.py:225: BeamDeprecationWarning: options is deprecated since First stable release. References to <pipeline>.options will not be supported
examples.pipeline.options.view_as(GoogleCloudOptions).project)
INFO:absl:Path of output examples split `train` is /home/ryan/workspace/models/ml-toolkit/data/tfx_pipeline_output/penguins/BulkInferrer/output_examples/699/Split-train.
INFO:absl:RunInference on model: saved_model_spec {
model_path: "/home/ryan/workspace/models/ml-toolkit/data/tfx_pipeline_output/penguins/Trainer/model/698/Format-Serving"
}
INFO:absl:Path of output examples split `eval` is /home/ryan/workspace/models/ml-toolkit/data/tfx_pipeline_output/penguins/BulkInferrer/output_examples/699/Split-eval.
INFO:absl:tensorflow_text is not available.
INFO:absl:struct2tensor is not available.
WARNING:tensorflow:From /home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/tfx_bsl/beam/run_inference.py:615: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.saved_model.load` instead.
WARNING:tensorflow:From /home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/tfx_bsl/beam/run_inference.py:615: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.saved_model.load` instead.
2022-12-27 11:17:40.443051: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:357] MLIR V1 optimization pass is not enabled
[INFO 2022-12-27T11:17:40.502242451-05:00 kernel.cc:1175] Loading model from path /home/ryan/workspace/models/ml-toolkit/data/tfx_pipeline_output/penguins/Trainer/model/698/Format-Serving/assets/ with prefix fb2f3a13b14c4b82
[INFO 2022-12-27T11:17:40.504761455-05:00 decision_forest.cc:640] Model loaded with 81 root(s), 2021 node(s), and 3 input feature(s).
[INFO 2022-12-27T11:17:40.505051299-05:00 kernel.cc:1021] Use fast generic engine
INFO:absl:tensorflow_text is not available.
INFO:absl:struct2tensor is not available.
[INFO 2022-12-27T11:17:40.842664137-05:00 kernel.cc:1175] Loading model from path /home/ryan/workspace/models/ml-toolkit/data/tfx_pipeline_output/penguins/Trainer/model/698/Format-Serving/assets/ with prefix fb2f3a13b14c4b82
[INFO 2022-12-27T11:17:40.847428263-05:00 decision_forest.cc:640] Model loaded with 81 root(s), 2021 node(s), and 3 input feature(s).
[INFO 2022-12-27T11:17:40.847493745-05:00 kernel.cc:1021] Use fast generic engine
2022-12-27 11:17:40.963031: W tensorflow/core/framework/op_kernel.cc:1830] OP_REQUIRES failed at strided_slice_op.cc:105 : INVALID_ARGUMENT: slice index 0 of dimension 1 out of bounds.
INFO:absl:MetadataStore with DB connection initialized
ERROR:absl:Execution 699 failed.
INFO:absl:Cleaning up stateless execution info.
Traceback (most recent call last):
File "/home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/tensorflow/python/client/session.py", line 1378, in _do_call
return fn(*args)
File "/home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/tensorflow/python/client/session.py", line 1361, in _run_fn
return self._call_tf_sessionrun(options, feed_dict, fetch_list,
File "/home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/tensorflow/python/client/session.py", line 1454, in _call_tf_sessionrun
return tf_session.TF_SessionRun_wrapper(self._session, options, feed_dict,
tensorflow.python.framework.errors_impl.InvalidArgumentError: assertion failed: [sex with tensor SparseTensor(indices=Tensor(\"inputs_10:0\", shape=(None, 2), dtype=int64), values=Tensor(\"inputs_11:0\", shape=(None,), dtype=string), dense_shape=Tensor(\"inputs_12:0\", shape=(2,), dtype=int64)) is provided as a sparse tensor with dynamic shape. Such feature can only be scalar but multiple values have been observed at the same time.]
[[{{function_node __inference__build_normalized_inputs_93104}}{{node Assert_1/Assert}}]]
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "apache_beam/runners/common.py", line 1417, in apache_beam.runners.common.DoFnRunner.process
File "apache_beam/runners/common.py", line 837, in apache_beam.runners.common.PerWindowInvoker.invoke_process
File "apache_beam/runners/common.py", line 983, in apache_beam.runners.common.PerWindowInvoker._invoke_process_per_window
File "/home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/apache_beam/ml/inference/base.py", line 429, in process
result_generator = self._model_handler.run_inference(
File "/home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/tfx_bsl/beam/run_inference.py", line 1170, in run_inference
predictions = self._model_handler.run_inference(examples, model)
File "/home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/tfx_bsl/beam/run_inference.py", line 383, in run_inference
outputs = self._run_inference(examples, serialized_examples, model)
File "/home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/tfx_bsl/beam/run_inference.py", line 652, in _run_inference
result = model.run(
File "/home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/tensorflow/python/client/session.py", line 968, in run
result = self._run(None, fetches, feed_dict, options_ptr,
File "/home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/tensorflow/python/client/session.py", line 1191, in _run
results = self._do_run(handle, final_targets, final_fetches,
File "/home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/tensorflow/python/client/session.py", line 1371, in _do_run
return self._do_call(_run_fn, feeds, fetches, targets, options,
File "/home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/tensorflow/python/client/session.py", line 1397, in _do_call
raise type(e)(node_def, op, message) # pylint: disable=no-value-for-parameter
tensorflow.python.framework.errors_impl.InvalidArgumentError: Graph execution error:
assertion failed: [sex with tensor SparseTensor(indices=Tensor(\"inputs_10:0\", shape=(None, 2), dtype=int64), values=Tensor(\"inputs_11:0\", shape=(None,), dtype=string), dense_shape=Tensor(\"inputs_12:0\", shape=(2,), dtype=int64)) is provided as a sparse tensor with dynamic shape. Such feature can only be scalar but multiple values have been observed at the same time.]
[[{{node Assert_1/Assert}}]]
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "<string>", line 1, in <module>
File "/home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/click/core.py", line 829, in __call__
return self.main(*args, **kwargs)
File "/home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/click/core.py", line 782, in main
rv = self.invoke(ctx)
File "/home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/click/core.py", line 1066, in invoke
return ctx.invoke(self.callback, **ctx.params)
File "/home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/click/core.py", line 610, in invoke
return callback(*args, **kwargs)
File "/home/ryan/workspace/models/ml-toolkit/ml_toolkit/local_runner.py", line 65, in app
run(config_path, cache)
File "/home/ryan/workspace/models/ml-toolkit/ml_toolkit/local_runner.py", line 45, in run
tfx.orchestration.LocalDagRunner().run(
File "/home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/tfx/orchestration/portable/tfx_runner.py", line 124, in run
return self.run_with_ir(pipeline_pb, run_options=run_options_pb, **kwargs)
File "/home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/tfx/orchestration/local/local_dag_runner.py", line 109, in run_with_ir
component_launcher.launch()
File "/home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/tfx/orchestration/portable/launcher.py", line 573, in launch
executor_output = self._run_executor(execution_info)
File "/home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/tfx/orchestration/portable/launcher.py", line 448, in _run_executor
executor_output = self._executor_operator.run_executor(execution_info)
File "/home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/tfx/orchestration/portable/beam_executor_operator.py", line 112, in run_executor
return python_executor_operator.run_with_executor(execution_info, executor)
File "/home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/tfx/orchestration/portable/python_executor_operator.py", line 58, in run_with_executor
result = executor.Do(execution_info.input_dict, output_dict,
File "/home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/tfx/components/bulk_inferrer/executor.py", line 117, in Do
self._run_model_inference(
File "/home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/tfx/components/bulk_inferrer/executor.py", line 212, in _run_model_inference
_ = (
File "/home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/apache_beam/pipeline.py", line 597, in __exit__
self.result = self.run()
File "/home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/apache_beam/pipeline.py", line 574, in run
return self.runner.run_pipeline(self, self._options)
File "/home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/apache_beam/runners/direct/direct_runner.py", line 131, in run_pipeline
return runner.run_pipeline(pipeline, options)
File "/home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/apache_beam/runners/portability/fn_api_runner/fn_runner.py", line 199, in run_pipeline
self._latest_run_result = self.run_via_runner_api(
File "/home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/apache_beam/runners/portability/fn_api_runner/fn_runner.py", line 212, in run_via_runner_api
return self.run_stages(stage_context, stages)
File "/home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/apache_beam/runners/portability/fn_api_runner/fn_runner.py", line 442, in run_stages
bundle_results = self._execute_bundle(
File "/home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/apache_beam/runners/portability/fn_api_runner/fn_runner.py", line 770, in _execute_bundle
self._run_bundle(
File "/home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/apache_beam/runners/portability/fn_api_runner/fn_runner.py", line 999, in _run_bundle
result, splits = bundle_manager.process_bundle(
File "/home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/apache_beam/runners/portability/fn_api_runner/fn_runner.py", line 1309, in process_bundle
result_future = self._worker_handler.control_conn.push(process_bundle_req)
File "/home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/apache_beam/runners/portability/fn_api_runner/worker_handlers.py", line 379, in push
response = self.worker.do_instruction(request)
File "/home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/apache_beam/runners/worker/sdk_worker.py", line 596, in do_instruction
return getattr(self, request_type)(
File "/home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/apache_beam/runners/worker/sdk_worker.py", line 634, in process_bundle
bundle_processor.process_bundle(instruction_id))
File "/home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/apache_beam/runners/worker/bundle_processor.py", line 1003, in process_bundle
input_op_by_transform_id[element.transform_id].process_encoded(
File "/home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/apache_beam/runners/worker/bundle_processor.py", line 227, in process_encoded
self.output(decoded_value)
File "apache_beam/runners/worker/operations.py", line 526, in apache_beam.runners.worker.operations.Operation.output
File "apache_beam/runners/worker/operations.py", line 528, in apache_beam.runners.worker.operations.Operation.output
File "apache_beam/runners/worker/operations.py", line 237, in apache_beam.runners.worker.operations.SingletonElementConsumerSet.receive
File "apache_beam/runners/worker/operations.py", line 240, in apache_beam.runners.worker.operations.SingletonElementConsumerSet.receive
File "apache_beam/runners/worker/operations.py", line 1021, in apache_beam.runners.worker.operations.SdfProcessSizedElements.process
File "apache_beam/runners/worker/operations.py", line 1030, in apache_beam.runners.worker.operations.SdfProcessSizedElements.process
File "apache_beam/runners/common.py", line 1432, in apache_beam.runners.common.DoFnRunner.process_with_sized_restriction
File "apache_beam/runners/common.py", line 817, in apache_beam.runners.common.PerWindowInvoker.invoke_process
File "apache_beam/runners/common.py", line 981, in apache_beam.runners.common.PerWindowInvoker._invoke_process_per_window
File "apache_beam/runners/common.py", line 1581, in apache_beam.runners.common._OutputHandler.handle_process_outputs
File "apache_beam/runners/common.py", line 1694, in apache_beam.runners.common._OutputHandler._write_value_to_tag
File "apache_beam/runners/worker/operations.py", line 240, in apache_beam.runners.worker.operations.SingletonElementConsumerSet.receive
File "apache_beam/runners/worker/operations.py", line 1308, in apache_beam.runners.worker.operations.FlattenOperation.process
File "apache_beam/runners/worker/operations.py", line 1311, in apache_beam.runners.worker.operations.FlattenOperation.process
File "apache_beam/runners/worker/operations.py", line 528, in apache_beam.runners.worker.operations.Operation.output
File "apache_beam/runners/worker/operations.py", line 237, in apache_beam.runners.worker.operations.SingletonElementConsumerSet.receive
File "apache_beam/runners/worker/operations.py", line 240, in apache_beam.runners.worker.operations.SingletonElementConsumerSet.receive
File "apache_beam/runners/worker/operations.py", line 907, in apache_beam.runners.worker.operations.DoOperation.process
File "apache_beam/runners/worker/operations.py", line 908, in apache_beam.runners.worker.operations.DoOperation.process
File "apache_beam/runners/common.py", line 1419, in apache_beam.runners.common.DoFnRunner.process
File "apache_beam/runners/common.py", line 1491, in apache_beam.runners.common.DoFnRunner._reraise_augmented
File "apache_beam/runners/common.py", line 1417, in apache_beam.runners.common.DoFnRunner.process
File "apache_beam/runners/common.py", line 623, in apache_beam.runners.common.SimpleInvoker.invoke_process
File "apache_beam/runners/common.py", line 1581, in apache_beam.runners.common._OutputHandler.handle_process_outputs
File "apache_beam/runners/common.py", line 1694, in apache_beam.runners.common._OutputHandler._write_value_to_tag
File "apache_beam/runners/worker/operations.py", line 240, in apache_beam.runners.worker.operations.SingletonElementConsumerSet.receive
File "apache_beam/runners/worker/operations.py", line 907, in apache_beam.runners.worker.operations.DoOperation.process
File "apache_beam/runners/worker/operations.py", line 908, in apache_beam.runners.worker.operations.DoOperation.process
File "apache_beam/runners/common.py", line 1419, in apache_beam.runners.common.DoFnRunner.process
File "apache_beam/runners/common.py", line 1491, in apache_beam.runners.common.DoFnRunner._reraise_augmented
File "apache_beam/runners/common.py", line 1417, in apache_beam.runners.common.DoFnRunner.process
File "apache_beam/runners/common.py", line 623, in apache_beam.runners.common.SimpleInvoker.invoke_process
File "apache_beam/runners/common.py", line 1581, in apache_beam.runners.common._OutputHandler.handle_process_outputs
File "apache_beam/runners/common.py", line 1694, in apache_beam.runners.common._OutputHandler._write_value_to_tag
File "apache_beam/runners/worker/operations.py", line 240, in apache_beam.runners.worker.operations.SingletonElementConsumerSet.receive
File "apache_beam/runners/worker/operations.py", line 907, in apache_beam.runners.worker.operations.DoOperation.process
File "apache_beam/runners/worker/operations.py", line 908, in apache_beam.runners.worker.operations.DoOperation.process
File "apache_beam/runners/common.py", line 1419, in apache_beam.runners.common.DoFnRunner.process
File "apache_beam/runners/common.py", line 1491, in apache_beam.runners.common.DoFnRunner._reraise_augmented
File "apache_beam/runners/common.py", line 1417, in apache_beam.runners.common.DoFnRunner.process
File "apache_beam/runners/common.py", line 623, in apache_beam.runners.common.SimpleInvoker.invoke_process
File "apache_beam/runners/common.py", line 1581, in apache_beam.runners.common._OutputHandler.handle_process_outputs
File "apache_beam/runners/common.py", line 1694, in apache_beam.runners.common._OutputHandler._write_value_to_tag
File "apache_beam/runners/worker/operations.py", line 240, in apache_beam.runners.worker.operations.SingletonElementConsumerSet.receive
File "apache_beam/runners/worker/operations.py", line 907, in apache_beam.runners.worker.operations.DoOperation.process
File "apache_beam/runners/worker/operations.py", line 908, in apache_beam.runners.worker.operations.DoOperation.process
File "apache_beam/runners/common.py", line 1419, in apache_beam.runners.common.DoFnRunner.process
File "apache_beam/runners/common.py", line 1507, in apache_beam.runners.common.DoFnRunner._reraise_augmented
File "apache_beam/runners/common.py", line 1417, in apache_beam.runners.common.DoFnRunner.process
File "apache_beam/runners/common.py", line 837, in apache_beam.runners.common.PerWindowInvoker.invoke_process
File "apache_beam/runners/common.py", line 983, in apache_beam.runners.common.PerWindowInvoker._invoke_process_per_window
File "/home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/apache_beam/ml/inference/base.py", line 429, in process
result_generator = self._model_handler.run_inference(
File "/home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/tfx_bsl/beam/run_inference.py", line 1170, in run_inference
predictions = self._model_handler.run_inference(examples, model)
File "/home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/tfx_bsl/beam/run_inference.py", line 383, in run_inference
outputs = self._run_inference(examples, serialized_examples, model)
File "/home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/tfx_bsl/beam/run_inference.py", line 652, in _run_inference
result = model.run(
File "/home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/tensorflow/python/client/session.py", line 968, in run
result = self._run(None, fetches, feed_dict, options_ptr,
File "/home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/tensorflow/python/client/session.py", line 1191, in _run
results = self._do_run(handle, final_targets, final_fetches,
File "/home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/tensorflow/python/client/session.py", line 1371, in _do_run
return self._do_call(_run_fn, feeds, fetches, targets, options,
File "/home/ryan/workspace/models/ml-toolkit/venv/lib/python3.9/site-packages/tensorflow/python/client/session.py", line 1397, in _do_call
raise type(e)(node_def, op, message) # pylint: disable=no-value-for-parameter
RuntimeError: tensorflow.python.framework.errors_impl.InvalidArgumentError: Graph execution error:
assertion failed: [sex with tensor SparseTensor(indices=Tensor(\"inputs_10:0\", shape=(None, 2), dtype=int64), values=Tensor(\"inputs_11:0\", shape=(None,), dtype=string), dense_shape=Tensor(\"inputs_12:0\", shape=(2,), dtype=int64)) is provided as a sparse tensor with dynamic shape. Such feature can only be scalar but multiple values have been observed at the same time.]
[[{{node Assert_1/Assert}}]] [while running 'RunInference[eval]/RunInference/RunInferenceImpl/BulkInference/BeamML_RunInference']
Although the Beam run inference fails, I am able to successfully load the the model and invoke the predict
function on a dataset containing SparseTensor
. Curious if anyone has experienced this as well or has any advice on where to look. Thank you.
Using python 3.9.16 with the following pip package versions:
apache-beam 2.43.0
tensorboard 2.11.0
tensorboard-data-server 0.6.1
tensorboard-plugin-wit 1.8.1
tensorflow 2.11.0
tensorflow-data-validation 1.12.0
tensorflow-decision-forests 1.1.0
tensorflow-estimator 2.11.0
tensorflow-hub 0.12.0
tensorflow-io 0.27.0
tensorflow-io-gcs-filesystem 0.27.0
tensorflow-metadata 1.12.0
tensorflow-model-analysis 0.43.0
tensorflow-probability 0.18.0
tensorflow-serving-api 2.11.0
tensorflow-transform 1.12.0
UPDATE: I believe I have a better understanding of the issue. Something (signature being a tf.function?) requires the SparseTensor
shape to not be completely dynamic. Here is the signature creation:
def make_serving_signatures(
model: tf.keras.Model,
tf_transform_output: tft.TFTransformOutput,
):
model.tft_layer = tf_transform_output.transform_features_layer()
@tf.function(
input_signature=[
tf.TensorSpec(shape=[None], dtype=tf.string, name="examples")
]
)
def serve_tf_examples_fn(serialized_tf_examples):
"""Returns the output to be used in the serving signature."""
raw_feature_spec = tf_transform_output.raw_feature_spec().copy()
parsed_features = tf.io.parse_example(
serialized_tf_examples, raw_feature_spec
)
transformed_features = model.tft_layer(parsed_features)
absl.logging.info("serve_transformed_features = %s", transformed_features)
return model(transformed_features)
@tf.function(
input_signature=[
tf.TensorSpec(shape=[None], dtype=tf.string, name="examples")
]
)
def transform_features_fn(serialized_tf_examples):
"""Returns the transformed_features to be fed as input to evaluator."""
raw_feature_spec = tf_transform_output.raw_feature_spec()
parsed_features = tf.io.parse_example(
serialized_tf_examples, raw_feature_spec
)
transformed_features = model.tft_layer(parsed_features)
absl.logging.info("eval_transformed_features = %s", transformed_features)
return transformed_features
return {
"serving_default": serve_tf_examples_fn,
"transform": transform_features_fn,
}
The features of interest are interpreted as tf.io.VarLenFeature
by default when determining the FeatureSpecType
. There is some documentation on this in the tfx-bsl repo. Since it is a tf.io.VarLenFeature
it will be parsed as a SparseTensor
with completely dynamic shape as tf.io.parse_example
needs to look at all supplied tf.train.Example
s to determine the shape.
It looks like the most direct solution is to provide your raw data in the form described in the above tfx-bsl link, but this is cumbersome and may eleminate some of the features that the StatisticsGen
, ExampleValidator
, etc components provide. What would be most convenient is some schema representation of nullable scalar features that correspond to sparse feature- and tensor-specs with non-dynamic dense shapes.