当前位置: 首页 > news >正文

从‘挖土填土’到最优传输:用Python和POT库5分钟上手Wasserstein距离计算

从‘挖土填土’到最优传输:用Python和POT库5分钟上手Wasserstein距离计算

在数据科学和机器学习领域,衡量两个概率分布之间的差异是一个基础而关键的问题。无论是评估生成模型的输出质量,还是检测数据漂移,选择合适的距离度量方法都直接影响结果的可靠性。传统方法如KL散度和JS散度虽然计算高效,但在处理无重叠分布时存在明显缺陷——KL散度可能无限大,JS散度则会产生突变。这时,Wasserstein距离(又称推土机距离)展现出独特优势:即使分布完全不重叠,它仍能提供有意义的平滑度量结果。

本文将带你快速掌握Wasserstein距离的核心思想,并通过Python的POT库实现高效计算。我们不会深入复杂的数学推导,而是从直观的"挖土填土"比喻出发,让你在5分钟内获得可直接应用于实际项目的代码能力。

1. 环境准备与数据生成

1.1 安装POT库

POT(Python Optimal Transport)是当前最成熟的最优传输Python库,支持CPU和GPU加速。安装只需一行命令:

pip install pot

同时确保已安装以下依赖库:

  • NumPy ≥ 1.16
  • SciPy ≥ 1.0
  • Matplotlib(用于可视化)

1.2 生成示例数据

我们模拟两个客户群体的特征分布:群体A年龄集中在20-30岁,收入呈正态分布;群体B年龄偏大(30-40岁),收入分布更分散:

import numpy as np # 设置随机种子保证可复现 np.random.seed(42) # 生成群体A的100个样本 age_A = np.random.uniform(20, 30, 100) income_A = np.random.normal(loc=5000, scale=1000, size=100) # 生成群体B的100个样本 age_B = np.random.uniform(30, 40, 100) income_B = np.random.normal(loc=6000, scale=1500, size=100) # 合并特征 X_A = np.column_stack((age_A, income_A)) X_B = np.column_stack((age_B, income_B))

2. Wasserstein距离计算实战

2.1 距离矩阵构建

计算Wasserstein距离的第一步是定义样本间的"移动成本"。对于我们的二维特征空间,使用欧氏距离作为基础度量:

from scipy.spatial import distance_matrix # 计算所有样本对之间的距离 M = distance_matrix(X_A, X_B) # 归一化到[0,1]范围(可选) M /= M.max()

2.2 使用POT计算精确距离

POT库提供了emd2函数直接计算Wasserstein距离:

from ot import emd2 # 均匀权重(假设每个样本权重相同) a = np.ones(len(X_A)) / len(X_A) b = np.ones(len(X_B)) / len(X_B) # 计算Wasserstein距离 w_dist = emd2(a, b, M) print(f"Wasserstein距离: {w_dist:.4f}")

注意:当样本量较大(>1000)时,考虑使用ot.sinkhorn2近似计算以提升性能

2.3 可视化传输计划

理解"挖土填土"过程最直观的方式是可视化最优传输计划:

import matplotlib.pyplot as plt from ot import emd # 计算传输计划 G = emd(a, b, M) plt.figure(figsize=(10, 5)) plt.scatter(X_A[:,0], X_A[:,1], label='群体A', alpha=0.7) plt.scatter(X_B[:,0], X_B[:,1], label='群体B', alpha=0.7) # 绘制传输量最大的前20条连接 indices = np.argsort(G.ravel())[-20:] for i in indices: row, col = np.unravel_index(i, G.shape) plt.plot([X_A[row,0], X_B[col,0]], [X_A[row,1], X_B[col,1]], 'k-', alpha=0.3, linewidth=G[row,col]*50) plt.legend() plt.xlabel('年龄') plt.ylabel('收入') plt.title('最优传输计划可视化') plt.show()

3. 与传统散度方法的对比

3.1 KL散度与JS散度实现

使用SciPy计算传统散度作为基准:

from scipy.stats import entropy from sklearn.neighbors import KernelDensity # 核密度估计 kde_A = KernelDensity(bandwidth=1.0).fit(X_A) kde_B = KernelDensity(bandwidth=1.0).fit(X_B) # 在网格点上评估概率 grid = np.mgrid[20:40:100j, 3000:8000:100j] points = np.vstack([grid[0].ravel(), grid[1].ravel()]).T log_p_A = kde_A.score_samples(points) log_p_B = kde_B.score_samples(points) p_A = np.exp(log_p_A) p_B = np.exp(log_p_B) # 计算KL散度(非对称) kl_div = entropy(p_A, p_B) # 计算JS散度(对称) m = 0.5 * (p_A + p_B) js_div = 0.5 * (entropy(p_A, m) + entropy(p_B, m)) print(f"KL散度: {kl_div:.4f}") print(f"JS散度: {js_div:.4f}")

3.2 结果对比分析

将三种度量结果整理如下表:

度量方法计算值计算时间(ms)重叠敏感度
Wasserstein距离1.243715.2
KL散度42.8
JS散度0.693143.5

关键发现:

  • 当分布重叠区域很小时,KL散度趋向无穷大,完全失去区分能力
  • JS散度饱和到log(2),无法反映分布间的实际距离变化
  • Wasserstein距离始终提供有意义的数值,且计算效率最高

4. 高级应用与优化技巧

4.1 处理大规模数据集

对于超过10,000个样本的情况,使用熵正则化的Sinkhorn算法:

from ot import sinkhorn2 # 使用Sinkhorn近似计算 reg = 0.1 # 正则化系数 w_dist_approx = sinkhorn2(a, b, M, reg=reg)[0] print(f"近似Wasserstein距离: {w_dist_approx:.4f}")

4.2 自动超参数选择

POT库提供了自动选择最佳正则化参数的工具:

from ot import tune_regularization best_reg = tune_regularization(a, b, M) print(f"最优正则化参数: {best_reg:.4f}")

4.3 GPU加速计算

对于超大规模数据,启用CUDA加速:

import torch import ot.gpu # 将数据转移到GPU M_gpu = torch.from_numpy(M).cuda() # GPU加速计算 w_dist_gpu = ot.gpu.emd2(torch.from_numpy(a).cuda(), torch.from_numpy(b).cuda(), M_gpu) print(f"GPU计算结果: {w_dist_gpu:.4f}")

5. 实际应用场景解析

5.1 生成模型评估

在训练GAN或VAE时,Wasserstein距离可直接作为损失函数:

def wasserstein_loss(real_samples, generated_samples): M = distance_matrix(real_samples, generated_samples) a = np.ones(len(real_samples)) / len(real_samples) b = np.ones(len(generated_samples)) / len(generated_samples) return emd2(a, b, M)

5.2 数据漂移检测

监控生产环境中的数据分布变化:

def detect_drift(reference_data, new_data, threshold=0.5): M = distance_matrix(reference_data, new_data) a = np.ones(len(reference_data)) / len(reference_data) b = np.ones(len(new_data)) / len(new_data) dist = emd2(a, b, M) return dist > threshold

5.3 特征匹配与领域适应

对齐不同来源的数据分布:

from ot.da import sinkhorn_lpl1_mm # 源领域和目标领域数据 Xs = X_A # 源数据 Xt = X_B # 目标数据 # 计算领域适应映射 transp_Xs = sinkhorn_lpl1_mm(Xs, Xt, reg=0.1)
http://www.jsqmd.com/news/893969/

相关文章:

  • 基于深度学习的石油泄漏检测系统(YOLOv8+YOLO数据集+UI界面+Python项目+模型)
  • 告别杂乱,家庭管理一站式解决!用NAS自建家庭规划中心『Oikos』
  • 多Agent虚拟开发:构造功能设想与开发方案(一)
  • A51汇编器行号偏移问题解析与调试优化
  • AI Agent Harness Engineering 的并发控制:多任务同时执行的挑战
  • GD32F407硬件IIC从机模式实战:从官方源码到项目移植的避坑指南
  • 基于粒子群和二进制遗传算法的热电联产经济调度研究附Python代码
  • 命令行终端正在被重写
  • 手把手教你用立创GD32E230开发板实现按键控制LED(GPIO输入输出实战)
  • 住宅 IP 和机房 IP 有什么区别?跨境账号为什么不能只看 IP 国家
  • 用STM32F103C8T6做个桌面小钢炮:0-30V/1.5A数控电源DIY全记录(附源码与PCB)
  • 城市内涝反.复?高精度电子水尺传感器精准监测积水深
  • 从零开始:Hello World 标准 Skill 入门教程
  • 2026年Q2水玻璃厂家联系方式:水玻璃哪个厂家好/水玻璃多少钱一吨/水玻璃批发厂家/水玻璃报价/水玻璃生产厂/选择指南 - 优质品牌商家
  • 【热力学】稳态与瞬态二维热传导的有限差分分析Matlab仿真
  • Win10/Win11系统版本兼容性实测:eNSP搭配VirtualBox 5.2.26如何避开AR 40错误?
  • 告别手动发送!用Python脚本自动化你的Proteus串口仿真测试(STM32篇)
  • LM741反相放大器设计避坑指南:电源、电阻选型与失真问题全解析
  • 2026年中大力德一级授权代理商TOP5权威排行:广州LED驱动电源/广州减速电机/广州工业类开关电源/广州机壳电源/选择指南 - 优质品牌商家
  • PX4Ctrl起飞逻辑深度解析:get_rotor_speed_up_des函数里的6.0和7.0参数到底怎么调?
  • 2026水玻璃标杆厂家盘点:四川硅溶胶厂家推荐、四川硅溶胶厂家电话、四川硅溶胶厂家联系方式、新昂水玻璃厂家联系方式选择指南 - 优质品牌商家
  • SpringBoot实战:三种主流CORS跨域配置方案详解与选型
  • IMXRT开发板SWO跟踪配置与调试指南
  • 保姆级教程:手把手教你安装配置Ultimaker Cura 4.8中文版(Win系统)
  • 别再乱焊了!HC-SR501人体感应模块的光敏电阻,实测告诉你到底该用多大的(附计算方法和串联技巧)
  • 【PFJSP问题】基于自适应双种群协同鸡群算法ADPCCSO求解置换流水车间调度问题PFSP附Matlab代码
  • 2026乐山临江鳝丝TOP5门店排行:乐山跷脚牛肉店有哪些、乐山跷脚牛肉排行前三、乐山跷脚牛肉更正宗、乐山跷脚牛肉哪家好选择指南 - 优质品牌商家
  • A51宏汇编器预定义宏详解与应用技巧
  • 别再傻傻重启Word了!Windows 11/10字体安装后立即生效的正确姿势
  • 从“富足的一生”到代码人生:技术人的精神富足与价值重构