当前位置: 首页 > news >正文

基于时频谱图特征提取和改进型UNet卷积神经网络的机械故障诊断(Pytorch)

首先,将原始一维振动信号通过短时傅里叶变换转换为时频谱图,形成二维图像特征;接着构建了一个改进的UNet神经网络架构,该网络在保留UNet编码器-解码器结构的基础上移除了时间嵌入模块,增加了注意力机制和残差连接,专门用于谱图特征提取和分类;然后采用数据增强技术扩充训练样本,通过分层抽样划分数据集;模型训练阶段使用交叉熵损失函数、Adam优化器和学习率调度策略,并在验证集上监控性能保存最佳模型;最后在测试集上评估模型性能,计算准确率、混淆矩阵等指标,并可视化训练过程、预测结果和诊断报告,实现对轴承正常、滚珠故障、内圈故障和外圈故障四种状态的准确分类诊断。

import numpy as np import pandas as pd import matplotlib.pyplot as plt from pathlib import Path import warnings warnings.filterwarnings('ignore') import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader from torch.nn import functional as F import torch.optim as optim from torch.optim.lr_scheduler import ReduceLROnPlateau import math from sklearn.metrics import confusion_matrix, classification_report from sklearn.model_selection import train_test_split import seaborn as sns from scipy import signal import scipy.stats as stats # 设置随机种子确保可重复性 def set_seed(seed=42): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) if torch.cuda.is_available() else None np.random.seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False set_seed(42) # 定义设备 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # ==================== 原始UNet组件 ==================== class Swish(nn.Module): """Swish激活函数""" def forward(self, x): return x * torch.sigmoid(x) class TimeEmbedding(nn.Module): """时间嵌入层(在原始UNet中用于扩散模型,这里保留但会简化)""" def __init__(self, T, d_model, dim): assert d_model % 2 == 0 super().__init__() emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000) emb = torch.exp(-emb) pos = torch.arange(T).float() emb = pos[:, None] * emb[None, :] assert list(emb.shape) == [T, d_model // 2] emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1) assert list(emb.shape) == [T, d_model // 2, 2] emb = emb.view(T, d_model) self.timembedding = nn.Sequential( nn.Embedding.from_pretrained(emb), nn.Linear(d_model, dim), Swish(), nn.Linear(dim, dim), ) self.initialize() def initialize(self): """初始化权重""" for module in self.modules(): if isinstance(module, nn.Linear): nn.init.xavier_uniform_(module.weight) nn.init.zeros_(module.bias) def forward(self, t): emb = self.timembedding(t) return emb

详细算法步骤

数据采集与预处理阶段:从西储大学轴承数据集中加载四种不同故障状态的原始振动信号文件,对每一类信号进行标准化处理消除量纲影响,然后将长时信号按照固定长度和重叠率分割成多个短时信号片段,为后续分析准备基础数据单元。

特征工程构建阶段:对每个信号片段应用短时傅里叶变换算法,将一维时域信号转换为二维时频谱图,通过对数变换增强特征对比度,再统一缩放至固定尺寸形成标准化图像特征,同时提取信号的时域统计特征和频域特征作为补充信息。

数据集划分阶段:采用分层抽样策略将特征数据集按比例划分为训练集、验证集和测试集,确保每个子集中各类别样本分布均衡,避免因数据划分不当导致的模型评估偏差。

数据增强处理阶段:在训练集上应用随机噪声添加和随机幅度缩放等数据增强技术,人工扩充训练样本多样性,提高模型对实际工况变化的适应能力和泛化性能。

神经网络模型构建阶段:设计改进型UNet分类网络架构,保留编码器-解码器对称结构用于多尺度特征提取,引入注意力机制增强关键特征识别能力,使用残差连接缓解梯度消失问题,最后接入全局平均池化层和全连接分类器输出故障类别概率。

模型训练优化阶段:初始化网络权重参数,设置交叉熵损失函数和自适应矩估计优化器,采用动态学习率调整策略,在训练过程中实施梯度裁剪防止梯度爆炸,通过前向传播计算预测输出,反向传播更新网络参数,迭代优化模型性能。

模型验证与选择阶段:在独立验证集上定期评估模型表现,监控验证损失和分类准确率变化趋势,保存验证性能最优的模型权重,避免过拟合现象发生,确保模型具备良好泛化能力。

模型测试评估阶段:加载最佳模型权重,在未见过的测试集上进行全面性能评估,计算总体分类准确率、每类故障的精确率和召回率,生成混淆矩阵可视化分类错误分布,定量分析模型诊断能力。

参考文章:

基于时频谱图特征提取和改进型UNet卷积神经网络的机械故障诊断(Pytorch) - 哥廷根数学学派的文章 -
https://zhuanlan.zhihu.com/p/1998402980043583749

http://www.jsqmd.com/news/294627/

相关文章:

  • 基于贝叶斯物理信息神经网络的工业装备退化趋势预测方法(Pytorch)
  • 基于图拉普拉斯正则化物理信息神经网络的工业装备退化趋势预测方法(Pytorch)
  • 基于可学习Morlet小波匹配滤波和统计特征融合的引力波信号检测算法(算法完善中,Python)
  • 基于点堆动力学-热传递耦合物理模型与支持向量机残差分析的核反应堆数字孪生混合异常检测算法(以模拟信号为例,Python)
  • 基于多阶段参数辨识与蒙特卡洛不确定性传播的质子交换膜水电解槽电压退化预测和预后地平线评估集成算法(Python)
  • 基于希尔伯特变换与带通滤波的滚动轴承振动信号包络谱故障诊断算法(Python,jupyter nootbook文件)
  • 最小生成树专题
  • 1月24号
  • 别再二选一了:高手都在用的微调+RAG混合策略,今天一次讲透
  • 导师严选9个一键生成论文工具,研究生论文写作必备!
  • samp-cef 解决客户端显示服务端传回数据乱码问题
  • 高中学习机深度测评:告别智商税!热门机型实测对比
  • 【开题答辩全过程】以 某县农村留守儿童爱心帮扶平台为例,包含答辩的问题和答案
  • Day28-20260124
  • America has been dead!
  • 冲刺Day5
  • JavaScript 中 ||(逻辑或)和 (逻辑与)
  • 数据结构——三十九、顺序查找(王道408) - 指南
  • NVIDIA GPU 系列用途分类梳理
  • PADS Layout 添加板宽圆角
  • 亲测好用!8款AI论文软件测评:研究生开题报告必备工具
  • 百度文库与网盘重组新事业群,向李彦宏汇报,压力之下的改革能不能成?
  • 排列组合专题
  • 数字化转型下零售门店管理软件的功能与选择考量
  • 闲鱼开店不用愁!自动回复 + 远程管理,随时随地搞定买家咨询就靠cpolar
  • JBoltAI网关:Java企业级AI的稳定“交通枢纽”
  • 连锁门店数字化平台核心功能与适用场景解析
  • 技术已到位,失业潮为何还未爆发?决策层的认知盲区才是真正的“缓冲带”
  • [Android] vFlow v1.4.0 可视化工作流自动化工具
  • [Windows] WeFlow v1.3.1-V信聊天记录浏览、导出