当前位置: 首页 > news >正文

TensorFlow 实现循环神经网络

摘要:本文介绍了使用TensorFlow实现循环神经网络(RNN)的方法,重点针对MNIST手写数字分类任务。主要内容包括:1) RNN的基本原理,通过序列式处理保留上下文信息;2) 具体实现步骤:数据预处理、参数定义、LSTM单元构建、损失函数和优化器配置;3) 训练过程展示,包括批次训练和准确率评估。实验结果表明,该方法在测试集上取得了良好的分类效果,验证了RNN处理序列数据的有效性。代码实现完整展示了从数据加载到模型评估的全流程,为RNN的TensorFlow实践提供了参考范例。

目录

TensorFlow 实现循环神经网络

基于 TensorFlow 的循环神经网络实现

步骤 1:导入所需模块

步骤 2:定义输入参数

步骤 3:定义循环神经网络计算函数并配置损失函数与优化器

步骤 4:启动计算图并训练模型

模型运行输出结果


TensorFlow 实现循环神经网络

循环神经网络是一类面向深度学习的算法,采用序列式处理方法。在传统神经网络中,我们通常假设每个输入和输出都与其他所有层相互独立,而循环神经网络之所以被称为 “循环”,是因为它会以序列的方式执行数学运算。

以下是训练循环神经网络的具体步骤:

  1. 从数据集中输入一个特定的样本;
  2. 网络接收该样本,并利用随机初始化的变量完成相关计算;
  3. 计算得到预测结果;
  4. 将实际输出结果与预期值对比,得到误差值;
  5. 沿原计算路径反向传播误差,同时调整相关变量;
  6. 重复步骤 1 至步骤 5,直至确定用于输出结果的变量已得到合理定义;
  7. 应用这些优化后的变量,对未见过的新输入数据进行系统性的预测。

循环神经网络的示意图表示如下:

基于 TensorFlow 的循环神经网络实现

本节将介绍如何使用 TensorFlow 实现循环神经网络,具体步骤如下:

步骤 1:导入所需模块

TensorFlow 提供了多个专用库,用于实现循环神经网络模块,通过以下代码导入核心模块:

from __future__ import print_function import tensorflow as tf from tensorflow.contrib import rnn from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("/tmp/data/", one_hot = True)

上述库的核心作用是定义输入数据,这是实现循环神经网络的基础环节。

步骤 2:定义输入参数

我们的核心目标是利用循环神经网络对图像进行分类,将每张图像的行视为一个像素序列。MNIST 数据集的图像尺寸固定为 28×28 像素,因此每个样本需处理 28 个序列,每个序列包含 28 个步骤,以下是输入参数的定义代码:

python

运行

n_input = 28 # MNIST数据输入,图像尺寸28*28 n_steps = 28 # 序列步数 n_hidden = 128 # 隐藏层神经元数量 n_classes = 10 # 分类类别数(0-9数字) # 定义TensorFlow计算图的输入占位符 x = tf.placeholder("float", [None, n_steps, n_input]) y = tf.placeholder("float", [None, n_classes]) # 定义权重和偏置项 weights = { 'out': tf.Variable(tf.random_normal([n_hidden, n_classes])) } biases = { 'out': tf.Variable(tf.random_normal([n_classes])) }

步骤 3:定义循环神经网络计算函数并配置损失函数与优化器

通过自定义函数实现循环神经网络的核心计算逻辑,对比数据形状与当前输入形状,保证计算精度,同时定义损失函数、优化器和模型评估指标:

def RNN(x, weights, biases): # 将输入数据按序列维度拆解 x = tf.unstack(x, n_steps, 1) # 定义LSTM细胞单元 lstm_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0) # 获取LSTM细胞的输出和状态 outputs, states = rnn.static_rnn(lstm_cell, x, dtype = tf.float32) # 对最后一个时间步的输出做线性激活,得到预测结果 return tf.matmul(outputs[-1], weights['out']) + biases['out'] # 得到模型预测值 pred = RNN(x, weights, biases) # 定义交叉熵损失函数 cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = pred, labels = y)) # 定义Adam优化器,最小化损失函数 optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate).minimize(cost) # 计算模型预测准确率 correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1)) accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) # 初始化所有全局变量 init = tf.global_variables_initializer()

步骤 4:启动计算图并训练模型

启动 TensorFlow 计算图执行计算,完成模型训练并测试模型准确率:

with tf.Session() as sess: # 初始化变量 sess.run(init) step = 1 # 迭代训练,直至达到最大迭代次数 while step * batch_size < training_iters: # 获取批次训练数据 batch_x, batch_y = mnist.train.next_batch(batch_size) # 调整数据形状以匹配模型输入 batch_x = batch_x.reshape((batch_size, n_steps, n_input)) # 执行优化步骤 sess.run(optimizer, feed_dict={x: batch_x, y: batch_y}) # 定期打印训练结果 if step % display_step == 0: # 计算批次准确率 acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y}) # 计算批次损失值 loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y}) # 打印迭代次数、损失值和准确率 print("Iter " + str(step*batch_size) + ", Minibatch Loss= " + \ "{:.6f}".format(loss) + ", Training Accuracy= " + \ "{:.5f}".format(acc)) step += 1 # 打印训练完成提示 print("Optimization Finished!") # 定义测试数据量 test_len = 128 # 准备测试数据并调整形状 test_data = mnist.test.images[:test_len].reshape((-1, n_steps, n_input)) test_label = mnist.test.labels[:test_len] # 打印测试准确率 print("Testing Accuracy:", \ sess.run(accuracy, feed_dict={x: test_data, y: test_label}))

模型运行输出结果

执行上述代码的终端命令及输出如下:

plaintext

E:\Tensorflowproject>activate tensorflow (tensorflow) E:\TensorFlowProject>python recurrent_network.py

运行过程中会出现部分 TensorFlow 弃用警告(提示后续版本将移除相关接口,建议使用 tf.data 等新接口替代),同时输出数据集解压信息,最终的训练迭代结果如下:

http://www.jsqmd.com/news/376189/

相关文章:

  • TensorFlow - TensorBoard 可视化
  • 2026年冷却塔改造厂家最新推荐:上海良机冷却塔/冷却塔填料更换/圆形冷却塔/常州良机冷却塔/方型冷却塔/无锡良机冷却塔/选择指南 - 优质品牌商家
  • 2026年无锡冷却塔维修厂家权威推荐榜:苏州良机冷却塔、闭式冷却塔、上海冷却塔维修、冷却塔填料更换、冷却塔改造选择指南 - 优质品牌商家
  • 2026年评价高的冷却塔配件公司推荐:良机冷却塔厂家/良机冷却塔维修/良机冷却塔配件/苏州良机冷却塔/闭式冷却塔/选择指南 - 优质品牌商家
  • 寒假集训9——图论
  • Java毕设项目:基于springboot的文创销售管理系统(源码+文档,讲解、调试运行,定制等)
  • blender 修改物体 修改衣服
  • ue 蓝图添加灯光
  • 2026年常州冷却塔维修厂家权威推荐榜:昆山冷却塔维修/昆山良机冷却塔/杭州良机冷却塔/良机冷却塔维修/良机冷却塔配件/选择指南 - 优质品牌商家
  • ue 框选 多个对象 框选物体
  • 2026年冷却塔厂家公司权威推荐:冷却塔填料更换、圆形冷却塔、常州良机冷却塔、方型冷却塔、无锡良机冷却塔、昆山冷却塔维修选择指南 - 优质品牌商家
  • 力扣第45题:二叉树的右视图
  • Nodejs+vue+ElementUI框架的在线学习资源推荐的设计与实现
  • 2026年开年室内健身器材综合制造厂商权威评测与选型指南 - 2026年企业推荐榜
  • 2026年月嫂培训机构厂家最新推荐:北京正规家政月嫂公司招商连锁加盟、北京正规家政月嫂公司招聘合伙人、北京高端月嫂公司选择指南 - 优质品牌商家
  • Nodejs+vue+ElementUI框架的志愿服务管理系统的设计与实现
  • 商用与家用兼顾:2026江苏健身器材品牌全景观察 - 2026年企业推荐榜
  • NASA 先进的空中交通(AAM)概述 2025
  • Nodejs+vue+ElementUI框架电动车辆充电桩报修管理系统的设计与开发
  • 2026年冷却塔填料更换公司权威推荐:良机冷却塔厂家/良机冷却塔维修/良机冷却塔配件/苏州冷却塔维修/苏州良机冷却塔/选择指南 - 优质品牌商家
  • Nodejs+vue+ElementUI框架的一键选择“搭子”线下社交陪伴聊天系统
  • Nodejs+vue+ElementUI框架二手交易系统的设计与实现
  • 2026年AI客服机器人厂家权威推荐榜:BOSS直聘AI客服机器人、抖音AI客服机器人、VEO视频生成、京东AI客服机器人选择指南 - 优质品牌商家
  • Windows 下 Node.js 重定向输出导致中文乱码的问题分析
  • Nodejs+vue+ElementUI框架共享厨师预约平台的设计与实现
  • 1.77秒克隆了100字!1G显存就能玩语音声音克隆,速度增快150倍,效果不输大模型,LuxTTS离线整合包_封面
  • 2026年AI视频生成厂家最新推荐:文字生成视频AI、电商短视频AI、美团AI客服机器人、营销视频AI制作、视频号AI制作选择指南 - 优质品牌商家
  • 2026年淘宝AI客服机器人厂家最新推荐:BOSS直聘AI客服机器人、商品视频AI生成、小红书AI客服机器人、微信AI客服机器人选择指南 - 优质品牌商家
  • Nodejs+vue+ElementUI框架家庭装修 家装 装饰工程管理系统
  • 2026年抖音AI客服机器人厂家最新推荐:拼多多智能客服/文字生成视频AI/电商短视频AI/美团AI客服机器人/选择指南 - 优质品牌商家