当前位置: 首页 > news >正文

线性回归代码

import torch import matplotlib.pyplot as plt #画图的 import random #产生随机数 def create_data(w, b, data_num): #生成数据 x = torch.normal(0, 1, (data_num, len(w))) #生成随机数x y = torch.matmul(x, w)+b #matmul表示矩阵相乘 noise = torch.normal(0, 0.01, y.shape) #噪声加在y上 y += noise return x, y num = 500 true_w = torch.tensor([8.1,2,2,4]) #必须用张量才能相乘 true_b = torch.tensor(1.1) X, Y = create_data(true_w, true_b, num) plt.scatter(X[:, 3], Y, 1) plt.show() def data_provider(data, label, batchsize): #每次访问这个函数,就能提供一批数据 length = len(label) indices = list(range(length)) random.shuffle(indices) #不能按顺序取数据,把数据打乱 for each in range(0,length, batchsize): get_indices = indices[each:each+batchsize] get_data = data[get_indices] get_label = label[get_indices] yield get_data, get_label #存档点的return 分批次返回数据,每次从这里取一个batch_size的数据,下次从这里继续取 batchsize = 16 # for batch_x, batch_y in data_provider(X, Y, batchsize): # print(batch_x, batch_y) # break def fun(x, w, b): pred_y = torch.matmul(x, w)+b return pred_y def maeLoss(pre_y, y): #计算loss:预测值和真实值 return torch.sum(abs(pre_y-y))/len(y) def sgd(paras, lr): #随机梯度下降,更新参数,梯度保存在参数中 with torch.no_grad(): #属于这句代码的部分,不计入梯度 for para in paras: para-=para.grad* lr para.grad.zero_() #使用过的梯度,归0 lr = 0.03 w_0 = torch.normal(0, 0.01, true_w.shape, requires_grad=True) #这个w需要计算梯度 b_0 = torch.tensor(0.01, requires_grad=True) print(w_0, b_0) epochs = 50 for epoch in range(epochs): data_loss = 0 for batch_x, batch_y in data_provider(X, Y, batchsize): #取一批数据 pred_y=fun(batch_x, w_0, b_0) loss = maeLoss(pred_y,batch_y) loss.backward() #梯度回传 sgd([w_0, b_0], lr) #更新参数 data_loss += loss #统计一轮所有loss,loss越小越好 print("epoch %03d: loss: %.6f"%(epoch, data_loss)) print("真实的函数值是" , true_w, true_b) print("训练得到的参数值是", w_0, b_0) idx = 2 #画图,可视化 plt.plot(X[:, idx].detach().numpy(), X[:, idx].detach().numpy()*w_0[idx].detach().numpy()+b_0.detach().numpy()) #只能取某一列作为自变量 plt.scatter(X[:, idx], Y, 1) plt.show()
http://www.jsqmd.com/news/492923/

相关文章:

  • 告别数据孤岛:基于WebDAV的Zotero与InfiniCLOUD跨平台同步实战
  • Linux操作系统(一)
  • 免费降AI率工具横评:嘎嘎vs比话vs率零谁更值
  • 深入分代 GC:新生代内存不足时的对象晋升规则
  • 用XGO Rider教孩子学编程:从Scratch到Python的AI机器人实战教程
  • Linux apt commands All In One
  • 游戏原画师福音:Kook Zimage真实幻想Turbo保姆级入门教程
  • 《道德经》第三章
  • 草莓成熟度目标检测数据集(2000张图片已标注)| YOLO训练数据集 AI视觉检测
  • 【已解决】xFormers安装报错:CPATH环境变量缺失导致cuda_runtime.h找不到
  • 【YOLOv11工业级实战】32. 超轻量分割模型实战:YOLOv11-seg剪枝+蒸馏压缩至2MB(精度仅降2%)
  • 解锁Edge内置Copilot:无需插件,一键直达GPT-4 Turbo智能助手
  • Z-Image Turbo性能评测:不同硬件下的生成速度对比
  • ESP32智慧时钟:嵌入式物联网教学硬件平台设计
  • 如何在MacBook上安装部署原生openclaw
  • 嘉立创EDA专业版多账号管理技巧:如何避免激活文件冲突
  • 一篇文章掌握PyQt5高级表格开发:从零复现工业级加工步骤设置界面
  • wan2.1-vae惊艳效果展示:人物写实度对比——发丝/皮肤纹理/瞳孔反光细节放大
  • Fish Speech 1.5镜像交付物清单:含启动脚本、日志、配置、证书模板
  • PP-DocLayoutV3内网穿透部署方案
  • 【Dify私有化部署黄金标准】:工信部等保三级/ISO 27001双认证配置模板(含OpenTelemetry全链路追踪脚本)
  • DeOldify图像上色服务效果深度评测:多场景色彩还原对比
  • Llama-3.2V-11B-cot 安全与合规:模型输入输出过滤与内容审核策略
  • Android Studio 安装教程(小白零基础,2026最新版,全程避坑)
  • 实测封神!6款小学语文学习APP,解放家长还提分 - 品牌测评鉴赏家
  • OSPF基础配置实验
  • 跨浏览器必备:高效IP定位查询扩展推荐(Edge/Chrome/Firefox全支持)
  • 实测4类小学语文素养线上课|告别盲目报课,1-6年级素养提升不踩坑 - 品牌测评鉴赏家
  • OpenWrt在树莓派Zero2W上的实战:如何用USB网卡替代板载WiFi
  • AI原生应用上下文理解:为智能交互添砖加瓦