Deploying Keras Deep Learning Models with Java

Source: Wikipedia

The Keras library provides an approachable interface to deep learning, making neural networks accessible to a broad audience. However, one of the challenges I’ve faced is transitioning from exploring models in Keras to productizing models. Keras is written in Python, and until recently had limited support outside of these languages. While tools such as Flask, PySpark, and Cloud ML make it possible to productize these models directly in Python, I usually prefer Java for deploying models.

Projects such as ONNX are moving towards standardization of deep learning, but the runtimes that support these formats are still limited. One approach that’s often used is converting Keras models to TensorFlow graphs, and then using these graphs in other runtines that support TensorFlow. I recently discovered the Deeplearning4J (DL4J) project, which natively supports Keras models, making it easy to get up and running with deep learning in Java.

One of the use cases that I’ve been exploring for deep learning is training models in Python using Keras, and then productizing models using Java. This is useful for situations where you need deep learning directly on the client, such as Android devices applying models, as well as situations where you want to leverage existing production systems written in Java. A DL4J introduction to using Keras is available here.

This post provides an overview of training a Keras model in Python, and deploying it with Java. I use Jetty to provide real-time predictions and Google’s DataFlow to build a batch prediction system. The full code and data needed to run these examples is available on GitHub.

Model Training

The first step is to train a model using the Keras library in Python. Once you have a model that is ready to deploy, you can save it in the h5 format and utilize it in Python and Java applications.For this tutorial, we’ll use the same model that I trained for predicting which players are likely to purchase a new game in my blog post on Flask.

The input to the model is ten binary features (G1, G2, … ,G10) that describe the games that a player has already purchased, and the label is a single variable that describes if the user purchased a game not included in the inputs. The main steps involved in the training process are shown below:

import keras
from keras import models, layers
# Define the model structure
model = models.Sequential()
model.add(layers.Dense(64, activation='relu', input_shape=(10,)))
model.add(layers.Dense(1, activation='sigmoid'))
# Compile and fit the model
history =, y, epochs=100, batch_size=100,
validation_split = .2, verbose=0)
# Save the model in h5 format"games.h5")

The output of this process is an h5 file that represents the trained model that we can deploy in Python and Java applications. In my past post, I showed how to use Flask to serve real-time model predictions in Python. In this post, I’ll show how to build batch and real-time predictions in Java.

Java Setup

To deploy Keras models with Java, we’ll use the Deeplearing4j library. It provides functionality for deep learning in Java and can load and utilize models trained with Keras. We’ll also use Dataflow for batch predictions and Jetty for real-time predictions. Here’s the libraries I used for this project:

  • Deeplearning4j: Provides deep neural network functionality for Java.
  • ND4J: Provides tensor operations for Java.
  • Jetty: Used for setting up a web endpoint.
  • Cloud DataFlow: Provides autoscaling for batch predictions on GCP.

I imported these into my project using the pom.xml shown below. For DL4J, boths the core and modelimport libraries are needed when using Keras.

</dependency> <dependency>

I set up my project in Eclipse, and once I got the pom file properly configured, no additional setup was needed to get started.

Keras Predictions with DL4J

Now that we have the libraries set up, we can start making predictions with the Keras model. I wrote the script below to test out loading a Keras model and making a prediction for a sample data set. The first step is to load the model from the h5 file. Next, I define a 1D tensor of length 10 and generate random binary values. The last step is to call the output method on the model to generate a prediction. Since my model has a single output node, I use getDouble(0) to return the output of the model.

// imports
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
// load the model
String simpleMlp = new ClassPathResource(
MultiLayerNetwork model = KerasModelImport.
// make a random sample
int inputs = 10;
INDArray features = Nd4j.zeros(inputs);
for (int i=0; i<inputs; i++)
features.putScalar(new int[] {i}, Math.random() < 0.5 ? 0 : 1);
// get the prediction
double prediction = model.output(features).getDouble(0);

One of the key concepts to become familiar with when using DL4J is tensors. Java does not have a built-in library for efficient tensor options, which is why NDJ4 is a prerequisite. It provides N-Dimensional arrays for implementing deep learning backends in Java. To set a value in the tensor object, you pass an integer array which provides a n-dimensional index to the tensor, and the value to set. Since I am using a 1D tensor, the array is of length one.

The model object provides predict and output methods. The predict method returns a class prediction (0 or 1), while the output method returns a continuous label, similar to predict_proba in scikit-learn.

Real-Time Predictions

Now that we have a Keras model up and running in Java, we can start serving model predictions. The first approach we’ll take is using Jetty to set up an endpoint on the web for providing model predictions. I previously covered setup for Jetty in my posts on tracking data and model production. The full code for the model endpoint is available here.

The model endpoint is implemented as a single class that loads the Keras model and provides predictions. It implements Jetty’s AbstractHandler interface to provide model results. The code below shows how to set up the Jetty service to run on port 8080 and instantiate the JettyDL4J class which loads the Keras model in the constructor.

// Setting up the web endpoint
Server server = new Server(8080);
server.setHandler(new JettyDL4J());
// Load the Keras model 
public JettyDL4J() throws Exception {
String p=new ClassPathResource("games.h5").getFile().getPath();

The handler for managing web requests is shown in the code snippet below. The passed-in parameters (G1,G2, …,G10) are converted into a 1D tensor object and passed to the output method of the Keras model. The request is then marked as handled and the prediction is returned as string.

// Entry point for the model prediction request 
public void handle(String target,Request baseRequest,
HttpServletRequest request, HttpServletResponse response)
throws IOException, ServletException {
  // create a dataset from the input parameters
INDArray features = Nd4j.zeros(inputs);
for (int i=0; i<inputs; i++)
features.putScalar(new int[] {i}, Double.parseDouble(
baseRequest.getParameter("G" + (i + 1))));
  // output the estimate
double prediction = model.output(features).getDouble(0);
response.getWriter().println("Prediction: " + prediction);

When you run the class, it sets up an endpoint on port 8080. You can call the model service by pointing your browser to the following URL:

// Request
// Result
Prediction: 0.735433042049408

The result is a Keras model that you can now invoke in real-time to get predictions from your deep learning model. For a production system, you’d want to set up a service in front of the Jetty endpoint, rather than exposing the endpoint directly on the web.

Batch Predictions

Another use case for Keras models is batch predictions, where you may need to apply the estimator to millions and millions of records. You can do this directly in Python with your Keras model, but scalability is limited with this approach. I’ll show how to use Google’s DataFlow to apply predictions to massive data sets using a full-managed pipeline. I previously covered setting up DataFlow in my past posts on model production and game simulations.

With DataFlow, you specify a graph of operations to perform on a data set, where the source and destination data sets can be relational databases, messaging services, application databases, and other services. The graphs can be executed as a batch operation, where infrastructure is spun up to handle a large data set and then spun down, or in streaming mode, where infrastructure is maintained and requests are process as they arrive. In both scenarios, the service will autoscale to meet demand. It’s fully managed and great for large calculations that can be performed independently.

DataFlow DAG for Batch Deep Learning

The DAG of operations in my DataFlow process is shown above. The first step is to create a dataset for the model to score. In this example, I’m loading values from my sample CSV, while in practice I’d usually use BigQuery as both the source and sync for model predictions. The next step is a transformation which takes TableRow objects as the input, transforms the rows to 1D tensors, applies the model to each tensor, and create a new output TableRow with the predicted value. The complete code for the DAG is available here.

The key step in this pipeline is the Keras Predict transformation, which is shown in the code snippet below. A transformation operates on a collection of objects and returns a collection of objects. Within a transformer, you can define objects such as the Keras model, which are shared across each of the process element steps defined in the transformer. The result is that the model is loaded once for each transformer, rather than loaded for each record that needs a prediction.

// Apply the transform to the pipeline
.apply("Keras Predict", new PTransform<PCollection<TableRow>,
PCollection<TableRow>>() {
  // Load the model in the transformer
public PCollection<TableRow> expand(PCollection<TableRow> input) {
    final int inputs = 10;
final MultiLayerNetwork model;
try {
String p= newClassPathResource("games.h5").getFile().getPath();
catch (Exception e) {
throw new RuntimeException(e);
  // create a DoFn for applying the Keras model to instances  
return input.apply("Pred",ParDo.of(new DoFn<TableRow,TableRow>(){
public void processElement(ProcessContext c) throws Exception {
... // Apply the Keras model

The code for the process element method is shown below. It reads the input record, creates a tensor from the table row, applies the model, and then saves the record. The output row contains the predicted and actual values.

  // get the record to score
TableRow row = c.element();
  // create the feature vector                   
INDArray features = Nd4j.zeros(inputs);
for (int i=0; i<inputs; i++)
features.putScalar(new int[] {i},
Double.parseDouble(row.get("G" + (i+1)).toString()));
  // get the prediction                  
double estimate = model.output(features).getDouble(0);
  // save the result                  
TableRow prediction = new TableRow();
prediction.set("actual", row.get("actual"));
prediction.set("predicted", estimate);

I’ve excluded the CSV loading and BigQuery writing code blocks in this post, since you may be working with different endpoints. The code and CSV is available on GitHub if you want to try running the DAG. To save the results to BigQuery, you’ll need to set the tempLocation program argument as follows:


After running the DAG, a new table will be created in BigQuery with the actual and predicted values for the dataset. The image below shows sample data points from my application of the Keras model.

Prediction results in BigQuery

The result of using DataFlow with DL4J is that you can score millions of records using autoscaling infrastructure for batch predictions.


As deep learning becomes increasingly popular, more languages and environments are supporting these models. As libraries start to standardize on model formats, it’s becoming possible to use separate languages for model training and model deployment. This post showed how neural networks trained using the Keras library in Python can be used for batch and real-time predictions using the DL4J library in Java. For the first time, I’ve been able to set up a batch process for applying deep learning to millions of data points.

Ben Weber is a principal data scientist at Zynga. We are hiring!

Source: Deep Learning on Medium