The accuracy between the Android environment and the Windows environment differs significantly


This is the model I built. Here, the input is (batch_size, 300, 3) and it takes 4 inputs in total. If I put in a random input like this, the output is the probability of the 7 classes we want.

# python
num_input = 4
list_input = []

for _ in range(num_input):
    x = np.random.rand(1,300,3)
    x = tf.convert_to_tensor(x)
    list_input.append(x)

output = multi_cnn(list_input)

스크린샷 2023-10-23 215228

What I want to do is convert this model to tflite so that it can be used in the android environment.

So I converted it to tflite format, imported it back into the windows environment and measured the accuracy with the test dataset that I have. Here’s the code

import numpy as np
import tensorflow as tf
import pandas as pd

x_test = pd.read_csv('.\\data2\\x_test.csv')
y_test = pd.read_csv('.\\data2\\y_test.csv')

interpreter2 = tf.lite.Interpreter(model_path='.\\tflite\\multi_model_h5_ver2.tflite')
interpreter2.allocate_tensors()

input_details = interpreter2.get_input_details()
output_details = interpreter2.get_output_details()

input_shape = input_details[0]['shape']
output_shape = output_details[0]['shape']

x_test = x_test.to_numpy(dtype=np.float32)

output = []

x_test = x_test.reshape(-1,300,12)

for input_data in x_test:
    
    input_data = input_data.reshape(1,300,12)
    
    x1 = input_data[:,:,0:3]
    x2 = input_data[:,:,3:6]
    x3 = input_data[:,:,6:9]
    x4 = input_data[:,:,9:12]
    
    interpreter2.set_tensor(input_details[0]['index'], x1)
    interpreter2.set_tensor(input_details[1]['index'], x2)
    interpreter2.set_tensor(input_details[2]['index'], x3)
    interpreter2.set_tensor(input_details[3]['index'], x4)
    
    interpreter2.invoke()

    output_data = interpreter2.get_tensor(output_details[0]['index'])
    output.append(output_data)

When we do this, we can get an accuracy of about 54%.
Now let’s run this part on android.

Here’s the code

package com.example.accuracytest;

import androidx.appcompat.app.AppCompatActivity;

import android.content.Context;
import android.os.AsyncTask;
import android.os.Build;
import android.os.Bundle;

import android.content.res.AssetFileDescriptor;
import android.util.Log;
import android.view.View;
import android.widget.Button;
import android.widget.TextView;

import org.tensorflow.lite.Interpreter;
import org.tensorflow.lite.nnapi.NnApiDelegate;

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.lang.reflect.Array;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Queue;

public class MainActivity extends AppCompatActivity {
    private TextView resultView;
    private Button startButton;
    private MultiModel cls;

    String csvXFile = "x1_train_ver2.csv";
    String csvYFile = "y1_train_ver2.csv";
    float answerSize = 0;
    int rightAnswer = 0;

    //Initialize View, Model
    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);
        resultView = findViewById(R.id.resultView);
        startButton = findViewById(R.id.startButton);

        cls = new MultiModel(this);

        try {
            cls.init();
        } catch (IOException e) {
            Log.d("Model", "Load Model Fail..");
        }
    }

    public void runInference(View view) {

        new CsvReadTask().execute(csvXFile, csvYFile);
        startButton.setText("Loading....");
    }

    private class CsvReadTask extends AsyncTask<String, Void, Void> {

        // csv file read and preprocessing
        @Override
        protected Void doInBackground(String... params) {
            Context context = MainActivity.this;
            CsvStreamReader X_csvReader = new CsvStreamReader(context, params[0]);
            CsvStreamReader Y_csvReader = new CsvStreamReader(context, params[1]);

            float[][][][] input_list;
            int lable_data;

            while ((input_list = X_csvReader.readNextLines(300)) != null && (lable_data = Y_csvReader.readLabel()) != -1) {
                ++answerSize;
                processData(input_list, lable_data);
            }
            X_csvReader.close();
            Y_csvReader.close();
            return null;
        }
        // When Background task ends, run 'onPostExecute'
        @Override
        protected void onPostExecute(Void result) {
            float accuracy = (rightAnswer/answerSize)*100;

            resultView.setText("Accuracy : " + String.format("%.2f", accuracy) +'%');
            startButton.setText("Complete");
        }
    }

    private void processData(float[][][][] input_list, int label) {

        int num_sensor = 4;
        int batchSize = 1;
        int inputSize = 300;
        int inputChannels = 3;
        int classes = 7;

        Map<Integer, Object> outputs = new HashMap<>();
        float[][] output = new float[batchSize][classes];

        outputs.put(0, output);

        // Print input list
//        for (int j=0; j<batchSize; j++) {
//            for (int k=0; k<inputSize; k++) {
//                Log.e("Value", "* " + Arrays.toString(input_list[0][j][k]) +" , "+ Arrays.toString(input_list[1][j][k]) +" , "+
//                        Arrays.toString(input_list[2][j][k]) +" , "+Arrays.toString(input_list[3][j][k]) + " *");
//            }
//        }
        cls.classify(input_list, outputs);

        Log.e("output", "print : [" + output[0][0]+", "+output[0][1]+", "+output[0][2]+", "+output[0][3]+", "
                + output[0][4]+", "+ output[0][5]+", " + output[0][6]);

        if (argmax(output) == label) {
            rightAnswer += 1;
        }
    }
    private int argmax(float[][] output) {

        if (output.length != 1 || output[0].length != 7) {
            throw new IllegalArgumentException("Invalid output shape");
        }
        int maxIndex = 0;
        float maxValue = output[0][0];

        for (int i = 1; i < 7; i++) {
            if (output[0][i] > maxValue) {
                maxValue = output[0][i];
                maxIndex = i;
            }
        }
        return maxIndex;
    }

    @Override
    protected void onDestroy() {
        cls.finish();
        super.onDestroy();
    }
}

Additionally, we have a CsvStreamReader to read the csv and a MultiModel class to load the model.

CsvStreamReader.java

package com.example.accuracytest;

import android.content.Context;
import android.content.res.AssetManager;
import android.util.Log;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;

public class CsvStreamReader {

    private static final String TAG = "CsvStreamReader";
    private BufferedReader reader;
    int test_value = 0;
    public CsvStreamReader(Context context, String fileName) {
        AssetManager assetManager = context.getAssets();
        InputStream inputStream = null;
        try {
            inputStream = assetManager.open(fileName);
            reader = new BufferedReader(new InputStreamReader(inputStream));
        } catch (IOException e) {
            Log.e(TAG, "Error opening CSV file: " + e.getMessage());
        }
    }
    public float[][][][] readNextLines(int linesToRead) {
        int batchSize = 1;
        int inputSize = 300;
        int inputChannels = 3;

        int batchIndex = 0;

        float[][][] gravity = new float[batchSize][inputSize][inputChannels];
        float[][][] linear_accel = new float[batchSize][inputSize][inputChannels];
        float[][][] gyroscope = new float[batchSize][inputSize][inputChannels];
        float[][][] magnet = new float[batchSize][inputSize][inputChannels];

        try {
            for (int i = 0; i < linesToRead; i++) {
                String line;

                if ((line = reader.readLine()) != null) {

                    String[] values = line.split(",");

                    gravity[batchIndex][i][0] = Float.parseFloat(values[0]);
                    gravity[batchIndex][i][1] = Float.parseFloat(values[1]);
                    gravity[batchIndex][i][2] = Float.parseFloat(values[2]);

                    linear_accel[batchIndex][i][0] = Float.parseFloat(values[3]);
                    linear_accel[batchIndex][i][1] = Float.parseFloat(values[4]);
                    linear_accel[batchIndex][i][2] = Float.parseFloat(values[5]);

                    gyroscope[batchIndex][i][0] = Float.parseFloat(values[6]);
                    gyroscope[batchIndex][i][1] = Float.parseFloat(values[7]);
                    gyroscope[batchIndex][i][2] = Float.parseFloat(values[8]);

                    magnet[batchIndex][i][0] = Float.parseFloat(values[9]);
                    magnet[batchIndex][i][1] = Float.parseFloat(values[10]);
                    magnet[batchIndex][i][2] = Float.parseFloat(values[11]);

                } else {
                    return null;
                }
            }
        } catch (IOException e) {
            Log.e(TAG, "Error reading input CSV file: " + e.getMessage());
        }
        float[][][][] input_list = new float[][][][] {magnet};

        return input_list;
    }
    public int readLabel() {
        int label = 0;

        try {
            String line;

            if ((line = reader.readLine()) != null) {
                String[] values = line.split(",");
                label = (int)Float.parseFloat(values[0]);

            } else {
                return -1;
            }
        } catch (IOException e) {
            Log.e(TAG, "Error reading Label CSV file: " + e.getMessage());
        }

        return label;
    }
    public static void print4DArrayShape(float[][][][] array4D) {
        int batch = array4D.length;
        int depth = array4D[0].length;
        int height = array4D[0][0].length;
        int width = array4D[0][0][0].length;

        Log.d("ArrayShape", "Shape: [" + batch + "][" + depth + "][" + height + "][" + width + "]");
    }

    public void close() {
        try {
            if (reader != null) {
                reader.close();
            }
        } catch (IOException e) {
            Log.e(TAG, "Error closing CSV file: " + e.getMessage());
        }
    }
}

MultiModel.java

package com.example.accuracytest;

import android.content.Context;
import android.util.Log;

import org.tensorflow.lite.Tensor;
import org.tensorflow.lite.support.model.Model;

import java.io.IOException;
import java.nio.MappedByteBuffer;
import java.util.HashMap;
import java.util.Map;

public class MultiModel {
    private static final String MODEL_NAME = "multi_model_h5_ver2.tflite";
    Context context;
    Model model;
    public MultiModel(Context context) {this.context = context;}
    public void init() throws IOException {
        model = Model.createModel(context, MODEL_NAME);
        printShape();
    }
    public void classify(float[][][][] input_list, Map<Integer, Object> outputs) {
        model.run(input_list, outputs);
    }
    public void printShape() {
        Tensor inputTensor = model.getInputTensor(0);
        int[] inputShape = inputTensor.shape();

        Log.e("Shape", "input shape1 : " + inputShape[0]);
        Log.e("Shape", "input shape2 : " + inputShape[1]);
        Log.e("Shape", "input shape3 : " + inputShape[2]);
    }
    public void finish() {
        if(model != null) {
            model.close();
        }
    }
}