PyTorch全连接层实战:从图像分类到文本处理的5个经典案例
PyTorch全连接层实战:从图像分类到文本处理的5个经典案例
全连接层作为神经网络的基础构建块,其重要性不言而喻。但很多学习者在掌握了基础理论后,面对实际项目时仍会感到无从下手。本文将带你深入五个典型应用场景,通过完整可运行的代码示例,展示如何用PyTorch的nn.Linear解决实际问题。
1. 手写数字识别:MNIST图像分类实战
MNIST数据集是深度学习入门的经典案例。让我们构建一个简单的全连接网络来实现手写数字识别。
首先准备数据集:
import torch from torchvision import datasets, transforms # 数据预处理 transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) # 加载数据集 train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform) test_dataset = datasets.MNIST('./data', train=False, transform=transform) # 创建数据加载器 train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=True)接下来定义网络结构:
import torch.nn as nn import torch.nn.functional as F class MNISTNet(nn.Module): def __init__(self): super(MNISTNet, self).__init__() self.fc1 = nn.Linear(28*28, 512) # 第一层全连接 self.fc2 = nn.Linear(512, 256) # 第二层全连接 self.fc3 = nn.Linear(256, 10) # 输出层 def forward(self, x): x = x.view(-1, 28*28) # 展平图像 x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) return F.log_softmax(self.fc3(x), dim=1)训练过程的关键代码:
def train(model, device, train_loader, optimizer, epoch): model.train() for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) loss = F.nll_loss(output, target) loss.backward() optimizer.step()这个简单网络在测试集上能达到约98%的准确率。实践中可以通过以下方式进一步提升性能:
- 增加网络深度
- 使用批归一化(BatchNorm)
- 添加Dropout层防止过拟合
- 尝试不同的优化器和学习率调度
2. 情感分析:IMDb电影评论分类
文本分类是自然语言处理中的基础任务。我们使用IMDb电影评论数据集构建情感分析模型。
首先处理文本数据:
from torchtext.datasets import IMDB from torchtext.data import Field, LabelField, BucketIterator TEXT = Field(tokenize='spacy', lower=True, include_lengths=True) LABEL = LabelField(dtype=torch.float) train_data, test_data = IMDB.splits(TEXT, LABEL) # 构建词汇表 TEXT.build_vocab(train_data, max_size=25000) LABEL.build_vocab(train_data) # 创建迭代器 train_iterator, test_iterator = BucketIterator.splits( (train_data, test_data), batch_size=64, sort_within_batch=True, sort_key=lambda x: len(x.text) )定义网络结构:
class SentimentClassifier(nn.Module): def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim): super().__init__() self.embedding = nn.Embedding(vocab_size, embedding_dim) self.fc1 = nn.Linear(embedding_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, output_dim) def forward(self, text, text_lengths): embedded = self.embedding(text) pooled = embedded.mean(1) # 平均池化 hidden = torch.relu(self.fc1(pooled)) return torch.sigmoid(self.fc2(hidden))训练时需要注意文本数据的特殊性:
- 使用嵌入层将单词转换为向量
- 处理变长序列
- 选择合适的池化方式
这个基础模型可以达到约85%的准确率。改进方向包括:
- 使用预训练词向量
- 引入LSTM或Transformer结构
- 调整文本预处理流程
3. 房价预测:回归任务实战
全连接网络同样适用于回归问题。我们使用波士顿房价数据集演示如何预测连续值。
加载和处理数据:
from sklearn.datasets import load_boston from sklearn.preprocessing import StandardScaler from sklearn.model_selection import train_test_split boston = load_boston() X = boston.data y = boston.target # 数据标准化 scaler = StandardScaler() X = scaler.fit_transform(X) # 划分训练测试集 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) # 转换为PyTorch张量 X_train = torch.FloatTensor(X_train) y_train = torch.FloatTensor(y_train).unsqueeze(1) X_test = torch.FloatTensor(X_test) y_test = torch.FloatTensor(y_test).unsqueeze(1)定义回归模型:
class HousePricePredictor(nn.Module): def __init__(self, input_dim): super(HousePricePredictor, self).__init__() self.fc1 = nn.Linear(input_dim, 64) self.fc2 = nn.Linear(64, 32) self.fc3 = nn.Linear(32, 1) def forward(self, x): x = torch.relu(self.fc1(x)) x = torch.relu(self.fc2(x)) return self.fc3(x)训练回归模型的要点:
- 使用MSE损失函数
- 评估指标改为RMSE或MAE
- 注意数据标准化
- 可能需要调整学习率
回归任务常见挑战及解决方案:
| 问题 | 解决方案 |
|---|---|
| 数据量少 | 使用更小的网络,增加正则化 |
| 特征尺度差异大 | 标准化或归一化输入 |
| 非线性关系 | 增加网络深度,使用合适的激活函数 |
| 过拟合 | 添加Dropout或L2正则化 |
4. 多标签分类:新闻主题分类
Reuters新闻数据集包含46个互斥的新闻主题类别,是多分类问题的典型案例。
数据处理:
from torchtext.datasets import Reuters from torchtext.data import Field, LabelField, BucketIterator TEXT = Field(tokenize='spacy', lower=True) LABEL = LabelField(dtype=torch.long) train_data, test_data = Reuters.splits(TEXT, LABEL) # 构建词汇表 TEXT.build_vocab(train_data, max_size=25000) LABEL.build_vocab(train_data) # 创建迭代器 train_iterator, test_iterator = BucketIterator.splits( (train_data, test_data), batch_size=64, sort_within_batch=True, sort_key=lambda x: len(x.text) )定义多分类网络:
class NewsClassifier(nn.Module): def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim): super().__init__() self.embedding = nn.Embedding(vocab_size, embedding_dim) self.fc1 = nn.Linear(embedding_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, output_dim) def forward(self, text, text_lengths): embedded = self.embedding(text) pooled = embedded.mean(1) # 平均池化 hidden = torch.relu(self.fc1(pooled)) return self.fc2(hidden)多分类任务的关键点:
- 输出层神经元数等于类别数
- 使用交叉熵损失
- 评估指标关注准确率和混淆矩阵
- 类别不平衡问题可能需要特殊处理
提升多分类性能的技巧:
尝试不同的文本表示方法:
- TF-IDF
- Word2Vec
- BERT等预训练模型
网络结构优化:
- 增加隐藏层
- 使用批归一化
- 添加注意力机制
数据增强:
- 同义词替换
- 随机删除单词
- 回译
5. 自定义数据集:花卉图像分类
最后我们看一个自定义数据集的案例,使用Oxford 102花卉数据集。
首先实现自定义数据集类:
from torch.utils.data import Dataset from PIL import Image import os class FlowerDataset(Dataset): def __init__(self, root_dir, transform=None): self.root_dir = root_dir self.transform = transform self.classes = sorted(os.listdir(root_dir)) self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)} self.images = [] for cls in self.classes: cls_dir = os.path.join(root_dir, cls) for img_name in os.listdir(cls_dir): self.images.append((os.path.join(cls_dir, img_name), self.class_to_idx[cls])) def __len__(self): return len(self.images) def __getitem__(self, idx): img_path, label = self.images[idx] image = Image.open(img_path).convert('RGB') if self.transform: image = self.transform(image) return image, label定义图像分类网络:
class FlowerClassifier(nn.Module): def __init__(self, input_size, hidden_size, num_classes): super(FlowerClassifier, self).__init__() self.fc1 = nn.Linear(input_size, hidden_size) self.fc2 = nn.Linear(hidden_size, hidden_size//2) self.fc3 = nn.Linear(hidden_size//2, num_classes) def forward(self, x): x = x.view(x.size(0), -1) # 展平图像 x = torch.relu(self.fc1(x)) x = torch.relu(self.fc2(x)) return self.fc3(x)处理自定义数据集的注意事项:
- 确保图像尺寸一致
- 合理的数据增强策略
- 处理类别不平衡
- 适当的学习率调度
全连接网络在图像分类中的局限性:
虽然全连接网络可以处理图像数据,但当图像尺寸较大时,参数量会急剧增加。例如:
| 输入尺寸 | 第一层参数数量(隐藏层512) |
|---|---|
| 32x32 | 32x32x512 ≈ 500K |
| 224x224 | 224x224x512 ≈ 25M |
因此,对于大尺寸图像,通常会先使用卷积层提取特征,再连接全连接层进行分类。
