Normalization#

Tip

To improve the training stability, Llama normalize the input of each transformer sub-layer, instead of normalizing the output. Using RMSNorm as the normalizing function.

../_images/post-pre-norm.svg

Batch Normalization#

\[y = \frac{x - \mathbf{E}[x]}{\sqrt{\mathbf{Var}[x] + \epsilon}}*\gamma+\beta\]

\(\mathbf{E}[x]\) and \(\mathbf{Var}[x]\) are calculated over \((N, L)\) and they are vectors of size \(C\), where \(N\) denote batch size, \(L\) denote sequence length and \(C\) denote number of features.

import torch
from torch import nn

# (N, L, C)
batch, seq_len, num_features = 2, 3, 4
x = torch.randn(batch, seq_len, num_features).permute(0, 2, 1)  # BatchNorm1d requires (N, C, L) input
bn = nn.BatchNorm1d(num_features)

t1 = bn(x)
t2 = (x - x.mean(axis=[0, 2], keepdim=True)) / x.std(axis=[0, 2], unbiased=False, keepdim=True)
torch.allclose(t1, t2, atol=1e-3, rtol=0)
True

Layer Normalization#

\[y = \frac{x - \mathbf{E}[x]}{\sqrt{\mathbf{Var}[x] + \epsilon}}*\gamma+\beta\]

\(\mathbf{E}[x]\) and \(\mathbf{Var}[x]\) are calculated over \(C\) and they are vectors of size \((N,L)\).

Tip

LayerNorm normalize each example and each position independently.

Tip

Layer Normalization (LN) is preferred over Batch Normalization (BN) in Transformer models for several key reasons:

  • Batch Size Sensitivity.

  • Sequence Length Variability.

  • Consistent behavior during both training and inference.

  • Transformer Architecture Compatibility.

  • Training Stability.

# (N, L, C)
batch, seq_len, num_features = 2, 3, 4
x = torch.randn(batch, seq_len, num_features)
ln = nn.LayerNorm(num_features)

t1 = ln(x)
t2 = (x - x.mean(axis=-1, keepdim=True)) / x.std(axis=-1, unbiased=False, keepdim=True)
torch.allclose(t1, t2, atol=1e-3, rtol=0)
True

RMSNorm#

\[y = \frac{x}{\sqrt{\mathbf{MeanSquare}[x] + \epsilon}}*\gamma\]

\(\mathbf{MeanSquare}[x]\) is calculated over \(C\) and is a vector of size \((N, L)\).

class RMSNorm(torch.nn.Module):
    """RMSNorm in Llama 3"""
    
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        # x shape: (N, L, C)
        # weight shape: (C,)
        output = self._norm(x.float()).type_as(x)
        return output * self.weight