当前位置: 首页 > news >正文

TPU 架构与 Pallas Kernel 编程入门:从内存层次结构到 FlashAttention

做过 GPU kernel 优化的人对以下编程模型肯定不会陌生:写一个 CUDA kernel分发到流式多处理器(SM)上执行,缓存层次结构自行负责数据搬运。而TPU 则完全不同,除非明确告诉编译器要把哪些数据块搬到哪里,否则kernel 根本无法编译。实际操作确实和听起来一样繁琐,所以JAX 的Pallas 就是解决的这个问题:以 tile 为单位描述计算,无需手动指定输入张量各部分的搬运路径,编译器自动生成所需的数据移动操作。

本文从硬件约束入手,接着逐步编写复杂度递增的 kernel,最后分析 JAX 生产级 FlashAttention 实现。我们先从基础开始,把那些绕不开的"为什么"讲清楚。

为什么不能在 TPU 上直接写循环?

GPU 上的基本原理很简单:写一个对单个元素或小块数据操作的 kernel,硬件调度成千上万份到各核心执行。线程通常处理同一张量中位置相邻的元素,大量线程同时读取内存中相邻的区域。GPU 的设计就是围绕这一模式展开的:自动合并相邻读取,将近期访问的数据保留在靠近计算单元的位置。内存访问符合这个模式时性能很好;不符合时,硬件通常也能平滑掉一部分开销。

  1. __global__ void add(float* x, float* y, float* out, int n) {
  2. int i = blockIdx.x * blockDim.x + threadIdx.x;
  3. if (i < n) {
  4. out[i] = x[i] + y[i];
  5. }
  6. }
  7. // 幕后:数千个线程在 GPU 上同时运行这同一个 kernel。
  8. // thread 0 → out[0] = x[0] + y[0]
  9. // thread 1 → out[1] = x[1] + y[1]
  10. // thread 2 → out[2] = x[2] + y[2]

理解 Pallas 的价值,先要看清 TPU 和 GPU 在定位上的根本差异。TPU 不是通用并行处理器,它只做一件事,矩阵运算而且做得极好。它不会给游戏带来更高帧率,但一定可以加速模型训练。TPU v5e 芯片围绕一个称为 TensorCore 的计算模块构建,内含四个 MXU(Matrix Multiply Unit),可以理解为 128×128 的 systolic array乘法器排成网格,计算结果沿网格逐级传递给相邻单元。TPU 的内存层次结构不像 GPU 那样自动管理缓存,数据必须在三个层次之间显式搬运:

  • HBM(高带宽内存):v5e 上约 16 GB,张量存放的位置,片外,速度相对较慢。
  • VMEM(向量内存):16+ MB 的片上 SRAM,速度快但容量小;数据到达这里后计算单元才能访问。
  • 寄存器:算术运算实际发生的位置,值从 VMEM 加载到寄存器、完成计算后写回 VMEM。

TPU 计算需要显式的数据暂存。

没法在 TPU 上像 CPU 或 GPU 那样对数据写一个简单循环,原因就在这里,数据不会自动从 HBM 流到寄存器。必须显式调度 DMA(直接内存访问)传输,将数据从 HBM 搬入 VMEM;kernel 执行完毕后 VMEM 中的结果再写回 HBM,这是 Pallas 存在的根本理由。GPU 上写

 

https://avoid.overfit.cn/post/12fe51915c5b439aacc1d33f3e4a2b12

http://www.jsqmd.com/news/535049/

相关文章:

  • Linux软RAID实战:mdadm构建RAID5及故障磁盘热替换指南
  • 2026年毕设AIGC检测过不了?这3款降AI工具亲测靠谱
  • Python VTK实战:5步搞定瓦力机器人3D模型渲染(附完整代码)
  • 20252906 2025-2026-2 《网络攻防实践》第1周作业
  • Python实战:5分钟搞定三菱PLC数据读取(附HslCommunication模块避坑指南)
  • 从Kettle老手到Hop新手:我的第一个数据管道迁移踩坑实录(附避坑清单)
  • 【全网首发】2026华为OD双机位C卷 机考真题题库含考点说明以及在线OJ (Java)
  • 亲测有效!论文AIGC率直降40%攻略:4个指令+3个技巧
  • Fluent 熔覆质量流模拟与激光电弧复合熔滴熔池模拟探索
  • LangChain实战:10行代码创建智能Agent,小白也能看懂(建议收藏)
  • AI报告文档审核护航飞行安全:IACheck打造航电与飞控检测报告智能审核新利器
  • CVPR2024无监督学习新突破:17篇论文中的5个实战技巧与避坑指南
  • ESP32玩转Matter协议:手把手教你用ESP-Matter搭建智能家居设备(附避坑指南)
  • 手把手教你用GPEN镜像修复老照片:单图增强+批量处理全攻略
  • Wan2.2-I2V-A14B构建MCP服务:实现与Claude等AI助手的无缝协作
  • SWAT模型数据准备保姆级避坑指南:从DEM到气象数据的完整ArcGIS+SWATweather流程
  • 告别手动复制!用Apifox Helper插件实现IDEA代码注释自动同步API文档(2024最新版)
  • 西门子S7-1200PLC与TP700触摸屏联机的自动洗车机控制系统博途V16应用解析
  • OpenClaw任务编排:GLM-4.7-Flash复杂流程自动化
  • 开源社区运营:Qwen1.5-1.8B GPTQ自动回复GitHub Issues与生成Release Note
  • 题解:qoj17256 Keep or Gamble
  • 全球微高压氧舱:健康消费升级与康复需求驱动下的爆发扩容,2026-2032年CAGR14.9%,2032年规模4.14亿美元
  • ZLMediaKit专业级流媒体服务器:3步完成高效部署方案
  • Lightpanda无头浏览器:11倍性能提升的自动化革命指南
  • 从焊接台到代码:手把手调试LAN8742以太网PHY的5个关键步骤
  • 5步搞定黑苹果配置:OpCore Simplify让EFI生成效率提升95%的实战指南
  • AI智能体权限过大?OpenClaw等框架的5个高危配置必须检查,否则代码真会“裸奔“!
  • 20253912 2025-2026-2 《网络攻防实践》第二周作业
  • ssm+java2026年毕设舒旅程旅游景点预订网站【源码+论文】
  • Flutter GetX Snackbar实战:5分钟实现顶部弹窗通知(附完整属性表)