PyTorch中通过训练图像去雾数据集 建立基于SFNet图像去雾算法的完整系统
PyTorch中通过训练图像去雾数据集 建立基于SFNet图像去雾算法的完整系统
文章目录
- (a) 整体架构
- (b) 浅层特征提取
- (c) ResBlock
- (d) Decoupler
- (e) Modulator
- 总结
- 1. 环境配置
- 2. 数据集准备
- 3. SFNet模型定义
- 4. 数据加载与预处理
- 5. 模型训练
- 6. 界面代码
- 1. `main.py` - 训练和测试脚本
- 2. `SFNet_model.py` - SFNet模型定义
- 3. `GUI.py` - GUI界面代码
- 运行步骤
以下文字及代码仅供参考。
SFNet图像去雾算法 PyTorch 附图像去雾数据集
基于SFNet图像去雾算法的完整系统,包括环境配置、数据集准备、模型训练、优化以及界面代码
深度学习模型用于图像恢复(如去雾、超分辨率等)的详细设计。让我们深入解析这个架构的各个部分。
(a) 整体架构
整体架构展示了模型如何处理输入的降质图像并输出恢复后的图像。流程如下:
- 输入层:接收降质图像。
- 浅层特征提取:通过一个
Conv 3x3卷积层提取浅层特征。 - ResBlock堆叠:多个残差块(ResBlocks)被串联起来,每个ResBlock内部包含复杂的特征学习机制(见© ResBlock)。这些ResBlocks负责学习更深层次的特征表示。
- 上采样与下采样:在某些ResBlocks之间,使用
Conv 1x1进行通道调整,并通过箭头指示的上采样或下采样操作来改变特征图的空间尺寸。 - 最终恢复:经过一系列特征学习后,通过
Conv 3x3层生成最终的恢复图像。
(b) 浅层特征提取
浅层特征提取模块主要由几个基础的卷积操作组成:
Conv 3x3:标准的3x3卷积核,用于提取局部特征。Conv 1x1:1x1卷积用于调整通道数,不改变空间维度。MCBF和MDSF:可能是特定的多尺度融合模块,用于结合不同尺度的信息。
© ResBlock
ResBlock是整个网络的核心组件,它包括:
- 多个
Conv 3x3层,用于逐层提取特征。 Decoupler和Modulator模块(见(d)和(e)),用于解耦和调制特征,增强模型的表达能力。- 残差连接(用⊕符号表示),将输入直接加到输出上,有助于缓解梯度消失问题。
(d) Decoupler
Decoupler模块的作用是将输入特征分解为两部分:
GAP(全局平均池化):获取全局信息。Split:将特征分为两部分,分别进行不同的处理。Invert:可能是一个逆变换操作,用于恢复或转换特征。Concat:将处理后的特征重新拼接在一起。
(e) Modulator
Modulator模块对特征进行调制:
Sum、GAP、FC(全连接层)、Concat、Softmax、Split等操作共同作用,实现对特征的非线性变换和选择性增强。- 这些操作有助于模型关注更重要的特征,抑制不重要的信息。
总结
该模型通过多层次的特征提取和复杂的特征调制机制,能够有效地从降质图像中恢复出高质量的图像。其设计考虑了特征的多尺度融合、深度残差学习以及特征的动态调制,体现了现代深度学习模型在图像恢复任务中的先进性和复杂性。
1. 环境配置
首先确保你的环境中安装了必要的库:
pipinstalltorch torchvision opencv-python pillow PyQt52. 数据集准备
假设你已经有了RSHAZE或其他图像去雾数据集,并且已经按照以下结构组织好:
data/ train/ hazy/ gt/ test/ hazy/ gt/3. SFNet模型定义
这里我们简化地展示一个基础的SFNet模型定义(实际应用中请参考官方或相关论文中的具体实现):
importtorchimporttorch.nnasnnclassSFNet(nn.Module):def__init__(self):super(SFNet,self).__init__()self.encoder=nn.Sequential(nn.Conv2d(3,64,kernel_size=3,padding=1),nn.ReLU(),nn.Conv2d(64,128,kernel_size=3,padding=1),nn.ReLU())self.decoder=nn.Sequential(nn.Conv2d(128,64,kernel_size=3,padding=1),nn.ReLU(),nn.Conv2d(64,3,kernel_size=3,padding=1))defforward(self,x):x=self.encoder(x)x=self.decoder(x)returnx4. 数据加载与预处理
fromtorch.utils.dataimportDataset,DataLoaderfromPILimportImageimportosfromtorchvisionimporttransformsclassDehazeDataset(Dataset):def__init__(self,hazy_dir,gt_dir,transform=None):self.hazy_images=sorted([os.path.join(hazy_dir,img)forimginos.listdir(hazy_dir)])self.gt_images=sorted([os.path.join(gt_dir,img)forimginos.listdir(gt_dir)])self.transform=transformdef__len__(self):returnlen(self.hazy_images)def__getitem__(self,idx):hazy_image=Image.open(self.hazy_images[idx]).convert('RGB')gt_image=Image.open(self.gt_images[idx]).convert('RGB')ifself.transform:hazy_image=self.transform(hazy_image)gt_image=self.transform(gt_image)returnhazy_image,gt_image transform=transforms.Compose([transforms.Resize((256,256)),transforms.ToTensor()])train_dataset=DehazeDataset('data/train/hazy','data/train/gt',transform=transform)train_loader=DataLoader(train_dataset,batch_size=8,shuffle=True)5. 模型训练
model=SFNet()criterion=nn.MSELoss()optimizer=torch.optim.Adam(model.parameters(),lr=0.001)num_epochs=10forepochinrange(num_epochs):fori,(hazy,gt)inenumerate(train_loader):optimizer.zero_grad()outputs=model(hazy)loss=criterion(outputs,gt)loss.backward()optimizer.step()print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss:{loss.item()}')6. 界面代码
SFNet图像去雾系统,包括训练、测试和推理(GUI界面),我们需要编写多个Python脚本文件。以下是详细的代码示例:
1.main.py- 训练和测试脚本
importargparseimportosimporttorchimporttorch.nnasnnfromtorchvisionimporttransformsfromtorch.utils.dataimportDataLoaderfromPILimportImageimportnumpyasnpfromSFNet_modelimportSFNet# 假设SFNet模型定义在SFNet_model.py中classDehazeDataset(Dataset):def__init__(self,hazy_dir,gt_dir,transform=None):self.hazy_images=sorted([os.path.join(hazy_dir,img)forimginos.listdir(hazy_dir)])self.gt_images=sorted([os.path.join(gt_dir,img)forimginos.listdir(gt_dir)])self.transform=transformdef__len__(self):returnlen(self.hazy_images)def__getitem__(self,idx):hazy_image=Image.open(self.hazy_images[idx]).convert('RGB')gt_image=Image.open(self.gt_images[idx]).convert('RGB')ifself.transform:hazy_image=self.transform(hazy_image)gt_image=self.transform(gt_image)returnhazy_image,gt_imagedeftrain(args):transform=transforms.Compose([transforms.Resize((256,256)),transforms.ToTensor()])dataset=DehazeDataset(os.path.join(args.data_dir,'train','hazy'),os.path.join(args.data_dir,'train','gt'),transform=transform)dataloader=DataLoader(dataset,batch_size=args.batch_size,shuffle=True)model=SFNet().cuda()criterion=nn.MSELoss()optimizer=torch.optim.Adam(model.parameters(),lr=args.learning_rate)forepochinrange(args.num_epoch):fori,(hazy,gt)inenumerate(dataloader):hazy,gt=hazy.cuda(),gt.cuda()optimizer.zero_grad()outputs=model(hazy)loss=criterion(outputs,gt)loss.backward()optimizer.step()print(f'Epoch [{epoch+1}/{args.num_epoch}], Step [{i+1}/{len(dataloader)}], Loss:{loss.item()}')torch.save(model.state_dict(),f'results/SFNet/{args.data}/Training-Results/Epoch_{epoch+1}.pkl')deftest(args):transform=transforms.Compose([transforms.Resize((256,256)),transforms.ToTensor()])dataset=DehazeDataset(os.path.join(args.data_dir,'test','hazy'),os.path.join(args.data_dir,'test','gt'),transform=transform)dataloader=DataLoader(dataset,batch_size=args.batch_size,shuffle=False)model=SFNet().cuda()model.load_state_dict(torch.load(args.test_model))model.eval()withtorch.no_grad():fori,(hazy,gt)inenumerate(dataloader):hazy,gt=hazy.cuda(),gt.cuda()outputs=model(hazy)ifargs.save_image:forjinrange(outputs.size(0)):output_img=transforms.ToPILImage()(outputs[j].cpu())output_img.save(f'results/SFNet/{args.data}/Test-Results/image_{i*args.batch_size+j}.png')if__name__=='__main__':parser=argparse.ArgumentParser(description='SFNet Image Dehazing')parser.add_argument('--data_dir',type=str,required=True,help='directory of the dataset')parser.add_argument('--data',type=str,required=True,help='dataset name')parser.add_argument('--mode',type=str,required=True,choices=['train','test'],help='train or test mode')parser.add_argument('--batch_size',type=int,default=4,help='batch size')parser.add_argument('--learning_rate',type=float,default=2e-5,help='learning rate')parser.add_argument('--num_epoch',type=int,default=300,help='number of epochs')parser.add_argument('--test_model',type=str,default='',help='path to the trained model for testing')parser.add_argument('--save_image',type=bool,default=False,help='whether to save dehazed images')args=parser.parse_args()ifargs.mode=='train':train(args)elifargs.mode=='test':test(args)2.SFNet_model.py- SFNet模型定义
importtorchimporttorch.nnasnnclassSFNet(nn.Module):def__init__(self):super(SFNet,self).__init__()self.encoder=nn.Sequential(nn.Conv2d(3,64,kernel_size=3,padding=1),nn.ReLU(),nn.Conv2d(64,128,kernel_size=3,padding=1),nn.ReLU())self.decoder=nn.Sequential(nn.Conv2d(128,64,kernel_size=3,padding=1),nn.ReLU(),nn.Conv2d(64,3,kernel_size=3,padding=1))defforward(self,x):x=self.encoder(x)x=self.decoder(x)returnx3.GUI.py- GUI界面代码
importsysfromPyQt5.QtWidgetsimportQApplication,QWidget,QPushButton,QVBoxLayout,QLabel,QFileDialogfromPyQt5.QtGuiimportQPixmapimportcv2importnumpyasnpimporttorchfromtorchvisionimporttransformsfromSFNet_modelimportSFNetclassDehazeApp(QWidget):def__init__(self):super().__init__()self.initUI()definitUI(self):self.setWindowTitle('图像去雾')self.setGeometry(100,100,800,400)layout=QVBoxLayout()self.btn_select=QPushButton('选择图像',self)self.btn_select.clicked.connect(self.select_image)layout.addWidget(self.btn_select)self.btn_dehaze=QPushButton('SFNet去雾',self)self.btn_dehaze.clicked.connect(self.dehaze_image)layout.addWidget(self.btn_dehaze)self.image_label=QLabel(self)layout.addWidget(self.image_label)self.setLayout(layout)defselect_image(self):options=QFileDialog.Options()fileName,_=QFileDialog.getOpenFileName(self,"选择图像","","Images (*.png *.xpm *.jpg *.bmp);;All Files (*)",options=options)iffileName:self.image_path=fileName pixmap=QPixmap(fileName)self.image_label.setPixmap(pixmap.scaled(400,400))defdehaze_image(self):ifhasattr(self,'image_path'):# Load and preprocess imageimage=cv2.imread(self.image_path)image=cv2.resize(image,(256,256))image=image/255.0image=np.transpose(image,(2,0,1))image=torch.tensor(image,dtype=torch.float32).unsqueeze(0).cuda()# Load pre-trained modelmodel=SFNet().cuda()model.load_state_dict(torch.load('results/SFNet/Outdoor/Training-Results/Best.pkl'))model.eval()# Perform dehazingwithtorch.no_grad():output=model(image).squeeze().cpu().numpy()output=np.transpose(output,(1,2,0))output=(output*255).astype(np.uint8)# Display resultcv2.imwrite('dehazed.jpg',output)pixmap=QPixmap('dehazed.jpg')self.image_label.setPixmap(pixmap.scaled(400,400))if__name__=='__main__':app=QApplication(sys.argv)ex=DehazeApp()ex.show()sys.exit(app.exec_())运行步骤
训练模型:
python main.py--data_dirdehaze--dataOutdoor--modetrain--batch_size4--learning_rate2e-5--num_epoch300测试模型:
python main.py--data_dirdehaze--dataOutdoor--modetest--batch_size4--test_modelresults/SFNet/Outdoor/Training-Results/Best.pkl--save_imageTrue运行GUI界面:
python GUI.py
确保所有路径正确,并根据实际情况调整参数和文件路径。
