当前位置: 首页 > news >正文

面试-Torch函数

0. 连续张量和非连续张量

1.核心含义:“连续(contiguous)” 描述的是张量底层数据在内存中的存储方式。
2.连续张量:张量的元素在内存中按“行优先”顺序连续排列,没有间隔,能通过固定步长遍历所有元素;
3.非连续张量:经过transpose()、permute()等操作后,张量的维度顺序变了,但底层数据的存储顺序没改,导致元素在内存中不再连续,遍历需要不规则步长。

用 “书” 举例:

  • 连续张量:书按[0,0]→[0,1]→[0,2]→[1,0]→[1,1]→[1,2]的顺序摆放在一排,没有空隙;
  • 非连续张量(如转置后):维度变成[列,行],但书的摆放顺序还是原来的[0,0]→[0,1]→[0,2]→[1,0]→[1,1]→[1,2],此时要按列取数(如[0,0]→[1,0]→[0,1]→[1,1]),需要跳着找书,内存不连续。

1. torch.view()

核心作用:重塑张量形状,采用方式是 “共享内存”(修改新张量会影响原张量的位置),要求张量是 “连续的(contiguous)”,否则会报错。
特点:不改变原始 x ,通过共享内存的方式改变张量的形状,并且仅支持连续张量。因为view()需要按固定步长重塑维度。

进一步解读:PyTorch 中像view()这类操作,并不会复制张量的底层数据,而是创建一个新的 “视图(view)” —— 新张量和原张量共用同一块内存空间,只是对数据的 “解读方式”(维度、步长)不同。因此,修改新张量的某个元素,原张量对应位置的元素也会同步改变,反之亦然。

importtorch# 原始张量x=torch.randn(2,3)print("原始x shape:",x.shape)# ([2,6])# 重塑x_view=x.view(2,3,3)print("重塑x shape:",x_view.shape)# ([2,3,3])# 验证共享内存x_view[0,0,0]=100.0print("原始x[0,0]:",x[0,0])# tensor(100.)

2. torch.reshape()

核心作用:重塑张量形状,无需张量连续,是更推荐的通用重塑方法。
特点:reshape兼容非连续张量,view仅支持连续张量;功能上几乎等价,新手优先用reshape。

importtorch# 原始张量x=torch.arange(12).reshape(3,4)# torch.Size([3,4])print("x shape:",x.shape)# 重塑为[4,3]x_trans=x.transpose(0,1)# 连续张量->非连续张量x_reshape=x_trans.reshape(4,3)# [3,4] -> [4,3]print("x_reshape:",x_reshape.shape)# 展平x_flat=x_reshape.reshape(-1)# -1 表示自动计算维度print("x_flat shape:",x_flat.shape)

3. torch.triu()

核心作用:提取张量的上三角部分,其余元素置 0;常用来构造因果掩码(如 Transformer 的自注意力)。
特点:提取张量的上三角部分。其中,diagonal(对角线偏移,默认 0,diagonal=1表示主对角线以上的部分)。

importtorch# 原始矩阵:torch.ones(x,y) 创建x=torch.ones(3,3)# 提取上三角部分,(diagonal=1:主对角线以上保留)x_triu=torch.triu(x,diagonal=1)print("x_triu:",x_triu)# 输出:# tensor([[0., 1., 1.],# [0., 0., 1.],# [0., 0., 0.]])# 构造因果掩码矩阵seq_len=3mask=torch.triu(torch.full(seqlen,seqlen),float("-inf"),diagonal=1)print("mask:",mask)# 输出:# tensor([[-inf, -inf, -inf],# [ -inf, -inf, -inf],# [ -inf, -inf, -inf]])

4. torch.full()

核心作用:创建指定形状、所有元素均为固定值的张量;常用于构造掩码(如负无穷、0/1 掩码)。
特点:必须得传入默认值。
参数:size(张量形状)、fill_value(填充值)、device(可选,指定设备)。比 torch.ones(x,y) 和 torch.zeros(x,y) 要更灵活。

importtorch# 创建2x3的全5张量 torch.full((seq, seq), float("-inf"))x_full=torch.full((2,3),5.0)print("x_full:\n",x_full)# 输出:# tensor([[5., 5., 5.],# [5., 5., 5.]])# 创建3x3的全负无穷张量(注意力掩码常用)mask=torch.full((3,3),float("-inf"),device="cpu")print("mask:\n",mask)# 输出:# tensor([[-inf, -inf, -inf],# [-inf, -inf, -inf],# [-inf, -inf, -inf]])

5. torch.transpose()

核心作用:交换张量的两个维度;常用于矩阵转置、调整注意力张量的维度顺序(如[bsz, seq_len, heads]→[bsz, heads, seq_len])。
参数:dim0、dim1(要交换的两个维度索引)。

importtorch# 通过 torch.randn()、torch.full()、torch.ones()创建张量x=torch.randn(1,512,16,1024)x=torch.full((1,512,16,1024),"float(-inf)")x=torch.full((1,512,16,1024))# 报错,torch.full(tensor, value)必须得同时传入默认值、张量两个元素x=torch.ones(1,512,16,1024)print("x:",x)# torch.Size([bsz, seq, heads, dim])

6. torch.cat()

核心作用:指定维度上拼接多个张量;要求除拼接维度外,其他维度形状完全一致。
参数:tensors(待拼接的张量列表)、dim(拼接维度)。

importtorch# 通过 torch.randn()、torch.full()、torch.ones()创建张量x=torch.randn(1,512,16,1024)x=torch.full((1,512,16,1024),"float(-inf)")x=torch.full((1,512,16,1024))# 报错,torch.full(tensor, value)必须得同时传入默认值、张量两个元素x=torch.ones(1,512,16,1024)print("x:",x)# [bsz, seq, heads, dim]x2=torch.randn(1,512,8,1024)# 在维度2上进行拼接x_cat=torch.cat([x1,x2],dim=2)print("x_cat shape:",x_cat.shape)# torch.Size([1,512,24,1024])# 注意力 KV 缓存拼接past_kv=torch.randn(1,10,1024)# [bsz, seq, dim],这里seq代表已经处理了 10 个kv健cur_kv=torch.randn(1,1,1024)# 当前 kv 键值对new_kv=torch.cat([past_kv,new_kv],dim=1)print("new_kv cache:",new_kv)# torch.cat([a, b], dim=c):torch.Size([1, 11, 1024])

7. torch.arange()

核心作用:创建连续整数序列的一维张量;常用于生成索引、位置编码等。
特点:torch.arange() 是根据步长来生成张量的,没有默认值,只能生成一维张量;torch.full() 能生成任意维度张量,且支持默认值;torch.randn() 随机生成指定维度的张量,不支持默认值。
参数:start(起始值,默认 0)、end(结束值,不包含)、step(步长,默认 1)。

# 生成0到9的整数:[0,1,2,...,9]x1=torch.arange(10)print("x1:",x1)# tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])# 生成1到9,步长2:[1,3,5,7,9]x2=torch.arange(1,10,2)print("x2:",x2)# tensor([1, 3, 5, 7, 9])# 结合size()使用:生成与张量某维度长度匹配的索引x=torch.randn(2,5,8)# 生成0到x.size(1)-1的索引(x.size(1)=5)idx=torch.arange(x.size(1))print("idx:",idx)# tensor([0, 1, 2, 3, 4])

8. tensor.size() / tensor.shape

核心作用:获取张量的形状信息;size()是方法,shape是属性,功能几乎等价。
参数:dim(可选,指定维度索引,返回该维度的长度;不指定则返回 torch.Size 对象)。如 x.size(0) 代表张量中第一个维度的大小。

x=torch.randn(2,3,4)# 获取整体形状print("x.size():",x.size())# torch.Size([2, 3, 4])print("x.shape:",x.shape)# torch.Size([2, 3, 4])# 获取指定维度的长度print("维度0长度:",x.size(0))# 2(批次大小)print("维度1长度:",x.size(1))# 3(序列长度)print("维度2长度:",x.size(2))# 4(特征维度)# 解包形状(常用操作)bsz,seq_len,hidden_dim=x.size()print(f"批次:{bsz}, 序列长度:{seq_len}, 特征维度:{hidden_dim}")# 批次:2, 序列长度:3, 特征维度:4

9. torch.unsqueeze() / torch.squeeze()

核心作用:插入和删除指定维度,插入和删除的维度的长度为1.
torch.unsqueeze(tensor, dim):在指定维度插入一个维度(维度长度为 1),常用于扩展掩码维度;
torch.squeeze(tensor, dim):删除长度为 1 的维度,简化张量形状。

# unsqueeze:扩展维度(注意力掩码常用)mask=torch.randn(2,3)# torch.Size([2,3])# 插入维度1和2:shape [2,1,1,3](匹配注意力分数维度)mask_unsq=mask.unsqueeze(1).unsqueeze(2)print("mask_unsq shape:",mask_unsq.shape)# torch.Size([2, 1, 1, 3])# squeeze:删除长度为1的维度x=torch.randn(2,1,3,1)x_sq=x.squeeze()# 删除所有长度为1的维度print("x_sq shape:",x_sq.shape)# torch.Size([2, 3])

总结:

  • 形状调整:reshape(通用)、view(共享内存)是核心,优先用reshape;size()/shape用于获取形状信息。
  • 维度操作:transpose(交换维度)、unsqueeze/squeeze(增 / 删维度)、cat(拼接张量)是维度调整高频函数。
  • 特殊张量创建:arange(生成序列)、full(固定值张量)、triu(上三角矩阵)常用于掩码、索引构造。
  • 记忆要点:cat要求非拼接维度形状一致;triu(diagonal=1)是 Transformer 因果掩码的核心;unsqueeze是扩展掩码维度的常用操作。
http://www.jsqmd.com/news/358807/

相关文章:

  • 2025 AI 变局:大模型“退烧”,Agent“上位” —— 深度复盘 DeepSeek、GPT-4o 与 Llama 3 的三国杀
  • 升鲜宝生鲜配送供应链管理系统 仓储式收银系统(多公司多门店 POS+会员+钱包+权益+门店WMS+库存成本+离线同步)
  • PostgreSQL 性能优化: I/O 瓶颈分析,以及如何提高数据库的 I/O 性能?
  • AI取代人工?别傻了,真正的危机是“超级个体”正在吞噬“平庸团队” —— 深度解析人机协作新范式
  • 《程序员修炼之道》——从小工到专家的习惯养成
  • 常用的 PNG 转 JPG 在线网站整理(无需安装,直接使用)
  • 【2 月小记】Part 3: CROI-R3 比赛总结 - L
  • 国内科研必备:16个Google和谷歌学术镜像站,2026最新更新
  • 集成灶的噪音大不大?揭秘静音真相+选购攻略|厨房宁静指南 - 匠言榜单
  • yolo姿态估计的板端算力占用评估
  • 如何选择合适的IP查询工具?精准度与更新频率全面分析
  • QMdiArea多窗口管理容器。官方demo,搜素mdi。复制,剪切,粘贴
  • QMimeData 是 Qt 中数据交换的标准化载体。粘贴复制,跨应用的标准格式。也能自定义数据类型
  • 2026年我会推荐哪些IP归属地查询网站?
  • 《梦断代码》——软件项目的理想与现实
  • 《人月神话》中的项目管理陷阱与启示
  • 外贸站必备!WordPress经销商地图,多国家适配+自动检索,省爆客服力!
  • 当内容遇冷之后:系统化运营如何激活短视频生命力 - 品牌之家
  • 【取模】思源黑体 取模只显示一部分问题,或者挤在一起
  • Excel分类汇总完全指南:从数据分析到分页打印的专业应用
  • 历史课不再枯燥!老师用什么AI工具做历史人物生平教学视频?横评 3 类神器,这款让学生抢着听课
  • 直流无刷电机,直径38mm,径向长23.8mm,转速25000rpm,功率200W
  • 嵌入式Linux:线程同步(读写锁) - 教程
  • 运用 HTML5 Canvas 实现可交互的内容瀑布流(隐藏式运维模式)
  • 《一文搞懂PyTorch优化器:SGD/Adam原理、使用流程与实战调优指南》
  • 本科生必看!万众偏爱的AI论文网站 —— 千笔ai写作
  • 救命神器!AI论文平台 千笔写作工具 VS 知文AI,专为本科生量身打造!
  • 一遍搞定全流程!专科生专属AI论文神器 —— 千笔·专业论文写作工具
  • 单例模式管理模型客户端的几种实现方式
  • OpenClaw 最新保姆级飞书对接指南教程 搭建属于你的 AI 助手