However... for some reason, when reading the book, there was one section that made me pause, and the more I thought about it, the more I came to a different conclusion than the author. It was, as foreshadowed by the title, the section on L2 Regularization and neural network "simplicity".
In it, he essentially makes the claim that L2 Regularization results in a simpler "model". He spends a fair amount of time discussing simplicity in a larger sense, and examples where simpler explanations are or are not more correct... but never really makes a convincing argument for why the smaller-weight models favored by L2 regularization should be considered "simpler".
He DOES make some good points about why it might favor more generalized models, vs just memorizing noise... and then implies that this therefore makes it "simpler". The main justification here is an analogy to a situation where you have some noisy data, and can use either a linear approximation or a polynomial fitting. Intuitively, the linear model is both simpler AND more generalized, but I don't know that the two things - simplicity and generalization - always go hand-in-hand, and I would argue that in the case of neural networks and L2 regularization, they don't.
To see why, let's consider one of the ways in which his linear vs polynomial comparison differs from our regularized vs. unregularized comparison: number of variables. His polynomial model essentially has 10 different variables, while his linear model only has one, slope (or two, if you consider offset, though in his pictured example it's 0). So, another way of looking at it might be to call the simpler network the one with fewer variables.
Ah, you say, what relevance does that have to our regularized vs. unregularized comparison? Don't both of those have the same number of parameters? And, technically, yes, that's true... but consider this: regularization is something that helps a network perform better when it's overfitting... that is, when it's number of parameters is relatively large compared to the number inputs we're training over. So, say we have a situation where regularization is helping; in that case, it's likely that if we take the unregularized version, and simply increase the size of the network (but keep the training set size the same), we'll see relatively small increases in real-world performance... but if we do the same with the regularized one, we might expect to see a bigger impact. That implies that regularized networks are making "better use of" their parameters... that is, that even though they technically have the same number of parameters, the regularized one has more "useless" parameters... and I think that's exactly what's happening.
To see why I think L2 regularization helps avoid "useless" parameters, let's take things down to more concrete terms: on the most basic level, if we have a set of 4 weights, then given two distributions of weights, A and B:
A = [.03, .9, .02, .05]
B = [.3, .4, .1, .2]
...then regularization will strongly favor B over A. But without any context, but simply looking at the weights, I think most people would say that A is simpler than B - A is effectively saying "the second input is so much more important than the other inputs, we we can effectively ignore the rest of them" - that sounds a whole lot simpler than the approach B is taking, which is to effectively say, "while the second IS more important, it's still important to consider all the others as well!" Without regularization, there's nothing to prevent this from going to extremes, as long as it happens to fit better to the training data - ie, A (the unregularized) result might end up looking like [.0001, .99999, 1e-10, .001] - which can be pretty much modeled by a 1-parameter system, and is fairly "simple" - even though B might only give a 2% worse result on the training data, and still uses all 4 parameters.
To put things in a different perspective - let's look at the handwriting-recognition problem. Say we happen to notice that ALL the "9"s in our training sample have a value > .5 in a given pixel. As time goes on, without regularization, our neural network will tend to HEAVILY weight the input from that pixel when deciding if something is a 9, which effectively ends up decreasing the importance of other pixels, or larger patterns. The regularized approach, on the other hand, will sort of be saying that, "ok, even though that one pixel seems more important on this data set, I don't want to forget the contributions of all the other pixels" - so that, when we feed the network a 9 that is < .5 in that pixel, it is able to cope with that better. This is a more nuanced approach, and to my mind at least, more complex.
Finally, I would argue that, for most problems we want to use machine learning for, Occam's razor is reversed - the simpler solution is LESS likely to be correct! Indeed, the whole field of machine learning can be thought of having been birthed by the desire to find more complex solutions - ie, for dealing with problems for which we can't find any simple models to deal with. The problems are so complex, that intuitively, I'm likely to think that the that more correct model is also likely more complex*... so, since regularized models tend to give better results for these problems, I'm more inclined to believe they're more complex!
A = [.03, .9, .02, .05]
B = [.3, .4, .1, .2]
...then regularization will strongly favor B over A. But without any context, but simply looking at the weights, I think most people would say that A is simpler than B - A is effectively saying "the second input is so much more important than the other inputs, we we can effectively ignore the rest of them" - that sounds a whole lot simpler than the approach B is taking, which is to effectively say, "while the second IS more important, it's still important to consider all the others as well!" Without regularization, there's nothing to prevent this from going to extremes, as long as it happens to fit better to the training data - ie, A (the unregularized) result might end up looking like [.0001, .99999, 1e-10, .001] - which can be pretty much modeled by a 1-parameter system, and is fairly "simple" - even though B might only give a 2% worse result on the training data, and still uses all 4 parameters.
To put things in a different perspective - let's look at the handwriting-recognition problem. Say we happen to notice that ALL the "9"s in our training sample have a value > .5 in a given pixel. As time goes on, without regularization, our neural network will tend to HEAVILY weight the input from that pixel when deciding if something is a 9, which effectively ends up decreasing the importance of other pixels, or larger patterns. The regularized approach, on the other hand, will sort of be saying that, "ok, even though that one pixel seems more important on this data set, I don't want to forget the contributions of all the other pixels" - so that, when we feed the network a 9 that is < .5 in that pixel, it is able to cope with that better. This is a more nuanced approach, and to my mind at least, more complex.
Finally, I would argue that, for most problems we want to use machine learning for, Occam's razor is reversed - the simpler solution is LESS likely to be correct! Indeed, the whole field of machine learning can be thought of having been birthed by the desire to find more complex solutions - ie, for dealing with problems for which we can't find any simple models to deal with. The problems are so complex, that intuitively, I'm likely to think that the that more correct model is also likely more complex*... so, since regularized models tend to give better results for these problems, I'm more inclined to believe they're more complex!
Now, I know that a lot these arguments are pretty complex and hand-wavey... but to me, they feel closer to the truth of what's happening here... and, I suppose, the real point of all this was that I think it gave me a better intuition on how L2 regularization is likely working!
*I think this heuristic - that machine-learning problems are so complex that the more correct model is also likely more complex - will often hold because the "ground" truth for many of these problems is what a human would say - ie, our basis for comparison is the model mapped in the neurons in our brains, which are incredibly complex. Of course, there are counter examples - handwriten digit recognition is largely solved, for instance, with relatively small networks, so the heuristic sort of fails here. But the standard NIST handwriting recognition problem is also one with a lot of constraints and preconditions, which make it a lot easier to solve - we're presupposing that the images we're fed ARE digits, they're frequently segmented already, we're only considering digits (and not letters, and capital letters, and punctuation), we don't have to find them within larger images, etc. The more of those preconditions are eliminated, the closer they get to the tasks our brains are actually doing, and the more complex the problem gets... and the more I will believe that a more correct network is more complex.
No comments:
Post a Comment