用PyTorch复现AirFormer:手把手教你搭建空气质量预测Transformer(附代码)
用PyTorch复现AirFormer:手把手教你搭建空气质量预测Transformer(附代码)
空气质量预测一直是环境科学和机器学习交叉领域的重要课题。传统方法往往受限于局部特征提取能力不足或计算复杂度高的问题,而Transformer架构凭借其强大的全局建模能力,正在这一领域展现出独特优势。今天我们要实现的AirFormer模型,通过创新的DS-MSA和CT-MSA机制,在保持线性计算复杂度的同时,实现了对全国范围内数千个监测站点的精准预测。
1. 环境准备与数据预处理
在开始构建模型前,我们需要配置合适的开发环境。推荐使用Python 3.8+和PyTorch 1.10+版本,这些版本在自注意力机制实现上具有更好的优化。以下是关键依赖的安装命令:
pip install torch torchvision torchaudio pip install pandas scikit-learn matplotlib空气质量数据集通常包含多种污染物指标(PM2.5、SO2、NO2等)和气象数据(温度、湿度、风速等)。我们需要对这些数据进行标准化处理:
from sklearn.preprocessing import StandardScaler def preprocess_data(data): # 处理缺失值 data = data.interpolate(method='linear', limit_direction='both') # 标准化特征 scaler = StandardScaler() scaled_data = scaler.fit_transform(data) # 构建时空序列样本 seq_length = 24 # 使用24小时历史数据 X, y = [], [] for i in range(len(data)-seq_length-1): X.append(scaled_data[i:i+seq_length]) y.append(scaled_data[i+seq_length]) return np.array(X), np.array(y), scaler关键预处理步骤:
- 时间对齐:确保所有监测站点的数据时间戳一致
- 空间编码:为每个站点生成经纬度特征
- 特征工程:添加星期几、节假日等时间特征
注意:实际应用中建议使用滑动窗口验证来评估模型性能,避免数据泄露问题。
2. 模型架构设计
AirFormer的核心创新在于其双阶段设计:自下而上的确定性阶段和自上而下的随机阶段。我们先来看模型的主体结构:
import torch.nn as nn class AirFormer(nn.Module): def __init__(self, num_stations, feature_dim, num_heads=8, num_layers=6): super().__init__() self.embedding = nn.Linear(feature_dim, 64) # 确定性阶段 self.deterministic_layers = nn.ModuleList([ AirFormerBlock(64, num_heads) for _ in range(num_layers) ]) # 随机阶段 self.stochastic_layers = nn.ModuleList([ StochasticBlock(64) for _ in range(num_layers) ]) self.output_layer = nn.Linear(64, feature_dim) def forward(self, x): # x形状: (batch, time, stations, features) x = self.embedding(x) # 确定性阶段处理 deterministic_states = [] for layer in self.deterministic_layers: x = layer(x) deterministic_states.append(x) # 随机阶段处理 predictions = [] for t in range(x.size(1)): z = torch.randn_like(x[:,t]) # 潜在变量 for l, layer in enumerate(self.stochastic_layers): z = layer(z, deterministic_states[l][:,t]) predictions.append(self.output_layer(z)) return torch.stack(predictions, dim=1)2.1 Dartboard空间注意力(DS-MSA)
DS-MSA是AirFormer的关键创新之一,它通过dartboard映射将计算复杂度从O(N²)降低到O(N):
class DS_MSA(nn.Module): def __init__(self, dim, num_heads, region_size=25): super().__init__() self.num_heads = num_heads self.region_size = region_size self.qkv = nn.Linear(dim, dim*3) self.proj = nn.Linear(dim, dim) # Dartboard映射矩阵 self.dartboard = self._init_dartboard() def _init_dartboard(self): # 实现同心圆区域划分逻辑 # 返回形状为(region_size, num_stations)的映射矩阵 ... def forward(self, x): B, T, N, C = x.shape qkv = self.qkv(x).reshape(B, T, N, 3, self.num_heads, C//self.num_heads) q, k, v = qkv.unbind(3) # 形状均为(B,T,N,num_heads,C/num_heads) # Dartboard映射 k_region = torch.einsum('rsn,btnhc->btrshc', self.dartboard, k) v_region = torch.einsum('rsn,btnhc->btrshc', self.dartboard, v) # 注意力计算 attn = torch.einsum('btnhc,btrshc->btrshn', q, k_region) / (C**0.5) attn = attn.softmax(dim=-2) out = torch.einsum('btrshn,btrshc->btnhc', attn, v_region) out = out.reshape(B, T, N, C) return self.proj(out)DS-MSA的优势:
- 线性复杂度:通过区域聚合减少计算量
- 空间感知:自动学习邻近站点的更强相关性
- 可解释性:注意力权重反映真实的空间影响模式
2.2 因果时间注意力(CT-MSA)
CT-MSA通过局部窗口和逐步扩大的感受野来高效捕获时间依赖性:
class CT_MSA(nn.Module): def __init__(self, dim, num_heads, window_sizes=[3,5,7]): super().__init__() self.num_heads = num_heads self.window_sizes = window_sizes self.qkv = nn.Linear(dim, dim*3) self.proj = nn.Linear(dim, dim) def forward(self, x): B, T, N, C = x.shape qkv = self.qkv(x).reshape(B, T, N, 3, self.num_heads, C//self.num_heads) q, k, v = qkv.unbind(3) outputs = [] for t in range(T): # 逐步扩大的时间窗口 window = self.window_sizes[min(t//10, len(self.window_sizes)-1)] start = max(0, t-window) # 局部注意力计算 q_t = q[:,t] # (B,N,num_heads,C/num_heads) k_window = k[:,start:t+1] # (B,t-start+1,N,num_heads,C/num_heads) v_window = v[:,start:t+1] attn = torch.einsum('bnhc,btnhc->btnh', q_t, k_window) / (C**0.5) attn = attn.softmax(dim=1) out = torch.einsum('btnh,btnhc->bnhc', attn, v_window) outputs.append(out) out = torch.stack(outputs, dim=1).reshape(B, T, N, C) return self.proj(out)3. 模型训练与优化
AirFormer的训练需要同时优化确定性预测损失和随机阶段的ELBO目标:
def train(model, dataloader, optimizer, epoch): model.train() total_loss = 0 for batch_idx, (x, y) in enumerate(dataloader): optimizer.zero_grad() # 前向传播 pred = model(x) # 确定性损失 deterministic_loss = F.l1_loss(pred, y) # 随机阶段ELBO kl_loss = model.get_kl_loss() # 实现潜在变量的KL散度计算 reconstruction_loss = F.mse_loss(pred, y) elbo = reconstruction_loss + kl_loss # 组合损失 loss = deterministic_loss + 0.5 * elbo # 反向传播 loss.backward() optimizer.step() total_loss += loss.item() print(f'Epoch {epoch} Loss: {total_loss/len(dataloader):.4f}')训练技巧:
- 学习率预热:前5个epoch线性增加学习率
- 梯度裁剪:防止Transformer训练不稳定
- 混合精度训练:使用torch.cuda.amp加速训练
提示:实际训练时建议使用学习率调度器,如CosineAnnealingLR
4. 结果可视化与分析
训练完成后,我们需要评估模型性能并进行结果可视化:
def plot_predictions(true, pred, station_idx=0, feature_idx=0): plt.figure(figsize=(12, 6)) plt.plot(true[:, station_idx, feature_idx], label='True') plt.plot(pred[:, station_idx, feature_idx], label='Predicted', alpha=0.7) plt.title(f'Station {station_idx} - Feature {feature_idx}') plt.legend() plt.show() # 计算评估指标 def evaluate(model, dataloader): model.eval() mae, rmse = 0, 0 with torch.no_grad(): for x, y in dataloader: pred = model(x) mae += F.l1_loss(pred, y).item() rmse += torch.sqrt(F.mse_loss(pred, y)).item() print(f'MAE: {mae/len(dataloader):.4f}, RMSE: {rmse/len(dataloader):.4f}')性能优化方向:
- 注意力头数调整:通常4-8个头效果最佳
- 区域大小优化:根据实际空间分布调整dartboard区域
- 潜在变量维度:影响模型捕捉不确定性的能力
5. 实际部署建议
将训练好的AirFormer模型投入实际应用时,有几个关键考虑因素:
- 实时预测优化:
class OptimizedAirFormer(AirFormer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.cache = {} # 用于存储中间计算结果 def predict_next(self, new_observation): # 增量式预测,利用缓存避免重复计算 if not self.cache: # 初始化缓存 ... else: # 增量更新 ... return prediction- 模型量化与加速:
# 使用TorchScript导出模型 torch.jit.script(model).save("airformer_quantized.pt")- 持续学习机制:
def online_update(model, new_data, optimizer, steps=100): # 小批量在线学习 for _ in range(steps): loss = model(new_data) optimizer.zero_grad() loss.backward() optimizer.step()在实际项目中,我们发现几个关键经验:首先,DS-MSA的区域划分需要根据监测站点的实际地理分布进行调整;其次,模型对风速等气象特征的依赖性较强,需要确保这些数据的质量;最后,随机阶段的潜在变量维度不宜过大,否则会导致训练不稳定。
