Deep Java Library(DJL) — a Deep Learning Toolkit for Java Developers

Source: Deep Learning on Medium

Deep Java Library (DJL), is an open-source library created by Amazon to develop machine learning (ML) and deep learning (DL) models natively in Java while simplifying the use of deep learning frameworks.

I recently used DJL to develop a footwear classification model and found the toolkit super intuitive and easy to use; it’s obvious a lot of thought went into the design and how Java developers would use it. DJL APIs abstract commonly used functions to develop models and orchestrate infrastructure management. I found the high-level APIs used to train, test and run inference allowed me to use my knowledge of Java and the ML lifecycle to develop a model in less than an hour with minimal code.

Footwear classification model

The footwear classification model is a multiclass classification computer vision (CV) model, trained using supervised learning, that classifies footwear in one of four class labels: boots, sandals, shoes, or slippers.

Image 1: footwear data (source UT Zappos50K)

About the data

The most important part of developing an accurate ML model is to use data from a reputable source. The data source for the footwear classification model is the UTZappos50k dataset provided by The University of Texas at Austin and is freely available for academic, non-commercial use. The shoe dataset consists of 50,025 labeled catalog images collected from Zappos.com.

Train the footwear classification model

Training is the process to produce an ML model by giving a learning algorithm training data to study. The term model refers to the artifact produced during the training process; the model contains patterns found in the training data and can be used to make a prediction (or inference). Before I started the training process, I set up my local environment for development. You will need JDK 8 (or later), IntelliJ, an ML engine for training (like Apache MXNet), an environment variable pointed to your engine’s path and the build dependencies for DJL.

dependencies {compile "org.apache.logging.log4j:log4j-slf4j-impl:2.12.1"compile "ai.djl:api:0.2.0"compile "ai.djl:basicdataset:0.2.0"compile "ai.djl:examples:0.2.0"compile "ai.djl:model-zoo:0.2.0"compile "ai.djl.mxnet:mxnet-model-zoo:0.2.0"runtimeOnly "ai.djl.mxnet:mxnet-native-mkl:1.6.0-a:osx-x86_64"}

DJL stays true to Java’s motto, “write once, run anywhere (WORA)”, by being engine and deep learning framework-agnostic. Developers can write code once that runs on any engine. DJL currently provides an implementation for Apache MXNet, an ML engine that eases the development of deep neural networks. DJL APIs use JNA, Java Native Access, to call the corresponding Apache MXNet operations. From a hardware perspective, training occurred locally on my laptop using a CPU. However, for the best performance, the DJL team recommends using a machine with at least one GPU. If you don’t have a GPU available to you, there is always an option to use Apache MXNet on Amazon EC2. A nice feature of DJL is that it provides automatic CPU/GPU detection based on the hardware configuration to always ensure the best performance.

Load dataset from the source

The footwear data was saved locally and loaded using DJL ImageFolder dataset, which is a dataset that can retrieve images from a local folder. In DJL terms, a Dataset simply holds the training data. There are dataset implementations that can be used to download data (based on the URL you provide), extract data, and automatically separate data into training and validation sets. The automatic separation is a useful feature as it is important to never use the same data the model was trained with to validate the model’s performance. The training validation dataset is used to find patterns in the data; the validation dataset is used to estimate the footwear model’s accuracy during the training process.

//identify the location of the training dataString trainingDatasetRoot = "src/test/resources/imagefolder/train";//identify the location of the validation dataString validateDatasetRoot = "src/test/resources/imagefolder/validate";//create training ImageFolder dataset
ImageFolder trainingDataset = initDataset(trainingDatasetRoot);
//create validation ImageFolder dataset
ImageFolder validateDataset = initDataset(validateDatasetRoot);
private ImageFolder initDataset(String datasetRoot) throws IOException { ImageFolder dataset = new ImageFolder
.Builder()
.setRepository(new
SimpleRepository(Paths.get(datasetRoot))).optPipeline(
// create preprocess pipeline
new Pipeline()
.add(new Resize(NEW_WIDTH, NEW_HEIGHT))
.add(new ToTensor()))
.setSampling(BATCH_SIZE,true)
.build();
dataset.prepare();
return dataset;
}
Image 2: structure of local training data

Train the model

Now that I have the footwear data separated into training and validation sets, I will use a neural network to train the model.

public final class Training extends AbstractTraining { . . . @Override
protected void train(Arguments arguments) throws IOException {

. . .
try (Model model = Models.getModel(NUM_OF_OUTPUT,
NEW_HEIGHT, NEW_WIDTH)) {
TrainingConfig config = setupTrainingConfig(loss); try (Trainer trainer = model.newTrainer(config)) {
trainer.setMetrics(metrics);
trainer.setTrainingListener(this);
Shape inputShape = new Shape(1, 3, NEW_HEIGHT,
NEW_WIDTH);

// initialize trainer with proper input shape
trainer.initialize(inputShape);
//find the patterns in data
fit(trainer, trainingDataset,
validateDataset, "build/logs/training");
//set model properties
model.setProperty("Epoch",
String.valueOf(EPOCHS));
model.setProperty("Accuracy",
String.format("%.2f",
getValidationAccuracy()));
//save the model after done training for
//inference later model saved
//as shoeclassifier-0000.params
model.save(Paths.get(modelParamsPath),
modelParamsName);
}
}
}

Training is started by feeding the training data as input to a Block. In DJL terms, a Block is a composable unit that forms a neural network. You can combine Blocks (just like Lego blocks) to form a complex network. At the end of the training process, a Block represents a fully-trained model. The first step is to get a model instance by calling Models.getModel(NUM_OF_OUTPUT, NEW_HEIGHT, NEW_WIDTH). The getModel() method creates an empty model, constructs the neural network, and sets the neural network to the model.

/*
Use a neural network (ResNet-50) to train the model
ResNet-50 is a deep residual network with 50 layers; good for image classification
*/
public class Models {
public static ai.djl.Model getModel(int numOfOutput,
int height, int width)
{
//create new instance of an empty model
ai.djl.Model model = ai.djl.Model.newInstance();
//Block is a composable unit that forms a neural network;
//combine them like Lego blocks to form a complex network
Block resNet50 = new ResNetV1.Builder()
.setImageShape(new Shape(3, height, width))
.setNumLayers(50)
.setOutSize(numOfOutput)
.build();
//set the neural network to the model
model.setBlock(resNet50);
return model;
}
}

The next step is to set up and configure a Trainer by calling the model.newTrainer(config) method. The config object was initialized by calling the setupTrainingConfig(loss) method, which sets the training configuration (or hyperparameters) to determine how the network is trained.

private static TrainingConfig setupTrainingConfig(Loss loss) {
// epoch number to change learning rate
int[] epoch = {3, 5, 8};
int[] steps = Arrays
.stream(epoch)
.map(k -> k * 60000 / BATCH_SIZE).toArray();
//initialize neural network weights using Xavier initializer
Initializer initializer = new XavierInitializer(
XavierInitializer.RandomType.UNIFORM,
XavierInitializer.FactorType.AVG, 2);
//set the learning rate
//adjusts weights of network based on loss
MultiFactorTracker learningRateTracker = LearningRateTracker
.multiFactorTracker()
.setSteps(steps)
.optBaseLearningRate(0.01f)
.optFactor(0.1f)
.optWarmUpBeginLearningRate(1e-3f)
.optWarmUpSteps(500)
.build();
//set optimization technique
//minimizes loss to produce better and faster results
//Stochastic gradient descent
Optimizer optimizer = Optimizer
.sgd()
.setRescaleGrad(1.0f / BATCH_SIZE)
.setLearningRateTracker(learningRateTracker)
.optMomentum(0.9f)
.optWeightDecays(0.001f)
.optClipGrad(1f)
.build();
return new DefaultTrainingConfig(initializer, loss)
.setOptimizer(optimizer)
.addTrainingMetric(new Accuracy())
.setBatchSize(BATCH_SIZE);
}

There are multiple hyperparameters set for training:

  • newHeight and newWidth — the shape of the image.
  • batchSize — the batch size used for training; pick a proper size based on your model.
  • numOfOutput — the number of labels; there are 4 labels for footwear classification.
  • loss — loss functions evaluate model predictions against true labels measuring how good (or bad) a model is.
  • Initializer — identifies an initialization method; in this case, Xavier initialization.
  • MultiFactorTracker — configures the learning rate options.
  • Optimizer: an optimization technique to minimize the value of the loss function; in this case, stochastic gradient descent (SGD).

The next step is to set Metrics, a training listener, and initialize the Trainer with the proper input shape. Metrics collect and report key performance indicators (KPIs) during training that can be used to analyze and monitor training performance and stability. Next, I kick off the training process by calling the fit(trainer, trainingDataset, validateDataset, “build/logs/training”) method, which iterates over the training data and stores the patterns found in the model.

public void fit(Trainer trainer, Dataset trainingDataset, Dataset validateDataset,String outputDir) throws IOException { // find patterns in data
for (int epoch = 0; epoch < EPOCHS; epoch++)
{
for (Batch batch : trainer.iterateDataset(trainingDataset))
{
trainer.trainBatch(batch);
trainer.step();
batch.close();
}

//validate patterns found
if (validateDataset != null) {
for (Batch batch:
trainer.iterateDataset(validateDataset)){
trainer.validateBatch(batch);
batch.close();
}
}
//reset training and validation metric at end of epoch
trainer.resetTrainingMetrics();
//save model at end of each epoch
if (outputDir != null) {
Model model = trainer.getModel();
model.setProperty("Epoch", String.valueOf(epoch));
model.save(Paths.get(outputDir), "resnetv1");
}
}
}

At the end of the training, a well-performing validated model artifact is saved locally along with its properties using the model.save(Paths.get(modelParamsPath), modelParamsName)method. The metrics reported during the training process are shown below.

Image 3: metrics reported during training

Run inference

Now that I have a model, I can use it to perform inference (or prediction) on new data for which I do not know the classification (or target). After setting the necessary paths to the model and the image to be classified, I obtain an empty model instance using the Models.getModel(NUM_OF_OUTPUT, NEW_HEIGHT, NEW_WIDTH) method and initialize it using the model.load(Paths.get(modelParamsPath), modelParamsName) method. This loads the model I trained in the previous step. Next, I’m initializing a Predictor, with a specified Translator, using the model.newPredictor(translator)method. You’ll notice that I’m passing a Translator to the Predictor. In DJL terms, a Translator provides model pre-processing and post-processing functionality. For example, with CV models, images need to be reshaped to grayscale; a Translator can do this for you. The Predictor allows me to perform inference on the loaded Model using the predictor.predict(img) method, passing in the image to classify. I’m doing a single prediction, but DJL also supports batch predictions. The inference is stored in predictResult, which contains the probability estimate per label. The model is automatically closed once inference completes making DJL memory efficient.

private Classifications predict() throws IOException, ModelException, TranslateException {

//the location to the model saved during training
String modelParamsPath = "build/logs";
//the name of the model set during training
String modelParamsName = "shoeclassifier";
//the path of image to classify
String imageFilePath = "src/test/resources/slippers.jpg";
//Load the image file from the path
BufferedImage img =
BufferedImageUtils.fromFile(Paths.get(imageFilePath));
//holds the probability score per label
Classifications predictResult;
try (Model model = Models.getModel(NUM_OF_OUTPUT, NEW_HEIGHT, NEW_WIDTH)) { //load the model
model.load(Paths.get(modelParamsPath), modelParamsName);
//define a translator for pre and post processing
Translator<BufferedImage, Classifications> translator =
new MyTranslator();
//run the inference using a Predictor
try (Predictor<BufferedImage, Classifications> predictor =
model.newPredictor(translator)) {
predictResult = predictor.predict(img);
}
}
return predictResult;
}

The inferences (per image) are shown below with their corresponding probability scores.

Image 4: Inference for boots
Image 5: Inference for sandals
Image 6: Inference for shoes
Image 7: Inference for slippers

Takeaways & Next Steps

I’ve been developing Java-based applications since the late ’90s and started my machine learning journey in 2017. My journey would’ve been much easier had DJL been around back then. I highly recommend that Java developers, looking to transition to machine learning, give DJL a try. In my example, I developed the footwear classification model from scratch; however, DJL also allows developers to deploy pre-trained models with minimal effort. DJL also comes with popular datasets out of the box to allow developers to instantly get started with ML. Before starting with DJL, I would recommend that you have a firm understanding of the ML lifecycle and are familiar with common ML terms. Once you have a basic level understanding of ML, you can quickly come up to speed on DJL APIs.

Amazon has open-sourced DJL, where further detailed information about the toolkit can be found on the DJL website and Java Library API Specification page. The code for the footwear classification model can be found on GitLab. Good luck on your ML journey and please feel free to reach out to me if you have any questions.