当前位置: 首页 > news >正文

Day41 TensorBoard

TensorBoard 是 TensorFlow/PyTorch 官方的可视化工具,核心作用是实时监控训练过程、分析模型性能、可视化模型结构,通过网页端交互展示训练数据,比单纯打印日志更直观。

import torch import torch.nn as nn import torch.optim as optim from torchvision import models, datasets, transforms from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter # TensorBoard核心类 import numpy as np import matplotlib.pyplot as plt # ===================== 1. 初始化TensorBoard Writer ===================== # 日志保存路径(会自动创建,建议按时间命名避免覆盖) log_dir = "./runs/cifar10_resnet18_experiment" writer = SummaryWriter(log_dir=log_dir) # 核心对象,负责写入日志 # ===================== 2. 基础配置 ===================== device = torch.device("cuda" if torch.cuda.is_available() else "cpu") batch_size = 64 lr = 0.001 epochs = 5 num_classes = 10 classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] # ===================== 3. 数据预处理 ===================== transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) # ===================== 4. 加载预训练模型 ===================== model = models.resnet18(pretrained=True) # 冻结卷积层 for param in model.parameters(): param.requires_grad = False # 替换最后一层 in_features = model.fc.in_features model.fc = nn.Linear(in_features, num_classes) model = model.to(device) # ===================== 5. 损失函数 & 优化器 ===================== criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.fc.parameters(), lr=lr) # ===================== 6. 辅助函数:反归一化显示图像 ===================== def denormalize(img_tensor): """还原归一化的图像,用于TensorBoard显示""" mean = np.array([0.485, 0.456, 0.406]) std = np.array([0.229, 0.224, 0.225]) img = img_tensor.cpu().numpy().transpose((1, 2, 0)) # (C,H,W)→(H,W,C) img = img * std + mean img = np.clip(img, 0, 1) # 限制在0-1之间 return img # ===================== 7. 训练+TensorBoard记录 ===================== def train_and_log(): global_step = 0 # 全局步数(跨epoch累计) # 1. 记录模型图(关键!可视化模型结构) # 生成一个示例输入(batch_size=1, 3, 224, 224) dummy_input = torch.randn(1, 3, 224, 224).to(device) writer.add_graph(model, dummy_input) # 写入模型计算图 # 2. 记录一批样本图像(可视化数据集) images, labels = next(iter(train_loader)) # 反归一化后拼接成网格 img_grid = [] for i in range(8): # 取前8张图 img = denormalize(images[i]) img_grid.append(img) img_grid = np.concatenate(img_grid, axis=1) # 横向拼接 writer.add_image('CIFAR10_Samples', img_grid.transpose((2, 0, 1)), global_step=0) # 转成(C,H,W) # 3. 训练并记录标量/直方图 for epoch in range(1, epochs+1): model.train() train_loss = 0.0 correct = 0 total = 0 for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() optimizer.step() # 累计损失和准确率 train_loss += loss.item() pred = output.argmax(dim=1) correct += pred.eq(target).sum().item() total += target.size(0) # ========== 记录每batch的标量(损失) ========== writer.add_scalar('Train/Batch_Loss', loss.item(), global_step) global_step += 1 # ========== 记录每epoch的标量(平均损失/准确率) ========== avg_train_loss = train_loss / len(train_loader) train_acc = 100.0 * correct / total # 测试集评估 test_loss, test_acc = test(model, test_loader, criterion, device) # 写入标量(支持多曲线对比) writer.add_scalars('Loss', { 'Train': avg_train_loss, 'Test': test_loss }, epoch) writer.add_scalars('Accuracy', { 'Train': train_acc, 'Test': test_acc }, epoch) # ========== 记录模型参数直方图(分析参数分布) ========== for name, param in model.named_parameters(): if 'fc' in name: # 只记录全连接层参数(避免日志过大) writer.add_histogram(f'Params/{name}', param, epoch) writer.add_histogram(f'Grads/{name}', param.grad, epoch) # 打印日志 print(f"Epoch {epoch:2d} | " f"Train Loss: {avg_train_loss:.4f} | Train Acc: {train_acc:.2f}% | " f"Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.2f}%") # 4. 关闭Writer(重要!确保日志写入完成) writer.close() # ===================== 8. 测试函数 ===================== def test(model, loader, criterion, device): model.eval() total_loss = 0 correct = 0 with torch.no_grad(): for data, target in loader: data, target = data.to(device), target.to(device) output = model(data) total_loss += criterion(output, target).item() pred = output.argmax(dim=1) correct += pred.eq(target).sum().item() avg_loss = total_loss / len(loader) acc = 100.0 * correct / len(loader.dataset) return avg_loss, acc # ===================== 9. 执行训练 ===================== if __name__ == "__main__": train_and_log() print(f"\n训练完成!TensorBoard日志已保存至:{log_dir}") print("启动命令:tensorboard --logdir=./runs/cifar10_resnet18_experiment")
操作方法作用网页端对应标签
writer.add_scalar()记录标量(损失、准确率)Scalars
writer.add_scalars()同时记录多个标量(对比训练 / 测试曲线)Scalars
writer.add_image()记录单张 / 网格图像(数据集 / 预测结果)Images
writer.add_graph()可视化模型计算图Graphs
writer.add_histogram()记录参数 / 梯度的分布(直方图)Histograms

@浙大疏锦行

http://www.jsqmd.com/news/478632/

相关文章:

  • 严格控制GOTO语句注意事项
  • 图算法中的边松弛与最短路径更新机制的技术6
  • 先知道“有什么”,再决定“学什么”
  • 2026-3-14 ABC算法题打卡
  • SpringCloud动态路由利器--router4j
  • 2026年毕业论文降AI过审技巧:学姐整理的保姆级攻略
  • 基于MATLAB环境,利用卷积神经网络-长短时记忆网络结合SE注意力机制的数据分类预测模型
  • Altium生成Gerber及CAM350、DFM检查
  • Gorilla项目管理工具:任务跟踪与团队协作API调用实践
  • 如何快速搭建高性能GraphQL服务器:Prisma与GraphQL的完美实战指南
  • {“code“:“40002“,“msg“:“Invalid Arguments“,“sub_code“:“isv.invalid-app-id“,“sub_msg“:“ 无效的AppID参数“}
  • 小爱音响L07A改装AUX血泪史:一根铜丝引发的“血案”与终极救赎
  • 100元打造便携显示器:PocketLCD完整物料清单与采购指南
  • 基于Django技术的建材销售平台(角色:用户、商家、管理员)
  • Git操作的基本命令
  • 3 xgboost
  • Schema.org未来路线图:2026年最新发展计划与功能预览
  • 代码随想录 Day-19(回溯算法)
  • 推荐使用:react-html-email - 优雅的React邮件模板库
  • 探秘 ESCRCPY:一款高效便捷的无线屏幕镜像工具
  • 动态代理详解
  • 通过git上传代码到gitlab(包含第一次上传)小结
  • wow-time时间操作说明
  • Agentic插件系统:扩展平台功能的终极架构设计指南
  • M3U8 在线调试神器!m3u8live.cn让 HLS 流测试更高效
  • HLS 开发必备!详解m3u8live.cn在线播放器的使用与价值
  • 【Index to Lectures or Courses】
  • 如何用代码定义架构:深入探索LikeC4项目
  • WebRTC系列-网络之带宽估计和码率估计(2)接收端带宽估计
  • 如何在Linux终端使用sc-im?新手入门的完整指南