当前位置: 首页 > news >正文

文字识别系统代码

点击查看代码
import torch
from torch import optim
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt
import os
import timedevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')#数据加载与预处理
transforms = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])
def load_data():train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms, download=True)test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms, download=True)batch_size = 128train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)return train_loader, test_loader, test_dataset#定义CNN模型
class MyCNN(nn.Module):def __init__(self):super(MyCNN, self).__init__()self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)self.pool1 = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(16, 16, kernel_size=3, padding=1)self.pool2 = nn.MaxPool2d(2, 2)self.output = nn.Linear(16 * 7 * 7, 10)self.dropout = nn.Dropout(0.5)def forward(self, x):x = self.pool1(F.relu(self.conv1(x)))x = self.pool2(F.relu(self.conv2(x)))x = x.view(x.size(0), -1)x = self.dropout(x)output = self.output(x)return output, x#训练模型
model = MyCNN().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
critirion = nn.CrossEntropyLoss()def train(epochs, train_loader, optimizer, critirion):model.train()train_loss = []train_acc = []for epoch in range(epochs):start_time = time.time()running_loss = 0.0total = 0current = 0for i, (images, labels) in enumerate(train_loader):images, labels = images.to(device), labels.to(device)optimizer.zero_grad()output = model(images)[0]loss = critirion(output, labels)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = torch.max(output.data, 1)total += labels.size(0)current += (predicted == labels).sum().item()epoch_loss = running_loss / len(train_loader)epoch_acc = 100 * current / totaltrain_loss.append(epoch_loss)train_acc.append(epoch_acc)end_time = time.time()print(f"Epoch [{epoch}/{epochs}], Loss: {epoch_loss:.4f}, Acc: {epoch_acc:.2f}%, Time: {end_time-start_time:.2f}s")return train_loss, train_acc#模型评估
def test(model, test_loader):model.eval()all_preds = []all_labels = []with torch.no_grad():for images, labels in test_loader:images, labels = images.to(device), labels.to(device)output = model(images)[0]_, predicted = torch.max(output.data, 1)all_preds.extend(predicted.cpu().numpy())all_labels.extend(labels.cpu().numpy())all_preds = np.array(all_preds)all_labels = np.array(all_labels)accuracy = (all_preds == all_labels).mean() * 100print(f"测试准确率: {accuracy:.2f}%")print("分类效果评估:")target_names = [str(i) for i in range(10)]report = classification_report(all_labels, all_preds, target_names=target_names)print(report)if __name__ == '__main__':print("(24信计2班 王晶莹 2024310143126)")print(f"device: {device}")train_loader, test_loader, test_dataset = load_data()epochs = 20train_loss, train_acc = train(epochs, train_loader, optimizer, critirion)test(model, test_loader)#绘制结果plt.figure(figsize=(10, 5))plt.subplot(1, 2, 1)plt.plot(range(1, epochs+1), train_loss)plt.title("Training Loss")plt.xlabel("Epoch")plt.ylabel("Loss")plt.subplot(1, 2, 2)plt.plot(range(1, epochs+1), train_acc)plt.title("Training Accuracy")plt.xlabel("Epoch")plt.ylabel("Accuracy (%)")plt.tight_layout()plt.show()

屏幕截图 2025-11-12 201725

http://www.jsqmd.com/news/38782/

相关文章:

  • B4093 [CSP-X2021 山东] 发送快递
  • 从零上手 Rokid JSAR:打造专属 AR 桌面交互式 3D魔方,开启空间创建之旅
  • 微软2025年11月补丁星期二修复1个零日漏洞和63个安全漏洞
  • CF468C Hack it!
  • 深入解析:FT62FC3X 8位MCU单片机选型表,详细解析FT62FC31A/32A/33A/35A/3FA
  • 语法记录
  • Can Large Language Models Detect Rumors on Social Media?
  • 压迫
  • P13573 [CCPC 2024 重庆站] Pico Park
  • 手工安装gcc-13.3.0
  • 深入解析:Cookie、Session、JWT、SSO,网站与 APP 登录持久化与缓存
  • gowin ide linux安装教程
  • AT_arc111_f [ARC111F] Do you like query problems?
  • Win7 隐藏文件夹盘符
  • pythontip 按条件过滤字典
  • DotNetGuide 突破了 9.5K + Star,一份全面的C#/.NET/.NET Core学习、工作、面试指南知识库!
  • 如何把华为mate 60手机备份到移动硬盘
  • Vue实例学习
  • 2.2 语言处理程序基础
  • Ai元人文:价值的“迷思”与“归真”——从家庭之爱到文明共生
  • MATLAB 数据可视化教程:从基础到进阶
  • 在ec2上部署qwen3-VL-2B模型
  • 37
  • Daily Scrum 2025.11.12
  • 完整教程:mit6s081 lab8 locks
  • 软件工程学习日志2025.11.12
  • [集训队互测 2025] 火花 做题记录
  • 返璞归真,因为自指,所以自洽
  • NLTK库用法示例:Python自然语言处理入门到实践 - 实践
  • 2025大桶/桶装/纯净/瓶装/灌装水设备推荐榜:青州市路得自动化五星领跑 四大品牌赋能水企高效生产