I followed the steps as outlined in the following documentation: Training Custom Object Detector — TensorFlow 2 Object Detection API tutorial documentation to train my custom object detection model. For reference, I’m using TF 2.10. However, after converting it to a tflite model and implementing it in an android application in Java, I get the following error:
EXCEPTION: Failed on interpreter inference -> Cannot copy from a TensorFlowLite tensor (StatefulPartionedCall:1) with shape [1,10] to a Java object with shape [1,10,4].
Prior to TensorFlow 2.6, the metadata order was boxes, classes, scores, number of detections. Now, it seems to have changed to scores, boxes, number of detections, classes.
I have tried two things: 1) downgrading to TF2.5 which solves this problem but raises incompatibility issues with other libraries so I do not prefer this method. 2) Declared the sequence of outputs explicitly using metadata writer based on one of the suggestions on here; however, this still raises the same exception as stated above. After loading the model (after the metadata writer process) and inspecting the output details, I see the following:
[{'name': 'StatefulPartitionedCall:1', 'index': 249, 'shape': array([ 1, 10]), 'shape_signature': array([ 1, 10]), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}, {'name': 'StatefulPartitionedCall:3', 'index': 247, 'shape': array([ 1, 10, 4]), 'shape_signature': array([ 1, 10, 4]), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}, {'name': 'StatefulPartitionedCall:0', 'index': 250, 'shape': array([1]), 'shape_signature': array([1]), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}, {'name': 'StatefulPartitionedCall:2', 'index': 248, 'shape': array([ 1, 10]), 'shape_signature': array([ 1, 10]), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}]
The order of the shapes displayed still do not match the order of boxes, classes, scores, number of detections. Without having to modify the android app code, is there anything else that can be done to avoid the distortion of the output shape during the tflite conversion?