Java api model loading and inference

When I save a tf model in python and try to load it and run inference using java API it fails complaining about missing parts of the graph.

error:
org.tensorflow.exceptions.TFFailedPreconditionException:
Could not find variable dense /kernel . This could mean that the variable has been
deleted. In TF1, it can also mean the variable is uninitialized.
Debug info: container=localhost, status error message=Resource
localhost/dense/kernel/N10tensorflow3VarE does not exist.
[error] [[{{function_node __inference_serving_tfrecord_fn_314}}
{{node functional_1/dense_1/Cast/ReadVariableOp}}]]

I’m adding code to reproduce the problem:

Python code using tf 2.16.2

import tensorflow as tf

# Define the feature specifications for parsing
feature_spec = {
    'query_id': tf.io.FixedLenFeature([], tf.string),
    'rank': tf.io.FixedLenSequenceFeature([], tf.int64, allow_missing=True),
    'clicked': tf.io.FixedLenSequenceFeature([], tf.int64, allow_missing=True),
    'text_match_score': tf.io.FixedLenSequenceFeature([], tf.float32, allow_missing=True),
}

def build_model():
    # Input layers
    rank_input = tf.keras.Input(shape=(None,), name='rank', dtype=tf.int64)
    clicked_input = tf.keras.Input(shape=(None,), name='clicked', dtype=tf.int64)
    text_match_score_input = tf.keras.Input(shape=(None,), name='text_match_score', dtype=tf.float32)

    # Cast inputs to float32 using Lambda layers
    rank_float = tf.keras.layers.Lambda(lambda x: tf.cast(x, tf.float32))(rank_input)
    clicked_float = tf.keras.layers.Lambda(lambda x: tf.cast(x, tf.float32))(clicked_input)

    # Expand dimensions using Lambda layers
    rank_expanded = tf.keras.layers.Lambda(lambda x: tf.expand_dims(x, axis=-1))(rank_float)
    clicked_expanded = tf.keras.layers.Lambda(lambda x: tf.expand_dims(x, axis=-1))(clicked_float)
    text_match_expanded = tf.keras.layers.Lambda(lambda x: tf.expand_dims(x, axis=-1))(text_match_score_input)

    # Apply GlobalAveragePooling1D to reduce sequence dimension
    rank_pooled = tf.keras.layers.GlobalAveragePooling1D()(rank_expanded)
    clicked_pooled = tf.keras.layers.GlobalAveragePooling1D()(clicked_expanded)
    text_match_pooled = tf.keras.layers.GlobalAveragePooling1D()(text_match_expanded)

    # Concatenate pooled features
    concatenated = tf.keras.layers.Concatenate()([
        rank_pooled,
        clicked_pooled,
        text_match_pooled
    ])

    # Dense layers
    x = tf.keras.layers.Dense(16, activation='relu')(concatenated)
    x = tf.keras.layers.Dense(8, activation='relu')(x)
    output = tf.keras.layers.Dense(1, activation='sigmoid', name='ranking_score')(x)

    # Build the model
    model = tf.keras.Model(inputs=[rank_input, clicked_input, text_match_score_input], outputs=output)
    return model

@tf.function(input_signature=[tf.TensorSpec([None], tf.string, name='protos')])
def serving_tfrecord_fn(serialized_examples):
    # Parse the serialized SequenceExample
    parsed_features = tf.io.parse_example(serialized_examples, features=feature_spec)

    # Extract the features
    rank = parsed_features['rank']
    clicked = parsed_features['clicked']
    text_match_score = parsed_features['text_match_score']

    # Convert to RaggedTensors
    rank_ragged = tf.RaggedTensor.from_tensor(rank)
    clicked_ragged = tf.RaggedTensor.from_tensor(clicked)
    text_match_score_ragged = tf.RaggedTensor.from_tensor(text_match_score)

    # Convert RaggedTensors to padded dense tensors
    rank_padded = rank_ragged.to_tensor(default_value=0)
    clicked_padded = clicked_ragged.to_tensor(default_value=0)
    text_match_padded = text_match_score_ragged.to_tensor(default_value=0.0)

    # Run inference
    predictions = model({
        'rank': rank_padded,
        'clicked': clicked_padded,
        'text_match_score': text_match_padded
    })

    return {'ranking_score': predictions}

# Build and compile the model
model = build_model()
model.compile(optimizer='adam', loss='binary_crossentropy')


# Dummy input data to initialize variables
import numpy as np

# Create dummy data matching the input shapes
dummy_rank = np.array([[1]], dtype=np.int64)
dummy_clicked = np.array([[0]], dtype=np.int64)
dummy_text_match_score = np.array([[0.5]], dtype=np.float32)

# Run a dummy inference to initialize variables
model({
    'rank': dummy_rank,
    'clicked': dummy_clicked,
    'text_match_score': dummy_text_match_score
})


# Save the model with the custom serving function
tf.saved_model.save(
    model,
    export_dir='saved_model',
    signatures={'serving_tfrecord': serving_tfrecord_fn}
)

Scala code to load and run inference:

import scala.collection.JavaConverters._
import org.junit.Test
import org.tensorflow.{SavedModelBundle, Tensor, Session}
import org.tensorflow.example._
import org.tensorflow.types.{TFloat32, TString}
import com.google.protobuf.ByteString
import org.tensorflow.ndarray.buffer.{DataBuffers, FloatDataBuffer}
import scala.collection.JavaConverters._


class TensorFlowInference {

  @Test
  def testLoadInference(): Unit = {
    // Update the model path
    val modelPath = PATH
    // Load the model with the "serve" tag
    val savedModelBundle = SavedModelBundle.load(modelPath, "serve")

    // Access the graph
    val graph: Graph = savedModel.graph()

    println("Listing all operations (nodes) in the model:")

    // Iterate through all operations and print their names
    for (op <- graph.operations().asScala) {
      println(op.name())
    }

    // Close the SavedModel to free resources
    savedModel.close()
  }

    // Create a session from the SavedModelBundle
    val session = savedModelBundle.session()


    // Initialize the SequenceExample builder
    val sequenceExampleBuilder = SequenceExample.newBuilder()

    // **Context Features**
    val contextFeatures = Features.newBuilder()

    // Add 'query_id' context feature (string)
    contextFeatures.putFeature(
      "query_id",
      Feature.newBuilder()
        .setBytesList(
          BytesList.newBuilder()
            .addValue(ByteString.copyFromUtf8("dummy_query_id"))
        )
        .build()
    )

    sequenceExampleBuilder.setContext(contextFeatures)

    // **Sequence Features**
    val featureLists = FeatureLists.newBuilder()

    // Helper function to create a FeatureList with a single feature
    def createFeatureList(values: Seq[Any]): FeatureList = {
      val featureListBuilder = FeatureList.newBuilder()
      values.foreach {
        case v: Int =>
          featureListBuilder.addFeature(
            Feature.newBuilder()
              .setInt64List(Int64List.newBuilder().addValue(v))
              .build()
          )
        case v: Long =>
          featureListBuilder.addFeature(
            Feature.newBuilder()
              .setInt64List(Int64List.newBuilder().addValue(v))
              .build()
          )
        case v: Float =>
          featureListBuilder.addFeature(
            Feature.newBuilder()
              .setFloatList(FloatList.newBuilder().addValue(v))
              .build()
          )
        case v: Double =>
          featureListBuilder.addFeature(
            Feature.newBuilder()
              .setFloatList(FloatList.newBuilder().addValue(v.toFloat))
              .build()
          )
        case v: String =>
          featureListBuilder.addFeature(
            Feature.newBuilder()
              .setBytesList(BytesList.newBuilder().addValue(ByteString.copyFromUtf8(v)))
              .build()
          )
        case _ =>
          throw new IllegalArgumentException("Unsupported value type")
      }
      featureListBuilder.build()
    }

    // Add 'rank' sequence feature (int64)
    featureLists.putFeatureList(
      "rank",
      createFeatureList(Seq(1L)) // Use Long values for int64
    )

    // Add 'clicked' sequence feature (int64)
    featureLists.putFeatureList(
      "clicked",
      createFeatureList(Seq(0L))
    )

    // Add 'text_match_score' sequence feature (float)
    featureLists.putFeatureList(
      "text_match_score",
      createFeatureList(Seq(0.5f))
    )

    sequenceExampleBuilder.setFeatureLists(featureLists)

    // Build the SequenceExample
    val sequenceExample = sequenceExampleBuilder.build()

    // Serialize the SequenceExample to a byte array
    val serializedExampleBytes = sequenceExample.toByteArray

    // Convert bytes to a string using ISO-8859-1 encoding to preserve raw bytes
    val serializedExampleString = new String(serializedExampleBytes, "ISO-8859-1")

    // Create a tensor with serialized example string
    val tensorInput = TString.vectorOf(serializedExampleString)

    // Input and Output Tensor Names
    val inputTensorName = "serving_tfrecord_protos"
    val outputTensorName = "StatefulPartitionedCall"

    // Feed the input tensor and fetch the output
    val outputTensors = session.runner()
      .feed(inputTensorName, tensorInput)
      .fetch(outputTensorName)
      .run()

    // Extract the output tensor
    val outputTensor = outputTensors.get(0).asInstanceOf[TFloat32]

    try {
      // Prepare a FloatDataBuffer backed by an Array[Float]
      val numElements = outputTensor.shape().size(0).toInt * outputTensor.shape().size(1).toInt
      val outputData = new Array[Float](numElements)
      val dataBuffer: FloatDataBuffer = DataBuffers.of(outputData, false, false)
      outputTensor.copyTo(dataBuffer)

      // Print the output scores
      println(s"Model output: ${outputData.mkString(", ")}")
    } finally {
      // Close the Result and other resources
      outputTensors.close()
      tensorInput.close()
      savedModelBundle.close()
    }
  }
}

Which version of TF Java you are using? Make sure that you use the official version and not the legacy one that is still in the TF main repo, this one won’t work with TF2.x

Also I don’t see anywhere where you fit your model, is this desired? Not sure all variables are initialized if that is the case. And last, since you work with Keras, you might want to try saving your model using Keras API, i.e. model.save('/path/, save_format='tf')