Chainer model inference in Java, with ONNX and Apache MXNet

Source: Deep Learning on Medium

Authors: Vandana Kannan, Roshani Nagmote

In this post, we’ll see how to convert a model trained in Chainer to ONNX format and import it in MXNet for inference in a Java environment. We’ll demonstrate this with the help of an image classification example using a VGG16 model. Complete code for this example can be found in this repo.

There are several different frameworks and tools out there offering a varied set of features and capabilities for deep learning. While this provides developers ample options to experiment with, creating a solution interoperable between frameworks requires a lot of development effort. Open Neural Network Exchange (ONNX) was developed to allow developers to port models across frameworks easily, thus enabling them to leverage the advantages of different frameworks with minimal engineering effort.

Chainer and Apache MXNet are two such deep learning frameworks that offer a wide variety of capabilities. While Chainer is known to be the first framework to introduce the define-by-run approach to network training, Apache MXNet is known to be a lean, highly scalable framework that offers a flexible programming model along with support for multiple programming languages (Python, C++, Java, Scala etc.).

Combining these advantages, in this blogpost we will use a pre-trained VGG16 model (a deep convolutional network for large-scale image recognition) from ChainerCV, a computer vision library built on top of Chainer, to perform inference using MXNet’s Java API. The rationale behind choosing Java for inference in this example is to demonstrate how AI capabilities can be embedded into Java-based production services or systems.

To achieve this, the pre-trained Chainer model will be exported to ONNX format, which will in turn be imported into MXNet’s model format (symbol and params) for inference in a Java environment.

Environment Setup

Before we start exploring the code, we will go over the prerequisites for executing the code in the repo. The following packages are required for this example to work. The steps to install these are mentioned in the README.

Python 3.5+
MXNet 1.4.0
Chainer 5.3.0
ChainerCV 0.12.0
ONNX 1.3.0
onnx_chainer 1.3.3
Java 8 JDK

onnx-chainer works only with Python 3.5 and above. 
Chainer is not guaranteed to work on MacOS and Windows.

From Chainer to MXNet

The code to convert a Chainer model to MXNet model format through ONNX is written in Python. After importing all the necessary libraries, the VGG16 model pre-trained on the ImageNet dataset, is exported to ONNX format using the export API in onnx_chainer. Since the model accepts input of shape (1, 3, 224, 224), a dummy input of this shape is passed to export.

def convert_model_to_onnx(input_shape, onnx_file_path): 
# Export Chainer model to ONNX
model = L.VGG16(pretrained_model='imagenet')

# Pseudo input
x = np.zeros(input_shape, dtype=np.float32)

# Don't forget to set train flag off!
chainer.config.train = False
onnx_chainer.export(model, x, filename=onnx_file_path)
convert_model_to_onnx((1, 3, 224, 224), 'model/chainer_vgg16.onnx')

The exported ONNX model can be visualized using Netron.

chainer_vgg16.onnx in Netron

MXNet’s import_model() takes in the path of the ONNX model to import into MXNet and generates symbol and parameters, which represent the graph/network and weights of the model, respectively. These symbol and parameters will be used to perform inference.

# Import ONNX model to MXNet
sym, arg_params, aux_params = onnx_mxnet.import_model(onnx_file_path)

Now load this imported model, bind it to allocate memory given the input shape, assign parameters, and export the symbol and parameters to JSON and params files respectively. These files are consumed by the Java API for inference.

ctx = mx.gpu() if mx.context.num_gpus() > 0 else mx.cpu()

mod = mx.mod.Module(symbol=sym, data_names=[input_name], label_names=None, context=ctx)
mod.bind(for_training=False, data_shapes=[(input_name, input_shape)], label_shapes=mod._label_shapes)
mod.set_params(arg_params=arg, aux_params=aux, allow_missing=True, allow_extra=True)
# Export model to JSON and params files'vgg16-symbol.json')

Inference in Java

Follow the step-by-step instructions to configure your development environment to use the Java APIs. This is also mentioned in the installation instructions of the README.

pom.xml contains details about the project and configurations that are required to build the Maven project. Adding the following to pom.xml gets the latest artifacts of MXNet from the Apache repository:

<id>Apache Snapshot</id>

The first step to inference is to load the pre-trained MXNet model. For this, apart from the path to the model, the model input description needs to be provided. This description is specified using an object of type DataDesc. The model that we are using here has input name ‘Input_0’ and input shape (1, 3, 224, 224).

Note: The input name varies for different models. It could be ‘data’ or ‘data_0’, or ‘Input_0’ etc. Mention the input name used by your model in DataDesc.

Tip: A quick way to figure this out would be to visualize the model in Netron and get the input ID of the first node in the graph.

Shape inputShape = new Shape(new int[] {1, 3, 224, 224});

List<DataDesc> inputDescriptors = new ArrayList<>();
inputDescriptors.add(new DataDesc("Input_0", inputShape, DType.Float32(), "NCHW"));

Next, we specify the context for inference, using the Context class.

List<Context> inferenceContext = Arrays.asList(Context.cpu());
For GPU,
List<Context> inferenceContext = Arrays.asList(Context.gpu());

Specify the path to the model (model_path/prefix). This folder contains the symbol file, the params file, and a synset file containing the list of possible classes. In this example, the user has been given the option to provide this path through command line or use the default VGG16 that we have already saved. The Predictor class is used for image classification applications, so we create an object of this class providing details about the model, input, and context.

private String modelPathPrefix = System.getProperty("user.dir") + "/model/vgg16";
Predictor predictor = new Predictor(inst.modelPathPrefix, inputDescriptors, context, 0);

Now, we will load the image for inference. This example uses a default image but gives you the option to enter the path to an image through command line. Image API lets you read and process images.

private String inputImagePath = System.getProperty("user.dir") + "/data/Penguin.jpg";
NDArray img = Image.imRead(inst.inputImagePath, 1, true);
Sample image for inference (Source)

Since the model expects an image of shape (224, 224) and format “NCHW”, we perform some preprocessing on the input image.

img = Image.imResize(img, 224, 224);
NDArray nd = img;
nd = NDArray.transpose(nd, new Shape(new int[]{2, 0, 1}), null)[0];
nd = NDArray.expand_dims(nd, 0, null)[0];

We make use of the predictWithNDArray API of the Predictor class to perform inference. This API takes NDArray as input and returns a list of predictions as output.

List<NDArray> ndList = Collections.singletonList(nd);
List<NDArray> ndResult = predictor.predictWithNDArray(ndList);

We then pick the result with the highest probability and find the corresponding class label from the synset file.

Prediction for /Users/vandanavk/ONNXinferenceJava/data/Penguin.jpg
Probability : 0.9999974 Class : n02056570 king penguin, Aptenodytes patagonica

In this blogpost, we saw how a model trained in one framework could be transferred to another framework for inference through ONNX. We also saw how to perform inference on the imported MXNet model, in Java.

While we exported a Chainer model to ONNX and then imported to MXNet, it is also possible to import readily available ONNX models from the ONNX Model Zoo into MXNet and perform inference using the Java API.

Learn more

If you would like to get started with Apache MXNet, you could check out the installation page to build MXNet and try out some of the tutorials in a language of your choice.

If you would like to try out inference of other ONNX models in Java, ONNX Model Zoo has a collection of pre-trained, State-of-the-art models in ONNX format that can be imported into MXNet. To learn more about MXNet’s support for the ONNX format, visit ONNX API docs in the MXNet website.

Java examples folder contains examples of Image Classification and Object Detection using MXNet’s Java API.