用DCRNN搞定城市交通预测:从论文到PyTorch实战(附METR-LA数据集处理)
用DCRNN实现城市交通预测:从理论到PyTorch工程实践
交通拥堵是现代城市治理的顽疾,而精准的流量预测能为智慧交通系统提供关键决策支持。传统时间序列方法在捕捉复杂空间关联时捉襟见肘,这正是DCRNN(扩散卷积循环神经网络)的突破点——它将图神经网络与循环神经网络融合,开创性地用扩散过程建模交通路网的动态传播效应。本文将以METR-LA数据集为例,手把手带你完成从论文公式到可部署模型的完整实现链路。
1. 环境配置与数据准备
工欲善其事,必先利其器。我们需要搭建支持图神经网络的开发环境:
conda create -n dcrnn python=3.8 conda install pytorch=1.12.0 torchvision cudatoolkit=11.3 -c pytorch pip install torch-geometric scikit-learn pandas matplotlibMETR-LA数据集包含洛杉矶高速公路4个月的车速传感器数据,原始格式需要特殊处理:
- 传感器元数据:207个检测器的经纬度坐标
- 时间序列数据:5分钟间隔的车速记录(单位:mph)
- 时间范围:2012年3月1日至6月30日
使用以下代码加载并可视化数据分布:
import pandas as pd import matplotlib.pyplot as plt # 加载传感器位置 sensors = pd.read_csv('sensor_graph/graph_sensor_locations.csv') plt.scatter(sensors['longitude'], sensors['latitude']) plt.title('METR-LA传感器空间分布')注意:原始数据中的缺失值需用线性插值或相邻传感器均值填充,否则会影响扩散过程建模。
2. 图结构构建与邻接矩阵计算
DCRNN的核心创新在于用扩散卷积替代传统卷积,这要求我们首先定义路网的图表示。基于传感器间距构建带权邻接矩阵:
from sklearn.metrics.pairwise import haversine_distances def build_adjacency_matrix(coords, threshold_km=3): """ 基于haversine距离构建阈值化邻接矩阵 :param coords: (N,2)维度的经纬度数组 :param threshold_km: 连接阈值(公里) :return: 标准化邻接矩阵 """ rad_coords = np.radians(coords) dist_matrix = haversine_distances(rad_coords) * 6371 # 转换为公里 adj_matrix = np.exp(-dist_matrix**2 / threshold_km**2) adj_matrix[dist_matrix > threshold_km] = 0 # 阈值截断 return adj_matrix / adj_matrix.sum(axis=1) # 行归一化关键参数对比:
| 参数 | 典型值 | 影响分析 |
|---|---|---|
| 距离阈值 | 3-5km | 值过小导致图稀疏,过大引入噪声 |
| 衰减系数 | 0.5-1.5 | 控制空间依赖衰减速度 |
| 归一化方式 | 行归一化 | 保证扩散过程稳定性 |
3. DCGRU单元实现详解
DCGRU(Diffusion Convolutional GRU)是DCRNN的核心组件,其在传统GRU中注入扩散卷积操作。以下是PyTorch实现关键步骤:
import torch import torch.nn as nn from torch_geometric.nn import MessagePassing class DiffusionConv(MessagePassing): def __init__(self, in_channels, out_channels, num_diffusions): super().__init__(aggr='add') self.lin = nn.Linear(in_channels, out_channels) self.num_diffusions = num_diffusions def forward(self, x, edge_index, edge_weight): # 前向扩散 h = x for _ in range(self.num_diffusions): h = self.propagate(edge_index, x=h, edge_weight=edge_weight) return self.lin(h) class DCGRUCell(nn.Module): def __init__(self, input_dim, hidden_dim, adj_matrix): super().__init__() self.diff_conv = DiffusionConv(input_dim+hidden_dim, 2*hidden_dim, 2) self.update_gate = nn.Linear(hidden_dim, hidden_dim) def forward(self, x, h_prev, adj): combined = torch.cat([x, h_prev], dim=-1) gates = torch.sigmoid(self.diff_conv(combined, adj)) reset_gate, update_gate = gates.chunk(2, dim=-1) h_candidate = torch.tanh(self.update_gate(reset_gate * h_prev)) h_new = (1 - update_gate) * h_prev + update_gate * h_candidate return h_new训练时采用计划采样(Scheduled Sampling)策略缓解自回归误差累积:
def scheduled_sampling(epoch, max_epochs): """线性衰减的教师强制比率""" epsilon = max(0.05, 1.0 - epoch / max_epochs) return epsilon4. 完整模型训练与调优
组装完整的DCRNN模型并进行端到端训练:
class DCRNN(nn.Module): def __init__(self, adj_matrix, input_dim=1, hidden_dim=64): super().__init__() self.encoder = nn.ModuleList([DCGRUCell(input_dim, hidden_dim, adj_matrix)]) self.decoder = nn.ModuleList([DCGRUCell(input_dim, hidden_dim, adj_matrix)]) self.projection = nn.Linear(hidden_dim, input_dim) def forward(self, x, y_true, teacher_forcing_ratio): # 编码器处理历史序列 h = torch.zeros(x.size(0), self.hidden_dim).to(x.device) for t in range(x.size(1)): h = self.encoder[0](x[:,t], h) # 解码器多步预测 outputs = [] input = x[:,-1] # 最后一步作为解码器初始输入 for t in range(y_true.size(1)): h = self.decoder[0](input, h) output = self.projection(h) outputs.append(output) # 计划采样决定下一时刻输入 if torch.rand(1) < teacher_forcing_ratio: input = y_true[:,t] else: input = output return torch.stack(outputs, dim=1)训练过程中的关键监控指标:
| 指标 | 健康范围 | 异常处理 |
|---|---|---|
| 训练损失 | 稳定下降 | 检查梯度裁剪 |
| 验证MAE | <3.5 | 调整学习率 |
| 过拟合gap | <15% | 增加Dropout |
使用Adam优化器时推荐初始参数:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-4) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5)5. 实战效果分析与部署建议
在METR-LA测试集上的典型表现(预测 horizon=15分钟):
| 模型 | MAE | RMSE | 训练时间/epoch |
|---|---|---|---|
| HA | 4.16 | 7.80 | - |
| ARIMA | 3.99 | 8.21 | - |
| DCRNN | 2.77 | 5.38 | 2.3min |
可视化预测效果时,重点关注以下异常模式:
def plot_prediction(true, pred, sensor_idx): plt.figure(figsize=(12,4)) plt.plot(true[:,sensor_idx], label='Ground Truth') plt.plot(pred[:,sensor_idx], '--', label='DCRNN Prediction') plt.legend() plt.xlabel('Time steps (5min)') plt.ylabel('Speed (mph)')实际部署时建议:
- 使用TorchScript将模型转换为生产环境可用的格式
- 对输入数据实施在线标准化(保留训练集的均值和方差)
- 设置异常值过滤器(如车速>100mph视为传感器故障)
我在实际项目中发现,将DCRNN与简单的规则引擎结合(如特殊天气事件处理),能进一步提升复杂场景下的鲁棒性。模型对传感器故障具有较好的容错能力,但当超过30%的节点数据缺失时,建议触发人工干预流程。
