Data Maps: Datasets Can Be Distilled Too

Original article was published by Elior Cohen on Deep Learning on Medium

My Take

From my point of view, I’ll start by saying that it was very interesting to discover a new source of information which I never dealt with before — training dynamics. Practically, the notion of ambiguous samples and creating models with better generalization is basically what practical ML is all about, so although not as “sexy” concept as an article about a novel exotic architecture, I think that this work and works that will be based on it are very important. The trend of bigger and bigger datasets will only continue, as we are always advancing in the data we collect. The ability to build high-quality datasets will become more crucial with time. Lastly, having another tool for building models that generalize better is very relevant and important and what makes it more interesting, is that this tool is about the data, and not only the model.

Creating Data Maps in TensorFlow 2

As I found this article interesting, I had to try it for myself. As some of my readers might already know, whatever module I build for TensorFlow and I believe would be useful across different data science application, I keep in my tested and documented PyPI package — tavolo.

For experimenting with data maps, in tavolo I created a callback that calculates all the metrics needed for creating data maps. You can see the full notebook in my Github, usage is quite simple.

In my experiment, I took a DistillBERT model and applied it on the QNLI dataset.
The usage of the callback and training dynamics calculation in quite simple:

import tensorflow as tf
import tavolo as tvl
# Load data
train_data = ... # Instance of
train_data_unshuffled = ... # Another instance of the train_data
# without shuffling
model = ... # Instance of tf.keras.Model# Creating the callback
datamap = tvl.learning.DataMapCallback(train_data_unshuffled)
# Training
model.compile(...), callbacks=[datamap], ...)
# Access training dynamics
confidence = datamap.confidence # This is a numpy array
variability = datamap.variability # This is a numpy array
correctness = datamap.correctness # This is a numpy array
# Plot datamap

Running this code on the QNLI dataset with the DistillBERT model, resulted in the following data map (taken from the notebook)

That’s it 🙂 I hope you enjoyed reading and have learnt something new ✌️