Can I use Tf native core platform in JAVA for ScaNN based model?

Hi All , I trained a ScaNN powered model described here in Python.

Tf version: 2.7.0
Environment: Mac and Ubuntu

I am trying to infer it in JAVA using native tensorflow-core-platform version 0.4.0

Java code:

        URL modelURL = Main.class.getClassLoader().getResource("model/1");
        String modelPath = Paths.get(modelURL.toURI()).toString();

        SavedModelBundle model = SavedModelBundle.load(modelPath, "serve");

I am getting error Op type not registered 'Scann>ScannSearchBatched' in binary running on xx. Make sure the Op and Kernel are registered in the binary running in this process. Note that if you are loading a saved graph which used ops from tf.contrib, accessing (e.g.) tf.contrib.resampler should be done before importing the graph, as contrib ops are lazily registered when the module is first accessed.

Sometimes I get TensorFlowException: Could not find SavedModel .pb or .pbtxt at supplied export directory path

Please let me know what’s going on?
I know Mac is not supported by ScaNN but its not working on Ubuntu too.
So Can I use Tf native core platform in JAVA for Scann based model?

1 Like

I’ve never played with Scann myself but the error message says that this library is providing custom ops that first need to be registered in TensorFlow.

To do this in Java, you must first load explicitly the native library that are providing these custom ops by calling this method. Normally the library can be extracted from the Python wheel, I’ve looked at its content and I guess it is probably the scann/scann_ops/cc/_scann_ops.so file.

So to summarize:

TensorFlow.loadLibrary("/path/to/python/wheel/scann/scann_ops/cc/_scann_ops.so");

try (SavedModelBundle model = SavedModelBundle.load(modelPath, "serve")) {
   ...
}

Please share with us if that worked for you

1 Like

@karllessard : I am now getting

2023-01-09 05:12:35.533897: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:301] SavedModel load for tags { serve }; Status: fail: NOT_FOUND: Could not find SavedModel .pb or .pbtxt at supplied export directory path: model/1. Took 1004 microseconds.
Exception in thread "main" org.tensorflow.exceptions.TensorFlowException: Could not find SavedModel .pb or .pbtxt at supplied export directory path: model/1

This is my Docker file:

# Start with a base image containing Java runtime
FROM adoptopenjdk/openjdk11

# Install required compiler tools
RUN apt-get update && apt-get install -y software-properties-common

# Install python 3.7 and pip
RUN apt-get update && apt-get install -y python3.8 python3-pip && apt-get update

# Install scann
RUN python3.8 -m pip install python-dev-tools
RUN python3.8 -m pip install tensorflow==2.7.0 && pip install scann==1.2.3
# Add a volume pointing to /tmp
VOLUME /tmp

# Make port 8080 available to the world outside this container
EXPOSE 8080


RUN mkdir /app
WORKDIR /app
COPY . /app
RUN ./gradlew clean build


COPY build/libs/core-tf-1.0-SNAPSHOT.jar /opt/core_tf.jar
CMD ["java","-jar","/opt/core_tf.jar"]

This is my main code:

        TensorFlow.loadLibrary(" /usr/local/lib/python3.8/dist-packages/scann/scann_ops/cc/_scann_ops.so");
        String filePath = "model/1";
        try (SavedModelBundle model = SavedModelBundle.load(filePath, "serve")) {
            // Set up input tensors
            TString inputArray = TString.vectorOf("42");
            Tensor inputTensor = TString.tensorOf(inputArray);
            System.out.println(inputTensor.shape());

            Map<String, Tensor> feed_dict = new HashMap<>();
            feed_dict.put("input_1", inputTensor);

// Run the model and get the output
            System.out.println(model.function("predict").call(feed_dict));
        }
        catch (Exception ex){
            ex.printStackTrace();
        }```

That error seems unrelated to loading scann, it cannot find the model at model/1 and I don’t see neither where you copy this file in you docker container. Make sure that container has access to your model and try to use absolute paths whenever possible. Or is the model located in the class/resource path of your JAR?

@karllessard : Yes the model is located in the resource path. (src/main/resources/model/1)

and in jar I am copying that. If I open my jar, I can see the model in root model directory.

Build.gradle

jar {duplicatesStrategy(DuplicatesStrategy.EXCLUDE)
    manifest {
        attributes(
                'Main-Class': 'org.example.Main'
        )
    }
    from {
        configurations.runtimeClasspath.filter{ it.exists() }.collect { it.isDirectory() ? it : zipTree(it) }
    }
    from 'resources', {
        into 'resources'
    }
}

Yes that won’t work like that because the SavedModelBundle expect the model path to be a directory in your file system. You need to convert the resource path to a file path. Something like that should do:

var modelPath = Main.class.getResource("/model/1").toURI().toString();
try (var model = SavedModelBundle.load(modelPath, "serve")) {
   ...
}

@karllessard : Ok Thanks!! Exec’ing into and running the jar gives me now:

Exception in thread "main" java.lang.UnsatisfiedLinkError: /usr/local/lib/python3.8/dist-packages/scann/scann_ops/cc/_scann_ops.so: undefined symbol: _ZN10tensorflow11ResourceMgr8DoCreateERKSsNS_9TypeIndexES2_PNS_12ResourceBaseE
        at org.tensorflow.TensorFlow.loadLibrary(TensorFlow.java:102)
        at org.example.Main.main(Main.java:33)

mmh, looks like the version of TensorFlow the scann library has been compiled with (2.6.0) is not compatible with the one TensorFlow Java 0.4.x is using (2.7.0)? That’s surprising but can you try with scann==1.2.5 instead (which has been compiled with TF 2.7.0)?

@karllessard : Ok the path thing still doesn’t work for me.

When passing absolute path exec’ing into the container, Now I am getting this exception:

java.lang.IllegalArgumentException: Function with signature [predict] not found
        at org.tensorflow.SavedModelBundle.function(SavedModelBundle.java:437)
        at org.example.Main.main(Main.java:45)

Do you see an obvious issue with predict call? I assume predict signature is not being exported. This is how I am exporting the savedModel: Google Colab

I tried session

try (SavedModelBundle model = SavedModelBundle.load(filePath, "serve")) {
            // Set up input tensors
            Session session = model.session();

            TString inputArray = TString.vectorOf("42");
            Tensor inputTensor = TString.tensorOf(inputArray);
            System.out.println(inputTensor.shape());

            Tensor output_tensor = session.runner().feed("input_1", inputTensor).fetch("output_1").run().get(0);

// Run the model and get the output
            System.out.println(output_tensor.toString());
        }

But I am getting java.lang.IllegalArgumentException: No Operation named [input_1] in the Graph

Here are the model specifics

signature_def['serving_default']:
  The given SavedModel SignatureDef contains the following input(s):
    inputs['input_1'] tensor_info:
        dtype: DT_STRING
        shape: (-1)
        name: serving_default_input_1:0
  The given SavedModel SignatureDef contains the following output(s):
    outputs['output_1'] tensor_info:
        dtype: DT_FLOAT
        shape: unknown_rank
        name: StatefulPartitionedCall_1:0
    outputs['output_2'] tensor_info:
        dtype: DT_STRING
        shape: unknown_rank
        name: StatefulPartitionedCall_1:1
  Method name is: tensorflow/serving/predict

Please help. I think we are close.

You’ll probably get better results if you call your model via a function instead of calling the session runner directly. In TF 0.4.0, you can try something like this:

try (SavedModelBundle model = SavedModelBundle.load(filePath, "serve")) {
    try (TString inputTensor = TString.vectorOf("42")) {
          System.out.println(inputTensor.shape());

          try (Tensor outputTensor = model.function(Signature.DEFAULT_KEY).call(Map.of("input_1", inputTensor)).get("output_1")) {
            System.out.println(((TFloat)outputTensor).getFloat());
          }

If you want to continue to use the session directly, then try to pass/retrieve the tensors by their name instead of their signature key, i.e. serving_default_input_1:0 and StatefulPartitionedCall_1:0

Thanks!! It works now.

Wonderful :tada: Please do not hesitate if you have any more questions!

@karllessard : Thanks :slight_smile: Sure!! It would be great if I can get some help on this other thread. http://discuss.ai.google.dev/t/how-to-build-a-preprocessing-layer-with-different-preprocessing-for-each-feature/14248