Understanding BERT architecture

Source: Deep Learning on Medium

Understanding BERT architecture

BERT is probably one of the most exciting developments in NLP in the recent years. Just last month, even Google has announced that it is using BERT in its search, supposedly the “biggest leap forward” it did in understanding search in the past five years. That is a huge testament to come from Google. About Search! That’s just how significant BERT is.

Now there are some amazing resources to understand BERT, Transformers, and Attention networks in detail (Attention and Transformers are the building blocks of BERT). I am linking them in the footnote. This article is about understanding the architecture and parameters better, once you already understood BERT at a decent level.

The model is fortunately very easy to load in Python, using Keras (and keras_bert). Following code loads up the model, and print a summary of all the layers.

import keras
from keras_bert import get_base_dict, get_model, compile_model, gen_batch_inputs
# Build & train the model
model = get_model(
token_num=30000,
head_num=12,
transformer_num=12,
embed_dim=768,
feed_forward_dim=3072,
seq_len=512,
pos_num=512,
dropout_rate=0.05,
)
compile_model(model)
model.summary()

All the parameters I have used, including the token count, here are from BERT base model (BERT has two variants, a small variant called base, and another one called large). The parameters are (remember the notation, as we will be using them later):

  1. Token number (T) = 30k. This is no:of distinct tokens, derived from WordPiece tokenization. This breaks down single words to component words, to improve coverage. Ex: playing is converted to (play ,##ing). So as long as the model knows the word “sleep”, it can infer the meaning of “sleeping” even if it is seeing the word for first time
  2. head_num (A) = 12. Total 12 attention heads per Transformer layer
  3. Transformer num (L) = 12
  4. embed_dim (H) = Embedding length =768
  5. Feed forward Dim (FFD) = H*4 =3072
  6. seq_len (S)= Max no:of tokens that can be in an input sentence = 512
  7. pos_num (P) = Positions to be encoded = S = 512

This should give you a very long summary of all the layers in BERT, which looks like this:

Layer (type) Output Shape Param # Connected to 
==================================================================================================
Input-Token (InputLayer) (None, 512) 0
__________________________________________________________________________________________________
Input-Segment (InputLayer) (None, 512) 0
__________________________________________________________________________________________________
Embedding-Token (TokenEmbedding [(None, 512, 768), ( 23040000 Input-Token[0][0]
__________________________________________________________________________________________________
Embedding-Segment (Embedding) (None, 512, 768) 1536 Input-Segment[0][0]
__________________________________________________________________________________________________
Embedding-Token-Segment (Add) (None, 512, 768) 0 Embedding-Token[0][0]
Embedding-Segment[0][0]
__________________________________________________________________________________________________
Embedding-Position (PositionEmb (None, 512, 768) 393216 Embedding-Token-Segment[0][0]
__________________________________________________________________________________________________
Embedding-Dropout (Dropout) (None, 512, 768) 0 Embedding-Position[0][0]
__________________________________________________________________________________________________
Embedding-Norm (LayerNormalizat (None, 512, 768) 1536 Embedding-Dropout[0][0]
__________________________________________________________________________________________________
Encoder-1-MultiHeadSelfAttentio (None, 512, 768) 2362368 Embedding-Norm[0][0]
__________________________________________________________________________________________________
Encoder-1-MultiHeadSelfAttentio (None, 512, 768) 0 Encoder-1-MultiHeadSelfAttention[
__________________________________________________________________________________________________
Encoder-1-MultiHeadSelfAttentio (None, 512, 768) 0 Embedding-Norm[0][0]
Encoder-1-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-1-MultiHeadSelfAttentio (None, 512, 768) 1536 Encoder-1-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-1-FeedForward (FeedForw (None, 512, 768) 4722432 Encoder-1-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-1-FeedForward-Dropout ( (None, 512, 768) 0 Encoder-1-FeedForward[0][0]
__________________________________________________________________________________________________
Encoder-1-FeedForward-Add (Add) (None, 512, 768) 0 Encoder-1-MultiHeadSelfAttention-
Encoder-1-FeedForward-Dropout[0][
__________________________________________________________________________________________________
Encoder-1-FeedForward-Norm (Lay (None, 512, 768) 1536 Encoder-1-FeedForward-Add[0][0]
__________________________________________________________________________________________________
Encoder-2-MultiHeadSelfAttentio (None, 512, 768) 2362368 Encoder-1-FeedForward-Norm[0][0]
__________________________________________________________________________________________________
Encoder-2-MultiHeadSelfAttentio (None, 512, 768) 0 Encoder-2-MultiHeadSelfAttention[
__________________________________________________________________________________________________
Encoder-2-MultiHeadSelfAttentio (None, 512, 768) 0 Encoder-1-FeedForward-Norm[0][0]
Encoder-2-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-2-MultiHeadSelfAttentio (None, 512, 768) 1536 Encoder-2-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-2-FeedForward (FeedForw (None, 512, 768) 4722432 Encoder-2-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-2-FeedForward-Dropout ( (None, 512, 768) 0 Encoder-2-FeedForward[0][0]
__________________________________________________________________________________________________
Encoder-2-FeedForward-Add (Add) (None, 512, 768) 0 Encoder-2-MultiHeadSelfAttention-
Encoder-2-FeedForward-Dropout[0][
__________________________________________________________________________________________________
Encoder-2-FeedForward-Norm (Lay (None, 512, 768) 1536 Encoder-2-FeedForward-Add[0][0]
__________________________________________________________________________________________________
Encoder-3-MultiHeadSelfAttentio (None, 512, 768) 2362368 Encoder-2-FeedForward-Norm[0][0]
__________________________________________________________________________________________________
Encoder-3-MultiHeadSelfAttentio (None, 512, 768) 0 Encoder-3-MultiHeadSelfAttention[
__________________________________________________________________________________________________
Encoder-3-MultiHeadSelfAttentio (None, 512, 768) 0 Encoder-2-FeedForward-Norm[0][0]
Encoder-3-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-3-MultiHeadSelfAttentio (None, 512, 768) 1536 Encoder-3-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-3-FeedForward (FeedForw (None, 512, 768) 4722432 Encoder-3-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-3-FeedForward-Dropout ( (None, 512, 768) 0 Encoder-3-FeedForward[0][0]
__________________________________________________________________________________________________
Encoder-3-FeedForward-Add (Add) (None, 512, 768) 0 Encoder-3-MultiHeadSelfAttention-
Encoder-3-FeedForward-Dropout[0][
__________________________________________________________________________________________________
Encoder-3-FeedForward-Norm (Lay (None, 512, 768) 1536 Encoder-3-FeedForward-Add[0][0]
__________________________________________________________________________________________________
Encoder-4-MultiHeadSelfAttentio (None, 512, 768) 2362368 Encoder-3-FeedForward-Norm[0][0]
__________________________________________________________________________________________________
Encoder-4-MultiHeadSelfAttentio (None, 512, 768) 0 Encoder-4-MultiHeadSelfAttention[
__________________________________________________________________________________________________
Encoder-4-MultiHeadSelfAttentio (None, 512, 768) 0 Encoder-3-FeedForward-Norm[0][0]
Encoder-4-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-4-MultiHeadSelfAttentio (None, 512, 768) 1536 Encoder-4-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-4-FeedForward (FeedForw (None, 512, 768) 4722432 Encoder-4-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-4-FeedForward-Dropout ( (None, 512, 768) 0 Encoder-4-FeedForward[0][0]
__________________________________________________________________________________________________
Encoder-4-FeedForward-Add (Add) (None, 512, 768) 0 Encoder-4-MultiHeadSelfAttention-
Encoder-4-FeedForward-Dropout[0][
__________________________________________________________________________________________________
Encoder-4-FeedForward-Norm (Lay (None, 512, 768) 1536 Encoder-4-FeedForward-Add[0][0]
__________________________________________________________________________________________________
Encoder-5-MultiHeadSelfAttentio (None, 512, 768) 2362368 Encoder-4-FeedForward-Norm[0][0]
__________________________________________________________________________________________________
Encoder-5-MultiHeadSelfAttentio (None, 512, 768) 0 Encoder-5-MultiHeadSelfAttention[
__________________________________________________________________________________________________
Encoder-5-MultiHeadSelfAttentio (None, 512, 768) 0 Encoder-4-FeedForward-Norm[0][0]
Encoder-5-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-5-MultiHeadSelfAttentio (None, 512, 768) 1536 Encoder-5-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-5-FeedForward (FeedForw (None, 512, 768) 4722432 Encoder-5-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-5-FeedForward-Dropout ( (None, 512, 768) 0 Encoder-5-FeedForward[0][0]
__________________________________________________________________________________________________
Encoder-5-FeedForward-Add (Add) (None, 512, 768) 0 Encoder-5-MultiHeadSelfAttention-
Encoder-5-FeedForward-Dropout[0][
__________________________________________________________________________________________________
Encoder-5-FeedForward-Norm (Lay (None, 512, 768) 1536 Encoder-5-FeedForward-Add[0][0]
__________________________________________________________________________________________________
Encoder-6-MultiHeadSelfAttentio (None, 512, 768) 2362368 Encoder-5-FeedForward-Norm[0][0]
__________________________________________________________________________________________________
Encoder-6-MultiHeadSelfAttentio (None, 512, 768) 0 Encoder-6-MultiHeadSelfAttention[
__________________________________________________________________________________________________
Encoder-6-MultiHeadSelfAttentio (None, 512, 768) 0 Encoder-5-FeedForward-Norm[0][0]
Encoder-6-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-6-MultiHeadSelfAttentio (None, 512, 768) 1536 Encoder-6-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-6-FeedForward (FeedForw (None, 512, 768) 4722432 Encoder-6-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-6-FeedForward-Dropout ( (None, 512, 768) 0 Encoder-6-FeedForward[0][0]
__________________________________________________________________________________________________
Encoder-6-FeedForward-Add (Add) (None, 512, 768) 0 Encoder-6-MultiHeadSelfAttention-
Encoder-6-FeedForward-Dropout[0][
__________________________________________________________________________________________________
Encoder-6-FeedForward-Norm (Lay (None, 512, 768) 1536 Encoder-6-FeedForward-Add[0][0]
__________________________________________________________________________________________________
Encoder-7-MultiHeadSelfAttentio (None, 512, 768) 2362368 Encoder-6-FeedForward-Norm[0][0]
__________________________________________________________________________________________________
Encoder-7-MultiHeadSelfAttentio (None, 512, 768) 0 Encoder-7-MultiHeadSelfAttention[
__________________________________________________________________________________________________
Encoder-7-MultiHeadSelfAttentio (None, 512, 768) 0 Encoder-6-FeedForward-Norm[0][0]
Encoder-7-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-7-MultiHeadSelfAttentio (None, 512, 768) 1536 Encoder-7-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-7-FeedForward (FeedForw (None, 512, 768) 4722432 Encoder-7-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-7-FeedForward-Dropout ( (None, 512, 768) 0 Encoder-7-FeedForward[0][0]
__________________________________________________________________________________________________
Encoder-7-FeedForward-Add (Add) (None, 512, 768) 0 Encoder-7-MultiHeadSelfAttention-
Encoder-7-FeedForward-Dropout[0][
__________________________________________________________________________________________________
Encoder-7-FeedForward-Norm (Lay (None, 512, 768) 1536 Encoder-7-FeedForward-Add[0][0]
__________________________________________________________________________________________________
Encoder-8-MultiHeadSelfAttentio (None, 512, 768) 2362368 Encoder-7-FeedForward-Norm[0][0]
__________________________________________________________________________________________________
Encoder-8-MultiHeadSelfAttentio (None, 512, 768) 0 Encoder-8-MultiHeadSelfAttention[
__________________________________________________________________________________________________
Encoder-8-MultiHeadSelfAttentio (None, 512, 768) 0 Encoder-7-FeedForward-Norm[0][0]
Encoder-8-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-8-MultiHeadSelfAttentio (None, 512, 768) 1536 Encoder-8-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-8-FeedForward (FeedForw (None, 512, 768) 4722432 Encoder-8-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-8-FeedForward-Dropout ( (None, 512, 768) 0 Encoder-8-FeedForward[0][0]
__________________________________________________________________________________________________
Encoder-8-FeedForward-Add (Add) (None, 512, 768) 0 Encoder-8-MultiHeadSelfAttention-
Encoder-8-FeedForward-Dropout[0][
__________________________________________________________________________________________________
Encoder-8-FeedForward-Norm (Lay (None, 512, 768) 1536 Encoder-8-FeedForward-Add[0][0]
__________________________________________________________________________________________________
Encoder-9-MultiHeadSelfAttentio (None, 512, 768) 2362368 Encoder-8-FeedForward-Norm[0][0]
__________________________________________________________________________________________________
Encoder-9-MultiHeadSelfAttentio (None, 512, 768) 0 Encoder-9-MultiHeadSelfAttention[
__________________________________________________________________________________________________
Encoder-9-MultiHeadSelfAttentio (None, 512, 768) 0 Encoder-8-FeedForward-Norm[0][0]
Encoder-9-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-9-MultiHeadSelfAttentio (None, 512, 768) 1536 Encoder-9-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-9-FeedForward (FeedForw (None, 512, 768) 4722432 Encoder-9-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-9-FeedForward-Dropout ( (None, 512, 768) 0 Encoder-9-FeedForward[0][0]
__________________________________________________________________________________________________
Encoder-9-FeedForward-Add (Add) (None, 512, 768) 0 Encoder-9-MultiHeadSelfAttention-
Encoder-9-FeedForward-Dropout[0][
__________________________________________________________________________________________________
Encoder-9-FeedForward-Norm (Lay (None, 512, 768) 1536 Encoder-9-FeedForward-Add[0][0]
__________________________________________________________________________________________________
Encoder-10-MultiHeadSelfAttenti (None, 512, 768) 2362368 Encoder-9-FeedForward-Norm[0][0]
__________________________________________________________________________________________________
Encoder-10-MultiHeadSelfAttenti (None, 512, 768) 0 Encoder-10-MultiHeadSelfAttention
__________________________________________________________________________________________________
Encoder-10-MultiHeadSelfAttenti (None, 512, 768) 0 Encoder-9-FeedForward-Norm[0][0]
Encoder-10-MultiHeadSelfAttention
__________________________________________________________________________________________________
Encoder-10-MultiHeadSelfAttenti (None, 512, 768) 1536 Encoder-10-MultiHeadSelfAttention
__________________________________________________________________________________________________
Encoder-10-FeedForward (FeedFor (None, 512, 768) 4722432 Encoder-10-MultiHeadSelfAttention
__________________________________________________________________________________________________
Encoder-10-FeedForward-Dropout (None, 512, 768) 0 Encoder-10-FeedForward[0][0]
__________________________________________________________________________________________________
Encoder-10-FeedForward-Add (Add (None, 512, 768) 0 Encoder-10-MultiHeadSelfAttention
Encoder-10-FeedForward-Dropout[0]
__________________________________________________________________________________________________
Encoder-10-FeedForward-Norm (La (None, 512, 768) 1536 Encoder-10-FeedForward-Add[0][0]
__________________________________________________________________________________________________
Encoder-11-MultiHeadSelfAttenti (None, 512, 768) 2362368 Encoder-10-FeedForward-Norm[0][0]
__________________________________________________________________________________________________
Encoder-11-MultiHeadSelfAttenti (None, 512, 768) 0 Encoder-11-MultiHeadSelfAttention
__________________________________________________________________________________________________
Encoder-11-MultiHeadSelfAttenti (None, 512, 768) 0 Encoder-10-FeedForward-Norm[0][0]
Encoder-11-MultiHeadSelfAttention
__________________________________________________________________________________________________
Encoder-11-MultiHeadSelfAttenti (None, 512, 768) 1536 Encoder-11-MultiHeadSelfAttention
__________________________________________________________________________________________________
Encoder-11-FeedForward (FeedFor (None, 512, 768) 4722432 Encoder-11-MultiHeadSelfAttention
__________________________________________________________________________________________________
Encoder-11-FeedForward-Dropout (None, 512, 768) 0 Encoder-11-FeedForward[0][0]
__________________________________________________________________________________________________
Encoder-11-FeedForward-Add (Add (None, 512, 768) 0 Encoder-11-MultiHeadSelfAttention
Encoder-11-FeedForward-Dropout[0]
__________________________________________________________________________________________________
Encoder-11-FeedForward-Norm (La (None, 512, 768) 1536 Encoder-11-FeedForward-Add[0][0]
__________________________________________________________________________________________________
Encoder-12-MultiHeadSelfAttenti (None, 512, 768) 2362368 Encoder-11-FeedForward-Norm[0][0]
__________________________________________________________________________________________________
Encoder-12-MultiHeadSelfAttenti (None, 512, 768) 0 Encoder-12-MultiHeadSelfAttention
__________________________________________________________________________________________________
Encoder-12-MultiHeadSelfAttenti (None, 512, 768) 0 Encoder-11-FeedForward-Norm[0][0]
Encoder-12-MultiHeadSelfAttention
__________________________________________________________________________________________________
Encoder-12-MultiHeadSelfAttenti (None, 512, 768) 1536 Encoder-12-MultiHeadSelfAttention
__________________________________________________________________________________________________
Encoder-12-FeedForward (FeedFor (None, 512, 768) 4722432 Encoder-12-MultiHeadSelfAttention
__________________________________________________________________________________________________
Encoder-12-FeedForward-Dropout (None, 512, 768) 0 Encoder-12-FeedForward[0][0]
__________________________________________________________________________________________________
Encoder-12-FeedForward-Add (Add (None, 512, 768) 0 Encoder-12-MultiHeadSelfAttention
Encoder-12-FeedForward-Dropout[0]
__________________________________________________________________________________________________
Encoder-12-FeedForward-Norm (La (None, 512, 768) 1536 Encoder-12-FeedForward-Add[0][0]
__________________________________________________________________________________________________
MLM-Dense (Dense) (None, 512, 768) 590592 Encoder-12-FeedForward-Norm[0][0]
__________________________________________________________________________________________________
MLM-Norm (LayerNormalization) (None, 512, 768) 1536 MLM-Dense[0][0]
__________________________________________________________________________________________________
Extract (Extract) (None, 768) 0 Encoder-12-FeedForward-Norm[0][0]
__________________________________________________________________________________________________
MLM-Sim (EmbeddingSimilarity) (None, 512, 30000) 30000 MLM-Norm[0][0]
Embedding-Token[0][1]
__________________________________________________________________________________________________
Input-Masked (InputLayer) (None, 512) 0
__________________________________________________________________________________________________
NSP-Dense (Dense) (None, 768) 590592 Extract[0][0]
__________________________________________________________________________________________________
MLM (Masked) (None, 512, 30000) 0 MLM-Sim[0][0]
Input-Masked[0][0]
__________________________________________________________________________________________________
NSP (Dense) (None, 2) 1538 NSP-Dense[0][0]
==================================================================================================
Total params: 109,705,010
Trainable params: 109,705,010
Non-trainable params: 0

That’s a daunting list. The total number of trainable parameters is ~110M, just like the BERT paper mentions. That’s reassuring that the model we loaded is the right one.

The same can also be visualized in an image which helps us understand the computation graph better:

from keras.utils import plot_model
plot_model(model, to_file='bert.png')

Here’s a brief of various steps in the model:

  1. Two inputs: One from word tokens, one from segment-layer
  2. These get added, summed over to a third embedding: position embedding, followed by dropout and a layer normalization
  3. Then starts Multi-head Self Attention layers — each set of these have 9 steps (all cells starting with Encoder-1 in the above image), and there are 12 such layers. So 108 lines in this are just to capture these. If we understand these better, we understand the architecture almost completely
  4. Following these 12 layers, there are two outputs — one for NSP (Next Sentence Prediction) and one for MLM (Masked Language Modeling)

Layer-wise accounting:

Going through layers from top to bottom, we can see following:

  1. Inputs — Token and segment do not have any trainable parameters, as expected.
  2. Token embeddings parameters= 23040000 (H * T) — because each of 30k (T) tokens needs a representation in dimension 768 (H)
  3. Segment Embeddings parameters = 1536 (2*H) because we need two vectors each of length (H). The vectors represent Segment A and Segment B respectively
  4. Token embeddings and segment embeddings are added to Position Embedding. Parameters = 393216 (H*P). This is because it needs to generate P vectors, each of length H, for the tokens starting 1 to 512 (P). The position embeddings in BERT are trained and not fixed as in Attention is all you need; There’s a dropout applied, and then Layer Normalization is done
  5. Layer Normalization parameters = 1536 (2*H). Normalization has two parameters to learn — mean and standard deviation of each of the embedding position, hence 2*H
  6. Encoder: MultiheadSelfAttention: MultiHeadAttention = 2362368

This needs a bit of explanation. This is what’s happening inside this step[ref]:

There are total 12 heads, with input of dimension 768. So each head generates embedding of length 768/12 = 64. There are three embeddings generated — Q, K, V. That’s toal: 768*64*3 parameters per head, or 12*768*64*3 for all heads. Adding biases for each of Q, K, V, there are another 768*3. Total =12 * 768 * 64 * 3 + 768 * 3; This is after concatenating all the heads. Then an additional weight (W0 towards right in above image) is applied. That is a fully connected dense layer, with output dimension = input dimension. Hence, parameters (with bias)= 768*768 + 768. So the total parameters in this step = A * D * (D/A) * 3+ D * 3 + D * D + D = 12 * 768 * 64 * 3 + 768 * 3 + 768*768 + 768 = 2362368

7. Another Layer Normalization, following same logic as #5

8. FeedForward: FeedForward. This is actually a FeedForward network, which has two fully connected feedforward layers. It transforms the input dimension (H) to FFD, and back to H with ReLu activation in between. So total parameters with biases = (H * FFD + FFD) + (FFD * H + H) = (768 * 3072 + 3072) + (3072 * 768 + 768) = 4722432; This follows another Dropout layer

9. Another Layer Normalization, following same logic as #5

Steps 6–9 covers a single Transformer Layer, and the same set repeats for 12(L) times

This follows two output objectives. One for MLM (Masked Language Modeling)and one for NSP (Next Sentence Prediction). Let’s observe their parameters:

10. MLM- Dense: This takes an embedding as input and tries to predict the masked word’s embedding. So parameters (with bias) = H * H + H = 768*768 + 768 = 590592

11. MLM-Norm: Normalization layer, with parameter count following same logic as #5

12. MLM-Sim: EmbeddingSimilarity: This is computing the similarity between the output of MLM-Norm, and the input masked token’s embedding. But this layer also learns token level bias. So that’s T (=30k) parameters in this layer (Intuitively, I undestand this as similar to token-level priors, but please correct me if I am wrong).

13. NSP-Dense:Dense: This converts the input D length embedding to another D length embedding. Parameters = D *D + D = 590592

14. NSP: Dense: The output D length embedding of previous layer then gets transformed to two vectors, each representing IsNext and NotNext respectively. Hence, parameters = 2*D + 2= 1538

This concludes overview of the whole network. By going through this, following questions got answered for me:

  1. Where the Sequence and Position embeddings are coming from, and the fact that both are trainable
  2. Detailed understanding of what’s happening inside the Transformer cell
  3. Role of Layer Normalization
  4. Propagation of both MLM and NSP tasks at once

Footnote:

PyTorch walkthrough implementation of Attention

Three types of Embeddings in BERT

Colab notebook to understand attention in BERT– This also has a cool interactive visualization that explains how the Q, K, V embeddings are interacting with each other to produce the attention distribution

Explaining Transformer, Self-Attention, and Cross-Attention