YOLOv10模型改进-Backbone改进-第60篇: YOLOv10改进策略【Backbone】| PVT Backbone替换
一、本文介绍
本文记录的是利用PVT(Pyramid Vision Transformer)作为Backbone改进YOLOv10的特征提取部分。PVT通过金字塔结构和空间缩减注意力,实现高效的多尺度特征提取。
二、PVT模块介绍
2.1 设计出发点
ViT缺乏多尺度特征提取能力,PVT通过金字塔结构和空间缩减注意力,同时兼顾全局建模和多尺度特征。
2.2 模块结构
PVT块:
- 空间缩减注意力:减少注意力计算复杂度
- 前馈网络:非线性变换
- 层次化设计:多尺度特征输出
三、PVT的实现代码
importtorchimporttorch.nnasnnclassSpatialReductionAttention(nn.Module):def__init__(self,dim,num_heads=4,sr_ratio=1):super().__init__()self.num_heads=num_heads self.scale=(dim//num_heads)**-0.5self.q=nn.Linear(dim,dim)self.kv=nn.Linear(dim,dim*2)self.proj=nn.Linear(dim,dim)self.sr_ratio=sr_ratioifsr_ratio>1:self.sr=nn.Conv2d(dim,dim,sr_ratio,sr_ratio)self.norm=nn.LayerNorm(dim)defforward(self,x,H,W):B,N,C=x.shape q=self.q(x).reshape(B,N,self.num_heads,C//self.num_heads).permute(0,2,1,3)ifself.sr_ratio>1:x_=x.transpose(1,2).view(B,C,H,W)x_=self.sr(x_).reshape(B,C,-1).transpose(1,2)x_=self.norm(x_)kv=self.kv(x_).reshape(B,-1,2,self.num_heads,C//self.num_heads).permute(2,0,3,1,4)else:kv=self.kv(x).reshape(B,N,2,self.num_heads,C//self.num_heads).permute(2,0,3,1,4)k,v=kv[0],kv[1]attn=(q @ k.transpose(-2,-1))*self.scale attn=attn.softmax(dim=-1)x=(attn @ v).transpose(1,2).reshape(B,N,C)returnself.proj(x)classPVTBasicLayer(nn.Module):def__init__(self,dim,num_heads,sr_ratio=1):super().__init__()self.norm1=nn.LayerNorm(dim)self.attn=SpatialReductionAttention(dim,num_heads,sr_ratio)self.norm2=nn.LayerNorm(dim)self.mlp=nn.Sequential(nn.Linear(dim,dim*4),nn.GELU(),nn.Linear(dim*4,dim))defforward(self,x,H,W):x=x+self.attn(self.norm1(x),H,W)x=x+self.mlp(self.norm2(x))returnxclassPVT(nn.Module):def__init__(self,c1=3,c2=1024,embed_dims=[64,128,256,512],num_heads=[1,2,4,8],sr_ratios=[8,4,2,1]):super().__init__()self.patch_embeds=nn.ModuleList()self.patch_embeds.append(nn.Sequential(nn.Conv2d(c1,embed_dims[0],7,4,3),nn.LayerNorm(embed_dims[0])))foriinrange(1,4):self.patch_embeds.append(nn.Sequential(nn.Conv2d(embed_dims[i-1],embed_dims[i],3,2,1),nn.LayerNorm(embed_dims[i])))self.layers=nn.ModuleList()foriinrange(4):self.layers.append(PVTBasicLayer(embed_dims[i],num_heads[i],sr_ratios[i]))self.final_conv=nn.Conv2d(embed_dims[-1],c2,1,bias=False)defforward(self,x):B=x.shape[0]fori,embedinenumerate(self.patch_embeds):x=embed(x)H,W=x.shape[2:]x=x.flatten(2).transpose(1,2)x=self.layers[i](x,H,W)ifi<3:x=x.transpose(1,2).reshape(B,-1,H,W)x=x.transpose(1,2).reshape(B,-1,H,W)x=self.final_conv(x)returnx四、创新模块
将PVT作为Backbone集成到YOLOv10中:
# yolov10n_pvt.yamlbackbone:-[-1,1,PVT,[3,1024]]-[-1,1,SPPF,[1024,5]]五、预期结果
| 模型 | mAP@0.5 | mAP@0.5:0.95 | 参数量 |
|---|---|---|---|
| YOLOv10n | 52.3% | 27.9% | 2.7M |
| YOLOv10n-PVT | 53.2% | 28.8% | 13.0M |
📌项目环境配置:
- Python:3.8.10+
- PyTorch:2.0.0+
- CUDA:11.8+
- Ultralytics:8.3.13+
