Approach to optimize decomposable but non-differentiable objectives

Original article was published by Kirill Kravtsov on Deep Learning on Medium


Approach to optimize decomposable but non-differentiable objectives

In a recent Tweet Sentiment Extraction competition participants were given tweets and their sentiments and were asked to predict word spans (some continuous parts of those tweets) which reflect the sentiments. Predicted spans were compared to true spans and evaluated with Jaccard metric calculated on a word level. I’ve tried a very interesting idea of approximating non-differentiable objective by differentiable neural network and using it as a loss function because the competition metric was non-differentiable. It resulted in score improvements both on cross-validation and on private leaderboard scores compared to other “proxy” loss functions such as cross entropy used by many other participants. Below i will describe the approach and its implementation with Pytorch.

Problem statement

Jaccard definition

Jaccard metric for one sample was defined as the following python code:

def jaccard(str1, str2): 
a = set(str1.lower().split())
b = set(str2.lower().split())
c = a.intersection(b)
return float(len(c)) / (len(a) + len(b) - len(c))

The final metric was calculated as the average Jaccard across test samples.

Span selection algorithm

Regardless of the base model architecture choice, usually in order to predict text span, model outputs a tensor of shape (n_batch, seq_length, 2) where n_batch — number of samples in the batch, seq_length — input sequence length. So there are two vectors of length seq_length per sample representing probabilities for each word (often it’s not a word but token which can be a part of a word, but for simplicity I am referring to a word) of being start span word (the first word in a span) and end span word (the last word in a span) respectively. Mean cross entropy over start and end probabilities can be used as a loss function. Then with argmax start and end tokens are selected in order to finally predict the text span.

Figure 1: Span selection algorithm

There is also an approach of predicting a tensor of shape (n_batch, seq_length) with probabilities for each word of being included into span and using binary cross entropy over every word separately.

Figure 2: Alternative span selection algorithm

But this approach requires to apply some additional heuristics to select a final span based on predicted probability vector (e.g. applying threshold on the picture above). Also whenever i try such method it works worse than the first one.

Jaccard differentiability problem

Let’s stick to the first approach described above. As the metric is calculated for a selected span of text we need to select this span first based on model’s output. The reason why we can’t just use Jaccard as a loss directly is the need of calculating argmax which breaks gradient flow.

As the direct Jaccard optimization wasn’t possible here, many participants were using cross entropy as a proxy loss function. Many of them tried some additional heuristics to make CE closer to Jaccard.

The solution

This idea is basically very simple: what could be a universal differentiable approximation of any function? We all know the answer — it’s neural network.

Let’s train a meta model which takes as an input a prediction of the main model (start and end scores) and a true span which can be represented as a vector of length seq_length with values 0 for words not in the span and 1 for words in the span. The inputs can be represented as a tensor of shape (n_batch, seq_length, 3) — these are the features of our loss model (start scores, end scores and true spans — targets for the base model) and we want to predict the actual Jaccard score for a given target and a model’s prediction.

Figure 3: Jaccard loss model

Let’s call such a network a Jaccard (loss) model. After Jaccard model training we can simply freeze its weights and use it as a loss function. Some intuition around it: by backpropagating Jaccard model and obtaining gradients of the start and end scores it basically tells us how we need to change base (main) model’s output in order to improve Jaccard score (from Jaccard model’s point of view). So we backpropagate gradients from Jaccard model into the main model in order to optimize an approximated Jaccard score.

Figure 4: Forward step with Jaccard loss model
Figure 5: Backward step with Jaccard loss model

Literature review

Already after implementing this idea I decided to search if some research was done before in that direction and actually I found 1 paper which proposes the same algorithm — Learning Surrogate Losses. What I find strange is that authors claim to be able to optimize non-decomposable objectives such as roc-auc but I am not sure how it works when batch size is small and/or data is unbalanced.

Tricks

Adversarial approach

At first I tried just to pretrain Jaccard model on saved logs (raw model outputs and true targets for all the epochs from cross entropy training run), then freeze it and simply use it as a loss. As a result the main model quickly learns how to “trick” Jaccard model in a way that predicted Jaccard is almost 1 while true Jaccard metric (properly calculated via formula) is almost 0.

If you are familiar with GANs concept then the solution of this problem is quite obvious — you need to continue Jaccard model training in parallel with the main model training. So if it learns to “trick”, Jaccard models learns how to prevent it.

Later I found out that the authors of the paper posted above refer to this approach as “Bilevel Programming”.

Combining with cross entropy

I found out that the Jaccard model loss gives results worse than cross entropy. These results will be discussed further. While experimenting I found out that combination of the Jaccard model loss and cross entropy gives the best result. Below I will show the implementation of the exact combined loss version.

Implementation

The base (main) model

Currently in many nlp tasks state-of-the-art results are achieved with pretrained transformer-like models (though RNNs are also still used, especially on kaggle where one can improve results by ensembling models with diverse predictions). My best single model was a pretrained Roberta base and I used it for this experiment.

Jaccard model

After a few experiments I’ve chosen the following RNN-based architecture:

It looks a bit overcomplicated but in general there are 2 bidirectional GRU layers whose hidden states are aggregated using max and mean poolings and passed as features into a 2-layer linear head without activations. There are also skip connections over both GRU layers in order to improve gradient backpropagation. You can also see some new features: new_words — a sequence of binary indicators of input tokens being the first tokens in its words and bin_sentiment_words — simply an element-wise product of bin_sentiment (it’s just a true span binary vector on a token level, “true span” on Figure 3) and new_words. We need these features because model need to predict Jaccard on a word level.

I found that binary cross entropy works better as meta loss than MSE for the final model (in terms of true Jaccard score).

Training implementation

For my Pytorch training pipeline I’ve been using Catalyst deep learning framework which incapsulates training loop, provides callback mechanism and has many other useful features. There is also a GAN training example in demo colab notebook. Since the approach is very similar to GANs, my implementation was based on the example above.

If you are not familiar with Catalyst, take a look on their docs. The main entity there is Runner class and the main entry point is train method of Runner. Runner has a method _handle_batch and it will be redefined and used for forward and backward passes for both the main model and the Jaccard model.

But first let’s assume that we already have:

  • Data loaders: train_loader and valid_loader
  • Both models created: main_model and jaccard_model
  • Separate optimizers for both models: main_optimizer and jaccard_optimizer
  • Function for batch mean Jaccard calculation based on Tensor input — jaccard_func to generate y_true values for the Jaccard model

And here is the adversarial logic implemented:

Note: ids , mask , token_type_ids are default inputs for transformers.

And then to run the training:

Results analysis

Using combined loss improved mean cross validation (5 fold) score from 0.712 to 0.715 and private leaderboard score from 0.716 to 0.718. Model with only Jaccard model loss had cv score only about 0.65.

Let’s also analyze how good cross entropy and Jaccard model loss are by calculating an absolute value of the Pearson correlation coefficient between them and ground truth Jaccard score. For that let’s just log all three values during the training. The values are:

  • 0.953 between Jaccard loss model and ground truth Jaccard score
  • 0.704 between cross entropy and ground truth Jaccard score

Further ideas

  1. Despite that Jaccard model loss looks like a better approximation of Jaccard metric than cross entropy, it’s still unclear why Jaccard model loss has much lower score. Understanding the reason of this could help to improve the approach.
  2. It would be nice to have a confidence intervals for the scores in order to prove that the method actually works.
  3. Jaccard loss model might be improved by using more sophisticated architecture, e.g. Transformer-based.