Batch normalization (also known as batch norm) is a method used to make training of artificial neural networks faster and more stable through normalization of the layers' inputs by re-centering and re-scaling. It was proposed by Sergey Ioffe and Christian Szegedy in 2015.[1]
While the effect of batch normalization is evident, the reasons behind its effectiveness remain under discussion. It was believed that it can mitigate the problem of internal covariate shift, where parameter initialization and changes in the distribution of the inputs of each layer affect the learning rate of the network. Recently, some scholars have argued that batch normalization does not reduce internal covariate shift, but rather smooths the objective function, which in turn improves the performance.[2] However, at initialization, batch normalization in fact induces severe gradient explosion in deep networks, which is only alleviated by skip connections in residual networks.[3] Others maintain that batch normalization achieves length-direction decoupling, and thereby accelerates neural networks.[4]
Each layer of a neural network has inputs with a corresponding distribution, which is affected during the training process by the randomness in the parameter initialization and the randomness in the input data. The effect of these sources of randomness on the distribution of the inputs to internal layers during training is described as internal covariate shift. Although a clear-cut precise definition seems to be missing, the phenomenon observed in experiments is the change on means and variances of the inputs to internal layers during training.
Batch normalization was initially proposed to mitigate internal covariate shift. During the training stage of networks, as the parameters of the preceding layers change, the distribution of inputs to the current layer changes accordingly, such that the current layer needs to constantly readjust to new distributions. This problem is especially severe for deep networks, because small changes in shallower hidden layers will be amplified as they propagate within the network, resulting in significant shift in deeper hidden layers. Therefore, the method of batch normalization is proposed to reduce these unwanted shifts to speed up training and to produce more reliable models.
Besides reducing internal covariate shift, batch normalization is believed to introduce many other benefits. With this additional operation, the network can use higher learning rate without vanishing or exploding gradients. Furthermore, batch normalization seems to have a regularizing effect such that the network improves its generalization properties, and it is thus unnecessary to use dropout to mitigate overfitting. It has also been observed that the network becomes more robust to different initialization schemes and learning rates while using batch normalization.
In a neural network, batch normalization is achieved through a normalization step that fixes the means and variances of each layer's inputs. Ideally, the normalization would be conducted over the entire training set, but to use this step jointly with stochastic optimization methods, it is impractical to use the global information. Thus, normalization is restrained to each mini-batch in the training process.
Let us use B to denote a mini-batch of size m of the entire training set. The empirical mean and variance of B could thus be denoted as
\muB=
1 | |
m |
m | |
\sum | |
i=1 |
xi
2 | |
\sigma | |
B |
=
1 | |
m |
m | |
\sum | |
i=1 |
(xi-\mu
2 | |
B) |
For a layer of the network with d-dimensional input,
x=(x(1),...,x(d))
(k) | |
\hat{x} | |
i |
=
| ||||||||||||||||
\sqrt{\left(\sigma |
(k) | |
B |
\right)2+\epsilon}
k\in[1,d]
i\in[1,m]
(k) | |
\mu | |
B |
(k) | |
\sigma | |
B |
\epsilon
\hat{x}(k)
\epsilon
(k) | |
y | |
i |
=\gamma(k)
(k) | |
\hat{x} | |
i |
+\beta(k)
where the parameters
\gamma(k)
\beta(k)
Formally, the operation that implements batch normalization is a transform
BN | |
\gamma(k),\beta(k) |
:
(k) | |
x | |
1...m |
→
(k) | |
y | |
1...m |
y(k)=
BN | |
\gamma(k),\beta(k) |
(x(k))
(k) | |
\hat{x} | |
i |
The described BN transform is a differentiable operation, and the gradient of the loss l with respect to the different parameters can be computed directly with the chain rule.
Specifically,
\partiall | ||||||||
|
\partiall | ||||||||
|
\partiall | |
\partial\hat{x |
(k) | |
i |
\partiall | |
\partial\gamma(k) |
=
m | |
\sum | |
i=1 |
\partiall | ||||||||
|
(k) | |
\hat{x} | |
i |
\partiall | |
\partial\beta(k) |
=
m | |
\sum | |
i=1 |
\partiall | ||||||||
|
\partiall | ||||||||
|
=
m | |
\sum | |
i=1 |
\partiall | ||||||||
|
(k) | |
(x | |
i |
(k) | ||
-\mu | )\left(- | |
B |
\gamma(k) | |
2 |
(k)2 | |
(\sigma | |
B |
+\epsilon)-3/2\right)
\partiall | ||||||||
|
=
m | |
\sum | |
i=1 |
\partiall | ||||||||
|
-\gamma(k) | |||||||||
|
and
\partiall | ||||||||
|
=
\partiall | |
\partial\hat{x |
(k) | |||||||||||||
|
During the training stage, the normalization steps depend on the mini-batches to ensure efficient and reliable training. However, in the inference stage, this dependence is not useful any more. Instead, the normalization step in this stage is computed with the population statistics such that the output could depend on the input in a deterministic manner. The population mean,
E[x(k)]
\operatorname{Var}[x(k)]
E[x(k)]=EB
(k) | |
[\mu | |
B] |
\operatorname{Var}[x(k)]=
m | |
m-1 |
EB
(k) | |
[\left(\sigma | |
B |
\right)2]
The population statistics thus is a complete representation of the mini-batches.
The BN transform in the inference step thus becomes
y(k)=
inf | |
BN | |
\gamma(k),\beta(k) |
(x(k))=\gamma(k)
x(k)-E[x(k)] | |
\sqrt{\operatorname{Var |
[x(k)]+\epsilon}}+\beta(k)
where
y(k)
x(k)
Although batch normalization has become popular due to its strong empirical performance, the working mechanism of the method is not yet well-understood. The explanation made in the original paper was that batch norm works by reducing internal covariate shift, but this has been challenged by more recent work. One experiment[2] trained a VGG-16 network[5] under 3 different training regimes: standard (no batch norm), batch norm, and batch norm with noise added to each layer during training. In the third model, the noise has non-zero mean and non-unit variance, i.e. it explicitly introduces covariate shift. Despite this, it showed similar accuracy to the second model, and both performed better than the first, suggesting that covariate shift is not the reason that batch norm improves performance.
Using batch normalization causes the items in a batch to no longer be iid, which can lead to difficulties in training due to lower quality gradient estimation.[6]
One alternative explanation, is that the improvement with batch normalization is instead due to it producing a smoother parameter space and smoother gradients, as formalized by a smaller Lipschitz constant.
Consider two identical networks, one contains batch normalization layers and the other does not, the behaviors of these two networks are then compared. Denote the loss functions as
\hat{L}
L
x
y
y=Wx
W
y
\hat{y}
z=\gamma\hat{y}+\beta
\gamma
\beta
\hat{yj}\in\Rm
\sigmaj
First, it can be shown that the gradient magnitude of a batch normalized network,
||\triangledown | |
yi |
\hat{L}||
||\triangledown | |
yi |
\hat{L}||2\le
\gamma2 | ||||||
|
(||\triangledown | |
yi |
| ||||
L|| |
\langle
1,\triangledown | |
yi |
| ||||
L\rangle |
\langle\triangledown | |
yi |
2) | |
L,\hat{y} | |
j\rangle |
Since the gradient magnitude represents the Lipschitzness of the loss, this relationship indicates that a batch normalized network could achieve greater Lipschitzness comparatively. Notice that the bound gets tighter when the gradient
\triangledown | |
yi |
\hat{L}
\hat{yi}
\gamma2 | ||||||
|
Secondly, the quadratic form of the loss Hessian with respect to activation in the gradient direction can be bounded as
(\triangledown | |
yj |
\hat{L})T
\partial\hat{L | |
The scaling of
\gamma2 | ||||||
|
\hat{gj}
It then follows to translate the bounds related to the loss with respect to the normalized activation to a bound on the loss with respect to the network weights:
\hat{gj}\le
\gamma2 | ||||||
|
2 | |
(g | |
gj |
-λ2\langle
\triangledown | |
yj |
2) | |
L,\hat{y} | |
j\rangle |
gj=max||X||\leλ||\triangledownWL||2
\hat{g}j=max||X||\leλ||\triangledownW\hat{L}||2
In addition to the smoother landscape, it is further shown that batch normalization could result in a better initialization with the following inequality:
*|| | |
||W | |
0-\hat{W} |
2\le
*|| | |
||W | |
0-W |
| ||||
(||W*||2-\langle
2 | |
W | |
0\rangle) |
W*
\hat{W}*
Some scholars argue that the above analysis cannot fully capture the performance of batch normalization, because the proof only concerns the largest eigenvalue, or equivalently, one direction in the landscape at all points. It is suggested that the complete eigenspectrum needs to be taken into account to make a conclusive analysis.
Since it is hypothesized that batch normalization layers could reduce internal covariate shift, an experiment is set up to measure quantitatively how much covariate shift is reduced. First, the notion of internal covariate shift needs to be defined mathematically. Specifically, to quantify the adjustment that a layer's parameters make in response to updates in previous layers, the correlation between the gradients of the loss before and after all previous layers are updated is measured, since gradients could capture the shifts from the first-order training method. If the shift introduced by the changes in previous layers is small, then the correlation between the gradients would be close to 1.
The correlation between the gradients are computed for four models: a standard VGG network,[5] a VGG network with batch normalization layers, a 25-layer deep linear network (DLN) trained with full-batch gradient descent, and a DLN network with batch normalization layers. Interestingly, it is shown that the standard VGG and DLN models both have higher correlations of gradients compared with their counterparts, indicating that the additional batch normalization layers are not reducing internal covariate shift.
Even though batchnorm was originally introduced to alleviate gradient vanishing or explosion problems, a deep batchnorm network in fact suffers from gradient explosion at initialization time, no matter what it uses for nonlinearity. Thus the optimization landscape is very far from smooth for a randomly initialized, deep batchnorm network.More precisely, if the network has
L
>cλL
λ>1,c>0
λ
λ
\pi/(\pi-1) ≈ 1.467
This gradient explosion on the surface contradicts the smoothness property explained in the previous section, but in fact they are consistent. The previous section studies the effect of inserting a single batchnorm in a network, while the gradient explosion depends on stacking batchnorms typical of modern deep neural networks.
Another possible reason for the success of batch normalization is that it decouples the length and direction of the weight vectors and thus facilitates better training.
By interpreting batch norm as a reparametrization of weight space, it can be shown that the length and the direction of the weights are separated and can thus be trained separately. For a particular neural network unit with input
x
w
f(w)=
Tw)] | |
E | |
x[\phi(x |
\phi
S=E[xxT]
E[x]=0
S
0<\mu=λmin(S)
L=λmax(S)<infty
S
fBN(w,\gamma,\beta)=
Tw))] | |
E | |
x[\phi(BN(x |
=
E | |||||||||||||||||||
|
)+\beta)]
The variance term can be simplified such that
Tw]=w | |
var | |
x[x |
TSw
x
\beta
fBN(w,\gamma)=
E | ||||
|
)]
(wTSw)
| ||||
S
||w||s
Hence, it could be concluded that
fBN(w,\gamma)=
T\tilde{w})] | |
E | |
x[\phi(x |
\tilde{w}=\gamma
w | |
||w||s |
\gamma
w
With the reparametrization interpretation, it could then be proved that applying batch normalization to the ordinary least squares problem achieves a linear convergence rate in gradient descent, which is faster than the regular gradient descent with only sub-linear convergence.
Denote the objective of minimizing an ordinary least squares problem as
min\tilde{w\in
d}f | |
R | |
OLS |
(\tilde{w})=min\tilde{w\in
d}(E | |
R | |
x,y |
[(y-xT\tilde{w})
2])=min | |
\tilde{w |
\inRd}(2uT\tilde{w}+\tilde{w}TS\tilde{w})
u=E[-yx]
S=E[xxT]
Since
\tilde{w}=\gamma | w |
||w||s |
min | |
w\inRd\backslash\{0\ |
,\gamma\inR}fOLS
(w,\gamma)=min | |
w\inRd\backslash\{0\ |
,\gamma\inR}(2\gamma
uTw | ||||||
|
)
Since the objective is convex with respect to
\gamma
\gamma
min | |
w\inRd\backslash\{0\ |
Note that this objective is a form of the generalized Rayleigh quotient
\tilde{\rho}(w)= | wTBw |
wTAw |
B\inRd
A\inRd x
It is proven that the gradient descent convergence rate of the generalized Rayleigh quotient is
λ1-\rho(wt+1) | |
\rho(wt+1-λ2) |
\le(1-
λ1-λ2 | |
λ1-λmin |
)2t
λ1-\rho(wt) | |
\rho(wt)-λ2 |
λ1
B
λ2
B
λmin
B
In our case,
B=uuT
wt+1=wt-ηt\triangledown\rho(wt)
η | ||||||||||
|
\rho(w0)\ne0
*)\le | ||
\rho(w | (1- | |
t)-\rho(w |
\mu | |
L |
)2t
*)) | |
(\rho(w | |
0)-\rho(w |
The problem of learning halfspaces refers to the training of the Perceptron, which is the simplest form of neural network. The optimization problem in this case is
min\tilde{w\in
d}f | |
R | |
LH |
(\tilde{w})=Ey,x[\phi(zT\tilde{w})]
z=-yx
\phi
Suppose that
\phi
fLH
\zeta
\alpha*=argmin\alpha||\triangledownf(\alphaw)||2
-infty<\alpha*<infty
z
\phi
fLH
\triangledown\tilde{w
c1(\tilde{w})=E
(1) | |
z[\phi |
(2) | |
(z | |
z[\phi |
(zT\tilde{w})](uT\tilde{w})
c2(\tilde{w})=E
(2) | |
z[\phi |
(zT\tilde{w})]
\phi(i)
i
\phi
By setting the gradient to 0, it thus follows that the bounded critical points
\tilde{w}*
\tilde{w}*=
-1 | |
g | |
*S |
u
g*
\tilde{w}*
\phi
First, a variation of gradient descent with batch normalization, Gradient Descent in Normalized Parameterization (GDNP), is designed for the objective function
min | |
w\inRd\backslash\{0\ |
,\gamma\inR}fLH(w,\gamma)
h(wt,\gammat)=E
Tw | |
t)-E |
2 | |
t) |
Let the step size be
st=s(wt,\gamma
|
For each step, if
h(wt,\gammat)\ne0
wt+1=wt-st\triangledownwf(wt,\gammat)
Then update the length according to
\gammat=Bisection(Ts,f,wt)
Bisection
Ts
Denote the total number of iterations as
Td
\tilde{w} | |
Td |
=
\gamma | |
Td |
| ||||||
|
The GDNP algorithm thus slightly modifies the batch normalization step for the ease of mathematical analysis.
It can be shown that in GDNP, the partial derivative of
fLH
(\partial\gammafLH(wt,a
(Ts) | |
t |
)2\le
| |||||||||||||||||||
\mu2 |
(0) | |
a | |
t |
0 | |
b | |
t |
Further, for each iteration, the norm of the gradient of
fLH
w
||wt||
2||\triangledown | |
S |
fLH(wt,gt)||
2\le | ||
(1- | ||
S-1 |
\mu | |
L |
)2t
*) | |
\Phi | |
0)-\rho |
Combining these two inequalities, a bound could thus be obtained for the gradient with respect to
\tilde{w} | |
Td |
||\triangledown\tilde{w}f(\tilde{w}
Td |
)||2\le(1-
\mu | |
L |
2Td | |
) |
| ||||||||||||||||||||||
\Phi | ||||||||||||||||||||||
0)-\rho |
Although the proof stands on the assumption of Gaussian input, it is also shown in experiments that GDNP could accelerate optimization without this constraint.
Consider a multilayer perceptron (MLP) with one hidden layer and
m
x\inRd
Fx(\tilde{W},\Theta)=\sum
m | |
i=1 |
T\tilde{w} | |
\theta | |
i\phi(x |
(i))
\tilde{w}(i)
\thetai
i
\phi
The input and output weights could then be optimized with
min\tilde{W,\Theta}(fNN(\tilde{W},\Theta)=Ey,x[l(-yFx(\tilde{W},\Theta))])
l
\tilde{W}=\{\tilde{w}(1),...,\tilde{w}(m)\}
\Theta=\{\theta(1),...,\theta(m)\}
Consider fixed
\Theta
\tilde{W}
fNN(\tilde{W})
i
\hat{w}(i)
\hat{w}(i)=\hat{c}(i)S-1u
\hat{c}(i)\inR
i=1,...,m
This result could be proved by setting the gradient of
fNN
Apply the GDNP algorithm to this optimization problem by alternating optimization over the different hidden units. Specifically, for each hidden unit, run GDNP to find the optimal
W
\gamma
(i) | |
||\triangledown | |
\tilde{w |
Since the parameters of each hidden unit converge linearly, the whole optimization problem has a linear rate of convergence.