How To Train Your Chatbot With Simple Transformers

Source: Artificial Intelligence on Medium


Chatbots and virtual assistants, once found mostly in Sci-Fi, are becoming increasingly more common. Google Assistant’s and Siri’s of today still has a long, long way to go to reach Iron Man’s J.A.R.V.I.S. and the like, but the journey has begun. While the current crop of Conversational AI is far from perfect, they are also a far cry from their humble beginnings as simple programs like ELIZA.

Moving away from the typical rule-based chatbots, Hugging Face came up with a Transformer based way to build chatbots that lets us leverage the state-of-the-art language modelling capabilities of models like BERT and OpenAI GPT. Using this method, we can quickly build powerful and impressive Conversational AI’s that can outperform most rule-based chatbots. It also eliminates the need for tedious rule building and script writing necessary for building a good rule-based chatbot.

Simple Transformers offers a way to build these Conversational AI models quickly, efficiently, and easily. The Simple Transformers implementation is built on the Hugging Face implementation given here.


Getting the environment set up is fairly straightforward.

  1. Install Anaconda or Miniconda Package Manager from here
  2. Create a new virtual environment and install packages.
    conda create -n transformers python
    conda activate transformers
    If using Cuda:
    conda install pytorch cudatoolkit=10.1 -c pytorch
    conda install pytorch cpuonly -c pytorch
  3. Install Apex if you are using fp16 training. Please follow the instructions here. (Installing Apex from pip has caused issues for several people.)
  4. Install simpletransformers.
    pip install simpletransformers


We’ll be using the Persona-Chat dataset. Note that you don’t need to manually download the dataset as the formatted JSON version of the dataset (provided by Hugging Face) will be automatically downloaded by Simple Transformers if no dataset is specified when training the model.

Conversational AI Model

ConvAIModel is the class used in Simple Transformers to do all thing related to conversational AI models. This includes training, evaluating, and interacting with the models.

At the moment, you can use any of the OpenAI GPT or GPT-2 models with ConvAIModel. However, the pre-trained model provided by Hugging Face performs well out-of-the-box and will likely require less fine-tuning when creating your own chatbot. You can download the model from the here and extract the archive to follow along with the tutorial (which assumes you have downloaded the model and extracted it to gpt_personachat_cache).

The code snippet above creates a ConvAIModel and loads the Transformer with the pre-trained weights.

The ConvAIModel comes with a wide range of configuration options, which can be found in the documentation here. You can also find the list of globally available configuration options in the Simple Transformers library here.


You can further fine-tune the model on the Persona-Chat training data by simply calling the train_model() method.

This will download the dataset (if it hasn’t already been downloaded) and start the training.

To train the model on your own data, you must create a JSON file with the following structure.

This structure follows the structure used in the Persona-Chat dataset as explained below. (Docs here)

Each entry in Persona-Chat is a dict with two keys personality and utterances, and the dataset is a list of entries.

  • personality: list of strings containing the personality of the agent
  • utterances: list of dictionaries, each of which has two keys which are lists of strings.
  • candidates: [next_utterance_candidate_1, …, next_utterance_candidate_19]
    The last candidate is the ground truth response observed in the conversational data
  • history: [dialog_turn_0, … dialog_turn N], where N is an odd number since the other user starts every conversation.


  • Spaces before periods at end of sentences
  • everything lowercase

Assuming you have created a JSON file with the given structure and saved it in data/train.json, you can train the model by executing the line below.



Evaluation can be performed on the Persona-Chat dataset just as easily as the training by calling the eval_model() method.

As with training, you may provide a different evaluation dataset as long as it follows the correct structure.


Although you can get a numerical score by calculating metrics on an evaluation dataset, the best way to learn how good a Conversational AI is to actually converse with it.

To talk with the model you have just trained, simply call model.interact(). This will pick a random personality from the dataset and let you talk with it from the terminal.


Alternatively, you can create a personality on the fly by giving the interact() method a list of strings to build a personality from!

Wrapping Up

Tip: To load a trained model, you need to provide the path to the directory containing the model file when creating the ConvAIModel object.

model = ConvAIModel("gpt", "outputs")

That’s it! I hope this tutorial helps you on your way to creating your own chatbot!