Train on Google Colab and Run on the Browser: A Case Study

We will create a simple tool that recognizes drawings and outputs the names of the current drawing. This app will run directly on the browser without any installations. We will leverage Google Colab for training the model and the js version of TensorFlow.

Code and Demo

Find the live demo and the code on GitHub. Also make sure to test the notebook on Google Colab here.


We will use a CNN to recognize drawings of different types. The CNN will be trained on the Quick Draw dataset. The dataset contains around 50 millions drawings of 345 classes.

A subset of the classes


We will train the model on GPU for free on Google Colab using Keras then run it on the browser directly using TensorFlowJS(tfjs) . I created a tutorial on tfjs make sure to read it before continuing. Here is the pipeline of the project

The pipeline

Train on Colab

Google provides free processing power on a GPU. You can see this tutorial on how to create a notebook and activate GPU programming.

Load the Data

Since we have a limited memory we will not train on all the classes. We will only use 100 classes of the dataset. The data for each class is available on Google Cloud as numpy arrays of the shape [N,784] where N is the number of of the images for that particular class. We first Download the dataset

Since our memory is limited we will only load to memory 5000 images per classes. We also reserve 20% of the data unseen for testing

Preprocess the Data

We preprocess the data to prepare it for training. The model will take batches of the shape [N, 28, 28, 1] and outputs probabilities of the shape [N, 100]

Create the Model

We will create a simple CNN. Notice that the simpler the model with lesser number of parameters the better. Indeed, we will run the model after conversion on the browser and we want the model to run fast for prediction. The following model contains 3 conv layers and 2 dense layers.

Fit, Validate and Test

After that we train the model for 5 epochs and 256batches

Here is the results of the training

And the testing accuracy is 92.78% top 5 accuracy.

Prepare the model for Web Format

After we are satisfied about the accuracy of the model we save it in order to convert it

we install the tfjs package for conversion

then we convert the model

This will create some weight files and the json file which contains the architecture of the model.

zip the model to prepare for downloading it to our local machine

finally download the model

Inference on the Browser

In this section we show how to load the model and make inference. Save the weight files to folder say model .

In order to use tfjs first use the following script

Now we can load the model

Load the class names to show for the user

The Drawing Canvas

For drawing on the canvas we will use the library fabricjs

When the user draws on the canvas we use listeners to catch the drawing coordinates. This is important because we are using the coordinates to crop the minimum bounding box around the drawing. Indeed, we don’t want our canvas to have most of it empty if the user draws something small

Preprocessing and Inference

We need to preprocess the data before making a prediction. We take the current frame of the canvas and resize then normalize

Finally we have the getFrame()method which gets the current frame of the canvas, preprocesses it and finds the top predictions

Source: Deep Learning on Medium