Shrink your Tensorflow.js Web Model Size with Weight Quantization



A quick and simple Guide to Weight Quantization with Tensorflow.js

Thanks to the awesome work of the tensorflow.js team, machine learning in the browser has become a very hot topic amongst us javascript and web developers. Tensorflow.js allows us to ship pretrained models down to the clients browser and to run inference directly on the client side.

Luckily, a client technically has to download a model only once, since typically such a web model is chunked into 4MB shards, such that a browser will cache them. However, we want the initial loading time to be as short as possible and to reduce the amount of bytes a client has to store for any of our models.

Why you should definitely quantize your Model Weights!

Simple answer: We would rather download a model of 15MB than a 60MB model, right? This is a no brainer! Yes, we can reduce the size of a model by a factor of 4 and it’s basically for free! I am using this technique for all the models exposed by face-api.js.

Simply put, with weight quantization we can compress our model parameters from Float32s (4 bytes) to Uint8s (single bytes) by mapping each tensors’ values from the range [min value, max value] represented by 255⁴ bits to [0, 255] represented by 255 bits. Therefore we subtract the minimum of a tensors’ values and apply a scaling factor to it. We are storing the min value and scale for each tensor along with our model meta data and once we are loading the model weights again we apply the inverse operation (dequantization).

What about the Accuracy of the Model?

Now you might wonder: Mapping tensor values from a 255⁴ bit representation to 255 bits? There must be some serious loss in accuracy of my model going on, right? Well, not necessarily. More precisely, tensor values will be clipped during the process, such that they end up with a lower floating point precision, but in my experience, for most cases, the overall model accuracy is not affected by weight quantization at all.

To be fair, there are some exceptions, when some tensors have an unfortunate value distribution, which I have been facing only once so far with the face recognition model I ported from dlib to tfjs. But in those cases, it doesn’t mean, that it is not possible to reduce the size of our model weights at all. In fact you can identify these tensors and simply leave their weights untouched, while we can still reduce the size of the remaining tensors (if you are running into this problem, see last section).

Ok I am convinced! But how?

Good question! Now this can be either very simple or just simple. First of all, we will discuss the very simple case, which is, you already have a tensorflow or keras model. You can simply use the tfjs-converter tool and pass the quantization flag to the CLI, when converting your model to a web model:

tensorflowjs_converter --quantization_bytes 1 --input_format=tf_frozen_model --output_node_names=logits/BiasAdd --saved_model_tags=serve ./model/input_graph.pb ./web_model

Obviously, this is the easiest and the preferable way to quantize your model weights. So if you can, I would recommend to use this tool.

But what to do if:

  1. one or multiple of our tensors are of the nasty kind, which messes with the accuracy of our model, as pointed out in the previous section?
  2. we have ported some existing model architecture (caffee, torch, darknet, whatever…) directly to tfjs?
  3. we trained our model with tfjs in the browser?

Short Answer:

We simply run our tensors through the tfjs-converter quantization script.

Long Answer:

Well, obviously this is a python script, hidden somewhere in the depths of the tfjs-converter source code. We would have to store all of our tensors data, run them through the script and build the weights_manifest.json file. Of course, you could certainly do that. Alternatively you could just do everything in javascript:

Using the javascript quantization implementation, we can then copy and paste the following into a html page, add the logic for loading the weight tensors of the model we want to shrink, open it in the browser and call quantizeAndSave() from the console, which will download the quantized model shards and the weights_manifest.json file:

Note, the empty function body of “getNamedTensors”. Here you should implement your own logic, which returns { name: string, tensor: tf.Tensor } pairs. If we take a look into a weights_manifest.json file, we can see that every tensor is named:

{
"name":"conv0/filter",
"shape":[3,3,3,32],
"dtype":"float32",
"quantization":{
"dtype":"uint8",
"scale":0.004699238725737029,
"min":-0.7471789573921876
}
}

Furthermore, we are using FileSaver.js, to download files. Simply get this script from this repo or by running npm i file-saver.

Afterwards we can simply load our compressed model using tf.io.loadWeights(manifest, modelBaseUri) like this:

const manifest = await (await fetch('/models/my-model-weights_manifest.json')).json()
const tensorMap = tf.io.loadWeights(manifest, '/models')

This will apply dequantization for each tensor with a “quantization” object in it’s manifest entry. Finally, it returns the named tensor map and that’s it:

{
"conv0/filter": tf.Tensor4D,
...
}

So, what if the Accuracy of my Model indeed drops?

In case you do not notice any kind of loss in accuracy of your model after weight quantization, everything is fine. But how to identify the culprits, in case accuracy did indeed suffer from quantization?

Unfortunately, this is a bit of a hairy situation. In this case, you should iteratively try to exclude tensors from quantization (or try to iteratively include tensors for that matter), to determine the tensors, which cause the models accuracy to drop after their values have been quantized.

To skip quantization of a tensor, simply make sure to set the “isSkipQuantization“ flag for the corresponding tensor in getNamedTensors.


If you liked this article you are invited to leave some claps and follow me on medium and/or twitter :). Also stay tuned for further articles and if you are interested, check out my open source work!

Source: Deep Learning on Medium