当前位置: 首页 > news >正文

TensorFlow——Keras 框架

摘要:本文介绍了使用Keras框架在TensorFlow上构建卷积神经网络(CNN)处理MNIST手写数字识别的完整流程。首先加载并预处理数据,包括维度调整、归一化和独热编码;然后构建包含两个卷积层、池化层、Dropout层和全连接层的序贯模型;接着使用交叉熵损失和Adam优化器编译模型;经过10轮训练后,模型在测试集上达到99.1%的准确率。整个过程展示了Keras简化深度学习模型开发的优势,包括直观的API设计、灵活的层配置和高效的训练流程。

目录

TensorFlow——Keras 框架

利用 Keras 构建深度学习模型的八大步骤

步骤一:加载并预处理数据

步骤二:定义模型架构

步骤三:编译模型

步骤四:训练模型

术语备注


TensorFlow——Keras 框架

Keras 是一款轻量易用的高级 Python 库,运行在 TensorFlow 框架之上。该库的设计核心是帮助开发者理解深度学习相关技术,比如为神经网络搭建网络层,同时兼顾维度形态与数学细节的相关概念。

Keras 可搭建的模型框架主要分为以下两种类型:

  • 序贯式 API(Sequential API)
  • 函数式 API(Functional API)

利用 Keras 构建深度学习模型的八大步骤

  1. 加载数据
  2. 对加载的数据进行预处理
  3. 定义模型结构
  4. 编译模型
  5. 训练模型
  6. 评估模型性能
  7. 执行所需的预测任务
  8. 保存模型

本文将使用 Jupyter 笔记本完成代码运行与结果输出,具体操作步骤如下:

步骤一:加载并预处理数据

这是运行深度学习模型的首要步骤,先导入相关库和模块,再完成数据的加载与预处理。

import warnings warnings.filterwarnings('ignore') import numpy as np np.random.seed(123) # 固定随机种子,保证实验可复现 from keras.models import Sequential from keras.layers import Flatten, MaxPool2D, Conv2D, Dense, Reshape, Dropout from keras.utils import np_utils # 后端使用TensorFlow from keras.datasets import mnist # 加载已打乱的MNIST手写数字数据集,划分为训练集和测试集 (X_train, y_train), (X_test, y_test) = mnist.load_data() # 重塑训练集数据维度,适配卷积层输入 X_train = X_train.reshape(X_train.shape[0], 28, 28, 1) # 重塑测试集数据维度,适配卷积层输入 X_test = X_test.reshape(X_test.shape[0], 28, 28, 1) # 将数据类型转换为32位浮点型 X_train = X_train.astype('float32') X_test = X_test.astype('float32') # 数据归一化,将像素值缩放到0-1区间 X_train /= 255 X_test /= 255 # 将标签进行独热编码,适配多分类任务 Y_train = np_utils.to_categorical(y_train, 10) Y_test = np_utils.to_categorical(y_test, 10)
步骤二:定义模型架构

采用序贯式模型搭建卷积神经网络结构:

model = Sequential() # 添加卷积层,32个3×3卷积核,激活函数为ReLU,指定输入维度为28×28×1 model.add(Conv2D(32, 3, 3, activation ='relu', input_shape = (28,28,1))) # 再次添加卷积层,提取更深层特征 model.add(Conv2D(32, 3, 3, activation ='relu')) # 添加最大池化层,2×2池化窗口,降维并保留关键特征 model.add(MaxPool2D(pool_size = (2,2))) # 添加Dropout层,随机丢弃25%的神经元,防止过拟合 model.add(Dropout(0.25)) # 展平层,将多维特征映射为一维,连接卷积层与全连接层 model.add(Flatten()) # 全连接层,128个神经元,激活函数为ReLU model.add(Dense(128, activation = 'relu')) # 再次添加Dropout层,随机丢弃50%的神经元,进一步防止过拟合 model.add(Dropout(0.5)) # 输出层,10个神经元,softmax激活函数,输出各分类的概率 model.add(Dense(10, activation = 'softmax'))
步骤三:编译模型

配置模型的损失函数、优化器和评估指标,为训练做准备:

# 损失函数选用交叉熵损失,优化器为Adam,评估指标为准确率 model.compile(loss = 'categorical_crossentropy', optimizer = 'adam', metrics = ['accuracy'])
步骤四:训练模型

使用训练集数据对模型进行训练,设置训练参数:

# 批次大小32,训练轮数10,显示训练过程 model.fit(X_train, Y_train, batch_size = 32, epochs = 10, verbose = 1)

训练过程的迭代输出结果如下:

plaintext

第1轮/共10轮 60000/60000 [==============================] - 65s - 损失值:0.2124 - 准确率:0.9345 第2轮/共10轮 60000/60000 [==============================] - 62s - 损失值:0.0893 - 准确率:0.9740 第3轮/共10轮 60000/60000 [==============================] - 58s - 损失值:0.0665 - 准确率:0.9802 第4轮/共10轮 60000/60000 [==============================] - 62s - 损失值:0.0571 - 准确率:0.9830 第5轮/共10轮 60000/60000 [==============================] - 62s - 损失值:0.0474 - 准确率:0.9855 第6轮/共10轮 60000/60000 [==============================] - 59s - 损失值:0.0416 - 准确率:0.9871 第7轮/共10轮 60000/60000 [==============================] - 61s - 损失值:0.0380 - 准确率:0.9877 第8轮/共10轮 60000/60000 [==============================] - 63s - 损失值:0.0333 - 准确率:0.9895 第9轮/共10轮 60000/60000 [==============================] - 64s - 损失值:0.0325 - 准确率:0.9898 第10轮/共10轮 60000/60000 [==============================] - 60s - 损失值:0.0284 - 准确率:0.9910

术语备注

  1. Sequential API:序贯式 API,是 Keras 中最简单的模型构建方式,适用于层与层之间依次连接的线性模型
  2. Functional API:函数式 API,更灵活的模型构建方式,可搭建多输入、多输出、带残差连接的复杂网络
  3. one-hot encoding:独热编码,将离散型标签转换为二进制向量,避免标签间的数值大小干扰模型训练
  4. Dropout:随机失活,深度学习中常用的正则化方法,通过随机丢弃部分神经元,解决模型过拟合问题
  5. Adam:一种自适应学习率优化器,结合了动量法和 RMSprop 的优点,收敛速度快且稳定性好
  6. softmax:归一化指数函数,将神经网络的输出转换为 0-1 之间的概率值,且所有类别概率之和为 1,适用于多分类任务
http://www.jsqmd.com/news/387944/

相关文章:

  • TensorFlow—— 卷积神经网络(CNN)与循环神经网络(RNN)的区别
  • Flink Exactly-Once语义:大数据处理的精确一次性
  • 企业级AI平台架构设计,AI应用架构师的技术创新之路
  • 逐字解析 json 对我来说太难了
  • 谁在帮企业成为AI的答案?2026年GEO服务商全景 - 品牌2025
  • 琼海海鲜美食推荐,2026年人气大厨为你揭晓十大必试佳肴
  • 《P5785 [SDOI2012] 任务安排》
  • 知识检索增强AI Agent:结合LLM与高效搜索算法
  • TG 专题模拟考试
  • Hadoop与GraphQL:构建高效数据API
  • 掌握AI原生应用领域知识库构建的秘诀
  • 每天 5000W Token 免费白嫖! 国内零门槛接入 Claude Code + Longcat,轻松开启 AI-Agent 生产力!全流程手把手教程
  • 豆包和deepseek可以打广告吗?2026年特色GEO服务商盘点 - 品牌2025
  • [数据结构]主席树/可持久化线段树
  • 信息安全管理与评估广东省2026模块一参考答案
  • 详细介绍:Maven 依赖作用域实战避坑指南
  • 循环同构问题证明
  • 生产环境【OpenCV】(六)滤波器最佳实践与性能优化
  • 春晚魔术代码
  • 在风里,在梦中
  • Flutter三方库适配OpenHarmony【flutter_speech】— 语音识别启动与参数配置
  • Flutter三方库适配OpenHarmony【flutter_speech】— 语音识别停止与取消
  • Zookeeper客户端连接池优化实战
  • AI提示设计实证研究:提示工程架构师的创新思路
  • 企业AI创新场景怎么选?AI应用架构师的5步筛选法(附案例)
  • 春节网络“春运”,你家路由器扛得住吗?
  • 掌握大数据领域数据架构,开启数据新征程
  • 智能AR_VR内容创作平台的高可用架构:架构师如何保证7x24运行?(附容灾方案)
  • ‌智慧校园建设:为中小学生找到普惠与实用的黄金平衡点
  • 人工智能之核心基础 机器学习 第十七章 Scikit-learn工具全解析 - 详解