当前位置: 首页 > news >正文

整体理解pai0-具身智能-PyTorch einsum 完全教程-11 - jack

目录
  • 1. 基础概念
  • 2. 基础语法
    • Level 1: 向量点积
    • Level 2: 矩阵乘法
    • Level 3: 批次矩阵乘法(Transformer中常用)
  • 4. PI0 代码中的实际例子
    • 例子1: QKV 投影 (gemma.py:183)
    • 例子2: 注意力计算 (gemma.py:217)
    • 例子3: 注意力输出 (gemma.py:230)
  • 5. 常见模式总结
  • 6. 调试技巧
  • 7. 练习题

1. 基础概念

einsum = Einstein Summation (爱因斯坦求和约定)

用简洁的字符串表示复杂的张量运算(乘法、求和、转置等)

2. 基础语法

torch.einsum("equation", tensor1, tensor2, ...)
字母代表维度
相同字母会进行对应相乘
输出中不出现的字母会被求和消除
逗号分隔不同的输入张量
箭头 -> 指定输出维度

Level 1: 向量点积

# 传统方法
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
result = torch.dot(a, b)  # 1*4 + 2*5 + 3*6 = 32# einsum 方法
result = torch.einsum('i,i->', a, b)
#                      ↑ ↑  ↑
#                      a b  输出(标量)

i: a 的第 0 维,b 的第 0 维
两个 i 相同 → 对应元素相乘
输出没有 i → 求和

Level 2: 矩阵乘法

# 传统方法
A = torch.randn(3, 4)  # [3, 4]
B = torch.randn(4, 5)  # [4, 5]
C = torch.mm(A, B)     # [3, 5]# einsum 方法
C = torch.einsum('ik,kj->ij', A, B)
#                 ↑↑  ↑↑  ↑↑
#                 A   B   输出

解析:

A.shape = (3, 4)  # i=3, k=4
B.shape = (4, 5)  # k=4, j=5# 运算: C[i,j] = Σ_k A[i,k] * B[k,j]
# k 出现在两边但不在输出 → 求和消除
# i, j 在输出 → 保留C.shape = (3, 5)  # i=3, j=5

Level 3: 批次矩阵乘法(Transformer中常用)

# Batch Matrix Multiplication
A = torch.randn(2, 3, 4)  # [batch, n, k]
B = torch.randn(2, 4, 5)  # [batch, k, m]# einsum 方法
C = torch.einsum('bik,bkj->bij', A, B)
#                 ↑             ↑
#              batch维度     batch维度
A.shape = (2, 3, 4)  # b=2, i=3, k=4
B.shape = (2, 4, 5)  # b=2, k=4, j=5# 运算: C[b,i,j] = Σ_k A[b,i,k] * B[b,k,j]
# b 在输出 → 保留(不求和)
# k 不在输出 → 求和消除C.shape = (2, 3, 5)  # [batch, n, m]

4. PI0 代码中的实际例子

例子1: QKV 投影 (gemma.py:183)

qkvs.append(qkv_einsum("BSD,3KDH->3BSKH", x))

# 输入
x.shape = (B, S, D)
# B = batch_size (例如 2)
# S = sequence_length (例如 512)
# D = hidden_dim (例如 2048)# 权重
weight.shape = (3, K, D, H)
# 3 = Q, K, V 三个矩阵
# K = num_kv_heads (例如 1)
# D = hidden_dim (2048)
# H = head_dim (256)# einsum: "BSD,3KDH->3BSKH"
#          ↑    ↑      ↑
#          x  weight  输出# 维度对应:
# B: batch (保留)
# S: sequence (保留)
# D: hidden_dim (求和消除,因为不在输出)
# 3: QKV (保留)
# K: num_heads (保留)
# H: head_dim (保留)# 输出
output.shape = (3, B, S, K, H)
# 例如: (3, 2, 512, 1, 256)

例子2: 注意力计算 (gemma.py:217)

logits = jnp.einsum("BTKGH,BSKH->BKGTS", q, k)

解析:

# 输入
q.shape = (B, T, K, G, H)
# B = batch_size
# T = query_length (例如 512)
# K = num_kv_heads (1)
# G = group_size (8, 因为8个query heads / 1个kv head)
# H = head_dim (256)k.shape = (B, S, K, H)
# S = key_length (例如 512)# einsum: "BTKGH,BSKH->BKGTS"
#          ↑      ↑     ↑
#          q      k    输出# 维度对应:
# B: batch (保留)
# T: query_length (保留)
# K: num_kv_heads (保留)
# G: group_size (保留)
# H: head_dim (求和消除!)
# S: key_length (保留)# 输出
logits.shape = (B, K, G, T, S)# 语义: logits[b,k,g,t,s] = Σ_h q[b,t,k,g,h] * k[b,s,k,h]
#       即: query位置t 对 key位置s 的注意力分数

T query的长度
S key的长度
G group_size

例子3: 注意力输出 (gemma.py:230)

encoded = jnp.einsum("BKGTS,BSKH->BTKGH", probs, v)
解析:

# 输入
probs.shape = (B, K, G, T, S)  # 注意力权重(softmax后)
v.shape = (B, S, K, H)          # Value# einsum: "BKGTS,BSKH->BTKGH"
#          ↑      ↑     ↑
#        probs    v    输出# 维度对应:
# B: batch (保留)
# K: num_kv_heads (保留)
# G: group_size (保留)
# T: query_length (保留)
# S: key_length (求和消除!) <- 加权求和
# H: head_dim (保留)# 输出
encoded.shape = (B, T, K, G, H)# 语义: encoded[b,t,k,g,h] = Σ_s probs[b,k,g,t,s] * v[b,s,k,h]
#       即: 用注意力权重加权 value

维度对应:
B: batch (保留)
K: num_kv_heads (保留)
G: group_size (保留)
T: query_length (保留)
H: head_dim (保留)

5. 常见模式总结

模式1: 矩阵乘法

# 2D
'ik,kj->ij'  # (i,k) @ (k,j) = (i,j)# 3D (batch)
'bik,bkj->bij'  # (b,i,k) @ (b,k,j) = (b,i,j)# 4D
'bhik,bhkj->bhij'  # 多头注意力

模式2: 外积

# 向量外积
'i,j->ij'  # (i,) ⊗ (j,) = (i,j)# 批次外积
'bi,bj->bij'(i,) ⊗ (j,) = (i, j),表示外积,维度相乘得到二维矩阵。

image
'i,j->ij' 表示将两个一维向量的所有元素两两相乘,生成一个二维矩阵,也就是向量的 外积(outer product)。

模式3: 求和

# 沿某个维度求和
'ijk->ij'   # 对k求和
'ijk->ik'   # 对j求和
'ijk->'     # 全部求和(标量)

模式4: 转置

'ij->ji'    # 转置
'ijk->ikj'  # 交换维度

模式5: 对角线

'ii->i'     # 提取对角线
'bii->bi'   # 批次对角线

6. 调试技巧

技巧1: 写出维度

# 先写出每个张量的维度
A: (3, 4)  # i=3, k=4
B: (4, 5)  # k=4, j=5# 再写 einsum
'ik,kj->ij'# 验证: k 求和消除,输出 (i, j) = (3, 5) ✓

技巧2: 分步理解

result = torch.einsum('bik,bkj->bij', A, B)# 步骤1: 找共同维度
# b: 共同(batch)
# k: 共同(求和)# 步骤2: 找独有维度
# i: 只在 A
# j: 只在 B# 步骤3: 确定输出
# b: 保留(在输出中)
# i: 保留(在输出中)
# j: 保留(在输出中)
# k: 消除(不在输出中)

技巧3: 用注释

q = torch.einsum('BTD,NDH->BTNH',  # Query projectionx,      # [B, T, D] = [batch, seq, hidden]w_q,    # [N, D, H] = [heads, hidden, head_dim]
)           # → [B, T, N, H]

7. 练习题

# 1. 简单点积
'i,i->'# 2. 批次矩阵乘法
'bmn,bnk->bmk'# 3. 多头注意力
'bhqd,bhkd->bhqk'# 4. 位置编码
'm,d->md'# 5. 交叉注意力
'bid,bjd->bij'

希望这个教程能帮你理解 einsum!关键是:
把字母当作维度的名字
相同字母 = 对应相乘
输出中没有的字母 = 求和消除

http://www.jsqmd.com/news/21517/

相关文章:

  • 2025年北京奢侈品品牌首饰回收公司权威推荐榜单:钻石回收/黄金回收/钻戒回收源头公司精选
  • 查询每门成绩都大于80分的同学学号
  • 【C++】函数参数传递
  • C++ lambd表达式
  • NVIDIA与Adobe漏洞深度解析
  • 监督学习、无监督学习、半监督学习、强化学习、自监督学习
  • 2025 年退磁器生产厂家最新推荐榜:技术创新、行业适配与服务保障全景对比及权威测评结果强力退磁器/手提退磁器/小型退磁器公司推荐
  • 计算机组成原理:磁盘存储设备 - 实践
  • 2025 年最新推荐辊涂机源头厂家推荐榜单:UV 漆 / 玻璃 / 铝板 / 木门 / PVC 地板辊涂机优质企业全解析
  • 【哲学思考】:规则
  • 2025.10.24第一节课内容
  • 【IEEE出版 | 高届数会议 | 上届已于会后3个多月完成见刊检索】2025第九届控制工程与国际论坛(IWCEAA 2025)
  • SQLServer截取字符串、字符串长度、特殊字符在字符串的下标索引
  • 题解:P8134 [ICPC 2020 WF] Opportunity Cost
  • 深入解析:数据结构 之 【图的遍历与最小生成树】(广度优先遍历算法、深度优先遍历算法、Kruskal算法、Prim算法实现)
  • 完整教程:构建并运行最小 Linux 内核
  • word批量转pdf
  • 【SAE出版 | 高届数 | 检索稳定】第七届土木建筑与城市工程国际学术会议(ICCAUE 2025)
  • qcefview库的使用
  • 解决Qt 不能debug问题
  • Exadata数据库性能异常,备份进程卡住
  • 做本地门户网站 10 年,我靠微擎摆脱了 “客户需求五花八门” 的噩梦
  • 2025 年国内吸顶灯源头厂家最新推荐排行榜:聚焦全光谱技术与品质生产,精选优质厂家助力家居照明选购全光谱/中山现代/客厅现代/吊灯吸顶灯公司推荐
  • RabbitMQ框架及应用场景
  • 【开题答辩全过程】以 “辛巴克餐饮”小程序为例,具备答辩的问题和答案
  • 2025年一体化雨水提升泵站厂家权威推荐榜单:污水提升泵站/一体化污水泵站/一体化雨水泵站源头厂家精选
  • STM32软件I2C读写AT24C64 - 指南
  • bcc
  • 手写ibatis
  • 国产IPD项目管理软件推荐|别再靠 Excel 推 IPD 了!帮你把IPD流程从“纸上”搬进系统