When I save a tf model in python and try to load it and run inference using java API it fails complaining about missing parts of the graph.
error:
org.tensorflow.exceptions.TFFailedPreconditionException:
Could not find variable dense /kernel . This could mean that the variable has been
deleted. In TF1, it can also mean the variable is uninitialized.
Debug info: container=localhost, status error message=Resource
localhost/dense/kernel/N10tensorflow3VarE does not exist.
[error] [[{{function_node __inference_serving_tfrecord_fn_314}}
{{node functional_1/dense_1/Cast/ReadVariableOp}}]]
I’m adding code to reproduce the problem:
Python code using tf 2.16.2
import tensorflow as tf
# Define the feature specifications for parsing
feature_spec = {
'query_id': tf.io.FixedLenFeature([], tf.string),
'rank': tf.io.FixedLenSequenceFeature([], tf.int64, allow_missing=True),
'clicked': tf.io.FixedLenSequenceFeature([], tf.int64, allow_missing=True),
'text_match_score': tf.io.FixedLenSequenceFeature([], tf.float32, allow_missing=True),
}
def build_model():
# Input layers
rank_input = tf.keras.Input(shape=(None,), name='rank', dtype=tf.int64)
clicked_input = tf.keras.Input(shape=(None,), name='clicked', dtype=tf.int64)
text_match_score_input = tf.keras.Input(shape=(None,), name='text_match_score', dtype=tf.float32)
# Cast inputs to float32 using Lambda layers
rank_float = tf.keras.layers.Lambda(lambda x: tf.cast(x, tf.float32))(rank_input)
clicked_float = tf.keras.layers.Lambda(lambda x: tf.cast(x, tf.float32))(clicked_input)
# Expand dimensions using Lambda layers
rank_expanded = tf.keras.layers.Lambda(lambda x: tf.expand_dims(x, axis=-1))(rank_float)
clicked_expanded = tf.keras.layers.Lambda(lambda x: tf.expand_dims(x, axis=-1))(clicked_float)
text_match_expanded = tf.keras.layers.Lambda(lambda x: tf.expand_dims(x, axis=-1))(text_match_score_input)
# Apply GlobalAveragePooling1D to reduce sequence dimension
rank_pooled = tf.keras.layers.GlobalAveragePooling1D()(rank_expanded)
clicked_pooled = tf.keras.layers.GlobalAveragePooling1D()(clicked_expanded)
text_match_pooled = tf.keras.layers.GlobalAveragePooling1D()(text_match_expanded)
# Concatenate pooled features
concatenated = tf.keras.layers.Concatenate()([
rank_pooled,
clicked_pooled,
text_match_pooled
])
# Dense layers
x = tf.keras.layers.Dense(16, activation='relu')(concatenated)
x = tf.keras.layers.Dense(8, activation='relu')(x)
output = tf.keras.layers.Dense(1, activation='sigmoid', name='ranking_score')(x)
# Build the model
model = tf.keras.Model(inputs=[rank_input, clicked_input, text_match_score_input], outputs=output)
return model
@tf.function(input_signature=[tf.TensorSpec([None], tf.string, name='protos')])
def serving_tfrecord_fn(serialized_examples):
# Parse the serialized SequenceExample
parsed_features = tf.io.parse_example(serialized_examples, features=feature_spec)
# Extract the features
rank = parsed_features['rank']
clicked = parsed_features['clicked']
text_match_score = parsed_features['text_match_score']
# Convert to RaggedTensors
rank_ragged = tf.RaggedTensor.from_tensor(rank)
clicked_ragged = tf.RaggedTensor.from_tensor(clicked)
text_match_score_ragged = tf.RaggedTensor.from_tensor(text_match_score)
# Convert RaggedTensors to padded dense tensors
rank_padded = rank_ragged.to_tensor(default_value=0)
clicked_padded = clicked_ragged.to_tensor(default_value=0)
text_match_padded = text_match_score_ragged.to_tensor(default_value=0.0)
# Run inference
predictions = model({
'rank': rank_padded,
'clicked': clicked_padded,
'text_match_score': text_match_padded
})
return {'ranking_score': predictions}
# Build and compile the model
model = build_model()
model.compile(optimizer='adam', loss='binary_crossentropy')
# Dummy input data to initialize variables
import numpy as np
# Create dummy data matching the input shapes
dummy_rank = np.array([[1]], dtype=np.int64)
dummy_clicked = np.array([[0]], dtype=np.int64)
dummy_text_match_score = np.array([[0.5]], dtype=np.float32)
# Run a dummy inference to initialize variables
model({
'rank': dummy_rank,
'clicked': dummy_clicked,
'text_match_score': dummy_text_match_score
})
# Save the model with the custom serving function
tf.saved_model.save(
model,
export_dir='saved_model',
signatures={'serving_tfrecord': serving_tfrecord_fn}
)
Scala code to load and run inference:
import scala.collection.JavaConverters._
import org.junit.Test
import org.tensorflow.{SavedModelBundle, Tensor, Session}
import org.tensorflow.example._
import org.tensorflow.types.{TFloat32, TString}
import com.google.protobuf.ByteString
import org.tensorflow.ndarray.buffer.{DataBuffers, FloatDataBuffer}
import scala.collection.JavaConverters._
class TensorFlowInference {
@Test
def testLoadInference(): Unit = {
// Update the model path
val modelPath = PATH
// Load the model with the "serve" tag
val savedModelBundle = SavedModelBundle.load(modelPath, "serve")
// Access the graph
val graph: Graph = savedModel.graph()
println("Listing all operations (nodes) in the model:")
// Iterate through all operations and print their names
for (op <- graph.operations().asScala) {
println(op.name())
}
// Close the SavedModel to free resources
savedModel.close()
}
// Create a session from the SavedModelBundle
val session = savedModelBundle.session()
// Initialize the SequenceExample builder
val sequenceExampleBuilder = SequenceExample.newBuilder()
// **Context Features**
val contextFeatures = Features.newBuilder()
// Add 'query_id' context feature (string)
contextFeatures.putFeature(
"query_id",
Feature.newBuilder()
.setBytesList(
BytesList.newBuilder()
.addValue(ByteString.copyFromUtf8("dummy_query_id"))
)
.build()
)
sequenceExampleBuilder.setContext(contextFeatures)
// **Sequence Features**
val featureLists = FeatureLists.newBuilder()
// Helper function to create a FeatureList with a single feature
def createFeatureList(values: Seq[Any]): FeatureList = {
val featureListBuilder = FeatureList.newBuilder()
values.foreach {
case v: Int =>
featureListBuilder.addFeature(
Feature.newBuilder()
.setInt64List(Int64List.newBuilder().addValue(v))
.build()
)
case v: Long =>
featureListBuilder.addFeature(
Feature.newBuilder()
.setInt64List(Int64List.newBuilder().addValue(v))
.build()
)
case v: Float =>
featureListBuilder.addFeature(
Feature.newBuilder()
.setFloatList(FloatList.newBuilder().addValue(v))
.build()
)
case v: Double =>
featureListBuilder.addFeature(
Feature.newBuilder()
.setFloatList(FloatList.newBuilder().addValue(v.toFloat))
.build()
)
case v: String =>
featureListBuilder.addFeature(
Feature.newBuilder()
.setBytesList(BytesList.newBuilder().addValue(ByteString.copyFromUtf8(v)))
.build()
)
case _ =>
throw new IllegalArgumentException("Unsupported value type")
}
featureListBuilder.build()
}
// Add 'rank' sequence feature (int64)
featureLists.putFeatureList(
"rank",
createFeatureList(Seq(1L)) // Use Long values for int64
)
// Add 'clicked' sequence feature (int64)
featureLists.putFeatureList(
"clicked",
createFeatureList(Seq(0L))
)
// Add 'text_match_score' sequence feature (float)
featureLists.putFeatureList(
"text_match_score",
createFeatureList(Seq(0.5f))
)
sequenceExampleBuilder.setFeatureLists(featureLists)
// Build the SequenceExample
val sequenceExample = sequenceExampleBuilder.build()
// Serialize the SequenceExample to a byte array
val serializedExampleBytes = sequenceExample.toByteArray
// Convert bytes to a string using ISO-8859-1 encoding to preserve raw bytes
val serializedExampleString = new String(serializedExampleBytes, "ISO-8859-1")
// Create a tensor with serialized example string
val tensorInput = TString.vectorOf(serializedExampleString)
// Input and Output Tensor Names
val inputTensorName = "serving_tfrecord_protos"
val outputTensorName = "StatefulPartitionedCall"
// Feed the input tensor and fetch the output
val outputTensors = session.runner()
.feed(inputTensorName, tensorInput)
.fetch(outputTensorName)
.run()
// Extract the output tensor
val outputTensor = outputTensors.get(0).asInstanceOf[TFloat32]
try {
// Prepare a FloatDataBuffer backed by an Array[Float]
val numElements = outputTensor.shape().size(0).toInt * outputTensor.shape().size(1).toInt
val outputData = new Array[Float](numElements)
val dataBuffer: FloatDataBuffer = DataBuffers.of(outputData, false, false)
outputTensor.copyTo(dataBuffer)
// Print the output scores
println(s"Model output: ${outputData.mkString(", ")}")
} finally {
// Close the Result and other resources
outputTensors.close()
tensorInput.close()
savedModelBundle.close()
}
}
}