Please could some help or advice be provided. I am working on a project which uses a tflite model I have trained and tested in Python and tried to integrate into my Android Studio Java Application. Everything seems to be fine at first glance but for some reason, the output of the model is the ‘other’ class every time without fail and this is with checking that the input data is of what looks like the right form, shape etc. the code is here for the model:
public class MainActivity extends AppCompatActivity {
private TextView classificationTextView;
private List<String[]> csvData;
private int currentRowIndex = -1;
private LighterModel model;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
EdgeToEdge.enable(this);
setContentView(R.layout.activity_main);
classificationTextView = findViewById(R.id.classification);
Button uploadButton = findViewById(R.id.uploadButton);
classificationTextView.setText("Classification: Not Uploaded");
// Load CSV data upon starting the app
csvData = readCSVFromAssets();
// Initialize the TFLite model
try {
model = LighterModel.newInstance(MainActivity.this);
} catch (IOException e) {
e.printStackTrace();
Toast.makeText(MainActivity.this, "Error loading the model: " + e.getMessage(), Toast.LENGTH_SHORT).show();
}
uploadButton.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View v) {
Pair<ByteBuffer, String> inputData = loadInputData();
ByteBuffer byteBuffer = inputData.first;
String actualLabel = inputData.second;
if (byteBuffer != null) {
TensorBuffer inputFeature0 = TensorBuffer.createFixedSize(new int[]{1, 2160, 1}, DataType.FLOAT32);
inputFeature0.loadBuffer(byteBuffer);
Log.d("InputShape", Arrays.toString(inputFeature0.getShape()));
Log.d("InputData", Arrays.toString(inputFeature0.getFloatArray()));
Log.d("InputLabel",actualLabel);
LighterModel.Outputs outputs = model.process(inputFeature0);
TensorBuffer outputFeature0 = outputs.getOutputFeature0AsTensorBuffer();
Log.d("OutputShape", Arrays.toString(outputFeature0.getShape()));
Log.d("RawOutput", Arrays.toString(outputFeature0.getFloatArray()));
float[] outputArray = outputFeature0.getFloatArray();
Log.d("Model Output", Arrays.toString(outputArray));
// Display the actual label and predicted class
String outputText = "Actual Label: " + actualLabel + "\nPrediction: " + Arrays.toString(outputArray);
classificationTextView.setText(outputText);
}
}
});
ViewCompat.setOnApplyWindowInsetsListener(findViewById(R.id.main), (v, insets) -> {
Insets systemBars = insets.getInsets(WindowInsetsCompat.Type.systemBars());
v.setPadding(systemBars.left, systemBars.top, systemBars.right, systemBars.bottom);
return insets;
});
}
@Override
protected void onPause() {
super.onPause();
// Release model resources when the activity is paused
if (model != null) {
model.close();
model = null;
}
}
private Pair<ByteBuffer, String> loadInputData() {
int numRows = csvData.size() - 1;
currentRowIndex = (currentRowIndex + 1) % numRows;
String[] selectedRow = csvData.get(currentRowIndex + 1);
String label = selectedRow[selectedRow.length - 1];
int numCols = selectedRow.length - 1;
ByteBuffer byteBuffer = ByteBuffer.allocateDirect(numCols * Float.BYTES);
for (int i = 0; i < numCols; i++) {
byteBuffer.putFloat(Float.parseFloat(selectedRow[i]));
}
byteBuffer.rewind();
return new Pair<>(byteBuffer, label);
}
and the output of the various logs look like this:
InputShape [1, 2160, 1]
InputData [-0.245, -0.25, -0.245, -0.25, -0.255, -0.25, -0.235, -0.215, -0.195, -0.21, -0.2, -0.195, -0.17, -0.145, -0.15, -0.155, -0.16, -0.155, -0.145, -0.145, -0.145, -0.145, -0.15, -0.1…
InputLabel 4
OutputShape [1, 5]
RawOutput [0.0, 0.0, 0.0, 0.0, 1.0]
Model Output [0.0, 0.0, 0.0, 0.0, 1.0]