I trained an object detector model using the recently released Yolov8 from Ultralytics. I managed to convert it to tensorflow.js formats. However, I am getting errors while trying to use the model via webcam. Here’s the demo code I am using.
import React from "react";
import ReactDOM from "react-dom";
import * as tf from '@tensorflow/tfjs';
import {loadGraphModel} from '@tensorflow/tfjs-converter';
import "./styles.css";
tf.setBackend('webgl');
const threshold = 0.10;
async function load_model() {
// It's possible to load the model locally or from a repo
// You can choose whatever IP and PORT you want in the "http://127.0.0.1:8080/model.json" just set it before in your https server
// const model = await loadGraphModel("https://raw.githubusercontent.com/jideilori/test_js_model/main/model.json");
// const model = await loadGraphModel("../models/kangaroo-detector/model.json")
const model = await loadGraphModel("https://raw.githubusercontent.com/jideilori/tfjs_models/main/model.json")
// const model = await loadGraphModel("https://raw.githubusercontent.com/hugozanini/TFJS-object-detection/master/models/kangaroo-detector/model.json");
return model;
}
let classesDir = {
1: {
name: 'mp',
id: 1,
},
2: {
name: 'wbc',
id: 2,
}
}
class App extends React.Component {
videoRef = React.createRef();
canvasRef = React.createRef();
componentDidMount() {
if (navigator.mediaDevices && navigator.mediaDevices.getUserMedia) {
const webCamPromise = navigator.mediaDevices
.getUserMedia({
audio: false,
video: {
facingMode: "user",
width: 640,
height: 640,
}
})
.then(stream => {
window.stream = stream;
this.videoRef.current.srcObject = stream;
return new Promise((resolve, reject) => {
this.videoRef.current.onloadedmetadata = () => {
resolve();
};
});
});
const modelPromise = load_model();
Promise.all([modelPromise, webCamPromise])
.then(values => {
this.detectFrame(this.videoRef.current, values[0]);
})
.catch(error => {
console.error(error);
});
}
}
detectFrame = (video, model) => {
tf.engine().startScope();
model.executeAsync(this.process_input(video)).then(predictions => {
this.renderPredictions(predictions, video);
requestAnimationFrame(() => {
this.detectFrame(video, model);
});
tf.engine().endScope();
});
};
process_input(video_frame){
const tfimg = tf.browser.fromPixels(video_frame).toInt();
const tfimg_res = tf.image.resizeBilinear(tfimg,[640,640]);
const expandedimg = tfimg_res.transpose([0,1,2]).expandDims();
return expandedimg;
};
buildDetectedObjects(scores, threshold, boxes, classes, classesDir) {
const detectionObjects = []
var video_frame = document.getElementById('frame');
scores[0].forEach((score, i) => {
if (score > threshold) {
const bbox = [];
const minY = boxes[0][i][0] * video_frame.offsetHeight;
const minX = boxes[0][i][1] * video_frame.offsetWidth;
const maxY = boxes[0][i][2] * video_frame.offsetHeight;
const maxX = boxes[0][i][3] * video_frame.offsetWidth;
bbox[0] = minX;
bbox[1] = minY;
bbox[2] = maxX - minX;
bbox[3] = maxY - minY;
detectionObjects.push({
class: classes[i],
label: classesDir[classes[i]].name,
score: score.toFixed(4),
bbox: bbox
})
}
})
return detectionObjects
}
// return detectionObjects
// }
renderPredictions = predictions => {
console.log(predictions);
const ctx = this.canvasRef.current.getContext("2d");
ctx.clearRect(0, 0, ctx.canvas.width, ctx.canvas.height);
// Font options.
const font = "16px sans-serif";
ctx.font = font;
ctx.textBaseline = "top";
//Getting predictions
// console.log(boxes);
const boxes = predictions[4].arraySync();
const scores = predictions[5].arraySync();
const classes = predictions[6].dataSync();
const detections = this.buildDetectedObjects(scores, threshold,
boxes, classes, classesDir);
detections.forEach(item => {
const x = item['bbox'][0];
const y = item['bbox'][1];
const width = item['bbox'][2];
const height = item['bbox'][3];
// Draw the bounding box.
ctx.strokeStyle = "#00FFFF";
ctx.lineWidth = 4;
ctx.strokeRect(x, y, width, height);
// Draw the label background.
ctx.fillStyle = "#00FFFF";
const textWidth = ctx.measureText(item["label"] + " " + (100 * item["score"]).toFixed(2) + "%").width;
const textHeight = parseInt(font, 10); // base 10
ctx.fillRect(x, y, textWidth + 4, textHeight + 4);
});
detections.forEach(item => {
const x = item['bbox'][0];
const y = item['bbox'][1];
// Draw the text last to ensure it's on top.
ctx.fillStyle = "#000000";
ctx.fillText(item["label"] + " " + (100*item["score"]).toFixed(2) + "%", x, y);
});
};
render() {
return (
<div>
<h1>Real-Time Object Detection: Kangaroo</h1>
<h3>MobileNetV2</h3>
<video
style={{height: '640px', width: "640px"}}
className="size"
autoPlay
playsInline
muted
ref={this.videoRef}
width="640"
height="640"
id="frame"
/>
<canvas
className="size"
ref={this.canvasRef}
width="640"
height="640"
/>
</div>
);
}
}
const rootElement = document.getElementById("root");
ReactDOM.render(<App />, rootElement);