一、环境搭建
实验采用Python3.8环境,主要依赖库:
- PyTorch 1.12:深度学习框架
- Torchvision 0.13:提供MNIST数据集
- OpenCV 4.6:图像预处理
安装命令:pip install torch torchvision opencv-python
二、实战开发步骤
- 数据加载技巧
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,)) # 数据集均值标准差
])train_set = datasets.MNIST('data/', train=True, download=True, transform=transform)
test_set = datasets.MNIST('data/', train=False, transform=transform)
- 改进型网络设计
class EnhancedCNN(nn.Module):def __init__(self):super().__init__()self.features = nn.Sequential(nn.Conv2d(1, 16, 3, padding=1), nn.BatchNorm2d(16),nn.ReLU(),nn.MaxPool2d(2),nn.Conv2d(16, 32, 3, padding=1),nn.Dropout(0.25) )self.classifier = nn.Sequential(nn.Linear(32*7*7, 128),nn.ReLU(),nn.Linear(128, 10))
- 训练优化技巧
def train_model():model = EnhancedCNN()optimizer = torch.optim.Adam(model.parameters(), lr=0.001)scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)for epoch in range(10):model.train()for data, target in train_loader:optimizer.zero_grad()output = model(data)loss = F.cross_entropy(output, target)loss.backward()optimizer.step()scheduler.step()
三、效果验证
在测试集上达到98.7%准确率的关键:
- 添加BatchNorm层加速收敛
- 使用Dropout防止过拟合
- 学习率阶梯下降策略
四、模型部署示例
def predict_image(img_path):img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)img = cv2.resize(img, (28,28))img_tensor = transform(255 - img).unsqueeze(0)with torch.no_grad():pred = torch.argmax(model(img_tensor)).item()return pred
思考延伸
尝试使用数据增强(旋转、平移)提升模型鲁棒性,比较不同优化器的性能差异,思考如何将模型部署到移动端应用。