图像分类训练技巧
图像分类项目不是模型一换就好。很多时候,真正决定效果的是训练细节:数据增强是否合理、学习率是否稳定、验证集是否可信、错误样本有没有被分析。
学习目标
- 能判断训练不收敛、过拟合、欠拟合的常见原因
- 理解学习率、batch size、数据增强和正则化的作用
- 知道类别不平衡和数据泄漏会怎样影响分类结果
- 能用错误样本分析指导下一轮改进
先看训练问题地图
一、学习率是最先检查的旋钮
学习率太大,loss 可能震荡甚至发散;学习率太小,训练会非常慢,模型看起来像没学到东西。初学时可以先从一个常见默认值开始,再观察训练曲线。
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
如果训练 loss 和验证 loss 都很高,可能是欠拟合或学习率不合适。如果训练 loss 很低但验证 loss 很高,通常是过拟合或数据划分有问题。
二、数据增强要符合真实场景
数据增强不是越多越好,而是模拟真实世界可能出现的变化。猫狗分类可以水平翻转,但数字识别随便旋转 180 度可能改变语义;医学影像也不能随意做不符合成像逻辑的增强。
from torchvision import transforms
train_tfms = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.2, contrast=0.2),
transforms.ToTensor(),
])
增强的原则是:训练集做增强,验证集不做随机增强;增强应该保留标签语义;增强后最好人工抽查几张图。
三、过拟合和欠拟合怎么区分
| 现象 | 可能原因 | 优先处理 |
|---|---|---|
| 训练和验证都差 | 模型太弱、训练不够、学习率问题 | 增加训练轮数、调学习率、换 backbone |
| 训练好验证差 | 过拟合、数据少、增强不足 | 加强增强、正则化、早停、更多数据 |
| 训练波动大 | batch 太小、学习率偏大 | 降学习率、增大 batch、检查数据 |
| 验证分数异常高 | 数据泄漏 | 检查重复图片、同一主体是否跨集合 |
四、类别不平衡要看混淆矩阵
准确率在类别不平衡时很容易骗人。比如 95% 图片都是正常样本,模型全预测正常也有 95% 准确率,但它完全不会识别异常。
from sklearn.metrics import classification_report, confusion_matrix
print(classification_report(y_true, y_pred))
print(confusion_matrix(y_true, y_pred))
类别不平衡可以考虑重采样、class weight、focal loss 或补充少数类数据。选择哪种方法,要看少数类样本是否足够可靠。
五、错误样本分析
每次训练后至少抽查 20 个错误样本。把它们分成几类:标注错误、图像质量差、类别边界模糊、模型关注错区域、训练集中类似样本太少。错误样本分析比盲目换模型更能指导下一步。
六、最小训练记录模板
README 或实验记录里建议保留:数据集版本、训练/验证划分方式、模型结构、输入尺寸、增强策略、学习率、batch size、epoch、最佳指标、混淆矩阵、错误样本截图和下一步计划。
常见误区
第一个误区是只看 accuracy,不看类别级指标 。第二个误区是验证集也用了随机增强。第三个误区是同一对象或同一视频帧同时出现在训练和验证,造成泄漏。第四个误区是一遇到效果差就换模型,而不先检查数据和训练曲线。
练习
- 训练一个小型分类模型,画出 train loss 和 val loss 曲线。
- 对同一模型分别使用弱增强和强增强,比较验证集效果。
- 输出混淆矩阵,找出最容易混淆的两个类别。
- 整理 10 张错误样本,给每张写一句可能原因。
过关标准
学完本节后,你应该能根据训练曲线判断常见问题,能设计合理的数据增强,能用混淆矩阵分析类别问题,并能把错误样本分析写进图像分类项目 README。