Learn How to Train U-Net On Your Dataset

Fig.1 : A test image along with its label (semantically segmented output)

With the aim of performing semantic segmentation on a small bio-medical data-set, I made a resolute attempt at demystifying the workings of U-Net, using Keras. Since I haven’t come across any article which explains the training steps systematically, the thought of documenting this for other deep learning enthusiasts, occurred to me. Some of you must be thinking if I cover the theoretical aspects of this framework. Although my primary focus is to elaborate on the implementation, I try to include details relevant to its understanding as well. Most of my references include zhixuhao’s unet repository on Github and the paper, ‘U-Net: Convolutional Networks for Biomedical Image Segmentation’ by Olaf Ronneberger et.al.

About U-Net

The U-net architecture is synonymous with an encoder-decoder architecture. Essentially, it is a deep-learning framework based on FCNs; it comprises two parts:

  1. A contracting path similar to an encoder, to capture context via a compact feature map.
  2. A symmetric expanding path similar to a decoder, which allows precise localisation. This step is done to retain boundary information (spatial information) despite down sampling and max-pooling performed in the encoder stage.
Fig.2: Architecture of U-Net based on the paper by Olaf Ronneberger et.al

Advantages of Using U-Net

  1. Computationally efficient
  2. Trainable with a small data-set
  3. Trained end-to-end
  4. Preferable for bio-medical applications

Wait, what on earth are semantic segmentation and localisation? In simple terms, they refer to pixel-level labelling i.e. each pixel in an image is provided with a class label. You get segmentation maps, which look like the one in Fig.1. Prior to this, bio-medical researchers followed two approaches:

  1. Classifying the image as a whole (malign or benign).
  2. Dividing the images into patches and classifying them.

Owing to an increased data-set size, patching was certainly better than a whole image classification, however, there were a few drawbacks pertaining to the same. Smaller strides or patching with a lot of overlapping is both, computationally intensive and results in redundant (repetitive) information. Secondly, a good tradeoff between context information and localisation is vital. Small patches result in loss of contextual information whereas large patches tamper with the localization results. Lastly, non-overlapping patches result in a loss of context information. Based on prior observations, an encoder-decoder architecture yields much higher intersection over union (IoU) values than feeding each pixel to a CNN for classification.

Preparing the dataset

Let’s dive into preparing your data-set! I will be writing a separate article covering the nitty-gritties of U-Net, so don’t rack your brains and get bogged down with its architecture.

  1. Divide your original dataset and corresponding annotations into two groups, namely, training and testing(validation, to be more precise) sets. The original images are in RGB while their masks are binary (black and white).
  2. Convert the all of the image data to .tif.
  3. You won’t be needing the image class labels or the annotations of the test set (this is not a classification problem).

NOTE: The image size should be selected, such that the consecutive convs and max-pooling yields even values of x and y (i.e. width and height of the feature map)at each stage. Although my images were 360X360 pixels, I resized them to 256X256. Crop out the borders to get an appropriate image size. I have included the code, for this.

Cropping and conversion to .tif

You must be wondering what to name the folders, yada yada yada..wait for a few more minutes and you’ll know exactly where to place them!


  • Tensorflow
  • Keras >= 1.0
  • libtiff(optional)
  • OpenCV 3 (if you’re a MAC user, follow this for the installation)

Also, this code should be compatible with Python versions 2.7–3.5.

Training & Data-augmentation

You can rotate, reflect and warp the images if need be. zhixuhao uses a deformation method available here. Follow the next few steps carefully; missing a step can make you go bonkers for hours! I have divided the training stage, into two parts- Part A and Part B, for simplicity.

Part A- Modifying data.py

1.Clone zhixuhao’s repository.

$ git clone https://github.com/zhixuhao/unet

2.Enter the image folder, which lies within the train folder (../unet/data/train/image).

3.Include the training images in the image folder. Each of the images should be in the .tif form, named consecutively, starting from 0.tif, 1.tif…and so on.

4.Get into the label folder which lies within the train folder (../unet/data/train/label). Include the corresponding train-image annotations. Each of the images should be in the .tif form, named consecutively, starting from 0.tif, 1.tif…and so on. Labelling must correspond to the training image-set.

5.Enter the test folder which lies within the data folder (../unet/data/test).

6.Create a folder called npydata within the data folder (../unet/data/npydata). Let this remain empty; the processed data-set will be saved in this, as 3 .npy files, subsequently.

7.Open the data.py file in the unet folder (../unet/data.py). Steps 8,9, 10 and 11 refer to the changes that you will have to make in this file, for RGB images. The regions in bold correspond to the changes made by me.

8.Modify def create_train_data(self) as shown below.

def create_train_data(self):
imgdatas = np.ndarray((len(imgs),self.out_rows,self.out_cols,3), dtype=np.uint8)
imglabels = np.ndarray((len(imgs),self.out_rows,self.out_cols,1), dtype=np.uint8)
img = load_img(self.data_path + "/" + midname) #Removed grayscale
label = load_img(self.label_path + "/" + midname,grayscale = True)
#Correspond to lines 159-164

9.Modify def create_test_data(self) as shown below.

def create_test_data(self):

imgdatas = np.ndarray((len(imgs),self.out_rows,self.out_cols,3), dtype=np.uint8)

img = load_img(self.test_path + "/" + midname) #Removed grayscale
#Correspond to lines 188 and 191

10.Change imagenum to the size of your train-set.

def doAugmentate(self, img, save_to_dir, save_prefix, batch_size=1, save_format='tif', imgnum=26): #I've considered 26 training images

11. Modify the following in class dataProcess(object) by providing the correct path details. I changed the data_path , label_path, test_path and npy_path to the correct path (corresponding to the directories in my system). You can try editing these in line 138, in data.py. Incase a few errors pop up, go though my answers to issue #40 on Github. Please see to it that the path to the npydata folder is not wrong (this is a common mistake).

def __init__(self, out_rows, out_cols, data_path = "/Users/sukritipaul/Dev/newcvtestpy2/unet2/data/train/image", label_path = "/Users/sukritipaul/Dev/newcvtestpy2/unet2/data/train/label", test_path = "/Users/sukritipaul/Dev/newcvtestpy2/unet2/data/test", npy_path = "/Users/sukritipaul/Dev/newcvtestpy2/unet2/data/npydata", img_type = "tif"):
#Corresponds to line 138

12.Run data.py

$python data.py

Part A- Verification

Fig 3: Output obtained on running data.py

Your output should match the output in Fig.3. If your terminal shows ‘Done: 0/<some_value> images’, then your files have not been included in the .npy files. Check if you have the following files in /unet/data/npydata:

  1. imgs_mask_train.npy
  2. imgs_test.npy
  3. imgs_train.npy

Wohooo! You’re done with the data preparation bit :)

Fig4: Result on the ISBI cell tracking challenge- input and cyan mask

Part B- Modifying unet.py

  1. Create a results folder in the unet folder (../unet/results). If you’re thinking why you’ve created this- you’ll know why, in a jiffy!

2. Open unet.py (../unet/unet.py)

3. Edit the number of rows and columns in the following line. My images had dimensions of 256X256.

def __init__(self, img_rows = 256, img_cols = 256):
#Corresponds to line 13

4. Modify the following for a 3 channel input, in get_unet(self), as shown below.

def get_unet(self):
inputs = Input((self.img_rows, self.img_cols,3))
#Corresponds to line 27

5. Modify the following lines in train(self).

def train(self):
np.save('/Users/sukritipaul/Dev/newcvtestpy2/unet2/results/imgs_mask_test.npy', imgs_mask_test)
#Note that the address to the results directory must be provided
##Corresponds to line 164

6. Modify def save_img(self), keeping in mind the address of the results directory, as specified in step 4.

def save_img(self):
imgs = np.load('/Users/sukritipaul/Dev/newcvtestpy2/unet2/results/imgs_mask_test.npy') # Use the same address as above
img.save("/Users/sukritipaul/Dev/newcvtestpy2/unet2/results/%d.jpg"%(i)) #Saves the resulting segmented maps in /results
##Corresponds to lines 169 and 173

7. Run unet.py and wait for a couple of minutes (your training time could take hours depending on the dataset size and system hardware).

$python unet.py

Part B- Verification

Visit /unet/results and viola! You can find the generated image masks or segmented feature maps in grayscale :)

In conclusion, I got an overall accuracy of 90.71% on 65 training images and 10 validation images of size 256X256. They have used binary cross-entropy as the loss function. It took me approximately 3 hours to train on an 8 GB MacBook Air with a 1.8 GHz Intel Core i5 processor (CPU).

Feel free to include your queries as comments! :)

Source: Deep Learning on Medium