当前位置: 首页 > news >正文

基于PyTorch的VGG19图像分类——从CPU到DLP的完整实践

【智能计算系统】实验三:基于PyTorch的VGG19图像分类——从CPU到DLP的完整实践(附完整代码)

本文是智能计算系统课程实验三的完整实现,使用PyTorch框架实现基于VGG19网络的图像分类,并在CPU和DLP平台上进行推理。通过对比实验一、二,展示使用编程框架的便捷性和DLP的加速效果。

一、实验概述

本实验目的是掌握PyTorch编程框架的使用,在CPU平台上使用PyTorch实现基于VGG19网络的图像分类,并在DLP平台上完成图像分类。

实验环境:

  • 硬件:CPU、DLP
  • 软件:Torch 1.6.0、CNNL高性能算子库、CNRT运行时库、Python 3.7.4

二、VGG19网络介绍

VGG19是Visual Geometry Group在2014年提出的深度卷积神经网络,在ImageNet图像分类任务上取得了优异的成绩。

网络结构特点:

  • 使用3×3的小卷积核,通过堆叠增加网络深度
  • 使用2×2的最大池化层进行下采样
  • 包含16个卷积层和3个全连接层
  • 总参数量约1.44亿

三、核心代码实现

3.1 VGG19网络定义

使用PyTorch的nn.Sequential构建VGG19网络:

import torch
import torch.nn as nncfgs = [64,'R', 64,'R', 'M', 128,'R', 128,'R', 'M',256,'R', 256,'R', 256,'R', 256,'R', 'M',512,'R', 512,'R', 512,'R', 512,'R', 'M',512,'R', 512,'R', 512,'R', 512,'R', 'M']def vgg19():layers = ['conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1','conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2','conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3','conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4','conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5','flatten', 'fc6', 'relu6','fc7', 'relu7', 'fc8', 'softmax']layer_container = nn.Sequential()in_channels = 3num_classes = 1000conv_cfgs = [c for c in cfgs if isinstance(c, int)]cfg_idx = 0for i, layer_name in enumerate(layers):if layer_name.startswith('conv'):out_channels = conv_cfgs[cfg_idx]cfg_idx += 1layer_container.add_module(layer_name, nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))in_channels = out_channelselif layer_name.startswith('relu'):layer_container.add_module(layer_name, nn.ReLU(inplace=True))elif layer_name.startswith('pool'):layer_container.add_module(layer_name, nn.MaxPool2d(kernel_size=2, stride=2))elif layer_name == 'flatten':layer_container.add_module(layer_name, nn.Flatten())elif layer_name == 'fc6':layer_container.add_module(layer_name, nn.Linear(25088, 4096))elif layer_name == 'fc7':layer_container.add_module(layer_name, nn.Linear(4096, 4096))elif layer_name == 'fc8':layer_container.add_module(layer_name, nn.Linear(4096, num_classes))elif layer_name == 'softmax':layer_container.add_module(layer_name, nn.Softmax(dim=1))return layer_container

3.2 生成.pth权重文件

从.mat文件加载预训练权重并保存为.pth格式:

import scipy.io
from collections import OrderedDictdef generate_pth():datas = scipy.io.loadmat(VGG_PATH)model = vgg19()new_state_dict = OrderedDict()for i, param_name in enumerate(model.state_dict()):name = param_name.split('.')if name[-1] == 'weight':new_state_dict[param_name] = torch.from_numpy(datas[str(i)]).float()else:new_state_dict[param_name] = torch.from_numpy(datas[str(i)][0]).float()model.load_state_dict(new_state_dict)torch.save(model.state_dict(), 'models/vgg19.pth')

3.3 图像预处理

使用torchvision.transforms进行图像预处理:

from PIL import Image
from torchvision import transformsdef load_image(path):image = Image.open(path).convert('RGB')transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])image = transform(image)image = image.unsqueeze(0)return image

3.4 CPU平台推理

import timeif __name__ == '__main__':input_image = load_image(IMAGE_PATH)net = vgg19()net.load_state_dict(torch.load(VGG_PATH, map_location='cpu'))net.eval()st = time.time()prob = net(input_image)print("cpu infer time:{:.3f} s".format(time.time()-st))with open('./labels/imagenet_classes.txt') as f:classes = [line.strip() for line in f.readlines()]_, indices = torch.sort(prob, descending=True)print("Classification result: id = %s, prob = %f " % (classes[indices[0][0]], prob[0][indices[0][0]].item()))if classes[indices[0][0]] == 'strawberry':print('TEST RESULT PASS.')

3.5 DLP平台推理

使用torch_mlu在DLP上进行推理:

import torch_mlu
import torch_mlu.core.mlu_model as ctif __name__ == '__main__':input_image = load_image(IMAGE_PATH)net = vgg19()net.load_state_dict(torch.load(VGG_PATH, map_location='cpu'))net.eval()# 使用JIT trace优化example_forward_input = torch.rand((1,3,224,224), dtype=torch.float)net_trace = torch.jit.trace(net, example_forward_input, check_trace=False)# 移动到DLP设备input_image = input_image.to(ct.mlu_device())net_trace = net_trace.to(ct.mlu_device())st = time.time()prob = net_trace(input_image)print("mlu370<cnnl backend> infer time:{:.3f} s".format(time.time()-st))prob = prob.cpu()with open('./labels/imagenet_classes.txt') as f:classes = [line.strip() for line in f.readlines()]_, indices = torch.sort(prob, descending=True)print("Classification result: id = %s, prob = %f " % (classes[indices[0][0]], prob[0][indices[0][0]].item()))if classes[indices[0][0]] == 'strawberry':print('TEST RESULT PASS.')

四、运行结果

平台 推理时间 分类结果
CPU 约0.5-1.0秒 strawberry(概率约0.99)
DLP 约0.01-0.05秒 strawberry(概率约0.99)

性能提升:约10-50倍

五、与实验一、二的对比

对比项 实验一 实验二 实验三
代码复杂度 手动实现约100行 pycnnl约50行 PyTorch约30行
网络类型 三层全连接 三层全连接 VGG19卷积网络
参数量 约100万 约100万 约1.44亿
推理平台 CPU DLP CPU + DLP

六、评分标准

分数 要求
60分 正确生成.pth文件
80分 CPU上正确推理,得到正确分类结果
100分 DLP上正确推理,处理时间相比CPU有明显提升

七、实验总结

通过本实验,我掌握了PyTorch框架的使用:

  1. PyTorch提供了简洁的API来构建复杂的神经网络
  2. 使用nn.Sequential可以方便地堆叠各种网络层
  3. torchvision.transforms提供了丰富的图像预处理工具
  4. torch_mlu库可以方便地将模型迁移到DLP平台
  5. 相比手动实现,使用框架可以大大提高开发效率

GitHub仓库地址: https://github.com/NiMark886/smart-computing-exp3-vgg19-pytorch

Gitee仓库地址: https://gitee.com/NiMark886/smart-computing-exp3-vgg19-pytorch

http://www.jsqmd.com/news/904875/

相关文章:

  • 手把手教你学Simulink——电动汽车V2G(车网互动)双向DC-AC充电逆变器建模
  • BG3模组管理器终极指南:5步解决模组冲突,轻松管理《博德之门3》模组
  • 基于Arduino的智能植物监测系统DIY:从传感器到低功耗设计
  • 终极指南:如何用SleeperX彻底掌控Mac睡眠行为
  • P1325 雷达安装【洛谷算法习题】
  • 国内优质砖雕厂家实力排行:工艺与服务全维度对比 - 奔跑123
  • 2026年5月徐州黄金回收哪家好?10家实测+选店避坑全攻略 - 生活测评君
  • Ant Design Pro v6.0.2 发布:升级 antd、新增 AI 辅助升级能力,多项功能改进
  • 基于Arduino与FFT的音频频谱可视化:从原理到实现的完整指南
  • Zabbix监控初步搭建
  • 猫抓浏览器扩展完全指南:告别网页资源获取烦恼
  • 2026年5月泰安黄金回收哪家好?8家实测+避坑全攻略 - 生活测评君
  • 2026年5月停车场出入口设备厂家选型攻略|智慧停车采购指南 - TOP10品牌推荐榜单
  • 有什么软件可以去视频水印?四款小程序加桌面工具实测
  • 2026年国内3大主流一物一码服务商对比:中大型快消选型权威测评报告 - 纳宝科技一物一码
  • 山东省 乳山市寄件省钱天花板!2026全国靠谱快递平台实测,低价寄件不踩坑 - 时讯资讯
  • 2026广州白云区注册公司攻略|靠谱财税代办机构TOP5科普推荐 - GrowthUME
  • 【数据分析】python-pandas速查文档(3)
  • Sora 2 AI主播生成全链路拆解:从提示词工程、语音驱动到唇形同步的7大关键技术突破
  • 基于DLP平台的手写数字分类——CPU到深度学习处理器的加速实践
  • 从零打造蓝牙遥控履带车:Arduino、3D打印与FPV系统全解析
  • 2025泉州除甲醛公司Top5深度测评:绿舒环保稳居榜首 - 绿舒环保母婴除甲醛
  • 2026年最值得关注的8款AI简历工具深度解析
  • 基于Raspberry Pi Pico W的Wi-Fi邮件报警系统设计与实现
  • 踩坑!JDK8u371 报 No appropriate protocol,加启动参数无效
  • 选择题专练数据库原理精选30题
  • 如何使用Legacy iOS Kit实现旧款iOS设备降级与越狱的完整指南
  • Arduino LED乒乓球游戏:从电路设计到状态机编程的嵌入式开发实践
  • 2.隐藏账户
  • crabc - api 开源项目更名 ApiGo,一站式 API 数据服务平台更新多项功能