ANNs : Ship trained model to production using Kotlin and DeepLearning4J

Source: Deep Learning on Medium


Introduction

Go to the profile of Yassin Hajaj

In the previous article, we saw how to Build and train a model using MNIST, Kotlin and DeepLearning4J. In this article, we’ll focus on how to use the previously trained model in production.

For that, we’ll create an application using the Spring framework with Kotlin and DeepLearning4J.

Generate Project

For this, we’ll use the spring initializr, either use it via the link provided, or via IntelliJ.

The project we’ll build will have the following

  • Maven Project
  • Kotlin
  • 2.1.4 (latest stable version to this day)
  • Add the Web Spring starter

In the end, you should have something looking like the following screenshot

Spring Initializr

Click on Generate Project, unzip it and open it using your editor.

Using IntelliJ, it’s even easier as it creates a new IntelliJ project for you.

Create Front-End App

The goal here is to have an end result allowing the user to draw MNIST-like digits and send them to the back-end to get the answer of the model on which digit has been drawed.

Using Spring MVC, the webapp can be a simple web application. For that, we’ll create the following folder structure to work in.

To avoid adding too much noise to this article, you can refer to the source code posted on GitHub to create your own webapp

GitHub of the Project

The end result looks like the following

  • A canvas allowing to draw the digit
  • A button to send the digit to the back-end API
  • A clear button to clear the canvas if needed
  • A label displaying the result received from the back-end

Create the Back-End API

The most important part regarding the use of our model is the following.

We’ll first take a look at the dependencies needed for the back-end to work.
The Spring and Kotlin dependencies have been generated automatically by the Spring Initiliazr, we only need to add the DeepLearning4J dependencies. I had to lighten the dependency since it comes with a lot of libraries that are not needed in our case (e.g OpenCV and mkl-dnn)

<dependencies>
<!-- SPRING -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-autoconfigure</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<!-- KOTLIN -->
<dependency>
<groupId>org.jetbrains.kotlin</groupId>
<artifactId>kotlin-reflect</artifactId>
</dependency>
<dependency>
<groupId>org.jetbrains.kotlin</groupId>
<artifactId>kotlin-stdlib-jdk8</artifactId>
</dependency>
<!-- DEEPLEARNING4J -->
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-core</artifactId>
<version>${dl4j.version}</version>
<!-- EXCLUDE OPENCV TO LIGHTEN ARCHIVE-->
<exclusions>
<exclusion>
<groupId>org.bytedeco.javacpp-presets</groupId>
<artifactId>*</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native-platform</artifactId>
<version>${dl4j.version}</version>
<!-- EXCLUDE MKL_DNN TO LIGHTEN ARCHIVE-->
<exclusions>
<exclusion>
<groupId>org.bytedeco.javacpp-presets</groupId>
<artifactId>mkl-dnn-platform</artifactId>
</exclusion>
</exclusions>
</dependency>
</dependencies>

We can then focus on the code needed to expose an API for the front-end to communicate with, and request our model to guess which number has been sent through the API.

This holds in one single file if needed containing

  • The deserializer / wrapper for our model (which is a singleton loading the model at startup)
  • The controller exposed for guessing
  • A data class needed to receive the number from outside
@Service
@Scope(SCOPE_SINGLETON)
class ModelWrapper {
internal lateinit var model: MultiLayerNetwork

@PostConstruct
fun loadModel() {
val path = Files.walk(Paths.get("."))
.filter { path -> path.toFile().name.toString().contains("persisted-model") }
.map { path -> path.toAbsolutePath().toString() }
.findFirst()
.orElseThrow { RuntimeException("persisted-model could not be found ! Check the path") }
model = ModelSerializer.restoreMultiLayerNetwork(path)
}

fun guess(image: IntArray): Int {
val ndArray = asNdArray(image)
val output = model.output(ndArray)
val data = output.data()
val results = data.asInt()
val guessResult = results.indexOf(1)
return if (guessResult in 0..9) guessResult else throw Exception("Unrecognized number")
}

private fun asNdArray(newImage: IntArray): INDArray? {
val mnistPixelsAmount = 28 * 28
val ndArray = Nd4j.zeros(mnistPixelsAmount)

for (i in 0 until mnistPixelsAmount) {
ndArray.putScalar(i.toLong(), newImage[i])
}
return ndArray
}
}
@RestController
class ModelController {

@Autowired
lateinit var modelWrapper: ModelWrapper

@PostMapping("/guess", consumes = [MediaType.APPLICATION_JSON_VALUE])
fun guessInput(@RequestBody image: FrontEndImage): Int {
return modelWrapper.guess(image.arrayOfPixels)
}
}

data class FrontEndImage(val arrayOfPixels: IntArray = intArrayOf())

As you see, nothing really difficult here. We make a simple HTTP request, invoke our deserialized model and send back the response.

It’s as easy as that ! :)

Final Result

If you want to play with the final result, go to the following link and send values through the web application. It does not send back the exact correct values each time because the model is overfitted for the training set. For it to be totally accurate, it should be trained with a whole lot of variety of hand written digits, transformed to the MNIST format, and fed to the training model.