当前位置: 首页 > news >正文

从西瓜数据到决策边界:手把手实现周志华《机器学习》中的对率回归分类器

1. 从西瓜数据到决策边界:初识对率回归

第一次翻开周志华老师的《机器学习》时,我被"对率回归"这个名词搞得一头雾水。听起来像是要做对数运算的回归分析?后来才发现这其实就是我们常说的逻辑回归(Logistic Regression),只不过教材采用了更严谨的学术命名。对率回归最神奇的地方在于,它能把线性回归的输出压缩到0到1之间,正好对应概率的概念。

西瓜数据集3.0α是个绝佳的入门案例。这个数据集记录了17个西瓜样本的两个关键特征:密度和含糖率,以及它们是否是好瓜的标签。我刚开始学的时候总在想,为什么不用简单的阈值判断呢?比如含糖率高于0.3就是好瓜。但实际数据会打脸——编号15的西瓜含糖率0.37却被标记为坏瓜,这说明单一特征判断会出错。

对率回归的优势这时候就显现出来了。它通过Sigmoid函数将线性组合wTx+b映射到(0,1)区间,相当于用概率的方式表达分类结果。比如输出0.7表示有70%概率属于正类。这种处理方式既考虑了特征间的线性组合,又保证了输出符合概率定义,比简单阈值法靠谱多了。

2. 数据准备与自定义DataLoader

2.1 理解西瓜数据集的结构

西瓜数据集3.0α虽然只有17条数据,但结构非常清晰:

  • 前两列是数值特征:密度(0.243~0.774)、含糖率(0.0267~0.46)
  • 最后一列是分类标签:1表示好瓜,0表示坏瓜

我建议先用pandas的describe()看看数据分布:

import pandas as pd data = pd.read_csv('watermelon_3.0a.csv') print(data.describe())

输出会显示两个特征的均值、标准差等信息。观察发现好瓜的平均含糖率(0.32)确实高于坏瓜(0.13),但存在交叉区域,这就是需要用模型学习的地方。

2.2 仿PyTorch实现DataLoader

虽然可以直接用numpy数组,但模仿PyTorch的DataLoader接口会让代码更规范。我实现的版本主要包含三个关键方法:

  • __init__:读取CSV文件并提取特征矩阵和标签矩阵
  • __len__:返回数据集样本数
  • __getitem__:支持索引访问和迭代
class WatermelonLoader: def __init__(self, data_path, x_cols, y_col): self.data = pd.read_csv(data_path) self.x = self.data[x_cols].values self.y = self.data[y_col].values.reshape(-1,1) def __len__(self): return len(self.x) def __getitem__(self, idx): return self.x[idx], self.y[idx]

使用时就像这样:

loader = WatermelonLoader('watermelon_3.0a.csv', ['密度','含糖率'], '好瓜') for x, y in loader: print(f"特征:{x}, 标签:{y}")

3. 对率回归模型的数学原理与实现

3.1 Sigmoid函数与决策边界

对率回归的核心是Sigmoid函数: σ(z) = 1/(1+e⁻ᶻ)

这个S型曲线将任意实数映射到(0,1)区间。当z=wTx+b>0时,σ(z)>0.5,我们预测为正类;反之预测为负类。决策边界就是wTx+b=0这个超平面。

在西瓜数据集的二维情况下,决策边界是一条直线: w₁·密度 + w₂·含糖率 + b = 0

3.2 极大似然估计推导

与线性回归用最小二乘法不同,对率回归使用极大似然估计。对于单个样本,其似然函数为: L(w,b) = ŷʸ(1-ŷ)⁽¹⁻ʸ⁾ 其中ŷ=σ(wTx+b)

对所有样本取对数似然: ℓ(w,b) = Σ[yⁱlog(ŷⁱ)+(1-yⁱ)log(1-ŷⁱ)]

我们的目标就是最大化这个对数似然函数。通过求导可以得到梯度: ∂ℓ/∂w = Σ(yⁱ-ŷⁱ)xⁱ

3.3 Python实现细节

我实现的LogisticRegression类包含三个关键方法:

import numpy as np class LogisticRegression: def __init__(self, lr=0.1): self.w = None self.lr = lr def sigmoid(self, z): return 1 / (1 + np.exp(-z)) def fit(self, X, y, epochs=100): # 添加偏置项 X = np.c_[X, np.ones(X.shape[0])] self.w = np.zeros(X.shape[1]) for _ in range(epochs): z = np.dot(X, self.w) y_pred = self.sigmoid(z) grad = np.dot(X.T, (y - y_pred)) self.w += self.lr * grad def predict(self, X): X = np.c_[X, np.ones(X.shape[0])] return (self.sigmoid(np.dot(X, self.w)) >= 0.5).astype(int)

注意几个关键点:

  1. 在特征矩阵X最后添加一列1,相当于把偏置b并入权重向量
  2. 使用向量化实现,避免低效的循环
  3. 学习率lr不宜过大,否则容易震荡

4. 模型训练与评估实战

4.1 留一法交叉验证

由于数据集只有17个样本,我采用留一法(Leave-One-Out)进行验证:

from sklearn.metrics import accuracy_score def loo_validation(data, epochs=100): accuracies = [] for i in range(len(data)): train = np.delete(data, i, axis=0) test = data[i:i+1] model = LogisticRegression() model.fit(train[:,:2], train[:,2], epochs) pred = model.predict(test[:,:2]) accuracies.append(pred == test[:,2]) return np.mean(accuracies) print(f"留一法准确率:{loo_validation(data.values):.2%}")

4.2 训练过程可视化

观察权重变化能更好理解模型学习过程:

plt.figure(figsize=(10,4)) for epoch in [10,50,100]: model = LogisticRegression() model.fit(X, y, epochs=epoch) # 绘制决策边界 x1 = np.linspace(0.2,0.8,100) x2 = -(model.w[0]*x1 + model.w[2])/model.w[1] plt.plot(x1,x2, label=f'epoch={epoch}') plt.scatter(X[y==0,0],X[y==0,1], c='blue', label='坏瓜') plt.scatter(X[y==1,0],X[y==1,1], c='red', label='好瓜') plt.legend()

可以看到随着训练轮次增加,决策边界逐渐移动到更合理的位置,将多数样本正确分类。

5. 决策边界与Sigmoid函数可视化

5.1 绘制二维决策边界

最终的决策边界可视化:

def plot_decision_boundary(model, X, y): # 创建网格点 x1_min, x1_max = X[:,0].min()-0.1, X[:,0].max()+0.1 x2_min, x2_max = X[:,1].min()-0.1, X[:,1].max()+0.1 xx1, xx2 = np.meshgrid(np.linspace(x1_min,x1_max,100), np.linspace(x2_min,x2_max,100)) # 预测每个网格点 Z = model.predict(np.c_[xx1.ravel(),xx2.ravel()]) Z = Z.reshape(xx1.shape) # 绘制 plt.contourf(xx1,xx2,Z,alpha=0.3) plt.scatter(X[y==0,0],X[y==0,1], c='blue', label='坏瓜') plt.scatter(X[y==1,0],X[y==1,1], c='red', label='好瓜') plt.xlabel('密度') plt.ylabel('含糖率')

5.2 Sigmoid函数与样本分布

理解Sigmoid如何将线性输出转为概率:

z = np.linspace(-10,10,100) y = 1/(1+np.exp(-z)) plt.figure(figsize=(10,4)) plt.subplot(121) plt.plot(z,y) plt.title('Sigmoid函数') # 绘制样本在Sigmoid曲线上的位置 z_samples = X @ model.w[:-1] + model.w[-1] y_samples = 1/(1+np.exp(-z_samples)) plt.subplot(122) plt.scatter(z_samples[y==0], y_samples[y==0], c='blue') plt.scatter(z_samples[y==1], y_samples[y==1], c='red') plt.plot(z,y,'k--')

右图显示好瓜样本(红色)大多位于Sigmoid曲线右侧(z>0),而坏瓜样本(蓝色)多在左侧。这正是我们希望看到的分布。

6. 完整代码实现与优化建议

6.1 完整代码结构

建议按以下结构组织代码:

/logistic_regression │── data/ │ └── watermelon_3.0a.csv │── utils.py # DataLoader等工具类 │── model.py # LogisticRegression实现 │── train.py # 训练与评估脚本 │── visualize.py # 可视化代码

6.2 性能优化技巧

在实际项目中,我总结了几个优化点:

  1. 添加L2正则化防止过拟合:
def fit(self, X, y, epochs=100, reg=0.1): # 在梯度更新步骤添加 self.w += self.lr * (grad - reg*self.w)
  1. 使用随机梯度下降(SGD)加速收敛:
for epoch in range(epochs): idx = np.random.permutation(len(X)) for i in idx: xi, yi = X[i:i+1], y[i:i+1] # 计算单个样本梯度并更新
  1. 添加早停机制(Early Stopping):
best_loss = float('inf') for epoch in range(epochs): # ...训练代码... current_loss = compute_loss() if current_loss > best_loss: break best_loss = current_loss

对率回归虽然简单,但包含了许多机器学习的关键思想。通过这个西瓜数据集的实践,我深刻理解了从数据准备、模型实现到评估可视化的完整流程。特别是在实现梯度下降时,手动推导并验证梯度的正确性让我对反向传播有了更直观的认识。建议初学者一定要亲手实现一遍,这比直接调用sklearn的LogisticRegression收获大得多。

http://www.jsqmd.com/news/853928/

相关文章:

  • 智慧工业火花火星烟火火灾检测数据集VOC+YOLO格式3965张4类别
  • 测试工程师的终身学习:如何保持测试技术竞争力
  • 终极指南:3分钟快速上手AMD Ryzen调试神器SMUDebugTool
  • 2026 PM知行商学院深度解析:定位、适配人群与创业优势测评 - 资讯速览
  • 从‘实体’到‘铰接’:一个SOLIDWORKS Simulation案例,带你理解有限元中的约束本质
  • 用STM32CubeMX的TIM6实现精准1秒定时:HAL库与LL库代码对比与选择建议
  • 终于有人把图计算讲明白了
  • 如何将 Infinix 手机中的联系人传输到 iPhone
  • Layerdivider终极指南:5步掌握AI图像分层技术,免费生成专业PSD文件
  • 如何在Photoshop中无缝集成AI绘图能力?SD-PPP插件的完整指南
  • 【vue】avue-crud表格与列属性实战:从配置清单到高效开发
  • 测试工程师的人生规划:如何平衡测试工作和生活
  • Vue3 Composition API:深度解析与最佳实践
  • 非谓语动词实战指南:解锁不定式、分词与动名词的进阶用法
  • 2026 广州天河空调移机 海珠空调维修服务前五强:拆装移机、中央空调维修清洗,靠谱实惠首选 - 广州搬家老班长
  • 从账单明细看 Taotoken 按 Token 计费模式带来的成本控制优势
  • wms系统核心功能拆解:wms系统如何提升库存准确率与作业效率
  • Nginx 是独立的反向代理 / 负载均衡软件;Ingress 是 K8s 的路由规则 API,本身不处理流量,需要 Ingress Controller(最常见就是 Nginx Ingress)
  • 告别命令盲敲:在甲骨文ARM服务器上为宝塔面板做这些安全初始化
  • 三菱PLC上位机开发避坑指南:MC协议读写D寄存器时,Float和Double到底差几个点?
  • 测试工程师的幸福感:如何在测试工作中找到成就感
  • 从化做出口怎么找财税服务商?从化出口企业找财税服务商,这6个陷阱踩了就是真金白银的损失 - 欢欢在创业
  • ExternalDNS 配置实践:自动化 DNS 记录管理
  • 从零到一:基于TrueNAS SCALE构建家庭媒体与数据备份中心
  • 2026 广州天河保洁 海珠开荒保洁前五强 开荒 上门 办公室保洁 - 广州搬家老班长
  • 不止于显示图片:在ROS2 Foxy中,用OpenCV和cv_bridge玩转摄像头图像订阅与简单处理
  • 专业视角 | 宜昌高考志愿填报的「隐形陷阱」:90%家长忽略了这三点 - 新闻快传
  • 从零到一:STM32驱动TM1637四位数码管实战解析
  • 企业如何利用多模型聚合能力构建稳定的AI客服系统
  • Vue3响应式原理:深入理解Proxy和Ref