Original article was published on Artificial Intelligence on Medium
Contrasting contrastive loss functions
A comprehensive guide to four contrastive loss functions for contrastive learning
In a previous post, I wrote about contrastive learning in supervised classification and performed some experiments on MNIST dataset and alike to find that the two-stage method proposed in the Khosla et al. 2020 paper indeed shows significant improvement for supervised classification task by learning meaningful embeddings with contrastive loss. Later I found my experiments actually used a different contrastive loss function than Khosla et al. proposed. Although sharing the same intuition of explicitly contrasting examples against each other with respect to their labels, different contrastive loss functions can have their own nuances. In this post, I will review a series of contrastive loss functions and compare their performances in supervised classification task.
Contrastive loss functions were invented for metric learning, which intends to learn similarity functions that measure the similarity or distance between a pair of objects. In the context of classification, the desired metric would render a pair of examples with the same label more similar than a pair of examples with different labels. Deep metric learning involves deep neural networks to embed data points to a lower-dimensional space with nonlinearity, then uses contrastive loss function to optimize the parameters in the neural networks. Recent research projects have applied deep metric learning to self-supervised learning, supervised learning and even reinforcement learning, for example, Contrastively-trained Structured World Models (C-SWMs).
To review different contrastive loss functions in the context of deep metric learning, I use the following formalization. Let 𝐱 be the input feature vector and 𝑦 be its label. Let 𝑓(⋅) be a encoder network mapping the input space to the embedding space and let 𝐳=𝑓(𝐱) be the embedding vector.
Types of contrastive loss functions
Here I review four contrastive loss functions in chronological order. I slightly changed the names of a few functions to highlight their distinctive characteristics.
1. Max margin contrastive loss (Hadsell et al. 2006)
Max margin contrastive loss function takes a pair of embedding vectors z_i and z_j as inputs. It essentially equates the Euclidean distance between them if they have the same label (y_i=y_j) and is otherwise equivalent to hinge loss. It has a margin parameter m > 0 to impose a lower bound on the distance between a pair of samples with different labels.
2. Triplet loss (Weinberger et al. 2006)
Triplet loss operates on a triplet of vectors whose labels follow 𝑦_𝑖=𝑦_𝑗 and 𝑦_𝑖≠𝑦_𝑘. That is to say two of the three (𝐳_𝐢 and 𝐳_𝐣) shared the same label and a third vector 𝐳_𝐤 has a different label. In triplet learning literatures, they are termed anchor (z_i), positive (z_j), and negative (z_k), respectively. Triplet loss is defined as:
, where 𝑚 again is a margin parameter that requires the delta distances between anchor-positive and anchor-negative to be larger than 𝑚. The intuition for this loss function is to push the negative sample outside of the neighborhood by a margin while keeping positive samples within the neighborhood. Here is a great graphical demonstration showing the effect of triplet loss from the original paper:
Based on the definition of the triplet loss, a triplet may have the following three scenarios before any training:
- easy: triplets with a loss of 0 because the negative is already more than a margin away from the anchor than the positive
- hard: triplets where the negative is closer to the anchor than the positive
- semi-hard: triplets where the negative lies in the margin
Triplet loss has been used to learn embeddings for faces in the FaceNet (Schroff et al. 2015) paper. Schroff et al. argued that triplet mining is crucial for model performance and convergence. They also found that the hardest triplets led to local minima early on in training, specifically resulted in collapsed models, whereas semi-hard triplets yields more stable results and faster convergence.
3. Multi-class N-pair loss (Sohn 2016)
Multi-class N-pair loss is a generalization of triplet loss allowing joint comparison among more than one negative samples. When applied on a pair of positive samples 𝐳_𝐢 and 𝐳_𝐣 sharing the same label (𝑦_𝑖=𝑦_𝑗) from a mini-batch with 2𝑁 samples, it is calculated as:
, where z_i z_j is the inner product, which is equivalent to cosine similarity when both vectors have unit norm.
As the figure below shows, N-pair loss pushes 2N-1 negative samples away simultaneously instead of one at a time:
With some algebraic manipulations, multi-class N-pair loss can be written as the following:
This form of multi-class N-pair loss helps us introduce the next loss function.
4. Supervised NT-Xent loss (Khosla et al. 2020)
Let’s first look at the self-supervised version of NT-Xent loss. NT-Xent is coined by Chen et al. 2020 in the SimCLR paper and is short for “normalized temperature-scaled cross entropy loss”. It is a modification of the multi-class N-pair loss with addition of the temperature parameter (𝜏) to scale the cosine similarities:
Chen et al. found that an appropriate temperature parameter can help the model learn from hard negatives. In addition, they showed that the optimal temperature differs on different batch sizes and number of training epochs.
Khosla et al. later extended NT-Xent loss for supervised learning:
Next I assess the whether these contrastive loss functions can help the encoder network to learn meaningful representations of the data to aid the classification task. Following the exact same experimental settings from my previous post, with small batch size (32) and low learning rate (0.001), I found all these contrastive loss functions except for triplet loss with hard negative mining outperforms the MLP baseline without the stage 1 pre-training:
These results confirmed the benefit of using contrastive loss function in the pre-training the encoder part of the network for the subsequent classification. It also underscored the importance of triplet mining for triplet loss. Specifically, semi-hard mining works the best on these experiments, which is consistent with the FaceNet paper.
Both Chen et al. (SimCLR) and Khosla et al. use very large batch sizes and higher learning rates for the NT-Xent loss to achieve greater performances. I next experimented with different batch sizes 32, 256 and 2048 with learning rates of 0.001, 0.01, and 0.2, respectively.
The results show that the performances diminish as the batch size increases for all loss functions. Although triplet loss with semi-hard negative mining performs very well on small/medium batches, it is very memory intensive and my 16G RAM is impossible to handle it with a batch size of 2048. Supervised NT-Xent loss does turn to perform relatively better on larger batch size compared to its counterparts. There could be space for improvement for supervised NT-Xent if I were to optimize the temperature parameter. The temperature I used was 0.5.
Next, I checked the PCA projections of the embeddings learned using contrastive loss functions to see if they learn any informative representations during the pre-training stage.