Hello Everyone, i am Ashok. i am student i am working on a digits mnist classification project in the part of my internship. i would like to create on device machine learning training in android app.
reference:
i trained the model and i am facing this warning:
WARNING:absl:Importing a function (__inference_internal_grad_fn_368181) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
later, i converted it into tensorflow lite model.
i got stuck while creating the application. the error i am facing is:
java.lang.IllegalArgumentException: Cannot copy to a TensorFlowLite tensor (train_y:0) with 40 bytes from a Java Buffer with 8 byte
please help me. i am new to python and machine learning. I truly appreciate your help. thank you.
Java code
public class MainActivity extends AppCompatActivity {
private ImageView imageView;
private TextView textView;
private Button selectImageButton, ProcessImage;
private Button trainModelButton, Updateweights;
private Button predictButton;
private Bitmap image;
private static final int NUM_EPOCHS = 100;
private static final int BATCH_SIZE = 10;
private static final int IMG_HEIGHT = 28;
private static final int IMG_WIDTH = 28;
private static final int NUM_TRAININGS = 60000;
private static final int NUM_BATCHES = NUM_TRAININGS / BATCH_SIZE;
private static final int NUM_IMAGES = 1;
private static final int REQUEST_CODE_GALLERY = 1;
private List<Bitmap> selectedImages;
private List<FloatBuffer> trainImageBatches;
private List<FloatBuffer> trainLabelBatches;
private Button SelectImagesBtn, TrainModelBtn;
// ByteBuffer modelBuffer;
Interpreter modelBuffer;
Bitmap bitmap;
private static final int IMAGE_PICK_REQUEST_CODE = 1;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
imageView = findViewById(R.id.selected_image_view);
textView = findViewById(R.id.improved_learning_rate_text_view);
selectImageButton = findViewById(R.id.select_image_button);
trainModelButton = findViewById(R.id.train_model_button);
predictButton = findViewById(R.id.predict_number_button);
ProcessImage = findViewById(R.id.process_image_button);
try {
modelBuffer = new Interpreter(loadModelFile());
} catch (Exception e) {
Log.e("MainActivity", "Error loading TFLite model", e);
}
selectImageButton.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View v) {
selectImagesFromGallery();
}
});
trainModelButton.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View v) {
trainModel();
Toast.makeText(getApplicationContext()," Train button is clicked",Toast.LENGTH_SHORT).show();
}
});
predictButton.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View v) {
// LoadOndevicetrainedmodel();
// predictNumber();
Toast.makeText(getApplicationContext()," predict button is clicked",Toast.LENGTH_SHORT).show();
}
});
ProcessImage.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View v) {
prepareTrainingBatches();
Toast.makeText(getApplicationContext()," process button is clicked",Toast.LENGTH_SHORT).show();
}
});
}
// Method to select images from the gallery
private void selectImagesFromGallery() {
// Use an Intent to pick images from the gallery
Intent intent = new Intent(Intent.ACTION_PICK, MediaStore.Images.Media.EXTERNAL_CONTENT_URI);
intent.setType("image/*");
intent.putExtra(Intent.EXTRA_ALLOW_MULTIPLE, true);
intent.setAction(Intent.ACTION_GET_CONTENT);
startActivityForResult(Intent.createChooser(intent, "Select Images"), REQUEST_CODE_GALLERY);
}
@Override
protected void onActivityResult(int requestCode, int resultCode, Intent data) {
super.onActivityResult(requestCode, resultCode, data);
if (requestCode == REQUEST_CODE_GALLERY && resultCode == RESULT_OK) {
ClipData clipData = data.getClipData();
if (clipData != null) {
selectedImages = new ArrayList<>();
int count = clipData.getItemCount();
count = Math.min(count, NUM_IMAGES);
for (int i = 0; i < count; i++) {
Uri imageUri = clipData.getItemAt(i).getUri();
try {
bitmap = MediaStore.Images.Media.getBitmap(this.getContentResolver(), imageUri);
// selectedImages.add(bitmap);
imageView.setImageBitmap(bitmap);
bitmap = resizeImage(bitmap);
Toast.makeText(getApplicationContext(),"image converted to bitmap",Toast.LENGTH_LONG).show();
} catch (IOException e) {
e.printStackTrace();
}
}
}
}
}
// Method to prepare training batches using the selected images
private void prepareTrainingBatches() {
try {
trainImageBatches = new ArrayList<>(NUM_BATCHES);
trainLabelBatches = new ArrayList<>(NUM_BATCHES);
// Iterate over the selected images
for (int i = 0; i < NUM_IMAGES; i++) {
// Allocate a direct buffer to store the image data
// ByteBuffer byteBuffer = ByteBuffer.allocateDirect(IMG_HEIGHT * IMG_WIDTH * BATCH_SIZE).order(ByteOrder.nativeOrder());
FloatBuffer trainImages = convertBitmapToFloatBuffer(bitmap);
// Convert the resized image to grayscale
Bitmap grayscaleImage = toGrayscale(bitmap);
// Convert the grayscale image to a float buffer
FloatBuffer floatBuffer = convertBitmapToFloatBuffer(grayscaleImage);
// Add the float buffer to trainImageBatches
trainImageBatches.add(floatBuffer);
// Allocate a direct buffer to store the label data
ByteBuffer labelBuffer = ByteBuffer.allocateDirect(10 * BATCH_SIZE).order(ByteOrder.nativeOrder());
FloatBuffer trainLabels = labelBuffer.asFloatBuffer();
// Fill the image and label data for the current batch
// trainImageBatches.add((FloatBuffer) trainImages.rewind());
trainLabelBatches.add((FloatBuffer) trainLabels.rewind());
Toast.makeText(getApplicationContext(), "prepareTrainingBatches is done", Toast.LENGTH_LONG).show();
}
} catch (Exception e) {
e.printStackTrace();
Toast.makeText(getApplicationContext(), "Error :"+ e, Toast.LENGTH_LONG).show();
}
}
public void trainModel(){
try {
// Run training for a few steps.
float[] losses = new float[NUM_EPOCHS];
for (int epoch = 0; epoch < NUM_EPOCHS; ++epoch) {
for (int batchIdx = 0; batchIdx < NUM_BATCHES; ++batchIdx) {
Map<String, Object> inputs = new HashMap<>();
inputs.put("x", trainImageBatches.get(batchIdx));
inputs.put("y", trainLabelBatches.get(batchIdx));
Map<String, Object> outputs = new HashMap<>();
FloatBuffer loss = FloatBuffer.allocate(1);
outputs.put("loss", loss);
modelBuffer.runSignature(inputs, outputs, "train");
// Record the last loss.
if (batchIdx == NUM_BATCHES - 1) losses[epoch] = loss.get(0);
}
// Print the loss output for every 10 epochs.
if ((epoch + 1) % 10 == 0) {
System.out.println(
"Finished " + (epoch + 1) + " epochs, current loss: " + losses[epoch]);
textView.setText("Finished " + (epoch + 1) + " epochs, current loss: " + losses[epoch]);
}
}
// ...
File outputFile = new File(getFilesDir(), "checkpoint.ckpt");
Map<String, Object> inputs = new HashMap<>();
inputs.put("checkpoint_path", outputFile.getAbsolutePath());
Map<String, Object> outputs = new HashMap<>();
modelBuffer.runSignature(inputs, outputs, "save");
}
catch (Exception e){
e.printStackTrace();
Log.d("TRAIN MODEL:", String.valueOf(e));
Toast.makeText(getApplicationContext(),"Error:"+e,Toast.LENGTH_LONG).show();
}
}
private MappedByteBuffer loadModelFile() throws IOException {
// Load the TensorFlow Lite model from a file
AssetFileDescriptor fileDescriptor = getAssets().openFd("model.tflite");
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}
private Bitmap resizeImage(Bitmap originalImage){
int width = originalImage.getWidth();
int height = originalImage.getHeight();
int newWidth = 28;
int newHeight = 28;
float scaleWidth = ((float) newWidth) / width;
float scaleHeight = ((float) newHeight) / height;
Matrix matrix = new Matrix();
matrix.postScale(scaleWidth, scaleHeight);
// Bitmap resizedImage = Bitmap.createBitmap(originalImage, 0, 0, width, height, matrix, false);
Bitmap resizedImage = Bitmap.createScaledBitmap(originalImage,newWidth,newHeight,true);
return resizedImage;
}
// The toGrayscale() and convertBitmapToFloatBuffer() methods are defined as follows:
public static Bitmap toGrayscale(Bitmap bmpOriginal) {
int width, height;
height = bmpOriginal.getHeight();
width = bmpOriginal.getWidth();
Bitmap bmpGrayscale = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888);
Canvas c = new Canvas(bmpGrayscale);
Paint paint = new Paint();
ColorMatrix cm = new ColorMatrix();
cm.setSaturation(0);
ColorMatrixColorFilter f = new ColorMatrixColorFilter(cm);
paint.setColorFilter(f);
c.drawBitmap(bmpOriginal, 0, 0, paint);
return bmpGrayscale;
}
public static FloatBuffer convertBitmapToFloatBuffer(Bitmap bitmap) {
int width = bitmap.getWidth();
int height = bitmap.getHeight();
float[] floatValues = new float[width * height];
for (int i = 0; i < height; ++i) {
for (int j = 0; j < width; ++j) {
int pixelValue = bitmap.getPixel(j, i);
floatValues[i * width + j] = (float)(pixelValue & 0xff) / 255.0f;
}
}
FloatBuffer floatBuffer = FloatBuffer.wrap(floatValues);
return floatBuffer;
}