Project 2: Use transfer learning to create a motion-controlled UI

Set up the project

The code for the first project is in the exercises/project-2 folder so change directory to it with:

cd exercises/project-2

Install the dependencies with:

npm install

Start the demo app with:

npm run watch

A browser window should open at http://localhost:1234

Train a model online

Navigate to the Teachable Machine website, record a few image samples and download the model to your computer.

Load the model

const URL = "./my_model/";

const modelURL = URL + "model.json";
const metadataURL = URL + "metadata.json";

let model = await tmImage.load(modelURL, metadataURL);
let maxPredictions = model.getTotalClasses();

Start the webcam

If you look at the index.html file, you will see a script tag importing the teachable machine library.

This library exposes a method that makes it easy to set up the webcam feed:

webcam = new tmImage.Webcam(200, 200, true); // width, height, flip the webcam
await webcam.setup(); // request access to the webcam
await webcam.play();

Predict

Then, use requestAnimationFrame to continuously call a function that will predict the output of the webcam.

window.requestAnimationFrame(loop);

async function loop() {
  webcam.update(); // update the webcam frame
  await predict();
  window.requestAnimationFrame(loop);
}

This function calls predict on the model. The output can then be used to extract the label with the highest probability.

async function predict() {
  // predict can take in an image, video or canvas html element
  const predictions = await model.predict(webcam.canvas);

  const topPrediction = Math.max(...predictions.map((p) => p.probability));

  const topPredictionIndex = predictions.findIndex(
    (p) => p.probability === topPrediction
  );
  console.log(predictions[topPredictionIndex].className);
}

Enjoy! 🎉

Try to train your model with different input from the webcam and experiment making interactive UIs!

Step 2: Train the model in the browser

You can run transfer learning on the mobilenet model without using Teachable Machine (see the part2.js file).

Import the packages

import * as tf from "@tensorflow/tfjs";
import * as tfd from "@tensorflow/tfjs-data";

Load the model

You can load the mobilenet model in your application and return its internal function to use when re-training your custom model.

const loadModel = async () => {
  const mobilenet = await tf.loadLayersModel(
    "https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json"
  );

  const layer = mobilenet.getLayer("conv_pw_13_relu");
  return tf.model({ inputs: mobilenet.inputs, outputs: layer.output });
};

const init = async () => {
  webcam = await tfd.webcam(document.getElementById("webcam"));

  initialModel = await loadModel();
  statusElement.style.display = "none";
  document.getElementById("controller").style.display = "block";
};

Add examples

When implementing the functionality yourself, you need to add examples that the model can be trained with. In the index.html file, there are 2 buttons to record images with the label 'left' and 'right'. When the user presses these buttons, images are being kept in memory to train the model.

buttonsContainer.onmousedown = (e) => {
  if (e.target === recordButtons[0]) {
    // left
    handleAddExample(0);
  } else if (e.target === recordButtons[1]) {
    // right
    handleAddExample(1);
  }
};

const handleAddExample = async (labelIndex) => {
  mouseDown = true;
  const total = document.getElementById(labels[labelIndex] + "-total");

  while (mouseDown) {
    addExample(labelIndex);
    total.innerText = ++totals[labelIndex];
    // Returns a promise that resolve when a requestAnimationFrame has completed
    await tf.nextFrame();
  }
};

const addExample = async (index) => {
  let img = await getImage(); // Gets a snapshot from the webcam
  let example = initialModel.predict(img);

  // One-hot encode the label.
  // Turns categorical data (e.g. colors) into numerical data
  const y = tf.tidy(() =>
    tf.oneHot(tf.tensor1d([index]).toInt(), labels.length)
  );

  if (xs == null) {
    // For the first example that gets added, keep example and y so that we own the memory of the inputs.
    xs = tf.keep(example);
    ys = tf.keep(y);
  } else {
    const oldX = xs;
    xs = tf.keep(oldX.concat(example, 0));

    const oldY = ys;
    ys = tf.keep(oldY.concat(y, 0));

    oldX.dispose();
    oldY.dispose();
    y.dispose();
  }

  img.dispose();
};

const getImage = async () => {
  const img = await webcam.capture();
  const processedImg = tf.tidy(() =>
    img.expandDims(0).toFloat().div(127).sub(1)
  );
  img.dispose();
  return processedImg;
};

Train

After gathering new examples, we need to create a new model to train with the samples recorded.

const train = () => {
  isTraining = true;
  if (!xs) {
    throw new Error("Add some examples before training!");
  }

  newModel = tf.sequential({
    layers: [
      tf.layers.flatten({
        inputShape: initialModel.outputs[0].shape.slice(1),
      }),
      // units is the output shape of the dense layer
      tf.layers.dense({
        units: denseUnits,
        activation: "relu",
        kernelInitializer: "varianceScaling",
        useBias: true,
      }),
      // The neural network should have 2 outputs so the last layer should have 2 units
      tf.layers.dense({
        units: labels.length,
        kernelInitializer: "varianceScaling",
        useBias: false,
        activation: "softmax",
      }),
    ],
  });

  const optimizer = tf.train.adam(learningRate);
  newModel.compile({ optimizer: optimizer, loss: "categoricalCrossentropy" });

  const batchSize = Math.floor(xs.shape[0] * batchSizeFraction);
  if (!(batchSize > 0)) {
    throw new Error(
      `Batch size is 0 or NaN. Please choose a non-zero fraction.`
    );
  }

  newModel.fit(xs, ys, {
    batchSize,
    epochs: epochs,
    callbacks: {
      onBatchEnd: async (batch, logs) => {
        statusElement.innerHTML = "Loss: " + logs.loss.toFixed(5);
      },
    },
  });

  isTraining = false;
};

Predict

After the new model has been trained, we can run predictions with live input from the webcam.

predictButton.onclick = async () => {
  isPredicting = true;
  while (isPredicting) {
    const img = await getImage();
    const embeddings = initialModel.predict(img);
    const predictions = newModel.predict(embeddings);
    const predictedClass = predictions.as1D().argMax();
    const classId = (await predictedClass.data())[0];
    img.dispose();

    console.log(labels[classId]);
    await tf.nextFrame();
  }
};

That's it!

Additional resources