Python实战:用sklearn Pipeline轻松搞定多项式回归(附完整代码)
最近在帮一个做电商数据分析的朋友优化他们的销量预测模型,他们发现用简单的线性回归去拟合广告投入和销售额的关系,效果总是不尽人意。数据点散落在一条曲线周围,硬用直线去套,偏差很大。这其实是一个典型的非线性关系场景,也是很多数据分析师和算法工程师在实际项目中都会遇到的“拦路虎”。线性回归的简洁高效让人爱不释手,但面对现实世界复杂的曲线关系,我们该怎么办?答案不是抛弃线性回归,而是升级它——这就是多项式回归的魅力所在。
多项式回归的核心思想非常巧妙:它没有发明一个全新的、复杂的非线性模型,而是通过特征工程,将原始特征进行多项式组合(比如平方、立方),生成新的特征,然后在这些新特征构成的高维空间里,依然使用我们熟悉的线性回归模型进行拟合。简单说,就是“把曲线掰直了再套用直线”。这种方法既保留了线性回归在求解和解释上的优势,又具备了拟合非线性数据的能力。
然而,从理论到实践,中间隔着好几道坎:如何高效地生成多项式特征?如何避免高次项带来的数值不稳定问题?如何将数据预处理和模型训练无缝衔接,避免在测试集上犯错?这些问题如果处理不当,不仅代码会变得冗长混乱,还极易引入难以察觉的Bug。幸运的是,scikit-learn(sklearn)为我们提供了一个堪称“神器”的工具——Pipeline。它能把数据预处理、特征工程、模型训练等一系列步骤封装成一个流畅的“流水线”,让代码既简洁又健壮。这篇文章,我就以一个实战案例为线索,带你彻底掌握如何用sklearn Pipeline优雅、高效地实现多项式回归,并分享一些我踩过坑后总结的实用技巧。
1. 为什么线性回归会“失灵”?从数据可视化开始
在动手写代码之前,我们先得用眼睛“看”数据。理解数据的内在模式,是选择正确模型的第一步。很多初学者拿到数据就直接往模型里塞,这是大忌。让我们先构造一个具有明显非线性关系的数据集,直观感受一下。
import numpy as np
import matplotlib.pyplot as plt
plt.style.use('seaborn-v0_8-darkgrid') # 使用更美观的绘图样式
# 设置随机种子,确保结果可复现
np.random.seed(42)
# 生成模拟数据:一个二次关系加上随机噪声
n_samples = 150
X_raw = np.random.uniform(-4, 4, size=n_samples)
# 真实关系:y = 0.7 * X^2 + 1.5 * X + 3
y = 0.7 * X_raw**2 + 1.5 * X_raw + 3 + np.random.normal(0, 1.5, n_samples)
# 将数据整理为sklearn需要的二维数组格式 (n_samples, n_features)
X = X_raw.reshape(-1, 1)
# 绘制原始数据散点图
fig, ax = plt.subplots(figsize=(10, 6))
ax.scatter(X, y, alpha=0.7, label='原始数据点', color='steelblue')
ax.set_xlabel('特征 X', fontsize=12)
ax.set_ylabel('目标值 y', fontsize=12)
ax.set_title('原始数据分布(明显的非线性关系)', fontsize=14)
ax.legend()
plt.tight_layout()
plt.show()
运行这段代码,你会看到数据点大致沿着一条开口向上的抛物线分布。这时,如果我们强行使用普通线性回归去拟合,结果会怎样?
from sklearn.linear_model import LinearRegression
# 训练简单线性回归模型
simple_lr = LinearRegression()
simple_lr.fit(X, y)
y_pred_linear = simple_lr.predict(X)
# 计算性能指标(这里用简单的R²分数)
from sklearn.metrics import r2_score
r2_linear = r2_score(y, y_pred_linear)
# 在同一张图上绘制拟合直线
fig, ax = plt.subplots(figsize=(10, 6))
ax.scatter(X, y, alpha=0.7, label='原始数据点', color='steelblue')
ax.plot(X, y_pred_linear, color='crimson', linewidth=3, label=f'线性回归拟合 (R²={r2_linear:.3f})')
ax.set_xlabel('特征 X', fontsize=12)
ax.set_ylabel('目标值 y', fontsize=12)
ax.set_title('线性回归拟合效果:明显欠拟合', fontsize=14)
ax.legend()
plt.tight_layout()
plt.show()
print(f"线性回归模型方程: y = {simple_lr.coef_[0]:.3f} * X + {simple_lr.intercept_:.3f}")
注意:R²分数越接近1,表示模型对数据的解释力越强。从图中和分数可以清晰看到,直线无法捕捉数据的弯曲趋势,这就是欠拟合。模型过于简单,无法学习数据中的潜在模式。
这个视觉化的对比至关重要。它告诉我们,当特征与目标值之间存在非线性关联时,线性模型的表达能力是不够的。我们需要一种方法,让模型能够“看到”X和X²、X³等之间的关系。这就是多项式特征的用武之地。
<
199

被折叠的 条评论
为什么被折叠?



