使用PyTorch实现猫狗分类Python源码及准确度对比(CNN、VGG16迁移学习两张方式)
数据集下载:Dogs vs. Cats ~| Kaggle
不同方法准确度
| 方法 | 轮数 | 准确度 |
|---|---|---|
| cnn | 5 | 67.64% |
| cnn | 10 | 74.92% |
| cnn | 15 | 73.42% |
| cnn | 20 | 79.28% |
| cnn | 25 | 78.28% |
| vgg16 | 5 | 86.5% |
| vgg16 | 10 | 86.98% |
| vgg16 | 15 | 85.42% |
cnn.py
importtorchimporttorch.nnasnnimporttorch.optimasoptimfromtorch.utils.dataimportDataset,DataLoaderfromtorchvisionimporttransformsfromPILimportImageimportos num_epochs=20batch_size=50learning_rate=0.001train_size=25000indices=torch.randperm(train_size)train_indices=indices[:20000]test_indices=indices[20000:]classDogsVsCatsDataset(Dataset):def__init__(self,root,train=True,transform=None):super().__init__()self.root=root self.transform=transform self.classes=["dog","cat"]self.files=[]self.labels=[]files=os.listdir(root)index=train_indicesiftrainelsetest_indicesforiinindex:file=files[i]self.files.append(file)if"dog"infile:self.labels.append(0)else:self.labels.append(1)def__len__(self):returnlen(self.files)def__getitem__(self,index):path=os.path.join(self.root,self.files[index])image=Image.open(path).convert("RGB")label=self.labels[index]ifself.transform:image=self.transform(image)returnimage,label transform=transforms.Compose([transforms.Resize((224,224)),transforms.RandomHorizontalFlip(),transforms.RandomVerticalFlip(),transforms.RandomRotation(degrees=30),# transforms.RandomResizedCrop(# size=224, scale=(0.08, 1.0), ratio=(0.75, 1.33333)# ),transforms.ColorJitter(brightness=0.2,contrast=0.2,saturation=0.2,hue=0.1),transforms.ToTensor(),transforms.Normalize(mean=(0.485,0.456,0.406),std=(0.229,0.224,0.225)),])train_dataset=DogsVsCatsDataset(root=".\\data\\Dogs Vs Cats\\train",train=True,transform=transform)test_dataset=DogsVsCatsDataset(root=".\\data\\Dogs Vs Cats\\train",train=False,transform=transform)train_loader=DataLoader(train_dataset,batch_size=batch_size,shuffle=True)test_loader=DataLoader(test_dataset,batch_size=batch_size,shuffle=False)device=torch.device("cuda"iftorch.cuda.is_available()else"cpu")classCNNModel(nn.Module):def__init__(self):super().__init__()self.cnn1=nn.Sequential(nn.Conv2d(3,24,kernel_size=3,stride=1,padding=1),nn.BatchNorm2d(24),nn.MaxPool2d(kernel_size=2,stride=2),)self.cnn2=nn.Sequential(nn.Conv2d(24,48,kernel_size=3,stride=1,padding=1),nn.BatchNorm2d(48),nn.MaxPool2d(kernel_size=2,stride=2),)self.cnn3=nn.Sequential(nn.Conv2d(48,96,kernel_size=3,stride=1,padding=1),nn.BatchNorm2d(96),nn.MaxPool2d(kernel_size=2,stride=2),)self.cnn4=nn.Sequential(nn.Conv2d(96,48,kernel_size=3,stride=1,padding=1),nn.BatchNorm2d(48),nn.MaxPool2d(kernel_size=2,stride=2),)self.dropout=nn.Dropout()self.line1=nn.Linear(14*14*48,512)self.line2=nn.Linear(512,2)defforward(self,x):out=self.cnn1(x)out=self.cnn2(out)out=self.cnn3(out)out=self.cnn4(out)out=out.reshape(out.size(0),-1)out=self.dropout(out)out=self.line1(out)out=self.line2(out)returnout model=CNNModel().to(device)criterion=nn.CrossEntropyLoss()optimizer=optim.Adam(model.parameters(),lr=learning_rate)model.train()forepochinrange(num_epochs):fori,(image,label)inenumerate(train_loader):image=image.to(device)label=label.to(device)output=model(image)loss=criterion(output,label)optimizer.zero_grad()loss.backward()optimizer.step()if(i+1)%batch_size==0:print("Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}".format(epoch+1,num_epochs,i+1,len(train_loader),loss.item()))model.eval()withtorch.no_grad():total=0correct=0forimage,labelintest_loader:image=image.to(device)label=label.to(device)output=model(image)_,predict=torch.max(output,1)total+=len(label)correct+=(predict==label).sum().item()print("Accuracy of test {} images: {} %".format(len(test_dataset),correct/total*100))vgg16.py
importtorchimporttorch.nnasnnimporttorch.optimasoptimfromtorch.utils.dataimportDataset,DataLoaderfromtorchvisionimportmodels,transformsfromPILimportImageimportos num_epochs=5batch_size=10learning_rate=0.001train_size=25000indices=torch.randperm(train_size)train_indices=indices[:20000]test_indices=indices[20000:]classDogsVsCatsDataset(Dataset):def__init__(self,root,train=True,transform=None):super().__init__()self.root=root self.transform=transform self.classes=["dog","cat"]self.files=[]self.labels=[]files=os.listdir(root)index=train_indicesiftrainelsetest_indicesforiinindex:file=files[i]self.files.append(file)if"dog"infile:self.labels.append(0)else:self.labels.append(1)def__len__(self):returnlen(self.files)def__getitem__(self,index):path=os.path.join(self.root,self.files[index])image=Image.open(path).convert("RGB")label=self.labels[index]ifself.transform:image=self.transform(image)returnimage,label transform=transforms.Compose([transforms.RandomResizedCrop(size=224),transforms.RandomHorizontalFlip(),transforms.RandomVerticalFlip(),transforms.RandomRotation(degrees=30),transforms.ColorJitter(brightness=0.2,contrast=0.2,saturation=0.2,hue=0.1),transforms.ToTensor(),transforms.Normalize(mean=(0.485,0.456,0.406),std=(0.229,0.224,0.225)),])train_dataset=DogsVsCatsDataset(root=".\\data\\Dogs Vs Cats\\train",train=True,transform=transform)test_dataset=DogsVsCatsDataset(root=".\\data\\Dogs Vs Cats\\train",train=False,transform=transform)train_loader=DataLoader(train_dataset,batch_size=batch_size,shuffle=True)test_loader=DataLoader(test_dataset,batch_size=batch_size,shuffle=False)device=torch.device("cuda"iftorch.cuda.is_available()else"cpu")model=models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)forparaminmodel.features.parameters():param.requires_grad=Falsemodel.classifier[6].out_features=2model=model.to(device)criterion=nn.CrossEntropyLoss()optimizer=optim.Adam(model.parameters(),lr=learning_rate)model.train()forepochinrange(num_epochs):fori,(image,label)inenumerate(train_loader):image=image.to(device)label=label.to(device)output=model(image)loss=criterion(output,label)optimizer.zero_grad()loss.backward()optimizer.step()if(i+1)%100==0:print("Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}".format(epoch+1,num_epochs,i+1,len(train_loader),loss.item()))model.eval()withtorch.no_grad():total=0correct=0forimage,labelintest_loader:image=image.to(device)label=label.to(device)output=model(image)_,predict=torch.max(output,1)total+=len(label)correct+=(predict==label).sum().item()print("Accuracy of test {} images: {} %".format(len(test_dataset),correct/total*100))参考文档:PyTorch猫狗大战:CNN vs VGG16迁移学习,谁更胜一筹?- 超腾开源
