用PyTorch搞定双输入图像分类:手把手教你从零搭建一个‘双胞胎’CNN模型
用PyTorch搞定双输入图像分类:手把手教你从零搭建一个‘双胞胎’CNN模型
当你面对需要同时分析两张相关联图片的任务时,比如商品的正反面照片、医疗影像的不同模态、或是同一物体的不同角度视图,传统的单输入卷积神经网络就显得力不从心了。本文将带你从零开始构建一个能够同时处理双输入图像的PyTorch模型,从数据准备到模型训练,每个步骤都有详细代码示例和实战技巧。
1. 理解双输入网络的核心设计
双输入卷积神经网络(Dual-Input CNN)与传统CNN的最大区别在于其能够并行处理两个独立的图像输入,并在网络的某个层级将两个信息流合并。这种架构特别适合以下场景:
- 互补信息分析:如商品正面和背面图像
- 多模态数据融合:如CT和MRI医疗影像
- 时间序列对比:如监控视频的连续帧
- 多角度视图:如3D物体的不同视角
模型设计关键点:
- 并行分支结构:两个独立的卷积分支分别处理不同输入
- 特征融合策略:通常在卷积层后通过拼接(concat)或相加(add)方式合并
- 共享权重考量:决定两个分支是否共享卷积核权重
# 双输入网络基础结构示例 class DualInputCNN(nn.Module): def __init__(self): super().__init__() # 分支1 self.branch1 = nn.Sequential( nn.Conv2d(3, 16, 3), nn.ReLU(), nn.MaxPool2d(2) ) # 分支2 self.branch2 = nn.Sequential( nn.Conv2d(3, 16, 3), nn.ReLU(), nn.MaxPool2d(2) ) # 合并后的全连接层 self.fc = nn.Linear(16*15*15*2, 10) # 假设合并后特征维度翻倍 def forward(self, x1, x2): x1 = self.branch1(x1) x2 = self.branch2(x2) x1 = x1.view(x1.size(0), -1) # 展平 x2 = x2.view(x2.size(0), -1) x = torch.cat((x1, x2), dim=1) # 沿特征维度拼接 return self.fc(x)注意:实际应用中需要根据输入图像尺寸调整全连接层的输入维度
2. 构建自定义数据集加载器
处理双输入图像时,标准的PyTorch数据集类需要做相应调整。我们需要确保每次迭代能同时返回两张图像和对应的标签。
数据集目录结构示例:
dataset/ ├── class1/ │ ├── image1.jpg │ ├── image1(1).jpg │ ├── image2.jpg │ ├── image2(1).jpg ├── class2/ │ ├── imageA.jpg │ ├── imageA(1).jpg ...from torch.utils.data import Dataset from PIL import Image import os class DualImageDataset(Dataset): def __init__(self, root_dir, transform=None, pair_suffix='(1)'): self.root_dir = root_dir self.transform = transform self.pair_suffix = pair_suffix self.classes = sorted(os.listdir(root_dir)) self.class_to_idx = {cls:i for i,cls in enumerate(self.classes)} self.samples = self._make_dataset() def _make_dataset(self): samples = [] for target_class in self.classes: class_dir = os.path.join(self.root_dir, target_class) if not os.path.isdir(class_dir): continue for filename in os.listdir(class_dir): if self.pair_suffix in filename: # 跳过配对文件 continue # 构建配对文件名 base, ext = os.path.splitext(filename) pair_name = f"{base}{self.pair_suffix}{ext}" pair_path = os.path.join(class_dir, pair_name) if os.path.exists(pair_path): samples.append(( os.path.join(class_dir, filename), pair_path, self.class_to_idx[target_class] )) return samples def __len__(self): return len(self.samples) def __getitem__(self, idx): img1_path, img2_path, label = self.samples[idx] img1 = Image.open(img1_path).convert('RGB') img2 = Image.open(img2_path).convert('RGB') if self.transform: img1 = self.transform(img1) img2 = self.transform(img2) return img1, img2, label数据增强技巧:
- 独立增强:对两张图像应用不同的随机变换
- 同步增强:保持对两张图像相同的几何变换
- 颜色空间处理:根据任务需求决定是否统一颜色处理
# 独立增强示例 transform = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(p=0.5), # 各自独立翻转 transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # 同步增强需要自定义实现 class PairTransform: def __init__(self): self.transform = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) self.geometric = transforms.RandomHorizontalFlip(p=0.5) def __call__(self, img1, img2): # 应用相同的几何变换 seed = torch.random.seed() torch.random.manual_seed(seed) img1 = self.geometric(img1) torch.random.manual_seed(seed) img2 = self.geometric(img2) # 独立应用颜色变换 img1 = self.transform(img1) img2 = self.transform(img2) return img1, img23. 高级模型架构设计
基础的双输入网络可以进一步优化,下面介绍几种进阶架构:
3.1 共享权重架构
当两个输入图像来自相同或相似域时,可以共享卷积核权重以减少参数量。
class SharedWeightDualCNN(nn.Module): def __init__(self): super().__init__() # 共享的卷积层 self.shared_conv = nn.Sequential( nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2) ) self.fc = nn.Sequential( nn.Linear(64*56*56*2, 512), # 假设输入为224x224 nn.ReLU(), nn.Linear(512, 10) ) def forward(self, x1, x2): x1 = self.shared_conv(x1) x2 = self.shared_conv(x2) x1 = x1.view(x1.size(0), -1) x2 = x2.view(x2.size(0), -1) x = torch.cat((x1, x2), dim=1) return self.fc(x)3.2 多层级特征融合
在不同卷积层级进行特征融合可以捕获多尺度信息。
class MultiLevelFusionCNN(nn.Module): def __init__(self): super().__init__() # 分支1 self.b1_conv1 = nn.Conv2d(3, 16, 3, padding=1) self.b1_conv2 = nn.Conv2d(16, 32, 3, padding=1) # 分支2 self.b2_conv1 = nn.Conv2d(3, 16, 3, padding=1) self.b2_conv2 = nn.Conv2d(16, 32, 3, padding=1) # 融合层 self.fusion_conv = nn.Conv2d(64, 64, 1) # 1x1卷积降维 self.fc = nn.Linear(64*56*56, 10) def forward(self, x1, x2): # 第一层 x1 = F.relu(self.b1_conv1(x1)) x2 = F.relu(self.b2_conv1(x2)) # 第二层 + 初级融合 x1 = F.relu(self.b1_conv2(x1)) x2 = F.relu(self.b2_conv2(x2)) x = torch.cat((x1, x2), dim=1) # 沿通道维度拼接 x = F.relu(self.fusion_conv(x)) # 全连接 x = x.view(x.size(0), -1) return self.fc(x)3.3 注意力融合机制
使用注意力机制动态决定两个分支特征的融合权重。
class AttentionFusion(nn.Module): def __init__(self, channels): super().__init__() self.attention = nn.Sequential( nn.Conv2d(channels*2, channels//2, 1), nn.ReLU(), nn.Conv2d(channels//2, 2, 1), nn.Softmax(dim=1) ) def forward(self, x1, x2): x = torch.cat((x1, x2), dim=1) attn_weights = self.attention(x) return x1 * attn_weights[:,0:1] + x2 * attn_weights[:,1:2] class AttentionFusionCNN(nn.Module): def __init__(self): super().__init__() self.branch1 = nn.Sequential( nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2) ) self.branch2 = nn.Sequential( nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2) ) self.fusion = AttentionFusion(32) self.fc = nn.Linear(32*112*112, 10) def forward(self, x1, x2): x1 = self.branch1(x1) x2 = self.branch2(x2) x = self.fusion(x1, x2) x = x.view(x.size(0), -1) return self.fc(x)4. 训练技巧与调试
双输入网络的训练过程有一些特殊注意事项:
学习率策略:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='max', # 根据准确率调整 patience=3, factor=0.5 )损失函数选择:
- 分类任务:CrossEntropyLoss
- 回归任务:MSELoss
- 对比学习:ContrastiveLoss
多GPU训练适配:
if torch.cuda.device_count() > 1: print(f"Using {torch.cuda.device_count()} GPUs!") model = nn.DataParallel(model) model.to(device)常见问题排查:
维度不匹配错误
- 检查两个分支的输出形状是否一致
- 验证拼接后的特征维度与全连接层输入是否匹配
训练不收敛
- 尝试降低学习率
- 检查数据增强是否过于激进
- 验证标签是否正确对应双输入
过拟合
- 增加Dropout层
- 使用更强的正则化
- 扩大训练数据集
训练循环示例:
def train_model(model, dataloaders, criterion, optimizer, scheduler, num_epochs=25): best_acc = 0.0 for epoch in range(num_epochs): print(f'Epoch {epoch}/{num_epochs-1}') print('-' * 10) for phase in ['train', 'val']: if phase == 'train': model.train() else: model.eval() running_loss = 0.0 running_corrects = 0 for inputs1, inputs2, labels in dataloaders[phase]: inputs1 = inputs1.to(device) inputs2 = inputs2.to(device) labels = labels.to(device) optimizer.zero_grad() with torch.set_grad_enabled(phase == 'train'): outputs = model(inputs1, inputs2) _, preds = torch.max(outputs, 1) loss = criterion(outputs, labels) if phase == 'train': loss.backward() optimizer.step() running_loss += loss.item() * inputs1.size(0) running_corrects += torch.sum(preds == labels.data) epoch_loss = running_loss / len(dataloaders[phase].dataset) epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset) print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}') if phase == 'val' and epoch_acc > best_acc: best_acc = epoch_acc torch.save(model.state_dict(), 'best_model.pth') if phase == 'val': scheduler.step(epoch_acc) print(f'Best val Acc: {best_acc:4f}') return model5. 实际应用案例
5.1 商品图像分类
场景:电商平台需要根据商品的主图和细节图自动分类商品类别。
解决方案:
- 主图分支:关注整体形状和颜色
- 细节图分支:提取文字和局部特征
- 晚期融合:在全连接层前合并特征
class ProductClassifier(nn.Module): def __init__(self, num_classes): super().__init__() # 主图分支 - 关注全局特征 self.main_branch = nn.Sequential( nn.Conv2d(3, 32, 3, stride=2, padding=1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2) ) # 细节图分支 - 更高分辨率处理 self.detail_branch = nn.Sequential( nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2) ) # 融合分类头 self.classifier = nn.Sequential( nn.Linear(64*28*28 + 64*56*56, 512), nn.ReLU(), nn.Dropout(0.5), nn.Linear(512, num_classes) ) def forward(self, main_img, detail_img): main_feat = self.main_branch(main_img) detail_feat = self.detail_branch(detail_img) main_feat = main_feat.view(main_feat.size(0), -1) detail_feat = detail_feat.view(detail_feat.size(0), -1) combined = torch.cat((main_feat, detail_feat), dim=1) return self.classifier(combined)5.2 医疗影像分析
场景:结合CT和MRI两种模态的医疗影像进行疾病诊断。
解决方案:
- 特定模态预处理
- 早期特征融合
- 跨模态注意力机制
class MedicalFusionNet(nn.Module): def __init__(self): super().__init__() # CT图像分支 self.ct_branch = nn.Sequential( nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2) ) # MRI图像分支 self.mri_branch = nn.Sequential( nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2) ) # 融合模块 self.fusion = nn.Sequential( nn.Conv2d(64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.AdaptiveAvgPool2d(1) ) # 分类头 self.classifier = nn.Sequential( nn.Linear(64, 32), nn.ReLU(), nn.Linear(32, 2) # 二分类 ) def forward(self, ct, mri): ct_feat = self.ct_branch(ct) mri_feat = self.mri_branch(mri) fused = torch.cat((ct_feat, mri_feat), dim=1) fused = self.fusion(fused) flattened = torch.flatten(fused, 1) return self.classifier(flattened)5.3 遥感图像解译
场景:利用多光谱和高分辨率全色图像进行地物分类。
解决方案:
- 多光谱分支:处理低分辨率多通道数据
- 全色分支:处理高分辨率单通道数据
- 特征金字塔融合
class RemoteSensingNet(nn.Module): def __init__(self, num_spectral_bands=8): super().__init__() # 多光谱分支 (低分辨率多通道) self.spectral = nn.Sequential( nn.Conv2d(num_spectral_bands, 32, 3, padding=1), nn.ReLU(), nn.Conv2d(32, 64, 3, padding=1), nn.ReLU() ) # 全色分支 (高分辨率单通道) self.panchromatic = nn.Sequential( nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(), nn.Conv2d(32, 64, 3, padding=1), nn.ReLU() ) # 特征金字塔融合 self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) self.downsample = nn.MaxPool2d(2) self.fusion_conv = nn.Sequential( nn.Conv2d(128, 64, 1), nn.ReLU() ) self.classifier = nn.Conv2d(64, 10, 1) # 像素级分类 def forward(self, spectral, pan): # 处理多光谱数据 spectral_feat = self.spectral(spectral) spectral_feat = self.upsample(spectral_feat) # 处理全色数据 pan_feat = self.panchromatic(pan) pan_feat = self.downsample(pan_feat) # 融合 fused = torch.cat((spectral_feat, pan_feat), dim=1) fused = self.fusion_conv(fused) return self.classifier(fused)