当前位置: 首页 > news >正文

告别VGG分类:手把手教你用PyTorch复现FCN-8s语义分割(附完整代码)

从VGG到FCN-8s:PyTorch实战语义分割全流程解析

引言

语义分割作为计算机视觉领域的核心任务之一,其目标是为图像中的每个像素分配一个类别标签。与传统的图像分类不同,语义分割需要同时理解图像的全局语义信息和局部细节特征。2015年,FCN(全卷积网络)的提出彻底改变了这一领域,它将传统的卷积神经网络(CNN)改造为能够处理任意尺寸输入的全卷积结构,开创了深度学习在语义分割中的应用先河。

本文将带您从零开始,使用PyTorch框架完整实现FCN-8s模型。不同于简单的理论讲解,我们将重点关注以下实战要点:

  • 如何将预训练的VGG16网络改造为FCN架构
  • 转置卷积(反卷积)层的具体实现与参数调优
  • 跳级连接(Skip Connection)的代码实现技巧
  • PASCAL VOC数据集的加载与预处理
  • 训练过程中的常见问题与调试方法

无论您是刚接触深度学习的初学者,还是希望深入了解语义分割实现细节的中级开发者,本文提供的完整代码和详细解释都将帮助您快速掌握这一核心技术。

1. 环境准备与数据加载

1.1 安装必要的库

在开始之前,请确保已安装以下Python库:

pip install torch torchvision matplotlib numpy opencv-python pillow

1.2 下载并准备PASCAL VOC数据集

PASCAL VOC是语义分割领域最常用的基准数据集之一。我们将使用2012版本:

import torchvision.datasets as datasets # 下载数据集 voc_train = datasets.VOCSegmentation( root='./data', year='2012', image_set='train', download=True, transform=None ) voc_val = datasets.VOCSegmentation( root='./data', year='2012', image_set='val', download=True, transform=None )

1.3 实现自定义数据加载器

为了高效地加载和处理数据,我们需要实现一个自定义的数据集类:

from torch.utils.data import Dataset import torchvision.transforms as transforms import cv2 import numpy as np class VOCDataset(Dataset): def __init__(self, root_dir, split='train'): self.root_dir = root_dir self.split = split self.image_dir = os.path.join(root_dir, 'JPEGImages') self.mask_dir = os.path.join(root_dir, 'SegmentationClass') # 获取所有图像文件名 with open(os.path.join(root_dir, f'ImageSets/Segmentation/{split}.txt'), 'r') as f: self.file_names = [x.strip() for x in f.readlines()] # 定义图像预处理 self.image_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def __len__(self): return len(self.file_names) def __getitem__(self, idx): img_name = os.path.join(self.image_dir, self.file_names[idx] + '.jpg') mask_name = os.path.join(self.mask_dir, self.file_names[idx] + '.png') # 读取图像和掩码 image = cv2.imread(img_name) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) mask = cv2.imread(mask_name, cv2.IMREAD_GRAYSCALE) # 应用预处理 image = self.image_transform(image) mask = torch.from_numpy(mask).long() return image, mask

2. FCN-8s网络架构详解

2.1 骨干网络:VGG16的改造

FCN-8s使用VGG16作为特征提取器,但需要进行以下修改:

  1. 移除最后的全连接层
  2. 将全连接层转换为等效的卷积层
  3. 添加转置卷积层进行上采样
import torch.nn as nn from torchvision.models import vgg16 class FCN8s(nn.Module): def __init__(self, num_classes): super(FCN8s, self).__init__() # 加载预训练的VGG16模型 vgg = vgg16(pretrained=True) features = list(vgg.features.children()) # 编码器部分 self.encoder1 = nn.Sequential(*features[:5]) # 到conv1_2 self.encoder2 = nn.Sequential(*features[5:10]) # 到conv2_2 self.encoder3 = nn.Sequential(*features[10:17]) # 到conv3_3 self.encoder4 = nn.Sequential(*features[17:24]) # 到conv4_3 self.encoder5 = nn.Sequential(*features[24:]) # 到conv5_3 # 将全连接层转换为卷积层 self.fc6 = nn.Conv2d(512, 4096, kernel_size=7, padding=3) self.fc7 = nn.Conv2d(4096, 4096, kernel_size=1) # 分类器 self.classifier = nn.Conv2d(4096, num_classes, kernel_size=1) # 转置卷积层 self.upscore2 = nn.ConvTranspose2d( num_classes, num_classes, kernel_size=4, stride=2, bias=False) self.upscore8 = nn.ConvTranspose2d( num_classes, num_classes, kernel_size=16, stride=8, bias=False) self.upscore_pool4 = nn.ConvTranspose2d( num_classes, num_classes, kernel_size=4, stride=2, bias=False) # 跳级连接 self.score_pool4 = nn.Conv2d(512, num_classes, kernel_size=1) self.score_pool3 = nn.Conv2d(256, num_classes, kernel_size=1)

2.2 前向传播与跳级连接

FCN-8s的核心在于将深层特征与浅层特征通过跳级连接融合:

def forward(self, x): h = x h = self.encoder1(h) h = self.encoder2(h) h = self.encoder3(h) pool3 = h # 1/8尺寸 h = self.encoder4(h) pool4 = h # 1/16尺寸 h = self.encoder5(h) # 1/32尺寸 h = self.fc6(h) h = F.relu(h) h = F.dropout(h, p=0.5, training=self.training) h = self.fc7(h) h = F.relu(h) h = F.dropout(h, p=0.5, training=self.training) h = self.classifier(h) # 第一次上采样 h = self.upscore2(h) upscore2 = h # 1/16尺寸 # 融合pool4特征 h = self.score_pool4(pool4) h = h[:, :, 5:5+upscore2.size()[2], 5:5+upscore2.size()[3]] score_pool4c = h # 1/16尺寸 h = upscore2 + score_pool4c # 第二次上采样 h = self.upscore_pool4(h) upscore_pool4 = h # 1/8尺寸 # 融合pool3特征 h = self.score_pool3(pool3) h = h[:, :, 9:9+upscore_pool4.size()[2], 9:9+upscore_pool4.size()[3]] score_pool3c = h # 1/8尺寸 h = upscore_pool4 + score_pool3c # 最终上采样到原图尺寸 h = self.upscore8(h) h = h[:, :, 31:31+x.size()[2], 31:31+x.size()[3]].contiguous() return h

3. 训练策略与实现

3.1 损失函数选择

语义分割通常使用像素级的交叉熵损失:

def loss_fn(outputs, labels): # 忽略边界像素(标签为255) criterion = nn.CrossEntropyLoss(ignore_index=255) return criterion(outputs, labels)

3.2 优化器配置

使用带动量的SGD优化器,并设置适当的学习率衰减策略:

import torch.optim as optim model = FCN8s(num_classes=21) # PASCAL VOC有20类+背景 optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9, weight_decay=5e-4) scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

3.3 训练循环实现

完整的训练过程包括以下步骤:

def train(model, dataloader, optimizer, epoch): model.train() running_loss = 0.0 for i, (inputs, labels) in enumerate(dataloader): inputs = inputs.to(device) labels = labels.to(device) # 梯度清零 optimizer.zero_grad() # 前向传播 outputs = model(inputs) loss = loss_fn(outputs, labels) # 反向传播 loss.backward() optimizer.step() # 统计损失 running_loss += loss.item() if i % 100 == 99: print(f'Epoch: {epoch}, Batch: {i+1}, Loss: {running_loss/100:.4f}') running_loss = 0.0 # 更新学习率 scheduler.step()

4. 评估与可视化

4.1 计算mIoU指标

语义分割常用的评估指标是平均交并比(mIoU):

def compute_miou(output, target, num_classes=21): # 忽略边界像素 mask = (target != 255) output = output[mask] target = target[mask] # 计算混淆矩阵 confusion_matrix = torch.zeros(num_classes, num_classes) for t, p in zip(target.view(-1), output.argmax(1).view(-1)): confusion_matrix[t.long(), p.long()] += 1 # 计算每个类的IoU intersection = torch.diag(confusion_matrix) union = confusion_matrix.sum(0) + confusion_matrix.sum(1) - intersection iou = intersection / union # 返回平均IoU return iou.mean().item()

4.2 结果可视化

将预测结果与真实标签进行可视化对比:

import matplotlib.pyplot as plt def visualize_results(image, target, prediction): fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5)) # 原始图像 ax1.imshow(image.permute(1, 2, 0)) ax1.set_title('Original Image') ax1.axis('off') # 真实标签 ax2.imshow(target, cmap='jet') ax2.set_title('Ground Truth') ax2.axis('off') # 预测结果 ax3.imshow(prediction.argmax(0), cmap='jet') ax3.set_title('Prediction') ax3.axis('off') plt.show()

5. 常见问题与调试技巧

5.1 内存不足问题

当遇到GPU内存不足时,可以尝试以下解决方案:

  1. 减小批量大小
  2. 使用混合精度训练
  3. 梯度累积技术
# 混合精度训练示例 from torch.cuda.amp import GradScaler, autocast scaler = GradScaler() for inputs, labels in dataloader: inputs = inputs.to(device) labels = labels.to(device) optimizer.zero_grad() with autocast(): outputs = model(inputs) loss = loss_fn(outputs, labels) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()

5.2 训练不收敛问题

如果模型训练不收敛,可以考虑:

  1. 检查学习率是否合适
  2. 验证数据预处理是否正确
  3. 检查损失函数实现
  4. 监控梯度流动情况
# 梯度检查工具 from torch.autograd import gradcheck # 选择一个小的输入样本 input_sample = torch.randn(1, 3, 32, 32, dtype=torch.double, requires_grad=True) test = gradcheck(model, input_sample, eps=1e-6, atol=1e-4) print("Gradient check passed:", test)

5.3 预测结果模糊问题

FCN-8s有时会产生模糊的预测边界,可以通过以下方法改进:

  1. 增加跳级连接的数量(如使用FCN-4s)
  2. 使用CRF后处理
  3. 尝试更先进的网络结构(如DeepLab、UNet等)
# CRF后处理示例(需要安装pydensecrf) import pydensecrf.densecrf as dcrf from pydensecrf.utils import unary_from_softmax def apply_crf(image, prob_map): # 图像尺寸 h, w = image.shape[:2] # 创建CRF d = dcrf.DenseCRF2D(w, h, 21) # 一元势能 U = unary_from_softmax(prob_map) d.setUnaryEnergy(U) # 二元势能 d.addPairwiseGaussian(sxy=3, compat=3) d.addPairwiseBilateral(sxy=80, srgb=13, rgbim=image, compat=10) # 推理 Q = d.inference(5) return np.argmax(np.array(Q).reshape((21, h, w)), axis=0)
http://www.jsqmd.com/news/683870/

相关文章:

  • 2026灯箱卷王横评:5大3M灯箱供应商性能实测 选型建议 - 资讯焦点
  • 为什么你的边缘Docker服务总在凌晨3点崩溃?——基于127台边缘设备日志的11项隐性资源耗尽预警指标
  • 从零开始手搓机器人关节:我用Arduino+步进电机驱动器DIY了一个二自由度机械臂控制器
  • 【会议征稿通知 | 中南大学主办 | IEEE出版 | EI 、Scopus稳定检索】第二届机电一体化、机器人与人工智能国际学术会议(MRAI 2026)
  • 从原理到实战:一文读懂随机森林(Random Forest)的集成智慧
  • 零基础制作宠物行业小程序 - 码云数智
  • 宠物服务小程序搭建步骤 - 码云数智
  • 【运维实战】企业级VSFTPD 文件服务 一键自动化部署方案 (适配银河麒麟 V10 /openEuler /CentOS)
  • 别再只输密码了!手把手教你用Windows 11连接公司WPA2-Enterprise企业WiFi(含EAP-PEAP配置)
  • 终极指南:用Android手机变身为专业USB键盘鼠标的完整解决方案
  • 【超简单教程】OpenClaw 2.6.4 本地 AI 零代码建站实战(内含安装包)
  • 2026NMN行业深度科普:从原理、选购标准到优质产品全解析 - 资讯焦点
  • Dify车载问答调试黄金 checklist(覆盖Qwen-2-VL+RAG+边缘缓存全链路)
  • 美业小程序怎么制作,助力门店实现数字化升级 - 码云数智
  • 地热井水位监测仪厂家排行榜 源头品牌推荐 - WHSENSORS
  • 别再折腾图数据增强了!用SimGCL/XSimGCL在PyTorch里5分钟搞定对比学习推荐
  • 2026 年成都五大 GEO 优化服务商深度盘点:AI 搜索时代本土增长引擎甄选 - GEO优化
  • P15940 [JOI Final 2026] 花园 3 / Garden 3
  • 告别许可证错误!深度解析UG NX安装后lmtools服务配置与菜单栏去水印实战
  • 3种模式实战VoiceFixer:从噪音录音到清晰人声的AI修复指南
  • 拯救者笔记本终极优化指南:Lenovo Legion Toolkit 完整使用教程
  • 加密结果看起来像正常汉字——我做了一个加密工具(密语盒子开发笔记)
  • # 034、AutoSAR OTA软件更新设计与实现:从深夜告警到量产落地
  • CF1810G题解
  • 从原理图到代码:手把手教你用STM32F103C8T6最小系统板驱动矩阵键盘做密码锁
  • 如何彻底告别网盘限速:8大平台直链下载助手完全指南
  • 从设计动机,决策链一步步推出 Shared ptr
  • 2026年上海五大GEO优化服务商深度盘点TOP机构 - GEO优化
  • Mplus链式中介实战:从模型设定到效应检验的完整指南
  • DeepSeek V4 这周发!梁文锋扛不住了