A guide to simple text classification with bert

Source: Deep Learning on Medium


In late 2018, Google open-sourced bert, a powerful deep learning algorithm for natural language processing. bert can be pre-trained on a massive corpus of unlabeled data, and then finetuned to a task for which you have a limited amount of data. This post will show you how to finetune bert for a simple text classification task of your own.

How bert works: a brief overview

bert is pre-trained on two different tasks: 1. in a sentence with two words removed, bert is trained to predict what those two words are,

and 2. given two sentences, bert is trained to determine whether one of these sentences comes after the other in a piece of text, or whether they are just two unrelated sentences.

The beauty of using these two tasks to do the pre-training, is that the training sets can be obtained programmatically, rather than through costly human annotation efforts. As a result, bert can be pre-trained on a truly massive corpus of text, in the processing learning rich representations of language that are impossible to learn with small labeled datasets.

bert’s final layers can then be finetuned on a task of your choosing, that will benefit from the rich representations of language it learned during pre-training.

How to use bert for text classification

Google’s documentation on bert is generally good, but how one can use bert on a simple text classification task isn’t immediately obvious. By simple text classification task, we mean a task in which you want to classify/categorize chunks of text that are roughly a sentence to a paragraph in length. The bert documentation shows you how to classify the relationships between pairs of sentences, but it doesn’t detail how to use bert to label single chunks of text.

Here’s how you can do it:

  1. Clone the bert github repo onto your own machine. Just open up your terminal and type this:
git clone https://github.com/google-research/bert.git

2. Download the bert model files. These are the weights and other necessary files to represent the information bert learned in pre-training. You’ll need to pick which bert pre-trained weights you want. If you don’t have access to a Google TPU, you’ll want to pick one of the “base” models. You should pick a “cased” model or and “uncased” model depending on whether you think letter casing will be helpful for the task you’re trying to solve. Save this into the directory where you cloned the git repository and unzip it. Here are links to the files for English:

Files for other languages can be found on the bert project github page.

3. Get your data into the format bert expects. Make a folder in the directory you cloned bert into. You’ll be putting three separate files in here called train.tsv dev.tsvand test.tsv. In train.tsv and dev.tsv you should have columns with no headers as follows: Column 1: an ID for the row (can be just a count, or even just the same number or letter for every row if you don’t care to keep track of each individual example), Column 2: the label for the row as an int. These are the classification label that your classifier aims to predict. Column 3: A column of all the same letter — this is a throw-away column that you need to include because the bert model expects it. Column 4: the text examples you want to classify. Here is an example of what the data in train.tsv and dev.tsvshould look like:

1    0    a    an example of text that should fit in class 0
2 1 a an example of text that should fit in class 1
3 0 a another class 0 example
4 2 a a class 2 example

test.tsv should have a slightly different format. It has Column 1: an ID for each example, similar to column 1 in the train and dev files, and Column 2: the text you want to classify. Also,test.tsv should have a header line (whereas train and dev should not). Here is an example of what test.tsv should look like:

id  sentence
1 my first test example
2 another test example. Yay this is fun!
3 yet another test example

If you’re looking for an easy way to get data into this format, I recommend making it into a csv file and then using the pandas Python package to convert it into a tsv. If you don’t already have a csv file containing your data, you can make one by using a tool like Google Sheets and exporting as csv. If you do this, make sure to put your columns in the right order before you export it. Here’s code to use pandas to convert csv to tsv:

import pandas as pd
df = pd.read_csv('path/to/your/csv/here.csv')
df.to_csv('path/of/your/choice.tsv', sep='\t', index=False, header=False)
# if you are creating test.tsv, set header=True instead of False

You can make your own choice as to how much of your data you want in train, test and dev sets, but a good rule of thumb is 80% in train, and 10% each in dev and test.

4. Run training. Navigate to the directory you cloned bert into, and type the following commands (or put them in a shell script and run the script).

export BERT_BASE_DIR=./path/to/weights/downloaded/in/step2
python bert/run_classifier.py \
--task_name=cola \
--do_train=true \
--do_eval=true \
--data_dir=./data \
--vocab_file=$BERT_BASE_DIR/vocab.txt \
--bert_config_file=$BERT_BASE_DIR/bert_config.json \
--init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
--max_seq_length=128 \
--train_batch_size=32 \
--learning_rate=2e-5 \
--num_train_epochs=3.0 \
--output_dir=./bert_output/

If you get an out of memory error, you may need to run this on a machine with a GPU that has more on-board RAM or a TPU (see instructions for TPUs in the bert github repo). You can try to fix this issue by reducing the training_batch_size, though the training will run slower as a result. If your typical text is longer than 128 words, you can increase max_seq_length up to a max of 512, though the model will run slower if you do this and you may get an out of memory error.

It can take a long time to train, so this step may take a while. You should see output regarding progress as it runs.

Once it’s finished running, you’ll get reports on how the model did in the bert_output directory.

Using bert to predict on new data

If you want to run predict on new data, you can put that data into test.tsv in the same format as we did in step 3 above. Then go into the bert_output directory and note the number of the highest-number model.ckptfile you see there. This set of files contains the weights for the model you trained. Once you’ve determined the highest checkpoint number, run the following commands in the terminal or through a shell script:

export BERT_BASE_DIR=./path/to/weights/downloaded/in/step2
export TRAINED_CLASSIFIER=./bert_output/model.ckpt-[highest checkpoint number you saw]
python bert/run_classifier.py \
--task_name=cola \
--do_predict=true \
--data_dir=./data \
--vocab_file=$BERT_BASE_DIR/vocab.txt \
--bert_config_file=$BERT_BASE_DIR/bert_config.json \
--init_checkpoint=$TRAINED_CLASSIFIER \
--max_seq_length=128 \
--output_dir=./bert_output/

Make sure the max_seq_length parameter is the same as you set it to during training. You should now get a file in bert_output called test_results.tsv. This file will have a number of columns equal to the number of classes you were aiming to classify, with the probability of each class for each example in each row. The rows are in the same order as the rows of data you had in test.tsv.

The power of bert

bert has greatly increased our capacity to do transfer learning in NLP, and that is an important step on the road to much more advanced NLP features. On the first problem in which I applied bert, I obtained a 66% improvement in accuracy over the best model I had tried up till that point.

If you have questions, feel free to post to me, and, if you found this helpful, please do give a clap!