PyTorch 无循环张量多对一求和方法
时间:2026-04-09 18:45:44 198浏览 收藏
本文揭秘了如何利用 PyTorch 的 `scatter_add_` 原语,结合 `repeat_interleave` 和索引展平技巧,以完全向量化、零 Python 循环的方式高效实现一维张量到另一维张量的“一对多”映射累加(如多源值聚合至目标位置),不仅大幅提升 GPU 并行计算效率、保持梯度可导性,还显著简化代码逻辑——告别慢速循环与手动索引遍历,让复杂映射操作变得简洁、健壮且生产就绪。

本文介绍使用 torch.Tensor.scatter_add_ 配合索引展开与值重复,高效完成一维张量到另一维张量的一对多映射累加操作,避免 Python 循环,完全基于向量化运算。
本文介绍使用 `torch.Tensor.scatter_add_` 配合索引展开与值重复,高效完成一维张量到另一维张量的一对多映射累加操作,避免 Python 循环,完全基于向量化运算。
在 PyTorch 中处理「一对多」映射关系(即每个输入元素贡献至多个输出位置)并执行聚合(如求和)时,若采用 Python 循环或列表推导,不仅代码冗长,更会严重拖慢训练速度、破坏计算图完整性,且无法充分利用 GPU 并行能力。幸运的是,PyTorch 提供了高度优化的原语——scatter_add,它专为这类“按索引分散累加”场景设计,可一次性完成全部映射与聚合。
核心思想是将不规则映射结构(如嵌套列表 mapping)转化为两个齐次一维张量:
- src:待累加的源值序列,其中每个 input[i] 根据其映射目标数量被重复;
- index:对应的目标位置索引序列,与 src 严格对齐;
- out:初始化为零的输出张量,长度由最大目标索引决定。
以下为完整实现示例:
import torch # 输入定义 input = torch.tensor([0, 1, 2, 3], dtype=torch.float32) mapping = [[1], [0, 2, 4], [0, 3], [1, 2]] # 步骤 1:计算各输入项的重复次数(即每个 input[i] 映射到多少个 output 位置) reps = torch.tensor([len(x) for x in mapping]) # 步骤 2:构建 src —— 按 reps 重复 input 中每个元素 src = input.repeat_interleave(reps) # tensor([0, 1, 1, 1, 2, 2, 3, 3]) # 步骤 3:构建 index —— 展平 mapping,得到所有 (src[i] → output[j]) 的 j 序列 index = torch.tensor([j for sublist in mapping for j in sublist]) # tensor([1, 0, 2, 4, 0, 3, 1, 2]) # 步骤 4:初始化输出张量(长度 = max(index) + 1) out = torch.zeros(max(index) + 1, dtype=src.dtype) # 步骤 5:执行向量化累加:out[j] += src[i] for each (i,j) pair result = out.scatter_add(dim=0, index=index, src=src) print(result) # tensor([3., 3., 4., 2., 1.])
✅ 关键优势:
- 全程无 Python 循环,100% 张量操作,支持 CUDA 加速;
- 时间复杂度为 O(∑|mapping[i]|),空间复杂度为 O(len(output)),理论最优;
- 自动兼容梯度传播(scatter_add 是可微分操作),适用于模型中间层。
⚠️ 注意事项:
- index 中的索引必须是非负整数,且严格小于 out.size(dim),否则抛出 RuntimeError;
- 若 mapping 可能为空(如 []),需提前过滤或用 max(index, default=0) 防御;
- 当 output 维度极大但稀疏时,该方法仍会分配全量内存;如需极致稀疏支持,可考虑结合 torch.sparse 或自定义 CUDA kernel,但绝大多数场景 scatter_add 已足够高效。
总结而言,scatter_add 是解决 PyTorch 中「一对多映射+聚合」问题的标准、简洁且高性能方案。掌握其与 repeat_interleave、索引展平等组合技巧,能显著提升数据预处理与自定义层的表达力与执行效率。
好了,本文到此结束,带大家了解了《PyTorch 无循环张量多对一求和方法》,希望本文对你有所帮助!关注golang学习网公众号,给大家分享更多文章知识!
相关阅读
更多>
-
501 收藏
-
501 收藏
-
501 收藏
-
501 收藏
-
501 收藏
最新阅读
更多>
-
342 收藏
-
296 收藏
-
483 收藏
-
177 收藏
-
145 收藏
-
343 收藏
-
266 收藏
-
488 收藏
-
121 收藏
-
438 收藏
-
495 收藏
-
431 收藏
课程推荐
更多>
-
- 前端进阶之JavaScript设计模式
- 设计模式是开发人员在软件开发过程中面临一般问题时的解决方案,代表了最佳的实践。本课程的主打内容包括JS常见设计模式以及具体应用场景,打造一站式知识长龙服务,适合有JS基础的同学学习。
- 立即学习 543次学习
-
- GO语言核心编程课程
- 本课程采用真实案例,全面具体可落地,从理论到实践,一步一步将GO核心编程技术、编程思想、底层实现融会贯通,使学习者贴近时代脉搏,做IT互联网时代的弄潮儿。
- 立即学习 516次学习
-
- 简单聊聊mysql8与网络通信
- 如有问题加微信:Le-studyg;在课程中,我们将首先介绍MySQL8的新特性,包括性能优化、安全增强、新数据类型等,帮助学生快速熟悉MySQL8的最新功能。接着,我们将深入解析MySQL的网络通信机制,包括协议、连接管理、数据传输等,让
- 立即学习 500次学习
-
- JavaScript正则表达式基础与实战
- 在任何一门编程语言中,正则表达式,都是一项重要的知识,它提供了高效的字符串匹配与捕获机制,可以极大的简化程序设计。
- 立即学习 487次学习
-
- 从零制作响应式网站—Grid布局
- 本系列教程将展示从零制作一个假想的网络科技公司官网,分为导航,轮播,关于我们,成功案例,服务流程,团队介绍,数据部分,公司动态,底部信息等内容区块。网站整体采用CSSGrid布局,支持响应式,有流畅过渡和展现动画。
- 立即学习 485次学习