Semantic segmentation for dummies

Semantic segmentation is one of the most popular topics in computer vision, more even since pasts years with the adoption of deep learning techniques. Below image shows an example of semantic segmentation result.

There are a lot of models capable of solving the problem, but despite being really accurate, it is a fact that network topologies behind them are really complex, things like RPNs, FPNs, ROIs, and anchors that may be a little confusing for anyone starting at segmentation. So, that is why I decided to show everyone the easy approach I took.

Thing is a while ago I faced a problem at work I knew that best shot to solve it was using semantic segmentation, but by then, there was no complex networks like Mask-RCNN, so I had to think of a valid one on my own, easy enough to explain my coworkers and not so computationally heavy in time (this last one was due to timing restrictions given by the problem itself).

After many tries, the winning idea came to my mind. I though that If autoencoders were used to compress image underlying information, and later on for denoising, why not to create an encoder/decoder topology which input is the image to segmentate and output, a binary mask image representing segmentation (that’s for only one class…). Some topology similar to the one shown in the bellow image.

Encoder/decoder topology for segmentation

As my initial problem was to perform segmentation over only one object per image, the first model returned only one binary mask. Though in time the problem evolved and had to make little changes to be able to support multiple classes. First thing I did was to use a C channels output mask, one for each object class plus one for unknown ones.

I tried to describe the problem as a probability distribution one (that is softmax activation for the last layer), meaning the value of all channels for each pixel was the probability for the pixel of belonging to each class, and I realised that would not allow objects from different classes to overlap.

So I end up using the classic C channels binary regressions where for each channel 0 means that pixel does not belong to that class and 1 means that it belongs to the class.

Dataset

So for demonstration purporses I built up a synthetic dataset to test multi-class semantic segmentation model. This is 512×512 images, a grayscale one for input (could be color, you will just need to have probably more parameters in your model) and output image has 5 channels, as many as figure types (circles, tringles, quads and pentagons) plus 1 for unknown objects. Below images shows samples for input/output pairs.

Dataset input/output pair samples. Left part is input image and right is output masks

An important detail you may have noticed is that output image is not apparently a 5 channels image. And that is as of for visualization purporses I assigned a color to each channel and then composed the color image.

Topology

Once the dataset is described, next part is to describe network topology. As we have already said it is mainly an encoder/decoder with certain considerations.

First of, some notes about several structures used in the model you may need to be familiar with:

  • Transposed convolution. This layer is also know as fractionally strided convolution or deconvolution and basically tries to perform the “inverse” operation of what convolution does.
  • Batch normalization. A commonly used mechanism to avoid mini-batch values to spread too many, normalizing it against its moving mean and moving standard deviation. It comes handy to both reduce overfitting and speed-up training process.
  • Residual connections. That is a mechanism used for the first time in ResNet, that helps to add needed information from the block input and helps to improve performance of the model during train phase.
  • Dropout. You already may be familiar with this layer, but just in case what it does is to probabilistically ignore certain connections helping this way to reduce overfitting.

On one side the encoder, it has N encoding blocks (3 for this example), and each one has multiple convolutional layers on top of the input with different kernel sizes (3×3 and 7×7 in this case), which are glued with a concat/residual lambda function. After convolutional layers there are the batch normalization, activation, dropout max-pooling layers. Topology for the encoder can be seen in the below image.

Encoder topology

On the other side, deocder is pretty similar to the encoder, but rather than using convolutional layers it uses transposed convolutional layers and obviously it uses upsampling layers instead of pooling layers. Last major difference is that adds a layer at the end with as many channels as classes (5) with sigmoid activation. Topology for the decoder can be seen in the below image.

Decoder topology

Results

So after describing network topology it is time to show results. Predictions are made with a model trained ~900 epochs during ~10h on a p2.xlarge AWS instance with a ~4e-3 loss, ~3e-3 SME and ~99.90% accuracy. Following images shows on the left the ground truth and on the right masks predictions.

Prediction sample 1
Prediction sample 2
Prediction sample 3
Prediction sample 4
Prediction sample 5
Prediction sample 6

As you can see results for this dataset are almost prefect, though is it a fact dataset is really simple, but I am confident network topology is expandible for more complex datasets.

Further work

Many of the models that perform segmentation uses a mechanism called skip connections, that is picking up feature maps from encoding phase and adding a residual connection to its similars during decoding phase. That is suppose to solve the loss of resolution due to pooling layers. This may be an improvement to try in the future in case the model stops learning or results are not good enough for your dataset.

On the other hand, there is an issue I noted that comes from the network topology definition itself. It happens that If two objects from the same class overlaps there is no way to tell that there are two objects in case you would run a computer vision post-process. This is the main reason why Mask-RCNN uses ROIs, so it only focuses on 1 object region rather than a connected group of pixels from the same class. This issue may be impossible to overcome given the network topology, but maybe someone could come up with a solution.

Conclusions

I have explained one approach to semantic segmentation that I hope it comes handy for you, for me at least it has already been.

Jupyter notebook to test it is in this GitHub repository, feel free to fork, comment and give some stars.

Thank you so much for reading and please let me know If you are interested in the matter, so we can exchange knowledge.

Source: Deep Learning on Medium