当前位置: 首页 > news >正文

基于pytorch卷积神经网络的汉字识别系统

基于pytorch卷积神经网络的汉字识别系统

源代码如下(pycharm//附运行结果):

import os
import shutil
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from sklearn.metrics import accuracy_score
import warnings
from tqdm import tqdm # 进度条显示

warnings.filterwarnings('ignore')


# ======================== 1. 配置参数 ========================
class Config:
# 数据路径配置
TXT_PATH = "C:/Users/33946/Downloads/hd_chinese/hd_chinese/train.txt"
RAW_PNG_DIR = "C:/Users/33946/Downloads/hd_chinese/hd_chinese/test_data"
OUTPUT_DATASET_ROOT = "C:/Users/33946/Downloads/hd_chinese/hd_chinese/dataset"

# 训练参数配置
IMAGE_SIZE = (64, 64)
BATCH_SIZE = 64 # GPU可用时用64,CPU用32
EPOCHS = 100
LR = 1e-4
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
SAVE_DIR = "saved_models"
SAVE_INTERVAL = 10
ROTATION_DEGREES = 5
TRANSLATE = (0.05, 0.05)


# 创建必要目录
os.makedirs(Config.SAVE_DIR, exist_ok=True)


# ======================== 2. 数据集处理 ========================
def process_train_txt_and_generate_dataset():
print("===== 开始处理数据集 =====")
for split in ['train', 'val', 'test']:
os.makedirs(os.path.join(Config.OUTPUT_DATASET_ROOT, split), exist_ok=True)

data = []
with open(Config.TXT_PATH, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if not line:
continue
png_rel_path, text = line.split('\t', 1)
png_filename = os.path.basename(png_rel_path)
if text:
first_char = text[0]
data.append((png_filename, first_char))
else:
print(f"⚠️ 跳过空文本:{png_rel_path}")

char_groups = {}
for png_filename, first_char in data:
if first_char not in char_groups:
char_groups[first_char] = []
char_groups[first_char].append(png_filename)

total_images = 0
for char, png_list in char_groups.items():
random.shuffle(png_list)
total = len(png_list)
total_images += total
train_num = int(total * 0.7)
val_num = int(total * 0.2)

for i, png_filename in enumerate(png_list):
src_path = os.path.join(Config.RAW_PNG_DIR, png_filename)
if not os.path.exists(src_path):
print(f"⚠️ 跳过不存在的文件:{src_path}")
continue

if i < train_num:
split = 'train'
elif i < train_num + val_num:
split = 'val'
else:
split = 'test'

dst_dir = os.path.join(Config.OUTPUT_DATASET_ROOT, split, char)
os.makedirs(dst_dir, exist_ok=True)
shutil.copy(src_path, os.path.join(dst_dir, png_filename))

print(f"✅ 数据集处理完成!共处理 {total_images} 张图像,{len(char_groups)} 个汉字类别")
print(f" 数据集目录:{Config.OUTPUT_DATASET_ROOT}")
return char_groups


# 仅首次运行时处理数据集,后续可注释
char_groups = process_train_txt_and_generate_dataset()

# ======================== 3. 数据加载 ========================
CHINESE_CHARS = sorted(char_groups.keys())
CHAR_TO_IDX = {char: idx for idx, char in enumerate(CHINESE_CHARS)}
IDX_TO_CHAR = {idx: char for idx, char in enumerate(CHINESE_CHARS)}
NUM_CLASSES = len(CHINESE_CHARS)
print(f"\n===== 模型配置 =====")
print(f" 识别类别数:{NUM_CLASSES},示例汉字:{CHINESE_CHARS[:10]}...")


class ChineseCharDataset(Dataset):
def __init__(self, data_dir, char_to_idx, transform=None):
self.data_dir = data_dir
self.char_to_idx = char_to_idx
self.transform = transform
self.image_paths = []
self.labels = []

for char in os.listdir(data_dir):
char_dir = os.path.join(data_dir, char)
if not os.path.isdir(char_dir) or char not in char_to_idx:
continue
for img_name in os.listdir(char_dir):
if img_name.endswith(".png"):
self.image_paths.append(os.path.join(char_dir, img_name))
self.labels.append(char_to_idx[char])

def __len__(self):
return len(self.image_paths)

def __getitem__(self, idx):
img = Image.open(self.image_paths[idx]).convert("L")
label = self.labels[idx]
if self.transform:
img = self.transform(img)
return img, torch.tensor(label, dtype=torch.long)


def get_transforms():
train_transform = transforms.Compose([
transforms.Resize(Config.IMAGE_SIZE),
transforms.RandomRotation(Config.ROTATION_DEGREES),
transforms.RandomAffine(0, translate=Config.TRANSLATE),
transforms.RandomResizedCrop(Config.IMAGE_SIZE, scale=(0.9, 1.0)),
transforms.ToTensor(),
transforms.RandomErasing(p=0.1, scale=(0.02, 0.05)),
transforms.Normalize(mean=[0.5], std=[0.5])
])
val_test_transform = transforms.Compose([
transforms.Resize(Config.IMAGE_SIZE),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
return train_transform, val_test_transform


train_transform, val_test_transform = get_transforms()
train_dir = os.path.join(Config.OUTPUT_DATASET_ROOT, "train")
val_dir = os.path.join(Config.OUTPUT_DATASET_ROOT, "val")
test_dir = os.path.join(Config.OUTPUT_DATASET_ROOT, "test")

train_dataset = ChineseCharDataset(train_dir, CHAR_TO_IDX, train_transform)
val_dataset = ChineseCharDataset(val_dir, CHAR_TO_IDX, val_test_transform)
test_dataset = ChineseCharDataset(test_dir, CHAR_TO_IDX, val_test_transform)

# Windows系统禁用多进程(解决路径问题)
train_loader = DataLoader(
train_dataset,
batch_size=Config.BATCH_SIZE,
shuffle=True,
num_workers=0,
pin_memory=True if Config.DEVICE.type == 'cuda' else False
)
val_loader = DataLoader(
val_dataset,
batch_size=Config.BATCH_SIZE,
shuffle=False,
num_workers=0,
pin_memory=True if Config.DEVICE.type == 'cuda' else False
)
test_loader = DataLoader(
test_dataset,
batch_size=Config.BATCH_SIZE,
shuffle=False,
num_workers=0,
pin_memory=True if Config.DEVICE.type == 'cuda' else False
)

print(f"\n===== 数据集加载 =====")
print(f" 训练集:{len(train_dataset)} 张图像")
print(f" 验证集:{len(val_dataset)} 张图像")
print(f" 测试集:{len(test_dataset)} 张图像")


# ======================== 4. 模型定义 ========================
class ImprovedChineseCharCNN(nn.Module):
def __init__(self, num_classes):
super(ImprovedChineseCharCNN, self).__init__()
self.conv_layers = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Dropout(0.05),

nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Dropout(0.05),

nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Dropout(0.05)
)

dummy = torch.randn(1, 1, Config.IMAGE_SIZE[0], Config.IMAGE_SIZE[1])
self.fc_input_dim = self.conv_layers(dummy).view(1, -1).size(1)

self.fc_layers = nn.Sequential(
nn.Linear(self.fc_input_dim, 1024),
nn.ReLU(inplace=True),
nn.BatchNorm1d(1024),
nn.Dropout(0.2),

nn.Linear(1024, 512),
nn.ReLU(inplace=True),
nn.BatchNorm1d(512),
nn.Dropout(0.2),

nn.Linear(512, num_classes)
)

def forward(self, x):
x = self.conv_layers(x)
x = x.view(x.size(0), -1)
x = self.fc_layers(x)
return x


model = ImprovedChineseCharCNN(NUM_CLASSES).to(Config.DEVICE)
print(f"\n===== 模型信息 =====")
print(f" 设备:{Config.DEVICE}")
print(f" 模型结构:{model}")


# ======================== 5. 训练与评估函数 ========================
def train_one_epoch(model, train_loader, criterion, optimizer, device):
model.train()
total_loss, all_preds, all_labels = 0.0, [], []
for images, labels in tqdm(train_loader, desc="训练中", leave=False):
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)

optimizer.zero_grad()
loss.backward()
optimizer.step()

total_loss += loss.item() * images.size(0)
all_preds.extend(torch.argmax(outputs, 1).cpu().numpy())
all_labels.extend(labels.cpu().numpy())

avg_loss = total_loss / len(train_loader.dataset)
acc = accuracy_score(all_labels, all_preds)
return avg_loss, acc


def evaluate(model, dataloader, criterion, device, split="验证"):
model.eval()
total_loss, all_preds, all_labels = 0.0, [], []
with torch.no_grad():
for images, labels in tqdm(dataloader, desc=f"{split}中", leave=False):
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)

total_loss += loss.item() * images.size(0)
all_preds.extend(torch.argmax(outputs, 1).cpu().numpy())
all_labels.extend(labels.cpu().numpy())

avg_loss = total_loss / len(dataloader.dataset)
acc = accuracy_score(all_labels, all_preds)
return avg_loss, acc


# ======================== 新增:输出识别文字结果 ========================
def print_recognition_results(model, dataloader, device, idx_to_char, num_samples=5):
"""随机打印指定数量样本的识别结果(预测文字 vs 真实文字)"""
model.eval()
samples_shown = 0
# 随机打乱数据顺序,避免每次打印相同样本
random_indices = random.sample(range(len(dataloader.dataset)), min(num_samples, len(dataloader.dataset)))

with torch.no_grad():
for idx in random_indices:
# 获取单个样本
image, label = dataloader.dataset[idx]
image = image.unsqueeze(0).to(device) # 增加批次维度
output = model(image)
pred_idx = torch.argmax(output, 1).cpu().item() # 预测索引
true_idx = label.item() # 真实索引

# 转换为文字
pred_char = idx_to_char[pred_idx]
true_char = idx_to_char[true_idx]

# 打印结果
print(f"样本 {samples_shown + 1}:预测='{pred_char}',真实='{true_char}',"
f"{'✅' if pred_char == true_char else '❌'}")
samples_shown += 1
if samples_shown >= num_samples:
break


# ======================== 6. 主训练函数(支持断点续训) ========================
def main_train(load_from_checkpoint=True, checkpoint_path="saved_models/best_model.pth"):
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(
model.parameters(),
lr=Config.LR,
weight_decay=1e-4
)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode='max',
patience=3,
factor=0.5
)

best_val_acc = 0.0
start_epoch = 1

if load_from_checkpoint and os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint["model_state_dict"])
if "optimizer_state_dict" in checkpoint:
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
if "val_acc" in checkpoint:
best_val_acc = checkpoint["val_acc"]
if "epoch" in checkpoint:
start_epoch = checkpoint["epoch"] + 1
print(f"📌 已加载历史模型,从第{start_epoch}轮继续训练(历史最佳准确率:{best_val_acc:.4f})")

print(f"\n===== 开始训练 =====")
for epoch in range(start_epoch, Config.EPOCHS + 1):
train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, Config.DEVICE)
val_loss, val_acc = evaluate(model, val_loader, criterion, Config.DEVICE, split="验证")

scheduler.step(val_acc)

print(f"Epoch [{epoch:3d}/{Config.EPOCHS}] | "
f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | "
f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f} | "
f"LR: {optimizer.param_groups[0]['lr']:.6f}")

if epoch % Config.SAVE_INTERVAL == 0:
torch.save({
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"val_acc": val_acc
}, os.path.join(Config.SAVE_DIR, f"model_epoch_{epoch}.pth"))
print(f"💾 已保存第{epoch}轮模型")

if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save({
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"val_acc": best_val_acc,
"char_to_idx": CHAR_TO_IDX,
"idx_to_char": IDX_TO_CHAR
}, os.path.join(Config.SAVE_DIR, "best_model.pth"))
print(f"🌟 最佳模型更新(Val Acc: {best_val_acc:.4f})")

# 训练完成后测试并输出识别结果
best_model_path = os.path.join(Config.SAVE_DIR, "best_model.pth")
if os.path.exists(best_model_path):
best_model = torch.load(best_model_path)
model.load_state_dict(best_model["model_state_dict"])
test_loss, test_acc = evaluate(model, test_loader, criterion, Config.DEVICE, split="测试")
print(f"\n===== 训练完成 =====")
print(f" 测试集准确率:{test_acc:.4f}")

# 调用新增函数,输出5个样本的识别文字
print(f"\n===== 随机抽取5个测试样本的识别结果 =====")
print_recognition_results(model, test_loader, Config.DEVICE, IDX_TO_CHAR, num_samples=5)
else:
print("\n⚠️ 未找到最佳模型文件")


# ======================== 启动训练 == ======================
if __name__ == "__main__":
main_train(load_from_checkpoint=True)

////准确率达90%以上////

http://www.jsqmd.com/news/34416/

相关文章:

  • 制图-学习日志
  • 2025年热门成人自考机构推荐
  • 实用指南:手写MyBatis第95弹:调试追踪MyBatis SQL执行流程的终极指南
  • SOCKS5代理:通用性与协议覆盖
  • 口碑好的成人自考机构2025年推荐榜单
  • 2025年国内成人自考机构口碑推荐排行榜单:选择指南与深度解析
  • 2025 年 11 月除锈剂厂家推荐排行榜,钢铁除锈剂,金属除锈剂,钢材除锈剂,不锈钢除锈剂,螺丝除锈剂,弹簧除锈剂,铝型材除锈剂公司推荐
  • CANopen转Profinet是一种构建于控制局域网设备之上的协议网关
  • 2025 年 11 月喷头漏墨维修厂家推荐排行榜,理光喷头漏墨,京瓷喷头漏墨,精工喷头漏墨,喷绘机喷头漏墨维修公司推荐
  • Cohen‘s Kappa系数:衡量分类一致性的黄金标准及其在NLP中的应用 - 实践
  • 2025年国内成人自考机构口碑推荐榜单:如何选择靠谱的学历提升平台
  • 2025年11月星光喷头厂家推荐排行榜:专业选购与维护指南
  • Spring Cloud Alibaba + Sentinel
  • 德鲁克管理哲学:管理是知行统一的实践创新 - 详解
  • 2025 年 11 月食堂承包公司推荐排行榜,食堂承包商,食堂承包方案,大型食堂承包,专业餐饮服务与高效运营管理口碑之选
  • 2025年双组份喷涂泵定做厂家权威推荐榜单:双组份喷漆机专用喷枪/无气喷涂机/高压无气喷涂泵专用喷枪源头厂家精选
  • 智能充气泵方案:充气泵电机怎么选?怎么适配
  • 智能家居产品品牌推荐排行2025:权威榜单揭晓
  • 2025 年 11 月电弧故障保护器厂家推荐排行榜,断路器/检测断路器,并联/串联电弧故障保护器,防火限流式保护器,故障电弧探测器公司推荐
  • 2025 年 11 月食堂送菜平台推荐排行榜,送菜上门,食堂送菜公司,饭堂送菜平台,专业高效与新鲜直达服务口碑之选
  • 小 E 的传奇一生
  • 2025 年黄锈石供应厂家最新推荐排行榜:聚焦实力厂商与新锐品牌,揭秘口碑优质服务商黄锈锈石/非标锈石/石材锈石公司推荐
  • 2025 年 11 月农产品配送厂家推荐排行榜,蔬菜配送,新鲜生鲜配送,食堂农产品配送公司,专业高效服务口碑之选
  • 2025年智能家居产品品牌推荐排行榜:权威口碑指南
  • 现今有实力的智能家居产品公司排行
  • 2025 年 11 月蔬菜配送厂家推荐排行榜,新鲜生鲜水果有机食材,食堂蔬菜配送中心,生鲜蔬菜配送供应商及平台上门服务精选指南
  • 用Dify工作流打造你的AI测试智能体,效率提升500%
  • 2025 年 11 月食材配送厂家推荐排行榜,食材采购,生鲜食材配送,食堂食材配送,食材配送中心公司推荐
  • Serverless感悟与杂谈
  • 2025 年 11 月展厅设计厂家推荐排行榜,企业展厅定制,科技展馆设计,全屋定制展厅,数字化多媒体展厅,人工智能展台设计公司推荐