How Much Do You (Really) Need to Train Your Model?

Original article was published by RVisco on Deep Learning on Medium

How Much Do You (Really) Need to Train Your Model?

Some approaches to reduce the time and data needed for NLP Language Modeling and Text Classification.

Photo by Jordan Benton from Pexels

In my brief role as a Data Scientist and Machine Learning engineer, I’ve come across a few projects that required an exorbitant amount of time to re-train their Text Classification (TC) models, based on a large volume of new data coming in.

This new data did not necessarily contain all new features (labels), but there was certainly a significant amount of new feature data coming in on a weekly basis that could negatively impact the performance (accuracy) of the trained TC model. This is typically known as “model drift” and there are several good articles describing this:

For this article, I will focus on outlining some helpful ways to:

  1. reduce the amount of training time, using the library
  2. reduce the amount of data needed to train your model
  3. keeping your TC model’s performance consistent, even with a relatively large amount of new feature data continually coming in to your data stream.

The examples below utilize the (version 1) library, which abstracts much of the PyTorch ML algorithms, ML architectures, and optimization techniques behind the scenes. But you can also apply the same approaches on data-reduction to any ML implementation you are comfortable with.

For the purpose of this article, I’ll use the term “label(s)” to mean “text classification label(s)”, or “feature(s)”.

24-hour Re-training (Yikes!!)

On one of my Text Classification projects, I wanted to see if I could reduce the model re-training time from over 24 hours to something more reasonable. The data being collected was continuously growing, and at this point was over 800MB in size.

Some observations:

  1. There was no pruning of any “old” data, so the dataset kept growing. However, there were “snapshots” taken of the growing data on a weekly basis.
  2. There was no distinction being made of “new” labeled data as compared to the “outdated” labeled data. This meant that each time a piece of text was to be classified, the model prediction was potentially still using “outdated” labels, that it probably wouldn’t see anymore in the real world.
  3. There was no distinction being made to frequency of labels. For instance, while some feature data had upwards of 100’s of examples, some feature data only had one example to learn from.
  4. There was no “sampling” of the data being trained. It was basically all-or-nothing.

Now, all of the above are not necessarily bad things in and of themselves. But when re-training time starts to become painfully slow, or the accuracy of the model starts to degrade, it may be time to take some action on the above.

One of the first things I immediately wanted to try was to see if there was any significant difference in training the TC model with the entire text corpus (850k records) versus a sampled subset of the data. I decided to try two sample sizes: 25k and 100k records. to the rescue!

As mentioned earlier, I used the (v1) library because it has demonstrated to be one of the best training models (at least for computer image and text classification) in terms of accuracy and ease of implementation.

Since has done much of the experimentation and legwork to optimize machine learning (as well as teaching ML via their free MOOC), it significantly reduces the time to effectively train your model, by removing a lot of the guess-work behind most Neural Network projects.

I also used because it significantly reduced the amount of code and infrastructure setup I would have needed to do a quick assessment and POC for this project.

The project I was working on was running with an older (and non-compatible version 0.7) of the library. Since the 1.0 version had many more features and optimizations, I decided to “re-write” the model-training portion using the new version.

This approach actually made it MUCH easier to test out different ideas rather quickly, since I did not need to create any python application-specific code libraries. I could do everything (test random sample file sizes, train and measure the model, and test predictions) in just three (3) simple Jupyter Notebooks!

The Old, the New, and the Unique

I started out (using Jupyter Notebooks) to load the “snapshots” of both the Previous Corpus of data, and the latest (New) Corpus of data:

As you can see, there were about 55,000 more records in the NEW data file.

I then compared the unique labels in each dataset, to get a sense of how many new labels are showing up in the data stream that I may need to re-train on:

I then set a threshold for how many times (frequency) a label should appear in the data to be considered for sampling. In the example below, I decided that a unique label should appear at least twice to be considered. Otherwise, I remove those labels from the dataset:

After filtering the under-represented labels out, I like to take another unique comparison count (as I did earlier). This shaved off 349 unique labels (1,322 minus 973):

It’s OK to be “selective” with your “random”

Given all the above pre-processing and filtering, I then wanted to see what the final count of unique labels in the New dataset that were not in the Previous dataset. I did this because I wanted to make sure I included this unique label subset in my upcoming re-training and holdout (test) data.

I distinguished these New unique labels from the remaining data, so that my (random) sampling came only from the non-New unique labels. In the example below, I wanted to take 25k records as my training sample size, but I also wanted to include ALL of the New unique labels as part of this training set. In addition, I accounted for 500 records for my Holdout Test data:

As you can see, there were 9,079 records that contained only new labels, while 16,421 comprised the remainder of the training set, for a total of 25,500.

500 records were set aside for the Holdout Test data (.csv) file.