当前位置: 首页 > news >正文

PyTorch实现二分类

二分类问题的实现方法,核心是把线性回归的 “连续值输出” 改成 “0/1 类别概率输出”。最基础常用的二分类模型基于逻辑回归(Logistic Regression)。

线性回归实现方式:PyTorch实现线性回归-CSDN博客

二分类本质上也是一种回归(Regression)问题,在上述线性回归的基础上修改就可以实现。下面是线性回归与二分类任务的差异:

环节线性回归(回归任务)二分类(分类任务)
输出目标连续数值(如 y=2x 的预测值)0/1 类别概率(0≤P≤1)
核心激活函数无(直接输出线性结果)Sigmoid(把线性输出映射到 0-1)
损失函数MSELoss(均方误差)BCELoss(二元交叉熵损失)
预测逻辑直接取输出值概率 > 0.5 归为 1 类,≤0.5 归为 0 类

1. 准备数据集(Prepare Dataset)

对比线性回归,数据格式还是 Tensor,但标签y_data是0/1 离散值,这是分类任务的核心特征。

import torch # 构造数据集:特征x(学分),标签y(0=不及格,1=及格) # 样本:[1.0], [2.0], [3.0] → 标签:0,0,1 x_data = torch.Tensor([[1.0], [2.0], [3.0]]) y_data = torch.Tensor([[0], [0], [1]])

2. 设计模型(Design model)

class LogisticRegressionModel(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(1, 1) def forward(self, x): # 核心:线性输出 + Sigmoid激活 → 映射到0-1概率 y_pred = torch.sigmoid(self.linear(x)) return y_pred model = LogisticRegressionModel()
  • Sigmoid函数公式:
  • forward中增加torch.sigmoid()把线性层的 “任意实数输出” 压缩到0~1 区间,这个值就是 “样本属于 1 类的概率”。

3. 构造损失函数(Construct Loss)

criterion = nn.BCELoss(reduction='sum')
  • BCELoss:二元交叉熵损失,是二分类的专用损失。

关于二元交叉熵损失函数的介绍,参考文章PyTorch_conda-CSDN博客中《nn.BCELoss(二元交叉熵损失)》一节。

4. 构造优化器(Construct Optimizer)

optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)

5. 训练循环(Training Cycle)

for epoch in range(1000): y_pred = model(x_data) # 前向传播(计算预测值) loss = criterion(y_pred, y_data) # 计算损失值 print(epoch, loss.item()) optimizer.zero_grad() # 梯度清零 loss.backward() # 反向传播计算梯度 optimizer.step() # 更新参数

完整实例

import torch import torch.nn as nn x_data = torch.Tensor([[1.0], [2.0], [3.0]]) y_data = torch.Tensor([[0], [0], [1]]) class LogisticRegressionModel(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(1, 1) def forward(self, x): return torch.sigmoid(self.linear(x)) model = LogisticRegressionModel() criterion = nn.BCELoss(reduction='sum') optimizer = torch.optim.SGD(model.parameters(), lr=0.01) for epoch in range(10000): y_pred = model(x_data) loss = criterion(y_pred, y_data) if epoch % 1000 == 0: print(f"Epoch: {epoch}, Loss: {loss.item():.4f}") optimizer.zero_grad() loss.backward() optimizer.step() x_test = torch.Tensor([[4.0]]) y_test_pred = model(x_test) print("\n测试结果:") print('y_pred = ', y_test_pred.data) # 查看模型参数 print(f"\n模型权重:{model.linear.weight.item():.6f}") print(f"模型偏置:{model.linear.bias.item():.6f}")
http://www.jsqmd.com/news/359100/

相关文章:

  • (100分)- 对称美学(Java JS Python)
  • java并发:管道流(Piped Streams)的应用场景
  • 【计算机毕业设计案例】基于springboot+vue的微信小程序的智慧校园平台基于springboot+小程序的高校校园信息交流平台小程序设计与实现(程序+文档+讲解+定制)
  • (100分)- 二元组个数(Java JS Python)
  • 计算机小程序毕设实战-基于SpringBoot中小学家校通系统的设计与实现springboot+小程序的家校通程序设计与实现【完整源码+LW+部署说明+演示视频,全bao一条龙等】
  • ubuntu启用ssh (广域网访问)(IPV6访问)
  • 要不然让ai研究原神的界面也行,比如写个skill文件按下某个按键会进入什么界面,不给坐标,搞个程序识别按钮给个固定标签
  • (100分)- 等和子数组最小和(Java JS Python)
  • 【课程设计/毕业设计】基于微信小程序的校园信息交流平台基于springboot+小程序的高校校园信息交流平台小程序设计与实现【附源码、数据库、万字文档】
  • 内网共享神器,手机电脑一键互传大文件
  • (100分)- 端口合并(Java JS Python)
  • 【课程设计/毕业设计】基于springboot+小程序的家校通程序设计与实现消息推送、班级管理、作业管理、考勤管理、成绩管理【附源码、数据库、万字文档】
  • (100分)- 单词倒序(Java JS Python)
  • 小程序毕设项目:基于springboot+小程序的高校校园信息交流平台小程序设计与实现(源码+文档,讲解、调试运行,定制等)
  • 小程序毕设项目:基于springboot+小程序的家校通程序设计与实现(源码+文档,讲解、调试运行,定制等)
  • (100分)- 单向链表中间节点(Java JS Python)
  • (100分)- 打印机队列(Java JS Python)
  • 创业三年,记录来时路
  • jwt和oauth2的原理、特点、区别及使用场景
  • 计算机小程序毕设实战-基于springboot+小程序的高校生活互助平台小程序基于SpringBoot的高校报修与互助平台小程序【完整源码+LW+部署说明+演示视频,全bao一条龙等】
  • 戴尔服务器常用设置
  • 如何在 Teams 中添加一个页面
  • 【课程设计/毕业设计】基于SpringBoot校园生活服务小程序基于springboot+小程序的高校生活互助平台小程序【附源码、数据库、万字文档】
  • STC15F204EA概述
  • 对于tarjan的思考
  • 小程序毕设项目:基于springboot+小程序的高校生活互助平台小程序(源码+文档,讲解、调试运行,定制等)
  • Python快速入门——学习笔记(持续更新中~)
  • 2月8日-(OpenSpec规范)
  • 《深入理解Java虚拟机》| 运行时数据区与OOM异常
  • 小程序计算机毕设之基于springboot+小程序的高校生活互助平台小程序基于SpringBoot校园生活服务小程序(完整前后端代码+说明文档+LW,调试定制等)