偏差-方差权衡
本节定位
偏差-方差权衡(Bias-Variance Tradeoff) 是机器学习中最重要的理论框架之一。它解释了为什么模型会欠拟合或过拟合,以及如何找到两者之间的最佳平衡。
学习目标
- 深入理解偏差(Bias)和方差(Variance)
- 理解欠拟合和过拟合的本质
- 掌握学习曲线分析
- 掌握验证曲线分析
- 理解正则化如何影响偏差-方差
一、什么是偏差和方差?
1.1 直觉理解——打靶比喻
| 偏差(Bias) | 方差(Variance) | |
|---|---|---|
| 含义 | 模型预测值与真实值的系统性偏移 | 模型对不同训练数据的敏感程度 |
| 高 → | 欠拟合(模型太简单) | 过拟合(模型太复杂) |
| 解决 | 增加模型复杂度 | 减少模型复杂度、增加数据 |
1.2 总误差分解
总误差 = 偏差² + 方差 + 不可约误差(噪声)
import numpy as np
import matplotlib.pyplot as plt
# 可视化偏差-方差权衡
complexity = np.linspace(0.1, 10, 100)
bias_sq = 5 / complexity
variance = 0.5 * complexity
noise = 0.5 * np.ones_like(complexity)
total = bias_sq + variance + noise
plt.figure(figsize=(8, 5))
plt.plot(complexity, bias_sq, 'b-', linewidth=2, label='偏差²')
plt.plot(complexity, variance, 'r-', linewidth=2, label='方差')
plt.plot(complexity, noise, 'g--', linewidth=1, label='噪声(不可约)')
plt.plot(complexity, total, 'k-', linewidth=2, label='总误差')
best_idx = np.argmin(total)
plt.axvline(x=complexity[best_idx], color='orange', linestyle=':', label='最优复杂度')
plt.xlabel('模型复杂度')
plt.ylabel('误差')
plt.title('偏差-方差权衡')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
二、实际观察偏差和方差
2.1 用多项式回归演示
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import make_pipeline
# 生成非线性数据
np.random.seed(42)
n = 30
X = np.sort(np.random.uniform(-3, 3, n))
y_true_func = lambda x: np.sin(x)
y = y_true_func(X) + np.random.randn(n) * 0.3
x_plot = np.linspace(-3.5, 3.5, 200)
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
configs = [
(1, '欠拟合(degree=1)\n高偏差,低方差'),
(4, '刚好(degree=4)\n偏差方差平衡'),
(15, '过拟合(degree=15)\n低偏差,高方差'),
]
for ax, (deg, title) in zip(axes, configs):
# 用不同数据子集训练多次,观察方差
for seed in range(10):
np.random.seed(seed)
X_sample = np.sort(np.random.uniform(-3, 3, n))
y_sample = y_true_func(X_sample) + np.random.randn(n) * 0.3
model = make_pipeline(PolynomialFeatures(deg, include_bias=False), LinearRegression())
model.fit(X_sample.reshape(-1, 1), y_sample)
y_pred = model.predict(x_plot.reshape(-1, 1))
y_pred = np.clip(y_pred, -3, 3)
ax.plot(x_plot, y_pred, alpha=0.3, color='steelblue')
ax.plot(x_plot, y_true_func(x_plot), 'r--', linewidth=2, label='真实函数')
ax.scatter(X, y, color='black', s=20, zorder=5)
ax.set_title(title)
ax.set_ylim(-3, 3)
ax.legend(fontsize=8)
ax.grid(True, alpha=0.3)
plt.suptitle('偏差-方差直觉(10 次不同数据训练)', fontsize=13)
plt.tight_layout()
plt.show()
观察要点
- degree=1:10 条线几乎重合(低方差),但都偏离真实函数(高偏差)
- degree=15:10 条线差异很大(高方差),但平均更接近真实(低偏差)
- degree=4:10 条线较一致(适当方差),且接近真实函数(适当偏差)
三、学习曲线
3.1 什么是学习曲线?
学习曲线展示训练集大小对模型性能的影响。它能告诉你:
- 模型是欠拟合还是过拟合
- 增加数据是否有帮助
from sklearn.model_selection import learning_curve
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_digits
digits = load_digits()
X, y = digits.data, digits.target
def plot_learning_curve(model, X, y, title, ax):
train_sizes, train_scores, val_scores = learning_curve(
model, X, y, cv=5,
train_sizes=np.linspace(0.1, 1.0, 10),
scoring='accuracy', n_jobs=-1
)
train_mean = train_scores.mean(axis=1)
train_std = train_scores.std(axis=1)
val_mean = val_scores.mean(axis=1)
val_std = val_scores.std(axis=1)
ax.fill_between(train_sizes, train_mean - train_std, train_mean + train_std, alpha=0.1, color='blue')
ax.fill_between(train_sizes, val_mean - val_std, val_mean + val_std, alpha=0.1, color='red')
ax.plot(train_sizes, train_mean, 'bo-', label='训练集')
ax.plot(train_sizes, val_mean, 'ro-', label='验证集')
ax.set_xlabel('训练样本数')
ax.set_ylabel('准确率')
ax.set_title(title)
ax.legend()
ax.grid(True, alpha=0.3)
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
# 欠拟合模型
plot_learning_curve(
DecisionTreeClassifier(max_depth=1, random_state=42),
X, y, '欠拟合(max_depth=1)\n训练和验证都低', axes[0]
)
# 刚好的模型
plot_learning_curve(
DecisionTreeClassifier(max_depth=10, random_state=42),
X, y, '适当复杂度(max_depth=10)', axes[1]
)
# 过拟合模型
plot_learning_curve(
DecisionTreeClassifier(max_depth=None, random_state=42),
X, y, '过拟合(max_depth=None)\n训练和验证差距大', axes[2]
)
plt.tight_layout()
plt.show()
3.2 如何解读学习曲线
| 现象 | 诊断 | 解决方案 |
|---|---|---|
| 训练和验证都低 | 欠拟合 | 增加模型复杂度 |
| 训练高,验证低 | 过拟合 | 更多数据 / 正则化 / 简化模型 |
| 两条线收敛且都高 | 刚好 | 模型不错 |
| 验证还在上升 | 需要更多数据 | 收集更多数据 |