Java api model loading and inference

When I save a tf model in python and try to load it and run inference using java API it fails complaining about missing parts of the graph.

error:
org.tensorflow.exceptions.TFFailedPreconditionException:
Could not find variable dense /kernel . This could mean that the variable has been
deleted. In TF1, it can also mean the variable is uninitialized.
Debug info: container=localhost, status error message=Resource
localhost/dense/kernel/N10tensorflow3VarE does not exist.
[error] [[{{function_node __inference_serving_tfrecord_fn_314}}
{{node functional_1/dense_1/Cast/ReadVariableOp}}]]

I’m adding code to reproduce the problem:

Python code using tf 2.16.2

import tensorflow as tf

# Define the feature specifications for parsing
feature_spec = {
    'query_id': tf.io.FixedLenFeature([], tf.string),
    'rank': tf.io.FixedLenSequenceFeature([], tf.int64, allow_missing=True),
    'clicked': tf.io.FixedLenSequenceFeature([], tf.int64, allow_missing=True),
    'text_match_score': tf.io.FixedLenSequenceFeature([], tf.float32, allow_missing=True),
}

def build_model():
    # Input layers
    rank_input = tf.keras.Input(shape=(None,), name='rank', dtype=tf.int64)
    clicked_input = tf.keras.Input(shape=(None,), name='clicked', dtype=tf.int64)
    text_match_score_input = tf.keras.Input(shape=(None,), name='text_match_score', dtype=tf.float32)

    # Cast inputs to float32 using Lambda layers
    rank_float = tf.keras.layers.Lambda(lambda x: tf.cast(x, tf.float32))(rank_input)
    clicked_float = tf.keras.layers.Lambda(lambda x: tf.cast(x, tf.float32))(clicked_input)

    # Expand dimensions using Lambda layers
    rank_expanded = tf.keras.layers.Lambda(lambda x: tf.expand_dims(x, axis=-1))(rank_float)
    clicked_expanded = tf.keras.layers.Lambda(lambda x: tf.expand_dims(x, axis=-1))(clicked_float)
    text_match_expanded = tf.keras.layers.Lambda(lambda x: tf.expand_dims(x, axis=-1))(text_match_score_input)

    # Apply GlobalAveragePooling1D to reduce sequence dimension
    rank_pooled = tf.keras.layers.GlobalAveragePooling1D()(rank_expanded)
    clicked_pooled = tf.keras.layers.GlobalAveragePooling1D()(clicked_expanded)
    text_match_pooled = tf.keras.layers.GlobalAveragePooling1D()(text_match_expanded)

    # Concatenate pooled features
    concatenated = tf.keras.layers.Concatenate()([
        rank_pooled,
        clicked_pooled,
        text_match_pooled
    ])

    # Dense layers
    x = tf.keras.layers.Dense(16, activation='relu')(concatenated)
    x = tf.keras.layers.Dense(8, activation='relu')(x)
    output = tf.keras.layers.Dense(1, activation='sigmoid', name='ranking_score')(x)

    # Build the model
    model = tf.keras.Model(inputs=[rank_input, clicked_input, text_match_score_input], outputs=output)
    return model

@tf.function(input_signature=[tf.TensorSpec([None], tf.string, name='protos')])
def serving_tfrecord_fn(serialized_examples):
    # Parse the serialized SequenceExample
    parsed_features = tf.io.parse_example(serialized_examples, features=feature_spec)

    # Extract the features
    rank = parsed_features['rank']
    clicked = parsed_features['clicked']
    text_match_score = parsed_features['text_match_score']

    # Convert to RaggedTensors
    rank_ragged = tf.RaggedTensor.from_tensor(rank)
    clicked_ragged = tf.RaggedTensor.from_tensor(clicked)
    text_match_score_ragged = tf.RaggedTensor.from_tensor(text_match_score)

    # Convert RaggedTensors to padded dense tensors
    rank_padded = rank_ragged.to_tensor(default_value=0)
    clicked_padded = clicked_ragged.to_tensor(default_value=0)
    text_match_padded = text_match_score_ragged.to_tensor(default_value=0.0)

    # Run inference
    predictions = model({
        'rank': rank_padded,
        'clicked': clicked_padded,
        'text_match_score': text_match_padded
    })

    return {'ranking_score': predictions}

# Build and compile the model
model = build_model()
model.compile(optimizer='adam', loss='binary_crossentropy')


# Dummy input data to initialize variables
import numpy as np

# Create dummy data matching the input shapes
dummy_rank = np.array([[1]], dtype=np.int64)
dummy_clicked = np.array([[0]], dtype=np.int64)
dummy_text_match_score = np.array([[0.5]], dtype=np.float32)

# Run a dummy inference to initialize variables
model({
    'rank': dummy_rank,
    'clicked': dummy_clicked,
    'text_match_score': dummy_text_match_score
})


# Save the model with the custom serving function
tf.saved_model.save(
    model,
    export_dir='saved_model',
    signatures={'serving_tfrecord': serving_tfrecord_fn}
)

Scala code to load and run inference:

import scala.collection.JavaConverters._
import org.junit.Test
import org.tensorflow.{SavedModelBundle, Tensor, Session}
import org.tensorflow.example._
import org.tensorflow.types.{TFloat32, TString}
import com.google.protobuf.ByteString
import org.tensorflow.ndarray.buffer.{DataBuffers, FloatDataBuffer}
import scala.collection.JavaConverters._


class TensorFlowInference {

  @Test
  def testLoadInference(): Unit = {
    // Update the model path
    val modelPath = PATH
    // Load the model with the "serve" tag
    val savedModelBundle = SavedModelBundle.load(modelPath, "serve")

    // Access the graph
    val graph: Graph = savedModel.graph()

    println("Listing all operations (nodes) in the model:")

    // Iterate through all operations and print their names
    for (op <- graph.operations().asScala) {
      println(op.name())
    }

    // Close the SavedModel to free resources
    savedModel.close()
  }

    // Create a session from the SavedModelBundle
    val session = savedModelBundle.session()


    // Initialize the SequenceExample builder
    val sequenceExampleBuilder = SequenceExample.newBuilder()

    // **Context Features**
    val contextFeatures = Features.newBuilder()

    // Add 'query_id' context feature (string)
    contextFeatures.putFeature(
      "query_id",
      Feature.newBuilder()
        .setBytesList(
          BytesList.newBuilder()
            .addValue(ByteString.copyFromUtf8("dummy_query_id"))
        )
        .build()
    )

    sequenceExampleBuilder.setContext(contextFeatures)

    // **Sequence Features**
    val featureLists = FeatureLists.newBuilder()

    // Helper function to create a FeatureList with a single feature
    def createFeatureList(values: Seq[Any]): FeatureList = {
      val featureListBuilder = FeatureList.newBuilder()
      values.foreach {
        case v: Int =>
          featureListBuilder.addFeature(
            Feature.newBuilder()
              .setInt64List(Int64List.newBuilder().addValue(v))
              .build()
          )
        case v: Long =>
          featureListBuilder.addFeature(
            Feature.newBuilder()
              .setInt64List(Int64List.newBuilder().addValue(v))
              .build()
          )
        case v: Float =>
          featureListBuilder.addFeature(
            Feature.newBuilder()
              .setFloatList(FloatList.newBuilder().addValue(v))
              .build()
          )
        case v: Double =>
          featureListBuilder.addFeature(
            Feature.newBuilder()
              .setFloatList(FloatList.newBuilder().addValue(v.toFloat))
              .build()
          )
        case v: String =>
          featureListBuilder.addFeature(
            Feature.newBuilder()
              .setBytesList(BytesList.newBuilder().addValue(ByteString.copyFromUtf8(v)))
              .build()
          )
        case _ =>
          throw new IllegalArgumentException("Unsupported value type")
      }
      featureListBuilder.build()
    }

    // Add 'rank' sequence feature (int64)
    featureLists.putFeatureList(
      "rank",
      createFeatureList(Seq(1L)) // Use Long values for int64
    )

    // Add 'clicked' sequence feature (int64)
    featureLists.putFeatureList(
      "clicked",
      createFeatureList(Seq(0L))
    )

    // Add 'text_match_score' sequence feature (float)
    featureLists.putFeatureList(
      "text_match_score",
      createFeatureList(Seq(0.5f))
    )

    sequenceExampleBuilder.setFeatureLists(featureLists)

    // Build the SequenceExample
    val sequenceExample = sequenceExampleBuilder.build()

    // Serialize the SequenceExample to a byte array
    val serializedExampleBytes = sequenceExample.toByteArray

    // Convert bytes to a string using ISO-8859-1 encoding to preserve raw bytes
    val serializedExampleString = new String(serializedExampleBytes, "ISO-8859-1")

    // Create a tensor with serialized example string
    val tensorInput = TString.vectorOf(serializedExampleString)

    // Input and Output Tensor Names
    val inputTensorName = "serving_tfrecord_protos"
    val outputTensorName = "StatefulPartitionedCall"

    // Feed the input tensor and fetch the output
    val outputTensors = session.runner()
      .feed(inputTensorName, tensorInput)
      .fetch(outputTensorName)
      .run()

    // Extract the output tensor
    val outputTensor = outputTensors.get(0).asInstanceOf[TFloat32]

    try {
      // Prepare a FloatDataBuffer backed by an Array[Float]
      val numElements = outputTensor.shape().size(0).toInt * outputTensor.shape().size(1).toInt
      val outputData = new Array[Float](numElements)
      val dataBuffer: FloatDataBuffer = DataBuffers.of(outputData, false, false)
      outputTensor.copyTo(dataBuffer)

      // Print the output scores
      println(s"Model output: ${outputData.mkString(", ")}")
    } finally {
      // Close the Result and other resources
      outputTensors.close()
      tensorInput.close()
      savedModelBundle.close()
    }
  }
}