Winter tree shadows on snow

Gating and Depth in Neural Networks

Depth is a critical part of modern neural networks. They enable efficient representations through constructions of hierarchical rules. By now we all know it so I’ll assume I don’t need to convince anyone, but in case you need a refresher it’s basically because we cannot efficiently model many data distributions that appear in the wild with a single or few functions without exponential amounts of neurons. The distributions are simply to complex to model in such a direct way, but despite this they do have structure it just so happens to be hierarchical. To think of it another way, imagine if the data generating process were not hierarchical. Then generating a complex distribution would take exponential resources as well. No doubt there are processes that are like this. That said one thing we know about the world is that it is made out of composing simpler parts, but put together they produce extremely complicated behavior at times.

Unfortunately our network training algorithm, error backpropagation, doesn’t like it when things recurse too much. Consider the case of a chain of linear neurons. If the (norm of the) weights of those neurons is not equal to one, the error signal will either successively shrink or grow. If the shrinking is too fast we see that it vanishes, and if it grows too fast it is said that it explodes. The picture is roughly the same for neurons with nonlinear activation functions, except that we think about the norm of the Jacobians instead but it’s basically the same deal. The scaling is not cool because it means when we propagate errors backwards from the loss we see that the signal gets way out of whack, meaning the updates to the weights are useless and our network training diverges.

Great, we’re all familiar with vanishing and exploding gradients. What do we do about it? The basic strategy is to make sure all that scaling doesn’t mess with the error signal while still enabling depth. We need to protect some of that signal.

In feed-forward networks (FFNs) we can accomplish this with skip connections. The simplest case is residual learning, where the output from a lower layer to a higher layer is added to that higher layer’s output. Then successive layers only have to learn an increment, since they are already “starting” from where the previous layer was. This residual is argued to be easier to learn than building a new transformation from scratch.

Another reason why this works is that residual connections actually enable a shorter path from a given layer to the output. All you need to do is travel through the residual connections of successive layers until you get to the end, skipping the layers in between. As it happens, deep residual networks actually have an effective depth that is significantly smaller than the specified depth because this is exactly what happens in training. In other words, we have avoided a lot of unfavorable transformations. This is implicitly happening in residual nets, but can be explicitly made so by having a connection between every layer and the output which is called a dense network.

By the way, it is possible that starting from the previous layer’s output actually makes it more difficult to learn the best transformation to come next. Or at least, it could just be an inefficient use of many layers. So, we introduce the concept of a gate and say let’s learn a coefficient to attenuate how much the identity connection should be used versus the normal stacked layers. Then the residual networks become special cases of networks with weighted skip connections called highway networks.

This is all very interesting because in a sense we have implemented a memory for FFNs, and with gating mechanisms we can learn when to access that memory and when to ignore it. You’ve created some representation at a lower layer and that signal can skip many layers and find itself somewhere else. The (learned) gate on that somewhere else layer tells you whether to access that info or not.

Let us think about a different problem now: sequence modeling. FFNs are great for this of course, and sequences are a nice way to justify depth. If you are using a convolutional FFN (as everyone does nowadays), you actually require a certain level of depth to take into account long-range correlations in your input. This is exceptionally obvious in language modeling where long range correlations are potentially really long. Then skip connections and also things like attention are acting as memory lookup mechanisms for more fine-grained information to service the higher level representations. It’s all very cool.

There is another fun way to look at this too, and that is with recurrent networks (RNNs). In that case, we have a single hidden layer that is applied to every element in a sequence, and then after the final element you backprop your error “in time”. So as you can imagine, our gradients can and do vanish by the time they get back to the beginning of the sequence. However, since we use the same weights at each sequence step, adding a skip connection doesn’t make sense here.

What the RNN people have done instead is to put in other little memory systems using a gating mechanism. Now you can protect some part of your signal by writing it to a special state somewhere and learn when to read it back. So what does effective depth mean for an RNN? Not sure it is really comparable beyond conceptually thinking of depth in the sequence direction, but at any rate it is an interesting analogy.

It is actually pretty funny to think about how we haven’t solved the vanishing or exploding gradients problem at all, we just side stepped it. So now we can ask whether this is natural or weird. I happen to feel that it is still weird that we can pull hierarchical representations out of thin air so well. It is interesting to note the effectiveness of intermediate losses, particularly in NLP applications where one has this type of information due to the existence of various parsers that can generate features for you to predict (e.g., part-of-speech, named entities).

It is pretty obvious why this works and it’s related to skip connections. Simply, you are anchoring your representations to something besides the final layer’s output. In the residual FFN case, it is to the lower layer outputs, and this is no different in that it is also an additional signal you know is a priori a good representation to some degree (just like your inputs!).

It does seem for the moment that vanishing and exploding gradients are a fundamental, though not insurmountable, problem.

Leave a Reply

Fill in your details below or click an icon to log in: Logo

You are commenting using your account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s