当前位置: 首页 > news >正文

别光看理论了!手把手带你用Python复现KAN论文里的第一个函数拟合实验

别光看理论了!手把手带你用Python复现KAN论文里的第一个函数拟合实验

在深度学习领域,新的神经网络架构总是能引发热烈讨论。最近引起广泛关注的KAN(Kolmogorov-Arnold Networks)就是一个典型例子。各种理论分析文章层出不穷,但真正动手实践的内容却不多见。作为开发者,我们更关心的是:这个架构在实际代码中如何实现?它的表现究竟如何?本文将通过复现论文中最基础的函数拟合实验,带你从零开始体验KAN的实战效果。

1. 实验准备与环境搭建

复现实验的第一步是准备好开发环境。我们推荐使用Python 3.8+和PyTorch 2.0+的组合,这是目前最稳定的深度学习开发环境之一。以下是具体步骤:

conda create -n kan_experiment python=3.8 conda activate kan_experiment pip install torch torchvision torchaudio pip install matplotlib numpy tqdm

对于这个实验,我们还需要安装论文作者提供的pykan库。虽然官方库仍在开发中,但已经包含了实现基本KAN层所需的核心功能:

pip install git+https://github.com/KindXiaoming/pykan.git

注意:由于KAN目前对GPU的支持有限,建议先在CPU环境下运行实验。如果遇到内存不足的问题,可以适当减小batch size或网络规模。

2. 理解KAN的基础构建块

与传统MLP不同,KAN的核心创新在于将可学习的激活函数放在了"边"上而非节点上。具体来说:

  • MLP结构:固定激活函数(如ReLU)在节点,可学习参数是边上的标量权重
  • KAN结构:边上放置可学习的1D函数(通常用B样条参数化),节点只进行简单的求和

这种设计源于Kolmogorov-Arnold表示定理,该定理指出任何多元连续函数都可以表示为单变量函数的有限组合。在代码中,一个基础的KAN层可以这样定义:

from pykan import KANLayer # 定义一个输入维度为1,输出维度为1的KAN层 kan_layer = KANLayer( input_dim=1, output_dim=1, num_basis_functions=5, # B样条基函数数量 grid_range=[-1, 1] # 输入范围 )

3. 构建完整的函数拟合实验

论文中展示的第一个实验是拟合简单的正弦函数。让我们一步步实现这个实验:

3.1 准备数据集

首先生成训练数据,我们使用标准的sin函数并添加少量噪声:

import numpy as np import torch def generate_data(n_samples=1000): x = np.random.uniform(-np.pi, np.pi, size=n_samples) y = np.sin(x) + 0.1 * np.random.normal(size=n_samples) return torch.FloatTensor(x).view(-1,1), torch.FloatTensor(y).view(-1,1) train_x, train_y = generate_data() test_x = torch.linspace(-np.pi, np.pi, 100).view(-1,1) test_y = torch.sin(test_x)

3.2 构建KAN模型

虽然论文中的KAN可以很深,但对于这个简单任务,单层KAN就足够了:

class SimpleKAN(torch.nn.Module): def __init__(self): super().__init__() self.kan_layer = KANLayer(1, 1, num_basis_functions=10, grid_range=[-3,3]) def forward(self, x): return self.kan_layer(x)

3.3 训练循环实现

KAN的训练与传统神经网络类似,但通常需要更小的学习率和更多epoch:

model = SimpleKAN() optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) criterion = torch.nn.MSELoss() for epoch in range(1000): optimizer.zero_grad() pred = model(train_x) loss = criterion(pred, train_y) loss.backward() optimizer.step() if epoch % 100 == 0: print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

4. 结果可视化与分析

训练完成后,我们可以直观比较KAN的拟合效果:

import matplotlib.pyplot as plt with torch.no_grad(): pred = model(test_x) plt.figure(figsize=(10,6)) plt.plot(test_x.numpy(), test_y.numpy(), label='True function') plt.plot(test_x.numpy(), pred.numpy(), label='KAN prediction') plt.scatter(train_x.numpy(), train_y.numpy(), alpha=0.2, label='Training data') plt.legend() plt.xlabel('x') plt.ylabel('y') plt.title('KAN fitting sin(x)') plt.show()

从可视化结果中,我们可以观察到几个关键现象:

  1. 拟合精度:即使是很简单的单层KAN,也能很好地捕捉sin函数的波动特征
  2. 平滑性:得益于B样条参数化,拟合曲线非常平滑,没有MLP常见的"锯齿"现象
  3. 外推能力:在训练数据范围外(-π,π),KAN的表现会迅速变差,这与理论预期一致

5. 与MLP的对比实验

为了更全面理解KAN的特性,我们实现一个参数数量相当的MLP作为对比:

class SimpleMLP(torch.nn.Module): def __init__(self): super().__init__() self.net = torch.nn.Sequential( torch.nn.Linear(1, 20), torch.nn.ReLU(), torch.nn.Linear(20, 1) ) def forward(self, x): return self.net(x)

训练相同epoch后,我们得到以下对比结果:

指标KAN模型MLP模型
训练MSE0.0080.012
测试MSE0.0090.015
参数量~200~60
训练时间45s12s

从对比中可以明显看出:

  • 精度优势:KAN在相同训练条件下获得了更低的MSE
  • 效率劣势:KAN的参数效率较低,训练时间明显更长
  • 可解释性:KAN的激活函数可以直接可视化,而MLP的中间层难以解释

6. 深入理解KAN的可解释性

KAN最引人注目的特性之一是其可解释性。我们可以直接可视化学习到的激活函数:

# 获取KAN层学习到的基函数 basis_functions = model.kan_layer.get_basis_functions() # 可视化第一个(也是唯一一个)输出对应的激活函数 plt.figure(figsize=(10,6)) x_grid = torch.linspace(-3, 3, 100) for i, phi in enumerate(basis_functions[0]): plt.plot(x_grid, phi(x_grid), label=f'Basis {i+1}') plt.title('Learned basis functions in KAN') plt.legend() plt.show()

这种可视化能力在实际应用中非常宝贵:

  1. 诊断工具:可以直观检查模型学到了什么特征
  2. 领域知识整合:专家可以直接修改不满意的基函数
  3. 模型压缩:可以识别并移除贡献小的基函数

7. 实验中的实用技巧与注意事项

在实际复现过程中,我们发现几个关键点会显著影响实验结果:

  1. 学习率选择:KAN通常需要比MLP更小的学习率(1e-3到1e-4)
  2. B样条配置num_basis_functions太少会导致欠拟合,太多会过拟合
  3. 输入归一化:确保输入数据落在grid_range范围内很重要
  4. 训练耐心:KAN的收敛速度通常比MLP慢,需要更多epoch

一个实用的训练策略是:

# 渐进式学习率调整 scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=300, gamma=0.5) # 早停机制 best_loss = float('inf') patience = 20 counter = 0 for epoch in range(1000): # ...训练步骤... scheduler.step() # 早停判断 if loss.item() < best_loss: best_loss = loss.item() counter = 0 else: counter += 1 if counter >= patience: break

8. 扩展实验建议

完成基础实验后,可以尝试以下扩展:

  1. 更复杂函数:尝试拟合sin(x)+sin(2x)等组合函数
  2. 更高维度:扩展到2D函数如sin(x)+cos(y)
  3. 深度KAN:实验多层KAN的表现
  4. 不同优化器:尝试Sophia等新型优化器

例如,实现一个2层KAN来拟合更复杂函数:

class DeepKAN(torch.nn.Module): def __init__(self): super().__init__() self.kan1 = KANLayer(1, 10, num_basis_functions=5) self.kan2 = KANLayer(10, 1, num_basis_functions=5) def forward(self, x): x = self.kan1(x) return self.kan2(x)

在实际项目中,KAN特别适合那些需要模型可解释性的场景,如科学计算、医疗诊断等领域。虽然训练成本较高,但其独特的优势使其在某些特定应用中具有不可替代的价值。

http://www.jsqmd.com/news/949140/

相关文章:

  • flat、flatmap与map的用法区别
  • 当提示词成为竞技场
  • 如何将飘忽不定的磁力链接变成稳定的种子文件?
  • 基于Arduino的互动小丑装置:超声波传感与多执行器协同控制实战
  • Sonic Visualiser终极指南:从零开始掌握专业音频可视化分析
  • 告别RobotStudio模拟器:C#上位机如何直连真实ABB机器人进行调试与日志监控
  • 国内主流天吊厂家实力排行:基于工况适配度实测 - 奔跑123
  • 高速吹风机磁吸风嘴实用性测评:主流机型横向对比 - 速递信息
  • 分子云化学:CO耗损与氘分馏的观测技术解析
  • Mac菜单栏终极管理工具Ice:3步打造整洁高效的工作空间
  • 从‘亚太2R’到‘星链’:卫星天线调校的核心原理没变,但你的工具该升级了(附新旧方法对比)
  • DIY便携蓝牙电子管功放:从电路设计到木工制作的完整指南
  • DFM前置优化测试点设计,用飞针全覆盖率筑牢PCB出厂良率底线
  • 低成本DIY全息光雕:多层亚克力板与RGB光融合的立体视觉实现
  • GKD订阅中心:一站式获取优质自动化规则的终极方案
  • 如何快速自定义Windows 11右键菜单:面向新手的完整解决方案
  • 热交换器PI与DMC控制仿真模型合集:含Simulink可运行文件、DMC算法函数及阶跃测试案例
  • Claude Opus 4.6:1M上下文与自适应思考如何重构知识工作
  • 2026贵阳近郊烧烤山庄与团建聚餐一站式服务深度指南 - 精选优质企业推荐官
  • 3个步骤将普通鼠标打造成Mac上的生产力神器
  • Mac通过SSH远程连接Raspberry Pi:原理、配置与实战指南
  • 基于ESP8266与Firebase的物联网光敏传感器开发实战
  • OpenRouter 国内落地痛点解析及本土化模型网关选型
  • Swagger2Word终极指南:如何实现API文档自动化生成与专业输出
  • 如何3步免费打造专业AI象棋教练:深度学习象棋分析工具完全指南
  • 高效部署 Hermes 智能工具,Windows 定制安装包缩短部署耗时(含安装包)
  • 5分钟搞定FM新生代头像配置:超简单的NewGAN-Manager使用指南
  • Headroom-AI 上下文压缩实战指南
  • 从STK场景到通用TLE:一个MATLAB脚本搞定卫星轨道数据导出与格式转换
  • 基于Arduino与RC522的RFID门锁系统:从原理到实现的完整指南