Speed up the training process using all TPU cores at once

Original article was published by Grakesh on Deep Learning on Medium

Speed up the training process using all TPU cores at once 🔥

Image source: https://cloud.google.com/tpu

In this article, I would to share a trick which can be useful to train deep learning model super fast on all the 8 cores of a TPU in parallel. Thanks to Abhisek Thakur for this amazing idea. You can check out his YouTube video which explains this training on a TPU device.

What is a TPU?

Image source: https://cloud.google.com/tpu

TPU is a designated architecture for DL/ML computation which is designed by Google as a matrix processor specialized for neural network work loads. TPUs can’t run word processors, control rocket engines, or execute bank transactions, but they can handle the massive multiplications and additions for neural networks, at blazingly fast speeds while consuming much less power and inside a smaller physical footprint.

Let’s look at this in action!

Thanks to Kaggle for providing 30 hours per week of TPU computation to experiment and get our hands dirty.

We are going to use PyTorch XLA and the dataset used to experiment comes from the kaggle competition: SIIM-ISIC Melanoma Classification.

Install the required PyTorch XLA package

Import the required dependencies,

Since the dataset is highly imbalanced, we will be using Stratified KFold technique for cross validation

We will be using EfficientNet-b7 as our base model

Now let’s dive straight into the important part which is required to train the model parallel leaving behind the data preprocessing and dataloading code.

We will be using a DistributedSampler for the Dataloader

xm.xrt_world_size() is used to retrieve the number of devices (cores) taking part in replication

xm.get_ordinal() is used to retrieve the replication ordinal of the current process which ranges from 0 to xrt_world_size() -1

In order to train the model in parallel, we use the ParallelLoader function which wraps the existing PyTorch DataLoader

The device is set to xm.xla_device() which is similar to torch.device(‘cuda’)

Train and evaluate functions,

ParallelLoader loads the training data onto each device

para_loader.per_device_loader(device) is used to retrieve the loader iterator object for the given device and the optimizer should be stepped with xm.optimizer_step(optimizer)

Running on Multiple XLA devices with Multiprocessing,

xmp.spawn() creates the processes that each run an XLA device

Using TPU we can increase the training speed and also benefit by cutting down the pricing cost involved. We can also use the bfloat16 datatype when running on TPU by using the command,

Link to my Kaggle Kernel https://www.kaggle.com/rocky03/effnet-pytorch-tpu/notebook