当前位置: 首页 > news >正文

Pytorch基础——(3)神经网络工具箱

文章目录

  • 一、基础知识
  • 二、构建模型
    • 1.1 方法1:继承nn.Model基类构建模型
    • 1.2 方法2:使用 nn.Sequential 容器
      • 1.2.1 添加参数
      • 1.2.2 add_module可指定名称
      • 1.2.2 orderedDict可指定名称
    • 1.3 结合1和2,集成基类并使用模拟容器
      • 1.3.1 使用nn.Sequential()
      • 1.3.2 使用ModuleList
      • 1.3.3 使用ModuleDict
  • 三、训练和评估模型
    • 3.1 训练和评估步骤
    • 3.2 Python代码

一、基础知识

torch.nn:nn 是 Neural Networks(神经网络) 的缩写。它就是一个工具库,里面包含了深度学习常用的所有零件建造神经网络的工具箱。nn 就是给你提供积木,让你搭神经网络。最常用的:

  • 层结构
    • nn.Linear(5, 1) 线性层(全连接层)
    • nn.Conv2d 卷积层
    • nn.ReLU / nn.Sigmoid 激活函数
  • 损失函数
    • nn.MSELoss() 回归损失
    • nn.CrossEntropyLoss() 分类损失
  • 其他工具
    • nn.Sequential 快速堆叠网络
    • nn.BatchNorm2d 归一化

nn.model:是所有模型的骨架 / 模板 / 父类,模型都必须继承 nn.Module,因为它帮你自动完成:

  • 自动管理参数 w、b
  • 自动求梯度(backward)
  • 自动更新参数
  • 自动保存 / 加载模型
  • 自动把模型搬到 GPU

nn.functional:nn.functional(简称 F),是 PyTorch 里的数学函数工具箱它里面全是无参数、纯计算的函数,专门用来做:

  • 激活计算(relu、sigmoid)
  • 损失计算(mse_loss)
  • 池化计算(max_pool)
  • 归一化、softmax、dropout 等

nn.model 里面的 nn.Xxx(nn.Linear)

  • nn.Xxx 继承于 nn.model,需要先实例化并传入参数,然后以函数调用的方式,调用实例对象,并传入输入数据
  • nn.Xxx不需要自己定义和管理 weight 和 bias 参数
importtorchimporttorch.nnasnn# 1. 实例化(创建层,并传入参数 w 和 b)layer=nn.Linear(5,1)# 2. 构造输入数据x=torch.randn(10,5)# 3. 把层当函数用,函数调用实例对象 layer, 传入输入数据y=layer(x)# 等价于 y = layer.forward(x)

nn.functinal里面的函数 nn.functional.xxx(nn.functional.linear)

  • nn.functional.xxx需要自己定义和管理 weight 和 bias 参数,每次调用的时候需要手动传入。

建议:

  • 具有学习参数的:必须用 nn.XXX,如 Linear, Conv2d, BatchNorm,不能用 nn.functional 代替,因为权重不会被自动管理。
  • 没有学习参数的参数,如 relu, pool, softmax,推荐用 nn.functional
importtorch.nn.functionalasF F.relu(x)# 激活函数F.sigmoid(x)F.tanh(x)F.max_pool2d(x)# 池化F.avg_pool2d(x)F.dropout(x)# 随机失活F.softmax(x,dim=1)

二、构建模型

如下图,采用不同的方式构建一下神经网络:

1.1 方法1:继承nn.Model基类构建模型

nn.Module是 PyTorch 中所有神经网络模块的基类,它提供了三大核心功能:

  • 参数管理:自动跟踪和管理网络中的可学习参数(权重、偏置等),方便优化器更新。
  • 子模块管理:自动注册和管理网络中的子层(如 nn.Linear、nn.Conv2d 等),支持递归遍历。
  • 前向传播接口:规定必须实现 forward 方法,定义数据在网络中的流动逻辑。

完整步骤:

  • 导入必要的库;
  • 定义一个类继承 nn.Module 并实现2个核心方法:
    • 在_init_ 方法中定义需要用到的层·
      • 调用父类 nn.Module 的初始化方法
      • 定义网络中用到的所有层(如全连接层、批归一化层等),并绑定为类的属性(self.xxx)。
    • 在 forward 方法中手动编写数据流动逻辑(前向传播)
      • 模型的输入数据为x
      • 定义数据 x 如何从输入经过各层处理,最终输出结果
importtorchimporttorch.nnasnnimporttorch.nn.functionalasFclassMyModule(nn.Module):def__init__(self,indim,h1,h2,outdim):super(MyModule,self).__init__()self.flatten=nn.Flatten()self.Linear1=nn.Linear(indim,h1)self.bn1=nn.BatchNorm1d(h1)self.Linear2=nn.Linear(h1,h2)self.bn2=nn.BatchNorm1d(h2)self.out=nn.Linear(h2,outdim)defforward(self,x):x=self.flatten(x)x=self.Linear1(x)x=self.bn1(x)x=F.relu(x)# 用nn.functionalx=F.relu(self.bn2(x=self.Linear2(x)))# 全连接层2+批归一化2+激活层2x=self.out(x)x=F.softmax(x,dim=1)# 用nn.functionalreturnx model=MyModule(28*28,300,100,10)print(model)
MyModule((flatten): Flatten(start_dim=1,end_dim=-1)(layer1): Linear(in_features=784,out_features=300,bias=True)(bn1): BatchNorm1d(300,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)(layer2): Linear(in_features=300,out_features=100,bias=True)(bn2): BatchNorm1d(100,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)(out): Linear(in_features=100,out_features
http://www.jsqmd.com/news/707064/

相关文章:

  • Phi-3-mini-4k-instruct-gguf效果展示:Chainlit前端实时流式输出+Markdown格式化响应截图
  • 从0到1集成FlyRefresh:Android开发者必备的下拉刷新解决方案
  • 2026年怎么选变压器生产厂家:变压器回收价格/变压器回收公司/变压器回收厂家/变压器回收多少钱一台/干式变压器厂家/选择指南 - 优质品牌商家
  • 2.6 应用容器:给应用套上的“现代化沙箱”
  • TVA检测技术在普通电子元器件领域的全维度解析(17)
  • 团体程序设计天梯赛竞赛题--登顶题【L3-043 门诊预约排队系统】
  • 南京邮电大学电装实习报告-2026版
  • 大学生就业信息管理|基于java+ vue大学生就业信息管理系统(源码+数据库+文档)
  • Qwen-Turbo-BF16部署教程:离线环境预下载模型权重与LoRA文件校验方案
  • AI项目环境管理利器:PyTorch 2.9云端镜像多实例使用攻略
  • 【Linux3】压缩解压缩,命令解释器,账户和组管理,文件系统权限
  • Arm A-profile架构TLB维护与内存管理机制解析
  • nlp_structbert_sentence-similarity_chinese-large效果展示:多领域中文文本相似度计算案例集
  • Python时间序列数据分析:从基础到实战
  • Qianfan-OCR在MobaXterm中的实践:远程服务器部署与中文环境调试
  • Phi-3.5-Mini-Instruct实战手册:系统提示词工程——从通用助手到领域专家
  • C++位图学习笔记
  • 【大白话说Java面试题】【Java基础篇】第8题:HashMap在计算元素下标时,为什么要进行二次hash
  • 线性表小回顾
  • Linux 0.11源码深度解析:kernel/chr_drv/tty_io.c —— 终端I/O的控制中枢与行规约引擎
  • Python新手在PyCharm写if总报错?5个坑90%人踩过,看完修复
  • C语言函数全解析
  • AI自主监测宠物健康,陪狗都不用自己来了!涂鸦Hey Tuya打造全屋智能“超级入口”
  • 快速上手:使用Clawdbot将星图平台Qwen3-VL接入飞书,实现智能问答
  • 【Linux从入门到精通】第17篇:日志系统——系统运行的黑匣子
  • 深度解析YOLOv11多光谱目标检测的技术实现与性能优化
  • 第78篇:AI辅助创意与设计工作流——Logo、海报、UI的自动化生成与迭代(操作教程)
  • 万物识别中文镜像部署教程:环境配置与推理测试
  • Python Web框架实战:Flask与Dash构建数据应用
  • OpenClaw本地部署接入飞书机器人并安装Skills(图文并茂超详细)