Hao Shan's Studio.

零基础上手深度学习-手写体识别

Word count: 2.7kReading time: 12 min
2025/08/14
loading

从零开始:用 PyTorch 实现手写体识别(MNIST)

这篇文章适合谁

  • 零基础同学:想亲手跑通“AI 识别手写数字”的小项目
  • 快速上手:想用最短路径理解深度学习项目的基本结构(数据 → 模型 → 训练 → 评估 → 保存 → 推理)

我们要做什么

  • 数据:使用经典数据集 MNIST(28×28 的手写数字图片,标注 0–9)
  • 训练:运行现成脚本 mnist_train.py 完成训练与评估
  • 保存:把最优与最终模型权重保存到磁盘
  • 推理:用一段小脚本对单张图片进行预测

环境准备(Windows + PowerShell)

假设你的项目路径为 D:\sh\手写体识别,并且你已安装了 Python 3.8+。

1) 创建并激活虚拟环境

(PS:conda的使用需要安装anaconda,可自行搜索或私信作者)

1
2
3
4
5
6
7
8
# 进入你的项目目录(请换成你的实际路径)
cd path/to/code-directory #目录需替换为自己的目录

# 创建虚拟环境
conda create -n HD python=3.10

# 激活虚拟环境
conda activate HD

2) 安装依赖(CPU 版)

如果没有 NVIDIA GPU,建议安装 CPU 版,简单稳定。

1
2
python -m pip install --upgrade pip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu

如有 NVIDIA GPU 且已配置合适的 CUDA,可到 PyTorch 官网 选择对应命令安装。


快速运行训练

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
import os
import random
import argparse
from typing import Optional, Dict, Any, Tuple

import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms


def set_global_seed(seed: int = 42) -> None:
"""
设置全局随机种子,提升实验可复现性。

参数:
seed: 随机种子数值。
"""
os.environ["PYTHONHASHSEED"] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
# 为了更强可复现性(可能牺牲速度)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


def get_device() -> torch.device:
"""
获取训练设备(GPU 优先,否则 CPU)。

返回:
torch.device 对象。
"""
if torch.cuda.is_available():
return torch.device("cuda")
# Apple Silicon MPS
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")


def load_mnist_datasets(normalize: bool = True) -> Tuple[torch.utils.data.Dataset, torch.utils.data.Dataset]:
"""
加载 MNIST 训练/测试数据集。

参数:
normalize: 是否对像素做标准化处理。

返回:
(train_dataset, test_dataset)
"""
data_dir = os.path.join(os.path.dirname(__file__), "data")
os.makedirs(data_dir, exist_ok=True)

transform_list = [transforms.ToTensor()]
if normalize:
# MNIST 的全局均值/方差
transform_list.append(transforms.Normalize((0.1307,), (0.3081,)))
transform = transforms.Compose(transform_list)

train_dataset = datasets.MNIST(root=data_dir, train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root=data_dir, train=False, download=True, transform=transform)
return train_dataset, test_dataset


class MnistCNN(nn.Module):
"""
用于 MNIST 分类的简单 CNN 模型。
结构: Conv(32)-ReLU-MaxPool -> Conv(64)-ReLU-MaxPool -> Flatten -> FC(128)-ReLU-Dropout -> FC(10)
"""

def __init__(self, num_classes: int = 10, dropout_rate: float = 0.5) -> None:
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, padding=0),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(32, 64, kernel_size=3, padding=0),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2),
)
# 输入 28x28 -> conv3 -> 26x26 -> pool -> 13x13 -> conv3 -> 11x11 -> pool -> 5x5
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(64 * 5 * 5, 128),
nn.ReLU(inplace=True),
nn.Dropout(p=dropout_rate),
nn.Linear(128, num_classes),
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.features(x)
x = self.classifier(x)
return x


def evaluate(model: nn.Module, loader: DataLoader, device: torch.device, criterion: nn.Module) -> Tuple[float, float]:
"""
在给定数据加载器上评估模型。

参数:
model: 待评估模型。
loader: 数据加载器。
device: 设备。
criterion: 损失函数。

返回:
(平均损失, 准确率)
"""
model.eval()
total_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for inputs, targets in loader:
inputs = inputs.to(device)
targets = targets.to(device)
outputs = model(inputs)
loss = criterion(outputs, targets)
total_loss += loss.item() * inputs.size(0)
_, predicted = outputs.max(1)
correct += predicted.eq(targets).sum().item()
total += targets.size(0)
avg_loss = total_loss / max(total, 1)
accuracy = correct / max(total, 1)
return avg_loss, accuracy


def train_and_evaluate(
model: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
device: torch.device,
epochs: int = 5,
output_dir: Optional[str] = None,
) -> Dict[str, Any]:
"""
训练并评估模型,可选保存最优权重与最终模型。

参数:
model: 待训练模型。
train_loader: 训练集 DataLoader。
val_loader: 验证/测试集 DataLoader。
device: 训练设备。
epochs: 训练轮数。
output_dir: 若提供,将保存最佳与最终模型。

返回:
包含 eval_loss、eval_acc、best_model_path、final_model_path 的字典。
"""
if output_dir:
os.makedirs(output_dir, exist_ok=True)
best_model_path = os.path.join(output_dir, "best_model.pth")
final_model_path = os.path.join(output_dir, "mnist_cnn_final.pth")
else:
best_model_path = None
final_model_path = None

model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

best_val_acc = -1.0
epochs_no_improve = 0
patience = 3

for epoch in range(1, epochs + 1):
model.train()
running_loss = 0.0
total = 0
for inputs, targets in train_loader:
inputs = inputs.to(device)
targets = targets.to(device)

optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()

batch_size_actual = targets.size(0)
running_loss += loss.item() * batch_size_actual
total += batch_size_actual

train_loss = running_loss / max(total, 1)
val_loss, val_acc = evaluate(model, val_loader, device, criterion)

print(
f"Epoch {epoch}/{epochs} - train_loss: {train_loss:.4f} - val_loss: {val_loss:.4f} - val_acc: {val_acc:.4f}"
)

improved = val_acc > best_val_acc
if improved:
best_val_acc = val_acc
epochs_no_improve = 0
if best_model_path:
torch.save(model.state_dict(), best_model_path)
else:
epochs_no_improve += 1

if epochs_no_improve >= patience:
print("Early stopping triggered.")
break

# 最终评估(使用当前权重)
final_loss, final_acc = evaluate(model, val_loader, device, criterion)
if final_model_path:
torch.save(model.state_dict(), final_model_path)

return {
"eval_loss": float(final_loss),
"eval_acc": float(final_acc),
"best_model_path": best_model_path,
"final_model_path": final_model_path,
}


def parse_args() -> argparse.Namespace:
"""
解析命令行参数。

返回:
参数命名空间。
"""
parser = argparse.ArgumentParser(
description="基于 MNIST 的手写体识别 (PyTorch)"
)
parser.add_argument("--epochs", type=int, default=5, help="训练轮数,默认 5")
parser.add_argument("--batch_size", type=int, default=128, help="批大小,默认 128")
parser.add_argument("--dropout", type=float, default=0.5, help="Dropout 比例,默认 0.5")
parser.add_argument("--seed", type=int, default=42, help="随机种子,默认 42")
parser.add_argument(
"--output_dir",
type=str,
default=None,
help="输出目录(可选),用于保存最优与最终模型",
)
return parser.parse_args()


def main() -> None:
"""
主入口:配置环境、加载数据、构建模型并训练评估。
"""
args = parse_args()

set_global_seed(args.seed)
device = get_device()

train_dataset, test_dataset = load_mnist_datasets(normalize=True)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available())
val_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2, pin_memory=torch.cuda.is_available())

model = MnistCNN(num_classes=10, dropout_rate=args.dropout)
results = train_and_evaluate(
model=model,
train_loader=train_loader,
val_loader=val_loader,
device=device,
epochs=args.epochs,
output_dir=args.output_dir,
)

print(
{
"eval_loss": results["eval_loss"],
"eval_acc": results["eval_acc"],
"best_model_path": results["best_model_path"],
"final_model_path": results["final_model_path"],
}
)


if __name__ == "__main__":
main()



1) 启动训练

1
2
# 仍在项目根目录
python .\mnist_train.py --epochs 5 --batch_size 128 --dropout 0.5 --seed 42 --output_dir .\outputs
  • --output_dir 会保存两个权重文件(训练过程中表现最好和最终权重)
  • 第一次运行会自动下载 MNIST 到脚本同级的 data 文件夹

运行时会看到类似日志:

1
2
3
Epoch 1/5 - train_loss: 0.2503 - val_loss: 0.0801 - val_acc: 0.9750
...
{'eval_loss': 0.05, 'eval_acc': 0.985, 'best_model_path': 'outputs\\best_model.pth', 'final_model_path': 'outputs\\mnist_cnn_final.pth'}
  • val_acc 常见能到 0.98 左右(即 98%+ 的准确率)

这份代码在做什么(通俗版)

  • **数据加载 load_mnist_datasets**:

    • 自动在脚本同级创建 data 并下载 MNIST
    • 把图片转为张量 ToTensor()
    • 标准化 Normalize((0.1307,), (0.3081,))(训练更稳定)
  • **模型结构 MnistCNN**:

    • 两个卷积模块提取“笔画/边缘”等局部特征
    • 两次池化减小尺寸、保留关键信息
    • 全连接层 + Dropout 输出 10 类(数字 0–9)
  • **训练流程 train_and_evaluate**:

    • 优化器 Adam,损失函数交叉熵
    • 每轮训练后在测试集上评估
    • 简单 Early Stopping(连续 3 次无提升提前停止)
    • 可保存“最佳模型”和“最终模型”到 --output_dir
  • **设备选择 get_device**:

    • 优先 CUDA(NVIDIA GPU)
    • 支持 Apple Silicon 的 MPS
    • 没有就用 CPU(也能跑,只是慢一些)

训练完了,如何做“单张图片预测”?

在项目根目录创建 infer_single.py,用于对单张图片进行预测:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# 文件名:infer_single.py
import os
import torch
from PIL import Image
from torchvision import transforms
from mnist_train import MnistCNN # 复用训练时的模型定义

def load_image_for_mnist(image_path: str) -> torch.Tensor:
"""
读取单张图片并做与训练一致的预处理。
要求:单通道(灰度) 28x28。如果不是,会自动转灰度并缩放。
"""
img = Image.open(image_path).convert("L")
img = img.resize((28, 28))
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
])
return transform(img).unsqueeze(0) # [1, 1, 28, 28]

def infer(image_path: str, model_path: str) -> int:
"""
加载模型并对单张图片进行预测,返回 0-9 的整数类别。
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MnistCNN(num_classes=10, dropout_rate=0.0).to(device)
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict)
model.eval()

x = load_image_for_mnist(image_path).to(device)
with torch.no_grad():
logits = model(x)
pred = logits.argmax(dim=1).item()
return pred

if __name__ == "__main__":
# 示例:从 outputs 加载最终权重,对 test.png 进行预测
model_file = os.path.join("outputs", "mnist_cnn_final.pth")
test_image = "test.png" # 请换成你的图片路径
result = infer(test_image, model_file)
print(f"Predicted digit: {result}")

运行方式:

1
2
# 仍在虚拟环境中
python .\infer_single.py

准备测试图片的小建议:

  • 尺寸与通道:任意尺寸均可,脚本会转为灰度并缩放至 28×28
  • 颜色:白底黑字或黑底白字都可(保持笔迹清晰)

常见问题(Windows/国内网络)

  • 下载数据很慢/失败

    • 多试几次,或更换网络
    • 也可手动下载 MNIST 的四个 gz 文件放到 ./data/MNIST/raw/
  • 训练报显存不足(GPU)

    • 减小 --batch_size(如 64、32)
  • 训练太慢(CPU)

    • 先用 --epochs 1 验证流程
    • 可将 DataLoadernum_workers 改为 0,或关闭 pin_memory

可以改动哪些超参数来“玩一玩”?

  • epochs:训练轮数(更多轮通常更准)
  • batch_size:批大小(更大通常更快,但更占显存)
  • dropout:过拟合时可增大
  • 学习率:代码中默认为 1e-3,可尝试稍大/稍小
  • Early Stopping:将 patience 调大以延迟提前停止

代码与目录结构简述

  • mnist_train.py(训练入口):
    • set_global_seed:固定随机种子
    • get_device:选择 CUDA/MPS/CPU
    • load_mnist_datasets:下载并加载数据集(到脚本同级 data
    • MnistCNN:CNN 模型
    • evaluate:评估
    • train_and_evaluate:训练循环 + Early Stopping + 模型保存
    • main:参数解析与运行入口
  • outputs/:训练时保存权重
  • data/:自动下载的 MNIST 数据

一步到位:命令清单(可复制运行)

推荐使用conda构建虚拟环境后配置环境,有关conda的使用可以自行搜索或联系作者
如不使用conda,则不需要运行下面质量中的第2、3行

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 1) 进入项目并创建/激活虚拟环境
cd path/to/code-directory #目录需替换为自己的目录
conda create -n HD python=3.10
conda activate HD

# 2) 安装依赖(CPU 版)
python -m pip install --upgrade pip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu

# 3) 训练(保存模型到 outputs)
python .\mnist_train.py --epochs 5 --batch_size 128 --dropout 0.5 --seed 42 --output_dir .\outputs

# 4) 推理(准备一张 test.png)
python .\infer_single.py

延伸学习

  • 尝试数据增强(旋转、平移、弹性形变等)
  • 引入学习率调度器(StepLR、CosineAnnealingLR 等)
  • 试更深/更轻量的网络
  • 导出 ONNX 或 TorchScript,做跨平台部署

祝你玩得开心,亲手训练出你的第一个“能看懂手写数字”的小模型!

CATALOG
  1. 1. 从零开始:用 PyTorch 实现手写体识别(MNIST)
    1. 1.1. 这篇文章适合谁
    2. 1.2. 我们要做什么
  2. 2. 环境准备(Windows + PowerShell)
    1. 2.1. 1) 创建并激活虚拟环境
    2. 2.2. 2) 安装依赖(CPU 版)
  3. 3. 快速运行训练
    1. 3.1. 1) 启动训练
  4. 4. 这份代码在做什么(通俗版)
  5. 5. 训练完了,如何做“单张图片预测”?
  6. 6. 常见问题(Windows/国内网络)
  7. 7. 可以改动哪些超参数来“玩一玩”?
  8. 8. 代码与目录结构简述
  9. 9. 一步到位:命令清单(可复制运行)
  10. 10. 延伸学习