A Guide to Optimizer Implementation for BERT at Scale

Source: Deep Learning on Medium

A Guide to Optimizer Implementation for BERT at Scale

Training with larger batches is a straightforward way to scale training of deep neural networks to larger numbers of accelerators and reduce training time. However, as the batch size increases, numerical instability can appear in the training process. The purpose of this article to provide an overview of one class of solutions to this problem: layer-wise adaptive optimizers, such as LARS and LAMB. We will also discuss how NVIDIA’s implementation of LAMB, or NVLAMB, differs from the originally published algorithm.

Layer-wise Adaptive Approaches

Typically, DNN training uses mini-batch Stochastic Gradient Descent (SGD), which adapts all model weights with a tunable parameter called the learning rate λ in the following way: wᵢ₊₁ = wᵢ − λᵢ∇L(wᵢ), where wᵢ and ∇L(wᵢ) is the weight and the stochastic gradient of loss L with respect to the weight at the current training step i. When λ is large, the update ||λ ∗ ∇L(wᵢ)|| can become larger than ||wᵢ||, and this can cause the training process to diverge. This is particularly problematic with larger mini-batch sizes because they require higher learning rates to compensate for fewer training updates, but training frequently diverges when the learning rate is too high. This limits the maximum mini-batch size we can scale up to. It turns out that based on observations by You et al. some layers may cause instability before others, and the “weakest” of these layers limits the overall learning rate that may be applied to the model and increases training time.

The Layer-wise Adaptive Rate Scaling (LARS) optimizer by You et al. is an extension of SGD with momentum which determines a learning rate per layer, by normalizing gradients by L2 gradient norm and scaling normalized gradients by the L2 weight norm, in order to decouple the magnitude of update from the magnitude of the gradient. The ratio of the weight norm to the gradient norm is called the trust ratio for each layer. This allows the more stable layers (with larger ||wᵢ||) to use a more aggressive learning rate and often converge more quickly without a loss in accuracy.

Adam is a member of an algorithm class inspired by AdaGrad and introduces running averages of the first two gradient moments, mean and variance. Loshchilov and Hutter proposed a variation, AdamW, which decouples weight decay from gradient computation. A great overview of optimizers can be found here.

The Layer-wise Adaptive Moments Based (LAMB) optimizer can be seen as the application of LARS to the AdamW optimizer, which adds a per-weight normalization with respect to the square root of the second moment to compute the update. The next sections discuss the details in NVIDIA’s open-source implementation of LAMB and the adjustments involved to ensure SoTA pretraining convergence for BERT.


We started developing our implementation from the preprint of the LAMB paper (v1) on arXiv, before the upcoming publication, and our findings led us to a different final algorithm. The goal of the following sections is to help shed some light on the choices made in our implementation of LAMB (NVLAMB) w.r.t. the addition of gradient pre-normalization and bias correction as described in Figure 1.

Note: In Step 6 of NVLAMB dense weights and bias weights of a particular transformation are considered as separate layers.

The Importance of Gradient Pre-normalization

We perform a gradient pre-normalization step such that gradients on the entire model combined (all individual layers / weight matrices) are unit L2 norm, as described in Step 2 in the NVLAMB algorithm above. Pre-normalization is important since updates are only dependant on the gradient direction and not their magnitude. This is particularly beneficial in large batch settings where the direction on the gradient is largely preserved. The larger the batch size, the closer the approximation of the (stochastic) gradient is to the true (full-batch) gradient and is less likely to suffer from noisy gradients. While the LAMB publication does not include this, our experiments found that without pre-normalization, BERT pretraining does not converge as expected.

* Google’s original BERT GitHub repository, which uses the unmodified Adam optimizer, also performs gradient pre-normalization.

Figure 2. BERT Phase1 pretraining behavior with and without gradient pre-normalization

The LAMB publication additionally applies a scaling factor on the weight norm while computing the weight update. However, the publication doesn’t provide exact guidance on what scaling factor works best. In step 6 of our NVLAMB implementation, we do not scale the norm of the weight, but are still able to achieve state of the art accuracy on downstream tasks as shown in Table 1 below.

Bias Correction and Learning Rate Decay

We also note the authors use bias correction in the algorithm as well as include learning rate warmup for BERT pretraining. However, a later section in the appendix claims that bias correction can be dropped since its behavior is similar to warmup.

Figure 3. BERT pretraining behavior with different learning rate decays on both phases

We experimented further and found that without the correction term, BERT pre-training diverges early in the training process. This is because initializing the moving averages m and v to zero induces an implicit bias of (1 — β₁) and (1 — β₂) on the subsequent gradients, as shown in Step 3 in the algorithm above. To correct for this factor, the bias correction seen in Step 4 of the NVLAMB algorithm above is necessary. For a more rigorous derivation, please refer to Section 3 and 6.4 in the Adam paper. As shown in Figure 3, the degree of polynomial learning rate decay makes no observable difference. The accuracy after fine-tuning on downstream SQuAD 1.1 yields identical F1 scores in the range 91–91.5 % in both settings.

Table 1. Fine-tuning results on SqUAD v1.1 and GLUE benchmarks. * best scores obtained using published checkpoint
BERT paper here, LAMBv4 paper here.

Note: The LAMB results were obtained using twice the number of training samples as Adam-W, to achieve similar accuracies on downstream fine-tuning tasks as seen in Table 2. The original LAMB publication doesn’t explain how this was determined. We did not attempt to understand whether a different training recipe could use fewer total training samples. This is a potential area for further investigation.


We showcased the general idea behind layer-wise adaptive optimizers and how they build on top of existing optimizers that use a common global learning rate across all layers, and specifically the various published versions of LAMB as well as our implementation of NVLAMB. Layer-wise adaptive optimizer approaches enable training with larger mini-batches with no compromise in accuracy as shown in Table 1. This results in dramatically reduced training times on modern parallel hardware, down from days to almost an hour, as described in our earlier blog. We also provide the implementation in our BERT repositories based on PyTorch and Tensorflow.

Be sure to check out the extended version of this blog, here.


Sharath Sreenivas, Deep Learning Engineer, NVIDIA

Swetha Mandava, Deep Learning Engineer, NVIDIA

Boris Ginsburg, Principal Deep Learning Engineer, NVIDIA

Chris Forster, Senior CUDA Algorithms Software Engineer, NVIDIA