当前位置: 首页 > news >正文

使用PyTorch实现猫狗分类Python源码及准确度对比(CNN、VGG16迁移学习两张方式)

数据集下载:Dogs vs. Cats ~| Kaggle

不同方法准确度

方法轮数准确度
cnn567.64%
cnn1074.92%
cnn1573.42%
cnn2079.28%
cnn2578.28%
vgg16586.5%
vgg161086.98%
vgg161585.42%

cnn.py

importtorchimporttorch.nnasnnimporttorch.optimasoptimfromtorch.utils.dataimportDataset,DataLoaderfromtorchvisionimporttransformsfromPILimportImageimportos num_epochs=20batch_size=50learning_rate=0.001train_size=25000indices=torch.randperm(train_size)train_indices=indices[:20000]test_indices=indices[20000:]classDogsVsCatsDataset(Dataset):def__init__(self,root,train=True,transform=None):super().__init__()self.root=root self.transform=transform self.classes=["dog","cat"]self.files=[]self.labels=[]files=os.listdir(root)index=train_indicesiftrainelsetest_indicesforiinindex:file=files[i]self.files.append(file)if"dog"infile:self.labels.append(0)else:self.labels.append(1)def__len__(self):returnlen(self.files)def__getitem__(self,index):path=os.path.join(self.root,self.files[index])image=Image.open(path).convert("RGB")label=self.labels[index]ifself.transform:image=self.transform(image)returnimage,label transform=transforms.Compose([transforms.Resize((224,224)),transforms.RandomHorizontalFlip(),transforms.RandomVerticalFlip(),transforms.RandomRotation(degrees=30),# transforms.RandomResizedCrop(# size=224, scale=(0.08, 1.0), ratio=(0.75, 1.33333)# ),transforms.ColorJitter(brightness=0.2,contrast=0.2,saturation=0.2,hue=0.1),transforms.ToTensor(),transforms.Normalize(mean=(0.485,0.456,0.406),std=(0.229,0.224,0.225)),])train_dataset=DogsVsCatsDataset(root=".\\data\\Dogs Vs Cats\\train",train=True,transform=transform)test_dataset=DogsVsCatsDataset(root=".\\data\\Dogs Vs Cats\\train",train=False,transform=transform)train_loader=DataLoader(train_dataset,batch_size=batch_size,shuffle=True)test_loader=DataLoader(test_dataset,batch_size=batch_size,shuffle=False)device=torch.device("cuda"iftorch.cuda.is_available()else"cpu")classCNNModel(nn.Module):def__init__(self):super().__init__()self.cnn1=nn.Sequential(nn.Conv2d(3,24,kernel_size=3,stride=1,padding=1),nn.BatchNorm2d(24),nn.MaxPool2d(kernel_size=2,stride=2),)self.cnn2=nn.Sequential(nn.Conv2d(24,48,kernel_size=3,stride=1,padding=1),nn.BatchNorm2d(48),nn.MaxPool2d(kernel_size=2,stride=2),)self.cnn3=nn.Sequential(nn.Conv2d(48,96,kernel_size=3,stride=1,padding=1),nn.BatchNorm2d(96),nn.MaxPool2d(kernel_size=2,stride=2),)self.cnn4=nn.Sequential(nn.Conv2d(96,48,kernel_size=3,stride=1,padding=1),nn.BatchNorm2d(48),nn.MaxPool2d(kernel_size=2,stride=2),)self.dropout=nn.Dropout()self.line1=nn.Linear(14*14*48,512)self.line2=nn.Linear(512,2)defforward(self,x):out=self.cnn1(x)out=self.cnn2(out)out=self.cnn3(out)out=self.cnn4(out)out=out.reshape(out.size(0),-1)out=self.dropout(out)out=self.line1(out)out=self.line2(out)returnout model=CNNModel().to(device)criterion=nn.CrossEntropyLoss()optimizer=optim.Adam(model.parameters(),lr=learning_rate)model.train()forepochinrange(num_epochs):fori,(image,label)inenumerate(train_loader):image=image.to(device)label=label.to(device)output=model(image)loss=criterion(output,label)optimizer.zero_grad()loss.backward()optimizer.step()if(i+1)%batch_size==0:print("Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}".format(epoch+1,num_epochs,i+1,len(train_loader),loss.item()))model.eval()withtorch.no_grad():total=0correct=0forimage,labelintest_loader:image=image.to(device)label=label.to(device)output=model(image)_,predict=torch.max(output,1)total+=len(label)correct+=(predict==label).sum().item()print("Accuracy of test {} images: {} %".format(len(test_dataset),correct/total*100))

vgg16.py

importtorchimporttorch.nnasnnimporttorch.optimasoptimfromtorch.utils.dataimportDataset,DataLoaderfromtorchvisionimportmodels,transformsfromPILimportImageimportos num_epochs=5batch_size=10learning_rate=0.001train_size=25000indices=torch.randperm(train_size)train_indices=indices[:20000]test_indices=indices[20000:]classDogsVsCatsDataset(Dataset):def__init__(self,root,train=True,transform=None):super().__init__()self.root=root self.transform=transform self.classes=["dog","cat"]self.files=[]self.labels=[]files=os.listdir(root)index=train_indicesiftrainelsetest_indicesforiinindex:file=files[i]self.files.append(file)if"dog"infile:self.labels.append(0)else:self.labels.append(1)def__len__(self):returnlen(self.files)def__getitem__(self,index):path=os.path.join(self.root,self.files[index])image=Image.open(path).convert("RGB")label=self.labels[index]ifself.transform:image=self.transform(image)returnimage,label transform=transforms.Compose([transforms.RandomResizedCrop(size=224),transforms.RandomHorizontalFlip(),transforms.RandomVerticalFlip(),transforms.RandomRotation(degrees=30),transforms.ColorJitter(brightness=0.2,contrast=0.2,saturation=0.2,hue=0.1),transforms.ToTensor(),transforms.Normalize(mean=(0.485,0.456,0.406),std=(0.229,0.224,0.225)),])train_dataset=DogsVsCatsDataset(root=".\\data\\Dogs Vs Cats\\train",train=True,transform=transform)test_dataset=DogsVsCatsDataset(root=".\\data\\Dogs Vs Cats\\train",train=False,transform=transform)train_loader=DataLoader(train_dataset,batch_size=batch_size,shuffle=True)test_loader=DataLoader(test_dataset,batch_size=batch_size,shuffle=False)device=torch.device("cuda"iftorch.cuda.is_available()else"cpu")model=models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)forparaminmodel.features.parameters():param.requires_grad=Falsemodel.classifier[6].out_features=2model=model.to(device)criterion=nn.CrossEntropyLoss()optimizer=optim.Adam(model.parameters(),lr=learning_rate)model.train()forepochinrange(num_epochs):fori,(image,label)inenumerate(train_loader):image=image.to(device)label=label.to(device)output=model(image)loss=criterion(output,label)optimizer.zero_grad()loss.backward()optimizer.step()if(i+1)%100==0:print("Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}".format(epoch+1,num_epochs,i+1,len(train_loader),loss.item()))model.eval()withtorch.no_grad():total=0correct=0forimage,labelintest_loader:image=image.to(device)label=label.to(device)output=model(image)_,predict=torch.max(output,1)total+=len(label)correct+=(predict==label).sum().item()print("Accuracy of test {} images: {} %".format(len(test_dataset),correct/total*100))

参考文档:PyTorch猫狗大战:CNN vs VGG16迁移学习,谁更胜一筹?- 超腾开源

http://www.jsqmd.com/news/911205/

相关文章:

  • 基于Arduino与AI的Furby智能改造:硬件拆解与Python集成实践
  • 在 VS2022 中创建 Presenter 文件对
  • 3种Janus-7B应用场景:从客服助手到个性化内容创作
  • Simple Live终极指南:一站式跨平台直播聚合解决方案,5分钟搭建专属直播中心
  • GA/T 1400协议实战:用Java和RestTemplate搞定通知消息推送(附完整代码)
  • 医药冷链运输的温湿度监控能做到无人值守吗?企业级Agent如何重塑效率
  • 深入解析LibreHardwareMonitor:开源硬件监控解决方案的核心架构与实践应用
  • 免费写标书软件推荐:一个功能永久免费的标书AI,值不值得试? - 陈工0237
  • 2026上海黄金回收TOP5靠谱商家(实测推荐)上海捷当黄金领跑黄金回收靠谱榜单 - 资讯快报
  • BERT-large-uncased训练数据揭秘:BookCorpus+Wikipedia的11亿词元预训练
  • 2026年前端开发完全指南:AI辅助写组件、调Bug、生成接口代码,效率翻倍
  • 开源矢量网络分析仪LibreVNA:从6GHz射频测量到专业级信号分析的完整指南
  • Gemma-4-31B-it-assistant:Google开源多模态AI助手完全指南
  • 企业矩阵系统建设实践:从账号管理到AI内容协同
  • 2026徐州黄金回收甄选TOP4:仅这几家满足零投诉无隐形扣费 - 生活测评君
  • 微信聊天记录永久保存终极指南:如何让每一段对话都成为永恒记忆
  • 深度解析:洛雪音乐音源架构的技术实现与性能优化
  • 基于Raspberry Pi Pico与MicroPython的嵌入式记忆游戏开发实战
  • 2026年沈阳地坪市场扫描:水性聚氨酯砂浆厂家多维实力梳理 - 兔兔不是荼荼
  • 从BIOS到ACPI:聊聊操作系统电源管理这二十年的‘幕后英雄’
  • h2o-danube-1.8b-sft 对比分析:与同类18亿参数模型的性能评测
  • 泰国DAB法规 学习英语~
  • NPU加速实战:Llama3-ChatQA-1.5-8B在国产硬件上的部署与性能优化指南
  • 2026年前端框架选型指南:React、Vue、Angular怎么选?AI辅助开发全流程演示
  • 2026年6月租房不收中介费指南,房东直租app省心租房攻略 - 资讯速览
  • 从Modbus到XMODEM:一文搞懂CRC-16不同变体的区别与C语言实战
  • 跨平台资源下载神器:3分钟快速掌握res-downloader完整教程
  • 2026苏州闲置黄金处置科普 | 选对门店避开回收各类套路 - 奢侈品回收测评
  • 原神FPS解锁器终极指南:三步实现高帧率游戏体验
  • 平台认证 + 实绩核验 拼多多代运营优质服务商推荐 - 品牌榜中榜