Implement Handwritten Digit Recognition with PyTorch

Hello everyone! I’m Dog Brother. Today, we are going to explore a super interesting project together: implementing handwritten digit recognition using PyTorch! This project will not only help us understand the basics of deep learning but also allow us to build a neural network model from scratch. I will guide you step by step in the simplest way to achieve this seemingly sophisticated project.

1. Preparation

First, we need to install the necessary libraries. We will mainly use PyTorch and torchvision, which are great helpers for deep learning.

bash copy

pip install torch torchvision

2. Import Required Modules

Let’s first import the modules we need:

python copy

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

3. Data Preparation

MNIST is a very classic dataset of handwritten digits, containing images of handwritten characters from 0 to 9. Let’s load it:

python copy

# Define data transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Load training and test datasets
train_dataset = torchvision.datasets.MNIST(
    root='./data', 
    train=True, 
    download=True,
    transform=transform
)

test_dataset = torchvision.datasets.MNIST(
    root='./data', 
    train=False,
    transform=transform
)

# Create data loaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False)

4. Build the Neural Network

Now, let’s build our neural network model. Here, I will use a simple three-layer neural network:

python copy

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )
    
    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

# Create model instance
model = Net()

5. Train the Model

The next step is the most important training phase:

python copy

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Train the model
def train(epochs):
    for epoch in range(epochs):
        running_loss = 0.0
        for i, data in enumerate(train_loader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            if i % 100 == 99:
                print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 100:.3f}')
                running_loss = 0.0

# Start training
train(epochs=5)

6. Test the Model

After training, let’s test the model’s performance:

python copy

def test():
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    print(f'Accuracy: {100 * correct / total}%')

test()

Tips:

  1. If you encounter an out-of-memory issue, try reducing the batch_size value.
  2. The number of training epochs can be adjusted based on actual performance.
  3. The choice of learning rate (lr) is also important; too high may cause the model not to converge, while too low will slow down training.

7. Practical Use

Try recognizing your own handwritten digits:

python copy

def predict(image_path):
    # Load and process the image
    from PIL import Image
    image = Image.open(image_path).convert('L')
    image = transform(image).unsqueeze(0)
    
    # Prediction
    with torch.no_grad():
        output = model(image)
        _, predicted = torch.max(output.data, 1)
        return predicted.item()

Everyone, that’s all for today’s Python learning journey! Remember to code along, and feel free to ask Dog Brother in the comments if you have any questions. I hope this project helps you appreciate the charm of deep learning. Looking forward to your wonderful performance in the AI field! Happy learning, and may your Python skills improve steadily!

Leave a Comment