Skip to content

Handwritten Digit Recognition with KNN in Python (MNIST Dataset, 90%+ Accuracy)

4 min read

In this beginner-friendly machine learning tutorial, you’ll learn how to recognize handwritten digits using Python and the MNIST dataset. We’ll use the K-Nearest Neighbors (KNN) algorithm, one of the simplest yet effective classifiers, to predict digits with over 90% accuracy. This project is a great way to practice with real-world image data, model training, and accuracy evaluation — all in under 50 lines of code!

Let’s dive in!

1. Importing Essential Python Libraries

To build our handwritten digit recognition model, we first import the necessary Python libraries. Each of these plays a crucial role in different stages of the machine learning workflow:

  • Scikit-learn: It is one of the most popular and beginner-friendly machine learning libraries in Python. It provides a wide range of tools for classification, regression, clustering, model evaluation, and data preprocessing.
  • TensorFlow (Keras Datasets Module): It is a powerful open-source deep learning framework developed by Google. While we’re not building a neural network in this project, we use the mnist dataset from its Keras Datasets module.
  • Matplotlib: It is a well-known plotting library in Python. It allows us to create visualizations, such as displaying images and drawing graphs. In this project, we use matplotlib.pyplot to visualize digit images and show the predicted vs. actual labels.

2. Load the MNIST Dataset

We use the mnist module from TensorFlow, which contains 70,000 grayscale images of handwritten digits (0–9). Each image is 28×28 pixels.

3. Flatten the Image Data

Machine learning algorithms like KNN require one-dimensional input. We reshape each 28×28 image into a 784-length vector and normalize the pixel values to the range [0, 1].

4. Use a Subset (Optional but Recommended)

Because KNN is a distance-based algorithm, it can be slow on large datasets. We select a smaller subset (10,000 training samples and 1,000 test samples) for faster processing.

5. Train the KNN Classifier

We use KNeighborsClassifier from Scikit-learn with k=3. The k in KNeighborsClassifier(n_neighbors=3) refers to the number of neighbors the algorithm looks at when making a prediction. This means the model looks at the 3 closest training samples when making predictions.

Here’s how it works:

  1. When the model receives a new, unseen data point (like a digit image from the test set), it compares this point to all the training data.
  2. It calculates the distance (usually Euclidean distance) between the new point and every training sample.
  3. It then finds the 3 closest training samples — these are the “3 nearest neighbors.”
  4. The model then looks at the labels of these 3 neighbors (e.g., digits 2, 2, and 3).
  5. It chooses the most common label among the neighbors — in this case, it would predict 2.

6. Make Predictions and Evaluate Accuracy

We predict on the test subset and evaluate accuracy using accuracy_score. In this case, the model achieves 90% accuracy.

At the time of writing this article, the test yielded the following result: “Model accuracy on MNIST: 94.56%”.

7. Visualize Predictions

Using Matplotlib, we visualize the random sample along with the predicted and actual labels. This visualization helps validate how well the model performs.

You can also look at other numbers by changing the value of i.

In this project, you built a simple yet effective handwritten digit recognition system using Python, Scikit-learn, and the MNIST dataset. The K-Nearest Neighbors algorithm achieved over 90% accuracy, making it a great choice for quick prototyping and learning how image classification works.

Feel free to reach out via email or connect with me on LinkedIn. I’ll do my best to get back to you as soon as possible.

Best Regards,
Can Ozgan