当前位置: 首页 > news >正文

卷积神经网络的引入3 —— MLP 与 CNN 在更大数据集上的性能对比实验

卷积神经网络的引入3 —— MLP 与 CNN 在更大数据集上的性能对比实验

在前两篇文章中,我们分别验证了:

  1. MLP 对平移等扰动非常敏感,而 CNN 具备更好的鲁棒性
  2. 在 Fashion-MNIST(低维灰度图)下,MLP 与 CNN 的表现差距不算巨大

为了进一步理解 CNN 的结构优势是否会随 数据集复杂度的提升 而真正显现,本篇将进入本系列的第三个验证点:


一、实验目标

本篇旨在验证:

MLP 与 CNN 在更大、更复杂的图像数据集上是否会出现明显性能差异?

更具体地说,我们希望回答以下问题:

🧪 1. 当图片不再是低维灰度图(如 CIFAR10、STL10),MLP 的表达能力是否明显不足?

🧪 2. CNN 由于卷积与池化机制,是否在更大数据集上展现出更强的泛化能力?

🧪 3. 随着训练轮次提升,两者的收敛速度与最终精度差异是否会逐步拉大?


二、数据集选择与对比策略

本次实验选择 三个不同复杂度的数据集

数据集 通道数 尺寸 难度 说明
Fashion-MNIST 1 28×28 上一章实验基准
CIFAR-10 3 32×32 彩色图片,分类更复杂
STL-10 3 96×96 图片分辨率大、类别难度高

本篇重点展示 CIFAR-10 实验(也是最经典的数据集)。


三、实验步骤

  1. 构建 MLP 与 CNN 两种模型基线

    • MLP:输入直接 Flatten → 全连接层
    • CNN:多层卷积 + 池化 + 全局池化
  2. 在同一数据集上训练 10 个 Epoch

    • 优化器:Adam
    • 学习率:1e-3
    • 批次大小:64
  3. 对比训练集精度与验证集精度

    • 用折线图对比两种模型的收敛过程
    • 观察最终的测试集表现

四、实验代码

以下代码可完整复现本章实验,结构与上一篇保持一致。

# -*- coding: utf-8 -*-
# 卷积神经网络的引入3 —— 不同数据集规模下的 MLP 与 CNN 对比实验
# Author: 方子敬
# Date: 2025-11-11import torch, torchvision
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as pltdevice = 'mps' if torch.backends.mps.is_available() else 'cpu'# =============================
# 1️⃣ 数据集选择(可修改)
# =============================
DATASET = 'CIFAR10'  # FashionMNIST / CIFAR10 / STL10# =============================
# 2️⃣ 数据加载
# =============================
if DATASET == 'FashionMNIST':transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))])trainset = datasets.FashionMNIST('./data', train=True, download=True, transform=transform)testset = datasets.FashionMNIST('./data', train=False, download=True, transform=transform)input_channels = 1input_dim = 28 * 28elif DATASET == 'CIFAR10':transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])trainset = datasets.CIFAR10('./data', train=True, download=True, transform=transform)testset = datasets.CIFAR10('./data', train=False, download=True, transform=transform)input_channels = 3input_dim = 32 * 32 * 3elif DATASET == 'STL10':transform = transforms.Compose([transforms.Resize((96,96)),transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])trainset = datasets.STL10('./data', split='train', download=True, transform=transform)testset = datasets.STL10('./data', split='test', download=True, transform=transform)input_channels = 3input_dim = 96 * 96 * 3train_loader = DataLoader(trainset, batch_size=64, shuffle=True)
test_loader = DataLoader(testset, batch_size=256)# =============================
# 3️⃣ 定义模型
# =============================
class MLP(nn.Module):def __init__(self, input_dim, hidden=1024):super().__init__()self.net = nn.Sequential(nn.Flatten(),nn.Linear(input_dim, hidden),nn.ReLU(),nn.Linear(hidden, 10))def forward(self, x):return self.net(x)class CNN(nn.Module):def __init__(self, in_ch):super().__init__()self.net = nn.Sequential(nn.Conv2d(in_ch, 32, 3, padding=1),nn.BatchNorm2d(32),nn.ReLU(),nn.MaxPool2d(2),nn.Conv2d(32, 64, 3, padding=1),nn.BatchNorm2d(64),nn.ReLU(),nn.MaxPool2d(2),nn.Conv2d(64, 128, 3, padding=1),nn.ReLU(),nn.AdaptiveAvgPool2d(1),nn.Flatten(),nn.Linear(128, 10))def forward(self, x):return self.net(x)# =============================
# 4️⃣ 训练与验证
# =============================
loss_fn = nn.CrossEntropyLoss()def train_one_epoch(model, loader, opt):model.train()total_loss, total_correct = 0, 0for x, y in loader:x, y = x.to(device), y.to(device)out = model(x)loss = loss_fn(out, y)opt.zero_grad()loss.backward()opt.step()total_loss += loss.item()total_correct += (out.argmax(1) == y).sum().item()return total_loss / len(loader), total_correct / len(loader.dataset)def evaluate(model, loader):model.eval()total_correct = 0with torch.no_grad():for x, y in loader:x, y = x.to(device), y.to(device)total_correct += (model(x).argmax(1) == y).sum().item()return total_correct / len(loader.dataset)# =============================
# 5️⃣ 实验执行
# =============================
mlp = MLP(input_dim).to(device)
cnn = CNN(input_channels).to(device)opt_mlp = torch.optim.Adam(mlp.parameters(), lr=1e-3)
opt_cnn = torch.optim.Adam(cnn.parameters(), lr=1e-3)epochs = 10
mlp_train_acc, cnn_train_acc = [], []
mlp_val_acc, cnn_val_acc = [], []for ep in range(epochs):_, acc_m = train_one_epoch(mlp, train_loader, opt_mlp)_, acc_c = train_one_epoch(cnn, train_loader, opt_cnn)val_m = evaluate(mlp, test_loader)val_c = evaluate(cnn, test_loader)mlp_train_acc.append(acc_m)cnn_train_acc.append(acc_c)mlp_val_acc.append(val_m)cnn_val_acc.append(val_c)print(f"[{ep+1}/{epochs}] MLP val acc={val_m:.3f} | CNN val acc={val_c:.3f}")# =============================
# 6️⃣ 精度曲线对比
# =============================
plt.figure(figsize=(10,6))
plt.plot(range(1, epochs+1), mlp_train_acc, 'r--o', label='MLP Train')
plt.plot(range(1, epochs+1), mlp_val_acc, 'r-', label='MLP Val')plt.plot(range(1, epochs+1), cnn_train_acc, 'b--o', label='CNN Train')
plt.plot(range(1, epochs+1), cnn_val_acc, 'b-', label='CNN Val')plt.title(f"Training vs Validation Accuracy on {DATASET}")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

五、训练结果(图示)

image

六、实验结论(可根据图示补充)

从图中可以明显观察到:

  1. MLP 的学习能力在 CIFAR10 上严重受限

训练集精度 从 42% → 66%,虽然逐步上升,但速度较慢。

验证集精度 长期停留在 48% ~ 53% 区间波动,几乎没有随训练改善。

出现典型的:

高维输入导致参数量巨大(32×32×3=3072维)

特征表达能力不足 → 难以捕捉局部图像结构

过拟合风险不断加剧 → 训练精度升,验证精度停滞

简而言之:

MLP 在 CIFAR10 这种复杂多类彩色图片上已经力不从心。

  1. CNN 从早期阶段就展现出显著优势

训练集精度 第一轮就达到了 ~49%,明显高于 MLP 的 42%。

验证集精度 随 epoch 持续提升,从 48% → 最终 71.0%。

并且 CNN 的训练线和验证线之间差距较小,说明泛化性良好。

这验证了 CNN 的结构优势:

卷积核能捕捉 局部空间信息

池化 & 步长提升模型的 平移不变性

BN 提升收敛速度与稳定性

参数量远小于 MLP,过拟合风险更低

  1. CNN 的提升速度明显更快

曲线中可以看到:

CNN 在 第 2 ~ 3 个 epoch 就已经达到 MLP 第 10 个 epoch 都无法达到的验证精度

随着迭代继续进行,两者差距持续被拉大

CNN 不仅最终精度更高,而且学习速度显著快于 MLP。

  1. 数据集复杂度越高,MLP 和 CNN 的差距会越大

Fashion-MNIST:两者差距有限
CIFAR10:差距明显
STL10:差距会进一步扩大(会在下一篇验证)

最终总结

通过 CIFAR10 的实验我们能够非常明确地得出:

随着数据维度和视觉复杂度的提升,MLP 的能力呈现下降趋势,而 CNN 的结构优势将快速显现。
CNN 在高维彩色图片上的泛化性能、特征提取能力与收敛速度均远胜于 MLP。

http://www.jsqmd.com/news/47066/

相关文章:

  • 全网都在找的Nano Banana Pro API 来了!便宜稳定0.15/张
  • 通过DataReader获取sql查询的字段元数据信息
  • 2025.11.21 - A
  • 2025年新版ADB工具箱下载+驱动+ADB指令集+fastboot刷机ROOT程序
  • P7960 [NOIP2021] 报数__洛谷题解
  • 与括号序列相关的 DP 笔记
  • 【251121】CF2171 Div.3 vp 总结
  • OI 笑传 #32
  • PyOpenGL实现Bresenham算法
  • 【Linux】教你在 Linux 上搭建 Web 服务器,步骤清晰无门槛 - 详解
  • 【第7章 I/O编程与异常】\r\n 和 \n\r是一回事吗?
  • 深入解析:windows显示驱动开发-CCD api的摘要及方案(一)
  • nju实验七 状态机及键盘输入
  • Gephi如何支持MySQL数据的复杂查询
  • Mozilla CI日志中暴露微软x-apikey的安全事件分析
  • Gephi怎样优化MySQL数据的展示效果
  • Gephi对MySQL数据的导入导出有何支持
  • 智能制造(MOM)-详细设计 - 智慧园区
  • nju实验六 移位寄存器及桶形移位器
  • P6727 [COCI 2015/2016 #5] OOP
  • 完整教程:政务系统信创改造中,金仓日志如何满足等保2.0三级审计要求
  • 基于 Erlang 的英文数字验证码识别系统设计与实现
  • 如何使用IDM嗅探视频并下载?
  • java数据结构--LinkedList与链表 - 教程
  • 洛谷 B4409:[GESP202509 一级] 商店折扣 ← 模拟算法
  • 深入解析:自动化文件管理:分类、重命名和备份
  • nju实验三 加法器与ALU
  • 信息论(八):吉布斯不等式的证明
  • macos: 景观类动态的壁纸和屏保保存在哪里
  • pyppeteer: 得到当前运行中的浏览器