Normalization#
Tip
To improve the training stability, Llama normalize the input of each transformer sub-layer, instead of normalizing the output. Using RMSNorm as the normalizing function.
Batch Normalization#
\(\mathbf{E}[x]\) and \(\mathbf{Var}[x]\) are calculated over \((N, L)\) and they are vectors of size \(C\), where \(N\) denote batch size, \(L\) denote sequence length and \(C\) denote number of features.
import torch
from torch import nn
# (N, L, C)
batch, seq_len, num_features = 2, 3, 4
x = torch.randn(batch, seq_len, num_features).permute(0, 2, 1) # BatchNorm1d requires (N, C, L) input
bn = nn.BatchNorm1d(num_features)
t1 = bn(x)
t2 = (x - x.mean(axis=[0, 2], keepdim=True)) / x.std(axis=[0, 2], unbiased=False, keepdim=True)
torch.allclose(t1, t2, atol=1e-3, rtol=0)
True
Layer Normalization#
\(\mathbf{E}[x]\) and \(\mathbf{Var}[x]\) are calculated over \(C\) and they are vectors of size \((N,L)\).
Tip
LayerNorm normalize each example and each position independently.
Tip
Layer Normalization (LN) is preferred over Batch Normalization (BN) in Transformer models for several key reasons:
Batch Size Sensitivity.
Sequence Length Variability.
Consistent behavior during both training and inference.
Transformer Architecture Compatibility.
Training Stability.
# (N, L, C)
batch, seq_len, num_features = 2, 3, 4
x = torch.randn(batch, seq_len, num_features)
ln = nn.LayerNorm(num_features)
t1 = ln(x)
t2 = (x - x.mean(axis=-1, keepdim=True)) / x.std(axis=-1, unbiased=False, keepdim=True)
torch.allclose(t1, t2, atol=1e-3, rtol=0)
True
RMSNorm#
\(\mathbf{MeanSquare}[x]\) is calculated over \(C\) and is a vector of size \((N, L)\).
class RMSNorm(torch.nn.Module):
"""RMSNorm in Llama 3"""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
# x shape: (N, L, C)
# weight shape: (C,)
output = self._norm(x.float()).type_as(x)
return output * self.weight