 When I first came across batch normalization, I hated it. It looked like an arbitrary hack. And worse, I had a tough time finding everything I needed to know to see how it worked. But since then, I've come to see that it's a wildly effective method and that it drops tantalizing hints at deeper machine learning concepts. Here's the tutorial I wish I had found at the outset. Batch normalization is an element by element shift, adding a constant, and scaling, multiplying by a constant, so that the mean of each element's values is zero over time, and the variance of each element's values is one within a batch. It's typically inserted before the non-linearity layer in a neural network. It works pretty well, and we're still trying to find out why. If you collect the activity of a single node across many iterations, you can represent that node's activity as a distribution. That distribution can fall far from the mean. It can take on odd shapes. Normalization can mean different things, but in this case it means making it as close as possible to the normal distribution, that is, giving it a mean of zero and a variance of one. Step one is to subtract the mean to shift the distribution. Step two is to divide all the shifted values by their standard deviation, the square root of the variance. The end result may not have a nice normal bell shape, but at least it will have zero mean and unit variance. How many observations of a node's activity should we include when normalizing? There's a trade-off here. Including too few observations will give us noisy estimates for mean and variance, making the normalization step a source of erratic changes rather than a source of stability. For example, all the histograms below were drawn from the same distribution. On the other hand, all the parameters in the model are being updated at every time step, so if we wait too long, will in effect be straddling different models, and the node's meaning may have changed entirely over time. Trial and error has shown that a convenient middle ground is the size of one batch. The number of parallel feedback runs that take place during a single time step. Batching has become commonplace as neural networks are most often trained on specialized hardware, like graphical processing units or GPUs, that compute the required matrix multiplications in parallel at blazing speeds. Driven by the interplay between model size and GPU capacity, batch sizes tend to run between 2 to the 6th and 2 to the 10th, that is between about 64 and 1024. These have been shown to work well for normalization. They give estimates of mean and variance that are accurate enough to help, but still change rapidly enough to keep up with any changes the model makes. This batch-wise normalization of each node's activities ensures that it will appear to produce a set of zero mean unit variance outputs to everything upstream of it in the network. A concise description of this is given by Yofe and Shagetti in their 2015 paper introducing the method. An output y sub i equals the input x sub i minus mu, the input mean, divided by sigma, the input's standard deviation. This occurs element by element. Within each batch, the activities of each element are separately shifted and scaled so that they have a zero mean and unit variance within the batch. The mean and variance of one node's activities will be independent from that of another. There's also an optional scale and shift that can be learned to further improve performance, known as an affine transformation. The shift parameter is represented by beta in the scale parameter by gamma in the full definition of batch normalization. The affine transformation is a popular option. It's the default in TensorFlow and PyTorch. An informal survey of practitioners on Twitter and LinkedIn shows that they tend to use the batch normalization with gamma and beta enabled about three times out of four. They've been shown to improve accuracy on image classification benchmarks by a few percent and they seem to sparsify feature activations. Affine transformations are definitely helpful, but we'll set them aside for now. They're the spinning rims of batch normalization. They're cool, but not essential to its function. They could just as easily be separated out into their own layer. There's no reason to bundle them with batch normalization. The most interesting part of what batch normalization does, it does without them. Although batch normalization is usually used to compute a separate mean and variance for every element, when it follows a convolution layer, it works a little differently. Instead, it just computes one mean and variance for each channel, not a separate one for every pixel of every channel. This is due to the fact that, thanks to convolution, every pixel value within a channel is generated by the same convolutional kernel. They're all part of the same pool of activities that result from that kernel and so can be reasonably lumped together into the same activity distribution. So how does batch normalization really work? Realistically, the first question to ask here is, does it matter? Why should you care about how it works as long as it brings the loss down? If your focus is on climbing a leaderboard, then it probably doesn't matter, and you can skip this section. But if you'd like to build a deeper intuition for what's going on and possibly introduce some new customizations to your toolbox, then understanding why batch normalization is so effective can help you out. Internal Covariate Shift This phrase, taken from the title of Yofe and Chigete's original batch normalization paper, is described in that paper as the change in nodes activation distributions due to changes in network weights. The authors give several examples of mechanisms that can cause those distributions to wander far from normal and postulate ways in which they may hurt its performance or make it difficult to train. The non-linearity functions that lie at the heart of neural networks all do their most interesting work around the neighborhood of zero, plus or minus one or two. If a node's activity distribution climbs too far away from that, it can escape the reach of the back propagation training signal. It's said to enter a gradient plateau. In a wonderfully informative post, David C. Page illustrates this shift in node's distributions and how batch normalization corrects for it. An accompanying Twitter thread gives additional context. He also investigates the curvature, sometimes called the smoothness, of the loss function and finds that the very directions in which the loss function is least smooth are the directions they get adjusted by batch normalization. In other words, batch normalization tends to smooth out the loss function. This allows for more aggressive training rates and shorter training runs. A subsequent paper demonstrates that if you carefully construct your example, you can in fact get all of the benefits from batch normalization while intentionally inducing distributional shifts in node's activities. This suggests that stabilizing node's activity distributions is a side effect of whatever good thing batch normalization does rather than the root cause. If you care to descend down to the rabbit hole, there's a Twitter thread that has some back and forth on the precise definition of internal covariate shift. The conversation appears to get bogged down a bit in the fine interpretations of terminology, but my takeaway is that internal covariate shift is a broad category and doesn't necessarily mean a shift in node's activity distribution away from normal. This lack of clarity is widely manifest. A very non-scientific poll of the community on Twitter and LinkedIn suggests that there's little consensus around whether internal covariate shift or gradient smoothing is the cause of batch normalization success. As far as I can tell, there's one point that everyone agrees on. Batch normalization smooths the lost landscape. Exactly how it does this and precisely how this generates all the resulting benefits is still the topic of investigation. If you're interested in diving deeper, I recommend starting with this paper linked in the notes. For a broader survey of possible explanations and their implications, the Wikipedia entry gives a good starting place. Batch normalization is a fascinating example of a method molding itself to the physical constraints of the hardware. The method of processing data in batches co-evolved with the use of GPUs. GPUs are made of lots of parallel processors, so breaking the training job up into parallel batches may perfect sense as a trip for speeding it up. Once in batches, batch averaging the gradient was a natural thing to try. Batch normalization followed after. But what if you're not using a GPU? Not everybody can afford one. And not every problem demands one. This limitation of batch normalization has been addressed in a series of advances, beginning with batch renormalization, moving on to streaming normalization and culminating in online normalization, a highly performant variant of batch normalization that works beautifully even with a batch size of one. It's probably not coincidence that the paper was published by a team from Cerebrus Systems, a startup that's building a new neural network chips from scratch, which are decidedly not GPUs. I've also implemented online normalization in Cottonwood, the CPU-centric machine learning framework I'm building. If you're writing your own neural networks from scratch and running them on your laptop, you can still get the benefits of batch normalization without having to port your code to a cloud GPUs or to go through a heavyweight framework. Now here's the cool bit. What made me fall in love with batch normalization is its adaptive signal processing. Here are a few examples. The auto-scaling of pixel values. Image pixel values come in a variety of ranges from 0 to 1, sometimes 0 to 255, sometimes something else, if there's been some fancy pre-processing. If you'd rather not go look it up, verify it in your data, and assume that it never drifts or decays, you can just slip a batch normalization layer right after the input layer of your network and it's all taken care of for you. Also, multimodal inputs. When you have inputs from several different types of sensors or data streams, each is likely to have its own range and distribution. This can be tough for neural networks to handle. A batch normalization layer brings them all into a similar distribution for apples-to-apples comparison and computation. In the case of constant inputs, occasionally one of the inputs to a model is constant, whether due to some quirk and data selection or a broken sensor somewhere. Many models can learn to ignore these if they provide no predictive value, but that's not a guaranteed approach. A batch normalization layer will adapt to a constant input element, reducing it to 0. Two-level inputs. For an input element that splits its time between two distinct values, a low and a high, batch normalization performs the convenience service of making those values plus and minus 1. Whatever they were originally. Rarely active elements. For inputs that spend most of their time at a baseline level but occasionally deviate from that, batch normalization helpfully shifts that baseline to 0 and magnifies the deviation. This amplifies the underlying information carried by that input element while attenuating the constant background level. And all of this without having to specify any additional thresholds, constants, or hyperparameters. This can also happen deep within a network, say at the bottlenecked layer of an autoencoder or at the far end of a network near the categorical output layer. Some deeper representations or categories may occur infrequently. In this case, batch normalization helps the network out by shouting extra loud when they do occur. Another case is drifting elements. When working in applications, the distributions of inputs can change slowly over time. When these are distractions, rather than the phenomena of interest, a batch normalization layer will remove the shifting baselines automatically and preserve the fluctuations. This property is exactly what makes batch normalization so useful at all levels of the network. By its nature, it is changing throughout during training. The distribution of every element is drifting. A batch normalization layer before every linear and convolutional layer grants them some stability. It allows them to train more effectively since their inputs will always fall within the same region of their non-linearity function. And it makes them less sensitive to interactions with changes in weights elsewhere in the network. Taken together, these do a little bit of teeming to the wildly unpredictable animals that deep neural networks are. Batch normalization makes networks more shareable, more easily reusable. A neural network is a non-linear processor so seemingly benign things like input scaling and offset make a difference. Used well, batch normalization can get you a bit closer to your goals. Best of luck.