BatchNorm1 is a technique for normalising the outputs of intermediate layers inside a neural network using the batch's statistics. The method was originally designed to "reduce internal covariate shift" - to force all of the intermediate layers to have zero mean and unit variance so that the outputs do not explode (especially in deep networks) and keep everything nice and centered around zero. And BatchNorm has proved vital to getting networks to train, enabling larger learning rates, faster training and results that generalise better. But since the algorithm's original publishing, much more work has been done identifying why BatchNorm works. And it's not because it reduces covariate shift (mostly), but rather because it has a profound impact on the loss landscape2 that the algorithm must traverse in order to find an optimal solution. BatchNorm smooths the loss landscape so that our optimizers can find good solutions much easier. But with this in mind, couldn't we design a better normalisation strategy? One that uses our better understanding of why it works the way it does. Well that's what I hope to explain in this article.
BatchNorm
As a brief reminder, BatchNorm normalises batches using the mean and variance of the batch during training according to the following formula:
Why Should we Replace BatchNorm?
BatchNorm has proven instrumental to creating and training increasingly deeper networks, providing stability to training and reducing training times. Nevertheless, there are a few drawbacks to the method:
-
One of the main reasons that researchers look for alternatives to BN is when training a network with a small batch size. Since BatchNorm generates statistics across the batch, if the batch size is small, the standard deviation is likely to be very small for at least one batch of data, leading to numerical instability in the normalisation. But with the increasing model and input sizes used in modern networks, the batch size must be small to fit into memory, leading to a trade off between the size of the architecture and the size of the batch.
-
Additionally, the original premise behind BatchNorm proved to be somewhat incorrect2 - but the technique still worked (like a lot of other bugs in ML). With this improved knowledge could it not be possible to further improve the normalisation by redesigning it now. It's a bit like, after learning how an oven works, we create a new cake recipe to use this knowledge.
-
BatchNorm creates trouble for different domains when using transfer learning. When transferring from one domain to another, particularly with little data, the input has a different distribution. As a result, the normalization is particularly skewed and may ruin the training. In particular, this happens when just training the head of the network - a common practice in transfer learning. As a result, we have to train the BatchNorm layers (or make sure that they are in inference mode) in order for the body of the network to produce sensible results to train the head on.
-
BatchNorm has different training and testing phases, making the generalisation slightly more unpredictable. The network can't dynamically adjust to an input that has a completely different distribution than the ones it has previously encountered, which may lead to strange performance3.
Setting the Baseline
For the most part of this exploration, I will compare two networks: a small 4 layer convolutional network and a ResNet18, both trained on Imagenette, and averaged over 3 training runs each. For full details of the training process that I used, please see this notebook. Each result is run over a small grid search of parameters (batchsize: 128, learning rate: 0.01).
The baseline below shows each network without normalization. Click on the tabs to switch between networks. The best performing result is in bold.
| bs | lr | Accuracy | Loss |
|---|---|---|---|
| 128 | 0.01 | 0.47 ± 0.01 | 1.59 ± 0.02 |
| 128 | 0.001 | 0.35 ± 0.02 | 1.91 ± 0.04 |
| 2 | 0.01 | 0.46 ± 0.01 | 1.62 ± 0.03 |
| 2 | 0.001 | 0.51 ± 0.01 | 1.50 ± 0.02 |
With BatchNorm
Adding BatchNorm to the same baselines above leads to the following results. BatchNorm improves performance overall, particularly at the higher learning rate and large batch setting. We can see how training does not perform as well when the batch size is small.
| bs | lr | Accuracy | Loss |
|---|---|---|---|
| 128 | 0.01 | 0.59 ± 0.00 | 1.28 ± 0.01 |
| 128 | 0.001 | 0.44 ± 0.01 | 1.70 ± 0.02 |
| 2 | 0.01 | 0.43 ± 0.04 | 2.78 ± 0.77 |
| 2 | 0.001 | 0.44 ± 0.02 | 1.77 ± 0.10 |
Managing Small Batch Sizes
How can we manage small batch sizes? Well the first possible solution would be to increase the epsilon parameter in the divisor of the normalisation step which is used to prevent numerical instability. If we increase it to 0.1 then the output of the normalisation will be constrained to
