Training a neural net is far from being a straightforward task, as the slightest mistake leads to non-optimal results without any warning. Training depends on many factors and parameters and thus require a thoughtful approach.
It is known that the beginning of training (i.e., the first few iterations) is very important. When done improperly, you get bad results — sometimes, the network won’t even learn anything at all! For this reason, the way you initialize the weights of the neural network is one of the key factors to good training.
The goal of this article is to explain why initialization is impacting and present a different number of ways to implement it efficiently. We will test our approaches against practical examples.
The code uses the fastai library (based on pytorch) and lessons from the last fastai MOOC (which, by the way, is really great!). All experiment notebooks are available in this github repository.
Why is initialization important?
Neural-net training essentially consists in repeating the two following steps:
- A forward step that consists in a huge amount of matrix multiplication between weights and input / activations (we call activations the output of a layer that will become the input of the next layer, i.e., the hidden activations)
- A backward step that consists in updating the weights of the network in order to minimize the loss function (using gradients of the parameters)
During the forward step, the activations (and then the gradients) can quickly get really big or really small — this is due to the fact that we repeat a lot of matrix multiplications. More specifically, we might get either:
- very big activations and hence large gradients that shoot towards infinity
- very small activations and hence infinitesimal gradients, which may be canceled to zero due to numerical precision
Either of these effects is fatal for training. Below is an example of explosion with randomly initialized weights, on the first forward pass.
In this particular example, the mean and standard deviation is already huge at the 10th layer!
What makes things even trickier is that, in practice, you can still get non-optimal results after long periods of training even while avoiding explosion or vanishing effects. This is illustrated below on a simple convnet (experiments will be detailed in the second part of the article):
Notice that the default pytorch approach is not the best one, and that random init does not learn a lot (also: this is only a 5-layers network, meaning that a deeper network would not learn anything).
How to initialize your network
Recall that the goal of a good initialization is to:
- get random weights
- keep the activations in a good range during the first forward passes (and so for the gradients in the backward passes)
What is a good range in practice? Quantitatively speaking, it implies having the output of the Matrix multiplications with the input vector produce an output vector (i.e. activations) with mean near 0 and standard deviation near 1. Then each layer will propagate these statistics across all the layers.
And even on a deep network, you will have stable statistics on the first iterations.
We now discuss two approaches to do so.
The math approach: Kaiming init
So let’s picture the issue. If the initialized weights are too big at the beginning of training, then each matrix multiplication will exponentially increase the activations, leading to what we call gradient explosion.
Conversely, if the weights are too small, then each matrix multiplication will decrease the activations until they vanish completely.
So the key here is to scale the weights matrix to get outputs of matrix multiplication with a mean around 0 and a standard deviation of 1.
But then how to define the scale of the weights? Well, since each weight (as well as the input) is independent and distributed according to a normal distribution, we can get help by working out some math.
Two famous papers present a good initialization scheme based on this idea:
- The “Xavier initialization”, presented in 2010 in the paper Understanding the difficulty of training deep feedforward neural networks
- The “Kaiming initialization”, presented in 2015 in the paper Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification
In practice, the two schemes are quite similar: the “main” difference is that Kaiming initialization takes into account the ReLU activation function following each matrix multiplication.
Nowadays, most neural nets use ReLU (or a similar function like leaky ReLU). Here, we only focus on the Kaiming initialization.
The simplified formula (for standard ReLU) is to scale the random weights (drawn from a standard distribution) by:
For instance, if we have an input of size 512:
In addition, all bias parameters should be initialized to zeros.
Note that for Leaky ReLU the formula has an additional component, which we do not consider here (we refer the reader to the original paper).
Let’s check how this approach works on our previous example:
Notice that now we get an activation with mean 0.64 and standard deviation 0.87 after initialization. Obviously, this is not perfect (how could it be with random numbers?), but much better than normally-distributed random weights.
After 50 layers, we get a mean of 0.27 and a standard deviation of 0.464, so no more explosion or vanishing effects.
Optional: Quick explanation of Kaiming formula
The math derivations that lead to the magic scaling number of math.sqrt(2 / size of input vector) are provided in the Kaiming paper. In addition, we provide below some useful code, which the reader can skip entirely to proceed to the next section. Note that the code requires an understanding of how to do matrix multiplications and what variance / standard deviation is.
To understand the formula, we can think about what is the variance of the result of a matrix multiplication. In this example, we have a 512 vector multiplied by a 512×512 matrix, with an output of a 512 vector.
So in our case, the variance of the output of a matrix multiplication is around the size of the input vector. And, by definition, the standard deviation is the square root of that.
This is why dividing the weight matrix by the square root of the input vector size (512 in this example) gives us results with a standard deviation of 1.
But where does the numerator of “2” come from? This is only to take into account the ReLU layer.
As you know, ReLU sets the negative numbers to 0 (it’s only max(0, input)). So, because we have numbers centered around a mean of 0, it basically removes half the variance. This is why we add a numerator of 2.
The downside of the Kaiming init
The Kaiming init works great in practice, so why consider another approach? It turns out that there are some downsides of Kaming init:
- The mean after a layer is not of 0 but around 0.5. This is because of the ReLU activation function, which removes all the negative numbers, effectively shifting its mean
- Kaiming init only works with ReLU activation functions. Hence, if you have a more complex architecture (not only matmult → ReLU layers), then this won’t be able to keep a standard deviation around 1 on all the layers
- The standard deviation after a layer is not of 1 but close to 1. In a deep network, this could not be enough to keep a standard deviation close to one all the way.
The algorithmic approach: LSUV
So what can we do to get a good initialization scheme, without manually customizing the Kaiming init for more complex architectures?
The paper All you need is a good init, from 2015, shows an interesting approach. It is called LSUV (Layer-sequential unit-variance).
The solution consists in using a simple algorithm: first, initialize all the layers with orthogonal initialization. Then, take a mini batch input and, for each layer, compute the standard deviation of its output. Dividing each layer by the resulting deviation then resets it to 1. Below is the algorithm as explained in the paper:
After some testing, I have found that orthogonal initialization gives similar (and sometimes worse) results than doing a Kaiming init before ReLU.
Jeremy Howard, in the fastai MOOC, shows another implementation, which adds an update to the weights to keep a mean around 0. In my experiments, I also find that keeping the mean around 0 gives better results.
Now let’s compare the results of these two approaches.
Performance of initialization schemes
We will check the performance of the different initialization schemes on two architectures: a “simple” convnet with 5 layers, and a more complex resnet-like architecture.
The task is to do image classification on the imagenette dataset (a subset of 10 classes from the Imagenet dataset).
This experiment can be found in this notebook. Note that because of randomness, the results could be slightly different each time (but it does not change the order and the big picture).
It uses a simple model, defined as:
#ConvLayer is a Conv2D layer followed by a ReLU
nn.Sequential(ConvLayer(3, 32, ks=5), ConvLayer(32, 64), ConvLayer(64, 128), ConvLayer(128, 128), nn.AdaptiveAvgPool2d(1), Flatten(), nn.Linear(128, data.c))
Below is a comparison of 3 initialization schemes: Pytorch default’s init (it’s a kaiming init but with some specific parameters), Kaiming init and LSUV init.
Note that the random init performance is so bad we removed it from results that follow.
Activations stats after init
The first question is what are the activations stats after a forward pass for the first iteration? The closer we are to a mean of 0 and a standard deviation of 1, the better it will be.
This figure shows the stats of the activations at each layer, after initialization (before training).
For the standard deviation (right figure), both the LSUV and Kaiming init are close to one (and LSUV is closer). But for the pytorch default, the standard deviation is way lower.
For the mean value though, the Kaiming init has worse results. It is understandable because Kaiming init doesn’t take into account the ReLU effect on the mean. So the mean is around 0.5 and not 0.
Complex architecture (resnet50)
Now let’s check if we get similar results on a more complex architecture.
The architecture is xresnet-50, as implemented in the fastai library. It has 10x more layers than our previous simple model.
We will check it in 2 steps:
- without normalization layer: batchnorm will be disabled. Because this layer will modify the stats minibatch-wise, it should decrease the impact of the initialization
- with normalization layer: batchnorm will be enabled
Step 1: Without batchnorm
This experiment can be found in this notebook.
Without batchnorm, the results for 10 epochs are:
The plot shows that the accuracy (y-axis) is of 67% for LSUV, 57% for Kaiming init and 48% for the pytorch default. The difference is huge!
Let’s check the activations stats before training:
Let’s zoom to get a better scale:
We see that some layers have stats of 0: it is by design of the xresnet50, and independent of the init scheme. It is a trick from the paper Bag of Tricks for Image Classification with Convolutional Neural Networks (implemented in the fastai library).
We see that for:
- Pytorch default init: the standard deviation and mean are close to 0. This is not good and shows a vanishing issue
- Kaiming init: We get a big mean and standard deviation
- LSUV init: We get good stats, not perfect but better than other schemes
We see that the best init scheme for this example gives much better results for the full training, even after 10 full epochs. This shows the importance of keeping good stats across the layers during the first iteration.
Step 2: with batchnorm layers
This experiment can be found in this notebook.
Because batchnorm is normalizing the output of a layer, we should expect the init schemes to have less impact.
The results show close accuracy for all init schemes, near 88%. Note that at each run the best init scheme may change depending on the random generator.
It shows that batchnom layers make the network less sensitive to the initialization scheme.
The activations stats before training are the following:
Like before, the best seems to be the LSUV init (only one to keep a mean around 0 as well as a standard deviation close to 1).
But the results show this has no impact on the accuracy, at least for this architecture and this dataset. It confirms one thing though: batchnorm makes the network much less sensitive to the quality of the initialization.
What to remember from this article?
- The first iterations are very important and can have a lasting impact on the full training.
- A good initialization scheme should keep the input stats (mean of 0 and standard deviation of 1) on the activations across all the layers of the network (for the first iteration).
- Batchnorm layers reduce the neural net sensitivity to the initialization scheme.
- Using Kaiming init + LSUV seems to be a good approach, especially when the network lacks a normalization layer.
- Other kinds of architecture could have different behaviors regarding initialization.