当前位置: 首页 > news >正文

用pytorch来自动求导

PyTorch 提供了一个非常强大的自动求导引擎(Autograd),它能够自动计算神经网络中张量的梯度,是训练深度学习模型的基础。

1. 核心概念:计算图与 requires_grad

在 PyTorch 中,当创建一个张量(Tensor)并设置 requires_grad=True 时,PyTorch 会开始追踪对该张量的所有操作,并构建一个计算图。这个图记录了从输入张量到输出张量的计算过程,为后续的梯度计算做准备。

import torch# 创建一个需要梯度的张量
x = torch.tensor([2.0], requires_grad=True)
print(x)
# 输出: tensor([2.], requires_grad=True)

2. 正向传播:构建计算图

对设置了 requires_grad 的张量进行运算,会生成新的张量,并且这些新张量也会自动关联梯度信息。

y = x ** 2      # y = x^2
z = y.mean()    # 对标量求导更方便,通常损失是标量
print(z)        # 输出: tensor(4., grad_fn=<MeanBackward0>)

此时,z 是一个标量(因为只有一个元素),并且它有一个 grad_fn 属性,表明它是如何计算出来的,PyTorch 已经构建好了从 xz 的计算图。

3. 反向传播:自动计算梯度

调用 z.backward() 即可自动计算 z 对所有需要梯度的张量的梯度。梯度会累积在张量的 .grad 属性中。

z.backward()    # 反向传播
print(x.grad)   # 查看 dz/dx
# 输出: tensor([4.]) 
# 因为 z = mean(x^2) = (x^2)/1 当只有一个元素时,z=x^2,导数为 2x,当 x=2 时,导数为 4

4. 在训练循环中的应用:梯度清零与参数更新

在训练神经网络时,我们通常使用优化器(如 SGD)来更新参数。典型流程如下:

  1. 正向传播计算损失(例如 MSE)。
  2. 反向传播计算梯度(loss.backward())。
  3. 优化器更新参数(optimizer.step())。
  4. 清零梯度(optimizer.zero_grad()),防止梯度累积。

结合 MSE 和 SGD 的示例:

import torch
import torch.nn as nn
import torch.optim as optim# 模拟数据:真实参数 w_true = 2.0, b_true = 1.0
x = torch.tensor([[1.0], [2.0], [3.0]])   # 输入
y_true = 2.0 * x + 1.0                     # 真实输出# 定义模型参数(需要梯度)
w = torch.randn(1, 1, requires_grad=True)
b = torch.zeros(1, requires_grad=True)# 定义损失函数(MSE)
criterion = nn.MSELoss()# 定义优化器(SGD,学习率 0.01)
optimizer = optim.SGD([w, b], lr=0.01)# 训练几个 epoch
for epoch in range(100):# 正向传播:预测y_pred = x @ w + b   # @ 表示矩阵乘法# 计算损失(MSE)loss = criterion(y_pred, y_true)# 反向传播:计算梯度loss.backward()# 更新参数optimizer.step()# 清零梯度optimizer.zero_grad()# 打印损失if (epoch+1) % 20 == 0:print(f'Epoch {epoch+1}, Loss: {loss.item():.4f}, w: {w.item():.4f}, b: {b.item():.4f}')

输出
Pasted image 20260227210211

可以看到,通过自动求导计算梯度,并用 SGD 更新,参数 wb 逐渐接近真实值 2.0 和 1.0。

5. 注意

  • 梯度累积:默认情况下,backward() 会累积梯度,而不是替换。所以在每个 batch 更新后需要手动清零(optimizer.zero_grad()),否则梯度会累加。
  • 禁用梯度追踪:在模型推理或评估时,不需要计算梯度,可以使用 torch.no_grad() 上下文管理器,以减少内存消耗和加速计算。
    with torch.no_grad():y_pred = model(x_test)   # 这里不会构建计算图
  • 保留计算图:如果需要多次调用 backward()(例如在某些复杂的模型中),可以传递 retain_graph=Truebackward()
  • 高阶导数:通过设置 create_graph=True 可以计算二阶导数。
http://www.jsqmd.com/news/418266/

相关文章:

  • 网易云音乐信息采集可视化分析系统 | 技术栈Flask+Echarts 多模块全流程实现 毕业设计源码 deepseek 人工智能 深度学习
  • ue 日志等级
  • 泓动数据各地区官方联系方式,如何联系到泓动数据咨询GEO业务 - 资讯焦点
  • 需要学习的东西
  • pycharm 启动关闭flask 关闭test
  • IEXS盈十证券:距活动结束仅剩半月,10倍收益加成与特斯拉豪礼静待最后赢家 - 资讯焦点
  • 【AI+教育】用飞书多维表格,零门槛实现教学内容自动化
  • 2026年昆山离婚律师专业甄选推荐:从资质到案例的全方位实用指南 - 资讯焦点
  • 从SEO到GEO: 一位百度前算法工程师的十五年探索 - 资讯焦点
  • 【项目实战】VSCode 里 Git 怎么提交空文件夹?超简单教程
  • 使用 IDEA 插件 JarEditor 修改 JAR 文件,无需手动解压重打包
  • 罗小军拆解AI“黑箱”:生成式引擎挑选答案的四步机制 - 资讯焦点
  • 上海正品兔宝宝全屋定制购买指南:源头工厂选择核心攻略 - 资讯焦点
  • 上海嘉定博园路全屋定制工厂怎么选?靠谱选择指南 - 资讯焦点
  • 吉舍吉屋定制工厂:以“快、真、新”重塑长三角高端定制家居代工新标杆 - 资讯焦点
  • 带指针的结构体-链表节点-随笔
  • 2026年2月液压货梯实力品牌,自动化升降控制技术深度解析 - 品牌鉴赏师
  • 系分/架构——案例之可行性分析
  • 还在手撸提示词?向量引擎+Flux才是AI绘画的终极外挂,画师看完都沉默了...
  • 系分/架构——领域驱动设计之战略设计
  • 基于深度学习的YOLOv8木材缺陷检测系统 deepseek可定制 木材死结木材裂缝图像识别(数据集+模型+jpyqy界面)
  • 单片机基础知识 -- 普通推挽和复用推挽模式
  • 大数据领域Kafka的消息队列容量规划
  • Python基于Vue的软件产品展示销售系统 django flask pycharm
  • Dify搭建Agent
  • P1012 [NOIP 1998 提高组] 拼数题解
  • Qt 的 .ui (XML) 文件和 WPF 的 .xaml (XML) 文件
  • CompletableFuture 完全指南:定义、使用、场景与实战
  • 深度学习--卷积神经网络之迁移学习ResNet
  • MSYS 环境下 GCC 启用本地化支持