当前位置: 首页 > news >正文

PyTorch实现二分类(多特征输出+多层神经网络)

前置文章:PyTorch实现二分类(单特征输出+单层神经网络)-CSDN博客

⭐处理多维特征输入

在上述实例中,x_data = torch.Tensor([[1.0], [2.0], [3.0]])是二维列表(矩阵),外层列表表示样本集,内层每个列表表示单个样本的特征,即表示输入三个样本,每个样本的维度为1维(1个特征feature)。当输入为多维时,如下所示二分类问题中的维度是8维(3行8列的矩阵)(8个特征features):

x_data = torch.Tensor([[-0.29, 0.49, 0.18, -0.29, 0.00, 0.00, -0.53, -0.03], [-0.88, -0.15, 0.08, -0.41, 0.00, -0.21, -0.77, -0.67], [-0.06, 0.84, 0.05, 0.00, 0.00, -0.31, -0.49, -0.63]]) y_data = torch.Tensor([[0], [1], [0]])

对于多维特征输入,模型的输入也要能够接收输入的多个特征。公式变化如下:

  • 原始公式一个样本只有一个特征,将此特征值代入线性公式后再经过激活函数Sigmoid非线性化后即可得到第个样本对应的输出
  • 变化后公式一个样本有8个特征(),将这8个特征分别进行线性公式(一个特征对应一个权重)后再经过激活函数Sigmoid非线性化后即可得到第个样本对应的输出

实际上就是将公式从标量(单维度)计算转换为矩阵计算:

  • 单样本(一维输入,单维输出):

  • N个样本(二维输入,多维输出):

输出矩阵,输入矩阵,权重矩阵

在PyTorch中我们只需要修改相应层的输入与输出维度即可,如下以线性层为例,将输入的8维数据经过线性层后降维为6维输出数据:

⭐神经网络

对于神经网络而言,就是将网络层堆叠几次,使输入数据的维度不断降低或升高(下面的图片实例是维度不断下降)。如果你最终需要的输出是一个数值,那最终就需要降维到一维,如果最终你需要的输出是一个矩阵那就降维到二维。

至于中间的维度如何变化、需要几层网络层以及需要哪些网络层

  • 上图所示的神经网络结构是全连接层 (Linear) + 激活函数 (σ)堆叠,属于最基础的多层感知机 (MLP),维度变化完全由torch.nn.Linear(in_features, out_features)决定,规则只有一条,就是上一层的输出维度 = 下一层的输入维度。
  • 输入8 维是原始特征,6 维、2 维是抽象特征,最终的1 维是任务输出。之所以要设置中间的8维到6维,6维到2维的隐藏层而不是直接8维到1维,目的是逐层提取核心特征,抛弃冗余信息,让模型学到规律。
  • 层数的选择:隐藏层越多,模型的学习能力越强,所以对于复杂任务(高维特征)隐藏层一般都在3层及以上。但是这并不意味着层数越多越好,模型学习能力过强会将数据的噪点都学习,出现过拟合(Overfitting)现象。
  • 维度的变化:一般使用漏斗型结构,即维度逐渐减小,不建议中间维度突然增大,低维输入强行扩维,会学噪声,过拟合。
  • 网络结构的选择:中间隐藏层的激活函数一般使用ReLU,Sigmoid易出现梯度消失。输出层一般使用Sigmoid(二分类)、Softmax(多分类)。除此之外也可以加一些其他神经网络层,并不只局限于线性层。

模型设计过程:

完整代码实例

在此例中,一个样本输入的特征有8个,要训练出好的模型参数需要的样本数要远远大于8个,这里为了方便只给出了三个样本。

import torch x_data = torch.Tensor([[-0.29, 0.49, 0.18, -0.29, 0.00, 0.00, -0.53, -0.03], [-0.88, -0.15, 0.08, -0.41, 0.00, -0.21, -0.77, -0.67], [-0.06, 0.84, 0.05, 0.00, 0.00, -0.31, -0.49, -0.63]]) y_data = torch.Tensor([[0], [1], [0]]) class Model(torch.nn.Module): def __init__(self): super().__init__() self.linear1 = torch.nn.Linear(8, 6) self.linear2 = torch.nn.Linear(6, 4) self.linear3 = torch.nn.Linear(4, 1) self.sigmoid = torch.nn.Sigmoid() def forward(self, x): x = self.sigmoid(self.linear1(x)) x = self.sigmoid(self.linear2(x)) x = self.sigmoid(self.linear3(x)) return x model = Model() criterion = torch.nn.BCELoss(reduction='sum') optimizer = torch.optim.SGD(model.parameters(), lr=0.01) for epoch in range(10000): y_pred = model(x_data) loss = criterion(y_pred, y_data) if epoch % 1000 == 0: print(f"Epoch: {epoch}, Loss: {loss.item():.4f}") optimizer.zero_grad() loss.backward() optimizer.step() x_test = torch.Tensor([[-0.41, 0.17, 0.21, 0.00, 0.00, 0.24, -0.89, -0.70]]) y_test_pred = model(x_test) print("\n测试结果:") print('y_pred = ', y_test_pred.data)
http://www.jsqmd.com/news/363393/

相关文章:

  • OFA视觉蕴含模型实战案例:电商搜索排序中文本相关性增强
  • 使用Anaconda快速搭建Nano-Banana开发环境
  • 游戏资源提取工具:零基础提取游戏素材完整攻略
  • NCM音频格式解锁全攻略:从加密原理到无损转换的技术探索
  • 如何实现文件格式转换与跨平台兼容:qmcdump工具的完整应用指南
  • LVGL下拉列表lv_ddlist全API详解与嵌入式实战
  • Qwen3-ASR-0.6B在C语言项目中的嵌入式集成
  • PasteMD高级配置指南:定制你的剪贴板转换规则
  • STM32蓝牙遥控机械臂:硬件匹配、协议解析与PWM运动控制
  • 解锁智能翻译工具:从入门到精通的游戏本地化实战指南
  • Pi0具身智能GitHub协作:开源项目管理实战
  • 使用VSCode调试通义千问3-Reranker-0.6B模型的完整指南
  • Pi0具身智能模型安全防护与对抗样本防御
  • 图片旋转判断高效率:单卡4090D每小时处理2.7万张JPEG/PNG图像
  • Seedance2.0像素级一致性算法原理(含3类典型失效场景的数学建模+Jacobian奇异点规避策略)
  • 六音音源配置完全指南:音乐播放修复与音源配置优化详解
  • 【工业视觉落地生死线】:Seedance2.0突破传统光流局限的4层自适应一致性验证机制,已通过ISO/IEC 19794-5认证
  • Qwen3-Reranker-8B量化部署:在边缘设备上的实践
  • 云容笔谈从零开始:东方审美影像生成系统环境搭建与首次生成步骤
  • 音乐插件系统:多平台音频资源聚合解决方案
  • HC-05/HC-06蓝牙模块AT指令配置全解析
  • Nano-Banana在MobaXterm中的远程开发配置
  • 春联生成模型-中文-base实战教程:两字祝福词一键生成高清春联
  • 突破限制:Windows多用户远程访问完全指南(2024实测有效)
  • AnimateDiff插件开发:C++高性能扩展模块编写指南
  • 漫画脸提示词生成器:Vue前端集成Qwen3-32B模型实战
  • Chandra AI与强化学习结合:游戏AI开发实战
  • 高效留存与智能管理:内容导出工具XHS-Downloader全攻略
  • AIGlasses OS Pro Python爬虫实战:智能网页内容抓取
  • HY-Motion 1.0在计算机网络教学中的可视化应用