当前位置: 首页 > news >正文

pytorch17->一张实际图片的识别实战

import torch import torchvision from PIL import Image from torch import nn from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear # ==================== 1. 网络结构(必须和训练时完全一致) ==================== class Tudui(nn.Module): def __init__(self): super(Tudui, self).__init__() self.model = Sequential( Conv2d(3, 32, 5, padding=2), MaxPool2d(2), Conv2d(32, 32, 5, padding=2), MaxPool2d(2), Conv2d(32, 64, 5, padding=2), MaxPool2d(2), Flatten(), Linear(1024, 64), Linear(64, 10) ) def forward(self, x): return self.model(x) # ==================== 2. 加载模型(直接用完整模型) ==================== # 用你保存的完整模型文件,选一个(比如 tudui_9.pth 是训练10轮后的) model = torch.load("tudui_9.pth", map_location=torch.device('cpu')) model.eval() print("✅ 模型加载成功!") # ==================== 3. 加载图片 ==================== image_path = "img/dog.jpg" # 改成你的图片路径 image = Image.open(image_path) print(f"✅ 图片加载成功,原始尺寸: {image.size}") # ==================== 4. 预处理图片 ==================== transform = torchvision.transforms.Compose([ torchvision.transforms.Resize((32, 32)), torchvision.transforms.ToTensor(), ]) image_tensor = transform(image) print(f"✅ 预处理后尺寸: {image_tensor.shape}") # 加 batch 维度 image_tensor = torch.reshape(image_tensor, (1, 3, 32, 32)) print(f"✅ 添加 batch 维度后: {image_tensor.shape}") # ==================== 5. 推理 ==================== with torch.no_grad(): output = model(image_tensor) predict = output.argmax(1).item() # ==================== 6. 输出结果 ==================== classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] print(f"\n{'='*40}") print(f"🎯 模型预测结果: {classes[predict]}") print(f"{'='*40}") print("\n📊 各类别得分详情:") for i, score in enumerate(output[0]): print(f" {classes[i]}: {score:.4f}")

1.选用一个训练过10轮的网络用cpu进行测试

model = torch.load("tudui_9.pth", map_location=torch.device('cpu')) model.eval()

2.把图片放入指定路径,打开

image_path = "img/dog.jpg" # 改成你的图片路径 image = Image.open(image_path)

3.修改图片的像素为32*32,必须修改,因为你的卷积层第一步就是Conv2d(3, 32, 5, padding=2),他只能卷32*32的,别的图片他会报错。然后再张量化,为后续操作做准备

transform = torchvision.transforms.Compose([ torchvision.transforms.Resize((32, 32)), torchvision.transforms.ToTensor(), ])

4.PyTorch 官方在设计和实现Conv2dBatchNorm2dLinear等层时,就规定了输入必须是 4 维张量,(batch, channels, height, width),而单张图片默认是 3 维的(channels, height, width)所以必须reshape

image_tensor = torch.reshape(image_tensor, (1, 3, 32, 32)) print(f"✅ 添加 batch 维度后: {image_tensor.shape}")

5.

with torch.no_grad(): output = model(image_tensor) predict = output.argmax(1).item()

with torch.no_grad():不计算梯度,只计算不更新模型

output = model(image_tensor),output是什么?

output │ ├── 类型:torch.Tensor(PyTorch 张量) │ ├── 形状:torch.Size([1, 10]) │ │ │ ├── 第0维大小=1(1张图) │ └── 第1维大小=10(10个类别) │ ├── 数据类型:torch.float32(32位浮点数) │ └── 存储的内容:10个浮点数

output是一个张量,存了二维的数,第0维是多少张图片,第1维是10个类型的得分。

output.argmax(1)拿到第一维的10个类型中得分最高的位置,output.argmax(1).item()给他从张量还原回数字,从而得到序号

为什么output是二维的?

因为模型会始终保持 batch 维度。输入是[batch, 3, 32, 32],输出也是[batch, 10]。batch 是第一维,类别得分是第二维。

6.通过predict序号找出类别,输出各类别和得分

classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] print(f"\n{'='*40}") print(f"🎯 模型预测结果: {classes[predict]}")
print("\n📊 各类别得分详情:") for i, score in enumerate(output[0]): print(f" {classes[i]}: {score:.4f}")
http://www.jsqmd.com/news/1084170/

相关文章:

  • 为什么AI只引用2-7个网站?内容结构优化才是GEO的隐藏密码!
  • volatile 这个坑,很多 STM32 新手都踩过
  • 03_Agent智能体与LangGraph
  • 出版商联盟指控 OpenAI 与微软:未经授权用作品训练 AI,版权诉讼再升级!
  • DESIGN.md:为编码代理提供设计系统持久结构化理解,支持多格式转换
  • AI 辅助智能合约安全审计:从静态分析到 LLM 漏洞检测的工程实战
  • 抖音音频下载终极指南:5分钟学会免费提取抖音热门背景音乐
  • 如何校准LED显示屏色彩均匀性以消除视觉马赛克
  • 华强北内存降价,资本市场却疯涨!内存缺货真相究竟几何?
  • Navicat Mac版无限试用期终极指南:3种方法实现永久免费使用
  • 【2026】超详细GraphPad Prism 10安装保姆级教程,永久免费使用,科研绘图和数据分析指南,看完这一篇就够了
  • 3分钟轻松搞定!为Royal TSX添加完美中文汉化包,告别英文界面困扰
  • 高通近 40 亿美元收购 Modular,拓展业务进军 AI 与数据中心市场
  • 科技企业如何通过智能化工具快速识别行业技术趋势并优化研发方向?
  • AWVS实战:构建自动化扫描与手动验证的Web漏洞评估闭环
  • +1毛也是首选!申通这家五星网点的底气
  • JMeter性能测试从入门到实战:核心组件、脚本编写与结果分析
  • Anuttacon研究模拟多智能体社会系统Agentopia:让AI更有人味儿,但仍面临挑战
  • Kill-Doc:浏览器脚本实现一站式文档下载解决方案
  • 工信局如何利用数智工具判断技术改造项目的可行性?
  • StarRailAssistant:解放双手的崩坏星穹铁道智能助手完全指南
  • ComfyUI ControlNet Aux完全指南:解锁40+图像预处理节点的终极AI绘画控制方案
  • JMeter压测实战:秒杀场景下401与200异常问题的深度排查与优化
  • 如何彻底解决游戏按键冲突:Hitboxer智能按键重映射完全指南
  • Deep3D深度解析:实时端到端2D转3D视频转换技术架构与实现原理
  • 云南旅游产品设计拆解:一条8天线路背后的逻辑
  • 从圈量子引力与分形几何到凯瑟琳轮:一个跨学科计算模型的构建
  • 专业防火墙管理方案:Destiny 2 Solo Enabler技术深度解析
  • SSL证书验证失败全解析:从诊断到修复的实战指南
  • 音频格式解码之opus