当前位置: 首页 > news >正文

手写识别

import os
import random
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw, ImageFont
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

解决OMP冲突
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
device = torch.device('cpu')

配置参数
CHARS = ['一', '二', '三', '十', '人', '口', '手', '日', '月', '水']
TRAIN_NUM = 200
TEST_NUM = 50
IMG_SIZE = 64
DATA_SAVE_DIR = 'hanzi_data'
BATCH_SIZE = 32
EPOCHS = 30
LEARNING_RATE = 0.005

-------------------------- 关键改进:使用PIL默认字体生成汉字(无需额外安装) --------------------------
class HanziDatasetGenerator:
def init(self):

不依赖系统字体,使用PIL的默认字体+手动调整位置确保汉字显示

self.font = ImageFont.load_default()
print("提示:使用默认字体生成汉字(可能显示较简单,但能保证运行)")

def _generate_single_img(self, char):
"""生成简单但可区分的汉字图像"""
img = Image.new('L', (IMG_SIZE, IMG_SIZE), color=255) # 白底
draw = ImageDraw.Draw(img)

# 针对默认字体调整位置(确保汉字完整显示)
# 不同汉字手动调整偏移量,保证特征差异
char_offsets = {'一': (5, 25), '二': (5, 15), '三': (5, 10),'十': (20, 15), '人': (10, 20), '口': (15, 15),'手': (5, 10), '日': (15, 15), '月': (10, 15), '水': (5, 10)
}
x, y = char_offsets[char]# 固定较大字体尺寸,确保笔画清晰
font_size = 40
try:# 再次尝试系统字体,失败则用默认font = ImageFont.truetype('simsun.ttc', size=font_size)  # 尝试宋体draw.text((x, y), char, font=font, fill=0, stroke_width=2)
except:# 用默认字体,手动加粗笔画确保可区分draw.text((x, y), char, font=self.font, fill=0, stroke_width=3)# 二次绘制增强笔画(避免默认字体太细)draw.text((x+1, y), char, font=self.font, fill=0, stroke_width=2)# 轻微旋转增加差异
rotation = random.randint(-10, 10)
img = img.rotate(rotation, expand=False, fillcolor=255)return img

def generate_dataset(self):
"""生成数据集目录和图片"""
if os.path.exists(DATA_SAVE_DIR):
for root, dirs, files in os.walk(DATA_SAVE_DIR, topdown=False):
for f in files: os.remove(os.path.join(root, f))
for d in dirs: os.rmdir(os.path.join(root, d))
os.rmdir(DATA_SAVE_DIR)

# 创建目录
for split in ['train', 'test']:for char in CHARS:os.makedirs(os.path.join(DATA_SAVE_DIR, split, char), exist_ok=True)# 生成样本
print("生成数据集...")
for char in CHARS:for i in range(TRAIN_NUM):img = self._generate_single_img(char)img.save(os.path.join(DATA_SAVE_DIR, 'train', char, f'{i}.png'))for i in range(TEST_NUM):img = self._generate_single_img(char)img.save(os.path.join(DATA_SAVE_DIR, 'test', char, f'{i}.png'))
print(f"数据集生成完成:{os.path.abspath(DATA_SAVE_DIR)}")

-------------------------- 数据集加载 --------------------------
class HanziDataset(Dataset):
def init(self, split='train'):
self.split = split
self.data_dir = os.path.join(DATA_SAVE_DIR, split)
self.char_list = CHARS
self.char2idx = {c:i for i,c in enumerate(self.char_list)}
self.images, self.labels = self._load_data()
self.transform = transforms.ToTensor()

def _load_data(self):
images = []
labels = []
for char in self.char_list:
char_dir = os.path.join(self.data_dir, char)
for img_name in os.listdir(char_dir):
images.append(os.path.join(char_dir, img_name))
labels.append(self.char2idx[char])
return images, labels

def len(self):
return len(self.images)

def getitem(self, idx):
img = Image.open(self.images[idx]).convert('L')
return self.transform(img), self.labels[idx]
-------------------------- 模型 --------------------------
class FeatureCNN(nn.Module):
def init(self, num_classes=10):
super().init()
self.features = nn.Sequential(
nn.Conv2d(1, 8, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2), # 64→32

    nn.Conv2d(8, 16, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(2, 2)  # 32→16
)
self.classifier = nn.Linear(16 * 16 * 16, num_classes)

def forward(self, x):
x = self.features(x)
x = x.view(-1, 16 * 16 * 16)
x = self.classifier(x)
return x
-------------------------- 训练与识别 --------------------------
def main():

生成数据集(关键:即使没有中文字体也能生成可区分的图像)

generator = HanziDatasetGenerator()
generator.generate_dataset()

加载数据

train_dataset = HanziDataset('train')
test_dataset = HanziDataset('test')
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

模型与优化器

model = FeatureCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

训练

print("\n开始训练...")
best_acc = 0.0
for epoch in range(EPOCHS):
model.train()
total_loss = 0.0

for imgs, labels in train_loader:imgs, labels = imgs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(imgs)loss = criterion(outputs, labels)loss.backward()optimizer.step()total_loss += loss.item() * imgs.size(0)avg_loss = total_loss / len(train_dataset)# 测试
model.eval()
correct = 0
total = 0
with torch.no_grad():for imgs, labels in test_loader:imgs, labels = imgs.to(device), labels.to(device)outputs = model(imgs)_, preds = torch.max(outputs, 1)total += labels.size(0)correct += (preds == labels).sum().item()acc = 100 * correct / total
print(f"轮次{epoch+1:2d} | 损失:{avg_loss:.4f} | 准确率:{acc:.2f}%")if acc > best_acc:best_acc = acctorch.save(model.state_dict(), 'best_model.pth')if acc >= 85:print(f"达标!准确率:{acc:.2f}%")break

识别

model.load_state_dict(torch.load('best_model.pth'))
print(f"\n最优准确率:{best_acc:.2f}%")

while True:
path = input("\n输入图片路径(q退出):")
if path.lower() == 'q':
break
if not os.path.exists(path):
print("路径错误")
continue

try:img = Image.open(path).convert('L').resize((64,64))img_tensor = transforms.ToTensor()(img).unsqueeze(0).to(device)with torch.no_grad():output = model(img_tensor)pred_char = CHARS[torch.argmax(output).item()]confidence = torch.softmax(output, dim=1).max().item() * 100print(f"识别结果:{pred_char} | 可信度:{confidence:.2f}%")
except Exception as e:print(f"错误:{e}")
http://www.jsqmd.com/news/37931/

相关文章:

  • 01人月神话读后感--软件中的“焦油坑”
  • 线程池FAQ
  • Python Threading new thread
  • 从同步耦合到异步解耦:消息中间件如何重塑系统间的通信范式?
  • 深入理解OpenWrt生态:LuCI、UCI、ubus与rpcd的协同工作机制 - 实践
  • 251111重点
  • 第22天(简单题中等题 二分查找)
  • In the name of capitalists
  • 2025.11.11总结
  • K8S百万资源预list加载数据方案
  • 102302105汪晓红数据采集作业2
  • 【数据结构】:链表的核心实现与运行解析
  • Meta AI 推出全语种语音识别系统,支持 1600+语言;谢赛宁、李飞飞、LeCun 联手发布「空间超感知」AI 框架丨日报
  • Python Socket网络编程
  • 研发度量DORA+SPACE+OST 影响模型
  • 详细介绍:HUD-汽车图标内容
  • 比特币的简单技术原理
  • 后端八股之mysql - 指南
  • 2025年包装机厂家推荐排行榜,全自动包装机,全自动包装机生产线,非标定制生产线,非标定制机器公司精选指南
  • nginx拦截ip
  • 【CI130x 离在线】FIFO的学习及实例
  • 2025年包装机厂家权威推荐榜:全自动包装机、半自动包装机,高效智能包装解决方案精选
  • CF1187F
  • 刷题日记—数组—数组偏移
  • 【数据结构】:C 语言常见排序算法的实现与特性解析 - 指南
  • rdp远程桌面协议进行远程桌面控制
  • 第五届 RTE 年度 Demo Day 三强公布!看到对话式 AI 的 N 种未来
  • 活用数组题目参考
  • static、static静态代码块、Math库、final
  • Miko Framework 系列(一):简介与核心理念