This is the model I built. Here, the input is (batch_size, 300, 3) and it takes 4 inputs in total. If I put in a random input like this, the output is the probability of the 7 classes we want.
# python
num_input = 4
list_input = []
for _ in range(num_input):
x = np.random.rand(1,300,3)
x = tf.convert_to_tensor(x)
list_input.append(x)
output = multi_cnn(list_input)
What I want to do is convert this model to tflite so that it can be used in the android environment.
So I converted it to tflite format, imported it back into the windows environment and measured the accuracy with the test dataset that I have. Here’s the code
import numpy as np
import tensorflow as tf
import pandas as pd
x_test = pd.read_csv('.\\data2\\x_test.csv')
y_test = pd.read_csv('.\\data2\\y_test.csv')
interpreter2 = tf.lite.Interpreter(model_path='.\\tflite\\multi_model_h5_ver2.tflite')
interpreter2.allocate_tensors()
input_details = interpreter2.get_input_details()
output_details = interpreter2.get_output_details()
input_shape = input_details[0]['shape']
output_shape = output_details[0]['shape']
x_test = x_test.to_numpy(dtype=np.float32)
output = []
x_test = x_test.reshape(-1,300,12)
for input_data in x_test:
input_data = input_data.reshape(1,300,12)
x1 = input_data[:,:,0:3]
x2 = input_data[:,:,3:6]
x3 = input_data[:,:,6:9]
x4 = input_data[:,:,9:12]
interpreter2.set_tensor(input_details[0]['index'], x1)
interpreter2.set_tensor(input_details[1]['index'], x2)
interpreter2.set_tensor(input_details[2]['index'], x3)
interpreter2.set_tensor(input_details[3]['index'], x4)
interpreter2.invoke()
output_data = interpreter2.get_tensor(output_details[0]['index'])
output.append(output_data)
When we do this, we can get an accuracy of about 54%.
Now let’s run this part on android.
Here’s the code
package com.example.accuracytest;
import androidx.appcompat.app.AppCompatActivity;
import android.content.Context;
import android.os.AsyncTask;
import android.os.Build;
import android.os.Bundle;
import android.content.res.AssetFileDescriptor;
import android.util.Log;
import android.view.View;
import android.widget.Button;
import android.widget.TextView;
import org.tensorflow.lite.Interpreter;
import org.tensorflow.lite.nnapi.NnApiDelegate;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.lang.reflect.Array;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Queue;
public class MainActivity extends AppCompatActivity {
private TextView resultView;
private Button startButton;
private MultiModel cls;
String csvXFile = "x1_train_ver2.csv";
String csvYFile = "y1_train_ver2.csv";
float answerSize = 0;
int rightAnswer = 0;
//Initialize View, Model
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
resultView = findViewById(R.id.resultView);
startButton = findViewById(R.id.startButton);
cls = new MultiModel(this);
try {
cls.init();
} catch (IOException e) {
Log.d("Model", "Load Model Fail..");
}
}
public void runInference(View view) {
new CsvReadTask().execute(csvXFile, csvYFile);
startButton.setText("Loading....");
}
private class CsvReadTask extends AsyncTask<String, Void, Void> {
// csv file read and preprocessing
@Override
protected Void doInBackground(String... params) {
Context context = MainActivity.this;
CsvStreamReader X_csvReader = new CsvStreamReader(context, params[0]);
CsvStreamReader Y_csvReader = new CsvStreamReader(context, params[1]);
float[][][][] input_list;
int lable_data;
while ((input_list = X_csvReader.readNextLines(300)) != null && (lable_data = Y_csvReader.readLabel()) != -1) {
++answerSize;
processData(input_list, lable_data);
}
X_csvReader.close();
Y_csvReader.close();
return null;
}
// When Background task ends, run 'onPostExecute'
@Override
protected void onPostExecute(Void result) {
float accuracy = (rightAnswer/answerSize)*100;
resultView.setText("Accuracy : " + String.format("%.2f", accuracy) +'%');
startButton.setText("Complete");
}
}
private void processData(float[][][][] input_list, int label) {
int num_sensor = 4;
int batchSize = 1;
int inputSize = 300;
int inputChannels = 3;
int classes = 7;
Map<Integer, Object> outputs = new HashMap<>();
float[][] output = new float[batchSize][classes];
outputs.put(0, output);
// Print input list
// for (int j=0; j<batchSize; j++) {
// for (int k=0; k<inputSize; k++) {
// Log.e("Value", "* " + Arrays.toString(input_list[0][j][k]) +" , "+ Arrays.toString(input_list[1][j][k]) +" , "+
// Arrays.toString(input_list[2][j][k]) +" , "+Arrays.toString(input_list[3][j][k]) + " *");
// }
// }
cls.classify(input_list, outputs);
Log.e("output", "print : [" + output[0][0]+", "+output[0][1]+", "+output[0][2]+", "+output[0][3]+", "
+ output[0][4]+", "+ output[0][5]+", " + output[0][6]);
if (argmax(output) == label) {
rightAnswer += 1;
}
}
private int argmax(float[][] output) {
if (output.length != 1 || output[0].length != 7) {
throw new IllegalArgumentException("Invalid output shape");
}
int maxIndex = 0;
float maxValue = output[0][0];
for (int i = 1; i < 7; i++) {
if (output[0][i] > maxValue) {
maxValue = output[0][i];
maxIndex = i;
}
}
return maxIndex;
}
@Override
protected void onDestroy() {
cls.finish();
super.onDestroy();
}
}
Additionally, we have a CsvStreamReader to read the csv and a MultiModel class to load the model.
CsvStreamReader.java
package com.example.accuracytest;
import android.content.Context;
import android.content.res.AssetManager;
import android.util.Log;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
public class CsvStreamReader {
private static final String TAG = "CsvStreamReader";
private BufferedReader reader;
int test_value = 0;
public CsvStreamReader(Context context, String fileName) {
AssetManager assetManager = context.getAssets();
InputStream inputStream = null;
try {
inputStream = assetManager.open(fileName);
reader = new BufferedReader(new InputStreamReader(inputStream));
} catch (IOException e) {
Log.e(TAG, "Error opening CSV file: " + e.getMessage());
}
}
public float[][][][] readNextLines(int linesToRead) {
int batchSize = 1;
int inputSize = 300;
int inputChannels = 3;
int batchIndex = 0;
float[][][] gravity = new float[batchSize][inputSize][inputChannels];
float[][][] linear_accel = new float[batchSize][inputSize][inputChannels];
float[][][] gyroscope = new float[batchSize][inputSize][inputChannels];
float[][][] magnet = new float[batchSize][inputSize][inputChannels];
try {
for (int i = 0; i < linesToRead; i++) {
String line;
if ((line = reader.readLine()) != null) {
String[] values = line.split(",");
gravity[batchIndex][i][0] = Float.parseFloat(values[0]);
gravity[batchIndex][i][1] = Float.parseFloat(values[1]);
gravity[batchIndex][i][2] = Float.parseFloat(values[2]);
linear_accel[batchIndex][i][0] = Float.parseFloat(values[3]);
linear_accel[batchIndex][i][1] = Float.parseFloat(values[4]);
linear_accel[batchIndex][i][2] = Float.parseFloat(values[5]);
gyroscope[batchIndex][i][0] = Float.parseFloat(values[6]);
gyroscope[batchIndex][i][1] = Float.parseFloat(values[7]);
gyroscope[batchIndex][i][2] = Float.parseFloat(values[8]);
magnet[batchIndex][i][0] = Float.parseFloat(values[9]);
magnet[batchIndex][i][1] = Float.parseFloat(values[10]);
magnet[batchIndex][i][2] = Float.parseFloat(values[11]);
} else {
return null;
}
}
} catch (IOException e) {
Log.e(TAG, "Error reading input CSV file: " + e.getMessage());
}
float[][][][] input_list = new float[][][][] {magnet};
return input_list;
}
public int readLabel() {
int label = 0;
try {
String line;
if ((line = reader.readLine()) != null) {
String[] values = line.split(",");
label = (int)Float.parseFloat(values[0]);
} else {
return -1;
}
} catch (IOException e) {
Log.e(TAG, "Error reading Label CSV file: " + e.getMessage());
}
return label;
}
public static void print4DArrayShape(float[][][][] array4D) {
int batch = array4D.length;
int depth = array4D[0].length;
int height = array4D[0][0].length;
int width = array4D[0][0][0].length;
Log.d("ArrayShape", "Shape: [" + batch + "][" + depth + "][" + height + "][" + width + "]");
}
public void close() {
try {
if (reader != null) {
reader.close();
}
} catch (IOException e) {
Log.e(TAG, "Error closing CSV file: " + e.getMessage());
}
}
}
MultiModel.java
package com.example.accuracytest;
import android.content.Context;
import android.util.Log;
import org.tensorflow.lite.Tensor;
import org.tensorflow.lite.support.model.Model;
import java.io.IOException;
import java.nio.MappedByteBuffer;
import java.util.HashMap;
import java.util.Map;
public class MultiModel {
private static final String MODEL_NAME = "multi_model_h5_ver2.tflite";
Context context;
Model model;
public MultiModel(Context context) {this.context = context;}
public void init() throws IOException {
model = Model.createModel(context, MODEL_NAME);
printShape();
}
public void classify(float[][][][] input_list, Map<Integer, Object> outputs) {
model.run(input_list, outputs);
}
public void printShape() {
Tensor inputTensor = model.getInputTensor(0);
int[] inputShape = inputTensor.shape();
Log.e("Shape", "input shape1 : " + inputShape[0]);
Log.e("Shape", "input shape2 : " + inputShape[1]);
Log.e("Shape", "input shape3 : " + inputShape[2]);
}
public void finish() {
if(model != null) {
model.close();
}
}
}