What’s My Line? Next Sentence Prediction in RunwayML with BERT
I spent a few Fridays recently working as an artist-in-residence with RunwayML, a company building software to make heavy-duty machine learning (ML) tech easier to use in creative projects. I’m pretty much their ideal customer: someone with an interest in creative ML but no GPUs, a small computer and a punishing fear of The Cloud.
I’ll focus on language models, which represent a small but growing number of those available in RunwayML. I’ll describe a situation (next sentence prediction) where I think large ML language models are interesting, then talk through a couple of proof-of-concept experiments in RunwayML.
What are language models good for?
A language model is a set of rules for evaluating the likelihood of sequences of text. Language models have many uses, including generating text by repeatedly answering the question: Given some text, what could come next?
If you’re designing a program to suggest the next word in a sentence, you can get pretty far with simple algorithms based on counting n-grams, sequences of words as found on the web, or in a book, or any corpus you choose.
An n-gram strategy to predict the next word in the line “Oops, I did it” might look at its last two words (“did it”) and cross-reference them with counts of all three-word sequences starting that started with “did it”. The algorithm’s top suggestions are all the third words in those sequences: words like “again” and “all” and “to” and “for” and “hurt”, depending on the training data.
But say you want to use more than the last couple words. Say you want to suggest lines based on the whole previous line.
Then the counting approach won’t do. Short sequences of words like “did it” are common enough that we can easily find examples of them in the wild, including what came after. But sequences longer than a few words tend to be rare.
Using exact occurrences, suggestions for lines to follow the line “Oops, I did it again” would be pretty thin. Even if the training data included all songs ever written, you’d be limited to lines from Britney Spears’s “Oops!… I Did It Again” plus later references to that song.
That’s not so satisfying — more like rote regurgitation than creation. Wouldn’t it be better if you could ask not what sentences best follow this exact sentence, but what sentences best follow this kind of sentence?
BERT For Next Sentence Prediction
BERT is a huge language model that learns by deleting parts of the text it sees, and gradually tweaking how it uses the surrounding context to fill in the blanks — effectively making its own flash cards from the world and quizzing itself on them billions of times.
I’ve always been intimidated by BERT’s unwieldiness. It takes up a lot of memory and takes far more compute to run than a lightweight Markov model. But I thought it could open the door to interesting line-by-line songwriting suggestions, so during my residency at RunwayML I ran a couple of experiments to see what kinds of line suggestions I could get from it.
RunwayML takes care of the compute resources for large models, which it hosts remotely. To get BERT running, I first had to “port” a version of the model to RunwayML. This meant creating a Github repo with two key files: a list of what the remote machine will need to install to run the model, and a script to run the model itself. Here’s a tutorial post that explains the porting process.
I chose to port BertForNextSentencePrediction, a version of BERT from the NLP startup HuggingFace that does just what it says on the tin: given two sentences, estimate how likely the second is to follow the first.
Here’s what the repo looks like. And here’s the code for the crucial script, runway_model.py:
The command takes two inputs: the prompt sentence and a list of candidate sentences to follow it.
First, let’s see what the model thinks about the original Britney couplet: “Oops, I did it again / I played with your heart”
Input 1: ‘Oops, I did it again’
Input 2: [‘I played with your heart’]
Output: The model’s loss for Input 1 followed by Input 2. The loss can be thought of as how much the model is surprised by the sequence. The lower the loss, the more likely it judges the sequence to be.
I’m not sure what a score of 4.0966539 means in isolation, so let’s see how it compares to loss scores across more candidates.
Input 1: ‘Oops, I did it again’
Input 2: All lines from a collection of The Beatles’ most popular songs
Output: All lines and their loss scores
The “loss” value is an index of how surprised the model is by a given sequence, expressed as a logarithm. So the higher number, the further off the model’s predictions were. A low or negative value means the model considers the sequence very likely.
Sorting by loss gives us top contenders like these:
Oops, I did it again / They’d seen his face before
Oops, I did it again / When I feel that somethin’
These lines don’t really seem to follow. I don’t know why. Are sentences any good as a guide to how lines of a song should succeed one another? Why does the model think that first line is so likely? Am I misinterpreting the output and looking at the wrong number entirely?
One of the perils of large language models is that questions like these are often hard to investigate. Huge models with lots of parameters don’t readily yield causal explanations for why they make decisions —they just decide. Their complexity also raises the barrier for novice users. Even after reading the docs for Bert-NSP, I’m still shaky on the output format.
With RunwayML, I got a giant language model running remotely on a GPU. That’s a win, but I can’t take the project further until I better understand the model’s output.
I’d really like to be able to get continuation suggestions as I’m writing, so that I can choose one, run the model again with the chosen suggestion as a prompt, and get another set of suggestions fast enough that the process feels like stepping through a branching maze, not waiting for the model to finish.
As I keep working on this project, I’m on the lookout for more interpretable models to try on line predictions. Maybe I’ll find one in the module for generative text that RunwayML is rolling out this month.