From PyTorch to PyTorch Lightning — A gentle introduction

Source: Artificial Intelligence on Medium

The Typical AI Research project

In a research project, we normally want to identify the following key components:

  • the model(s)
  • the data
  • the loss
  • the optimizer(s)

The Model

Let’s design a 3-layer fully-connected neural network that takes as input an image that is 28×28 and outputs a probability distribution over 10 possible labels.

First, let’s define the model in PyTorch

This model defines the computational graph to take as input an MNIST image and convert it to a probability distribution over 10 classes for digits 0–9.

3-layer network (illustration by: William Falcon)

To convert this model to PyTorch Lightning we simply replace the nn.Module with the pl.LightningModule

The new PyTorch Lightning class is EXACTLY the same as the PyTorch, except that the LightningModule provides a structure for the research code.

Lightning provides structure to PyTorch code

See? The code is EXACTLY the same for both!

The Data

For this tutorial we’re using MNIST.

Source: Wikipedia

Let’s generate three splits of MNIST, a training, validation and test split.

This again, is the same code in PyTorch as it is in Lightning.

The dataset is added to the Dataloader which handles the loading, shuffling and batching of the dataset.

In short, data preparation has 4 steps:

  1. Download images
  2. Image transforms (these are highly subjective).
  3. Generate training, validation and test dataset splits.
  4. Wrap each dataset split in a DataLoader

Again, the code is exactly the same except that we’ve organized the PyTorch code into 4 functions:


This function handles downloads and any data processing. This function makes sure that when you use multiple GPUs you don’t download multiple datasets or apply double manipulations to the data.

This is because each GPU will execute the same PyTorch thereby causing duplication. ALL of the code in Lightning makes sure the critical parts are called from ONLY one GPU.

train_dataloader, val_dataloader, test_dataloader

Each of these is responsible for returning the appropriate data split. Lightning structures it this way so that it is VERY clear HOW the data are being manipulated. If you ever read random github code written in PyTorch it’s nearly impossible to see how they manipulate their data.

Lightning even allows multiple dataloaders for testing or validating.

The Optimizer

Now we choose how we’re going to do the optimization. We’ll use Adam instead of SGD because it is a good default in most DL research.

Again, this is exactly the same in both except it is organized into the configure optimizers function.

Lightning is extremely extensible. For instance, if you wanted to use multiple optimizers (ie: a GAN), you could just return both here.

You’ll also notice that in Lightning we pass in self.parameters() and not a model because the LightningModule IS the model.

The Loss

For n-way classification we want to compute the cross-entropy loss. Cross-entropy is the same as NegativeLogLikelihood(log_softmax) which we’ll use instead.

Again… code is exactly the same!