用PyTorch复现FCN语义分割:从VGG16预训练到FCN-8s实战,附完整代码与避坑指南
用PyTorch实现FCN-8s语义分割:从VGG16迁移学习到工业级部署全流程
当我们需要让计算机理解图像中每个像素的语义时,传统的分类网络就显得力不从心了。想象一下自动驾驶汽车需要识别道路上的行人、车辆和交通标志,或者医疗影像分析需要精确勾勒出肿瘤边界——这些场景都需要像素级的理解能力。全卷积网络(FCN)正是为解决这类问题而生的革命性架构。
1. 环境准备与数据预处理
在开始构建模型之前,我们需要确保开发环境配置正确。推荐使用Python 3.8+和PyTorch 1.10+版本,这些版本在稳定性和功能支持上都有良好表现。对于GPU加速,CUDA 11.3是目前最兼容的版本。
conda create -n fcn python=3.8 conda activate fcn pip install torch torchvision opencv-python matplotlib1.1 数据集处理实战
语义分割数据集通常包含原始图像和对应的标注掩码。以Cityscapes数据集为例,我们需要特别处理标注图像:
class CityscapesDataset(Dataset): def __init__(self, root, split='train', transform=None): self.images_dir = os.path.join(root, 'leftImg8bit', split) self.targets_dir = os.path.join(root, 'gtFine', split) self.transform = transform self.images = [] self.targets = [] for city in os.listdir(self.images_dir): img_dir = os.path.join(self.images_dir, city) target_dir = os.path.join(self.targets_dir, city) for file_name in os.listdir(img_dir): self.images.append(os.path.join(img_dir, file_name)) target_name = file_name.replace('leftImg8bit', 'gtFine_labelIds') self.targets.append(os.path.join(target_dir, target_name)) def __getitem__(self, index): image = cv2.cvtColor(cv2.imread(self.images[index]), cv2.COLOR_BGR2RGB) target = cv2.imread(self.targets[index], cv2.IMREAD_GRAYSCALE) if self.transform: augmented = self.transform(image=image, mask=target) image, target = augmented['image'], augmented['mask'] return image, target注意:处理标注时需要考虑类别不平衡问题。例如在城市街景中,天空和道路像素可能远多于行人像素,这会影响模型训练效果。
2. VGG16骨干网络改造
FCN-8s使用VGG16作为特征提取器,但需要对其进行关键改造:
- 移除最后的全连接层
- 保留卷积层和池化层的特征提取能力
- 添加1x1卷积层替代原始分类头
class VGG16FeatureExtractor(nn.Module): def __init__(self, pretrained=True): super().__init__() vgg = models.vgg16(pretrained=pretrained).features self.slice1 = nn.Sequential() self.slice2 = nn.Sequential() self.slice3 = nn.Sequential() self.slice4 = nn.Sequential() self.slice5 = nn.Sequential() for x in range(5): # conv1_2 self.slice1.add_module(str(x), vgg[x]) for x in range(5, 10): # conv2_2 self.slice2.add_module(str(x), vgg[x]) for x in range(10, 17): # conv3_3 self.slice3.add_module(str(x), vgg[x]) for x in range(17, 24): # conv4_3 self.slice4.add_module(str(x), vgg[x]) for x in range(24, 31): # conv5_3 self.slice5.add_module(str(x), vgg[x]) if pretrained: for param in self.parameters(): param.requires_grad = False def forward(self, x): h = self.slice1(x) h_relu1_2 = h h = self.slice2(h) h_relu2_2 = h h = self.slice3(h) h_relu3_3 = h h = self.slice4(h) h_relu4_3 = h h = self.slice5(h) h_relu5_3 = h return h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_33. FCN-8s网络架构实现
FCN-8s的核心创新在于多尺度特征融合:
- 32倍上采样路径:直接从conv7输出上采样
- 16倍上采样路径:融合pool4特征
- 8倍上采样路径:进一步融合pool3特征
class FCN8s(nn.Module): def __init__(self, n_class=21): super().__init__() self.features = VGG16FeatureExtractor(pretrained=True) # 1x1卷积替代全连接 self.conv6 = nn.Conv2d(512, 4096, kernel_size=1) self.drop6 = nn.Dropout2d() self.conv7 = nn.Conv2d(4096, 4096, kernel_size=1) self.drop7 = nn.Dropout2d() self.score_fr = nn.Conv2d(4096, n_class, kernel_size=1) # 跳级连接 self.score_pool4 = nn.Conv2d(512, n_class, kernel_size=1) self.score_pool3 = nn.Conv2d(256, n_class, kernel_size=1) # 上采样 self.upscore2 = nn.ConvTranspose2d( n_class, n_class, kernel_size=4, stride=2, padding=1) self.upscore8 = nn.ConvTranspose2d( n_class, n_class, kernel_size=16, stride=8, padding=4) self.upscore_pool4 = nn.ConvTranspose2d( n_class, n_class, kernel_size=4, stride=2, padding=1) def forward(self, x): _, _, h_pool3, h_pool4, h_pool5 = self.features(x) # 主路径处理 h = self.drop6(F.relu(self.conv6(h_pool5))) h = self.drop7(F.relu(self.conv7(h))) h = self.score_fr(h) h = self.upscore2(h) # 16x上采样 # 融合pool4特征 upscore_pool4 = self.score_pool4(h_pool4) h = h[:, :, 1:1+upscore_pool4.size(2), 1:1+upscore_pool4.size(3)] h = h + upscore_pool4 h = self.upscore_pool4(h) # 8x上采样 # 融合pool3特征 upscore_pool3 = self.score_pool3(h_pool3) h = h[:, :, 1:1+upscore_pool3.size(2), 1:1+upscore_pool3.size(3)] h = h + upscore_pool3 return self.upscore8(h) # 最终8x上采样提示:特征融合时要注意张量尺寸对齐。FCN论文中采用crop方式处理边界不匹配问题,这是实现时容易出错的关键点。
4. 训练策略与性能优化
4.1 损失函数选择
语义分割常用的损失函数对比:
| 损失函数 | 优点 | 缺点 | 适用场景 |
|---|---|---|---|
| CrossEntropy | 分类标准选择 | 忽略类别不平衡 | 均衡数据集 |
| DiceLoss | 处理类别不平衡 | 训练不稳定 | 医学图像 |
| FocalLoss | 关注难样本 | 超参敏感 | 目标检测 |
| Lovász-Softmax | 直接优化mIoU | 计算复杂 | 竞赛场景 |
对于城市街景分割,推荐使用组合损失:
class MixedLoss(nn.Module): def __init__(self, alpha=0.5): super().__init__() self.alpha = alpha self.ce = nn.CrossEntropyLoss() self.dice = DiceLoss() def forward(self, pred, target): return self.alpha * self.ce(pred, target) + (1-self.alpha) * self.dice(pred, target)4.2 学习率调度策略
采用warmup+余弦退火组合策略:
def get_lr_scheduler(optimizer, n_iter_per_epoch, args): def lr_lambda(current_iter): # Warmup阶段 if current_iter < args.warmup_epochs * n_iter_per_epoch: return float(current_iter) / float(max(1, args.warmup_epochs * n_iter_per_epoch)) # 余弦退火阶段 return 0.5 * (1. + math.cos(math.pi * (current_iter - args.warmup_epochs * n_iter_per_epoch) / ((args.epochs - args.warmup_epochs) * n_iter_per_epoch))) return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)4.3 训练过程可视化
使用TensorBoard记录关键指标:
writer = SummaryWriter(log_dir='runs/fcn8s_experiment') for epoch in range(epochs): model.train() for i, (images, masks) in enumerate(train_loader): outputs = model(images) loss = criterion(outputs, masks) optimizer.zero_grad() loss.backward() optimizer.step() # 记录训练指标 writer.add_scalar('Loss/train', loss.item(), epoch*len(train_loader)+i) # 验证集评估 if i % 100 == 0: model.eval() val_loss = 0 mIoU = 0 with torch.no_grad(): for val_images, val_masks in val_loader: val_outputs = model(val_images) val_loss += criterion(val_outputs, val_masks).item() mIoU += mean_iou(val_outputs, val_masks) writer.add_scalar('Loss/val', val_loss/len(val_loader), epoch*len(train_loader)+i) writer.add_scalar('mIoU/val', mIoU/len(val_loader), epoch*len(train_loader)+i) model.train()5. 模型部署与性能优化
5.1 模型量化加速
使用PyTorch的量化工具减小模型体积:
model = FCN8s(n_class=21).eval() # 量化配置 model.qconfig = torch.quantization.get_default_qconfig('fbgemm') quantized_model = torch.quantization.prepare(model, inplace=False) quantized_model = torch.quantization.convert(quantized_model, inplace=False) # 测试量化效果 with torch.no_grad(): quantized_output = quantized_model(torch.rand(1,3,512,512)) print(f"量化模型输出尺寸: {quantized_output.shape}")5.2 ONNX格式导出
dummy_input = torch.randn(1, 3, 512, 512) torch.onnx.export(model, dummy_input, "fcn8s.onnx", input_names=["input"], output_names=["output"], dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}})5.3 TensorRT优化
# 使用trtexec转换ONNX到TensorRT引擎 trtexec --onnx=fcn8s.onnx --saveEngine=fcn8s.engine --fp16 --workspace=20486. 实际应用中的挑战与解决方案
6.1 小目标分割难题
当处理小物体时,FCN-8s可能表现不佳。解决方案包括:
- 多尺度训练:在训练时随机缩放输入图像
- 注意力机制:在跳级连接处添加注意力模块
- 高分辨率分支:保留更多底层特征
class AttentionBlock(nn.Module): def __init__(self, in_channels): super().__init__() self.query = nn.Conv2d(in_channels, in_channels//8, kernel_size=1) self.key = nn.Conv2d(in_channels, in_channels//8, kernel_size=1) self.value = nn.Conv2d(in_channels, in_channels, kernel_size=1) self.gamma = nn.Parameter(torch.zeros(1)) def forward(self, x): batch_size, C, H, W = x.size() query = self.query(x).view(batch_size, -1, H*W).permute(0,2,1) key = self.key(x).view(batch_size, -1, H*W) energy = torch.bmm(query, key) attention = F.softmax(energy, dim=-1) value = self.value(x).view(batch_size, -1, H*W) out = torch.bmm(value, attention.permute(0,2,1)) out = out.view(batch_size, C, H, W) return self.gamma * out + x6.2 实时性优化
对于实时应用,可以通过以下方式优化:
- 通道剪枝:移除不重要的卷积通道
- 知识蒸馏:用大模型指导小模型训练
- 架构搜索:自动寻找高效网络结构
def channel_prune(model, prune_percent=0.3): parameters_to_prune = [] for name, module in model.named_modules(): if isinstance(module, nn.Conv2d): parameters_to_prune.append((module, 'weight')) prune.global_unstructured( parameters_to_prune, pruning_method=prune.L1Unstructured, amount=prune_percent, ) for module, _ in parameters_to_prune: prune.remove(module, 'weight') return model7. 进阶技巧与最新改进
7.1 深度监督训练
在中间层添加辅助损失函数:
class FCN8sWithDS(nn.Module): def __init__(self, n_class=21): super().__init__() # ...原有初始化代码... # 深度监督分支 self.ds_conv1 = nn.Conv2d(256, n_class, kernel_size=1) self.ds_conv2 = nn.Conv2d(512, n_class, kernel_size=1) def forward(self, x): _, _, h_pool3, h_pool4, h_pool5 = self.features(x) # 深度监督输出 ds1 = self.ds_conv1(h_pool3) ds2 = self.ds_conv2(h_pool4) # 主路径处理... return main_output, ds1, ds27.2 自注意力增强
class SelfAttention(nn.Module): def __init__(self, in_channels): super().__init__() self.query = nn.Conv2d(in_channels, in_channels//8, 1) self.key = nn.Conv2d(in_channels, in_channels//8, 1) self.value = nn.Conv2d(in_channels, in_channels, 1) self.gamma = nn.Parameter(torch.zeros(1)) def forward(self, x): batch_size, C, H, W = x.size() proj_query = self.query(x).view(batch_size, -1, H*W).permute(0,2,1) proj_key = self.key(x).view(batch_size, -1, H*W) energy = torch.bmm(proj_query, proj_key) attention = F.softmax(energy, dim=-1) proj_value = self.value(x).view(batch_size, -1, H*W) out = torch.bmm(proj_value, attention.permute(0,2,1)) out = out.view(batch_size, C, H, W) return self.gamma * out + x