Creating an interactive artificial intelligence experience in the browser with TensorflowJS

Alessandro Cauduro
12 min readMay 23, 2019

--

We can now run AI in the browser, let's explore what's possible!

You can view the result of this project in the online demo.

Tensorflow is the most popular Artificial Intelligence Framework at the moment, it was created by Google and released as an open source project four years ago. At this year's dev summit they announced version 1.0 of TensorflowJS , a special version of the framework in javascript. This has made it easy to run deep learning models on the browser and opened up a whole new world of possibilities of what can be done.

At the same event, Tensorflow 2.0 Alpha of the main framework was also announced, this is a major rewrite, incorporating the best practices from other frameworks such as eager execution and consolidating it's APIs that had many ways of doing the same thing. When I heard about all of this, I was eager (pun intended) to try them out!

TensorFlowJs, which runs on the browser, was made possible because it leverages the computer's graphic card processing unit (GPU) power by the clever use of WebGL that originally was intended to be used for rendering 3D graphics on the browser. This accelerates AI processing directly on the client side.

This is a game changer in two major ways:

Privacy: running AI on the browser, means we can do intelligent data processing locally and no user data needs to be uploaded to the cloud. Thus, protecting the user's privacy.

User Experience: anyone who has used an AI assistant such as Google Home, has noticed a lag. This is because, every time we say something, Google sends it to a datacenter in the cloud to be processed. If we process it locally on the browser, the response can be immediate. This enables new interactions and doesn't generate great costs as we run on the user's device.

Quick intro about me

To quickly put things into context, I'm CEO at a Design & Tech studio called Huia and a Tech Enthusiast. My company helps businesses innovate and stand out from the competition by leveraging unique user experiences taking advantage of recently made available technologies.

The origin of the company's name is a tribute to a rare and sacred bird called Huia that existed in New Zealand and unfortunately has been extinct for over a 100 years. Therefore, no video footage or sound recordings exist from this beautiful creature. So for this experiment with TensorflowJS, we decided to create an interactive Huia.

Illustration of a male and female Huia from Wikipedia.

Since our plan was to create an experience in the browser, we opted for ThreeJS a 3D Javascript library and the VueJS framework. Our initial experience was interactive via mouse interaction. Here is what our rendered 3D modelled Huia looks like responding to the mouse:

Huia 3d model interacting with mouse movements

Together with the release of TensorflowJS 1.0, Google also made available a few pre-trained models. Amongst them was PoseNet, a model that will take in an image and predict the position of your skeleton. This is what it looks like running in the browser capturing real-time video from the computer’s webcam:

Paulo and Mauricio from Huia helped out as body models for this project and had a lot of fun :-)

My first thought was that Posenet could be a perfect way for people to interact with our Huia Bird! I wasn't sure how it would perform, since we already had a pretty complex 3D model in place that was doing realtime shader effects (if you hover the mouse over the 3D model of the bird, a wireframe shader of the Huia pops up. It is one of many subtle but cool effects in place). But nothing like a challenge to figure out the limits and learn a new technology!

After a few days of coding or so, I had a working prototype! Wow! TensorFlowJS is cool! The bird would match my head position and jump when both arms where lifted at the same time. Even more impressive was that this was all running smoothly on my MacBook Pro 2016 that doesn't even have a discrete GPU!

The Huia bird was being triggered to jump after both arms were above a certain height threshold and I was mapping the head position (the small red projection dot in the GIF below) by making a simple rule of 3: when the left eye reached the left ear that would mean that the head is totally turned to the right and vertically I did the same using the nose as a reference when the eye would be positioned below it I considered the person was looking downwards. As you can see from the demo below it works pretty well! Simple but effective :-)

The red dot projects where the user is facing.
With just these coordinates and a simple rule of 3, I was able to predict where the user was looking at.

Technical Deep Dive: looking under the hood

Let's go through over the steps required to train a neural network and make it run on the browser. Our end goal, is for the bird to interact with the user by detecting poses and reacting accordingly.

Step 1 — Capturing Data

As in any AI project, the first step is always to collect and organize the available data. This is by far the most important step!

Posenet tries to detect up to 15 people in the video and gives a score of how “confident” it is on the prediction and each body part. Below is a sample of what this looks like for a single pose:

To automate the process of capturing the data, I created a simple program to capture individual poses and also a timer feature to take pictures every few seconds of myself. In 10 min I had gathered around 300 images for 4 different classes!

Looking at this data available, one approach to detect the current pose could be to compare the euclidian distance between these key points and the current pose’s key points. I wasn’t sure this would generalize well for different body types and so decided for another route.

As well as capturing these images, I also saved the JSON key point data.

Step 2 — Creating and Training our Model

The cool thing with computer vision is that the same techniques that are used for image classification, can be applied for other fields of expertise that don’t seem to relate by simply converting our data to images and using them instead.

For example, you can use use image classification with sound, if we convert sound waves to images! This is used in speech recognition.

I decided to go this route and converted our key point data into skeleton images, as in my mind this had the potential to be a great abstraction for many people. My thought was that with a little help from the Huia team I would be able to sample different variations of body type: short, tall, skinny, not so skinny, man, woman, etc.. and cover most people without needing thousands of people.

Start small, then expand.

As a proof of concept, I created a small neural network and it was able to classify amongst the 4 different poses with good success. This wasn’t an amazing feat, as it already had a 33% random choice of getting it right anyway :-).

After checking that things were going to work, I decided to make it easier for the neural network to discern the different poses and added color and thickness to the body parts and joints.

Color was added to help the neural network classify the images.

I also started capturing all the different poses that we were trying to classify. One of the poses is called “normal”, and is used to identify a resting position and doesn’t activate any actions.

Different poses we captured data for training.

At the same time, Rodolpho from Funn started work on the extra 3D animations for the Huia :-)

Spend time to get to know your data better!

I was anxious, so the first thing I did after collecting the data was to start training the model, and things didn't initially go well. This made me panic as I was certain it was going to work! But I only had a week left to finish the whole project and showcase the interactive Huia bird at a trade show!

I took a long breath and two steps back to understand what was going wrong. The first thing I did was to carefully go through the 1000 images that I had collected and then started removing anything that I personally couldn’t classify for a specific pose or that was wrongly classified. I mean, If I can’t recognise the position I can’t expect the AI to do so. I also found a few blank images, so all these things were making things confusing to our model to generalize. Of my original batch of 1000 images I was left with only 381 after I finished reviewing them! Thankfully, after this thorough review, things started working correctly!

I cannot emphasise enough on how understanding your data will give you better results. Problems I had in the project where all related to data quality!

Architecture & Transfer Learning

For our architecture I decided to start out by creating a MobileNet, a proven and efficient neural network architecture that is used for image classifications in lower power devices (this is important, as I didn't want for users to require a super computer to run it on the browser :-) ).

The cool thing about MobileNet, is that we can find a pre-trained model with the famous ImageNet dataset directly from the TensorFlow Library and then extended it with some of our own layers. This is know as Transfer Learning, where we start using an existing trained model instead of a randomly initialised data. This has been proven to get better results than starting from scratch, specially when we don't have so much data. By the way, PoseNet also uses MobileNet as it's core.

TransferLearning is a proven technique to achieve better results when we have little data, specially in computer vision.

This is what our model looks like:

The table below has a more detailed description.

Image Data Augmentation

Since we only had 381 images left, that translates to a little under 50 images for each pose we want to classify. This might not be enough . So I decided to do Image Data Augmentation, which means we programatically create variations of our data that are also valid and can use them for training and cover more positions that we didn't have in the original dataset.

Image Augmentation can help a model generalize better and we can get away with having less data.

Ideally we do data augmentation in runtime so it is forever creating variations of our images as it requests each image. But at the time of writing this, The TensorflowAddons that has all the advanced image manipulation features wasn't compatible with the Tf.data.Dataset pipeline that requires graph mode. So instead I pre-created static images on the file system. For each captured image I created 10 variations using the following transformations and combined them randomly:

  • random rotate
  • random crop / zoom
  • warp
  • affine
  • translate

other augmentations that didn’t make sense for this use case but can make sense in other situations:

  • random flip horizontal / vertical
  • change saturation / brightness / compression
Data Augmentation can help generate more data and create a better model.

After data augmentation we went from 381 to 4191 images and a much more robust model could be trained.

Dropout

Our goal is always to make our model generalize to real world use cases, so I used Dropout to force the network to learn across different neurons and not depend too much on each one. The ideia behind dropout is simple, discard random data from the forward pass, and therefore a model that converges won't depend heavily on any one parameter.

Training

For the training, I started with a first run of 20 epochs, for this I only trained our extended part of the neural net with a learning rate of 3e-4 and the training data quickly achieved 0.99 validation accuracy, but as you can see the test data accuracy stopped improving after a while. This is faster than training the whole network that has over 25 million parameters and is a good starting strategy.

learning rate = 3e-4

For the next step, I "unfroze" the whole model and let it train for 40 epochs with a learning rate of 3e-6. We had over 800 images in the validation set and where able to achieve an outstanding 100% accuracy in both sets.

learning rate = 3e-6

I was planning initially to plot a confusion matrix to explore where the network was making mistakes, but there weren't any, so no point in doing that:-)

Quantization

Our model was trained using 32-bit floating point, that means 4 bytes per parameter. So after we save it, it has a whopping 62MB! Google has discovered that quantising /down sampling the model to 16 bit ints we don't loose too much accuracy. Doing this, we can reduce the model size to 10MB which is still big, but much better!

Quantization example graph from Wikipedia.

Tensorflow has this feature built in and will interpret it in runtime automatically, so to quantize our model after training we can just run:

tfjs.converters.save_keras_model(model, "models_tfjs/huia_model",quantization_dtype=np.uint16

or via command line:

tensorflowjs_converter \
--input_format=keras \
--output_format=tfjs_layers_model \
./models/huia_model.h5 \
./models_tfjs/tfjs_model \
--quantization_bytes 2

Step 3— Putting everything together + have fun

This is the code needed to load our model in the browser:

import * as tf from '@tensorflow/tfjs';
customMobilenet = await tf.loadLayersModel(MODEL_HUIA_URL);

You can checkout our online demo here

The first Huia Hadouken in the world!

Conclusions

I really enjoyed seeing where Tensorflow 2.0 and TensorflowJS are heading. I did reach some roadblocks and this was expected as TF2 is still "alpha" software. In the following months we can expect a preview release that should be more stable. Artificial intelligence in the browser is bringing exciting new experiences as well as other machine learning solutions that didn't seem possible!

Source Code

Want to look under the hood and see how it was done? Then checkout the source code on Github. The project is separated into 3 modules as the article above:

  • 01_image_capture (Node + Tensorflow JS)
  • 02_training (Jupyter Notebooks + Tensorflow 2)
  • 03_interactive_experience (Node +ThreeJs + VueJS + TensorflowJS)

If it inspires you to create something yourself, let me know as I'm curious to see what people will create with TensorflowJS.

Backlog

I wrote down some things that if I had time to investigate more, I would try out:

  • Compare image classification to using a Siamese network with our data points / embeddings.
  • Use Face Landmark detection to sync the persons expressions to the bird
  • Implement a more robust face pose estimate.
  • Implement Deep Video Classification to recognise user activity and not just user pose
  • Try to compact and create a smaller neural networks to achieve similar results. We are running in the browser, so the less the better :-)

This project was made possible with the help from:

Art Direction & Design: Caio Ogata and Carolina Dutra Silveira
Original interactive Huia coding: Luiz Sordi
3D Modeling and Animations: Funn
Docker and Site Infrastructure: Mauricio Klagenberg
Body Data used for pose training:
Alessandro Cauduro, Carolina Dutra Silveira, André Gonçalves, Mauricio Klagenberg and Paulo Araújo.

--

--