告别VGG分类:手把手教你用PyTorch复现FCN-8s语义分割(附完整代码)
从VGG到FCN-8s:PyTorch实战语义分割全流程解析
引言
语义分割作为计算机视觉领域的核心任务之一,其目标是为图像中的每个像素分配一个类别标签。与传统的图像分类不同,语义分割需要同时理解图像的全局语义信息和局部细节特征。2015年,FCN(全卷积网络)的提出彻底改变了这一领域,它将传统的卷积神经网络(CNN)改造为能够处理任意尺寸输入的全卷积结构,开创了深度学习在语义分割中的应用先河。
本文将带您从零开始,使用PyTorch框架完整实现FCN-8s模型。不同于简单的理论讲解,我们将重点关注以下实战要点:
- 如何将预训练的VGG16网络改造为FCN架构
- 转置卷积(反卷积)层的具体实现与参数调优
- 跳级连接(Skip Connection)的代码实现技巧
- PASCAL VOC数据集的加载与预处理
- 训练过程中的常见问题与调试方法
无论您是刚接触深度学习的初学者,还是希望深入了解语义分割实现细节的中级开发者,本文提供的完整代码和详细解释都将帮助您快速掌握这一核心技术。
1. 环境准备与数据加载
1.1 安装必要的库
在开始之前,请确保已安装以下Python库:
pip install torch torchvision matplotlib numpy opencv-python pillow1.2 下载并准备PASCAL VOC数据集
PASCAL VOC是语义分割领域最常用的基准数据集之一。我们将使用2012版本:
import torchvision.datasets as datasets # 下载数据集 voc_train = datasets.VOCSegmentation( root='./data', year='2012', image_set='train', download=True, transform=None ) voc_val = datasets.VOCSegmentation( root='./data', year='2012', image_set='val', download=True, transform=None )1.3 实现自定义数据加载器
为了高效地加载和处理数据,我们需要实现一个自定义的数据集类:
from torch.utils.data import Dataset import torchvision.transforms as transforms import cv2 import numpy as np class VOCDataset(Dataset): def __init__(self, root_dir, split='train'): self.root_dir = root_dir self.split = split self.image_dir = os.path.join(root_dir, 'JPEGImages') self.mask_dir = os.path.join(root_dir, 'SegmentationClass') # 获取所有图像文件名 with open(os.path.join(root_dir, f'ImageSets/Segmentation/{split}.txt'), 'r') as f: self.file_names = [x.strip() for x in f.readlines()] # 定义图像预处理 self.image_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def __len__(self): return len(self.file_names) def __getitem__(self, idx): img_name = os.path.join(self.image_dir, self.file_names[idx] + '.jpg') mask_name = os.path.join(self.mask_dir, self.file_names[idx] + '.png') # 读取图像和掩码 image = cv2.imread(img_name) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) mask = cv2.imread(mask_name, cv2.IMREAD_GRAYSCALE) # 应用预处理 image = self.image_transform(image) mask = torch.from_numpy(mask).long() return image, mask2. FCN-8s网络架构详解
2.1 骨干网络:VGG16的改造
FCN-8s使用VGG16作为特征提取器,但需要进行以下修改:
- 移除最后的全连接层
- 将全连接层转换为等效的卷积层
- 添加转置卷积层进行上采样
import torch.nn as nn from torchvision.models import vgg16 class FCN8s(nn.Module): def __init__(self, num_classes): super(FCN8s, self).__init__() # 加载预训练的VGG16模型 vgg = vgg16(pretrained=True) features = list(vgg.features.children()) # 编码器部分 self.encoder1 = nn.Sequential(*features[:5]) # 到conv1_2 self.encoder2 = nn.Sequential(*features[5:10]) # 到conv2_2 self.encoder3 = nn.Sequential(*features[10:17]) # 到conv3_3 self.encoder4 = nn.Sequential(*features[17:24]) # 到conv4_3 self.encoder5 = nn.Sequential(*features[24:]) # 到conv5_3 # 将全连接层转换为卷积层 self.fc6 = nn.Conv2d(512, 4096, kernel_size=7, padding=3) self.fc7 = nn.Conv2d(4096, 4096, kernel_size=1) # 分类器 self.classifier = nn.Conv2d(4096, num_classes, kernel_size=1) # 转置卷积层 self.upscore2 = nn.ConvTranspose2d( num_classes, num_classes, kernel_size=4, stride=2, bias=False) self.upscore8 = nn.ConvTranspose2d( num_classes, num_classes, kernel_size=16, stride=8, bias=False) self.upscore_pool4 = nn.ConvTranspose2d( num_classes, num_classes, kernel_size=4, stride=2, bias=False) # 跳级连接 self.score_pool4 = nn.Conv2d(512, num_classes, kernel_size=1) self.score_pool3 = nn.Conv2d(256, num_classes, kernel_size=1)2.2 前向传播与跳级连接
FCN-8s的核心在于将深层特征与浅层特征通过跳级连接融合:
def forward(self, x): h = x h = self.encoder1(h) h = self.encoder2(h) h = self.encoder3(h) pool3 = h # 1/8尺寸 h = self.encoder4(h) pool4 = h # 1/16尺寸 h = self.encoder5(h) # 1/32尺寸 h = self.fc6(h) h = F.relu(h) h = F.dropout(h, p=0.5, training=self.training) h = self.fc7(h) h = F.relu(h) h = F.dropout(h, p=0.5, training=self.training) h = self.classifier(h) # 第一次上采样 h = self.upscore2(h) upscore2 = h # 1/16尺寸 # 融合pool4特征 h = self.score_pool4(pool4) h = h[:, :, 5:5+upscore2.size()[2], 5:5+upscore2.size()[3]] score_pool4c = h # 1/16尺寸 h = upscore2 + score_pool4c # 第二次上采样 h = self.upscore_pool4(h) upscore_pool4 = h # 1/8尺寸 # 融合pool3特征 h = self.score_pool3(pool3) h = h[:, :, 9:9+upscore_pool4.size()[2], 9:9+upscore_pool4.size()[3]] score_pool3c = h # 1/8尺寸 h = upscore_pool4 + score_pool3c # 最终上采样到原图尺寸 h = self.upscore8(h) h = h[:, :, 31:31+x.size()[2], 31:31+x.size()[3]].contiguous() return h3. 训练策略与实现
3.1 损失函数选择
语义分割通常使用像素级的交叉熵损失:
def loss_fn(outputs, labels): # 忽略边界像素(标签为255) criterion = nn.CrossEntropyLoss(ignore_index=255) return criterion(outputs, labels)3.2 优化器配置
使用带动量的SGD优化器,并设置适当的学习率衰减策略:
import torch.optim as optim model = FCN8s(num_classes=21) # PASCAL VOC有20类+背景 optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9, weight_decay=5e-4) scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)3.3 训练循环实现
完整的训练过程包括以下步骤:
def train(model, dataloader, optimizer, epoch): model.train() running_loss = 0.0 for i, (inputs, labels) in enumerate(dataloader): inputs = inputs.to(device) labels = labels.to(device) # 梯度清零 optimizer.zero_grad() # 前向传播 outputs = model(inputs) loss = loss_fn(outputs, labels) # 反向传播 loss.backward() optimizer.step() # 统计损失 running_loss += loss.item() if i % 100 == 99: print(f'Epoch: {epoch}, Batch: {i+1}, Loss: {running_loss/100:.4f}') running_loss = 0.0 # 更新学习率 scheduler.step()4. 评估与可视化
4.1 计算mIoU指标
语义分割常用的评估指标是平均交并比(mIoU):
def compute_miou(output, target, num_classes=21): # 忽略边界像素 mask = (target != 255) output = output[mask] target = target[mask] # 计算混淆矩阵 confusion_matrix = torch.zeros(num_classes, num_classes) for t, p in zip(target.view(-1), output.argmax(1).view(-1)): confusion_matrix[t.long(), p.long()] += 1 # 计算每个类的IoU intersection = torch.diag(confusion_matrix) union = confusion_matrix.sum(0) + confusion_matrix.sum(1) - intersection iou = intersection / union # 返回平均IoU return iou.mean().item()4.2 结果可视化
将预测结果与真实标签进行可视化对比:
import matplotlib.pyplot as plt def visualize_results(image, target, prediction): fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5)) # 原始图像 ax1.imshow(image.permute(1, 2, 0)) ax1.set_title('Original Image') ax1.axis('off') # 真实标签 ax2.imshow(target, cmap='jet') ax2.set_title('Ground Truth') ax2.axis('off') # 预测结果 ax3.imshow(prediction.argmax(0), cmap='jet') ax3.set_title('Prediction') ax3.axis('off') plt.show()5. 常见问题与调试技巧
5.1 内存不足问题
当遇到GPU内存不足时,可以尝试以下解决方案:
- 减小批量大小
- 使用混合精度训练
- 梯度累积技术
# 混合精度训练示例 from torch.cuda.amp import GradScaler, autocast scaler = GradScaler() for inputs, labels in dataloader: inputs = inputs.to(device) labels = labels.to(device) optimizer.zero_grad() with autocast(): outputs = model(inputs) loss = loss_fn(outputs, labels) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()5.2 训练不收敛问题
如果模型训练不收敛,可以考虑:
- 检查学习率是否合适
- 验证数据预处理是否正确
- 检查损失函数实现
- 监控梯度流动情况
# 梯度检查工具 from torch.autograd import gradcheck # 选择一个小的输入样本 input_sample = torch.randn(1, 3, 32, 32, dtype=torch.double, requires_grad=True) test = gradcheck(model, input_sample, eps=1e-6, atol=1e-4) print("Gradient check passed:", test)5.3 预测结果模糊问题
FCN-8s有时会产生模糊的预测边界,可以通过以下方法改进:
- 增加跳级连接的数量(如使用FCN-4s)
- 使用CRF后处理
- 尝试更先进的网络结构(如DeepLab、UNet等)
# CRF后处理示例(需要安装pydensecrf) import pydensecrf.densecrf as dcrf from pydensecrf.utils import unary_from_softmax def apply_crf(image, prob_map): # 图像尺寸 h, w = image.shape[:2] # 创建CRF d = dcrf.DenseCRF2D(w, h, 21) # 一元势能 U = unary_from_softmax(prob_map) d.setUnaryEnergy(U) # 二元势能 d.addPairwiseGaussian(sxy=3, compat=3) d.addPairwiseBilateral(sxy=80, srgb=13, rgbim=image, compat=10) # 推理 Q = d.inference(5) return np.argmax(np.array(Q).reshape((21, h, w)), axis=0)