登录
首页 >  文章 >  python教程

PyTorch广播与矩阵乘法教程

时间:2026-01-02 12:12:39 352浏览 收藏

你在学习文章相关的知识吗?本文《PyTorch 广播与矩阵乘法详解》,主要介绍的内容就涉及到,如果你想提升自己的开发能力,就不要错过这篇文章,大家要知道编程理论基础和实战操作都是不可或缺的哦!

PyTorch 中的广播机制与矩阵乘法:彻底厘清常见误解

本文澄清 PyTorch 中广播(broadcasting)与矩阵乘法(`matmul`)的本质区别:广播不适用于形状不兼容的逐元素运算(如 `+`),而 `X @ Y` 或 `torch.matmul(X, Y)` 才是正确执行 2×4 与 4×2 矩阵乘法的方式。

在 PyTorch 中,初学者常将「形状满足矩阵乘法条件」与「支持广播运算」混淆。实际上,二者遵循完全不同的规则:

  • *逐元素运算(如 +, -, `,/)依赖广播机制**:要求张量在每个维度上满足广播兼容性——即从尾部维度开始比对,任一维度为1或两维度相等,才能自动扩展。 例如:X.shape = (2, 4)与Y.shape = (4, 2)**无法广播**,因为最后维度4 ≠ 2,倒数第二维2 ≠ 4,且无维度为1可触发扩展。因此X + Y` 报错:

    RuntimeError: The size of tensor a (4) must match the size of tensor b (2) at non-singleton dimension 1

    这明确指出:第 1 维(0-indexed)尺寸不匹配,且均非 1,广播失败。

  • 矩阵乘法(@ 或 torch.matmul)不依赖广播,而是遵循线性代数规则:只要 X 的最后一维等于 Y 的倒数第二维(即 X.shape[-1] == Y.shape[-2]),即可计算。本例中 X 为 (2, 4),Y 为 (4, 2),满足 4 == 4,结果为 (2, 2):

    import torch
    X = torch.tensor([[1,5,2,7],
                      [8,2,5,3]])      # shape: (2, 4)
    Y = torch.tensor([[2,9],
                      [11,4],
                      [9,2],
                      [22,7]])         # shape: (4, 2)
    
    result = torch.matmul(X, Y)  # 或 X @ Y
    print(result)
    # 输出:
    # tensor([[229,  82],
    #         [149, 111]])

⚠️ 注意:torch.mm() 仅支持 2D 张量,而 torch.matmul() 支持高维批量矩阵乘(如 (b, m, k) @ (b, k, n) → (b, m, n)),并可在必要时对缺失的 batch 维度进行隐式广播(如将 (2,4) 视为 (1,2,4) 与 (4,2) 相乘)。但这种广播是 matmul 内部行为,不改变逐元素运算的广播规则

✅ 正确实践建议:

  • 需逐元素运算?先确保形状兼容或显式 unsqueeze()/expand();
  • 需矩阵乘法?直接用 @ 或 torch.matmul(),无需手动调整形状;
  • 调试时善用 .shape 和 torch.broadcast_shapes()(PyTorch 2.0+)验证广播可行性。

归根结底:广播不是“万能适配器”,而是有严格维度对齐规则的逐元素操作机制;而矩阵乘法是独立的、基于线性代数定义的运算——二者不可混为一谈。

文中关于的知识介绍,希望对你的学习有所帮助!若是受益匪浅,那就动动鼠标收藏这篇《PyTorch广播与矩阵乘法教程》文章吧,也可关注golang学习网公众号了解相关技术文章。

前往漫画官网入口并下载 ➜
相关阅读
更多>
最新阅读
更多>
课程推荐
更多>