Packaging your Pytorch Model using MLflow

Original article was published on Deep Learning on Medium

Packaging your Pytorch Model using MLflow

Most Machine Learning projects start in a research environment; usually, a notebook, where data scientists gather data, develop features, train, and validate the model(s). The final product is usually a mix of data preprocessing code, machine learning code, environment dependencies, configuration files, etc. Data scientists then turn to the engineering teams to package their code and make it ready for the production environment. More challenging in this process is the ability of the engineering team to duplicate the model precisely in the production environment. i.e., given the same raw input data in both environments, we get the same output in both environments.

MLflow is an open-source platform for managing the end-to-end machine learning lifecycle. There are four pillars around MLflow: MLflow Tracking, MLflow Projects, MLflow Models, and MLflow Registry. In this project, I am using MLflow Models. MLflow Models defines a standard format for packaging machine learning models that can be used in a variety of downstream tools such as real-time serving through the REST API or batch inference on Apache Spark.

In this article, I am building a Text Classification model in Pytorch and package it using MLflow Models. This post contains the followings:

  1. Text preprocessing with pre-trained word embeddings.
  2. Model definition.
  3. Model training.
  4. Model packaging using MLflow.

Text preprocessing with pre-trained word embeddings

Every machine learning algorithm deals with numbers. To use text data as features, we need to transform each word type within the corpus into some number representation. Traditional NLP approaches mostly based on word-count techniques, usually convert characters to lower case, remove punctuation, remove symbols or special characters, perform stemming, or lemmatization on tokens. When using pre-trained embeddings, the quality of the final model depends on how much our corpus’ vocabulary matches the embeddings’ vocabulary. I’m using GloVe embeddings, more specifically glove.840B.300d.zip, I won’t convert characters in my corpus into lowercase since my embeddings handle letter casing tokens. I’m cleaning the corpus by doing the following:

  1. Remove all the character symbols that do not appear in the embedding. A symbol is a character that is not an ASCII character, an integer, or any of the Latin characters listed below, including the whitespace.
  2. Handle contractions using the TreebankTokenizer.
  3. Remove the apostrophe symbol at the beginning of tokens.
Text preprocessing: remove symbols
Text preprocessing: handle contractions.
Text preprocessing: remove the apostrophe.

After cleaning the corpus, I tokenize the corpus using the Keras Tokenizer to build our corpus’ vocabulary, build the embedding matrix, transform text tokens into sequences of indexes, pad, or truncate sequences to have the same length.

Model Definition

To train the model, I use a special kind of Recurrent Neural Networks (RNNs) called Long Short Term Memory Networks (LSTMs). LSTMs solve the long-dependency and vanishing gradient problems encountered in vanilla RNNs. LSTMs modules contain interacting layers that control the information flow. They maintain a cell state and use structures called gates where information is added or removed to the cell state. More importantly, they forget irrelevant parts of the previous states through the forget gate, selectively update the state value through the update gate, and selectively output certain parts of the cell state through the output gate. I use a two-layer bidirectional LSTM.

LSTM architecture
Model code

Model Training

I train the model on 10 epochs using the binary cross-entropy loss.

Training code
Learning curves

Package the model using MLflow

Now that we have trained and validated the model, it’s time to package it. We will need to do the following:

  1. Serialize the tokenizer, the embeddings matrix, the model’s weights.
  2. Define the dictionary artifact that contains the file path of the serialized objects.
  3. Define the Conda environment with all the dependencies.
  4. Define the PythonModel class, which includes the ‘predict’ function logic.

Our packaged model is ready to be used by any downstream application. I just need to load the package and call the ‘predict’ function.