Tensorboard使用
一.整体认知
1.核心定位
TensorBoard 是专门配合 PyTorch/TensorFlow 使用的可视化工具,本身不参与模型运算、训练逻辑,核心作用:读取代码输出的日志文件,在浏览器中生成图表、图片、网络结构图,直观观察模型训练状态。
2.完整工作流程
编写
Python代码,创建日志记录工具writer;调用对应方法,将数值、图片、网络结构等数据写入本地
logs日志文件夹;在终端启动
TensorBoard服务,指定读取的日志文件夹;
conda activate D:\conda_envs\pytorch#首先进入虚拟环境,根据条件改变#启动 TensorBoard,根据条件改变--logdir后的logs路径tensorboard--logdir=D:\Users\xin\PycharmProjects\PythonProject2\python\logs- 打开浏览器访问对应地址,查看各类可视化内容。
3.环境前置准备
仅需安装一次,在当前项目所用的虚拟环境终端中执行命令:
pip install tensorboard
二.核心对象:SummaryWriter
1.作用
SummaryWriter是TensorBoard的唯一入口类,所有数据写入、日志生成都依赖它。我们代码中定义的writer变量,就是SummaryWriter实例对象。
2.完整导入语法
除核心库外,额外导入os、shutil:
shutil:用于自动删除旧日志文件夹,避免新旧数据叠加导致图表混乱;os:用于判断文件夹是否存在、打印日志绝对路径,解决跨盘、路径找不到的问题。
# 核心库:TensorBoard 日志写入fromtorch.utils.tensorboardimportSummaryWriter# 辅助库:文件路径判断、文件夹删除importosimportshutil3.创建writer实例(标准完整写法)
代码逻辑说明
先判断项目目录下是否存在旧的
logs文件夹,如果存在则彻底删除;实例化
SummaryWriter,指定日志存放文件夹名称(示例为logs,可自定义);打印日志文件夹的绝对路径,用于排查路径错误问题。
完整代码
# 1. 判断本地是否存在旧的 logs 文件夹ifos.path.exists("logs"):# 2. 递归删除整个 logs 文件夹及内部所有日志文件shutil.rmtree("logs")# 3. 创建日志写入对象,自动生成全新的 logs 文件夹# 括号内字符串 = 日志文件夹名,可修改为 my_log / train_log 等writer=SummaryWriter("logs")# 4. 打印日志文件夹绝对路径,方便终端启动时核对路径log_abs_path=os.path.abspath(writer.log_dir)print(f"当前日志文件存放绝对路径:{log_abs_path}")补充说明
- 文件夹名称自定义规则:代码中写什么名称,终端启动命令
--logdir后就必须写相同名称,大小写、字符必须完全一致;(一般写绝对路径,如
tensorboard --logdir=D:\Users\xin\PycharmProjects\PythonProject2\python\logs)
- 路径规则:不写绝对路径时,
logs默认生成在当前.py文件同级目录。
- 收尾操作(强制要求,不可省略)
所有数据记录完成后,必须关闭writer对象:
# 关闭日志写入器,将缓存数据完整写入本地文件,防止日志损坏、数据丢失writer.close()使用场景:代码末尾、训练循环结束后执行。
三.常用方法 1:add_scalar记录标量
1.概念解释
标量(scalar):单一的数字,无维度、无形状。
深度学习训练中最常用的标量:损失值loss、准确率accuracy、学习率lr。
适用场景:记录训练过程中数值的变化趋势,生成折线图。
2.标准语法格式
writer.add_scalar(tag, scalar_value, global_step)
3.参数逐行详解
tag(字符串):图表/曲线的名称,自定义,会展示在TensorBoard页面上。
示例: “训练损失值” 、 “准确率” 、 “损失/训练集” ;
支持分级命名(用 / 分隔),页面会自动分组管理多条曲线。
2.scalar_value(数字/变量):需要记录的标量数值,可以是常量、变量、运算表达式。
global_step(整数):横坐标数值,代表步数/轮数,对应折线图 X 轴。
常用变量:循环变量step、训练轮数epoch。
注意:不要手动添加tag= 关键字传参,会导致参数顺序错乱、数据写入异常。
4.基础示例(单条曲线)
功能:模拟 5 轮训练,记录一组线性变化的数值,生成单条折线图。
# 导入所需库fromtorch.utils.tensorboardimportSummaryWriterimportosimportshutil# 步骤1:清理旧日志文件夹ifos.path.exists("logs"):shutil.rmtree("logs")# 步骤2:创建日志写入器writer=SummaryWriter("logs")# 步骤3:打印日志绝对路径log_abs_path=os.path.abspath(writer.log_dir)print(f"当前日志文件存放绝对路径:{log_abs_path}")# 步骤4:循环模拟训练,写入标量数据# range(5):循环 0、1、2、3、4,共5步forstepinrange(5):# 定义需要记录的数值current_value=10-step# 调用方法写入数据writer.add_scalar("测试曲线",current_value,step)# 步骤5:关闭写入器(必写)writer.close()print("标量数据写入完成!")5.代码运行 + 页面查看完整步骤
运行上述 Python 代码,控制台打印出日志绝对路径,同时项目目录生成全新 logs 文件夹;
打开虚拟环境终端,切换到日志所在目录:
- 方式(绝对路径,路径出错时使用,复制代码打印的路径):
tensorboard --logdir=你代码打印的日志绝对路径
终端出现
TensorBoard xxx at [http://localhost:6006/](http://localhost:6006/)代表启动成功;浏览器输入地址
[http://localhost:6006](http://localhost:6006),点击左侧Scalars选项卡,查看折线图。
6.进阶示例(多条曲线分组展示)
功能:同时记录训练集损失、验证集损失,使用分级标签分组展示。
fromtorch.utils.tensorboardimportSummaryWriterimportosimportshutil# 清理旧日志ifos.path.exists("logs"):shutil.rmtree("logs")# 创建写入器writer=SummaryWriter("logs")log_abs_path=os.path.abspath(writer.log_dir)print(f"当前日志文件存放绝对路径:{log_abs_path}")# 模拟6轮训练 epochforepochinrange(6):# 模拟训练集损失、验证集损失train_loss=5-epoch*0.6val_loss=5-epoch*0.4# 分级标签:同一分组下两条曲线writer.add_scalar("损失/训练集",train_loss,epoch)writer.add_scalar("损失/验证集",val_loss,epoch)# 关闭写入器writer.close()print("多组标量数据写入完成!")运行后,TensorBoard 页面会出现「损失」分组,内部包含两条对比曲线。
四.常用方法 2:add_images记录图片
1.作用
将数据集内的图片张量写入日志,在浏览器可视化展示。
核心用途:校验图片加载是否正常、图像预处理逻辑是否出错(缩放、归一化、裁剪等)。
2.标准语法格式
writer.add_images(tag, img_tensor, global_step)
3.参数逐行详解
tag(字符串):图片面板名称,自定义;img_tensor(张量):批量图片数据,一般是DataLoader迭代取出的张量,
要求形状:[batch_size, channel, height, width](PyTorch 默认格式);
:::info
形状:[batch_size, channel, height, width] (简写 [B, C, H, W] )
:::
:::info
- batch_size(B):批次大小,一个批次里包含多少张图片
- channel(C):通道数
- 灰度图: C=1 (单通道)
- 彩色图(RGB): C=3 (三通道)
- height(H):图像高度(像素)
- width(W):图像宽度(像素)
:::
global_step(整数):序号/轮数,区分不同批次的图片,单批次图片写 0 即可。
4.完整可运行示例
fromtorch.utils.tensorboardimportSummaryWriterimportosimportshutil#数据集,数据加载器,图像预处理库fromtorch.utils.dataimportDataset,DataLoaderfromtorchvisionimporttransforms#图片读取库fromPILimportImage#===========1.初始化+清理日志===============ifos.path.exists('logs'):shutil.rmtree('logs')writer=SummaryWriter('logs')log_abs_path=os.path.abspath('writer.log_dir')print(f'当前日志文件存放绝对路径:{log_abs_path}')#============2.定义图像预处理规则==============transforms=transforms.Compose([transforms.Resize((64,64)),#统一缩放到64*64transforms.ToTensor()#PIL图片转化为 张量,归一化到【0,1】])#=============3.自定义数据集类==================classMyImageDataset(Dataset):def__init__(self,root_dir,transform=None):self.root_dir=root_dir self.transform=transform# 只获取图片文件,过滤掉文件夹和非图片格式self.img_name_list=[]fornameinos.listdir(root_dir):path=os.path.join(root_dir,name)# 只保留文件,并且是常见图片格式ifos.path.isfile(path)andname.lower().endswith(('.jpg','.jpeg','.png','.bmp')):self.img_name_list.append(name)def__len__(self):returnlen(self.img_name_list)def__getitem__(self,idx):img_name=self.img_name_list[idx]img_path=os.path.join(self.root_dir,img_name)# 确保路径存在ifnotos.path.exists(img_path):raiseFileNotFoundError(f"图片不存在:{img_path}")img=Image.open(img_path).convert('RGB')ifself.transform:img=self.transform(img)returnimg#=============4.加载数据集与批量数据===================#替换为本地真实的图片文件夹路径# 推荐写法(原始字符串)image_root=r"D:\Users\xin\PycharmProjects\PythonProject2\pytorch\train\ants"dataset=MyImageDataset(image_root,transform=transforms)print("数据集总长度:",len(dataset))# 打印长度,看是不是 0print("图片列表:",dataset.img_name_list)# 打印读到的文件名#批量加载,每批8张图片dataloader=DataLoader(dataset,batch_size=8,shuffle=True)#取出第一批图片数量batch_imgs=next(iter(dataloader))#================5.写入图片数据到日志====================writer.add_images('训练集的图片预览',batch_imgs,0)#======================6.收尾关闭=======================writer.close()print('图片数据写入完成!')5.页面查看步骤
运行代码,生成 logs 文件夹;
终端启动 TensorBoard(命令同前文);
浏览器打开页面,点击左侧 Images 选项卡,即可查看批量图片。
五.常用方法 3:add_graph记录网络结构
1.作用
可视化手写的神经网络模型结构,直观查看:网络层顺序、层与层之间的连接、张量维度传递,快速排查网络搭建错误。
2.标准语法格式
writer.add_graph(model, input_to_model)
3.参数逐行详解
model:自己搭建的神经网络模型对象(必须继承torch.nn.Module);input_to_model:模拟模型的输入张量,形状必须和真实训练输入完全一致。
4.完整可运行示例
fromtorch.utils.tensorboardimportSummaryWriterimportosimportshutilimporttorchimporttorch.nnasnn#=================1.初始化+清理日志================ifos.path.exists('logs'):shutil.rmtree('logs')writer=SummaryWriter('logs')log_abs_path=os.path.abspath(writer.log_dir)print(f'当前日志存放绝对路径:{log_abs_path}')#==================2.搭建简单神经网络================classSimpleNet(nn.Module):def__init__(self):super(SimpleNet,self).__init__()#全连接层:输入64*64*3,输出2个分类self.fc=nn.Linear(64*64*3,2)#向前传播defforward(self,x):#展开张量:[batch,3,64,64]->[batch,64*64*3]x=x.flatten(1)x=self.fc(x)returnx#实例化网络模型net=SimpleNet()#====================3.构造模拟输入 张量 ===================#形状:[batch_size=1,channel=3,height=64,width=64]test_input=torch.randn(1,3,64,64)#=======================4.写入网络结构======================writer.add_graph(net,test_input)#=======================5.收尾关闭==========================writer.close()print('网络结构数据写入完成')5.页面查看步骤
运行代码,生成日志文件;
终端启动
TensorBoard;浏览器页面点击左侧
Graphs选项卡,查看完整网络拓扑结构。
六.关键规则
1.文件夹名称强匹配规则
代码中:
SummaryWriter("文件夹名")定义日志文件夹名称;终端中:
tensorboard --logdir=文件夹名读取日志;硬性要求:两个名称必须一字不差,包括字母大小写、符号、空格,否则无法读取日志。
2.旧日志问题解决方案
问题现象:多次运行代码, logs 内生成多个日志文件,TensorBoard 叠加所有历史数据,曲线错乱、出现 y=x 异常线条;
解决方案:代码开头固定添加 os + shutil 逻辑,每次运行自动删除旧文件夹,从根源避免干扰。
3.端口占用问题
默认端口:
TensorBoard默认使用 6006 端口;问题现象:端口被其他程序占用,服务启动失败;
解决命令(手动指定新端口):
tensorboard --logdir=logs --port=6007访问地址:
[http://localhost:6007](http://localhost:6007)(端口号和命令保持一致)。
4.writer.close()必写规则
日志数据会先存入内存缓存,不会实时写入本地;
不执行
close():代码异常终止、程序退出后,缓存数据丢失,logs为空或文件损坏;执行时机:所有数据写入操作全部完成后,放在代码最后一行。
七.全套代码模板
1.基础通用模板(所有场景通用框架)
# 1. 统一导入库fromtorch.utils.tensorboardimportSummaryWriterimportosimportshutil# 2. 自动清理旧日志ifos.path.exists("logs"):shutil.rmtree("logs")# 3. 创建日志写入器writer=SummaryWriter("logs")# 4. 打印日志绝对路径(路径排查专用)log_path=os.path.abspath(writer.log_dir)print(f"日志绝对路径:{log_path}")# ========== 中间区域:根据需求选择对应方法 ==========# 记录标量:writer.add_scalar(标签, 数值, 步数)# 记录图片:writer.add_images(标签, 图片张量, 序号)# 记录网络:writer.add_graph(模型, 模拟输入张量)# ================================================# 5. 关闭写入器(结尾必写)writer.close()2.终端启动命令汇总
- 相对路径(终端与 logs 同目录)
tensorboard --logdir=logs
- 绝对路径(路径异常/跨盘使用,替换为代码打印的真实路径)
tensorboard --logdir=D:\xxx\项目文件夹\logs
- 端口被占用,指定新端口
tensorboard --logdir=logs --port=6007
- 浏览器访问地址汇总
默认端口:
[http://localhost:6006](http://localhost:6006)自定义端口 6007:
[http://localhost:6007](http://localhost:6007)
