Understanding XLNet

Original article was published on Deep Learning on Medium

Unsupervised representation learning has been highly successful in the NLP domain. Models are pre-trained on the unsupervised tasks so they could perform well on the downstream tasks without starting from scratch. Traditional language models used to train networks left-to-right to predict next word in the sequence, clearly lacking the “bidirectional context”, then ELMo tried to capture it but it was two separate models (L to R, R to L) concatenated together, therefore not simultaneously bidirectional. Then BERT showed up with the capability of both bidirectional context training and parallelization, showing state-of-the-art results on almost every NLP datasets. But BERT suffers from few discrepancies and was outperformed by XLNet, which we will discuss in the next sections.

Unsupervised pre-training methods:

Here we will discuss 2 methods of pre-training objectives.

  1. Auto-regressive (AR) language modeling:

AR language modeling seeks to estimate the probability distribution of a text corpus with an autoregressive model. Given a text sequence x = [x1,… ,xT ], AR language modeling performs pre-training by maximizing the likelihood under the forward autoregressive factorization:

where h(x1:t-1) denotes context representation produced by neural models and e(xt) is the embedding of x.

2. Auto-encoding (AE) language modeling:

AE model does not perform explicit density function estimation like in AR but reconstruct the original data from corrupted input. BERT is an example for AE approach where Masked Language Model (MLM) is performed on the original sequence. Model masks 15% of the tokens at random with [MASK] token and then predict those masked tokens at the output layer.

Let the original sequence is x = [x1, x2,…,xT] and ^x is the corrupted version with few tokens marked [MASK] and let those few masked token be x_bar. The objective is to reconstruct x_bar with ^x:

where m(t)= 1 indicates x(t) is masked, and H is a Transformer.

Pros and Cons of AR and AE modeling:

Context dependency: Since an AR language model is only trained to encode a uni-directional context (either forward or backward), it is not effective at modeling deep bidirectional contexts, while BERT access the contextual information from both sides.

Independence Assumption: BERT assumes that all masked tokens are separately reconstructed whereas AR language modeling objective factorizes using the product rule that holds universally without such an independence assumption. Let’s consider an example [New, York, is, a, city]. Let “New” and “York” are masked words. Then BERT objective is:

Input noise: The input to BERT contains artificial symbols like [MASK] that never occur in downstream tasks, which creates a pretrain-finetune discrepancy.In comparison, AR language modeling does not rely on any input corruption and does not suffer from this issue.

XLNet to the rescue:

As we saw both the model approaches have their own pros and cons and we want to get the best of both models’ pre-training methods. XLNet helps achieving that in the following ways:

Permutation Language Modeling:

Permutation based language modeling retains the benefits of AR model and also include bidirectional context. For a sequence of length T, there are T! different orders to perform a valid autoregressive factorization. Since parameters are shared across all factorization orders, the model gathers all the information from context of both sides.

Let Z be the set of all possible permutations of the length-T index sequence [1,…..,T], z(t) and z(<t) denote the tᵗʰ element and the first (t-1) elements of a permutation z ∈ Z , then permutation language modeling objective can be expressed as follows:

This objective fits into the AR framework, it naturally avoids the independence assumption and the pretrain-finetune discrepancy.

Example for Permutation Language Modeling, here t is the iteration of 1 to T.

Permutation LM only permutes the factorization order, not the sequence order. In other words, XLNet keep the original sequence order, use the positional encodings corresponding to the original sequence, and rely on a proper attention mask in Transformers to achieve permutation of the factorization order.

Example:

Permutation mask: (i, j) cell represents whether token i attends to token j

Consider an example, [quick, brown, fox, jumps] and for this sequence assume we have permutation z = [1,3,4,2]. Then the attention/permutation mask will be given as left fig.

XLNet objective for previously mentioned example

What will be the permutation in case of second term of the above example?

Answer: (any permutation sequence of [1,3,4,5], 2)

XLNet always learns more dependency pairs given the same target.

Issue in Permutation Language Modeling:

The standard Transformer parameterization may not work in the case of Permutation LM, let’s see how.

Consider two different permutations [This, great, is] and [This, is, great] and let’s suppose given z(<t), we are predicting next word at tᵗʰ position ( t = 2).

P(Xz(₂) = ‘great’|[This]) → z(2)=3

P(Xz(₂) = ‘is’|[This]) → z(2)=2

Here z(<t) is [This] and now we are predicting next word using standard softmax formulation. Since for both the permutations, z(<t) is same, next word could be anything (“great” or “is”). Therefore, it’s important to know where the next word belongs in the original sequence, in this case 3 and 2 position resp, but the softmax function do not have this information.