Normalization
Normalization refers to adjusting data to shape it in a certain way. Most of them operate on a tensor where is batch dim and is channel dim.
Batch Normalization
Early deep networks had this problem of distribution of the data changing through the layers. This is called Internal Covariate Shift, and was solved by using small learning rates and careful initialization.
Batch normalization solves this issue by forcing input of every layer to have mean of 0 and variance of 1.
Then it applies a learnable scale and shift so the network can "undo" the normalization when needed.
+ Allows higher learning rates + Acts as a regularizer for the network + More efficient as it's foldable, which means we can fuse it with CNN kernel to optimize the norm layer. - Becomes noisy with smaller batch size, which is an issue with larger models
Training and Inference Behavior of BatchNorm
During training batch normalization uses batch statistics as defined. But in order to work with single images during inference, it uses the running average of mean/variance accumulated during training.
Layer Normalization
Instead of normalizing on batch dimension, it uses channel dimension. Replaces batch norm for text/sequence tasks.
+ Works with smaller batch sizes, which works better with Transformer architectures and bigger vision models that use smaller batches to fit in GPU memory - Doesn't work well with CNNs because it removes some of the spatial features that CNNs need.
Group Normalization
Divides the channels into G groups and computes its mean and variance within each group for a single sample.
+ Works better with CNNs. - Doesn't work when all channels learn a single, unified representation.
RMSNorm
Root Mean Square Normalization Developed as a simplification of LN. Instead of centering (subtracting mean) and scaling (dividing by variance), RMSNorm only scales the input.
SOTA LLM models (Llama, Gemma, DeepSeek) use it now.
+ More cost-effective than Layer Normalization
- : Learnable gain, allows model to learn scaling amount for each feature
- : A small constant that prevents division by zero.
TODO: NF-Nets