Hello everyone,
I am trying to implement On-device training for fashion mnist model. But i am facing some run time issues.
I am really confused about this. Can anyone please assist me . I am not getting any idea how to implement this.
Hello everyone,
I am trying to implement On-device training for fashion mnist model. But i am facing some run time issues.
I am really confused about this. Can anyone please assist me . I am not getting any idea how to implement this.
Hi @chunduriv ,
Sorry for the late reply. here is the response i am getting when ever i click on train button
2023-05-11 12:31:39.615 8542-8542 MsyncFactory com.application.mlmodeltesting E [static] ClassNotFoundException
java.lang.ClassNotFoundException: com.mediatek.view.impl.MsyncFactoryImpl
at java.lang.Class.classForName(Native Method)
at java.lang.Class.forName(Class.java:454)
at java.lang.Class.forName(Class.java:379)
at com.mediatek.view.MsyncFactory.<clinit>(MsyncFactory.java:14)
at com.mediatek.view.MsyncFactory.getInstance(MsyncFactory.java:29)
at android.view.ViewRootImpl.<init>(ViewRootImpl.java:763)
at android.view.ViewRootImpl.<init>(ViewRootImpl.java:859)
at android.view.WindowManagerGlobal.addView(WindowManagerGlobal.java:393)
at android.view.WindowManagerImpl.addView(WindowManagerImpl.java:134)
at android.app.ActivityThread.handleResumeActivity(ActivityThread.java:5012)
at android.app.servertransaction.ResumeActivityItem.execute(ResumeActivityItem.java:54)
at android.app.servertransaction.ActivityTransactionItem.execute(ActivityTransactionItem.java:45)
at android.app.servertransaction.TransactionExecutor.executeLifecycleState(TransactionExecutor.java:176)
I recently changed the code for digit classification model
here is the code , i changed:
# Import the necessary libraries
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt
# Load the MNIST dataset
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# Normalize the pixel values
x_train = x_train / 255.0
x_test = x_test / 255.0
# Define a simple sequential model
def create_model():
model = keras.Sequential([
layers.Flatten(input_shape=(28, 28)),
layers.Dense(128, activation='relu'),
layers.Dropout(0.2),
layers.Dense(10)
])
model.compile(optimizer='adam',
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
return model
# Create a model instance
model = create_model()
# Train the model on the training data
model.fit(x_train, y_train, epochs=5)
# Evaluate the model on the test data
model.evaluate(x_test, y_test, verbose=2)
# Define a function to add signatures for on-device training
def add_signatures(model):
# Get the input and output tensors of the model
input_tensor = model.input
output_tensor = model.output
# Define a train function that takes the input tensor and updates the weights
@tf.function(input_signature=[tf.TensorSpec(input_tensor.shape, input_tensor.dtype)])
def train(input_data):
with tf.GradientTape() as tape:
predictions = model(input_data)
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)(y_train[:len(input_data)], predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer = tf.keras.optimizers.Adam()
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
# Define an infer function that takes the input tensor and returns the output tensor
@tf.function(input_signature=[tf.TensorSpec(input_tensor.shape, input_tensor.dtype)])
def infer(input_data):
return output_tensor
# Define a save function that saves the weights to a file path
@tf.function(input_signature=[tf.TensorSpec([], tf.string)])
def save(file_path):
tf.io.write_file(file_path, tf.io.serialize_tensor(model.get_weights()))
# Define a restore function that restores the weights from a file path
@tf.function(input_signature=[tf.TensorSpec([], tf.string)])
def restore(file_path):
model.set_weights(tf.io.parse_tensor(tf.io.read_file(file_path), out_type=tf.float32))
# Return a dictionary of signatures
signatures = {
'train': train,
'infer': infer,
'save': save,
'restore': restore,
}
return signatures
# Convert the model to TensorFlow Lite format with signatures
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,
tf.lite.OpsSet.SELECT_TF_OPS]
converter.experimental_enable_resource_variables = True
converter._experimental_new_converter = True
signatures = add_signatures(model)
converter._experimental_signature_defs = signatures
tflite_model = converter.convert()
# Save the TensorFlow Lite model to a file
with open('digit_model.tflite', 'wb') as f:
f.write(tflite_model)
# Download the TensorFlow Lite model file from Colab
from google.colab import files
files.download('digit_model.tflite')
The Activity class program in java class :
package com.application.mlmodeltesting;
import android.content.Intent;
import android.content.res.AssetFileDescriptor;
import android.graphics.Bitmap;
import android.net.Uri;
import android.os.Bundle;
import android.provider.MediaStore;
import android.view.View;
import android.widget.Button;
import android.widget.ImageView;
import android.widget.TextView;
import android.widget.Toast;
import androidx.appcompat.app.AppCompatActivity;
import org.tensorflow.lite.DataType;
import org.tensorflow.lite.Interpreter;
import org.tensorflow.lite.support.image.TensorImage;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
import java.io.FileInputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
public class MainActivity extends AppCompatActivity {
private Interpreter interpreter;
private ImageView imageView;
private TextView predictionTextView;
private TextView accuracyTextView;
private Button chooseImageButton;
private Button classifyButton;
private Bitmap imagebitmap;
private ByteBuffer inputBuffer;
private float[][] outputBuffer;
private static final int NUM_CLASSES = 10;
private static final int PICK_IMAGE = 1;
private static final int IMAGE_SIZE = 28;
private static final String MODEL_FILE = "digit_model.tflite";
private static final String TRAIN_SIGNATURE = "train";
private static final String INFER_SIGNATURE = "infer";
private static final String SAVE_SIGNATURE = "save";
private static final String RESTORE_SIGNATURE = "restore";
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
// Initialize TensorFlow Lite model
try {
interpreter = new Interpreter(loadModelFile());
inputBuffer = ByteBuffer.allocateDirect(IMAGE_SIZE * IMAGE_SIZE * 4);
inputBuffer.order(ByteOrder.nativeOrder());
outputBuffer = new float[1][NUM_CLASSES];
} catch (IOException e) {
e.printStackTrace();
}
// Get references to views
imageView = findViewById(R.id.image_view);
predictionTextView = findViewById(R.id.prediction_text_view);
accuracyTextView = findViewById(R.id.accuracy_text_view);
chooseImageButton = findViewById(R.id.choose_image_button);
classifyButton = findViewById(R.id.classify_button);
// Set up choose image button
chooseImageButton.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View v) {
// Open image picker
// ...
Intent intent = new Intent();
intent.setType("image/*");
intent.setAction(Intent.ACTION_GET_CONTENT);
startActivityForResult(Intent.createChooser(intent, "Select Picture"), PICK_IMAGE);
}
});
classifyButton.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View v) {
// Get the image from the imageView and convert it to a byte buffer
inputBuffer.rewind();
imagebitmap.copyPixelsToBuffer(inputBuffer);
// Create a map for the input tensor
Map<String, Object> inputMap = new HashMap<>();
inputMap.put("input", inputBuffer);
// Create a map for the output tensor
Map<String, Object> outputMap = new HashMap<>();
outputMap.put("output", outputBuffer);
// Run inference using the infer signature
interpreter.runSignature(inputMap, outputMap, INFER_SIGNATURE);
// Restore the model weights from the file system
interpreter.runSignature(null, null, RESTORE_SIGNATURE);
// Find the index of the maximum value in the output buffer
int maxIndex = 0;
float maxValue = 0f;
for (int i = 0; i < NUM_CLASSES; i++) {
if (outputBuffer[0][i] > maxValue) {
maxIndex = i;
maxValue = outputBuffer[0][i];
}
}
// Show the predicted digit and the confidence score
predictionTextView.setText("Predicted digit: " + maxIndex + "\nConfidence: " + maxValue);
}
});
// Set up classify button
// classifyButton.setOnClickListener(new View.OnClickListener() {
// @Override
// public void onClick(View v) {
// if (imagebitmap != null) {
// // Preprocess the image
// float[] input = getPreprocessedImage(imagebitmap);
//
// // Classify the image
// float[][] output = new float[1][10];
// Map<String, Object> inputs = new HashMap<>();
// inputs.put("x", input);
// inputs.put("y", output);
//
// Map<String, Object> outputs = new HashMap<>();
// FloatBuffer loss = FloatBuffer.allocate(1);
// outputs.put("loss", loss);
//
//// tflite.runSignature(inputs, outputs, "train");
// tflite.run(input, output);
//
// // Find the class with the highest confidence
// int classIndex = 0;
// float maxConfidence = output[0][0];
// for (int i = 1; i < 10; i++) {
// if (output[0][i] > maxConfidence) {
// classIndex = i;
// maxConfidence = output[0][i];
// }
// }
//
// // Display the prediction
// predictionTextView.setText(String.valueOf(classIndex));
// accuracyTextView.setText(String.valueOf(maxConfidence));
// }
// }
// });
}
@Override
protected void onActivityResult(int requestCode, int resultCode, Intent data) {
super.onActivityResult(requestCode, resultCode, data);
if (requestCode == PICK_IMAGE && resultCode == RESULT_OK && data != null && data.getData() != null) {
// Get the image URI
Uri imageUri = data.getData();
try {
// Convert the image URI to a Bitmap
imagebitmap = MediaStore.Images.Media.getBitmap(getContentResolver(), imageUri);
// Display the image
imageView.setImageBitmap(imagebitmap);
imagebitmap = getPreprocessedImage(imagebitmap);
} catch (IOException e) {
e.printStackTrace();
}
}
}
private Bitmap getPreprocessedImage(Bitmap image) {
// Resize the image
Bitmap resizedImage = Bitmap.createScaledBitmap(image, 28, 28, true);
// Convert the image to a float array
int width = resizedImage.getWidth();
int height = resizedImage.getHeight();
int[] pixels = new int[width * height];
resizedImage.getPixels(pixels, 0, width, 0, 0, width, height);
float[] imageData = new float[pixels.length];
for (int i = 0; i < pixels.length; i++) {
imageData[i] = (pixels[i] & 0xff) / 255.0f;
}
return resizedImage;
}
private MappedByteBuffer loadModelFile() throws IOException {
// Open the model file from the assets folder
AssetFileDescriptor fileDescriptor = getAssets().openFd(MODEL_FILE);
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
return fileChannel.map(FileChannel.MapMode.READ_ONLY, fileDescriptor.getStartOffset(), fileDescriptor.getDeclaredLength());
}
private float[] inferImage(Bitmap bitmap) {
// Pass the preprocessed image to the TensorFlow Lite model for inference
TensorImage tensorImage = new TensorImage(DataType.FLOAT32);
tensorImage.load(bitmap);
ByteBuffer byteBuffer = tensorImage.getBuffer();
TensorBuffer inputBuffer = TensorBuffer.createFixedSize(new int[]{1, IMAGE_SIZE, IMAGE_SIZE, 1}, DataType.FLOAT32);
inputBuffer.loadBuffer(byteBuffer);
TensorBuffer outputBuffer = TensorBuffer.createFixedSize(new int[]{1, 10}, DataType.FLOAT32);
interpreter.run(inputBuffer.getBuffer(), outputBuffer.getBuffer());
return outputBuffer.getFloatArray();
}
}
The error i am getting is :
2023-05-11 12:32:03.085 8542-8542 AndroidRuntime com.application.mlmodeltesting E FATAL EXCEPTION: main
Process: com.application.mlmodeltesting, PID: 8542
java.lang.IllegalArgumentException: Input error: Signature infer not found.
at org.tensorflow.lite.NativeSignatureRunnerWrapper.<init>(NativeSignatureRunnerWrapper.java:28)
at org.tensorflow.lite.NativeInterpreterWrapper.getSignatureRunnerWrapper(NativeInterpreterWrapper.java:543)
at org.tensorflow.lite.NativeInterpreterWrapper.runSignature(NativeInterpreterWrapper.java:181)
at org.tensorflow.lite.Interpreter.runSignature(Interpreter.java:253)
at com.application.mlmodeltesting.MainActivity$2.onClick(MainActivity.java:110)
at android.view.View.performClick(View.java:7751)
at com.google.android.material.button.MaterialButton.performClick(MaterialButton.java:1219)
at android.view.View.performClickInternal(View.java:7724)
at android.view.View.access$3700(View.java:858)
at android.view.View$PerformClick.run(View.java:29336)
at android.os.Handler.handleCallback(Handler.java:938)
at android.os.Handler.dispatchMessage(Handler.java:99)
at android.os.Looper.loopOnce(Looper.java:210)
at android.os.Looper.loop(Looper.java:299)
at android.app.ActivityThread.main(ActivityThread.java:8280)
at java.lang.reflect.Method.invoke(Native Method)
at com.android.internal.os.RuntimeInit$MethodAndArgsCaller.run(RuntimeInit.java:576)
at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:1073)
@chunduriv can you help me figure out , what is the mistake ?