当前位置: 首页 > news >正文

红外火情时序预判 CNN-LSTM 模型

基于 CNN-LSTM 的阴燃蔓延时序预判,实现从 “事后识别火情” 升级为 “提前预判高温扩张风险”。整套流程分为:视频时序样本构建、CNN-LSTM 模型训练、ONNX 模型导出、嵌入式部署四大环节。

  • 数据来源:红外摄像头录制1m 30s 分段 AVI 火情视频;
  • 样本构造:滑动窗口截取连续 8 帧红外图像为输入序列,提取未来帧高温特征作为回归标签,划分安全 / 阴燃扩张 / 明火三类风险标签;
  • 标签提取:对每组序列,提取未来帧的高温像素面积、图像平均灰度、高温区域中心坐标作为回归标签;同时标注 0 = 无扩张风险、1 = 阴燃持续扩张、2 = 即将出现明火三分类标签;
  • 数据划分:训练集 80%、验证集 20%,统一归一化至 [0,1],适配模型输入。
  • 数据集规格:共 8 个.npz时序文件,数据集硬盘占用 12GB;
  • 训练硬件:笔记本 CPU,内存有限,无法一次性加载全部数据。

初始代码一次性读取全部 npz 文件,加载数组时 numpy 直接申请 4GB 以上内存,程序初始化阶段直接崩溃:

MemoryError: Unable to allocate 4.17 GiB for an array with shape (559153152,) and data type float64

摒弃全局一次性加载,改为单文件分批训练不再扫描全部文件构建全局索引,循环逐个读取 npz,训练完成立刻关闭文件释放内存,同一时间仅占用单个文件内存;

import os import torch import torch.nn as nn import torch.optim as optim import numpy as np from torch.utils.data import Dataset, DataLoader device = torch.device("cuda" if torch.cuda.is_available() else "cpu") SEQ_LENGTH = 8 H, W = 192, 256 WEIGHT_SAVE = "temp_weight.pth" class FireSingleFileDataset(Dataset): def __init__(self, file_path): print(f"Load single file: {os.path.basename(file_path)}") self.data = np.load(file_path, mmap_mode="r") self.seq_arr = self.data["seq"] self.label_arr = self.data["label"] self.total = len(self.seq_arr) print(f"Samples in this file: {self.total}") def __len__(self): return self.total def __getitem__(self, idx): seq = self.seq_arr[idx] label = self.label_arr[idx] seq = torch.from_numpy(seq).permute(0, 3, 1, 2).float() label = torch.from_numpy(label).float() return seq, label def close(self): self.data.close() class ConvBlock(nn.Module): def __init__(self, in_c, out_c): super().__init__() self.block = nn.Sequential( nn.Conv2d(in_c, out_c, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2) ) def forward(self, x): return self.block(x) class CNNLSTM(nn.Module): def __init__(self): super().__init__() self.cnn = nn.Sequential( ConvBlock(3, 16), ConvBlock(16, 32), ConvBlock(32, 64) ) self.lstm = nn.LSTM(input_size=64*24*32, hidden_size=128, batch_first=True) self.reg_head = nn.Linear(128, 4) self.cls_head = nn.Sequential(nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, 3)) def forward(self, x): b, seq, c, h, w = x.shape feat_list = [] for t in range(seq): feat = self.cnn(x[:, t]) feat_list.append(feat.flatten(1)) seq_feat = torch.stack(feat_list, dim=1) lstm_out, _ = self.lstm(seq_feat) last_feat = lstm_out[:, -1, :] reg_out = self.reg_head(last_feat) cls_out = self.cls_head(last_feat) return reg_out, cls_out def train_one_file(model, file_path, epoch_num=1): dataset = FireSingleFileDataset(file_path) loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0) opt = optim.Adam(model.parameters(), lr=1e-4) loss_mse = nn.MSELoss() loss_ce = nn.CrossEntropyLoss() print(f"===== Start train file {os.path.basename(file_path)} =====") for e in range(epoch_num): total_loss = 0.0 for seq, label in loader: seq, label = seq.to(device), label.to(device) pred_reg, pred_cls = model(seq) loss1 = loss_mse(pred_reg, label) cls_label = torch.clamp(torch.round(label[:, 0]).long(), 0, 2) loss2 = loss_ce(pred_cls, cls_label) loss = loss1 + loss2 opt.zero_grad() loss.backward() opt.step() total_loss += loss.item() avg_loss = total_loss / len(loader) print(f"Epoch {e+1}/{epoch_num}, Loss: {avg_loss:.4f}") dataset.close() torch.save(model.state_dict(), WEIGHT_SAVE) print(f"Temporary weight saved to {WEIGHT_SAVE}") if __name__ == "__main__": print("===== Start Initializing CNN-LSTM Model =====") model = CNNLSTM().to(device) print(f"Model loaded on device: {device}") dataset_path = r"D:\图像处理\venv\train_seq" all_files = [os.path.join(dataset_path, f) for f in os.listdir(dataset_path) if f.endswith(".npz")] print(f"Total npz files: {len(all_files)}") if os.path.exists(WEIGHT_SAVE): print(f"Load previous weight from {WEIGHT_SAVE}") model.load_state_dict(torch.load(WEIGHT_SAVE, map_location=device)) for f in all_files: train_one_file(model, f, epoch_num=1) print("Release memory, prepare next file\n") print("===== All files training complete, export ONNX Model =====") model.eval() dummy = torch.randn(1, SEQ_LENGTH, 3, H, W).to(device) save_name = "fire_predict.onnx" full_save_path = os.path.abspath(save_name) torch.onnx.export(model, dummy, full_save_path, opset_version=12) print(f"ONNX model saved successfully! Path: {full_save_path}") print("Upload fire_predict.onnx to RK3568 /home via MobaXterm")

  • 每次仅打开单个 npz,训练结束调用close()释放内存,杜绝多文件同时占用内存;
  • 每个文件训练完成保存临时权重,下次运行可加载权重接续训练,无需从头开始;
  • CNN 提取单帧红外热区空间特征,LSTM 学习时序扩张规律,同时输出热区回归预测值 + 火情风险三分类结果。
http://www.jsqmd.com/news/1132177/

相关文章:

  • 多模态AI Agent在内容生成领域的研究进展综述
  • 3大核心功能彻底解决Android存储空间不足问题:SD Maid SE深度清理指南
  • 《怪物猎人:荒野》 豪华中文版 全DLC VBS一键启狩猎
  • 开源中文字体的终极解决方案:思源宋体专业设计指南
  • 可视化 vs 终端 vs 云端:VTJ.PRO、Claude Code、Codex 三强横评
  • AI编程助手会“分期付款”藏毒?实验:65%攻击绕过了监控
  • 【Python环境】从零解读PyCharm项目结构:虚拟环境、外部库与uv包管理器
  • DQN 高估问题深度解析:3 种成因与双 DQN 等 2 类解决方案对比
  • 沃尔安智能摄像机删除后的恢复方法
  • 郑州高口碑黄金回收白银回收
  • 超参数调优进阶:Optuna/Bayesian/Early Stopping
  • 出行和货运行业正在被智能体重塑,效率提升超过50%
  • PCB布线设计 2025:3W/20H/蛇形线等5大核心规则实战解析与量化验证
  • C++ 线程安全日志系统:策略模式解耦输出端,RAII 实现 glog 风格流式日志
  • 集成隔离电源的RS-485/RS-422收发器:PCB拼接电容设计实战与EMC优化
  • nlpconnect/vit-gpt2-image-captioning 超详细入门解析
  • Java---牛客的ACM模式被卡输入输出时间,如何解决?一个模版即可解决
  • AI 音频生成流水线:异步任务要有进度和取消
  • 基于社交图谱的校园活动与交友系统(SpringBoot + Neo4j + UniApp)
  • 舟山高口碑黄金回收白银回收
  • 2025黑科技!加持会议任务提醒,快准稳颠覆你的认知?
  • Flutter 开发鸿蒙实战:Windows 环境下从 HAP 构建到四 Tab 页面运行
  • MT7621 PCIe WiFi 驱动移植:从 5.4 内核到 OpenWrt 22.03 的 3 个关键步骤
  • 对比聚类 (Contrastive Clustering) 与 SimCLR 深度对比:3 个核心差异与 2 个应用场景分析
  • C++26 std::inplace_vector 详解:零堆分配的定容向量
  • C++26 std::chrono 哈希与 SI 词头详解
  • Want 参数安全:类型、边界、异常兜底怎么写
  • 机器学习系统设计:从原型到生产
  • 开始委托之旅 委托与接口
  • 张掖口碑黄金铂金回收白银回收实体老店