当前位置: 首页 > news >正文

深度学习第三章,线性表示

#模型就是函数 #李哥样例:(我的分部解析) import torch import matplotlib.pyplot as plt # 画图的包 import random #随机 def create_data(w, b, data_num): #生成数据 x = torch.normal(0, 1, (data_num, len(w))) # torch.normal 函数用于正态分布 y = torch.matmul(x, w) + b #matmul表示矩阵相乘 noise = torch.normal(0, 0.01, y.shape) #噪声要加到y上 y += noise #设置噪声 return x, y num = 500 true_w = torch.tensor([8.1,2,2,4]) #创建张量 true_b = torch.tensor(1.1) #创建张量的好处:PyTorch中的张量支持多种操作,包括转置、索引、切片、数学运算、线性代数和随机数生成等。这些操作使得张量成为处理和变换数据的强大工具 X, Y = create_data(true_w, true_b, num) #X通过正态分布,Y通过线性组合 plt.scatter(X[:, 3], Y, 1)#绘图 plt.show() def data_provider(data, label, batchsize): #每次访问这个函数, 就能提供一批数据 length = len(label) indices = list(range(length))#创建一张从 0到 length - 1的连续整数列表 #我不能按顺序取 把数据打乱 random.shuffle(indices)#打乱顺序 for each in range(0, length, batchsize): get_indices = indices[each: each+batchsize] get_data = data[get_indices] get_label = label[get_indices] yield get_data,get_label #有存档点的return batchsize = 16 # for batch_x, batch_y in data_provider(X, Y, batchsize): # print(batch_x, batch_y) # break def fun(x, w, b): #实现简单的线性变换,基于输入的张量x\权重w和偏置b计算预测值 pred_y = torch.matmul(x, w) + b return pred_y def maeLoss(pre_y, y): #损失函数 return torch.sum(abs(pre_y-y))/len(y) def sgd(paras, lr): #随机梯度下降,更新参数 with torch.no_grad(): #属于这句代码的部分,不计算梯度 for para in paras: para -= para.grad * lr #不能写成 para = para - para.grad*lr para.grad.zero_() #使用过的梯度,归0 lr = 0.03 w_0 = torch.normal(0, 0.01, true_w.shape, requires_grad=True) #这个w需要计算梯度(参数1:均值,参数2:标准差,参数3:张量形状,参数4:是否启动微分机制) b_0 = torch.tensor(0.01, requires_grad=True) print(w_0, b_0) epochs = 50 for epoch in range(epochs): #机器学习训练循环 epochs是训练的总轮数 data_loss = 0##初始化一个变量 data_loss,用于累积当前 epoch 中所 有批次(batch)的损失值。目的是计算整个 epoch 的平均或总损失,用于评估训练效果。 for batch_x, batch_y in data_provider(X, Y, batchsize): #内层循环,从 data_provider 函数获取训练数据的批次(mini-batch)。data_provider 是一个自定义的数据迭代器或生成器,每次返回一个批次的输入数据 batch_x 和对应标签 batch_y。X, Y 是完整的训练数据和标签。 batchsize 是批次大小,比如 32 或 64,表示每次训练多少条数据。 pred_y = fun(batch_x,w_0, b_0)#调用模型函数 fun进行前向计算 loss = maeLoss(pred_y, batch_y)#计算预测值 pred_y 和真实标签 batch_y 之间的损失。 loss.backward()#执行反向传播,自动计算损失函数相对于模型参数的梯度 sgd([w_0, b_0], lr)#调用自定义的随机梯度下降(SGD)优化函数,利用计算好的梯度更新模型参数 data_loss += loss print("epoch %03d: loss: %.6f"%(epoch, data_loss))#格式化 print("真实的函数值是", true_w, true_b) print("训练得到的参数值是", w_0, b_0) idx = 3 plt.plot(X[:, idx].detach().numpy(), X[:, idx].detach().numpy()*w_0[idx].detach().numpy()+b_0.detach().numpy()) plt.scatter(X[:, idx], Y, 1) plt.show() #一定注意维度一般维度不错就不会错

同学们尽量理解该部分每一行代码,过段时间进行知识回顾

训练后的结果

http://www.jsqmd.com/news/604687/

相关文章:

  • SpringBoot 三大参数注解详解:@RequestParam @RequestBody @PathVariable 区别及常用开发注解
  • 【C++ 引用全解析】左值 / 右值、左右值引用、万能引用及其底层原理:引用折叠
  • 如何在Windows上轻松安装安卓应用?APK-Installer完整指南
  • 关于Tsak Traker
  • 5大核心价值解析:Jsxer如何破解Adobe ExtendScript二进制黑盒
  • 2026自贡特殊儿童康复:自贡多动症儿童康复/自贡孤独症康复培训机构/自贡孤独症康复寄宿学校/选择指南 - 优质品牌商家
  • 免费且好用的精益工具在哪里?2026年精益工具清单整理
  • S2-Pro模型提示词(Prompt)工程高级教程:从基础到实战技巧
  • 终极Windows系统优化工具Dism++:从新手到专家的完整使用指南
  • 应急响应-vulntarget-n-勒索病毒应急靶场
  • Vue3中如何实现动态页面的SEO优化
  • 关于springboot的面试题
  • 23岁+计算机人注意!困在传统开发?这份大模型报告助你职场逆袭,薪资翻倍!
  • 华硕笔记本色彩修复终极指南:3步恢复完美显示效果
  • 文化墙13种常见工艺材质全解析|一篇讲透!建议收藏!
  • LangGraph 实战:搭建一个智能研发多Agent协作系统(含代码)
  • 嵌入式开发:在Clion中构建面向对象的STM32 C++编程框架
  • IDM 下载管理器 下载安装
  • sqlmap基本操作流程介绍
  • Realistic Vision V5.1虚拟摄影棚效果:烟雾/蒸汽/粉尘等大气介质物理模拟
  • 快速生成jdk配置交互教程:用快马平台制作可视化环境搭建原型
  • python telebot
  • Cobbler v3.3.7 配置 Ubuntu 24.04 无人值守安装,我踩过的那些坑(附完整脚本)
  • Koikatu HF Patch终极指南:3分钟解锁200+模组完整游戏体验
  • 领英大规模账户攻击事件技术溯源与反钓鱼防御体系研究
  • 嵌入式工程师必看:用STM32的PWM驱动Buck电路给MCU供电的5个坑
  • Redisson进阶:Lua脚本与API在分布式锁与限流中的深度整合
  • 如何从 Polygon 到 QOJ 无缝衔接
  • AI智能体刚火就“撞墙”?揭秘大厂落地最怕的巨坑,别掉进去了
  • 在Ubuntu里同时安装mozc和sogoupinyin输入法的后续故事