当前位置: 首页 > news >正文

[深度学习网络从入门到入土] 残差网络ResNet

[深度学习网络从入门到入土] 残差网络ResNet

个人导航

知乎:https://www.zhihu.com/people/byzh_rc

CSDN:https://blog.csdn.net/qq_54636039

注:本文仅对所述内容做了框架性引导,具体细节可查询其余相关资料or源码

参考文章:各方资料

文章目录

  • [深度学习网络从入门到入土] 残差网络ResNet
  • 个人导航
  • 参考资料
  • 背景
  • 架构(公式)
        • 1. ==BasicBlock(ResNet18/34)==
        • 2. ==Bottleneck(ResNet50/101/152)==
        • 3. ==Shortcut 类型==
  • 创新点
        • 1. ==残差连接==(Skip Connection)
        • 2. 可训练超深网络
        • 3. 结构简洁但极强
  • 为什么 ResNet 能训练 152 层
  • 代码实现
  • 项目实例

参考资料

Deep Residual Learning for Image Recognition.

背景

在 2014–2015 年,深度 CNN 进入“越深越好”的阶段:

  • AlexNet:8 层
  • VGG:16–19 层
  • GoogLeNet:22 层

问题来了:当网络超过 20 层后,训练误差反而上升

这不是过拟合,而是优化困难(degradation problem)

resnet横空出世:让网络学习“残差”,而不是直接学习映射

传统网络:
H ( x ) = F ( x ) H(x)=F(x)H(x)=F(x)
ResNet:
H ( x ) = F ( x ) + x H(x) = F(x) + xH(x)=F(x)+x

架构(公式)

1.BasicBlock(ResNet18/34)
Conv → BN → ReLU Conv → BN + Shortcut → ReLU

y = ReLU ( F ( x ) + x ) F ( x ) = W 2 σ ( W 1 x ) y = \text{ReLU}(F(x) + x) \\ F(x) = W_2 \sigma(W_1 x)y=ReLU(F(x)+x)F(x)=W2σ(W1x)

2.Bottleneck(ResNet50/101/152)

当网络变得非常深时,使用瓶颈结构:

1×1(降维) → 3×3(提取特征) → 1×1(升维)

F ( x ) = W 3 σ ( W 2 σ ( W 1 x ) ) F(x) = W_3 \sigma(W_2 \sigma(W_1 x))F(x)=W3σ(W2σ(W1x))

3.Shortcut 类型

情况1:尺寸相同
y = F ( x ) + x y = F(x) + xy=F(x)+x
情况2:尺寸不同(下采样)
y = F ( x ) + W s x W s = 1 × 1 Conv y = F(x) + W_s x \\ \color{purple}{W_s = 1\times1 \text{ Conv}}y=F(x)+WsxWs=1×1Conv

创新点

1.残差连接(Skip Connection)

允许梯度直接传播

2. 可训练超深网络

152 层首次成功训练

3. 结构简洁但极强

成为后续几乎所有视觉网络的基础(DenseNet, U-Net)

为什么 ResNet 能训练 152 层

残差网络的理论基础:
∂ y ∂ x = ∂ F ( x ) ∂ x + 1 \frac{\partial y}{\partial x} = \frac{\partial F(x)}{\partial x} + 1xy=xF(x)+1
即梯度中始终存在 “+1” 项:

  • 梯度不会消失
  • 网络可以直接传递恒等映射

代码实现

importtorchimporttorch.nnasnnimporttorch.nn.functionalasFfrombyzh.ai.Butilsimportb_get_paramsclassBasicBlock(nn.Module):""" 给 ResNet18/34 用 """expansion=1def__init__(self,in_ch,out_ch,stride=1):super().__init__()self.conv1=nn.Conv2d(in_ch,out_ch,3,stride,1,bias=False)self.bn1=nn.BatchNorm2d(out_ch)self.conv2=nn.Conv2d(out_ch,out_ch,3,1,1,bias=False)self.bn2=nn.BatchNorm2d(out_ch)self.shortcut=nn.Sequential()ifstride!=1orin_ch!=out_ch:self.shortcut=nn.Sequential(nn.Conv2d(in_ch,out_ch,1,stride,bias=False),nn.BatchNorm2d(out_ch))defforward(self,x):out=torch.relu(self.bn1(self.conv1(x)))out=self.bn2(self.conv2(out))out+=self.shortcut(x)out=torch.relu(out)returnoutclassBottleneck(nn.Module):""" 给 ResNet50/101/152 用 """expansion=4# 输出通道 = out_ch * 4def__init__(self,in_ch,out_ch,stride=1):super().__init__()# 1x1 降维self.conv1=nn.Conv2d(in_ch,out_ch,kernel_size=1,bias=False)self.bn1=nn.BatchNorm2d(out_ch)# 3x3 特征提取(这里做 stride 下采样)self.conv2=nn.Conv2d(out_ch,out_ch,kernel_size=3,stride=stride,padding=1,bias=False)self.bn2=nn.BatchNorm2d(out_ch)# 1x1 升维self.conv3=nn.Conv2d(out_ch,out_ch*self.expansion,kernel_size=1,bias=False)self.bn3=nn.BatchNorm2d(out_ch*self.expansion)self.shortcut=nn.Sequential()ifstride!=1orin_ch!=out_ch*self.expansion:self.shortcut=nn.Sequential(nn.Conv2d(in_ch,out_ch*self.expansion,kernel_size=1,stride=stride,bias=False),nn.BatchNorm2d(out_ch*self.expansion))defforward(self,x):out=torch.relu(self.bn1(self.conv1(x)))out=torch.relu(self.bn2(self.conv2(out)))out=self.bn3(self.conv3(out))out+=self.shortcut(x)out=torch.relu(out)returnoutclassResNet(nn.Module):""" input shape: (N, 3, 224, 224) """def__init__(self,block,layers,num_classes=1000):super().__init__()self.in_ch=64self.conv1=nn.Conv2d(3,64,7,2,3,bias=False)self.bn1=nn.BatchNorm2d(64)self.maxpool=nn.MaxPool2d(3,2,1)self.layer1=self._make_layer(block,64,layers[0])self.layer2=self._make_layer(block,128,layers[1],stride=2)self.layer3=self._make_layer(block,256,layers[2],stride=2)self.layer4=self._make_layer(block,512,layers[3],stride=2)self.avgpool=nn.AdaptiveAvgPool2d((1,1))self.fc=nn.Linear(512*block.expansion,num_classes)def_make_layer(self,block,out_ch,blocks,stride=1):layers=[]layers.append(block(self.in_ch,out_ch,stride))self.in_ch=out_ch*block.expansionfor_inrange(1,blocks):layers.append(block(self.in_ch,out_ch))returnnn.Sequential(*layers)defforward(self,x):x=torch.relu(self.bn1(self.conv1(x)))x=self.maxpool(x)x=self.layer1(x)x=self.layer2(x)x=self.layer3(x)x=self.layer4(x)x=self.avgpool(x)x=torch.flatten(x,1)x=self.fc(x)returnxclassB_ResNet18_Paper(ResNet):""" input shape: (N, 3, 224, 224) """def__init__(self,num_classes=1000):block=BasicBlock layers=[2,2,2,2]super().__init__(block=block,layers=layers,num_classes=num_classes)classB_ResNet34_Paper(ResNet):""" input shape: (N, 3, 224, 224) """def__init__(self,num_classes=1000):block=BasicBlock layers=[3,4,6,3]super().__init__(block=block,layers=layers,num_classes=num_classes)classB_ResNet50_Paper(ResNet):""" input shape: (N, 3, 224, 224) """def__init__(self,num_classes=1000):block=Bottleneck layers=[3,4,6,3]super().__init__(block=block,layers=layers,num_classes=num_classes)classB_ResNet101_Paper(ResNet):""" input shape: (N, 3, 224, 224) """def__init__(self,num_classes=1000):block=Bottleneck layers=[3,4,23,3]super().__init__(block=block,layers=layers,num_classes=num_classes)classB_ResNet152_Paper(ResNet):""" input shape: (N, 3, 224, 224) """def__init__(self,num_classes=1000):block=Bottleneck layers=[3,8,36,3]super().__init__(block=block,layers=layers,num_classes=num_classes)if__name__=='__main__':# ResNet18net=B_ResNet18_Paper(num_classes=1000)a=torch.randn(50,3,224,224)result=net(a)print(result.shape)print(f"参数量:{b_get_params(net)}")# 11_689_512# ResNet34net=B_ResNet34_Paper(num_classes=1000)a=torch.randn(50,3,224,224)result=net(a)print(result.shape)print(f"参数量:{b_get_params(net)}")# 21_797_672# ResNet50net=B_ResNet50_Paper(num_classes=1000)a=torch.randn(50,3,224,224)result=net(a)print(result.shape)print(f"参数量:{b_get_params(net)}")# 25_557_032# ResNet101net=B_ResNet101_Paper(num_classes=1000)a=torch.randn(50,3,224,224)result=net(a)print(result.shape)print(f"参数量:{b_get_params(net)}")# 44_549_160# ResNet152net=B_ResNet152_Paper(num_classes=1000)a=torch.randn(50,3,224,224)result=net(a)print(result.shape)print(f"参数量:{b_get_params(net)}")# 60_192_808

项目实例

库环境:

numpy==1.26.4 torch==2.2.2cu121 byzh-core==0.0.9.21 byzh-ai==0.0.9.53 byzh-extra==0.0.9.12 ...

ResNet18训练MNIST数据集:

# copy all the codes from here to runimporttorchimporttorch.nn.functionalasFfromuploadToPypi_ai.byzh.ai.Bdataimportb_stratified_indicesfrombyzh.ai.BtrainerimportB_Classification_Trainerfrombyzh.ai.BdataimportB_Download_MNIST,b_get_dataloader_from_tensor# from uploadToPypi_ai.byzh.ai.Bmodel.study_cnn import B_ResNet18_Paperfrombyzh.ai.Bmodel.study_cnnimportB_ResNet18_Paperfrombyzh.ai.Butilsimportb_get_device##### hyper params #####epochs=10lr=1e-3batch_size=32device=b_get_device(use_idle_gpu=True)##### data #####downloader=B_Download_MNIST(save_dir='D:/study_cnn/datasets/MNIST')data_dict=downloader.get_data()X_train=data_dict['X_train_standard']y_train=data_dict['y_train']X_test=data_dict['X_test_standard']y_test=data_dict['y_test']num_classes=data_dict['num_classes']num_samples=data_dict['num_samples']indices=b_stratified_indices(y_train,num_samples//5)X_train=X_train[indices]X_train=F.interpolate(X_train,size=(224,224),mode='bilinear')X_train=X_train.repeat(1,3,1,1)y_train=y_train[indices]indices=b_stratified_indices(y_test,num_samples//5)X_test=X_test[indices]X_test=F.interpolate(X_test,size=(224,224),mode='bilinear')X_test=X_test.repeat(1,3,1,1)y_test=y_test[indices]train_dataloader,val_dataloader=b_get_dataloader_from_tensor(X_train,y_train,X_test,y_test,batch_size=batch_size)##### model #####model=B_ResNet18_Paper(num_classes=num_classes)##### else #####optimizer=torch.optim.Adam(model.parameters(),lr=lr)criterion=torch.nn.CrossEntropyLoss()##### trainer #####trainer=B_Classification_Trainer(model=model,optimizer=optimizer,criterion=criterion,train_loader=train_dataloader,val_loader=val_dataloader,device=device)trainer.set_writer1('./runs/resnet18/log.txt')##### run #####trainer.train_eval_s(epochs=epochs)##### calculate #####trainer.draw_loss_acc('./runs/resnet18/loss_acc.png',y_lim=False)trainer.save_best_checkpoint('./runs/resnet18/best_checkpoint.pth')trainer.calculate_model()
http://www.jsqmd.com/news/392318/

相关文章:

  • 实用指南:【随手记】uniapp + V3 使用TailwindCss3
  • Code Review 2.0:当AI助理在我的PR下留言“建议重构”,我默默点了Resolve
  • 2026-01-19-论文阅读-Agentic-Reasoning-for-Large-Language-Models - 详解
  • OpenEuler 22.03安装mysql
  • 如何为不同紧急场景选开锁服务?2026年24小时开锁全面评测与推荐,直击响应慢与价格不透明痛点 - 品牌推荐
  • 如何为不同场景选开锁服务?2026年24小时开锁全面评测与推荐,直击响应慢痛点 - 品牌推荐
  • 2026年淄博管道疏通推荐:居家应急与市政维护场景深度评测排名,解决堵塞与清淤痛点 - 品牌推荐
  • Redis哨兵机制
  • 为什么需要哨兵机制?
  • Python基于微信小程序的停车场预约自助停取车系统
  • Python基于微信小程序的健康卫生医院导诊咨询交流平台
  • 如何为不同场景选开锁服务?2026年24小时上门开锁全面评测与推荐,直击响应慢痛点 - 品牌推荐
  • 管道疏通服务哪家强?2026年株洲管道疏通推荐排名解决响应慢痛点 - 品牌推荐
  • Tire前缀树应用
  • 郑州管道疏通哪家强?2026年郑州管道疏通服务排名与推荐,解决响应慢与施工安全隐忧 - 品牌推荐
  • 开锁服务哪个靠谱?2026年24小时上门开锁推荐与排名解决响应慢痛点 - 品牌推荐
  • 开锁修锁换锁哪家强?2026年服务商推荐与排名,解决价格不透明与信任痛点 - 品牌推荐
  • 如何为不同场景选疏通服务?2026年郑州管道疏通全面评测与推荐,直击响应慢与效果差痛点 - 品牌推荐
  • 如何选择2026年淄博管道疏通服务?场景化评测与推荐直击痛点 - 品牌推荐
  • 2026年长沙管道疏通推荐:基于多场景实测评价,解决堵塞与异味核心痛点 - 品牌推荐
  • 重庆管道疏通哪家靠谱?2026年服务商推荐评测,针对复杂堵塞与安全痛点 - 品牌推荐
  • K8S的HorizontalPodAutoscaler
  • 管道疏通服务哪家强?2026年珠海管道疏通推荐与排名,直击响应慢与效果差痛点 - 品牌推荐
  • 如何选择西安管道疏通服务?2026年服务商推荐与综合性能评价 - 品牌推荐
  • 宜宾管道疏通哪家强?2026年宜宾管道疏通推荐与排名,解决复杂堵塞与安全隐忧痛点 - 品牌推荐
  • 如何为不同堵塞场景选服务商?2026年长沙管道疏通全面评测与推荐,直击效率与安全痛点 - 品牌推荐
  • 管道疏通服务如何选?2026年中山管道疏通推荐与评价,直击响应慢与效果差痛点 - 品牌推荐
  • 宜宾管道疏通哪家专业?2026年服务商排名与推荐,解决复杂堵塞与安全隐忧核心痛点 - 品牌推荐
  • 管道疏通哪家靠谱?2026年西安管道疏通服务推荐排名解决质量隐忧 - 品牌推荐
  • 管道疏通哪家靠谱?2026年扬州管道疏通服务推荐与专业评价 - 品牌推荐