Since siamese networks are getting increasingly popular in Deep Learning research and applications, I decided to dedicate a blog post to this extremely powerful technique. I will explain what siamese networks are and conclude with a simple example of a siamese CNN network in PyTorch.
What Are Siamese Networks?
Siamese networks (Bromley, Jane, et al. “Signature verification using a” siamese” time delay neural network.” Advances in neural information processing systems. 1994.) are neural networks containing two or more identical subnetwork components. A siamese network may look like this:
It is important that not only the architecture of the subnetworks is identical, but the weights have to be shared among them as well for the network to be called “siamese”. The main idea behind siamese networks is that they can learn useful data descriptors that can be further used to compare between the inputs of the respective subnetworks. Hereby, inputs can be anything from numerical data (in this case the subnetworks are usually formed by fully-connected layers), image data (with CNNs as subnetworks) or even sequential data such as sentences or time signals (with RNNs as subnetworks).
Trending AI Articles:
Usually, siamese networks perform binary classification at the output, classifying if the inputs are of the same class or not. Hereby, different loss functions may be used during training. One of the most popular loss functions is the binary cross-entropy loss. This loss can be calculated as
, where L is the loss function, y the class label (0 or 1) and p is the prediction. In order to train the network to distinguish between similar and dissimilar objects, we may feed it one positive and one negative example at a time and add up the losses:
Another possibility is to use the triplet loss (Schroff, Florian, Dmitry Kalenichenko, and James Philbin. “Facenet: A unified embedding for face recognition and clustering.” Proceedings of the IEEE conference on computer vision and pattern recognition. 2015.):
Hereby, d is a distance function (e.g. the L2 loss), a is a sample of the dataset, p is a random positive sample and n is a negative sample. m is an arbitrary margin and is used to further the separation between the positive and negative scores.
Applications Of Siamese Networks
Siamese networks have wide-ranging applications. Here are a few of them:
- One-shot learning. In this learning scenario, a new training dataset is presented to the trained (classification) network, with only one sample per class. Afterwards, the classification performance on this new dataset is tested on a separate testing dataset. As siamese networks first learn discriminative features for a large specific dataset, they can be used to generalize this knowledge to entirely new classes and distributions as well. In (Koch, Gregory, Richard Zemel, and Ruslan Salakhutdinov. “Siamese neural networks for one-shot image recognition.” ICML Deep Learning Workshop. Vol. 2. 2015.), the authors use this capability to do one-shot learning on the MNIST dataset using a network trained on the Omniglot dataset (an entirely different image dataset).
- Pedestrian tracking for video surveillance (Leal-Taixé, Laura, Cristian Canton-Ferrer, and Konrad Schindler. “Learning by tracking: Siamese cnn for robust target association.” Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition Workshops. 2016.). In this work, a siamese CNN network is combined with size and position features of image patches to track multiple persons in the field-of-view of the camera by detecting their position in each video frame, learning the associations between multiple frames and computing the trajectories.
- Cosegmentation (Mukherjee, Prerana, Brejesh Lall, and Snehith Lattupally. “Object cosegmentation using deep Siamese network.” arXiv preprint arXiv:1803.02555 (2018).).
- Matching resumes to jobs (Maheshwary, Saket, and Hemant Misra. “Matching Resumes to Jobs via Deep Siamese Network.” Companion of the The Web Conference 2018 on The Web Conference 2018. International World Wide Web Conferences Steering Committee, 2018.). In this exotic application, the network tries to find matching job postings for applicants. In order to do this, a trained siamese CNN network extracts deep contextual information from both the postings and the resumes and computes their semantic similarity. The hypothesis is that matching resume — posting pairs will rank higher on the similarity scale than non-matching ones.
Example: Classifying MNIST Images Using A Siamese Network In PyTorch
Having explained the fundamentals of siamese networks, we will now build a network in PyTorch to classify if a pair of MNIST images is of the same number or not. We will use the binary cross entropy loss as our training loss function and we will evaluate the network on a testing dataset using the accuracy measure. Below is the entire code for this post:
As you can see, most of the code consists of building an appropriate Dataset class that provides us with random image samples. For the purpose of training the network it is crucial that we obtain a balanced dataset with as many positive as negative sampes. Therefore, on each iteration, we provide both at the same time. The code for the dataset is quite long but ultimately simple: for each number (class) 0–9, we have to provide a positive pair (another image of the same number) and a negative pair (image of a random different number).
The network itself, defined in the Net class, is a siamese convolutional neural network consisting of 2 identical subnetworks, each containing 3 convolutional layers with kernel sizes of 7, 5 and 5 and a pooling layer in-between. After passing through the convolutional layers, we let the network build a 1-dimensional descriptor of each input by flattening the features and passing them through a linear layer with 512 output features. Note that the layers in the two subnetworks share the same weights. This allows the network to learn meaningful descriptors for each input and makes the output symmetrical (the ordering of the input should be irrelevant to our goal).
The crucial step of the whole procedure is the next one: we calculate the squared distance of the feature vectors. In principle, to train the network, we could use the triplet loss with the outputs of this squared differences. However, I obtained better results (faster convergence) using binary cross entropy loss. Therefore, we attach one more linear layer with 2 output features (equal number, different number) to the network to obtain the logits.
There are three main relevant functions in the code: the train function, the test function and the predict function.
In the train function, we feed the network a positive and a negative sample (two pairs of images). We calculate the losses for each of these and add them up (with the positive sample having a target of 1 and the negative sample having a target of 0).
The test function serves to measure the accuracy of the network on the test dataset. We perform the test after each training epoch to observe the training progress and to prevent overfitting.
The predict function, given a pair of MNIST images, simply predicts if they are of the same class or not. You can use predict after training is finished by setting the global variable do_learn to False.
Using the implementation above, I was able to achieve 96% accuracy on the test MNIST dataset.
Source: Deep Learning on Medium