当前位置: 首页 > news >正文

《动手学深度学习》-48全连接卷积神经网络FCN实现

全连接神经网络通过卷积神经网络CNN实现特征提取,然后通过1x1的卷积将通道数转换为类别个数,最后通过转置卷积层将图像的高宽变换为原输入图的尺寸大小

一、代码

1.构建net

(1)框架

pretrained_net=torchvision.models.resnet18(pretrained=True) # print(list(pretrained_net.children())[-3:])#最后两层为AdaptiveAvgPool2d、Linear去掉 net=nn.Sequential(*list(pretrained_net.children())[:-2])
num_classes=21 net.add_module('final_conv',nn.Conv2d(in_features=512, out_features=num_classes,kernel_size=1)) net.add_module('Transposed_conv',nn.ConvTranspose2d(num_classes,num_classes,kernel_size=64,padding=16,stride=32))

(2)初始化

def bilinear_kernel(in_channel,out_channel,kernel_size): factor=(kernel_size+1)//2 #上采样放大倍数 if kernel_size %2==1: center=factor-1 else: center=factor-0.5 og=(torch.arange(kernel_size).reshape(-1,1),torch.arange(kernel_size).reshape(1,-1))#og[0]是行向量kx1,ogp[1]列向量1xk,广播之后变成kxk, filt=(1-torch.abs(og[0]-center)/factor)*(1-torch.abs(og[1]-center)/factor)#kxk的矩阵,中心大,周围小 weight=torch.zeros((in_channel,out_channel,kernel_size,kernel_size)) weight[range(in_channel),range(out_channel),:,:]=filt#让输入通道c只影响同编号C’输出,不进行混合,只改变对角线上的K初始化 return weight
W=bilinear_kernel(num_classes,num_classes,64) net.Transposed_conv.weight.data.copy_(W)

(3)测试

conv_transopsed=nn.ConvTranspose2d(3,3,kernel_size=4,padding=1,stride=2,bias=False) conv_transopsed.weight.data.copy_(bilinear_kernel(3,3,4)) img=torchvision.transforms.ToTensor()(Image.open('D:/PycharmDocument/limu/data/dogcat.png').convert('RGB')) X=img.unsqueeze(0) Y=conv_transopsed(X) out_img=Y[0].permute(1,2,0).detach() print('input image shape',img.permute(1,2,0).shape) print('output image shape',out_img.shape) d2l.set_figsize() fig,axes=plt.subplots(1,2) axes[0].imshow(img.permute(1,2,0)) axes[0].set_title('input image') axes[1].imshow(out_img) axes[1].set_title('output image') d2l.plt.show()

输入一张图,采用conv_transopsed操作,看一下大小,可以看出经过转置卷积,输出图片尺寸大一倍,

2.读取数据

batch_size,crop_size=36,(320,480)
train_iter,test_iter=test46SemanticSegmentation.load_data_voc(batch_size=batch_size,crop_size=crop_size)
voc_dir = 'D:/VOCtrainval_11-May-2012/VOCdevkit/VOC2012'
def read_voc_images(voc_dir, is_train=True):
"""读取所有VOC图像并标注"""
# 这里代码会自动拼路径:voc_dir + ImageSets + Segmentation + train.txt
txt_fname = os.path.join(voc_dir, 'ImageSets', 'Segmentation',
'train.txt' if is_train else 'val.txt')
mode = torchvision.io.image.ImageReadMode.RGB
with open(txt_fname, 'r') as f:
images = f.read().split()
features, labels = [], []
for i, fname in enumerate(images):
# 读取原始图片
features.append(torchvision.io.read_image(os.path.join(
voc_dir, 'JPEGImages', f'{fname}.jpg')))
# 读取语义分割标签图
labels.append(torchvision.io.read_image(os.path.join(
voc_dir, 'SegmentationClass' ,f'{fname}.png'), mode))
return features, labels

3.训练

def loss(inputs,targets):
return F.cross_entropy(inputs,targets,reduction='none').mean(1).mean(1)
num_epochs,lr,wd,device=5,0.01,1e-3,d2l.try_gpu()
trainer=torch.optim.SGD(net.parameters(),lr=lr,weight_decay=wd)
d2l.train_ch3(net,trainer,num_epochs,batch_size,device)
4.预测
def predect(img):
X=test_iter.dataset.normalize_image(img).unsqueeze(0)#(1,3,h,w,)
pred=net(X.to(device)).argmax(dim=1)#(1,h,w)
return pred.reshape(pred.shape[1],pred.shape[2])#(h,w)
#根据类别反向找对应的rgb,将像素点涂对应的颜色
def label2image(pred):
colormap=torch.tensor(test46SemanticSegmentation.VOC_COLORMAP,device=device)
X=pred.long()
return colormap[X,:]
test_images,test_labels=read_voc_images(voc_dir,is_train=False)
n,imags=4,[]
for i in range(n):
crop_rect=(0,0,320,480)
X=torchvision.transforms.functional.crop(test_images[i],*crop_rect)
pred=label2image(predect(X))
imags+=[X.permute(1,2,0),pred.cpu(),torchvision.transforms.functional.crop(test_labels[i],*crop_rect).permute(1,2,0)]
d2l.show_images(imags[::3]+imags[1::3]+imags[2::3],3,n,scale=2)

http://www.jsqmd.com/news/264362/

相关文章:

  • QSPI地址与数据复用总线原理:图解说明多路复用
  • Emotion2Vec+ Large情感类型有哪些?9类Emoji标签详细解读
  • RetinaFace魔改实战:基于预装环境快速实现GhostNet轻量化改造
  • 小白也能懂的YOLOE目标检测:官版镜像保姆级教程
  • GLM-4.6V-Flash-WEB轻量秘籍:如何在低配环境高效运行?
  • LangFlow智能招聘系统:HR的AI面试官搭建指南
  • 机器学习中的性能指标
  • 全网最全8个AI论文平台,本科生搞定毕业论文!
  • Speech Seaco Paraformer ASR代码实例:调用API实现自动化语音转写
  • 社交网络影响力分析:大数据方法与实践
  • 初学者掌握 claude code 的一些进阶知识
  • 如何通过服装管理ERP软件实现生产流程的高效优化?
  • 打包 Python 项目
  • 搞定提示工程优化文本生成
  • 尺寸约束下商业卫星编码器系统的抗辐照MCU性能边界研究
  • 无人驾驶物流车网关的多路CANFD冗余架构与通信可靠性分析
  • json库使用教程
  • 西门子PLC S7-1200实现4ms精准周期数据采集(带时间戳)
  • 2026.1.15总结
  • 2026年普通人有什么机会?
  • Linux操作系统(1)
  • P1119 灾后重建
  • Linux操作系统(3)
  • <Linux基础第5集>关于apt命令的细节
  • Linux操作系统(2)
  • 11-3 register integration
  • 智能驾驶三剑客:NDS、KIWI与ADASIS
  • day147—递归—二叉树的最近公共祖先(LeetCode-236)
  • 题解:P9353 [JOI 2023 Final] 现代机器 / Modern Machine
  • 12款论文AI工具横向对比:数学建模论文复现效率提升与格式优化方法