MNIST: Handwritten digit classification in Stata

Posted on February 08, 2018

To a large extend, the prominent role of neural networks in machine learning is due to their successful applications in image recognition and classification. Neural networks are particularly effective in modeling complex data structures such as shapes, forms and textures formed by image pixels. The hierarchical nature of these structures explains the success of multilevel network models such as convolutional neural networks.

Although image analysis is far beyond the traditional application areas of Stata, I find it interesting to try to do some basic image classification problems in Stata. In this post I decided to demonstrate an application of the mlp2 command for classifying handwritten digits from the MNIST database, [1].

The mlp2 command can be installed by typing the following in Stata:

. net install mlp2, from("http://www.stata.com/users/nbalov")

To see the help file, type

. help mlp2
See the MNIST post for dataset description and the mlp2 post for model and command description.

Model specification and training

. use http://www.stata.com/users/nbalov/datasets/mnist-train)
(Training MNIST: 28x28 images with stacked pixel values v* in [0,1] range)
. set seed 12345
. mlp2 fit y v*, layer1(100) layer2(100)

------------------------------------------------------------------------------
Multilayer perceptron                              input variables =      784
                                                   layer1 neurons  =      100
                                                   layer2 neurons  =      100
                                                   output levels   =       10

Optimizer: sgd                                     batch size      =       50
                                                   max epochs      =      100
                                                   loss tolerance  =  1.0e-04
                                                   learning rate   =  1.0e-01

Training ended:                                    epochs          =       21
                                                   start loss      = 0.287316
                                                   end loss        = 0.000616
------------------------------------------------------------------------------

I then perform in-sample prediction using the mlp2 predict command to evaluate the fit of the model. The option genvar(ypred) provides a stub for new variables holding the predicted class probabilities. The command will thus generate variables ypred_0 to ypred_9 of probability values that sum to 1. The digit with maximum probability will be the predicted one used for calculating the prediction accuracy. The accuracy itself is given by the proportion of correctly predicted digits in the sample.

. mlp2 predict, genvar(ypred)

Prediction accuracy: .9999833333333333

The reported in-sample prediction accuracy of about 1 shows that the model explains the training data almost perfectly. This itself is not enough indication for a good model fit. The model is so complex that it can easily overfit even large training samples. The important problem of designing and training efficient models without overfitting is, however, out of the scope of this blog.

Validating

. use http://www.stata.com/users/nbalov/datasets/mnist-test, clear
(Testing MNIST: 28x28 images with stacked pixel values v* in [0,1] range)
. mlp2 predict, genvar(ypred)

Prediction accuracy: .9742

The reported test prediction accuracy is about 0.97, so the test error is less than 3%. This is a reassuring result, inline with the performance of other similar classification models. See for example the [3-layer NN, 500+150 hidden units] model listed in the classification table in {http://yann.lecun.com/exdb/mnist}. The best performing classifiers have less than 0.5 error but employ substantially more elaborate models than the presented here.

It is instructive to look at some of the test digits that are missclassified. The indices of the missclassified records are saved in the matrix e(pred_err_ind). Some of them are 116, 152, 246, and 322. As seen from the predicted class probability vector, the test image 116 depicts the digit 4 but it is classified as 9.

. list y ypred* in 116

     +--------------------------------------------------------------------+
116. | y |  ypred_0 |  ypred_1 |  ypred_2 |  ypred_3 | ypred_4 |  ypred_5 |
     | 4 | 1.22e-08 | 4.05e-08 | 2.46e-06 | 1.51e-07 | .121335 | 8.60e-07 |
     |--------------------------------------------------------------------|
     |     ypred_6    |     ypred_7    |     ypred_8     |    ypred_9     |
     |    .0007171    |    5.11e-08    |    .0001033     |    .877841     |
     +--------------------------------------------------------------------+

By looking at test image 116 (Figure 3), we indeed see a resemblance with the digit 9, so it is not surprising the algorithm has been fooled. Similarly you can check that test image 152 depicts the digit 9 but it is classified as 8, test image 246 depicts the digit 39 but it is classified as 6, and t est image 322 depicts the digit 2 but it is classified as 7. I show these test images in Figure 3.

Figure 3: The incorrectly classified test images 115, 152, 246 and 322.

References

  1. Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner (1998). Gradient-Based Learning Applied to Document Recognition. Proceedings of the IEEE, 86(11), 2278-2324.