当前位置: 首页 > news >正文

案例之 手写数字识别

利用KNN算法实现手写数字识别

需求:从数万个手写图像的数据集中正确识别数字



利用 plt.imread读取 图片,如果是 png格式,则自动归一化 如果是jpg,则不会进行归一化


1列标签(label) ➕️784列特征(0~783)
行:24000
每个图像 28 * 28,一共784个像素;(一行784要转成28*28)

""" 案例:演示KNN算法 识别图片,即:手写数字识别案例. 介绍: 每张图片都是由 28*28 像素组成的,即:我们的CSV文件中每一行都有 784个像素点,表示图片(每个像素)的 颜色. 最终构成图像 """importmatplotlib.pyplotaspltimportpandasaspdfromsklearn.metricsimportaccuracy_scorefromsklearn.model_selectionimporttrain_test_split# 分割训练集和测试集的fromsklearn.neighborsimportKNeighborsClassifier#KNN算法 分类对象importjoblibfromcollectionsimportCounter# 扩展:忽略警告importwarnings#参1:忽略警告,参2:忽略模块warnings.filterwarnings('ignore',module='sklearn')#1.定义函数,接收用户传入的索引,展示该索引对应的图片。defshow_digit(idx):# 1.读取数据集,获取源数据df=pd.read_csv('./data/手写数字识别.csv')# print(df) #[42000 rows x 785 columns]# 2.判断传入的索引是否越界ifidx<0oridx>len(df)-1:print('索引越界')return# 3.走这里说明没有越界,就正常获取数据x=df.iloc[:,1:]y=df.iloc[:,0]# 4.查看用户传入的索引对应的图片——》是几?print(f'该索引对应的图片:{y.iloc[idx]}')print(f'查看所有的标签分布情况:{Counter(y)}')#Counter({1: 4684, 7: 4401, 3: 4351, 9: 4188, 2: 4177, 6: 4137, 0: 4132, 4: 4072, 8: 4063, 5: 3795})# 5.查看用户传入的索引对应的图片的形状print(f'x的形状:{x.iloc[idx].shape}')#(784,) 要想办法将(784,)转换为(28,28)# print(x.iloc[idx].values) #具体的784个像素点数据# 6.把(784,)转换为(28,28)x=x.iloc[idx].values.reshape(28,28)# print(x)# 7.具体绘制灰度图的动作plt.imshow(x,cmap='gray')plt.axis('off')plt.show()# 2.定义函数,训练模型,并保存训练好的模型deftrain_model():# 1.加载数据集df=pd.read_csv('./data/手写数字识别.csv')# 2.数据的预处理# 2.1 拆分出特征列x=df.iloc[:,1:]#特征列# 2.2 拆分出标签列y=df.iloc[:,0]#标签列# 2.3 打印特征和标签的形状print(f'特征x的形状:{x.shape}')#(42000, 784)print(f'标签y的形状:{y.shape}')#(42000,)print(f'查看所有标签的分布情况:{Counter(y)}')# 2.4 对特征列(拆分前)进行归一化x=x/255# 2.5 拆分数据集和测试集#参1:特征列,参2:标签列,参数3:测试集所占比例,参数4:随机种子,参数5:标签y值进行抽取,保持标签的比例(数据均衡)x_train,x_test,y_train,y_test=train_test_split(x,y,test_size=0.2,random_state=20,stratify=y)# 3.模型训练# 3.1 创建模型对象estimator=KNeighborsClassifier(n_neighbors=3)# 3.2 模型训练estimator.fit(x_train,y_train)# 4.模型评估print(f'准确率:{estimator.score(x_test,y_test)}')print(f'准确率:{accuracy_score(y_test,estimator.predict(x_test))}')# 5.保存模型#参1:模型对象,参2:模型保存路径joblib.dump(estimator,'./model/手写数字识别.pkl')#pickle文件:Python(Pandas)独有的文件类型print('模型保存成功!')# 3.定义函数,测试模型defuse_mode():# 1.加载图片x=plt.imread('./data/demo.png')#28 * 28像素# 2.绘制图片plt.imshow(x,cmap='gray')#灰度图plt.axis('off')#关闭坐标轴plt.show()# 3.加载模型estimator=joblib.load('./model/手写数字识别.pkl')# 4.模型预测# 4.1 查看数据集转换print(x.shape)#(28, 28)print(x.reshape(1,784).shape)#(1, 784)print(x.reshape(1,-1).shape)#语法糖,效果等同于(1,784)# 4.2 具体的转换动作,记得归一化!(因为训练的时候使用了归一化动作)# x=x.reshape(1,-1)/255 #按道理这里需要归一化,因为2.train_model()在训练的时候使用了归一化动作x=x.reshape(1,-1)#这里不需要归一化,因为利用plt.imread读取 图片,如果是 png格式,则自动归一化 如果是jpg,则不会进行归一化# 4.3 模型预测y_predict=estimator.predict(x)# 5.打印预测结果print(f'预测结果:{y_predict}')# 4.测试:if__name__=='__main__':# 绘制数字# show_digit(9) #传入9对应11列的label值:3# show_digit(20)# show_digit(23)#传入23对应25列的label值:0# 训练模型,并保存模型train_model()# 模型预测(使用模型)use_mode()

show_digit(idx):即给定一个索引如传入9,则拿到的是第11行的label,把后面对应的784个像素点读出来转成28* 28的列表,然后绘制;先导包:matplotlib.pyplot是可视化的,pandas是处理的,train_test_split是切割数据的,KNeighborsClassifier是分类对象,joblib是保存模型的(如网格搜索与交叉验证案例中代码中已经有最优超参是3了,但是每次运行代码第4步模型训练都还会再次执行,所以可以将训练完的模型保存下来,用的时候直接读取文件拿到模型即可);Counter去重统计。
1.绘制数字:定义函数,接收用户传入的索引,展示该索引对应的图片,实现在第5步测试中,传入9时能拿到数字3的图片(根据传入的index=9,csv文件中对应第11行对应的label是3,所以绘制的图片是3);2是保存模型到文件中;3是从文件中把模型读出来使用模型进行预测;这样可以将代码分成两部分:一部分是保存、一部分是使用,实现了代码的分离;
对于1:定义函数,接收用户传入的索引,展示该索引对应的图片:show_digit方法:1.读取数据集,获取源数据:pd.read_csv获取df对象,打印出来是42000行 x 785列;2.判断传入的索引是否越界;3.x=df.iloc[:, 1:]x即为特征x_train是第二列(它的索引是1所以是df.iloc[:, 1:])一直到最后,第一列是标签y_train他的索引是0,y=df.iloc[:, 0]即刚好拿到第一列,4.查看用户传入的索引对应的图片—》是几?y.iloc[idx]:传入的是idx=9,对应的iloc为3;然后想办法将要想办法将784个像素点转换为(28,28);5.查看用户传入的索引对应的图片的形状:x.iloc[idx].shape要想办法将(784,)转换为(28,28);看一下数据值x.iloc[idx],但直接这样会包含了索引列值,所以只要数据的话是x.iloc[idx].values即具体的784个像素点数据;6.把(784,)转换为(28,28)x.iloc[idx].values.reshape(28,28),再打印;7.具体绘制灰度图的动作plt.imshow(x,cmap='gray')即imshow展示x中的数据,cmap='gray’表示灰度图,即边界没有锯齿的图;plt.show()展示但是会带有坐标线,所以需要使用plt.axis('off')关闭坐标;Counter包中的Counter(y)查看所有的标签分布情况,会显示42000个图片中几个0、几个1…。
到这里如何展示图片完成:传一个索引,读到那一行的数据它的第1列是结果(标签),后面784列是像素点,想办法将784个像素点转成(28, 28),即可画图显示图片。
2.训练模型,并保存模型:23都要掌握:2指训练模型并把它保存下来,3指使用模型做预测。2.定义函数,训练模型,并保存训练好的模型:1.加载数据集read_csv,2.数据的预处理,2.1 拆分出特征列x=df.iloc[:,1:]即第1列到最后所有列的数据即特征列、y=df.iloc[:,0]第1列标签列;2.3 打印特征和标签的形状:x.shape是(42000, 784),y.shape是(42000,1)1可不写。接着做划分,之前先transfer=StandardScaler()创建一个标准化对象,对x_train=transfer.fit_transform(x_train)x_train做标准化,在x_test=transfer.transform(x_test)对x_test标准化,但这里如果先划分再做标准化则像之前一样x_train、x_test(图示右侧两部分) 都要做标准化需要两次,如果不切割直接对x做标准化之后再切割为x_train、x_test则之前两者都是处理后的数据,则2.4 对特征列(拆分前)进行归一化即可,不需要再进行标准化,归一化的原因:最小像素是0,最大像素是255,差值较大,有必要进行归一化:x=x/255;2.5 拆分数据集和测试集:x_train,x_test,y_train,y_test=train_test_split(x,y,test_size=0.2,random_state=23,stratify=y)参1:特征列,参2:标签列,参数3:测试集所占比例,参数4:随机种子,不加stratify=y有可能猜出来的训练集中没有某值,但有的值又太多,需要基于比例来拆,如1占比多可多拆,占比少少拆,所以需要加上stratify=y参考y轴进行标签的抽取,保持数字的比例均衡;3.模型训练:3.1 创建模型对象、3.2 模型训练,4.模型评估,有两种方式:estimator.score(x_test,y_test)accuracy_score(y_test,estimator.predict(x_test))5.保存模型:将训练好的模型保存起来,如果未来想要使用可直接读取使用,省掉了重复训练的操作:保存模型joblib.dump(estimator,'./model/手写数字识别.pkl')参1:模型对象,参2:模型保存路径,.pkl指pickle文件:Python(Pandas)独有的文件类型,使用它效率更高,pth也一样;
3.模型预测(使用模型):用2中生成且保存的模型读出图片中的数字,plt.imread()imread读图,读出来的图片是2828的序列,但训练模型时用的是784的数据,所以需要先将2828的数据转成784;imshow展示图片;load记载模型;加载完成后做预测img.reahape(1, -1)等价于(1, 784):-1相当于能转多少就转多少:3.定义函数,测试模型:1.加载图片plt.imread('./data/demo.png'),读的数据是x是28 * 28像素,2.绘制图片plt.imshow(x,cmap='gray')plt.axis('off')可关闭坐标轴;3.加载模型estimator=joblib.load('./model/手写数字识别.pkl'):以前的模型estimator=KNeighborsClassifier()获取的,这里的模型是从保存的模型中加载的(日常中运营商将底层的模型训练好之后保存在一个地方,对上层用户提供一个使用权3.use_mode(),不能训练模型2.train_model(),使用时会产生一些日志记录,拿到这些数据他会再train_model训练模型,这就是很多模型开源的原因,借助用户的数据来优化其模型);4.模型预测,4.1 查看数据集转换x.shape(28, 28),训练时用的是784一横行来训练的,但现在的数据是(28, 28),需要将其转成一行x.reshape(1,784).shape即先reshape改形状,再.shape查看,语法糖x.reshape(1,-1).shape,这里(1,784) = (1,-1);4.2 具体的转换动作,按道理这里需要归一化x.reshape(1,-1)/255因为训练的时候使用了归一化动作,但利用plt.imread读取 图片,如果是 png格式,则自动归一化 如果是jpg,则不会进行归一化;4.3 模型预测y_predict=estimator.predict(x)再打印y_predict;

扩展:运行结果中Warning的忽略:import warnings warnings.filterwarnings('ignore',module='sklearn')参1:忽略警告,参2:忽略模块;

http://www.jsqmd.com/news/700352/

相关文章:

  • CSS如何实现模块化的颜色主题_通过CSS变量集中定义色板
  • ROS Melodic下,如何用TurtleBot3模型快速验证你的Gazebo SLAM仿真流程?
  • 别再只盯着IoU了!目标检测中GIoU、DIoU、CIoU损失函数详解与PyTorch实现
  • 终极指南:Disque分布式消息队列DELAY/RETRY/TTL时间参数配置最佳实践
  • FireRedASR-AED-L实际作品:教育场景课堂录音→教学笔记一键生成
  • 【AI原生开发实战专栏】5.5 RAG高级技巧:从Naive RAG到生产级系统
  • 掌握pmu-tools:大规模分布式系统性能监控的终极解决方案
  • SGPlayer在tvOS上的特殊适配:为大屏体验优化的播放器开发技巧
  • 如何用OpenResume实现简历数据可视化:打造专业简历统计与分析功能
  • 2026届必备的五大降重复率助手实际效果
  • 如何快速构建低延迟智能语音应用:RealtimeSTT实战指南
  • 从 ChatGPT 到 AutoGPT:对话式 AI 向智能体演进的关键转折
  • 图像融合新思路:拆开再拼起来——DeFusion论文精读与代码实战指南
  • 《把 Hermes Agent 养成你的专属帕鲁:从捕捉到满级实战指南》(二)
  • 如何快速上手AtCoder Library:5分钟完成安装与配置
  • 避坑指南:Seurat v4/v5对象互转时,你的差异表达结果还可靠吗?
  • 如何在Windows电脑上直接安装安卓应用:APK安装器完整指南
  • LOOT模组加载优化工具:5分钟掌握完美游戏体验的秘诀
  • 如何将Disque消息代理无缝集成到CI/CD流程:自动化部署与版本管理终极指南
  • innovus LEF/DEF 6.0 语言学习参考(1)
  • 2026家装墙板优选指南:适配全场景,告别后期维修烦恼 - 速递信息
  • Python使用XPath定位元素:动态计算与函数调用
  • MySQL主从复制过程中怎么增加从库_利用mysqldump快速扩容从库.txt
  • Apache Kylin 3.1.3 自动化构建指南:如何用Shell脚本调用REST API定时触发增量构建
  • JVM 学习第五天:类加载机制 + 内存调优实战 + 新面试题全解(无重复)
  • XUnity自动翻译器:如何为Unity游戏实现实时文本翻译
  • Simple Form开源项目安全政策:漏洞披露完整指南
  • Qwen3.5-2B实操手册:WebUI中启用RAG插件连接本地知识库方法
  • RocketMQ 系列文章(高级篇第 2 篇):消息追踪与性能优化实战
  • 终极指南:3分钟快速搭建Kafka可视化管理平台