Original article was published on Deep Learning on Medium
So in summary the model took ~2331s to train on one epoch, which is acceptable, but still i want it be be even less.
I’m not rich enough to buy 4xP100 and train them on parallel, i rely on colab for all my training.
The next best option is to use a TPU !
Tensorflow models have good support for TPU and its straight forward with Estimator API to train on TPU, but since i was already comfortable with PyTorch i did not want to move on to Tensorflow, one option is to use PyTorch Lightning, and you can easily find colab notebooks for running a model on TPU
But i felt most of these don’t work properly, and seems buggy, and there are a lot of issues, but will surely check it out some other time, for now i wanted to run my model with the least amount of changes.
So, i decided to follow the PyTorch XLA tutorials https://github.com/pytorch/xla/blob/master/contrib/colab/resnet18-training.ipynb
And came up with this code
Notice, there isn’t much changes (zero changes to the model), the only thing is to create a Parallel Loader, and then create a Sampler, then simply train the model, few things to note:
A TPU is a Tensor processing unit. Each TPU has 8 cores where each core is optimized for 128×128 matrix multiplies. In general, a single TPU is about as fast as 5 V100 GPUs!
A TPU pod hosts many TPUs on it. Currently, TPU pod v2 has 2048 cores! You can request a full pod from Google cloud or a “slice” which gives you some subset of those 2048 cores. 
- xm.optimizer_step() does not take a barrier argument this time
- Model was declared outside the run function and was sent to Xla Device in the run fucntion whereas when using single TPU’s we did it simultaneously in one place
- Something called Paraloader is wrapped around dataloader
- USE of XLA_USE_BF16 Environment variable
- And off course we now run the spawn function to execute the model training and eval
- You get 8 TPU cores on Colab