Tflite convert() reduces model input shape

I am converting a tf model to tflite, but the conversion seems to change the input shape for no apparent reason

import os
import tensorflow as tf
from basic_pitch import ICASSP_2022_MODEL_PATH
model_best_path = os.path.join(ICASSP_2022_MODEL_PATH, 'saved_model.pb')
model_best_dir = str(ICASSP_2022_MODEL_PATH)
model_original = tf.saved_model.load(model_best_dir)

input_original = tf.convert_to_tensor(np.ones([11,43844,1]), dtype=tf.float32)
output_original = model_original(input_original)

# Converting a SavedModel to a TensorFlow Lite model.
converter = tf.lite.TFLiteConverter.from_saved_model(model_best_dir)
tflite_model = converter.convert()
with open('quantized_model.tflite', 'wb') as f:
    f.write(tflite_model)

interpreter = tf.lite.Interpreter(model_path='quantized_model.tflite')
interpreter.allocate_tensors()
input_quantized = interpreter.get_input_details()

print(input_quantized[0]['shape'], input_original.shape)

The output is

[    1 43844     1] (11, 43844, 1)

Notice that I’ve utilized from_saved_model() instead of from_keras_model(). Why is this happening?

This seems to solve the issue

    interpreter = tf.lite.Interpreter(model_path='quantized_model.tflite')
    interpreter.resize_tensor_input(0, audio_windowed.shape, strict=True)
    interpreter.allocate_tensors()

but now the output dimension is smaller than the original, when doing

    interpreter.invoke()
    output = interpreter.get_tensor(output_details[0]['index'])

meaning that it should be [ 3, 172, 88] while on the tf lite model is [ 1, 172, 88].
I could not find any equivalent command to fix the output shape.

Hi @Phys, AFAIK, when you convert a model to tflite the batch dimension will be automatically set to one because while making inference using tflite model, the model accepts only 1 input data at a time. Thank You.

Hi,
I was trying to convert a bert pretrained model to tflite model, and on convertion the shape becomes [1,1].
Is this an issue with the conversion, or is this expected?

{‘name’: ‘serving_default_attention_mask:0’, ‘index’: 0, ‘shape’: array([1, 1]), ‘shape_signature’: array([-1, -1]), ‘dtype’: <class ‘numpy.int32’>, ‘quantization’: (0.0, 0), ‘quantization_parameters’: {‘scales’: array(, dtype=float32), ‘zero_points’: array(, dtype=int32), ‘quantized_dimension’: 0}, ‘sparsity_parameters’: {}}