当前位置: 首页 > news >正文

别再死记ResNet结构了!用PyTorch手搓一个ResNet-18,带你彻底搞懂残差连接

用PyTorch手搓ResNet-18:从代码实现透视残差连接的本质

残差网络(ResNet)自2015年问世以来,一直是计算机视觉领域的基石模型。但很多开发者对它的理解停留在"跳跃连接"这个表面概念上,真正动手实现时才发现诸多细节问题:为什么有的残差块用1x1卷积?维度不匹配时如何处理?Basic Block和Bottleneck Block究竟有什么区别?今天我们就用PyTorch从零构建一个ResNet-18,在代码层面彻底搞懂这些核心问题。

1. 残差网络的设计哲学

深度神经网络在图像识别任务中表现出色,但当网络深度超过20层后,准确率不升反降。这种现象并非过拟合导致,而是源于梯度消失——深层网络在反向传播时,梯度信号经过多层传递后逐渐衰减直至消失。ResNet的创新之处在于提出了残差学习框架,让网络能够学习输入与输出之间的残差(即变化部分),而非直接学习完整的映射。

残差块的核心公式简单优雅:

output = F(x) + x

其中F(x)是需要学习的残差映射,x是恒等映射。当网络已经达到最优状态时,理论上可以让F(x)趋近于0,此时网络就退化为恒等映射,避免了性能退化。

在PyTorch中实现这个思想时,需要考虑几个关键点:

  • F(x)x的维度不一致时,需要用1x1卷积调整通道数
  • 残差块内部通常采用"卷积-BN-ReLU"的标准组合
  • 最终输出前需要再次经过ReLU激活

2. 构建Basic Block:ResNet-18的核心组件

ResNet-18使用的是Basic Block结构,每个残差块包含两个3x3卷积层。我们先实现这个基础构件:

import torch import torch.nn as nn class BasicBlock(nn.Module): expansion = 1 # 通道数扩展系数 def __init__(self, in_channels, out_channels, stride=1): super().__init__() # 第一个卷积层 self.conv1 = nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False ) self.bn1 = nn.BatchNorm2d(out_channels) # 第二个卷积层 self.conv2 = nn.Conv2d( out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False ) self.bn2 = nn.BatchNorm2d(out_channels) # 跳跃连接处理 self.shortcut = nn.Sequential() if stride != 1 or in_channels != out_channels * self.expansion: self.shortcut = nn.Sequential( nn.Conv2d( in_channels, out_channels * self.expansion, kernel_size=1, stride=stride, bias=False ), nn.BatchNorm2d(out_channels * self.expansion) ) def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = nn.ReLU()(out) out = self.conv2(out) out = self.bn2(out) # 处理维度匹配 residual = self.shortcut(residual) out += residual out = nn.ReLU()(out) return out

这个实现中有几个值得注意的技术细节:

  1. 维度匹配处理:当输入输出维度不一致时(通常发生在每个stage的第一个block),使用1x1卷积调整通道数和空间尺寸
  2. 批归一化:每个卷积层后都接BatchNorm,这是现代CNN的标准配置
  3. 残差相加:在相加前不进行激活,这是原始论文的设计

提示:Basic Block中的expansion参数是为了保持与Bottleneck Block的接口一致,在Basic Block中其值为1

3. 组装完整的ResNet-18架构

现在我们可以用Basic Block搭建完整的ResNet-18了。ResNet的网络结构遵循一个通用范式:

  1. 初始卷积层(较大的卷积核和下采样)
  2. 4个stage的残差块堆叠
  3. 全局平均池化和全连接层
class ResNet(nn.Module): def __init__(self, block, num_blocks, num_classes=1000): super().__init__() self.in_channels = 64 # 初始卷积层 self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # 四个stage的残差块 self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) # 分类头 self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(512 * block.expansion, num_classes) def _make_layer(self, block, out_channels, num_blocks, stride): strides = [stride] + [1] * (num_blocks - 1) layers = [] for stride in strides: layers.append(block(self.in_channels, out_channels, stride)) self.in_channels = out_channels * block.expansion return nn.Sequential(*layers) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = nn.ReLU()(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = torch.flatten(x, 1) x = self.fc(x) return x

创建ResNet-18实例的代码如下:

def resnet18(): return ResNet(BasicBlock, [2, 2, 2, 2])

这里[2,2,2,2]表示四个stage各自包含2个Basic Block,总计2*4=8个残差块,加上初始卷积层和最后的全连接层,正好是18层(每个Basic Block包含2个卷积层)。

4. 残差网络的训练技巧与可视化

实现网络结构只是第一步,要让ResNet真正发挥作用,还需要注意训练过程中的几个关键点:

4.1 初始化策略

残差网络对参数初始化比较敏感。推荐使用以下初始化方法:

def initialize_weights(model): for m in model.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0)

4.2 学习率调度

使用带热重启的余弦退火学习率(CosineAnnealingWarmRestarts)通常能取得不错的效果:

optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4) scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10)

4.3 梯度流动可视化

为了直观理解残差连接如何缓解梯度消失,我们可以可视化不同层的梯度范数:

def plot_gradient_flow(model): gradients = [] for name, param in model.named_parameters(): if param.grad is not None and 'weight' in name: gradients.append(param.grad.norm().item()) plt.figure(figsize=(10, 5)) plt.plot(gradients, alpha=0.3, color='b') plt.hlines(0, 0, len(gradients)+1, linewidth=1, color='k') plt.title('Gradient flow') plt.xlabel('Layers') plt.ylabel('Average gradient norm') plt.yscale('log')

与普通CNN相比,ResNet的梯度分布更加均匀,深层仍然能接收到较强的梯度信号。

5. ResNet变体与实战选择

虽然我们实现了ResNet-18,但ResNet家族还有多个重要变体:

模型层数残差块类型参数量(M)ImageNet Top-1 Acc
ResNet-1818Basic Block11.769.8%
ResNet-3434Basic Block21.873.3%
ResNet-5050Bottleneck25.676.2%
ResNet-101101Bottleneck44.577.4%
ResNet-152152Bottleneck60.278.0%

对于不同应用场景,选择建议如下:

  • 轻量级应用:ResNet-18/34,适合移动端或实时系统
  • 平衡型应用:ResNet-50,在精度和计算量间取得良好平衡
  • 高性能应用:ResNet-101/152,追求最高准确率

Bottleneck Block的实现与Basic Block类似,只是在两个3x3卷积之间增加了1x1卷积用于降维和升维:

class Bottleneck(nn.Module): expansion = 4 # 最终输出通道数是中间通道数的4倍 def __init__(self, in_channels, out_channels, stride=1): super().__init__() # 1x1卷积降维 self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(out_channels) # 3x3卷积 self.conv2 = nn.Conv2d( out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False ) self.bn2 = nn.BatchNorm2d(out_channels) # 1x1卷积升维 self.conv3 = nn.Conv2d( out_channels, out_channels * self.expansion, kernel_size=1, bias=False ) self.bn3 = nn.BatchNorm2d(out_channels * self.expansion) # 跳跃连接 self.shortcut = nn.Sequential() if stride != 1 or in_channels != out_channels * self.expansion: self.shortcut = nn.Sequential( nn.Conv2d( in_channels, out_channels * self.expansion, kernel_size=1, stride=stride, bias=False ), nn.BatchNorm2d(out_channels * self.expansion) ) def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = nn.ReLU()(out) out = self.conv2(out) out = self.bn2(out) out = nn.ReLU()(out) out = self.conv3(out) out = self.bn3(out) residual = self.shortcut(residual) out += residual out = nn.ReLU()(out) return out

在实际项目中,我通常先尝试ResNet-50作为基线模型,它提供了较好的精度与计算效率平衡。当需要更高精度时,会考虑使用ResNet-101,但要注意这会使训练时间显著增加。

http://www.jsqmd.com/news/905778/

相关文章:

  • STM32 HAL库三LED九种模式闪烁项目实战:从GPIO原理到工程优化
  • 2026年新都财务代理公司应该怎么选?五家财务公司服务全解析 - 速递信息
  • 基于Arduino与NRF24L01的无线遥控车DIY全攻略:从电路设计到代码实现
  • 弯头厂家哪家好主流厂商横评:近两年核心差异(含行业FAQ - 速递信息
  • PS 怎么去掉灰色水印?零基础保姆级完整解决方案
  • JSON.stringify() 方法详解
  • 2026年5月电磁流量计生产厂家推荐——污水测量哪款能真正获得市场认可?
  • 基于Arduino与红外传感器的DIY音乐盒:从传感器原理到嵌入式音乐合成
  • 基于OpenLIT实现三层 LLM Agent 可观测性的实践
  • STM32入门实战:从零开始用STM32CubeIDE实现LED闪烁
  • AI Agent 开发大比拼!2026年选型指南,Python仍是王者,TypeScript崛起,混合架构成主流!
  • 从‘像素对错’到‘结构好坏’:一个迭代细化技巧,让你的模型预测自己纠错(Topology Loss实战)
  • HarmonyOS 全局状态管理实战:GlobalContext 跨页面数据共享完全指南
  • 别再手动移植算法了!保姆级教程:用MATLAB Coder App把.m文件一键转成C静态库
  • 从一次线上宕机复盘说起:我是如何用JMeter压测,定位到RT暴增和QPS暴跌的罪魁祸首
  • 嵌入式Linux内存稳定性测试:手把手教你用memtester排查硬件‘暗病’(附RK3399实测)
  • SAP PS项目模板搭建保姆级教程:从CJ91到CN13,手把手教你构建企业核心资产
  • 创客教育实战:从电路设计到生活应用的跨学科项目指南
  • 咸阳华帝热水器燃气灶维修|秦都渭城世纪大道上门检修 - GrowthUME
  • 移动端电声乐器音频处理:从DSP算法到硬件接口的完整实现
  • Ka波段SIW接收机设计:实现立方星高速星间通信
  • 别再踩坑了!用mqtt.js连接MQTT时,WebSocket端口(8083/8084)和TCP端口(1883)到底怎么选?
  • Arduino红外传感器触发OLED显示系统:实现智能感应与节能显示
  • Python3 注释
  • 047、直播录制丢帧、音画不同步?实时 TS 切片写入、Buffer 缓冲与降级策略
  • 大厂面试高频考点!手把手拆解AI Agent工具调用与Function Calling原理及工程实践
  • Oracle 11g静默安装后,别忘了这几步:从创建用户到优化Redo Log的实战配置
  • IDEA生成UML类图保姆级教程:从快捷键到高级配置,看完就能用
  • 保姆级教程:手把手教你搞定Windows 10/11的远程开机(WOL),告别办公室加班
  • GRBL Plotter:从创意到现实的数控加工终极指南 [特殊字符]