Project 2: Use transfer learning to create a motion-controlled UI
Set up the project
The code for the first project is in the exercises/project-2
folder so change directory to it with:
cd exercises/project-2
Install the dependencies with:
npm install
Start the demo app with:
npm run watch
A browser window should open at http://localhost:1234
Train a model online
Navigate to the Teachable Machine website, record a few image samples and download the model to your computer.
Load the model
const URL = "./my_model/"; const modelURL = URL + "model.json"; const metadataURL = URL + "metadata.json"; let model = await tmImage.load(modelURL, metadataURL); let maxPredictions = model.getTotalClasses();
Start the webcam
If you look at the index.html file, you will see a script tag importing the teachable machine library.
This library exposes a method that makes it easy to set up the webcam feed:
webcam = new tmImage.Webcam(200, 200, true); // width, height, flip the webcam await webcam.setup(); // request access to the webcam await webcam.play();
Predict
Then, use requestAnimationFrame
to continuously call a function that will predict the output of the webcam.
window.requestAnimationFrame(loop); async function loop() { webcam.update(); // update the webcam frame await predict(); window.requestAnimationFrame(loop); }
This function calls predict
on the model. The output can then be used to extract the label with the highest probability.
async function predict() { // predict can take in an image, video or canvas html element const predictions = await model.predict(webcam.canvas); const topPrediction = Math.max(...predictions.map((p) => p.probability)); const topPredictionIndex = predictions.findIndex( (p) => p.probability === topPrediction ); console.log(predictions[topPredictionIndex].className); }
Enjoy! 🎉
Try to train your model with different input from the webcam and experiment making interactive UIs!
Step 2: Train the model in the browser
You can run transfer learning on the mobilenet model without using Teachable Machine (see the part2.js file).
Import the packages
import * as tf from "@tensorflow/tfjs"; import * as tfd from "@tensorflow/tfjs-data";
Load the model
You can load the mobilenet model in your application and return its internal function to use when re-training your custom model.
const loadModel = async () => { const mobilenet = await tf.loadLayersModel( "https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json" ); const layer = mobilenet.getLayer("conv_pw_13_relu"); return tf.model({ inputs: mobilenet.inputs, outputs: layer.output }); }; const init = async () => { webcam = await tfd.webcam(document.getElementById("webcam")); initialModel = await loadModel(); statusElement.style.display = "none"; document.getElementById("controller").style.display = "block"; };
Add examples
When implementing the functionality yourself, you need to add examples that the model can be trained with. In the index.html file, there are 2 buttons to record images with the label 'left' and 'right'. When the user presses these buttons, images are being kept in memory to train the model.
buttonsContainer.onmousedown = (e) => { if (e.target === recordButtons[0]) { // left handleAddExample(0); } else if (e.target === recordButtons[1]) { // right handleAddExample(1); } }; const handleAddExample = async (labelIndex) => { mouseDown = true; const total = document.getElementById(labels[labelIndex] + "-total"); while (mouseDown) { addExample(labelIndex); total.innerText = ++totals[labelIndex]; // Returns a promise that resolve when a requestAnimationFrame has completed await tf.nextFrame(); } }; const addExample = async (index) => { let img = await getImage(); // Gets a snapshot from the webcam let example = initialModel.predict(img); // One-hot encode the label. // Turns categorical data (e.g. colors) into numerical data const y = tf.tidy(() => tf.oneHot(tf.tensor1d([index]).toInt(), labels.length) ); if (xs == null) { // For the first example that gets added, keep example and y so that we own the memory of the inputs. xs = tf.keep(example); ys = tf.keep(y); } else { const oldX = xs; xs = tf.keep(oldX.concat(example, 0)); const oldY = ys; ys = tf.keep(oldY.concat(y, 0)); oldX.dispose(); oldY.dispose(); y.dispose(); } img.dispose(); }; const getImage = async () => { const img = await webcam.capture(); const processedImg = tf.tidy(() => img.expandDims(0).toFloat().div(127).sub(1) ); img.dispose(); return processedImg; };
Train
After gathering new examples, we need to create a new model to train with the samples recorded.
const train = () => { isTraining = true; if (!xs) { throw new Error("Add some examples before training!"); } newModel = tf.sequential({ layers: [ tf.layers.flatten({ inputShape: initialModel.outputs[0].shape.slice(1), }), // units is the output shape of the dense layer tf.layers.dense({ units: denseUnits, activation: "relu", kernelInitializer: "varianceScaling", useBias: true, }), // The neural network should have 2 outputs so the last layer should have 2 units tf.layers.dense({ units: labels.length, kernelInitializer: "varianceScaling", useBias: false, activation: "softmax", }), ], }); const optimizer = tf.train.adam(learningRate); newModel.compile({ optimizer: optimizer, loss: "categoricalCrossentropy" }); const batchSize = Math.floor(xs.shape[0] * batchSizeFraction); if (!(batchSize > 0)) { throw new Error( `Batch size is 0 or NaN. Please choose a non-zero fraction.` ); } newModel.fit(xs, ys, { batchSize, epochs: epochs, callbacks: { onBatchEnd: async (batch, logs) => { statusElement.innerHTML = "Loss: " + logs.loss.toFixed(5); }, }, }); isTraining = false; };
Predict
After the new model has been trained, we can run predictions with live input from the webcam.
predictButton.onclick = async () => { isPredicting = true; while (isPredicting) { const img = await getImage(); const embeddings = initialModel.predict(img); const predictions = newModel.predict(embeddings); const predictedClass = predictions.as1D().argMax(); const classId = (await predictedClass.data())[0]; img.dispose(); console.log(labels[classId]); await tf.nextFrame(); } };
That's it!