Quick Primer on Distributed Training with PyTorch

Source: Deep Learning on Medium

PyTorch provides two main approaches for parallelizing training jobs — DataParallel and DistributedDataParallel. Each approach involves a wrapper that: a) encapsulates all the lower level communication details to orchestrate and synchronize distributed training; and b) exposes a clean API for end users. Barring a few small modifications to the code, the experience of working with wrapped models is almost the same as working with a local, non-distributed model.

DataParallel

DataParallel (DP) is the simpler and more straightforward approach, requiring minimal effort and changes to the training code. It works by slicing up a training data batch into further smaller, similarly-sized sub-batches equalling in number to the available GPU count. These sub-batches are then processed in parallel on respective GPUs.

One GPU serves as the master (default=GPU:0), which orchestrates the overall flow. During each iteration of the SGD, the master: a) replicates and broadcasts the model weights on each available GPU (replicate step); b) splits the training data batch and sends job orders to individual GPUs (scatter step). This enables each GPU to process its own sub-batch of the data to compute local loss and gradients with respect to model parameters (parallel_apply step). Once finished, the master GPU: c) collects local gradients from each GPU (gather step); d), aggregates gradients and performs model updates (update step). As a result, at the start of each iteration all GPUs are supplied with the exact same model parameters.

DataParallel: Single node, multi-GPU setup; implemented as a single-process with multiple threads to manage workers

All of these steps are performed by the PyTorch library under the hood. As a end-user, the only change to code required is to wrap the model object in DataParallel, as shown in the gist below.

While DP is an effective first line of attack when trying to squeeze performance out of multiple available GPUs, there are at least a few limitations: a) it follows a single-process, multi-threaded design, and hence only works on a single node; b) it entails significant communication overhead due to data transfer in each iteration between master and workers; c) master GPU clearly has more responsibilities and hence works harder in general compared to other GPUs in the mix (skewed workload). In fact, during the update step, the master alone works to update model parameters while other devices in the cluster wait, thus remaining underutilized.

DistributedDataParallel

Like DP, DistributedDataParallel (DDP) also implements parallelization on the data dimension. It addresses some of the fundamental limitations of DP; in particular, DDP works in both single- and multi-node environments. The implementation of DDP involves multiple python processes that must be coordinated and synchronized at appropriate points in code. This provides a greater balance of workload across workers and ensures a lower communication overhead resulting from data transfer. Because of these advantages, DDP is also the recommended approach in a single-node context.

DistributedDataParallel: Multi-node, multi-GPU setup; each worker is managed by its own python process.

One key difference from DP is that it follows a different protocol when it comes to updating model parameters. Unlike DP, instead of collecting local gradients on the master node for performing aggregation and model updates which are then broadcast back to workers, gradients are aggregated across all GPUs using an efficient all-reduce operation which makes available averaged gradients to all the GPU workers, each of which can then update its own model replica in parallel.

Under the hood, PyTorch performs synchronization at critical points in the code so no GPU is left behind. As a result, and because models are replicated on each GPU upfront before training begins, at the end of each iteration all model replicas have the same updated parameters.

According to PyTorch docs, the most effective approach to DDP is to assign one process to manage each GPU device. The processes must all be launched and initialized for communication and synchronization before training begins. There are multiple ways to achieve this goal — shared file system, environment variables and TCP. In this post, I have used environment variables, which are populated by the Launcher Utility provided by PyTorch. This sets up the required environment for all processes to operate in tandem, including a master address and port for all processes to coordinate through , world size indicating total number of processes to wait for when synchronizing, and rank of each process (both local on each node and global across nodes). The gist below shows the key steps and commands of interest for both — single-node and multi-node launch.

After launch, the first step in each process is to call the blocking init_process_group function, which enables processes to establish communication with each other, and establishes the procedure for initialization (env vars here) and the choice of backend (there are multiple available — GLOO, NCCL, MPI. In this example, I’ve used NCCL).

Once all processes are initialized, the next step is to map each process to the GPU device that it will control. This must be managed at the level of client code. Local Rank (set by the Launcher utility described above) refers to the rank of the process on a specific node and goes from 0 to the number of processes launched on that particular node. This can then be used to index the GPU device to be managed by that specific process.

This also ensures that each process / GPU, in conjunction with a distributed data sampler object, handles their own subset of the data. Each GPU on each node works in parallel on its own batch of the training data.

As mentioned earlier, the DDP wrapper orchestrates training once the initialization and setup steps described above have been performed. From the end-user’s perspective, this is a drop-in replacement for an unwrapped model. Note that since GPU device has been appropriately set for each process using Local Rank, the model replica for each worker is automatically moved to the right device with the cuda() method call. Similarly, device_ids and output_device are correctly instantiated in the DDP wrapper constructor call.

This is mostly all that needs to be done to leverage the native distributed training wrappers from PyTorch.

Parallel implementation strategies for scaling SGD over multiple GPU devices expose a tradeoff between training efficiency and accuracy of the model, especially when using a large number of devices. This is an active area of research and a variety of approaches and tricks have been proposed to empirically manage this tradeoff. While the basic usage of PyTorch features as discussed in this post can be sufficient for many practical use-cases and moderate amount of scaling, PyTorch also exposes lower-level communication primitives for implementing more advanced or novel strategies.

In addition, more recent developments — like the RAPIDS data science platformwill likely give a further significant boost as it directly attacks the core issue of communication overhead when using a large cluster of instances, which can otherwise significantly diminish the gains achievable from distributed training. More on this in a follow-up post.

Additional Notes

For the purposes of demonstration in this post I have used and adapted much of the code and dataset from an example in PyTorch documentation.

All testing was done in a high-performance computing (HPC) environment where our compute nodes are equipped with 4 or 8 NVIDIA Tesla V100 GPU cards, each with 16GB on-device memory. Each node also has access to a shared volume from which the dataset can be loaded. The tests were run against PyTorch v1.3.1 with cuda toolkit v10.1 and cuDNN v7.6 on RedHat Enterprise v7.4 platform. The scripts were successfully executed for DP on a single node (up to 8 GPUs) and for DDP on multiple nodes (up to 4, each with 4 GPUs). Python v3.6.5 was used for all tests.