Multi-Label Classification using BERT, RoBERTa, XLNet, XLM, and DistilBERT with Simple Transformers

Source: Deep Learning on Medium


The Simple Transformers library is built on top of the excellent Transformers library by Hugging Face. You guys are incredible!

Simple Transformers now supports:

There’s plenty more in the pipeline.


Transformer models and Transfer Learning methods continue to propel the field of Natural Language Processing forward at a tremendous pace. However, state-of-the-art performance too often comes at the price of tons of (complex) code.

Simple Transformers avoids all the complexity and lets you get down to what matters, training and using Transformer models. Bypass all the complicated setups, boilerplates, and other general unpleasantness to initialize a model in one line, train in the next, and evaluate with the third.

This guide shows how you can use Simple Transformers to perform Multilabel Classification. In Multilabel Classification, each sample can have any combination (none, one, some, or all) of labels from a given set of labels.

All source code is available on the Github Repo. If you have any issues or questions, that’s the place to resolve them. Please do check it out!


  1. Install Anaconda or Miniconda Package Manager from here.
  2. Create a new virtual environment and install packages.
    conda create -n simpletransformers python pandas tqdm
    conda activate simpletransformers
    If using cuda:
    conda install pytorch cudatoolkit=10.0 -c pytorch
    conda install pytorch cpuonly -c pytorch
    conda install -c anaconda scipy
    conda install -c anaconda scikit-learn
    pip install transformers
    pip install seqeval
    pip install tensorboardx
  3. Install Apex if you are using fp16 training. Please follow the instructions here. (Installing Apex from pip has caused issues for several people.)
  4. Install simpletransformers.
    pip install simpletransformers

Multilabel Classification

To demonstrate Multilabel Classification we will use the Toxic Comments dataset from Kaggle. Download the dataset from the link above and place the csv files in the data/ directory.

Data Preparation

The comments in the dataset have been labelled according to the criteria below.

  • toxic
  • severe_toxic
  • obscene
  • threat
  • insult
  • identity_hate

The dataset contains a column for each criterion with a Boolean 1 or 0 indicating whether or not the comment contains the corresponding toxicity.

However, Simple Transformers requires a column labels which contains multi-hot encoded lists of labels as well as a column text which contains all the text (duh!).

Let’s split the dfinto train and eval datasets so we can validate the model easily.

Now the dataset is ready for use!

Multilabel Classification Model

This creates a MultiLabelClassificationModel that can be used for training, evaluating, and predicting on multilabel classification tasks. The first parameter is the model_type, the second is the model_name, and the third is the number of labels in the data.

  • model_type may be one of ['bert', 'xlnet', 'xlm', 'roberta', 'distilbert'].
  • For a full list of pretrained models that can be used for model_name, please refer to Current Pretrained Models.

The args parameter takes in an optional Python dictionary of hyper-parameter values and configuration options. I highly recommend checking out all the options here.

The default values are shown below.

To load a model a previously saved model instead of a default model, you can change the model_name to the path to a directory which contains a saved model.

model = MultiLabelClassificationModel('xlnet', 'path_to_model/', num_labels=6)


This will train the model on the training data. You can also change the hyperparameters by passing in a dict containing the relevant attributes to the train_model method. Note that, these modifications will persist even after training is completed.

The train_model method will create a checkpoint (save) of the model at every nth step where n is self.args['save_steps']. Upon completion of training, the final model will be saved to self.args['output_dir'].


The eval_model method is used to perform evaluation on an evaluation dataset. This method has three return values.

  • result: The evaluation result in the form of a dict. By default, only the Label ranking average precision (LRAP) is reported for multilabel classification.
  • model_outputs: A list of model outputs for each item in the evaluation dataset. This is useful if you need probabilities for each class rather than a single prediction. Note that a sigmoid function has been applied to each output to squash the values between 0 and .
  • wrong_predictions: A list of InputFeature of each incorrect prediction. The text may be obtained from the InputFeature.text_a attribute. (The InputFeature class can be found in the file in the repo)

You can also include additional metrics to be used in the evaluation. Simply pass in the metrics functions as keyword arguments to the eval_model method. The metrics functions should take in two parameters, the first one being the true labels, and the second being the predictions. This follows the sklearn standard.

Make sure that the metric functions are compatible with multilabel classification.


While eval_model is useful when we know the correct labels and merely need to asses the performance of a model, we rarely know the true labels in real-world tasks (I’m sure there’s some profound philosophy there). In such cases, the predict method comes in handy. It is similar to the eval_model method except that it doesn’t require the true labels and returns the predictions and the model outputs.

We can try it out on the test data provided in the Toxic Comments dataset.

Submitting this to Kaggle nets me a score of 0.98468, once again demonstrating how far NLP has progressed since the advent of Transformers and Transfer Learning. Keep in mind that I haven’t done much hyper-parameter tuning here!


BERT and its derivatives are awesome! I hope Simple Transformers helps smooth out a few bumps on the road to using them.