How to train Detectron2 with Custom COCO Datasets

Source: Deep Learning on Medium

How to train Detectron2 with Custom COCO Datasets

Along with the latest PyTorch 1.3 release came with the next generation ground-up rewrite of its previous object detection framework, now called Detectron2. This tutorial will help you get started with this framework by training an instance segmentation model with your custom COCO datasets. If you know how to create COCO datasets, please read my previous post — How to create custom COCO data set for instance segmentation.

For a quick start, we will do our experiment in a Colab Notebook so you don’t need to worry about setting up the development environment on your own machine before getting comfortable with Pytorch 1.3 and Detectron2.

Install Detectron2

In the Colab notebook, just run those 4 lines to install the latest Pytorch 1.3 and Detectron2.

Click “RESTART RUNTIME” in the cell’s output to let your installation take effect.

Register a COCO dataset

To tell Detectron2 how to obtain your dataset, we are going to “register” it.

To demonstrate this process, we use the fruits nuts segmentation dataset which only has 3 classes: data, fig, and hazelnut. We’ll train a segmentation model from an existing model pre-trained on the COCO dataset, available in detectron2’s model zoo.

You can download the dataset like this.

Or you can upload your own dataset from here.

Register the fruits_nuts dataset to detectron2, following the detectron2 custom dataset tutorial.

Each dataset is associated with some metadata. In our case, it is accessible by calling fruits_nuts_metadata = MetadataCatalog.get("fruits_nuts"), you will get

Metadata(evaluator_type='coco', image_root='./data/images', json_file='./data/trainval.json', name='fruits_nuts',
thing_classes=['date', 'fig', 'hazelnut'], thing_dataset_id_to_contiguous_id={1: 0, 2: 1, 3: 2})

To get the actual internal representation of the catalog stores information about the datasets and how to obtain them, you can call dataset_dicts = DatasetCatalog.get("fruits_nuts"). The internal format uses one dict to represent the annotations of one image.

To verify the data loading is correct, let’s visualize the annotations of randomly selected samples in the dataset:

One of the images might show this.

Train the model

Now, let’s fine-tune a coco-pretrained R50-FPN Mask R-CNN model on the fruits_nuts dataset. It takes ~6 minutes to train 300 iterations on Colab’s K80 GPU.

In case you switch to your own datasets, change the number of classes, learning rate, or max iterations accordingly.

Make a prediction

Now, we perform inference with the trained model on the fruits_nuts dataset. First, let’s create a predictor using the model we just trained:

Then, we randomly select several samples to visualize the prediction results.

Here is what you get with a sample image with prediction overlayed.

Conclusion and further thought

You might have read my previous tutorial on a similar object detection framework named MMdetection also built upon PyTorch. So how is Detectron2 compared with it? Here are my few thoughts.

Both frameworks are easy to config with a config file that describes how you want to train a model. Detectron2’s YAML config files are more efficient for two reasons. First, You can reuse configs by making a “base” config first and build final training config files upon this base config file which reduces duplicated code. Second, the config file can be loaded first and allows any further modification as necessary in Python code which makes it more flexible.

What about the inference speed? Simply put, Detectron2 is slightly faster than MMdetection for the same Mask RCNN Resnet50 FPN model. MMdetection gets 2.45 FPS while Detectron2 achieves 2.59 FPS, or a 5.7% speed boost on inferencing a single image. Benchmark based on the following code.

So, you have it, Detectron2 make it super simple for you to train a custom instance segmentation model with custom datasets. You might find the following resources helpful.

My previous post — How to create custom COCO data set for instance segmentation.

My previous post — How to train an object detection model with mmdetection.

Detectron2 GitHub repository.

The runnable Colab Notebook for this post.