当前位置: 首页 > news >正文

汉字识别

点击查看代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import matplotlib.pyplot as plt
import random
import os
from torchvision import transforms
import matplotlib# 确保中文显示正常(只保留Windows系统默认存在的字体)
plt.rcParams["font.family"] = ["SimHei", "Microsoft YaHei"]  # 这两种字体Windows系统默认必装
matplotlib.rcParams["axes.unicode_minus"] = False  # 解决负号显示问题# 1. 生成中文数据集
class ChineseCharacterDataset(Dataset):def __init__(self, num_samples=10000, img_size=(64, 64), transform=None):"""生成模拟手写中文字符数据集"""self.num_samples = num_samplesself.img_size = img_sizeself.transform = transform# 选择常用中文字符(可扩展)self.chars = "一二三四五六七八九十甲乙丙丁戊己庚辛壬癸金木水火土天地日月"self.classes = list(self.chars)self.num_classes = len(self.classes)self.char_to_idx = {char: i for i, char in enumerate(self.classes)}# 生成数据集self.images, self.labels = self._generate_data()def _generate_data(self):"""生成模拟手写中文字符图像"""images = []labels = []# 尝试加载中文字体,若失败则使用默认字体try:# 直接使用系统字体目录中的黑体font = ImageFont.truetype("C:/Windows/Fonts/simhei.ttf", 40)  # 明确指定系统黑体路径except:try:# 尝试系统宋体font = ImageFont.truetype("C:/Windows/Fonts/simsun.ttc", 40)except:font = ImageFont.load_default()print("警告:未找到中文字体,使用默认字体可能导致显示异常")for _ in range(self.num_samples):# 随机选择一个字符char = random.choice(self.classes)label = self.char_to_idx[char]# 创建空白图像img = Image.new('L', self.img_size, color=255)  # 白色背景draw = ImageDraw.Draw(img)# 计算文字位置(居中)- 兼容新版本PILbbox = draw.textbbox((0, 0), char, font=font)char_width = bbox[2] - bbox[0]  # 宽度char_height = bbox[3] - bbox[1]  # 高度x = (self.img_size[0] - char_width) // 2y = (self.img_size[1] - char_height) // 2# 添加随机扰动,模拟手写效果x += random.randint(-5, 5)y += random.randint(-5, 5)# 绘制文字(黑色)draw.text((x, y), char, font=font, fill=0)# 添加随机噪声if random.random() > 0.5:img = self._add_noise(img)# 添加随机旋转if random.random() > 0.5:angle = random.randint(-10, 10)img = img.rotate(angle, expand=False, fillcolor=255)images.append(img)labels.append(label)return images, labelsdef _add_noise(self, img):"""为图像添加随机噪声"""img_np = np.array(img)noise = np.random.randint(-20, 20, size=img_np.shape)img_np = np.clip(img_np + noise, 0, 255).astype(np.uint8)return Image.fromarray(img_np)def __len__(self):return self.num_samplesdef __getitem__(self, idx):img = self.images[idx]label = self.labels[idx]if self.transform:img = self.transform(img)return img, label# 2. 数据预处理与加载
transform = transforms.Compose([transforms.ToTensor(),  # 转为张量transforms.Normalize((0.5,), (0.5,))  # 归一化到 [-1, 1]
])# 创建数据集
train_dataset = ChineseCharacterDataset(num_samples=8000,img_size=(64, 64),transform=transform
)
test_dataset = ChineseCharacterDataset(num_samples=2000,img_size=(64, 64),transform=transform
)# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)# 3. 定义适用于中文字符识别的CNN模型
class ChineseCNN(nn.Module):def __init__(self, num_classes):super(ChineseCNN, self).__init__()# 卷积层:输入为1通道(灰度图)self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)# 池化层self.pool = nn.MaxPool2d(2, 2)# 全连接层:64x64图像经过3次池化后变为8x8self.fc1 = nn.Linear(128 * 8 * 8, 512)self.fc2 = nn.Linear(512, 256)self.fc3 = nn.Linear(256, num_classes)# Dropout防止过拟合self.dropout = nn.Dropout(0.5)def forward(self, x):# 卷积 -> ReLU -> 池化x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = self.pool(F.relu(self.conv3(x)))# 展平特征图x = x.view(-1, 128 * 8 * 8)# 全连接层 -> ReLU -> Dropoutx = F.relu(self.fc1(x))x = self.dropout(x)x = F.relu(self.fc2(x))x = self.dropout(x)# 输出层x = self.fc3(x)return x# 创建模型实例
num_classes = train_dataset.num_classes
model = ChineseCNN(num_classes)# 4. 定义损失函数与优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 5. 模型训练
num_epochs = 15
model.train()for epoch in range(num_epochs):total_loss = 0correct = 0total = 0for images, labels in train_loader:# 前向传播outputs = model(images)loss = criterion(outputs, labels)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item()# 计算训练准确率_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()train_accuracy = 100 * correct / totalavg_loss = total_loss / len(train_loader)print(f"Epoch [{epoch + 1}/{num_epochs}], 损失: {avg_loss:.4f}, 训练准确率: {train_accuracy:.2f}%")# 6. 模型测试
model.eval()
correct = 0
total = 0with torch.no_grad():for images, labels in test_loader:outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()test_accuracy = 100 * correct / total
print(f"测试准确率: {test_accuracy:.2f}%")# 7. 可视化测试结果
dataiter = iter(test_loader)
images, labels = next(dataiter)
outputs = model(images)
_, predictions = torch.max(outputs, 1)# 获取类别名称
classes = train_dataset.classes# 显示6个样本
fig, axes = plt.subplots(1, 6, figsize=(15, 3))
for i in range(6):# 反归一化显示图像img = images[i].numpy().squeeze()img = (img * 0.5) + 0.5  # 还原从[-1,1]到[0,1]axes[i].imshow(img, cmap='gray')true_char = classes[labels[i]]pred_char = classes[predictions[i]]axes[i].set_title(f"真实: {true_char}\n预测: {pred_char}")axes[i].axis('off')plt.tight_layout()
plt.show()
http://www.jsqmd.com/news/32661/

相关文章:

  • AGC与AVC是什么
  • 链表1
  • 競プロ典型 90 問-难题
  • c++函数调用的大致工作过程
  • Slack端到端测试管道优化:构建时间减半的技术实践
  • 结构体与联合体的区别
  • Day14综合案例二--
  • 解决colcon编译卡死
  • 新学期每日总结(第20天)
  • 铁杆粉丝占比20251105
  • Mybatis 都有哪些 Executor 执行器?它们之间的区别是什么? - Higurashi
  • 100小时学会SAP—问题10:ME51N提示物料XX的强制账户设置(输入账户设置类别)
  • P8990 [北大集训 2021] 小明的树 题解
  • 100小时学会SAP—问题11:MIGO收货时报错不可能为条目BSX CN01确立账户
  • 【动态维护前 x 大元素】LeetCode 3321. 计算子数组的 x-sum II
  • 100小时学会SAP—问题8:财务凭证行项目BSEG及对应的六张表
  • 100小时学会SAP—问题9:MD03提示日期在有效工厂日历之后(请改正)
  • 100小时学会SAP—问题6:创建采购收货时出现WE在年2025中编号不存在
  • 100小时学会SAP—问题7:FB70提示过账码没有定义
  • 树剖
  • 100小时学会SAP—问题5:SAP导航菜单字体突然变小
  • 如何降低大模型幻觉
  • 11月5日---学习总结
  • 11-2
  • 100小时学会SAP—问题4:ME21N创建采购订单报错
  • 11-1
  • 多智能体架构中 如何解决总控agent路由错误的问题
  • 回归(监督学习)
  • 100小时学会SAP—问题3:成本控制控制凭证的编号范围
  • 10-20