当前位置: 首页 > news >正文

手写数字识别(3种算法对比)

点击查看代码
import numpy as np
import matplotlib.pyplot as plt
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms# ======================== 1. 设备配置与参数设置 ========================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 128
num_classes = 10  # 模型输出维度:0-9(对应数字1-10)
num_epochs = 5
lr = 0.001# ======================== 2. 数据预处理 ========================
# 数据变换:张量转换 + 标准化
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))  # MNIST官方均值和标准差
])# 加载MNIST数据集(原始标签0-9,对应数字1-10)
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)# 注意:不再偏移标签!训练时用原始0-9标签(符合CrossEntropyLoss要求),仅在评估/展示时映射为1-10
# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)# ======================== 3. 模型定义 ========================
# 3.1 共享CNN特征提取器
class CNNFeatureExtractor(nn.Module):def __init__(self):super(CNNFeatureExtractor, self).__init__()self.features = nn.Sequential(nn.Conv2d(1, 32, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(2, 2),  # 输出: 32 * 14 * 14nn.Conv2d(32, 64, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(2, 2),  # 输出: 64 * 7 * 7nn.Conv2d(64, 128, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(2, 2)  # 输出: 128 * 3 * 3)self.flatten_dim = 128 * 3 * 3  # 展平后维度def forward(self, x):x = self.features(x)x = x.view(-1, self.flatten_dim)return x# 3.2 CNN+Softmax
class CNNSoftmax(nn.Module):def __init__(self, num_classes):super(CNNSoftmax, self).__init__()self.feature_extractor = CNNFeatureExtractor()self.classifier = nn.Linear(self.feature_extractor.flatten_dim, num_classes)self.softmax = nn.Softmax(dim=1)def forward(self, x):features = self.feature_extractor(x)logits = self.classifier(features)return self.softmax(logits), logits# 3.3 CNN+Sigmoid
class CNNSigmoid(nn.Module):def __init__(self, num_classes):super(CNNSigmoid, self).__init__()self.feature_extractor = CNNFeatureExtractor()self.classifier = nn.Linear(self.feature_extractor.flatten_dim, num_classes)self.sigmoid = nn.Sigmoid()def forward(self, x):features = self.feature_extractor(x)logits = self.classifier(features)return self.sigmoid(logits), logits# 3.4 CNN+SVM(多分类Hinge Loss)
class CNNSVM(nn.Module):def __init__(self, num_classes):super(CNNSVM, self).__init__()self.feature_extractor = CNNFeatureExtractor()self.classifier = nn.Linear(self.feature_extractor.flatten_dim, num_classes)def forward(self, x):features = self.feature_extractor(x)return self.classifier(features)# 多分类SVM Hinge Loss定义
class MultiClassHingeLoss(nn.Module):def __init__(self, margin=1.0):super(MultiClassHingeLoss, self).__init__()self.margin = margindef forward(self, logits, labels):# labels已经是0-9,无需调整one_hot = torch.zeros_like(logits).scatter(1, labels.unsqueeze(1), 1)correct_logit = torch.sum(logits * one_hot, dim=1, keepdim=True)loss = torch.maximum(torch.tensor(0.0).to(device), self.margin - (correct_logit - logits))loss = torch.sum(loss * (1 - one_hot)) / logits.size(0)return loss# ======================== 4. 训练与评估函数 ========================
# 4.1 通用CNN训练函数
def train_cnn_model(model, criterion, optimizer, train_loader, num_epochs, is_sigmoid=False):model.to(device)model.train()start_time = time.time()for epoch in range(num_epochs):total_loss = 0.0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)  # target是0-9,符合要求optimizer.zero_grad()if is_sigmoid:# Sigmoid需要one-hot标签(0-9)target_onehot = torch.zeros(len(target), num_classes).to(device).scatter(1, target.unsqueeze(1), 1)output, logits = model(data)loss = criterion(logits, target_onehot)else:if isinstance(model, (CNNSoftmax, CNNSigmoid)):output, logits = model(data)else:logits = model(data)loss = criterion(logits, target)loss.backward()optimizer.step()total_loss += loss.item() * data.size(0)avg_loss = total_loss / len(train_loader.dataset)print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_loss:.4f}')train_time = time.time() - start_timereturn model, train_time# 4.2 模型评估函数(计算准确率、精确率、召回率、F1)
def evaluate_model(model, test_loader, is_cnn_svm=False):model.eval()all_preds = []all_targets = []with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)  # target是0-9if isinstance(model, CNNSoftmax):output, _ = model(data)preds = torch.argmax(output, dim=1)  # 0-9elif isinstance(model, CNNSigmoid):output, _ = model(data)preds = torch.argmax(output, dim=1)  # 0-9elif is_cnn_svm:logits = model(data)preds = torch.argmax(logits, dim=1)  # 0-9# 评估时映射为1-10(仅为了符合题目“1-10识别”的表述)all_preds.extend((preds + 1).cpu().numpy())  # 0-9 → 1-10all_targets.extend((target + 1).cpu().numpy())  # 0-9 → 1-10all_preds = np.array(all_preds)all_targets = np.array(all_targets)# 手动实现宏观平均的精确率、召回率、F1(避免依赖sklearn)def calculate_metrics(y_true, y_pred, num_classes):# 转换回0-9计算指标(避免10超出索引)y_true = y_true - 1y_pred = y_pred - 1# 初始化各类的TP、FP、FNTP = np.zeros(num_classes)FP = np.zeros(num_classes)FN = np.zeros(num_classes)for cls in range(num_classes):TP[cls] = np.sum((y_true == cls) & (y_pred == cls))FP[cls] = np.sum((y_true != cls) & (y_pred == cls))FN[cls] = np.sum((y_true == cls) & (y_pred != cls))# 计算各类的精确率、召回率(避免除以0)precision_per_cls = TP / (TP + FP + 1e-8)recall_per_cls = TP / (TP + FN + 1e-8)f1_per_cls = 2 * (precision_per_cls * recall_per_cls) / (precision_per_cls + recall_per_cls + 1e-8)# 宏观平均precision = np.mean(precision_per_cls)recall = np.mean(recall_per_cls)f1 = np.mean(f1_per_cls)accuracy = np.sum(TP) / len(y_true)return accuracy, precision, recall, f1accuracy, precision, recall, f1 = calculate_metrics(all_targets, all_preds, num_classes)return accuracy, precision, recall, f1# ======================== 5. 模型训练 ========================
print("=" * 30 + " Training CNN+Softmax " + "=" * 30)
model_softmax = CNNSoftmax(num_classes)
criterion_softmax = nn.CrossEntropyLoss()
optimizer_softmax = optim.Adam(model_softmax.parameters(), lr=lr)
model_softmax, time_softmax = train_cnn_model(model_softmax, criterion_softmax, optimizer_softmax, train_loader,num_epochs)print("\n" + "=" * 30 + " Training CNN+Sigmoid " + "=" * 30)
model_sigmoid = CNNSigmoid(num_classes)
criterion_sigmoid = nn.BCEWithLogitsLoss()
optimizer_sigmoid = optim.Adam(model_sigmoid.parameters(), lr=lr)
model_sigmoid, time_sigmoid = train_cnn_model(model_sigmoid, criterion_sigmoid, optimizer_sigmoid, train_loader,num_epochs, is_sigmoid=True)print("\n" + "=" * 30 + " Training CNN+SVM " + "=" * 30)
model_cnn_svm = CNNSVM(num_classes)
criterion_cnn_svm = MultiClassHingeLoss()
optimizer_cnn_svm = optim.Adam(model_cnn_svm.parameters(), lr=lr)
model_cnn_svm, time_cnn_svm = train_cnn_model(model_cnn_svm, criterion_cnn_svm, optimizer_cnn_svm, train_loader,num_epochs)# ======================== 6. 模型评估 ========================
print("\n" + "=" * 30 + " Model Evaluation " + "=" * 30)
# 评估三个CNN模型
acc_softmax, pre_softmax, rec_softmax, f1_softmax = evaluate_model(model_softmax, test_loader)
acc_sigmoid, pre_sigmoid, rec_sigmoid, f1_sigmoid = evaluate_model(model_sigmoid, test_loader)
acc_cnn_svm, pre_cnn_svm, rec_cnn_svm, f1_cnn_svm = evaluate_model(model_cnn_svm, test_loader, is_cnn_svm=True)# 整理结果
models = ['CNN+Softmax', 'CNN+Sigmoid', 'CNN+SVM']
accuracy_list = [acc_softmax, acc_sigmoid, acc_cnn_svm]
precision_list = [pre_softmax, pre_sigmoid, pre_cnn_svm]
recall_list = [rec_softmax, rec_sigmoid, rec_cnn_svm]
f1_list = [f1_softmax, f1_sigmoid, f1_cnn_svm]
time_list = [time_softmax, time_sigmoid, time_cnn_svm]# 打印结果表格
print("\n" + "=" * 90)
print(f"{'Model':<18} {'Accuracy':<12} {'Precision':<12} {'Recall':<12} {'F1-Score':<12} {'Train Time(s)':<12}")
print("=" * 90)
for i in range(len(models)):print(f"{models[i]:<18} {accuracy_list[i]:<12.4f} {precision_list[i]:<12.4f} {recall_list[i]:<12.4f} {f1_list[i]:<12.4f} {time_list[i]:<12.2f}")
print("=" * 90)# ======================== 7. 结果可视化 ========================
def plot_metrics_comparison(models, accuracy, precision, recall, f1, time_list):fig, ax = plt.subplots(1, 2, figsize=(18, 7))x = np.arange(len(models))width = 0.2# 子图1:分类指标对比ax[0].bar(x - 1.5 * width, accuracy, width, label='Accuracy', color='#1f77b4', edgecolor='black')ax[0].bar(x - 0.5 * width, precision, width, label='Precision', color='#ff7f0e', edgecolor='black')ax[0].bar(x + 0.5 * width, recall, width, label='Recall', color='#2ca02c', edgecolor='black')ax[0].bar(x + 1.5 * width, f1, width, label='F1-Score', color='#d62728', edgecolor='black')ax[0].set_xlabel('Models', fontsize=12, fontweight='bold')ax[0].set_ylabel('Score', fontsize=12, fontweight='bold')ax[0].set_title('Classification Metrics Comparison (Digits 1-10)', fontsize=14, fontweight='bold')ax[0].set_xticks(x)ax[0].set_xticklabels(models, rotation=15, fontsize=10)ax[0].legend(fontsize=10)ax[0].grid(axis='y', linestyle='--', alpha=0.7)# 子图2:训练时间对比bars = ax[1].bar(models, time_list, color=['#1f77b4', '#ff7f0e', '#2ca02c'], edgecolor='black')ax[1].set_xlabel('Models', fontsize=12, fontweight='bold')ax[1].set_ylabel('Training Time (s)', fontsize=12, fontweight='bold')ax[1].set_title('Training Time Comparison', fontsize=14, fontweight='bold')ax[1].tick_params(axis='x', rotation=15, labelsize=10)ax[1].grid(axis='y', linestyle='--', alpha=0.7)# 添加数值标签for bar in bars:height = bar.get_height()ax[1].text(bar.get_x() + bar.get_width() / 2., height + 0.5, f'{height:.2f}', ha='center', va='bottom',fontsize=10)plt.tight_layout()plt.savefig('mnist_digit_recognition_comparison.png', dpi=300, bbox_inches='tight')plt.show()# 调用绘图函数
plot_metrics_comparison(models, accuracy_list, precision_list, recall_list, f1_list, time_list)
print("2024310143014")
http://www.jsqmd.com/news/140003/

相关文章:

  • CRMEB WxJava,微信生态开发外挂来袭!
  • Perfecxion.ai发布:生产级安全编程数据集防范AI代码漏洞
  • 夸克网盘下载不限速_在线解析站
  • 别再把 K8s 当大号 Docker 了:我用 Kubernetes 跑数据任务踩过的那些坑
  • 前端架构演进之路——从网页到应用
  • 利用SAT求解优化量子电路映射
  • P3241 [HNOI2015] 开店
  • Shell 脚本
  • 不懂技术怕什么?陀螺匠低代码平台,拖拽之间搞定复杂数据关联
  • 夸克网盘不限速_在线公益解析站
  • 同步通信协议(I2C/SPI)驱动OLED/EEPROM/传感器实战
  • Chat2PDF 的最神级用法,其实是一键把 AI 对话变成干净高保真的 PDF - 实践
  • CRMEB 标准版系统(PHP)- 前端多语言开发指南
  • 午餐肉灌装机市场风向标:优质午餐肉生产厂家大公开,行业内评价好的灌装机公司博锐层层把关品质优 - 品牌推荐师
  • 高速斩拌机品牌权威测评,谁是行业真王者?搅拌机源头厂家精选实力品牌榜单发布 - 品牌推荐师
  • 当“同时发生”成为攻击武器
  • 学习《Transformer原理》读书报告
  • OriginPro 2024 保姆级下载安装教程图文详细步骤(附激活激活 + 中文切换,亲测有效)
  • 跨数据源搜索的优化过程
  • 学长亲荐8个AI论文工具,本科生轻松搞定论文格式!
  • 三星自研GPU剑指AI芯片霸权,2027年能否撼动英伟达?
  • 高速斩拌机厂家综合实力排行,国内有实力的搅拌机品牌怎么选择博锐满足多元需求 - 品牌推荐师
  • 学生管理系统!
  • 当CAIE证书遇上职场现实:考后的路该怎么走?
  • 天气查询前端
  • 天气查询前端
  • DeepAnaX「GEO优化分析统计系统」重磅升级:让每一份数据都通往清晰决策
  • MySQL 日志体系总览
  • 在postgresql和duckdb的多表连接中其中一个表引用另一个表的数据
  • 2025最新!研究生必备8个AI论文工具:开题报告与文献综述全测评