当前位置: 首页 > news >正文

别再死记硬背公式了!用Python手写线性回归,从MSE、R²到梯度下降一次搞懂

从零实现线性回归:用Python揭开机器学习黑箱

当你第一次接触机器学习时,线性回归往往是入门的第一课。但太多教程止步于调用sklearn的几行代码,把最关键的原理变成了一个"黑箱"。今天,我们要用Python和NumPy亲手拆解这个黑箱,从数学推导到代码实现,完整走一遍线性回归的构建过程。

1. 线性回归的本质与数学基础

线性回归的核心思想非常简单:找到一条直线(在高维空间中是超平面),使得所有数据点到这条直线的垂直距离之和最小。但这条简单的直线背后,蕴含着丰富的数学原理。

关键概念解析

  • 假设函数:$h_\theta(x) = \theta_0 + \theta_1x_1 + \theta_2x_2 + ... + \theta_nx_n$
  • 参数(θ):模型需要学习的权重值
  • 特征(x):输入数据的各个维度

在实际应用中,我们通常会遇到两种求解线性回归参数的方法:

方法类型求解方式优点缺点
解析解(最小二乘法)直接通过矩阵运算求解一次计算得到最优解大数据集计算成本高
数值解(梯度下降)迭代逼近最优解适合大规模数据需要选择学习率等超参数

2. 实现最小二乘法解析解

最小二乘法是线性回归最直接的求解方式,它通过矩阵运算直接计算出最优参数θ。让我们看看如何用NumPy实现:

import numpy as np class LinearRegression: def __init__(self): self.theta = None def fit(self, X, y): # 添加偏置项(x0=1) X_b = np.c_[np.ones((X.shape[0], 1)), X] # 计算解析解:(X^T X)^-1 X^T y self.theta = np.linalg.inv(X_b.T.dot(X_b)).dot(X_b.T).dot(y) return self def predict(self, X): X_b = np.c_[np.ones((X.shape[0], 1)), X] return X_b.dot(self.theta)

实现细节说明

  1. np.c_用于将一列1与原始特征矩阵拼接,对应偏置项θ0
  2. np.linalg.inv计算矩阵的逆
  3. 矩阵乘法遵循(X^T X)^-1 X^T y的数学公式

注意:当特征矩阵X^T X不可逆时,需要添加小的正则项或使用伪逆

3. 梯度下降法实现

对于大规模数据集,矩阵求逆可能计算量太大。这时,梯度下降这种迭代方法就派上用场了。梯度下降通过不断沿损失函数梯度方向更新参数,逐步逼近最优解。

批量梯度下降实现

def fit_gradient_descent(self, X, y, learning_rate=0.01, n_iters=1000): n_samples, n_features = X.shape self.theta = np.zeros(n_features + 1) # 初始化参数 X_b = np.c_[np.ones((n_samples, 1)), X] for _ in range(n_iters): gradients = 2/n_samples * X_b.T.dot(X_b.dot(self.theta) - y) self.theta -= learning_rate * gradients return self

梯度下降有几个关键点需要注意:

  • 学习率选择:太大导致震荡,太小收敛慢
  • 迭代次数:需要足够让算法收敛
  • 特征缩放:不同特征尺度差异大时需要归一化

梯度下降变体对比

类型每次迭代样本数收敛速度内存需求
批量梯度下降全部样本稳定但慢
随机梯度下降1个样本快但不稳定
小批量梯度下降小批量样本折中中等

4. 模型评估指标实现

模型建好了,如何评估它的表现呢?最常用的两个指标是均方误差(MSE)和R²分数。

MSE实现

def mse_score(y_true, y_pred): return np.mean((y_true - y_pred) ** 2)

R²分数实现

def r2_score(y_true, y_pred): ss_res = np.sum((y_true - y_pred) ** 2) ss_tot = np.sum((y_true - np.mean(y_true)) ** 2) return 1 - (ss_res / ss_tot)

这两个指标各有侧重:

  • MSE:直接反映预测值与真实值的平均平方误差
  • :表示模型解释的方差比例,范围通常在0-1之间

提示:在实际项目中,建议同时计算多个评估指标,从不同角度评估模型性能

5. 完整案例演示

让我们用一个完整的例子把所有这些概念串联起来。假设我们有一组房屋面积与价格的数据:

# 生成示例数据 np.random.seed(42) X = 2 * np.random.rand(100, 1) y = 4 + 3 * X + np.random.randn(100, 1) # 划分训练测试集 X_train, X_test = X[:80], X[80:] y_train, y_test = y[:80], y[80:] # 训练模型 lin_reg = LinearRegression() lin_reg.fit(X_train, y_train) # 预测并评估 y_pred = lin_reg.predict(X_test) print("MSE:", mse_score(y_test, y_pred)) print("R²:", r2_score(y_test, y_pred)) # 可视化结果 import matplotlib.pyplot as plt plt.scatter(X, y) plt.plot(X, lin_reg.predict(X), color='red') plt.xlabel('房屋面积') plt.ylabel('价格') plt.show()

这个完整流程展示了从数据准备到模型评估的全过程。在实际项目中,你还需要考虑:

  • 数据预处理(缺失值、异常值处理)
  • 特征工程(特征选择、多项式特征)
  • 模型调参(正则化、学习率调整)

6. 进阶话题与优化方向

掌握了线性回归的基础实现后,有几个重要的进阶方向值得探索:

正则化方法

  • 岭回归(L2正则):解决特征共线性问题
  • Lasso回归(L1正则):自动进行特征选择

多项式回归: 通过增加特征的高次项,线性回归可以拟合非线性关系:

from sklearn.preprocessing import PolynomialFeatures poly_features = PolynomialFeatures(degree=2, include_bias=False) X_poly = poly_features.fit_transform(X)

数值稳定性优化

  • 使用np.linalg.pinv代替逆矩阵计算
  • 添加小的正则项保证矩阵可逆
  • 实现QR分解等更稳定的求解方法

在实现这些优化时,你会发现线性回归这个看似简单的模型,其实蕴含着丰富的优化空间和技巧。

http://www.jsqmd.com/news/907363/

相关文章:

  • 深入解析 SmartPrintAI:基于 MAF + DeepSeek + MCP 的智能物流打印平台
  • 免费服务器指南:GitHub Pages搭建静态网站全攻略
  • Bootstrap方法避坑指南:什么时候用?什么时候千万别用?(附R代码验证)
  • 从安装到第一个视觉项目:Halcon20.11环境搭建与‘Hello World’实战
  • Conan C++ 包管理工具深度解析
  • 26HVV护网行动 初 中 高 级人员招聘
  • 7nm工艺下,我为什么从ICC2换到了Innovus?聊聊真实项目里的那些坑
  • 测试左移 + 右移 + 自动化,三位一体构建质量护城河
  • 别再只仿真了!用100个三极管在面包板上还原4位加法器,我总结了这些避坑指南
  • CocosCreator 2.4.4 长列表性能翻倍:手把手教你实现带缓存池的无尽循环列表(告别图片闪烁)
  • 华为BGP选路实战:用这3个属性(PrefVal、Local_Pref、MED)轻松搞定网络流量调度
  • AMD电脑装VMware报错?手把手教你进BIOS开启SVM Mode(附华硕/微星/技嘉主板截图)
  • EasyOCR模型下载太慢?手把手教你离线部署与自定义训练,打造专属OCR识别引擎
  • 有机化学真的在指数增长吗?数据告诉你另一个故事
  • 告别‘丑地图’!用ArcGIS Pro的视觉效果和后处理,轻松打造高级感分析图
  • RAG 04:向量数据库与索引算法
  • Shader - 水体(保姆级)
  • CentOS环境下手动升级openssl、openssh
  • MacType字体渲染引擎深度解析:Windows字体美化的核心技术方案
  • AVL Cruise 2023 保姆级教程:手把手教你用自带实例模型搞定纯电动车续航仿真
  • RTX51 Tiny在SiLABS SFR分页机制下的移植优化
  • RTX51 Tiny调试技巧与C源代码显示问题解析
  • 在mac上安装hermes
  • 鼎捷Tiptop ERP 5.3版本下,手把手教你用SoapUI测试一个用户登录WebService接口
  • RAG 技术体系:从向量检索到生产级 Pipeline
  • 保姆级教程:用PyTorch Geometric搭建GCN,实战DEAP脑电情绪分类(附完整代码)
  • 深入UGUI底层:手把手教你用OnPopulateMesh和顶点偏移,实现Image的任意变形(不只是倾斜)
  • 大数据处理:Spark与分布式计算
  • 用 Nerfstudio 和手机照片,5分钟快速生成你的第一个 3D 数字手办(Nerfacto 模型实战)
  • 告别双系统安装噩梦:Intel RST模式下无损切换AHCI,保住Windows再装Ubuntu