Original article was published by Neil Sinclair on Artificial Intelligence on Medium
Teaching BART to Rap: Fine-tuning Hugging Face’s BART Model
I taught BART to rap as part of the process of learning how to tweak the incredibly powerful Hugging Face Transformers models.
Transfer learning has provided an unimaginable boon to artificial intelligence over the past few years, making waves in the computer vision space and more recently in the NLP space with researchers discovering that a model trained on a language modelling task can easily (quickly and cheaply) be adapted for other tasks. From a practitioner’s perspective, aside from the deluge of new discoveries — easily accessible on Arvix — Hugging Face have developed unbelievably easy-to-use APIs that allow anyone to access these latest developments with a few lines of code.
In spite of the ease with which one can use the Hugging Face APIs both for on-the-fly inference and for fine tuning through command line style arguments, I got a little stuck trying to fine-tune the BART model. I’m aiming to use it for my Masters thesis and it took me an inordinate amount of time to write the code to fine tune the model because I got stuck in the process. However, once I’d managed to get past this, I’ve been amazed at the power of this model.
This article will give a brief overview of how to fine-tune the BART model, with code rather liberally borrowed from Hugging Face’s finetuning.py script. However, this will allow a bit more control over how one can experiment with the model. I’ve used PyTorch Lightning to handle the training, and if you’re new to it, I encourage you to get familiar with it. The implementation is incredibly straightforward and may be able to streamline some of your projects going forward. Although I’ve taught BART to rap here, it’s really just a convenient (and fun!) seq2seq example as to how one can fine-tune the model.
Just a quick overview of where I got stuck in the training process. The loss on my model was declining at a rapid pace over each batch, however the model was learning to generate blank sentences. For a long time, I couldn’t figure out why this was happening. It turns out that you need to manually shift the tokens to the right before you feed them to the decoder, but that you must pass the unshifted tokens to the loss function.
So, without further ado, this is how to teach BART to rap.
The training set
I found a great set of lyrics from this GitHub repo. The author explains how they scraped the lyrics using the Genius python API, but I just downloaded the lyrics that had already been scraped. I then spun up a quick notebook (here) where I created a set of line, next-line pairs of lyrics, such that:
Here L(n) represents line “n”, L(n+1) represents the following line and -> indicates the lines are paired in the training data. I also did a small amount of additional processing to ensure that songs wouldn’t bleed into each other and that a verse line wouldn’t be followed by a chorus line in the training pairs and vice versa. I also removed a lot of duplicate lines as failing to do this led to a model that often just generated the same lines, repeated, over and over again (due to the significant portion of line -> repeated-line pairs in the data).
From here, I noised the data. Because BART is trained as a denoising autoencoder I thought it best to pass noised data into the model for training. I’m not sure if this is necessary though. I replaced 25% of the data with the <mask> token, however I excluded the final word of the lyric line from being added to the replacement pool as this word plays a crucial role in supporting a rhyming scheme.
I also tried to set up the training set in such a way that a line could be predicted by more than just the previous line. This was done in the hope that during generation the model would be able to have greater coherence across a four-line verse. Concretely, as above if L(n) represents the nth line in the training set, it was set up such that:
However, when doing this I found that although the number of training examples grew significantly (obviously), the model was learning to copy output lines from the dataset. I thus abandoned this version of the training data.
The training was relatively straight forward (after I solved the plummeting loss issue). I used PyTorch Lightning to simplify the process of training, loading and saving the model. I also used ‘bart-base’ as the pre-trained model because I had previously had some GPU memory issues on Google Colab using ‘bart-large’. I trained the model for around 10 epochs. The code is available here.
Generating the text
When generating the text, I did two things. First, I fed a seed line into the generate_text() method (which used the BartForConditionalGeneration generate() method) and auto-regressively generated k new lines; secondly, I noised each line with 25% — 35% noised tokens. I found that noising the tokens like this generally gave more variation in the outputs of the model. In the end, I was pretty amazed with the results that I was able to get.
As a side note, I found an odd artifact with the model that if I generated 8 lines that generally the 5th and 8th line were very similar. I’m not quite sure what caused this.
Thoughts on improvements
Adding a longer range of lyrics for generation — although my experiment with creating more training data with longer lead-in lyrics (two and three lines leading to the target line) weren’t successful, there might be a way to improve this, by for example, adding <sep> tokens to the sentences.
I think it would also be interesting to see how BART works with a different genre of music, such as Country or Punk Rock. I also think it would be interesting to see what happens if one doesn’t noise the source lines when training the model. Although, as I recall, from the original BART paper all of the data was noised when training the model, I’m not sure if not noising the data would work.
Finally, every now and then when a really great line would pop up, I’d manually go search through the training data to see if the model was just copying it or if it was generating the text. Most of the time it was generating the text, however adding in a BLEU-like metric to get a sense of whether the model is copying or being “original” would be helpful.
There are a lot of options that one has available now when utilising pre-tuned models, especially with the great work that Hugging Face is doing with democratising access to the latest and greatest ones. BART shows a lot of promise for a wide-range of seq2seq tasks and having spent some time getting to know the model better, I’m very keen to see what else is possible.