Using BERT For Classifying Documents with Long Texts

Source: Deep Learning on Medium

Using BERT For Classifying Documents with Long Texts

How to fine-tuning Bert for inputs longer than a few words or sentences


BERT (stands for Bidirectional Encoder Representations from Transformer) is a Google’s Deep Learning model developed for NLP task which has achieved State-of-the-Art Pre-training for Natural Language Processing in multiples task. However one of its “limitation” is on application when you have long inputs, because in BERT the self-attention layer has a quadratic complexity O(n²) in terms of the sequence length n (see this link), in this post I followed the main ideas of this paper in order to know how overcome this limitation, when you want to use BERT over long sequences of text.

Getting Ready

For this article we will need Python, Bert, Tensorflow and Keras If you do not have it yet, please install all of them.

1. The Dataset

The dataset is composed of data extracted from kaggle, the dataset is text from consumer finance complaint narrative, the model attempts to predict which product the complaint is about. So is a multi-class classification problem.

Let’s check our data:

The dataset has 18 columns however is this article we are using only the columns: consumer_complaint_narrative and product

2. Preprocessing the Data

As preprocessing:

  1. Select only the rows where the column consumer_complaint_narrative is not null

2. Because we are focusing in “long texts” we are selecting only the rows where the amount of words is more than 250:

3. Select only the 2 columns to consider:

4. Now let’s consolidate the product categories as proposed in this article:

We ended with 10 classes:

5. Rename the columns to text and label:

6. Encode the label column to numeric:

7. Remove non alphanumeric characters from the text:

8. Split the datain to train (80%) and validation (20%)

3. Format the data for BERT model

In this article as the paper suggests, we are going to segment the input into smaller text and feed each of them into BERT, it mean for each row we are split the text in order to have some smaller text (200 words long each ), for example:

from a given long text:

We must split it into chunk of 200 word each, with 50 words overlapped, just for example:

So we need a function to split out text like explained before:

and apply it to every row in our dataset

As you can see in this way we ended with a column (text_split) which every row has a list of string of around 200 word length.

4. Fine Tuning Bert

This article is not about how BERT work there are a lot of better articles for that, like this or this one or the official one, if you are unfamiliar with BERT please check them out.

Fine tuning bert is easy for classification task, for this article I followed the official notebook about fine tuning bert.

Basically the main steps are:

  1. Prepare the input data, i.e create InputExample using the BERT’s constructor:

2. Convert the InputExamples into features BERT understands:

3. Create the model as created in the official notebook, basically is create a single new layer that will be trained to adapt BERT to our task. please refer to the notebook where I put all the code.

4. Train the model:

5. Evaluate to see how well is:

BERT is amazing it achieved a good 85% just with fine tuning, however we are using the vector representation of this fine tuned model as input for another more simple model as follow.

5. Get the BERT vector as text representation

After fine tuning BERT we need to extract the representation from it, in other words we need the output pooled of every text chunk.

So I modified the function below in order to extract de output pooled for our fine tuned BERT:

For every 200-lenght chunk we extracted a representation vector from BERT of size 768 each

Now let’s extract the representation of every chunk of text:

The numpy method: apply_along_axis help a lot in speed when you ar trying to apply a function to a dataset column, when compare to pandas apply, itertuples or iterrows.

And the result:

a columns of vectors representation and a columns of label

6. Make an LSTM model over the BERT representation

Now we are going to build a simple LSTM model having as input the vectors created before, however in this case or when you have long text secuences the most of the time this sequences are variable, I mean there will be text with number of words 300, 550, 1000, etc, so the number of 200-length chunk is not fixed, so our vector of representacions are variable length.

For that reason we have to deal with a lstm-model with variable input length, so we have 3 options:

  1. Padding to fixed length
  2. Batch size = 1
  3. Batch size > 1, padding and masking.

First 2 are inefficient, so we are choosing the option 3, batch size more than one,padding to the max length and masking, in this way we pad the shorter sequences with a special value to be masked (skipped for the network) later.

In this case the special values is -99

But I do not want to pad every sequence to the larget one, instead I used a generator function which takes batches of size 3, get the size of the largest one and extends the 2 lefts to the size of the largest, filling them with the special value, this process is along all the data when training:

This way, all batches sequences would have the same length.

why batch size 3? because when using a generator to train the model you have to fixed batch size and batches per epoch in order to garantee that all of your data is passing in your training process.

You must follow this ecuation:

number of row in the data = batch size * batches per epoch

In this case: 13713 = 3 * 4571

For your own generator you can choose a different value for batch size however it must follow the equation above

Finally we train the model, using the kera’s callback named ReduceLROnPlateau which reduce the hyperparameter learning rate if the validation’s accuary does not improving

7. Evaluating the model

We evaluated the model with unseen data and get as results:

Getting an accuracy of 87% and loss 0.41

Final Words

The techniques for classifying long documents requires in mostly cases padding to a shorter text, however as we seen you can use BERT and some techniques like masking to make a model, good enougth for this task.

In the paper, they proposed another method: ToBERT (transformer over BERT that you can implement and compare with this)

The complete code can be found on this Jupyter notebook, and you can browse for more projects on my Github.

Also my linkedin