Why MNIST is the worst thing that has ever happened to humanity
All right, all right, it’s not that bad. But still.
Everyone who has approached the magical world of data science, including me, at some point went through the same initiation ritual: fire up a jupyter notebook, import pytorch or keras (your choice of poison), implement a simple CNN and train it on the almighty MNIST dataset for a few epochs. And then rejoice at the sight of those sweet, sweet figures near the words “validation accuracy” in the output cell.
And that’s ok. I mean, we have to start somewhere and jokes aside handwritten digits classification is a good place to start:
- It’s a well defined problem, meaning that it’s very easy to understand what is the goal of the task and therefore how to evaluate it.
- It takes just a few minutes to train a neural network able to achieve very high accuracy levels. Even with a super simple architecture (AlexNet or something even simpler) you’ll probably get >90% accuracy on the validation set.
For these and possibly more reasons, it allows a person to focus on getting the basics of neural networks right: the actual implementation of the network, the choice of hyperparameters, how to split the dataset etc.
So what’s my problem with this?
A fundamental and yet a bit too often overlooked aspect of machine learning models is that, assuming there was nothing fishy going on when training them (and everything that happened to the data before that moment), you can expect them to perform as well as they did on the test set as long as you use it on data coming from the same distribution as the data on which you trained it. And this is true less often than one might think.
What would you expect from a model trained on MNIST? A reasonable assumption would be that you might be able to use it to classify with reasonable precision handwritten digits right?
One problem is that, while the overall shape of a digit is well defined (a zero will always be drawn more or less as a circle, an eight as two small circles one on top of the other etc.), there are minor variations in the way people draw them. Take the number one, for instance: a single vertical segment is a valid way to represent it, but so is a segment with a small marking on the left starting from the top.
The MNIST dataset however contains exclusively examples similar to the one on the left, so if you ask your model what is that thing on the right, it will probably answer “it’s a 7”. This is probably due to the fact that the angle between the two segments composing the marking on the right is an important feature used by your model to discriminate between digits, and representations of digit “7” have that particular feature in that dataset, unlike representations of digit “1”.
It’s particularly disappointing for instance trying a model trained on MNIST on digits that have not been handwritten but printed: most fonts represent digits in a slightly different way, that is however enough to make predictions unreliable, especially for some classes.
Noiseless = bad
Another issue with the infamous dataset that makes the world a darker place™ is that all images are essentially noiseless: no “stains” of any type, no poorly cropped border or any shape that is not part of the actual digit.
The problem is that in reality this is quite an unlikely scenario: images are aquired from a photo made with a camera or a scan of a piece of paper with actual ink, and there might be smears or other lines printed on the paper itself and whatnot; no matter how much data cleaning one does, it is likely that some of these defects will end up in the input image fed to the classifier, which has never seen anything like that and therefore will have trouble classifying correctly those kind of images.
But.. but… it worked on my notebook!!!
My personal issue with MNIST (and most tutorials that use it) is that it gives people who use it (who are likely to be beginner/student-level wannabe data scientists) the unrealistic expectation that they have solved a problem. I think it’s hard not to have that feeling at the sight of an extremely high accuracy score after a few epochs of training with almost any network on default settings. But in reality what they did is they trained a classifier that correctly identifies digits that belong to the MNIST dataset, and probably not much more. If they tried to write a digit on a piece of paper, take a picture of it with their phone and feed it to the model, there are good chances that the result will be wrong, for any of the reasons explained above.
Things are, unfortunately, a bit more complicated than that, and I think it would be nice to be very explicit about it instead of hiding the dirt under the carpet.
I think it can still be used as an exercise for the “mechanical” part of implementing a machine learning pipeline, which is however a non trivial task for a complete beginner in the field, but there should be more transparency about what kind of results one will obtain by using this dataset.
Personally, I would encourage people to use a different dataset in image classification tutorials: something like Fashion MNIST, available from https://github.com/zalandoresearch/fashion-mnist (also included in
keras.datasets) is basically the same as MNIST, only you have images of different categories of clothes/shoes instead of digits.
While it has similar characteristics (28 x 28 pixel grayscale images, perfectly isolated from the background, no noise etc.), I think it doesn’t trigger the same over-optimistic feeling about how the model could perform in a real life scenario, simply because the images are significantly different from the ones that one can imagine acquiring with a camera.
Another viable option is the CIFAR-10 dataset, where at least performances won’t be so good that one gets a twisted idea of the real difficulty of this class of problems. Plus one of the classes is 🐱, which is an added value as far as I’m concerned.
Stay in school, don’t do MNIST. But if you do, don’t be fooled by overwhelmingly good performances.