Why does Batch normalization work ?

8 mins

So, why does batch normalization work? Here’s one reason, you know how normalizing the input features, the X’s, to mean zero and variance one, how that can speed up learning? So rather than having some features that range from zero to one and some from 1 to 1,000 , by normalizing all the features , input features X, to take on a similar range of values that can speed up learning. So one intuition behind working of batch norm is that this is doing a similar thing but further values in your hidden units and not just your input there.

A second reason why batch norm works is that it makes weights later or deeper into your network. Say thet weight on layer 10, is more robust to weights in earlier layers of neural network, say in layer one.

To explain what I mean, let’s take a look at this most vivid example. cat

Let’s see a training network, maybe a shallow network like logistic regression or maybe a neural network. Maybe a shallow network like this regression or maybe a deep network on cat detection problem. But let’s say that you have trained your data sets on all images of black cats. If you now try to apply this network to data with colored cats where the positive examples are not just black cats like on the left but also color cats on the right, then your network might not do very well. So in pictures, if your training set looks like this, where you have positive exaples here and negative examples there and you were to generalize it to a dataset where positive exaples are here and the negative exaples are there then you might not expect a model trained on the left to do very well on the right.

Even though this might be same function that actually works well,but you wouldn’t expect your learning algorithm to discover the decision boundary(in green) by just looking at the data on the left.

So this idea of data distribution changing goes by “covariate shift”. The idea is that if you have learned some X to Y mapping, if the distribution of X changes, then you might need to retrain your learning algorithm.And this is true if the function, the ground true function mapping from X to Y, remains unchanged which it is in this example because the ground true function is this picture a cat or not.And the need to retrain your function becomes even more acute or it becomes even worse if the ground true function shifts as well.

So, how does this problem of covariate shift apply to a neural network? Consider a deep network like this, Deep Neural net2

Let’s look at he learning process from the perspective of the certain layer, let’s say the third layer. So this network has learned the paramters W[3] and B[3] and from the perspective of the third hidden layer, it gets some set of values from earlier layers and then it has to do some stuff to hopefully make the output of Y_hat close to the ground true value Y. So let me cover up the nodes on the left for some time. Deep Neural net3

So from the perspective of this third hidden layer, it gets some values. Let’s call them A[2][1], A[2][2], A[2][3], A[2][4]. But these values might as well be features X1,X2,X3,X4. The job of the third hidden layer is to take these values and find a way to map them to Y_hat. So you could imagine doing gradient descent by these paramters W[3], B[3] as well as W[4],B[4] and even W[5]. Maybe try to make our network learn these parameters so that it does a good job mapping from values (shown on left side in blue box) to the output value which gives Y_hat. The network is also adapting parameters W[2],B[2] & W[1],B[1], and so if these paramters change, these values will also change. So from the perspective of third hidden layer, these hidden units value are changing all the time and suffering from the problem of “covariate shift”.

It reduces the amount that the distribution of these hidden unit value shifts around. If we were to plot the distribution of hidden unit values, maybe we technicalize it as z.

Deep Neural net4

So we plot two values instead of four values, so we can visualize it in 2d. What batch norm will do is is, these values of Z22 and Z21 can change and indeed they will change when the neural network updates the parameters in later layers. What batch norm would do is that how it changes the mean and variance of Z[2]1 and Z[2]2 will remain the same. So even if the exact values of Z[2]1 and Z[2]2 change, their mean and variance will remain the same (mean = 0, variance =1 (not necessarily) ). But whatever values are governed by β[2] and γ[2] which neural network can force it to be mean 0 and variance equal to 1. But what it does is that it limits the amount to which updating the parameter in the earlier layers can affect the distribution of values that the third layer now sees and therefore has to learn on. And so batch norm reduces the problem of the input values changing, it really causes these values to become more stable, so that the later layers of the neural network of the neural network has more firm ground to stand on. Even though the input distribution changes a bit, it changes less and what this does is that even as the earlier layers keep learning, the amount that this forces the later layers to adapt to as early as layer changes is reduced or if you will, it weakens the coupling between what the early layer parameters has to doand what the later layer parameters have to do. And so it allows each layer of the network to learn itself, a little more independently of other layers, and this has the effect of speeding up learning in the whole network. The takeaway is that batch normalization means that especially from the perspective of one of the later layers of the neural network, the earlier layers don’t get to shift around as much because they are constrained to have the same mean and variance. So this makes the job of learning on the later layers easier. It turns out that batch norm has a second effect, it has a slight regularization effect.

Batch norm as a regularization

  • Each mini batch is scaled by the mean/variance computed on just that mini batch.
  • This adds some noise to the values Z[1] within that mini batch. So similar to dropout, it adds some noise to each hidden layer’s activations.

So one non intuitive thing of a batch norm is that each mini batch, let’s say mini batch X,T has the values Z,l scaled by the mean and variance computed on just that mini batch . Now because the mean and variance computed on just that mini batch as opposed to computer on the entire data set, that means mean and variance has a little noise in it because it has computed just on your mini batch of say 64 or 128 or 256 or larger. So because the mean and variance are little bit noisy it’s estimated with just a relatively small sample of data. the scaling process going from Z_l to Z_l~ is a little bit noisy as well because it’s computed using slightly noisy mean and variance. So similar to dropout it adds noise to each hidden layers’s activation. The way dropout adds noises. It takes a hidden unit and it multiplies it by zero with some probability and multiplies it by one with some probability. And so your dropout has noise because it has been multiplied by zero and one whereas batch norm has noise because of scaling by standard devaiation as well as additive noise because of subtracting the mean.

Well here the estimates of mean and standartd deviation are noisy and so similar to dropout, batch norm has a slight regularization effect. Because by adding noise to hidden units, it’s forcing the downstream hidden units not to rely too much on any one hidden unit.

Because the noise added is quite small, this is not a huge regularization effect and you might choose to use batch norm together with dropout if you want the more powerful regulrizaion effect of dropout.And maybe slight less intuitive effect is that if you use a bigger mini batch size of size say 512 intead of 64, so by choosing larger mini batchsize you’re reducing the noise and therefore reducing the regularization effect. So that’s one strange property of dropout which is that by using a bigger mini batch size, you reduce the the regularization effect.Sometimes it has this extra unintended effect on your learning sustem algorithm. Don’t use batch norm as a regularizer. Use it as a way to normalize your hidden units and therefore speed up training and I think regularization is unintended side effect.

Batch norm handles data one mini batch at a time. It computes means and variance on mini batches, so at test time you try and make prediction, try and evaluate the neural network, you might not have a mini batch of examples, you might be processing a single example at the time. So at test tme, you need to do something slightly different to make sure your prediction makes sense.

Written on September 11, 2017